0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-21 03:42:07 +01:00

Convert synapse.api to async/await (#8031)

This commit is contained in:
Patrick Cloke 2020-08-06 08:30:06 -04:00 committed by GitHub
parent c36228c403
commit d4a7829b12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 171 additions and 159 deletions

1
changelog.d/8031.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import List, Optional, Tuple
import pymacaroons import pymacaroons
from netaddr import IPAddress from netaddr import IPAddress
from twisted.internet import defer
from twisted.web.server import Request from twisted.web.server import Request
import synapse.types import synapse.types
@ -80,13 +79,14 @@ class Auth(object):
self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key self._macaroon_secret_key = hs.config.macaroon_secret_key
@defer.inlineCallbacks async def check_from_context(
def check_from_context(self, room_version: str, event, context, do_sig_check=True): self, room_version: str, event, context, do_sig_check=True
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) ):
auth_events_ids = yield self.compute_auth_events( prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@ -94,14 +94,13 @@ class Auth(object):
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
) )
@defer.inlineCallbacks async def check_user_in_room(
def check_user_in_room(
self, self,
room_id: str, room_id: str,
user_id: str, user_id: str,
current_state: Optional[StateMap[EventBase]] = None, current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False, allow_departed_users: bool = False,
): ) -> EventBase:
"""Check if the user is in the room, or was at some point. """Check if the user is in the room, or was at some point.
Args: Args:
room_id: The room to check. room_id: The room to check.
@ -119,37 +118,35 @@ class Auth(object):
Raises: Raises:
AuthError if the user is/was not in the room. AuthError if the user is/was not in the room.
Returns: Returns:
Deferred[Optional[EventBase]]: Membership event for the user if the user was in the
Membership event for the user if the user was in the room. This will be the join event if they are currently joined to
room. This will be the join event if they are currently joined to the room. This will be the leave event if they have left the room.
the room. This will be the leave event if they have left the room.
""" """
if current_state: if current_state:
member = current_state.get((EventTypes.Member, user_id), None) member = current_state.get((EventTypes.Member, user_id), None)
else: else:
member = yield defer.ensureDeferred( member = await self.state.get_current_state(
self.state.get_current_state( room_id=room_id, event_type=EventTypes.Member, state_key=user_id
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
) )
membership = member.membership if member else None
if membership == Membership.JOIN: if member:
return member membership = member.membership
# XXX this looks totally bogus. Why do we not allow users who have been banned, if membership == Membership.JOIN:
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
if not forgot:
return member return member
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member
raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@defer.inlineCallbacks async def check_host_in_room(self, room_id, host):
def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.is_host_joined(room_id, host) latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids return latest_event_ids
def can_federate(self, event, auth_events): def can_federate(self, event, auth_events):
@ -160,14 +157,13 @@ class Auth(object):
def get_public_keys(self, invite_event): def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event) return event_auth.get_public_keys(invite_event)
@defer.inlineCallbacks async def get_user_by_req(
def get_user_by_req(
self, self,
request: Request, request: Request,
allow_guest: bool = False, allow_guest: bool = False,
rights: str = "access", rights: str = "access",
allow_expired: bool = False, allow_expired: bool = False,
): ) -> synapse.types.Requester:
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -180,7 +176,7 @@ class Auth(object):
/login will deliver access tokens regardless of expiration. /login will deliver access tokens regardless of expiration.
Returns: Returns:
defer.Deferred: resolves to a `synapse.types.Requester` object Resolves to the requester
Raises: Raises:
InvalidClientCredentialsError if no user by that token exists or the token InvalidClientCredentialsError if no user by that token exists or the token
is invalid. is invalid.
@ -194,14 +190,14 @@ class Auth(object):
access_token = self.get_access_token_from_request(request) access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request) user_id, app_service = await self._get_appservice_user_id(request)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id) opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips: if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user_id, user_id=user_id,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -211,7 +207,7 @@ class Auth(object):
return synapse.types.create_requester(user_id, app_service=app_service) return synapse.types.create_requester(user_id, app_service=app_service)
user_info = yield self.get_user_by_access_token( user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired access_token, rights, allow_expired=allow_expired
) )
user = user_info["user"] user = user_info["user"]
@ -221,7 +217,7 @@ class Auth(object):
# Deny the request if the user account has expired. # Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired: if self._account_validity.enabled and not allow_expired:
user_id = user.to_string() user_id = user.to_string()
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
if ( if (
expiration_ts is not None expiration_ts is not None
and self.clock.time_msec() >= expiration_ts and self.clock.time_msec() >= expiration_ts
@ -235,7 +231,7 @@ class Auth(object):
device_id = user_info.get("device_id") device_id = user_info.get("device_id")
if user and access_token and ip_addr: if user and access_token and ip_addr:
yield self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user.to_string(), user_id=user.to_string(),
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -261,8 +257,7 @@ class Auth(object):
except KeyError: except KeyError:
raise MissingClientTokenError() raise MissingClientTokenError()
@defer.inlineCallbacks async def _get_appservice_user_id(self, request):
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request) self.get_access_token_from_request(request)
) )
@ -283,14 +278,13 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id): if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.") raise AuthError(403, "Application service cannot masquerade as this user.")
if not (yield self.store.get_user_by_id(user_id)): if not (await self.store.get_user_by_id(user_id)):
raise AuthError(403, "Application service has not registered this user") raise AuthError(403, "Application service has not registered this user")
return user_id, app_service return user_id, app_service
@defer.inlineCallbacks async def get_user_by_access_token(
def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False, self, token: str, rights: str = "access", allow_expired: bool = False,
): ) -> dict:
""" Validate access token and get user_id from it """ Validate access token and get user_id from it
Args: Args:
@ -300,7 +294,7 @@ class Auth(object):
allow_expired: If False, raises an InvalidClientTokenError allow_expired: If False, raises an InvalidClientTokenError
if the token is expired if the token is expired
Returns: Returns:
Deferred[dict]: dict that includes: dict that includes:
`user` (UserID) `user` (UserID)
`is_guest` (bool) `is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest `token_id` (int|None): access token id. May be None if guest
@ -314,7 +308,7 @@ class Auth(object):
if rights == "access": if rights == "access":
# first look in the database # first look in the database
r = yield self._look_up_user_by_access_token(token) r = await self._look_up_user_by_access_token(token)
if r: if r:
valid_until_ms = r["valid_until_ms"] valid_until_ms = r["valid_until_ms"]
if ( if (
@ -352,7 +346,7 @@ class Auth(object):
# It would of course be much easier to store guest access # It would of course be much easier to store guest access
# tokens in the database as well, but that would break existing # tokens in the database as well, but that would break existing
# guest tokens. # guest tokens.
stored_user = yield self.store.get_user_by_id(user_id) stored_user = await self.store.get_user_by_id(user_id)
if not stored_user: if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id) raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]: if not stored_user["is_guest"]:
@ -482,9 +476,8 @@ class Auth(object):
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry
@defer.inlineCallbacks async def _look_up_user_by_access_token(self, token):
def _look_up_user_by_access_token(self, token): ret = await self.store.get_user_by_access_token(token)
ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
return None return None
@ -507,7 +500,7 @@ class Auth(object):
logger.warning("Unrecognised appservice access token.") logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError() raise InvalidClientTokenError()
request.authenticated_entity = service.sender request.authenticated_entity = service.sender
return defer.succeed(service) return service
async def is_server_admin(self, user: UserID) -> bool: async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin. """ Check if the given user is a local server admin.
@ -522,7 +515,7 @@ class Auth(object):
def compute_auth_events( def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False, self, event, current_state_ids: StateMap[str], for_verification: bool = False,
): ) -> List[str]:
"""Given an event and current state return the list of event IDs used """Given an event and current state return the list of event IDs used
to auth an event. to auth an event.
@ -530,11 +523,11 @@ class Auth(object):
should be added to the event's `auth_events`. should be added to the event's `auth_events`.
Returns: Returns:
defer.Deferred(list[str]): List of event IDs. List of event IDs.
""" """
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
return defer.succeed([]) return []
# Currently we ignore the `for_verification` flag even though there are # Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding # some situations where we can drop particular auth events when adding
@ -553,7 +546,7 @@ class Auth(object):
if auth_ev_id: if auth_ev_id:
auth_ids.append(auth_ev_id) auth_ids.append(auth_ev_id)
return defer.succeed(auth_ids) return auth_ids
async def check_can_change_room_list(self, room_id: str, user: UserID): async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the """Determine whether the user is allowed to edit the room's entry in the
@ -636,10 +629,9 @@ class Auth(object):
return query_params[0].decode("ascii") return query_params[0].decode("ascii")
@defer.inlineCallbacks async def check_user_in_room_or_world_readable(
def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False self, room_id: str, user_id: str, allow_departed_users: bool = False
): ) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world """Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised. readable. If it isn't then an exception is raised.
@ -650,10 +642,9 @@ class Auth(object):
members but have now departed members but have now departed
Returns: Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of Resolves to the current membership of the user in the room and the
the user in the room and the membership event ID of the user. If membership event ID of the user. If the user is not in the room and
the user is not in the room and never has been, then never has been, then `(Membership.JOIN, None)` is returned.
`(Membership.JOIN, None)` is returned.
""" """
try: try:
@ -662,15 +653,13 @@ class Auth(object):
# * The user is a non-guest user, and was ever in the room # * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room # * The user is a guest user, and has joined the room
# else it will throw. # else it will throw.
member_event = yield self.check_user_in_room( member_event = await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users room_id, user_id, allow_departed_users=allow_departed_users
) )
return member_event.membership, member_event.event_id return member_event.membership, member_event.event_id
except AuthError: except AuthError:
visibility = yield defer.ensureDeferred( visibility = await self.state.get_current_state(
self.state.get_current_state( room_id, EventTypes.RoomHistoryVisibility, ""
room_id, EventTypes.RoomHistoryVisibility, ""
)
) )
if ( if (
visibility visibility

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
@ -36,8 +34,7 @@ class AuthBlocking(object):
self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
@defer.inlineCallbacks async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason, """Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag
@ -60,7 +57,7 @@ class AuthBlocking(object):
if user_id is not None: if user_id is not None:
if user_id == self._server_notices_mxid: if user_id == self._server_notices_mxid:
return return
if (yield self.store.is_support_user(user_id)): if await self.store.is_support_user(user_id):
return return
if self._hs_disabled: if self._hs_disabled:
@ -76,11 +73,11 @@ class AuthBlocking(object):
# If the user is already part of the MAU cohort or a trial user # If the user is already part of the MAU cohort or a trial user
if user_id: if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id) timestamp = await self.store.user_last_seen_monthly_active(user_id)
if timestamp: if timestamp:
return return
is_trial = yield self.store.is_trial_user(user_id) is_trial = await self.store.is_trial_user(user_id)
if is_trial: if is_trial:
return return
elif threepid: elif threepid:
@ -93,7 +90,7 @@ class AuthBlocking(object):
# allow registration. Support users are excluded from MAU checks. # allow registration. Support users are excluded from MAU checks.
return return
# Else if there is no room in the MAU bucket, bail # Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count() current_mau = await self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value: if current_mau >= self._max_mau_value:
raise ResourceLimitError( raise ResourceLimitError(
403, 403,

View file

@ -21,8 +21,6 @@ import jsonschema
from canonicaljson import json from canonicaljson import json
from jsonschema import FormatChecker from jsonschema import FormatChecker
from twisted.internet import defer
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
@ -137,9 +135,8 @@ class Filtering(object):
super(Filtering, self).__init__() super(Filtering, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def get_user_filter(self, user_localpart, filter_id):
def get_user_filter(self, user_localpart, filter_id): result = await self.store.get_user_filter(user_localpart, filter_id)
result = yield self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result) return FilterCollection(result)
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):

View file

@ -106,7 +106,7 @@ class EventBuilder(object):
state_ids = await self._state.get_current_state_ids( state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids self.room_id, prev_event_ids
) )
auth_ids = await self._auth.compute_auth_events(self, state_ids) auth_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1: if format_version == EventFormatVersions.V1:

View file

@ -2064,7 +2064,7 @@ class FederationHandler(BaseHandler):
if not auth_events: if not auth_events:
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = await self.auth.compute_auth_events( auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events_x = await self.store.get_events(auth_events_ids) auth_events_x = await self.store.get_events(auth_events_ids)

View file

@ -1061,7 +1061,7 @@ class EventCreationHandler(object):
raise SynapseError(400, "Cannot redact event from a different room") raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = await self.auth.compute_auth_events( auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = await self.store.get_events(auth_events_ids) auth_events = await self.store.get_events(auth_events_ids)

View file

@ -194,12 +194,16 @@ class ModuleApi(object):
synapse.api.errors.AuthError: the access token is invalid synapse.api.errors.AuthError: the access token is invalid
""" """
# see if the access token corresponds to a device # see if the access token corresponds to a device
user_info = yield self._auth.get_user_by_access_token(access_token) user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)
)
device_id = user_info.get("device_id") device_id = user_info.get("device_id")
user_id = user_info["user"].to_string() user_id = user_info["user"].to_string()
if device_id: if device_id:
# delete the device, which will also delete its access tokens # delete the device, which will also delete its access tokens
yield self._hs.get_device_handler().delete_device(user_id, device_id) yield defer.ensureDeferred(
self._hs.get_device_handler().delete_device(user_id, device_id)
)
else: else:
# no associated device. Just delete the access token. # no associated device. Just delete the access token.
yield defer.ensureDeferred( yield defer.ensureDeferred(

View file

@ -120,7 +120,7 @@ class BulkPushRuleEvaluator(object):
pl_event = await self.store.get_event(pl_event_id) pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event} auth_events = {POWER_KEY: pl_event}
else: else:
auth_events_ids = await self.auth.compute_auth_events( auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False event, prev_state_ids, for_verification=False
) )
auth_events = await self.store.get_events(auth_events_ids) auth_events = await self.store.get_events(auth_events_ids)

View file

@ -28,7 +28,7 @@ class SlavedClientIpStore(BaseSlavedStore):
name="client_ip_last_seen", keylen=4, max_entries=50000 name="client_ip_last_seen", keylen=4, max_entries=50000
) )
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
key = (user_id, access_token, ip) key = (user_id, access_token, ip)

View file

@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
try: try:
service = await self.auth.get_appservice_by_req(request) service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
await dir_handler.delete_appservice_association(service, room_alias) await dir_handler.delete_appservice_association(service, room_alias)
logger.info( logger.info(

View file

@ -424,7 +424,7 @@ class RegisterRestServlet(RestServlet):
appservice = None appservice = None
if self.auth.has_access_token(request): if self.auth.has_access_token(request):
appservice = await self.auth.get_appservice_by_req(request) appservice = self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely # fork off as soon as possible for ASes which have completely
# different registration flows to normal users # different registration flows to normal users

View file

@ -380,8 +380,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if self.user_ips_max_age: if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks async def insert_client_ip(
def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None self, user_id, access_token, ip, user_agent, device_id, now=None
): ):
if not now: if not now:
@ -392,7 +391,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
yield self.populate_monthly_active_users(user_id) await self.populate_monthly_active_users(user_id)
# Rate-limited inserts # Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return return

View file

@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests # this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
self.store.is_support_user = Mock(return_value=defer.succeed(False)) self.store.is_support_user = Mock(return_value=defer.succeed(False))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"} user_info = {"name": self.test_user, "token_id": "ditto"}
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN") self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10" request.getClientIP.return_value = "192.168.10.10"
@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42" request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self): def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN") self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
) )
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) # This just needs to return a truth-y value.
self.store.get_user_by_id = Mock(
return_value=defer.succeed({"is_guest": False})
)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
) )
app_service.is_interested_in_user = Mock(return_value=False) app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = defer.ensureDeferred(self.auth.get_user_by_req(request))
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_from_macaroon(self): def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org", "device_id": "device"} return_value=defer.succeed(
{"name": "@baldrick:matrix.org", "device_id": "device"}
)
) )
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = Mock(return_value={"is_guest": True}) self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
def get_user(tok): def get_user(tok):
if token != tok: if token != tok:
return None return defer.succeed(None)
return { return defer.succeed(
"name": USER_ID, {
"is_guest": False, "name": USER_ID,
"token_id": 1234, "is_guest": False,
"device_id": "DEVICE", "token_id": 1234,
} "device_id": "DEVICE",
}
)
self.store.get_user_by_access_token = get_user self.store.get_user_by_access_token = get_user
self.store.get_user_by_id = Mock(return_value={"is_guest": False}) self.store.get_user_by_id = Mock(
return_value=defer.succeed({"is_guest": False})
)
# check the token works # check the token works
request = Mock(args={}) request = Mock(args={})

View file

@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile") event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield defer.ensureDeferred(
user_localpart=user_localpart, filter_id=filter_id self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
) )
results = user_filter.filter_presence(events=events) results = user_filter.filter_presence(events=events)
@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
) )
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield defer.ensureDeferred(
user_localpart=user_localpart + "2", filter_id=filter_id self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id
)
) )
results = user_filter.filter_presence(events=events) results = user_filter.filter_presence(events=events)
@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield defer.ensureDeferred(
user_localpart=user_localpart, filter_id=filter_id self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
) )
results = user_filter.filter_room_state(events=events) results = user_filter.filter_room_state(events=events)
@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
) )
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield defer.ensureDeferred(
user_localpart=user_localpart, filter_id=filter_id self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
) )
results = user_filter.filter_room_state(events) results = user_filter.filter_room_state(events)
@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
user_filter_json, user_filter_json,
( (
yield self.datastore.get_user_filter( yield defer.ensureDeferred(
user_localpart=user_localpart, filter_id=0 self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0
)
) )
), ),
) )
@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
user_localpart=user_localpart, user_filter=user_filter_json user_localpart=user_localpart, user_filter=user_filter_json
) )
filter = yield self.filtering.get_user_filter( filter = yield defer.ensureDeferred(
user_localpart=user_localpart, filter_id=filter_id self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
) )
self.assertEquals(filter.get_filter_json(), user_filter_json) self.assertEquals(filter.get_filter_json(), user_filter_json)

View file

@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = [] self.room_members = []
def check_user_in_room(room_id, user_id): async def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]: if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room") raise AuthError(401, "User is not in the room")
return defer.succeed(None) return None
hs.get_auth().check_user_in_room = check_user_in_room hs.get_auth().check_user_in_room = check_user_in_room

View file

@ -20,6 +20,8 @@ import urllib.parse
from mock import Mock from mock import Mock
from twisted.internet import defer
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError from synapse.api.errors import HttpResponseException, ResourceLimitError
@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore() store = self.hs.get_datastore()
# Set monthly active users to the limit # Set monthly active users to the limit
store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit
self.get_failure( self.get_failure(
@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit # Set monthly active users to the limit
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=self.hs.config.max_mau_value return_value=defer.succeed(self.hs.config.max_mau_value)
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit
@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit # Set monthly active users to the limit
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=self.hs.config.max_mau_value return_value=defer.succeed(self.hs.config.max_mau_value)
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit

View file

@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler, profile_handler=self.mock_handler,
) )
def _get_user_by_req(request=None, allow_guest=False): async def _get_user_by_req(request=None, allow_guest=False):
return defer.succeed(synapse.types.create_requester(myid)) return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_auth().get_user_by_req = _get_user_by_req

View file

@ -23,8 +23,6 @@ from urllib import parse as urlparse
from mock import Mock from mock import Mock
from twisted.internet import defer
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus from synapse.handlers.pagination import PurgeStatus
@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock()) self.hs.get_federation_handler = Mock(return_value=Mock())
def _insert_client_ip(*args, **kwargs): async def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip self.hs.get_datastore().insert_client_ip = _insert_client_ip

View file

@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def get_user_by_access_token(token=None, allow_guest=False): async def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): async def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return None
hs.get_datastore().insert_client_ip = _insert_client_ip hs.get_datastore().insert_client_ip = _insert_client_ip

View file

@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000) return_value=defer.succeed(1000)

View file

@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"): if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False): async def get_user_by_access_token(token=None, allow_guest=False):
return succeed( return {
{ "user": UserID.from_string(self.helper.auth_user_id),
"user": UserID.from_string(self.helper.auth_user_id), "token_id": 1,
"token_id": 1, "is_guest": False,
"is_guest": False, }
}
)
def get_user_by_req(request, allow_guest=False, rights="access"): async def get_user_by_req(request, allow_guest=False, rights="access"):
return succeed( return create_requester(
create_requester( UserID.from_string(self.helper.auth_user_id), 1, False, None
UserID.from_string(self.helper.auth_user_id), 1, False, None
)
) )
self.hs.get_auth().get_user_by_req = get_user_by_req self.hs.get_auth().get_user_by_req = get_user_by_req