forked from MirrorHub/synapse
Convert _base, profile, and _receipts handlers to async/await (#7860)
This commit is contained in:
parent
fff483ea96
commit
6fca1b3506
6 changed files with 53 additions and 59 deletions
1
changelog.d/7860.misc
Normal file
1
changelog.d/7860.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert _base, profile, and _receipts handlers to async/await.
|
|
@ -15,8 +15,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
import synapse.state
|
import synapse.state
|
||||||
import synapse.storage
|
import synapse.storage
|
||||||
import synapse.types
|
import synapse.types
|
||||||
|
@ -66,8 +64,7 @@ class BaseHandler(object):
|
||||||
|
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def ratelimit(self, requester, update=True, is_admin_redaction=False):
|
||||||
def ratelimit(self, requester, update=True, is_admin_redaction=False):
|
|
||||||
"""Ratelimits requests.
|
"""Ratelimits requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -99,7 +96,7 @@ class BaseHandler(object):
|
||||||
burst_count = self._rc_message.burst_count
|
burst_count = self._rc_message.burst_count
|
||||||
|
|
||||||
# Check if there is a per user override in the DB.
|
# Check if there is a per user override in the DB.
|
||||||
override = yield self.store.get_ratelimit_for_user(user_id)
|
override = await self.store.get_ratelimit_for_user(user_id)
|
||||||
if override:
|
if override:
|
||||||
# If overridden with a null Hz then ratelimiting has been entirely
|
# If overridden with a null Hz then ratelimiting has been entirely
|
||||||
# disabled for the user
|
# disabled for the user
|
||||||
|
|
|
@ -488,11 +488,15 @@ class EventCreationHandler(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "displayname" not in content:
|
if "displayname" not in content:
|
||||||
displayname = yield profile.get_displayname(target)
|
displayname = yield defer.ensureDeferred(
|
||||||
|
profile.get_displayname(target)
|
||||||
|
)
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
content["displayname"] = displayname
|
content["displayname"] = displayname
|
||||||
if "avatar_url" not in content:
|
if "avatar_url" not in content:
|
||||||
avatar_url = yield profile.get_avatar_url(target)
|
avatar_url = yield defer.ensureDeferred(
|
||||||
|
profile.get_avatar_url(target)
|
||||||
|
)
|
||||||
if avatar_url is not None:
|
if avatar_url is not None:
|
||||||
content["avatar_url"] = avatar_url
|
content["avatar_url"] = avatar_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
Codes,
|
Codes,
|
||||||
|
@ -54,16 +52,15 @@ class BaseProfileHandler(BaseHandler):
|
||||||
|
|
||||||
self.user_directory_handler = hs.get_user_directory_handler()
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_profile(self, user_id):
|
||||||
def get_profile(self, user_id):
|
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
displayname = yield self.store.get_profile_displayname(
|
displayname = await self.store.get_profile_displayname(
|
||||||
target_user.localpart
|
target_user.localpart
|
||||||
)
|
)
|
||||||
avatar_url = yield self.store.get_profile_avatar_url(
|
avatar_url = await self.store.get_profile_avatar_url(
|
||||||
target_user.localpart
|
target_user.localpart
|
||||||
)
|
)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
|
@ -74,7 +71,7 @@ class BaseProfileHandler(BaseHandler):
|
||||||
return {"displayname": displayname, "avatar_url": avatar_url}
|
return {"displayname": displayname, "avatar_url": avatar_url}
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
result = yield self.federation.make_query(
|
result = await self.federation.make_query(
|
||||||
destination=target_user.domain,
|
destination=target_user.domain,
|
||||||
query_type="profile",
|
query_type="profile",
|
||||||
args={"user_id": user_id},
|
args={"user_id": user_id},
|
||||||
|
@ -86,8 +83,7 @@ class BaseProfileHandler(BaseHandler):
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
raise e.to_synapse_error()
|
raise e.to_synapse_error()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_profile_from_cache(self, user_id):
|
||||||
def get_profile_from_cache(self, user_id):
|
|
||||||
"""Get the profile information from our local cache. If the user is
|
"""Get the profile information from our local cache. If the user is
|
||||||
ours then the profile information will always be corect. Otherwise,
|
ours then the profile information will always be corect. Otherwise,
|
||||||
it may be out of date/missing.
|
it may be out of date/missing.
|
||||||
|
@ -95,10 +91,10 @@ class BaseProfileHandler(BaseHandler):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
displayname = yield self.store.get_profile_displayname(
|
displayname = await self.store.get_profile_displayname(
|
||||||
target_user.localpart
|
target_user.localpart
|
||||||
)
|
)
|
||||||
avatar_url = yield self.store.get_profile_avatar_url(
|
avatar_url = await self.store.get_profile_avatar_url(
|
||||||
target_user.localpart
|
target_user.localpart
|
||||||
)
|
)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
|
@ -108,14 +104,13 @@ class BaseProfileHandler(BaseHandler):
|
||||||
|
|
||||||
return {"displayname": displayname, "avatar_url": avatar_url}
|
return {"displayname": displayname, "avatar_url": avatar_url}
|
||||||
else:
|
else:
|
||||||
profile = yield self.store.get_from_remote_profile_cache(user_id)
|
profile = await self.store.get_from_remote_profile_cache(user_id)
|
||||||
return profile or {}
|
return profile or {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_displayname(self, target_user):
|
||||||
def get_displayname(self, target_user):
|
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
displayname = yield self.store.get_profile_displayname(
|
displayname = await self.store.get_profile_displayname(
|
||||||
target_user.localpart
|
target_user.localpart
|
||||||
)
|
)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
|
@ -126,7 +121,7 @@ class BaseProfileHandler(BaseHandler):
|
||||||
return displayname
|
return displayname
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
result = yield self.federation.make_query(
|
result = await self.federation.make_query(
|
||||||
destination=target_user.domain,
|
destination=target_user.domain,
|
||||||
query_type="profile",
|
query_type="profile",
|
||||||
args={"user_id": target_user.to_string(), "field": "displayname"},
|
args={"user_id": target_user.to_string(), "field": "displayname"},
|
||||||
|
@ -189,11 +184,10 @@ class BaseProfileHandler(BaseHandler):
|
||||||
|
|
||||||
await self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_avatar_url(self, target_user):
|
||||||
def get_avatar_url(self, target_user):
|
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
avatar_url = yield self.store.get_profile_avatar_url(
|
avatar_url = await self.store.get_profile_avatar_url(
|
||||||
target_user.localpart
|
target_user.localpart
|
||||||
)
|
)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
|
@ -203,7 +197,7 @@ class BaseProfileHandler(BaseHandler):
|
||||||
return avatar_url
|
return avatar_url
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
result = yield self.federation.make_query(
|
result = await self.federation.make_query(
|
||||||
destination=target_user.domain,
|
destination=target_user.domain,
|
||||||
query_type="profile",
|
query_type="profile",
|
||||||
args={"user_id": target_user.to_string(), "field": "avatar_url"},
|
args={"user_id": target_user.to_string(), "field": "avatar_url"},
|
||||||
|
@ -253,8 +247,7 @@ class BaseProfileHandler(BaseHandler):
|
||||||
|
|
||||||
await self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_profile_query(self, args):
|
||||||
def on_profile_query(self, args):
|
|
||||||
user = UserID.from_string(args["user_id"])
|
user = UserID.from_string(args["user_id"])
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||||
|
@ -264,12 +257,12 @@ class BaseProfileHandler(BaseHandler):
|
||||||
response = {}
|
response = {}
|
||||||
try:
|
try:
|
||||||
if just_field is None or just_field == "displayname":
|
if just_field is None or just_field == "displayname":
|
||||||
response["displayname"] = yield self.store.get_profile_displayname(
|
response["displayname"] = await self.store.get_profile_displayname(
|
||||||
user.localpart
|
user.localpart
|
||||||
)
|
)
|
||||||
|
|
||||||
if just_field is None or just_field == "avatar_url":
|
if just_field is None or just_field == "avatar_url":
|
||||||
response["avatar_url"] = yield self.store.get_profile_avatar_url(
|
response["avatar_url"] = await self.store.get_profile_avatar_url(
|
||||||
user.localpart
|
user.localpart
|
||||||
)
|
)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
|
@ -304,8 +297,7 @@ class BaseProfileHandler(BaseHandler):
|
||||||
"Failed to update join event for room %s - %s", room_id, str(e)
|
"Failed to update join event for room %s - %s", room_id, str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_profile_query_allowed(self, target_user, requester=None):
|
||||||
def check_profile_query_allowed(self, target_user, requester=None):
|
|
||||||
"""Checks whether a profile query is allowed. If the
|
"""Checks whether a profile query is allowed. If the
|
||||||
'require_auth_for_profile_requests' config flag is set to True and a
|
'require_auth_for_profile_requests' config flag is set to True and a
|
||||||
'requester' is provided, the query is only allowed if the two users
|
'requester' is provided, the query is only allowed if the two users
|
||||||
|
@ -337,8 +329,8 @@ class BaseProfileHandler(BaseHandler):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
|
requester_rooms = await self.store.get_rooms_for_user(requester.to_string())
|
||||||
target_user_rooms = yield self.store.get_rooms_for_user(
|
target_user_rooms = await self.store.get_rooms_for_user(
|
||||||
target_user.to_string()
|
target_user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -371,25 +363,24 @@ class MasterProfileHandler(BaseProfileHandler):
|
||||||
"Update remote profile", self._update_remote_profile_cache
|
"Update remote profile", self._update_remote_profile_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _update_remote_profile_cache(self):
|
||||||
def _update_remote_profile_cache(self):
|
|
||||||
"""Called periodically to check profiles of remote users we haven't
|
"""Called periodically to check profiles of remote users we haven't
|
||||||
checked in a while.
|
checked in a while.
|
||||||
"""
|
"""
|
||||||
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
|
entries = await self.store.get_remote_profile_cache_entries_that_expire(
|
||||||
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
|
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
|
||||||
)
|
)
|
||||||
|
|
||||||
for user_id, displayname, avatar_url in entries:
|
for user_id, displayname, avatar_url in entries:
|
||||||
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
|
is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
|
||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
if not is_subscribed:
|
if not is_subscribed:
|
||||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
await self.store.maybe_delete_remote_profile_cache(user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
profile = yield self.federation.make_query(
|
profile = await self.federation.make_query(
|
||||||
destination=get_domain_from_id(user_id),
|
destination=get_domain_from_id(user_id),
|
||||||
query_type="profile",
|
query_type="profile",
|
||||||
args={"user_id": user_id},
|
args={"user_id": user_id},
|
||||||
|
@ -398,7 +389,7 @@ class MasterProfileHandler(BaseProfileHandler):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to get avatar_url")
|
logger.exception("Failed to get avatar_url")
|
||||||
|
|
||||||
yield self.store.update_remote_profile_cache(
|
await self.store.update_remote_profile_cache(
|
||||||
user_id, displayname, avatar_url
|
user_id, displayname, avatar_url
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
@ -407,4 +398,4 @@ class MasterProfileHandler(BaseProfileHandler):
|
||||||
new_avatar = profile.get("avatar_url")
|
new_avatar = profile.get("avatar_url")
|
||||||
|
|
||||||
# We always hit update to update the last_check timestamp
|
# We always hit update to update the last_check timestamp
|
||||||
yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
|
await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
|
||||||
|
|
|
@ -14,8 +14,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.types import ReadReceipt, get_domain_from_id
|
from synapse.types import ReadReceipt, get_domain_from_id
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
@ -129,15 +127,14 @@ class ReceiptEventSource(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_new_events(self, from_key, room_ids, **kwargs):
|
||||||
def get_new_events(self, from_key, room_ids, **kwargs):
|
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
to_key = yield self.get_current_key()
|
to_key = self.get_current_key()
|
||||||
|
|
||||||
if from_key == to_key:
|
if from_key == to_key:
|
||||||
return [], to_key
|
return [], to_key
|
||||||
|
|
||||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
events = await self.store.get_linearized_receipts_for_rooms(
|
||||||
room_ids, from_key=from_key, to_key=to_key
|
room_ids, from_key=from_key, to_key=to_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -146,8 +143,7 @@ class ReceiptEventSource(object):
|
||||||
def get_current_key(self, direction="f"):
|
def get_current_key(self, direction="f"):
|
||||||
return self.store.get_max_receipt_stream_id()
|
return self.store.get_max_receipt_stream_id()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_pagination_rows(self, user, config, key):
|
||||||
def get_pagination_rows(self, user, config, key):
|
|
||||||
to_key = int(config.from_key)
|
to_key = int(config.from_key)
|
||||||
|
|
||||||
if config.to_key:
|
if config.to_key:
|
||||||
|
@ -155,8 +151,8 @@ class ReceiptEventSource(object):
|
||||||
else:
|
else:
|
||||||
from_key = None
|
from_key = None
|
||||||
|
|
||||||
room_ids = yield self.store.get_rooms_for_user(user.to_string())
|
room_ids = await self.store.get_rooms_for_user(user.to_string())
|
||||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
events = await self.store.get_linearized_receipts_for_rooms(
|
||||||
room_ids, from_key=from_key, to_key=to_key
|
room_ids, from_key=from_key, to_key=to_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
def test_get_my_name(self):
|
def test_get_my_name(self):
|
||||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||||
|
|
||||||
displayname = yield self.handler.get_displayname(self.frank)
|
displayname = yield defer.ensureDeferred(
|
||||||
|
self.handler.get_displayname(self.frank)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEquals("Frank", displayname)
|
self.assertEquals("Frank", displayname)
|
||||||
|
|
||||||
|
@ -140,7 +142,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
{"displayname": "Alice"}
|
{"displayname": "Alice"}
|
||||||
)
|
)
|
||||||
|
|
||||||
displayname = yield self.handler.get_displayname(self.alice)
|
displayname = yield defer.ensureDeferred(
|
||||||
|
self.handler.get_displayname(self.alice)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEquals(displayname, "Alice")
|
self.assertEquals(displayname, "Alice")
|
||||||
self.mock_federation.make_query.assert_called_with(
|
self.mock_federation.make_query.assert_called_with(
|
||||||
|
@ -155,9 +159,11 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
yield self.store.create_profile("caroline")
|
yield self.store.create_profile("caroline")
|
||||||
yield self.store.set_profile_displayname("caroline", "Caroline")
|
yield self.store.set_profile_displayname("caroline", "Caroline")
|
||||||
|
|
||||||
response = yield self.query_handlers["profile"](
|
response = yield defer.ensureDeferred(
|
||||||
|
self.query_handlers["profile"](
|
||||||
{"user_id": "@caroline:test", "field": "displayname"}
|
{"user_id": "@caroline:test", "field": "displayname"}
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEquals({"displayname": "Caroline"}, response)
|
self.assertEquals({"displayname": "Caroline"}, response)
|
||||||
|
|
||||||
|
@ -166,8 +172,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
yield self.store.set_profile_avatar_url(
|
yield self.store.set_profile_avatar_url(
|
||||||
self.frank.localpart, "http://my.server/me.png"
|
self.frank.localpart, "http://my.server/me.png"
|
||||||
)
|
)
|
||||||
|
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
|
||||||
avatar_url = yield self.handler.get_avatar_url(self.frank)
|
|
||||||
|
|
||||||
self.assertEquals("http://my.server/me.png", avatar_url)
|
self.assertEquals("http://my.server/me.png", avatar_url)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue