0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-22 07:30:26 +01:00

Merge pull request #2429 from matrix-org/erikj/groups_profile_cache

Add a remote user profile cache
This commit is contained in:
Erik Johnston 2017-08-25 15:52:42 +01:00 committed by GitHub
commit 6d8799af1a
15 changed files with 292 additions and 47 deletions

View file

@ -503,6 +503,13 @@ class GroupsServerHandler(object):
get_domain_from_id(user_id), group_id, user_id, content get_domain_from_id(user_id), group_id, user_id, content
) )
user_profile = res.get("user_profile", {})
yield self.store.add_remote_profile_cache(
user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
)
if res["state"] == "join": if res["state"] == "join":
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
remote_attestation = res["attestation"] remote_attestation = res["attestation"]
@ -627,6 +634,9 @@ class GroupsServerHandler(object):
get_domain_from_id(user_id), group_id, user_id, {} get_domain_from_id(user_id), group_id, user_id, {}
) )
if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id)
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
@ -647,6 +657,7 @@ class GroupsServerHandler(object):
avatar_url = profile.get("avatar_url") avatar_url = profile.get("avatar_url")
short_description = profile.get("short_description") short_description = profile.get("short_description")
long_description = profile.get("long_description") long_description = profile.get("long_description")
user_profile = content.get("user_profile", {})
yield self.store.create_group( yield self.store.create_group(
group_id, group_id,
@ -679,6 +690,13 @@ class GroupsServerHandler(object):
remote_attestation=remote_attestation, remote_attestation=remote_attestation,
) )
if not self.hs.is_mine_id(user_id):
yield self.store.add_remote_profile_cache(
user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
)
defer.returnValue({ defer.returnValue({
"group_id": group_id, "group_id": group_id,
}) })

View file

@ -20,7 +20,6 @@ from .room import (
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .admin import AdminHandler from .admin import AdminHandler
from .identity import IdentityHandler from .identity import IdentityHandler
@ -52,7 +51,6 @@ class Handlers(object):
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)

View file

@ -56,6 +56,8 @@ class GroupsLocalHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.attestations = hs.get_groups_attestation_signing() self.attestations = hs.get_groups_attestation_signing()
self.profile_handler = hs.get_profile_handler()
# Ensure attestations get renewed # Ensure attestations get renewed
hs.get_groups_attestation_renewer() hs.get_groups_attestation_renewer()
@ -123,6 +125,7 @@ class GroupsLocalHandler(object):
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content): def create_group(self, group_id, user_id, content):
"""Create a group """Create a group
""" """
@ -130,13 +133,16 @@ class GroupsLocalHandler(object):
logger.info("Asking to create group with ID: %r", group_id) logger.info("Asking to create group with ID: %r", group_id)
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
return self.groups_server_handler.create_group( res = yield self.groups_server_handler.create_group(
group_id, user_id, content group_id, user_id, content
) )
defer.returnValue(res)
return self.transport_client.create_group( content["user_profile"] = yield self.profile_handler.get_profile(user_id)
res = yield self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content, get_domain_from_id(group_id), group_id, user_id, content,
) # TODO )
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_users_in_group(self, group_id, requester_user_id): def get_users_in_group(self, group_id, requester_user_id):
@ -265,7 +271,9 @@ class GroupsLocalHandler(object):
"groups_key", token, users=[user_id], "groups_key", token, users=[user_id],
) )
defer.returnValue({"state": "invite"}) user_profile = yield self.profile_handler.get_profile(user_id)
defer.returnValue({"state": "invite", "user_profile": user_profile})
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content): def remove_user_from_group(self, group_id, user_id, requester_user_id, content):

View file

@ -47,6 +47,7 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
@ -210,7 +211,7 @@ class MessageHandler(BaseHandler):
if membership in {Membership.JOIN, Membership.INVITE}: if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
profile = self.hs.get_handlers().profile_handler profile = self.profile_handler
content = builder.content content = builder.content
try: try:

View file

@ -19,14 +19,15 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID from synapse.types import UserID, get_domain_from_id
from ._base import BaseHandler from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler): class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) super(ProfileHandler, self).__init__(hs)
@ -36,6 +37,40 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query "profile", self.on_profile_query
) )
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
displayname = yield self.store.get_profile_displayname(
target_user.localpart
)
avatar_url = yield self.store.get_profile_avatar_url(
target_user.localpart
)
defer.returnValue({
"displayname": displayname,
"avatar_url": avatar_url,
})
else:
try:
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={
"user_id": user_id,
},
ignore_backoff=True,
)
defer.returnValue(result)
except CodeMessageException as e:
if e.code != 404:
logger.exception("Failed to get displayname")
raise
@defer.inlineCallbacks @defer.inlineCallbacks
def get_displayname(self, target_user): def get_displayname(self, target_user):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
@ -182,3 +217,44 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", "Failed to update join event for room %s - %s",
room_id, str(e.message) room_id, str(e.message)
) )
def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
)
for user_id, displayname, avatar_url in entries:
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
user_id,
)
if not is_subscribed:
yield self.store.maybe_delete_remote_profile_cache(user_id)
continue
try:
profile = yield self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
args={
"user_id": user_id,
},
ignore_backoff=True,
)
except:
logger.exception("Failed to get avatar_url")
yield self.store.update_remote_profile_cache(
user_id, displayname, avatar_url
)
continue
new_name = profile.get("displayname")
new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp
yield self.store.update_remote_profile_cache(
user_id, new_name, new_avatar
)

View file

@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.profile_handler = hs.get_profile_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None self._next_generated_user_id = None
@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler yield self.profile_handler.set_displayname(
yield profile_handler.set_displayname(
user, requester, displayname, by_admin=True, user, requester, displayname, by_admin=True,
) )

View file

@ -45,6 +45,8 @@ class RoomMemberHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs) super(RoomMemberHandler, self).__init__(hs)
self.profile_handler = hs.get_profile_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -255,7 +257,7 @@ class RoomMemberHandler(BaseHandler):
content["membership"] = Membership.JOIN content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler profile = self.profile_handler
if not content_specified: if not content_specified:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = yield profile.get_avatar_url(target)

View file

@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs) super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.profile_handler.get_displayname(
user, user,
) )
@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.profile_handler.set_displayname(
user, requester, new_name, is_admin) user, requester, new_name, is_admin)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs) super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url( avatar_url = yield self.profile_handler.get_avatar_url(
user, user,
) )
@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url( yield self.profile_handler.set_avatar_url(
user, requester, new_name, is_admin) user, requester, new_name, is_admin)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs) super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.profile_handler.get_displayname(
user, user,
) )
avatar_url = yield self.handlers.profile_handler.get_avatar_url( avatar_url = yield self.profile_handler.get_avatar_url(
user, user,
) )

View file

@ -51,6 +51,7 @@ from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler from synapse.handlers.user_directory import UserDirectoyHandler
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@ -114,6 +115,7 @@ class HomeServer(object):
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
'device_message_handler', 'device_message_handler',
'profile_handler',
'notifier', 'notifier',
'distributor', 'distributor',
'client_resource', 'client_resource',
@ -258,6 +260,9 @@ class HomeServer(object):
def build_initial_sync_handler(self): def build_initial_sync_handler(self):
return InitialSyncHandler(self) return InitialSyncHandler(self)
def build_profile_handler(self):
return ProfileHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)

View file

@ -743,6 +743,33 @@ class SQLBaseStore(object):
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def _simple_update(self, table, keyvalues, updatevalues, desc):
return self.runInteraction(
desc,
self._simple_update_txn,
table, keyvalues, updatevalues,
)
@staticmethod
def _simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
update_sql = "UPDATE %s SET %s %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
where,
)
txn.execute(
update_sql,
updatevalues.values() + keyvalues.values()
)
return txn.rowcount
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"): desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for """Executes an UPDATE query on the named table, setting new values for
@ -768,27 +795,13 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues, table, keyvalues, updatevalues,
) )
@staticmethod @classmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues): def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
if keyvalues: rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
update_sql = "UPDATE %s SET %s %s" % ( if rowcount == 0:
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
where,
)
txn.execute(
update_sql,
updatevalues.values() + keyvalues.values()
)
if txn.rowcount == 0:
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
if txn.rowcount > 1: if rowcount > 1:
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
@staticmethod @staticmethod

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -55,3 +57,99 @@ class ProfileStore(SQLBaseStore):
updatevalues={"avatar_url": new_avatar_url}, updatevalues={"avatar_url": new_avatar_url},
desc="set_profile_avatar_url", desc="set_profile_avatar_url",
) )
def get_from_remote_profile_cache(self, user_id):
return self._simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url", "last_check"),
allow_none=True,
desc="get_from_remote_profile_cache",
)
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
return self._simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
},
desc="add_remote_profile_cache",
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
return self._simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
},
desc="update_remote_profile_cache",
)
@defer.inlineCallbacks
def maybe_delete_remote_profile_cache(self, user_id):
"""Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache
"""
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
yield self._simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
)
def get_remote_profile_cache_entries_that_expire(self, last_checked):
"""Get all users who haven't been checked since `last_checked`
"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """
SELECT user_id, displayname, avatar_url
FROM remote_profile_cache
WHERE last_check < ?
"""
txn.execute(sql, (last_checked,))
return self.cursor_to_dict(txn)
return self.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
@defer.inlineCallbacks
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
res = yield self._simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
allow_none=True,
desc="should_update_remote_profile_cache_for_user",
)
if res:
defer.returnValue(True)
res = yield self._simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
allow_none=True,
desc="should_update_remote_profile_cache_for_user",
)
if res:
defer.returnValue(True)

View file

@ -0,0 +1,28 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- A subset of remote users whose profiles we have cached.
-- Whether a user is in this table or not is defined by the storage function
-- `is_subscribed_remote_profile_for_user`
CREATE TABLE remote_profile_cache (
user_id TEXT NOT NULL,
displayname TEXT,
avatar_url TEXT,
last_check BIGINT NOT NULL
);
CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);

View file

@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
hs.handlers = ProfileHandlers(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test") self.frank = UserID.from_string("@1234ABCD:test")
@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile(self.frank.localpart) yield self.store.create_profile(self.frank.localpart)
self.handler = hs.get_handlers().profile_handler self.handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):

View file

@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver( self.hs = yield setup_test_homeserver(
handlers=None, handlers=None,
http_client=None, http_client=None,
expire_access_token=True) expire_access_token=True,
profile_handler=Mock(),
)
self.macaroon_generator = Mock( self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')) generate_access_token=Mock(return_value='secret'))
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):

View file

@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase):
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
federation=Mock(), federation=Mock(),
replication_layer=Mock(), replication_layer=Mock(),
profile_handler=self.mock_handler
) )
def _get_user_by_req(request=None, allow_guest=False): def _get_user_by_req(request=None, allow_guest=False):
@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase):
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req
hs.get_handlers().profile_handler = self.mock_handler
profile.register_servlets(hs, self.mock_resource) profile.register_servlets(hs, self.mock_resource)
@defer.inlineCallbacks @defer.inlineCallbacks