Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2021-03-03 16:08:05 +00:00
commit 50c242fa29
13 changed files with 109 additions and 49 deletions

1
changelog.d/9498.bugfix Normal file
View file

@ -0,0 +1 @@
Properly purge the event chain cover index when purging history.

1
changelog.d/9521.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to user admin API.

1
changelog.d/9529.misc Normal file
View file

@ -0,0 +1 @@
Bump the versions of mypy and mypy-zope used for static type checking.

1
changelog.d/9537.bugfix Normal file
View file

@ -0,0 +1 @@
Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events.

View file

@ -102,7 +102,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8", "flake8",
] ]
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"] CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.11"]
# Dependencies which are exclusively required by unit test code. This is # Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests. # NOT a list of all modules that are necessary to run the unit tests.

View file

@ -502,7 +502,7 @@ class AccountDataStream(Stream):
"""Global or per room account data was changed""" """Global or per room account data was changed"""
AccountDataStreamRow = namedtuple( AccountDataStreamRow = namedtuple(
"AccountDataStream", "AccountDataStreamRow",
("user_id", "room_id", "data_type"), # str # Optional[str] # str ("user_id", "room_id", "data_type"), # str # Optional[str] # str
) )

View file

@ -16,7 +16,7 @@ import hashlib
import hmac import hmac
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
@ -47,13 +47,15 @@ logger = logging.getLogger(__name__)
class UsersRestServlet(RestServlet): class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, List[JsonDict]]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
otherwise an error. otherwise an error.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
@ -165,7 +167,9 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -179,7 +183,9 @@ class UserRestServletV2(RestServlet):
return 200, ret return 200, ret
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -273,6 +279,8 @@ class UserRestServletV2(RestServlet):
) )
user = await self.admin_handler.get_user(target_user) user = await self.admin_handler.get_user(target_user)
assert user is not None
return 200, user return 200, user
else: # create user else: # create user
@ -330,9 +338,10 @@ class UserRestServletV2(RestServlet):
target_user, requester, body["avatar_url"], True target_user, requester, body["avatar_url"], True
) )
ret = await self.admin_handler.get_user(target_user) user = await self.admin_handler.get_user(target_user)
assert user is not None
return 201, ret return 201, user
class UserRegisterServlet(RestServlet): class UserRegisterServlet(RestServlet):
@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet):
PATTERNS = admin_patterns("/register") PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60 NONCE_TIMEOUT = 60
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
self.nonces = {} self.nonces = {} # type: Dict[str, int]
self.hs = hs self.hs = hs
def _clear_old_nonces(self): def _clear_old_nonces(self):
@ -362,7 +371,7 @@ class UserRegisterServlet(RestServlet):
if now - v > self.NONCE_TIMEOUT: if now - v > self.NONCE_TIMEOUT:
del self.nonces[k] del self.nonces[k]
def on_GET(self, request): def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
""" """
Generate a new nonce. Generate a new nonce.
""" """
@ -372,7 +381,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds()) self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce} return 200, {"nonce": nonce}
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces() self._clear_old_nonces()
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet):
client_patterns("/admin" + path_regex, v1=True) client_patterns("/admin" + path_regex, v1=True)
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
auth_user = requester.user auth_user = requester.user
@ -508,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -550,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
async def on_POST(self, request, target_user_id): async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
""" """
@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, target_user_id): async def on_GET(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, Optional[List[JsonDict]]]:
"""Get request to search user table for specific users according to """Get request to search user table for specific users according to
search term. search term.
This needs user to have a administrator access in Synapse. This needs user to have a administrator access in Synapse.
@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -699,7 +718,9 @@ class UserAdminServlet(RestServlet):
return 200, {"admin": is_admin} return 200, {"admin": is_admin}
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user auth_user = requester.user
@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -891,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user auth_user = requester.user
@ -943,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):

View file

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import List, Optional, Tuple
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
from synapse.types import get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore from .account_data import AccountDataStore
@ -264,7 +264,7 @@ class DataStore(
return [UserPresenceState(**row) for row in rows] return [UserPresenceState(**row) for row in rows]
async def get_users(self) -> List[Dict[str, Any]]: async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table. """Function to retrieve a list of users in users table.
Returns: Returns:
@ -292,7 +292,7 @@ class DataStore(
name: Optional[str] = None, name: Optional[str] = None,
guests: bool = True, guests: bool = True,
deactivated: bool = False, deactivated: bool = False,
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json list of users and the users list. This will return a json list of users and the
total number of users matching the filter criteria. total number of users matching the filter criteria.
@ -353,7 +353,7 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn "get_users_paginate_txn", get_users_paginate_txn
) )
async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]: async def search_users(self, term: str) -> Optional[List[JsonDict]]:
"""Function to search users list for one or more users with """Function to search users list for one or more users with
the matched term. the matched term.

View file

@ -696,7 +696,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
) )
if not has_event_auth: if not has_event_auth:
for auth_id in event.auth_event_ids(): # Old, dodgy, events may have duplicate auth events, which we
# need to deduplicate as we have a unique constraint.
for auth_id in set(event.auth_event_ids()):
auth_events.append( auth_events.append(
{ {
"room_id": event.room_id, "room_id": event.room_id,

View file

@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: int, start: int,
limit: int, limit: int,
user_id: str, user_id: str,
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value, order_by: str = MediaSortOrder.CREATED_TS.value,
direction: str = "f", direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media """Get a paginated list of metadata for a local piece of media

View file

@ -28,7 +28,10 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
async def purge_history( async def purge_history(
self, room_id: str, token: str, delete_local_events: bool self, room_id: str, token: str, delete_local_events: bool
) -> Set[int]: ) -> Set[int]:
"""Deletes room history before a certain point """Deletes room history before a certain point.
Note that only a single purge can occur at once, this is guaranteed via
a higher level (in the PaginationHandler).
Args: Args:
room_id: room_id:
@ -52,7 +55,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
delete_local_events, delete_local_events,
) )
def _purge_history_txn(self, txn, room_id, token, delete_local_events): def _purge_history_txn(
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
) -> Set[int]:
# Tables that should be pruned: # Tables that should be pruned:
# event_auth # event_auth
# event_backward_extremities # event_backward_extremities
@ -103,7 +108,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
if max_depth < token.topological: if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database # We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not # otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties) # having any backwards extremities)
raise SynapseError( raise SynapseError(
400, "topological_ordering is greater than forward extremeties" 400, "topological_ordering is greater than forward extremeties"
) )
@ -154,7 +159,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
logger.info("[purge] Finding new backward extremities") logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding # We calculate the new entries for the backward extremities by finding
# events to be purged that are pointed to by events we're not going to # events to be purged that are pointed to by events we're not going to
# purge. # purge.
txn.execute( txn.execute(
@ -296,7 +301,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"purge_room", self._purge_room_txn, room_id "purge_room", self._purge_room_txn, room_id
) )
def _purge_room_txn(self, txn, room_id): def _purge_room_txn(self, txn, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before # First we fetch all the state groups that should be deleted, before
# we delete that information. # we delete that information.
txn.execute( txn.execute(
@ -310,6 +315,31 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
state_groups = [row[0] for row in txn] state_groups = [row[0] for row in txn]
# Get all the auth chains that are referenced by events that are to be
# deleted.
txn.execute(
"""
SELECT chain_id, sequence_number FROM events
LEFT JOIN event_auth_chains USING (event_id)
WHERE room_id = ?
""",
(room_id,),
)
referenced_chain_id_tuples = list(txn)
logger.info("[purge] removing events from event_auth_chain_links")
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
(origin_chain_id = ? AND origin_sequence_number = ?) OR
(target_chain_id = ? AND target_sequence_number = ?)
""",
(
(chain_id, seq_num, chain_id, seq_num)
for (chain_id, seq_num) in referenced_chain_id_tuples
),
)
# Now we delete tables which lack an index on room_id but have one on event_id # Now we delete tables which lack an index on room_id but have one on event_id
for table in ( for table in (
"event_auth", "event_auth",
@ -319,6 +349,8 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_reference_hashes", "event_reference_hashes",
"event_relations", "event_relations",
"event_to_state_groups", "event_to_state_groups",
"event_auth_chains",
"event_auth_chain_to_calculate",
"redactions", "redactions",
"rejections", "rejections",
"state_events", "state_events",

View file

@ -73,9 +73,6 @@ class PurgeEventsStorage:
Returns: Returns:
The set of state groups that can be deleted. The set of state groups that can be deleted.
""" """
# Graph of state group -> previous group
graph = {}
# Set of events that we have found to be referenced by events # Set of events that we have found to be referenced by events
referenced_groups = set() referenced_groups = set()
@ -111,8 +108,6 @@ class PurgeEventsStorage:
next_to_search |= prevs next_to_search |= prevs
state_groups_seen |= prevs state_groups_seen |= prevs
graph.update(edges)
to_delete = state_groups_seen - referenced_groups to_delete = state_groups_seen - referenced_groups
return to_delete return to_delete

View file

@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
) )
GetRoomsForUserWithStreamOrdering = namedtuple( GetRoomsForUserWithStreamOrdering = namedtuple(
"_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
) )