diff --git a/changelog.d/6504.misc b/changelog.d/6504.misc new file mode 100644 index 0000000000..7c873459af --- /dev/null +++ b/changelog.d/6504.misc @@ -0,0 +1 @@ +Port handlers.account_data and handlers.account_validity to async/await. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 2d7e6df6e4..20ec1ca01b 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - class AccountDataEventSource(object): def __init__(self, hs): @@ -23,15 +21,14 @@ class AccountDataEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() - @defer.inlineCallbacks - def get_new_events(self, user, from_key, **kwargs): + async def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key - current_stream_id = yield self.store.get_max_account_data_stream_id() + current_stream_id = self.store.get_max_account_data_stream_id() results = [] - tags = yield self.store.get_updated_tags(user_id, last_stream_id) + tags = await self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): results.append( @@ -41,7 +38,7 @@ class AccountDataEventSource(object): ( account_data, room_account_data, - ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) + ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) @@ -54,6 +51,5 @@ class AccountDataEventSource(object): return results, current_stream_id - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): + async def get_pagination_rows(self, user, config, key): return [], config.to_id diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d04e0fe576..829f52eca1 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,8 +18,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText - -from twisted.internet import defer +from typing import List from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable @@ -78,42 +77,39 @@ class AccountValidityHandler(object): # run as a background process to make sure that the database transactions # have a logcontext to report to return run_as_background_process( - "send_renewals", self.send_renewal_emails + "send_renewals", self._send_renewal_emails ) self.clock.looping_call(send_emails, 30 * 60 * 1000) - @defer.inlineCallbacks - def send_renewal_emails(self): + async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they have an email 3PID attached to their account. """ - expiring_users = yield self.store.get_users_expiring_soon() + expiring_users = await self.store.get_users_expiring_soon() if expiring_users: for user in expiring_users: - yield self._send_renewal_email( + await self._send_renewal_email( user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - @defer.inlineCallbacks - def send_renewal_email_to_user(self, user_id): - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - yield self._send_renewal_email(user_id, expiration_ts) + async def send_renewal_email_to_user(self, user_id: str): + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + await self._send_renewal_email(user_id, expiration_ts) - @defer.inlineCallbacks - def _send_renewal_email(self, user_id, expiration_ts): + async def _send_renewal_email(self, user_id: str, expiration_ts: int): """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. Args: - user_id (str): ID of the user to send email(s) to. - expiration_ts (int): Timestamp in milliseconds for the expiration date of + user_id: ID of the user to send email(s) to. + expiration_ts: Timestamp in milliseconds for the expiration date of this user's account (used in the email templates). """ - addresses = yield self._get_email_addresses_for_user(user_id) + addresses = await self._get_email_addresses_for_user(user_id) # Stop right here if the user doesn't have at least one email address. # In this case, they will have to ask their server admin to renew their @@ -125,7 +121,7 @@ class AccountValidityHandler(object): return try: - user_display_name = yield self.store.get_profile_displayname( + user_display_name = await self.store.get_profile_displayname( UserID.from_string(user_id).localpart ) if user_display_name is None: @@ -133,7 +129,7 @@ class AccountValidityHandler(object): except StoreError: user_display_name = user_id - renewal_token = yield self._get_renewal_token(user_id) + renewal_token = await self._get_renewal_token(user_id) url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( self.hs.config.public_baseurl, renewal_token, @@ -165,7 +161,7 @@ class AccountValidityHandler(object): logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable( + await make_deferred_yieldable( self.sendmail( self.hs.config.email_smtp_host, self._raw_from, @@ -180,19 +176,18 @@ class AccountValidityHandler(object): ) ) - yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) + await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) - @defer.inlineCallbacks - def _get_email_addresses_for_user(self, user_id): + async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: """Retrieve the list of email addresses attached to a user's account. Args: - user_id (str): ID of the user to lookup email addresses for. + user_id: ID of the user to lookup email addresses for. Returns: - defer.Deferred[list[str]]: Email addresses for this account. + Email addresses for this account. """ - threepids = yield self.store.user_get_threepids(user_id) + threepids = await self.store.user_get_threepids(user_id) addresses = [] for threepid in threepids: @@ -201,16 +196,15 @@ class AccountValidityHandler(object): return addresses - @defer.inlineCallbacks - def _get_renewal_token(self, user_id): + async def _get_renewal_token(self, user_id: str) -> str: """Generates a 32-byte long random string that will be inserted into the user's renewal email's unique link, then saves it into the database. Args: - user_id (str): ID of the user to generate a string for. + user_id: ID of the user to generate a string for. Returns: - defer.Deferred[str]: The generated string. + The generated string. Raises: StoreError(500): Couldn't generate a unique string after 5 attempts. @@ -219,52 +213,52 @@ class AccountValidityHandler(object): while attempts < 5: try: renewal_token = stringutils.random_string(32) - yield self.store.set_renewal_token_for_user(user_id, renewal_token) + await self.store.set_renewal_token_for_user(user_id, renewal_token) return renewal_token except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - @defer.inlineCallbacks - def renew_account(self, renewal_token): + async def renew_account(self, renewal_token: str) -> bool: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. + renewal_token: Token sent with the renewal request. Returns: - bool: Whether the provided token is valid. + Whether the provided token is valid. """ try: - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + user_id = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - defer.returnValue(False) + return False logger.debug("Renewing an account for user %s", user_id) - yield self.renew_account_for_user(user_id) + await self.renew_account_for_user(user_id) - defer.returnValue(True) + return True - @defer.inlineCallbacks - def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): + async def renew_account_for_user( + self, user_id: str, expiration_ts: int = None, email_sent: bool = False + ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. - expiration_ts (int): New expiration date. Defaults to now + validity period. - email_sent (bool): Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. + expiration_ts: New expiration date. Defaults to now + validity period. + email_sen: Whether an email has been sent for this validity period. Defaults to False. Returns: - defer.Deferred[int]: New expiration date for this account, as a timestamp - in milliseconds since epoch. + New expiration date for this account, as a timestamp in + milliseconds since epoch. """ if expiration_ts is None: expiration_ts = self.clock.time_msec() + self._account_validity.period - yield self.store.set_account_validity_for_user( + await self.store.set_account_validity_for_user( user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index c0d0d2b44e..d0c997e385 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Email config. self.email_attempts = [] - def sendmail(*args, **kwargs): + async def sendmail(*args, **kwargs): self.email_attempts.append((args, kwargs)) - return config["email"] = { "enable_notifs": True,