forked from MirrorHub/synapse
Merge pull request #6504 from matrix-org/erikj/account_validity_async_await
Port handlers.account_validity to async/await.
This commit is contained in:
commit
31905a518e
4 changed files with 47 additions and 57 deletions
1
changelog.d/6504.misc
Normal file
1
changelog.d/6504.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Port handlers.account_data and handlers.account_validity to async/await.
|
|
@ -13,8 +13,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
|
|
||||||
class AccountDataEventSource(object):
|
class AccountDataEventSource(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -23,15 +21,14 @@ class AccountDataEventSource(object):
|
||||||
def get_current_key(self, direction="f"):
|
def get_current_key(self, direction="f"):
|
||||||
return self.store.get_max_account_data_stream_id()
|
return self.store.get_max_account_data_stream_id()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_new_events(self, user, from_key, **kwargs):
|
||||||
def get_new_events(self, user, from_key, **kwargs):
|
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
last_stream_id = from_key
|
last_stream_id = from_key
|
||||||
|
|
||||||
current_stream_id = yield self.store.get_max_account_data_stream_id()
|
current_stream_id = self.store.get_max_account_data_stream_id()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
|
tags = await self.store.get_updated_tags(user_id, last_stream_id)
|
||||||
|
|
||||||
for room_id, room_tags in tags.items():
|
for room_id, room_tags in tags.items():
|
||||||
results.append(
|
results.append(
|
||||||
|
@ -41,7 +38,7 @@ class AccountDataEventSource(object):
|
||||||
(
|
(
|
||||||
account_data,
|
account_data,
|
||||||
room_account_data,
|
room_account_data,
|
||||||
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||||
|
|
||||||
for account_data_type, content in account_data.items():
|
for account_data_type, content in account_data.items():
|
||||||
results.append({"type": account_data_type, "content": content})
|
results.append({"type": account_data_type, "content": content})
|
||||||
|
@ -54,6 +51,5 @@ class AccountDataEventSource(object):
|
||||||
|
|
||||||
return results, current_stream_id
|
return results, current_stream_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_pagination_rows(self, user, config, key):
|
||||||
def get_pagination_rows(self, user, config, key):
|
|
||||||
return [], config.to_id
|
return [], config.to_id
|
||||||
|
|
|
@ -18,8 +18,7 @@ import email.utils
|
||||||
import logging
|
import logging
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
|
from typing import List
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
@ -78,42 +77,39 @@ class AccountValidityHandler(object):
|
||||||
# run as a background process to make sure that the database transactions
|
# run as a background process to make sure that the database transactions
|
||||||
# have a logcontext to report to
|
# have a logcontext to report to
|
||||||
return run_as_background_process(
|
return run_as_background_process(
|
||||||
"send_renewals", self.send_renewal_emails
|
"send_renewals", self._send_renewal_emails
|
||||||
)
|
)
|
||||||
|
|
||||||
self.clock.looping_call(send_emails, 30 * 60 * 1000)
|
self.clock.looping_call(send_emails, 30 * 60 * 1000)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _send_renewal_emails(self):
|
||||||
def send_renewal_emails(self):
|
|
||||||
"""Gets the list of users whose account is expiring in the amount of time
|
"""Gets the list of users whose account is expiring in the amount of time
|
||||||
configured in the ``renew_at`` parameter from the ``account_validity``
|
configured in the ``renew_at`` parameter from the ``account_validity``
|
||||||
configuration, and sends renewal emails to all of these users as long as they
|
configuration, and sends renewal emails to all of these users as long as they
|
||||||
have an email 3PID attached to their account.
|
have an email 3PID attached to their account.
|
||||||
"""
|
"""
|
||||||
expiring_users = yield self.store.get_users_expiring_soon()
|
expiring_users = await self.store.get_users_expiring_soon()
|
||||||
|
|
||||||
if expiring_users:
|
if expiring_users:
|
||||||
for user in expiring_users:
|
for user in expiring_users:
|
||||||
yield self._send_renewal_email(
|
await self._send_renewal_email(
|
||||||
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def send_renewal_email_to_user(self, user_id: str):
|
||||||
def send_renewal_email_to_user(self, user_id):
|
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
|
await self._send_renewal_email(user_id, expiration_ts)
|
||||||
yield self._send_renewal_email(user_id, expiration_ts)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _send_renewal_email(self, user_id: str, expiration_ts: int):
|
||||||
def _send_renewal_email(self, user_id, expiration_ts):
|
|
||||||
"""Sends out a renewal email to every email address attached to the given user
|
"""Sends out a renewal email to every email address attached to the given user
|
||||||
with a unique link allowing them to renew their account.
|
with a unique link allowing them to renew their account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): ID of the user to send email(s) to.
|
user_id: ID of the user to send email(s) to.
|
||||||
expiration_ts (int): Timestamp in milliseconds for the expiration date of
|
expiration_ts: Timestamp in milliseconds for the expiration date of
|
||||||
this user's account (used in the email templates).
|
this user's account (used in the email templates).
|
||||||
"""
|
"""
|
||||||
addresses = yield self._get_email_addresses_for_user(user_id)
|
addresses = await self._get_email_addresses_for_user(user_id)
|
||||||
|
|
||||||
# Stop right here if the user doesn't have at least one email address.
|
# Stop right here if the user doesn't have at least one email address.
|
||||||
# In this case, they will have to ask their server admin to renew their
|
# In this case, they will have to ask their server admin to renew their
|
||||||
|
@ -125,7 +121,7 @@ class AccountValidityHandler(object):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_display_name = yield self.store.get_profile_displayname(
|
user_display_name = await self.store.get_profile_displayname(
|
||||||
UserID.from_string(user_id).localpart
|
UserID.from_string(user_id).localpart
|
||||||
)
|
)
|
||||||
if user_display_name is None:
|
if user_display_name is None:
|
||||||
|
@ -133,7 +129,7 @@ class AccountValidityHandler(object):
|
||||||
except StoreError:
|
except StoreError:
|
||||||
user_display_name = user_id
|
user_display_name = user_id
|
||||||
|
|
||||||
renewal_token = yield self._get_renewal_token(user_id)
|
renewal_token = await self._get_renewal_token(user_id)
|
||||||
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
|
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
|
||||||
self.hs.config.public_baseurl,
|
self.hs.config.public_baseurl,
|
||||||
renewal_token,
|
renewal_token,
|
||||||
|
@ -165,7 +161,7 @@ class AccountValidityHandler(object):
|
||||||
|
|
||||||
logger.info("Sending renewal email to %s", address)
|
logger.info("Sending renewal email to %s", address)
|
||||||
|
|
||||||
yield make_deferred_yieldable(
|
await make_deferred_yieldable(
|
||||||
self.sendmail(
|
self.sendmail(
|
||||||
self.hs.config.email_smtp_host,
|
self.hs.config.email_smtp_host,
|
||||||
self._raw_from,
|
self._raw_from,
|
||||||
|
@ -180,19 +176,18 @@ class AccountValidityHandler(object):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
|
||||||
def _get_email_addresses_for_user(self, user_id):
|
|
||||||
"""Retrieve the list of email addresses attached to a user's account.
|
"""Retrieve the list of email addresses attached to a user's account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): ID of the user to lookup email addresses for.
|
user_id: ID of the user to lookup email addresses for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[list[str]]: Email addresses for this account.
|
Email addresses for this account.
|
||||||
"""
|
"""
|
||||||
threepids = yield self.store.user_get_threepids(user_id)
|
threepids = await self.store.user_get_threepids(user_id)
|
||||||
|
|
||||||
addresses = []
|
addresses = []
|
||||||
for threepid in threepids:
|
for threepid in threepids:
|
||||||
|
@ -201,16 +196,15 @@ class AccountValidityHandler(object):
|
||||||
|
|
||||||
return addresses
|
return addresses
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_renewal_token(self, user_id: str) -> str:
|
||||||
def _get_renewal_token(self, user_id):
|
|
||||||
"""Generates a 32-byte long random string that will be inserted into the
|
"""Generates a 32-byte long random string that will be inserted into the
|
||||||
user's renewal email's unique link, then saves it into the database.
|
user's renewal email's unique link, then saves it into the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): ID of the user to generate a string for.
|
user_id: ID of the user to generate a string for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[str]: The generated string.
|
The generated string.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError(500): Couldn't generate a unique string after 5 attempts.
|
StoreError(500): Couldn't generate a unique string after 5 attempts.
|
||||||
|
@ -219,52 +213,52 @@ class AccountValidityHandler(object):
|
||||||
while attempts < 5:
|
while attempts < 5:
|
||||||
try:
|
try:
|
||||||
renewal_token = stringutils.random_string(32)
|
renewal_token = stringutils.random_string(32)
|
||||||
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
|
await self.store.set_renewal_token_for_user(user_id, renewal_token)
|
||||||
return renewal_token
|
return renewal_token
|
||||||
except StoreError:
|
except StoreError:
|
||||||
attempts += 1
|
attempts += 1
|
||||||
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
|
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def renew_account(self, renewal_token: str) -> bool:
|
||||||
def renew_account(self, renewal_token):
|
|
||||||
"""Renews the account attached to a given renewal token by pushing back the
|
"""Renews the account attached to a given renewal token by pushing back the
|
||||||
expiration date by the current validity period in the server's configuration.
|
expiration date by the current validity period in the server's configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
renewal_token (str): Token sent with the renewal request.
|
renewal_token: Token sent with the renewal request.
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether the provided token is valid.
|
Whether the provided token is valid.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
|
user_id = await self.store.get_user_from_renewal_token(renewal_token)
|
||||||
except StoreError:
|
except StoreError:
|
||||||
defer.returnValue(False)
|
return False
|
||||||
|
|
||||||
logger.debug("Renewing an account for user %s", user_id)
|
logger.debug("Renewing an account for user %s", user_id)
|
||||||
yield self.renew_account_for_user(user_id)
|
await self.renew_account_for_user(user_id)
|
||||||
|
|
||||||
defer.returnValue(True)
|
return True
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def renew_account_for_user(
|
||||||
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
|
self, user_id: str, expiration_ts: int = None, email_sent: bool = False
|
||||||
|
) -> int:
|
||||||
"""Renews the account attached to a given user by pushing back the
|
"""Renews the account attached to a given user by pushing back the
|
||||||
expiration date by the current validity period in the server's
|
expiration date by the current validity period in the server's
|
||||||
configuration.
|
configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
renewal_token (str): Token sent with the renewal request.
|
renewal_token: Token sent with the renewal request.
|
||||||
expiration_ts (int): New expiration date. Defaults to now + validity period.
|
expiration_ts: New expiration date. Defaults to now + validity period.
|
||||||
email_sent (bool): Whether an email has been sent for this validity period.
|
email_sen: Whether an email has been sent for this validity period.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[int]: New expiration date for this account, as a timestamp
|
New expiration date for this account, as a timestamp in
|
||||||
in milliseconds since epoch.
|
milliseconds since epoch.
|
||||||
"""
|
"""
|
||||||
if expiration_ts is None:
|
if expiration_ts is None:
|
||||||
expiration_ts = self.clock.time_msec() + self._account_validity.period
|
expiration_ts = self.clock.time_msec() + self._account_validity.period
|
||||||
|
|
||||||
yield self.store.set_account_validity_for_user(
|
await self.store.set_account_validity_for_user(
|
||||||
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
|
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
# Email config.
|
# Email config.
|
||||||
self.email_attempts = []
|
self.email_attempts = []
|
||||||
|
|
||||||
def sendmail(*args, **kwargs):
|
async def sendmail(*args, **kwargs):
|
||||||
self.email_attempts.append((args, kwargs))
|
self.email_attempts.append((args, kwargs))
|
||||||
return
|
|
||||||
|
|
||||||
config["email"] = {
|
config["email"] = {
|
||||||
"enable_notifs": True,
|
"enable_notifs": True,
|
||||||
|
|
Loading…
Reference in a new issue