Convert synapse.api to async/await (#8031)

This commit is contained in:
Patrick Cloke 2020-08-06 08:30:06 -04:00 committed by GitHub
parent c36228c403
commit d4a7829b12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 171 additions and 159 deletions

1
changelog.d/8031.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import List, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
from twisted.internet import defer
from twisted.web.server import Request
import synapse.types
@ -80,13 +79,14 @@ class Auth(object):
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events(
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
):
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@ -94,14 +94,13 @@ class Auth(object):
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
)
@defer.inlineCallbacks
def check_user_in_room(
async def check_user_in_room(
self,
room_id: str,
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
):
) -> EventBase:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@ -119,37 +118,35 @@ class Auth(object):
Raises:
AuthError if the user is/was not in the room.
Returns:
Deferred[Optional[EventBase]]:
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield defer.ensureDeferred(
self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
if membership == Membership.JOIN:
return member
if member:
membership = member.membership
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
if not forgot:
if membership == Membership.JOIN:
return member
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
async def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.is_host_joined(room_id, host)
latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids
def can_federate(self, event, auth_events):
@ -160,14 +157,13 @@ class Auth(object):
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event)
@defer.inlineCallbacks
def get_user_by_req(
async def get_user_by_req(
self,
request: Request,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
):
) -> synapse.types.Requester:
""" Get a registered user's ID.
Args:
@ -180,7 +176,7 @@ class Auth(object):
/login will deliver access tokens regardless of expiration.
Returns:
defer.Deferred: resolves to a `synapse.types.Requester` object
Resolves to the requester
Raises:
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
@ -194,14 +190,14 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip(
await self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
ip=ip_addr,
@ -211,7 +207,7 @@ class Auth(object):
return synapse.types.create_requester(user_id, app_service=app_service)
user_info = yield self.get_user_by_access_token(
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
@ -221,7 +217,7 @@ class Auth(object):
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
if (
expiration_ts is not None
and self.clock.time_msec() >= expiration_ts
@ -235,7 +231,7 @@ class Auth(object):
device_id = user_info.get("device_id")
if user and access_token and ip_addr:
yield self.store.insert_client_ip(
await self.store.insert_client_ip(
user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
@ -261,8 +257,7 @@ class Auth(object):
except KeyError:
raise MissingClientTokenError()
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
async def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
@ -283,14 +278,13 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
if not (yield self.store.get_user_by_id(user_id)):
if not (await self.store.get_user_by_id(user_id)):
raise AuthError(403, "Application service has not registered this user")
return user_id, app_service
@defer.inlineCallbacks
def get_user_by_access_token(
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
):
) -> dict:
""" Validate access token and get user_id from it
Args:
@ -300,7 +294,7 @@ class Auth(object):
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
Deferred[dict]: dict that includes:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
@ -314,7 +308,7 @@ class Auth(object):
if rights == "access":
# first look in the database
r = yield self._look_up_user_by_access_token(token)
r = await self._look_up_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
if (
@ -352,7 +346,7 @@ class Auth(object):
# It would of course be much easier to store guest access
# tokens in the database as well, but that would break existing
# guest tokens.
stored_user = yield self.store.get_user_by_id(user_id)
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
@ -482,9 +476,8 @@ class Auth(object):
now = self.hs.get_clock().time_msec()
return now < expiry
@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
@ -507,7 +500,7 @@ class Auth(object):
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
return defer.succeed(service)
return service
async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin.
@ -522,7 +515,7 @@ class Auth(object):
def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
):
) -> List[str]:
"""Given an event and current state return the list of event IDs used
to auth an event.
@ -530,11 +523,11 @@ class Auth(object):
should be added to the event's `auth_events`.
Returns:
defer.Deferred(list[str]): List of event IDs.
List of event IDs.
"""
if event.type == EventTypes.Create:
return defer.succeed([])
return []
# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
@ -553,7 +546,7 @@ class Auth(object):
if auth_ev_id:
auth_ids.append(auth_ev_id)
return defer.succeed(auth_ids)
return auth_ids
async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
@ -636,10 +629,9 @@ class Auth(object):
return query_params[0].decode("ascii")
@defer.inlineCallbacks
def check_user_in_room_or_world_readable(
async def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False
):
) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
@ -650,10 +642,9 @@ class Auth(object):
members but have now departed
Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of
the user in the room and the membership event ID of the user. If
the user is not in the room and never has been, then
`(Membership.JOIN, None)` is returned.
Resolves to the current membership of the user in the room and the
membership event ID of the user. If the user is not in the room and
never has been, then `(Membership.JOIN, None)` is returned.
"""
try:
@ -662,15 +653,13 @@ class Auth(object):
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.check_user_in_room(
member_event = await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = yield defer.ensureDeferred(
self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
visibility = await self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility

View file

@ -15,8 +15,6 @@
import logging
from twisted.internet import defer
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
@ -36,8 +34,7 @@ class AuthBlocking(object):
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
@ -60,7 +57,7 @@ class AuthBlocking(object):
if user_id is not None:
if user_id == self._server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
if await self.store.is_support_user(user_id):
return
if self._hs_disabled:
@ -76,11 +73,11 @@ class AuthBlocking(object):
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
timestamp = await self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
is_trial = yield self.store.is_trial_user(user_id)
is_trial = await self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
@ -93,7 +90,7 @@ class AuthBlocking(object):
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
current_mau = await self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value:
raise ResourceLimitError(
403,

View file

@ -21,8 +21,6 @@ import jsonschema
from canonicaljson import json
from jsonschema import FormatChecker
from twisted.internet import defer
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
@ -137,9 +135,8 @@ class Filtering(object):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
result = yield self.store.get_user_filter(user_localpart, filter_id)
async def get_user_filter(self, user_localpart, filter_id):
result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)
def add_user_filter(self, user_localpart, user_filter):

View file

@ -106,7 +106,7 @@ class EventBuilder(object):
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
auth_ids = await self._auth.compute_auth_events(self, state_ids)
auth_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:

View file

@ -2064,7 +2064,7 @@ class FederationHandler(BaseHandler):
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = await self.auth.compute_auth_events(
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)

View file

@ -1061,7 +1061,7 @@ class EventCreationHandler(object):
raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = await self.auth.compute_auth_events(
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)

View file

@ -194,12 +194,16 @@ class ModuleApi(object):
synapse.api.errors.AuthError: the access token is invalid
"""
# see if the access token corresponds to a device
user_info = yield self._auth.get_user_by_access_token(access_token)
user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)
)
device_id = user_info.get("device_id")
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
yield self._hs.get_device_handler().delete_device(user_id, device_id)
yield defer.ensureDeferred(
self._hs.get_device_handler().delete_device(user_id, device_id)
)
else:
# no associated device. Just delete the access token.
yield defer.ensureDeferred(

View file

@ -120,7 +120,7 @@ class BulkPushRuleEvaluator(object):
pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event}
else:
auth_events_ids = await self.auth.compute_auth_events(
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events = await self.store.get_events(auth_events_ids)

View file

@ -28,7 +28,7 @@ class SlavedClientIpStore(BaseSlavedStore):
name="client_ip_last_seen", keylen=4, max_entries=50000
)
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)

View file

@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
try:
service = await self.auth.get_appservice_by_req(request)
service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await dir_handler.delete_appservice_association(service, room_alias)
logger.info(

View file

@ -424,7 +424,7 @@ class RegisterRestServlet(RestServlet):
appservice = None
if self.auth.has_access_token(request):
appservice = await self.auth.get_appservice_by_req(request)
appservice = self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users

View file

@ -380,8 +380,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks
def insert_client_ip(
async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
if not now:
@ -392,7 +391,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
yield self.populate_monthly_active_users(user_id)
await self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return

View file

@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
self.store.is_support_user = Mock(return_value=defer.succeed(False))
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
self.store.get_user_by_access_token = Mock(return_value=user_info)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
self.store.get_user_by_access_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
self.store.get_user_by_access_token = Mock(return_value=user_info)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
# This just needs to return a truth-y value.
self.store.get_user_by_id = Mock(
return_value=defer.succeed({"is_guest": False})
)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
return_value=defer.succeed(
{"name": "@baldrick:matrix.org", "device_id": "device"}
)
)
user_id = "@baldrick:matrix.org"
@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = Mock(return_value={"is_guest": True})
self.store.get_user_by_access_token = Mock(return_value=None)
self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
def get_user(tok):
if token != tok:
return None
return {
"name": USER_ID,
"is_guest": False,
"token_id": 1234,
"device_id": "DEVICE",
}
return defer.succeed(None)
return defer.succeed(
{
"name": USER_ID,
"is_guest": False,
"token_id": 1234,
"device_id": "DEVICE",
}
)
self.store.get_user_by_access_token = get_user
self.store.get_user_by_id = Mock(return_value={"is_guest": False})
self.store.get_user_by_id = Mock(
return_value=defer.succeed({"is_guest": False})
)
# check the token works
request = Mock(args={})

View file

@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
user_filter = yield defer.ensureDeferred(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
)
results = user_filter.filter_presence(events=events)
@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id
user_filter = yield defer.ensureDeferred(
self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id
)
)
results = user_filter.filter_presence(events=events)
@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
user_filter = yield defer.ensureDeferred(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
)
results = user_filter.filter_room_state(events=events)
@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
user_filter = yield defer.ensureDeferred(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
)
results = user_filter.filter_room_state(events)
@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(
user_filter_json,
(
yield self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0
yield defer.ensureDeferred(
self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0
)
)
),
)
@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
user_localpart=user_localpart, user_filter=user_filter_json
)
filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
filter = yield defer.ensureDeferred(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
)
self.assertEquals(filter.get_filter_json(), user_filter_json)

View file

@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
def check_user_in_room(room_id, user_id):
async def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
return defer.succeed(None)
return None
hs.get_auth().check_user_in_room = check_user_in_room

View file

@ -20,6 +20,8 @@ import urllib.parse
from mock import Mock
from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError
@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
# Set monthly active users to the limit
store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
self.get_failure(
@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
return_value=self.hs.config.max_mau_value
return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
return_value=self.hs.config.max_mau_value
return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit

View file

@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler,
)
def _get_user_by_req(request=None, allow_guest=False):
return defer.succeed(synapse.types.create_requester(myid))
async def _get_user_by_req(request=None, allow_guest=False):
return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req

View file

@ -23,8 +23,6 @@ from urllib import parse as urlparse
from mock import Mock
from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
async def _insert_client_ip(*args, **kwargs):
return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip

View file

@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock()
def get_user_by_access_token(token=None, allow_guest=False):
async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
async def _insert_client_ip(*args, **kwargs):
return None
hs.get_datastore().insert_client_ip = _insert_client_ip

View file

@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)

View file

@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False):
return succeed(
{
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1,
"is_guest": False,
}
)
async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1,
"is_guest": False,
}
def get_user_by_req(request, allow_guest=False, rights="access"):
return succeed(
create_requester(
UserID.from_string(self.helper.auth_user_id), 1, False, None
)
async def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester(
UserID.from_string(self.helper.auth_user_id), 1, False, None
)
self.hs.get_auth().get_user_by_req = get_user_by_req