Merge pull request #6504 from matrix-org/erikj/account_validity_async_await

Port handlers.account_validity to async/await.
This commit is contained in:
Erik Johnston 2019-12-11 14:49:26 +00:00 committed by GitHub
commit 31905a518e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 57 deletions

1
changelog.d/6504.misc Normal file
View file

@ -0,0 +1 @@
Port handlers.account_data and handlers.account_validity to async/await.

View file

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
class AccountDataEventSource(object): class AccountDataEventSource(object):
def __init__(self, hs): def __init__(self, hs):
@ -23,15 +21,14 @@ class AccountDataEventSource(object):
def get_current_key(self, direction="f"): def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id() return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks async def get_new_events(self, user, from_key, **kwargs):
def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key
current_stream_id = yield self.store.get_max_account_data_stream_id() current_stream_id = self.store.get_max_account_data_stream_id()
results = [] results = []
tags = yield self.store.get_updated_tags(user_id, last_stream_id) tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items(): for room_id, room_tags in tags.items():
results.append( results.append(
@ -41,7 +38,7 @@ class AccountDataEventSource(object):
( (
account_data, account_data,
room_account_data, room_account_data,
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content}) results.append({"type": account_data_type, "content": content})
@ -54,6 +51,5 @@ class AccountDataEventSource(object):
return results, current_stream_id return results, current_stream_id
@defer.inlineCallbacks async def get_pagination_rows(self, user, config, key):
def get_pagination_rows(self, user, config, key):
return [], config.to_id return [], config.to_id

View file

@ -18,8 +18,7 @@ import email.utils
import logging import logging
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import List
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions # run as a background process to make sure that the database transactions
# have a logcontext to report to # have a logcontext to report to
return run_as_background_process( return run_as_background_process(
"send_renewals", self.send_renewal_emails "send_renewals", self._send_renewal_emails
) )
self.clock.looping_call(send_emails, 30 * 60 * 1000) self.clock.looping_call(send_emails, 30 * 60 * 1000)
@defer.inlineCallbacks async def _send_renewal_emails(self):
def send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time """Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity`` configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account. have an email 3PID attached to their account.
""" """
expiring_users = yield self.store.get_users_expiring_soon() expiring_users = await self.store.get_users_expiring_soon()
if expiring_users: if expiring_users:
for user in expiring_users: for user in expiring_users:
yield self._send_renewal_email( await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
) )
@defer.inlineCallbacks async def send_renewal_email_to_user(self, user_id: str):
def send_renewal_email_to_user(self, user_id): expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) await self._send_renewal_email(user_id, expiration_ts)
yield self._send_renewal_email(user_id, expiration_ts)
@defer.inlineCallbacks async def _send_renewal_email(self, user_id: str, expiration_ts: int):
def _send_renewal_email(self, user_id, expiration_ts):
"""Sends out a renewal email to every email address attached to the given user """Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account. with a unique link allowing them to renew their account.
Args: Args:
user_id (str): ID of the user to send email(s) to. user_id: ID of the user to send email(s) to.
expiration_ts (int): Timestamp in milliseconds for the expiration date of expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates). this user's account (used in the email templates).
""" """
addresses = yield self._get_email_addresses_for_user(user_id) addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address. # Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their # In this case, they will have to ask their server admin to renew their
@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return return
try: try:
user_display_name = yield self.store.get_profile_displayname( user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart UserID.from_string(user_id).localpart
) )
if user_display_name is None: if user_display_name is None:
@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError: except StoreError:
user_display_name = user_id user_display_name = user_id
renewal_token = yield self._get_renewal_token(user_id) renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl, self.hs.config.public_baseurl,
renewal_token, renewal_token,
@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address) logger.info("Sending renewal email to %s", address)
yield make_deferred_yieldable( await make_deferred_yieldable(
self.sendmail( self.sendmail(
self.hs.config.email_smtp_host, self.hs.config.email_smtp_host,
self._raw_from, self._raw_from,
@ -180,19 +176,18 @@ class AccountValidityHandler(object):
) )
) )
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
@defer.inlineCallbacks async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
def _get_email_addresses_for_user(self, user_id):
"""Retrieve the list of email addresses attached to a user's account. """Retrieve the list of email addresses attached to a user's account.
Args: Args:
user_id (str): ID of the user to lookup email addresses for. user_id: ID of the user to lookup email addresses for.
Returns: Returns:
defer.Deferred[list[str]]: Email addresses for this account. Email addresses for this account.
""" """
threepids = yield self.store.user_get_threepids(user_id) threepids = await self.store.user_get_threepids(user_id)
addresses = [] addresses = []
for threepid in threepids: for threepid in threepids:
@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses return addresses
@defer.inlineCallbacks async def _get_renewal_token(self, user_id: str) -> str:
def _get_renewal_token(self, user_id):
"""Generates a 32-byte long random string that will be inserted into the """Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database. user's renewal email's unique link, then saves it into the database.
Args: Args:
user_id (str): ID of the user to generate a string for. user_id: ID of the user to generate a string for.
Returns: Returns:
defer.Deferred[str]: The generated string. The generated string.
Raises: Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts. StoreError(500): Couldn't generate a unique string after 5 attempts.
@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5: while attempts < 5:
try: try:
renewal_token = stringutils.random_string(32) renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token) await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token return renewal_token
except StoreError: except StoreError:
attempts += 1 attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.") raise StoreError(500, "Couldn't generate a unique string as refresh string.")
@defer.inlineCallbacks async def renew_account(self, renewal_token: str) -> bool:
def renew_account(self, renewal_token):
"""Renews the account attached to a given renewal token by pushing back the """Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration. expiration date by the current validity period in the server's configuration.
Args: Args:
renewal_token (str): Token sent with the renewal request. renewal_token: Token sent with the renewal request.
Returns: Returns:
bool: Whether the provided token is valid. Whether the provided token is valid.
""" """
try: try:
user_id = yield self.store.get_user_from_renewal_token(renewal_token) user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError: except StoreError:
defer.returnValue(False) return False
logger.debug("Renewing an account for user %s", user_id) logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id) await self.renew_account_for_user(user_id)
defer.returnValue(True) return True
@defer.inlineCallbacks async def renew_account_for_user(
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): self, user_id: str, expiration_ts: int = None, email_sent: bool = False
) -> int:
"""Renews the account attached to a given user by pushing back the """Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's expiration date by the current validity period in the server's
configuration. configuration.
Args: Args:
renewal_token (str): Token sent with the renewal request. renewal_token: Token sent with the renewal request.
expiration_ts (int): New expiration date. Defaults to now + validity period. expiration_ts: New expiration date. Defaults to now + validity period.
email_sent (bool): Whether an email has been sent for this validity period. email_sen: Whether an email has been sent for this validity period.
Defaults to False. Defaults to False.
Returns: Returns:
defer.Deferred[int]: New expiration date for this account, as a timestamp New expiration date for this account, as a timestamp in
in milliseconds since epoch. milliseconds since epoch.
""" """
if expiration_ts is None: if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period expiration_ts = self.clock.time_msec() + self._account_validity.period
yield self.store.set_account_validity_for_user( await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
) )

View file

@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Email config. # Email config.
self.email_attempts = [] self.email_attempts = []
def sendmail(*args, **kwargs): async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs)) self.email_attempts.append((args, kwargs))
return
config["email"] = { config["email"] = {
"enable_notifs": True, "enable_notifs": True,