Don't push if an user account has expired (#8353)

This commit is contained in:
Mathieu Velten 2020-09-23 17:06:28 +02:00 committed by GitHub
parent 4bb203ea4f
commit 916bb9d0d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 5 deletions

1
changelog.d/8353.bugfix Normal file
View file

@ -0,0 +1 @@
Don't send push notifications to expired user accounts.

View file

@ -218,11 +218,7 @@ class Auth:
# Deny the request if the user account has expired. # Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired: if self._account_validity.enabled and not allow_expired:
user_id = user.to_string() user_id = user.to_string()
expiration_ts = await self.store.get_expiration_ts_for_user(user_id) if await self.store.is_account_expired(user_id, self.clock.time_msec()):
if (
expiration_ts is not None
and self.clock.time_msec() >= expiration_ts
):
raise AuthError( raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
) )

View file

@ -60,6 +60,8 @@ class PusherPool:
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self._account_validity = hs.config.account_validity
# We shard the handling of push notifications by user ID. # We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -202,6 +204,14 @@ class PusherPool:
) )
for u in users_affected: for u in users_affected:
# Don't push if the user account has expired
if self._account_validity.enabled:
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
if expired:
continue
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
p.on_new_notifications(max_stream_id) p.on_new_notifications(max_stream_id)
@ -222,6 +232,14 @@ class PusherPool:
) )
for u in users_affected: for u in users_affected:
# Don't push if the user account has expired
if self._account_validity.enabled:
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
if expired:
continue
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id) p.on_new_receipts(min_stream_id, max_stream_id)

View file

@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_expiration_ts_for_user", desc="get_expiration_ts_for_user",
) )
async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
"""
Returns whether an user account is expired.
Args:
user_id: The user's ID
current_ts: The current timestamp
Returns:
Whether the user account has expired
"""
expiration_ts = await self.get_expiration_ts_for_user(user_id)
return expiration_ts is not None and current_ts >= expiration_ts
async def set_account_validity_for_user( async def set_account_validity_for_user(
self, self,
user_id: str, user_id: str,