mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-22 07:30:26 +01:00
Merge pull request #2429 from matrix-org/erikj/groups_profile_cache
Add a remote user profile cache
This commit is contained in:
commit
6d8799af1a
15 changed files with 292 additions and 47 deletions
|
@ -503,6 +503,13 @@ class GroupsServerHandler(object):
|
||||||
get_domain_from_id(user_id), group_id, user_id, content
|
get_domain_from_id(user_id), group_id, user_id, content
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_profile = res.get("user_profile", {})
|
||||||
|
yield self.store.add_remote_profile_cache(
|
||||||
|
user_id,
|
||||||
|
displayname=user_profile.get("displayname"),
|
||||||
|
avatar_url=user_profile.get("avatar_url"),
|
||||||
|
)
|
||||||
|
|
||||||
if res["state"] == "join":
|
if res["state"] == "join":
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(user_id):
|
||||||
remote_attestation = res["attestation"]
|
remote_attestation = res["attestation"]
|
||||||
|
@ -627,6 +634,9 @@ class GroupsServerHandler(object):
|
||||||
get_domain_from_id(user_id), group_id, user_id, {}
|
get_domain_from_id(user_id), group_id, user_id, {}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||||
|
|
||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -647,6 +657,7 @@ class GroupsServerHandler(object):
|
||||||
avatar_url = profile.get("avatar_url")
|
avatar_url = profile.get("avatar_url")
|
||||||
short_description = profile.get("short_description")
|
short_description = profile.get("short_description")
|
||||||
long_description = profile.get("long_description")
|
long_description = profile.get("long_description")
|
||||||
|
user_profile = content.get("user_profile", {})
|
||||||
|
|
||||||
yield self.store.create_group(
|
yield self.store.create_group(
|
||||||
group_id,
|
group_id,
|
||||||
|
@ -679,6 +690,13 @@ class GroupsServerHandler(object):
|
||||||
remote_attestation=remote_attestation,
|
remote_attestation=remote_attestation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
yield self.store.add_remote_profile_cache(
|
||||||
|
user_id,
|
||||||
|
displayname=user_profile.get("displayname"),
|
||||||
|
avatar_url=user_profile.get("avatar_url"),
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"group_id": group_id,
|
"group_id": group_id,
|
||||||
})
|
})
|
||||||
|
|
|
@ -20,7 +20,6 @@ from .room import (
|
||||||
from .room_member import RoomMemberHandler
|
from .room_member import RoomMemberHandler
|
||||||
from .message import MessageHandler
|
from .message import MessageHandler
|
||||||
from .federation import FederationHandler
|
from .federation import FederationHandler
|
||||||
from .profile import ProfileHandler
|
|
||||||
from .directory import DirectoryHandler
|
from .directory import DirectoryHandler
|
||||||
from .admin import AdminHandler
|
from .admin import AdminHandler
|
||||||
from .identity import IdentityHandler
|
from .identity import IdentityHandler
|
||||||
|
@ -52,7 +51,6 @@ class Handlers(object):
|
||||||
self.room_creation_handler = RoomCreationHandler(hs)
|
self.room_creation_handler = RoomCreationHandler(hs)
|
||||||
self.room_member_handler = RoomMemberHandler(hs)
|
self.room_member_handler = RoomMemberHandler(hs)
|
||||||
self.federation_handler = FederationHandler(hs)
|
self.federation_handler = FederationHandler(hs)
|
||||||
self.profile_handler = ProfileHandler(hs)
|
|
||||||
self.directory_handler = DirectoryHandler(hs)
|
self.directory_handler = DirectoryHandler(hs)
|
||||||
self.admin_handler = AdminHandler(hs)
|
self.admin_handler = AdminHandler(hs)
|
||||||
self.identity_handler = IdentityHandler(hs)
|
self.identity_handler = IdentityHandler(hs)
|
||||||
|
|
|
@ -56,6 +56,8 @@ class GroupsLocalHandler(object):
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.attestations = hs.get_groups_attestation_signing()
|
self.attestations = hs.get_groups_attestation_signing()
|
||||||
|
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
# Ensure attestations get renewed
|
# Ensure attestations get renewed
|
||||||
hs.get_groups_attestation_renewer()
|
hs.get_groups_attestation_renewer()
|
||||||
|
|
||||||
|
@ -123,6 +125,7 @@ class GroupsLocalHandler(object):
|
||||||
|
|
||||||
defer.returnValue(res)
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def create_group(self, group_id, user_id, content):
|
def create_group(self, group_id, user_id, content):
|
||||||
"""Create a group
|
"""Create a group
|
||||||
"""
|
"""
|
||||||
|
@ -130,13 +133,16 @@ class GroupsLocalHandler(object):
|
||||||
logger.info("Asking to create group with ID: %r", group_id)
|
logger.info("Asking to create group with ID: %r", group_id)
|
||||||
|
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
return self.groups_server_handler.create_group(
|
res = yield self.groups_server_handler.create_group(
|
||||||
group_id, user_id, content
|
group_id, user_id, content
|
||||||
)
|
)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
return self.transport_client.create_group(
|
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
|
||||||
|
res = yield self.transport_client.create_group(
|
||||||
get_domain_from_id(group_id), group_id, user_id, content,
|
get_domain_from_id(group_id), group_id, user_id, content,
|
||||||
) # TODO
|
)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_users_in_group(self, group_id, requester_user_id):
|
def get_users_in_group(self, group_id, requester_user_id):
|
||||||
|
@ -265,7 +271,9 @@ class GroupsLocalHandler(object):
|
||||||
"groups_key", token, users=[user_id],
|
"groups_key", token, users=[user_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue({"state": "invite"})
|
user_profile = yield self.profile_handler.get_profile(user_id)
|
||||||
|
|
||||||
|
defer.returnValue({"state": "invite", "user_profile": user_profile})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||||
|
|
|
@ -47,6 +47,7 @@ class MessageHandler(BaseHandler):
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.validator = EventValidator()
|
self.validator = EventValidator()
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
self.pagination_lock = ReadWriteLock()
|
self.pagination_lock = ReadWriteLock()
|
||||||
|
|
||||||
|
@ -210,7 +211,7 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||||
# If event doesn't include a display name, add one.
|
# If event doesn't include a display name, add one.
|
||||||
profile = self.hs.get_handlers().profile_handler
|
profile = self.profile_handler
|
||||||
content = builder.content
|
content = builder.content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -19,14 +19,15 @@ from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.types
|
import synapse.types
|
||||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, get_domain_from_id
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProfileHandler(BaseHandler):
|
class ProfileHandler(BaseHandler):
|
||||||
|
PROFILE_UPDATE_MS = 60 * 1000
|
||||||
|
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileHandler, self).__init__(hs)
|
super(ProfileHandler, self).__init__(hs)
|
||||||
|
@ -36,6 +37,40 @@ class ProfileHandler(BaseHandler):
|
||||||
"profile", self.on_profile_query
|
"profile", self.on_profile_query
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_profile(self, user_id):
|
||||||
|
target_user = UserID.from_string(user_id)
|
||||||
|
if self.hs.is_mine(target_user):
|
||||||
|
displayname = yield self.store.get_profile_displayname(
|
||||||
|
target_user.localpart
|
||||||
|
)
|
||||||
|
avatar_url = yield self.store.get_profile_avatar_url(
|
||||||
|
target_user.localpart
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"displayname": displayname,
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
result = yield self.federation.make_query(
|
||||||
|
destination=target_user.domain,
|
||||||
|
query_type="profile",
|
||||||
|
args={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
except CodeMessageException as e:
|
||||||
|
if e.code != 404:
|
||||||
|
logger.exception("Failed to get displayname")
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_displayname(self, target_user):
|
def get_displayname(self, target_user):
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
|
@ -182,3 +217,44 @@ class ProfileHandler(BaseHandler):
|
||||||
"Failed to update join event for room %s - %s",
|
"Failed to update join event for room %s - %s",
|
||||||
room_id, str(e.message)
|
room_id, str(e.message)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _update_remote_profile_cache(self):
|
||||||
|
"""Called periodically to check profiles of remote users we haven't
|
||||||
|
checked in a while.
|
||||||
|
"""
|
||||||
|
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
|
||||||
|
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id, displayname, avatar_url in entries:
|
||||||
|
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
if not is_subscribed:
|
||||||
|
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
profile = yield self.federation.make_query(
|
||||||
|
destination=get_domain_from_id(user_id),
|
||||||
|
query_type="profile",
|
||||||
|
args={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.exception("Failed to get avatar_url")
|
||||||
|
|
||||||
|
yield self.store.update_remote_profile_cache(
|
||||||
|
user_id, displayname, avatar_url
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_name = profile.get("displayname")
|
||||||
|
new_avatar = profile.get("avatar_url")
|
||||||
|
|
||||||
|
# We always hit update to update the last_check timestamp
|
||||||
|
yield self.store.update_remote_profile_cache(
|
||||||
|
user_id, new_name, new_avatar
|
||||||
|
)
|
||||||
|
|
|
@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
super(RegistrationHandler, self).__init__(hs)
|
super(RegistrationHandler, self).__init__(hs)
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||||
|
|
||||||
self._next_generated_user_id = None
|
self._next_generated_user_id = None
|
||||||
|
@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
profile_handler = self.hs.get_handlers().profile_handler
|
yield self.profile_handler.set_displayname(
|
||||||
yield profile_handler.set_displayname(
|
|
||||||
user, requester, displayname, by_admin=True,
|
user, requester, displayname, by_admin=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,8 @@ class RoomMemberHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomMemberHandler, self).__init__(hs)
|
super(RoomMemberHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
self.member_linearizer = Linearizer(name="member")
|
self.member_linearizer = Linearizer(name="member")
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
@ -255,7 +257,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
|
|
||||||
content["membership"] = Membership.JOIN
|
content["membership"] = Membership.JOIN
|
||||||
|
|
||||||
profile = self.hs.get_handlers().profile_handler
|
profile = self.profile_handler
|
||||||
if not content_specified:
|
if not content_specified:
|
||||||
content["displayname"] = yield profile.get_displayname(target)
|
content["displayname"] = yield profile.get_displayname(target)
|
||||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||||
|
|
|
@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
displayname = yield self.profile_handler.get_displayname(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||||
except:
|
except:
|
||||||
defer.returnValue((400, "Unable to parse name"))
|
defer.returnValue((400, "Unable to parse name"))
|
||||||
|
|
||||||
yield self.handlers.profile_handler.set_displayname(
|
yield self.profile_handler.set_displayname(
|
||||||
user, requester, new_name, is_admin)
|
user, requester, new_name, is_admin)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
avatar_url = yield self.profile_handler.get_avatar_url(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||||
except:
|
except:
|
||||||
defer.returnValue((400, "Unable to parse name"))
|
defer.returnValue((400, "Unable to parse name"))
|
||||||
|
|
||||||
yield self.handlers.profile_handler.set_avatar_url(
|
yield self.profile_handler.set_avatar_url(
|
||||||
user, requester, new_name, is_admin)
|
user, requester, new_name, is_admin)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileRestServlet, self).__init__(hs)
|
super(ProfileRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
displayname = yield self.profile_handler.get_displayname(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
avatar_url = yield self.profile_handler.get_avatar_url(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ from synapse.handlers.receipts import ReceiptsHandler
|
||||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||||
from synapse.handlers.user_directory import UserDirectoyHandler
|
from synapse.handlers.user_directory import UserDirectoyHandler
|
||||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||||
|
from synapse.handlers.profile import ProfileHandler
|
||||||
from synapse.groups.groups_server import GroupsServerHandler
|
from synapse.groups.groups_server import GroupsServerHandler
|
||||||
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
|
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
|
||||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||||
|
@ -114,6 +115,7 @@ class HomeServer(object):
|
||||||
'application_service_scheduler',
|
'application_service_scheduler',
|
||||||
'application_service_handler',
|
'application_service_handler',
|
||||||
'device_message_handler',
|
'device_message_handler',
|
||||||
|
'profile_handler',
|
||||||
'notifier',
|
'notifier',
|
||||||
'distributor',
|
'distributor',
|
||||||
'client_resource',
|
'client_resource',
|
||||||
|
@ -258,6 +260,9 @@ class HomeServer(object):
|
||||||
def build_initial_sync_handler(self):
|
def build_initial_sync_handler(self):
|
||||||
return InitialSyncHandler(self)
|
return InitialSyncHandler(self)
|
||||||
|
|
||||||
|
def build_profile_handler(self):
|
||||||
|
return ProfileHandler(self)
|
||||||
|
|
||||||
def build_event_sources(self):
|
def build_event_sources(self):
|
||||||
return EventSources(self)
|
return EventSources(self)
|
||||||
|
|
||||||
|
|
|
@ -743,6 +743,33 @@ class SQLBaseStore(object):
|
||||||
txn.execute(sql, values)
|
txn.execute(sql, values)
|
||||||
return cls.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
def _simple_update(self, table, keyvalues, updatevalues, desc):
|
||||||
|
return self.runInteraction(
|
||||||
|
desc,
|
||||||
|
self._simple_update_txn,
|
||||||
|
table, keyvalues, updatevalues,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _simple_update_txn(txn, table, keyvalues, updatevalues):
|
||||||
|
if keyvalues:
|
||||||
|
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||||
|
else:
|
||||||
|
where = ""
|
||||||
|
|
||||||
|
update_sql = "UPDATE %s SET %s %s" % (
|
||||||
|
table,
|
||||||
|
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||||
|
where,
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
update_sql,
|
||||||
|
updatevalues.values() + keyvalues.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
return txn.rowcount
|
||||||
|
|
||||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||||
desc="_simple_update_one"):
|
desc="_simple_update_one"):
|
||||||
"""Executes an UPDATE query on the named table, setting new values for
|
"""Executes an UPDATE query on the named table, setting new values for
|
||||||
|
@ -768,27 +795,13 @@ class SQLBaseStore(object):
|
||||||
table, keyvalues, updatevalues,
|
table, keyvalues, updatevalues,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
|
||||||
if keyvalues:
|
rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
|
||||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
|
||||||
else:
|
|
||||||
where = ""
|
|
||||||
|
|
||||||
update_sql = "UPDATE %s SET %s %s" % (
|
if rowcount == 0:
|
||||||
table,
|
|
||||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
|
||||||
where,
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
update_sql,
|
|
||||||
updatevalues.values() + keyvalues.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
if txn.rowcount == 0:
|
|
||||||
raise StoreError(404, "No row found")
|
raise StoreError(404, "No row found")
|
||||||
if txn.rowcount > 1:
|
if rowcount > 1:
|
||||||
raise StoreError(500, "More than one row matched")
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,3 +57,99 @@ class ProfileStore(SQLBaseStore):
|
||||||
updatevalues={"avatar_url": new_avatar_url},
|
updatevalues={"avatar_url": new_avatar_url},
|
||||||
desc="set_profile_avatar_url",
|
desc="set_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_from_remote_profile_cache(self, user_id):
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=("displayname", "avatar_url", "last_check"),
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_from_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||||
|
"""Ensure we are caching the remote user's profiles.
|
||||||
|
|
||||||
|
This should only be called when `is_subscribed_remote_profile_for_user`
|
||||||
|
would return true for the user.
|
||||||
|
"""
|
||||||
|
return self._simple_upsert(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
values={
|
||||||
|
"displayname": displayname,
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
"last_check": self._clock.time_msec(),
|
||||||
|
},
|
||||||
|
desc="add_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||||
|
return self._simple_update(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
values={
|
||||||
|
"displayname": displayname,
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
"last_check": self._clock.time_msec(),
|
||||||
|
},
|
||||||
|
desc="update_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def maybe_delete_remote_profile_cache(self, user_id):
|
||||||
|
"""Check if we still care about the remote user's profile, and if we
|
||||||
|
don't then remove their profile from the cache
|
||||||
|
"""
|
||||||
|
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
|
||||||
|
if not subscribed:
|
||||||
|
yield self._simple_delete(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
desc="delete_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_remote_profile_cache_entries_that_expire(self, last_checked):
|
||||||
|
"""Get all users who haven't been checked since `last_checked`
|
||||||
|
"""
|
||||||
|
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
||||||
|
sql = """
|
||||||
|
SELECT user_id, displayname, avatar_url
|
||||||
|
FROM remote_profile_cache
|
||||||
|
WHERE last_check < ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (last_checked,))
|
||||||
|
|
||||||
|
return self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_remote_profile_cache_entries_that_expire",
|
||||||
|
_get_remote_profile_cache_entries_that_expire_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def is_subscribed_remote_profile_for_user(self, user_id):
|
||||||
|
"""Check whether we are interested in a remote user's profile.
|
||||||
|
"""
|
||||||
|
res = yield self._simple_select_one_onecol(
|
||||||
|
table="group_users",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcol="user_id",
|
||||||
|
allow_none=True,
|
||||||
|
desc="should_update_remote_profile_cache_for_user",
|
||||||
|
)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
res = yield self._simple_select_one_onecol(
|
||||||
|
table="group_invites",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcol="user_id",
|
||||||
|
allow_none=True,
|
||||||
|
desc="should_update_remote_profile_cache_for_user",
|
||||||
|
)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
28
synapse/storage/schema/delta/43/profile_cache.sql
Normal file
28
synapse/storage/schema/delta/43/profile_cache.sql
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
/* Copyright 2017 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
-- A subset of remote users whose profiles we have cached.
|
||||||
|
-- Whether a user is in this table or not is defined by the storage function
|
||||||
|
-- `is_subscribed_remote_profile_for_user`
|
||||||
|
CREATE TABLE remote_profile_cache (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
displayname TEXT,
|
||||||
|
avatar_url TEXT,
|
||||||
|
last_check BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
|
||||||
|
CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
|
|
@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
|
||||||
hs.handlers = ProfileHandlers(hs)
|
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
self.frank = UserID.from_string("@1234ABCD:test")
|
self.frank = UserID.from_string("@1234ABCD:test")
|
||||||
|
@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
|
|
||||||
yield self.store.create_profile(self.frank.localpart)
|
yield self.store.create_profile(self.frank.localpart)
|
||||||
|
|
||||||
self.handler = hs.get_handlers().profile_handler
|
self.handler = hs.get_profile_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_my_name(self):
|
def test_get_my_name(self):
|
||||||
|
|
|
@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
self.hs = yield setup_test_homeserver(
|
self.hs = yield setup_test_homeserver(
|
||||||
handlers=None,
|
handlers=None,
|
||||||
http_client=None,
|
http_client=None,
|
||||||
expire_access_token=True)
|
expire_access_token=True,
|
||||||
|
profile_handler=Mock(),
|
||||||
|
)
|
||||||
self.macaroon_generator = Mock(
|
self.macaroon_generator = Mock(
|
||||||
generate_access_token=Mock(return_value='secret'))
|
generate_access_token=Mock(return_value='secret'))
|
||||||
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||||
self.handler = self.hs.get_handlers().registration_handler
|
self.handler = self.hs.get_handlers().registration_handler
|
||||||
self.hs.get_handlers().profile_handler = Mock()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||||
|
|
|
@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
resource_for_client=self.mock_resource,
|
resource_for_client=self.mock_resource,
|
||||||
federation=Mock(),
|
federation=Mock(),
|
||||||
replication_layer=Mock(),
|
replication_layer=Mock(),
|
||||||
|
profile_handler=self.mock_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None, allow_guest=False):
|
def _get_user_by_req(request=None, allow_guest=False):
|
||||||
|
@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
hs.get_handlers().profile_handler = self.mock_handler
|
|
||||||
|
|
||||||
profile.register_servlets(hs, self.mock_resource)
|
profile.register_servlets(hs, self.mock_resource)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
Loading…
Add table
Reference in a new issue