forked from MirrorHub/synapse
Don't push if an user account has expired (#8353)
This commit is contained in:
parent
4bb203ea4f
commit
916bb9d0d1
4 changed files with 34 additions and 5 deletions
1
changelog.d/8353.bugfix
Normal file
1
changelog.d/8353.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Don't send push notifications to expired user accounts.
|
|
@ -218,11 +218,7 @@ class Auth:
|
||||||
# Deny the request if the user account has expired.
|
# Deny the request if the user account has expired.
|
||||||
if self._account_validity.enabled and not allow_expired:
|
if self._account_validity.enabled and not allow_expired:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
|
||||||
if (
|
|
||||||
expiration_ts is not None
|
|
||||||
and self.clock.time_msec() >= expiration_ts
|
|
||||||
):
|
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||||
)
|
)
|
||||||
|
|
|
@ -60,6 +60,8 @@ class PusherPool:
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
|
||||||
|
self._account_validity = hs.config.account_validity
|
||||||
|
|
||||||
# We shard the handling of push notifications by user ID.
|
# We shard the handling of push notifications by user ID.
|
||||||
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
@ -202,6 +204,14 @@ class PusherPool:
|
||||||
)
|
)
|
||||||
|
|
||||||
for u in users_affected:
|
for u in users_affected:
|
||||||
|
# Don't push if the user account has expired
|
||||||
|
if self._account_validity.enabled:
|
||||||
|
expired = await self.store.is_account_expired(
|
||||||
|
u, self.clock.time_msec()
|
||||||
|
)
|
||||||
|
if expired:
|
||||||
|
continue
|
||||||
|
|
||||||
if u in self.pushers:
|
if u in self.pushers:
|
||||||
for p in self.pushers[u].values():
|
for p in self.pushers[u].values():
|
||||||
p.on_new_notifications(max_stream_id)
|
p.on_new_notifications(max_stream_id)
|
||||||
|
@ -222,6 +232,14 @@ class PusherPool:
|
||||||
)
|
)
|
||||||
|
|
||||||
for u in users_affected:
|
for u in users_affected:
|
||||||
|
# Don't push if the user account has expired
|
||||||
|
if self._account_validity.enabled:
|
||||||
|
expired = await self.store.is_account_expired(
|
||||||
|
u, self.clock.time_msec()
|
||||||
|
)
|
||||||
|
if expired:
|
||||||
|
continue
|
||||||
|
|
||||||
if u in self.pushers:
|
if u in self.pushers:
|
||||||
for p in self.pushers[u].values():
|
for p in self.pushers[u].values():
|
||||||
p.on_new_receipts(min_stream_id, max_stream_id)
|
p.on_new_receipts(min_stream_id, max_stream_id)
|
||||||
|
|
|
@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
desc="get_expiration_ts_for_user",
|
desc="get_expiration_ts_for_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether an user account is expired.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
current_ts: The current timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the user account has expired
|
||||||
|
"""
|
||||||
|
expiration_ts = await self.get_expiration_ts_for_user(user_id)
|
||||||
|
return expiration_ts is not None and current_ts >= expiration_ts
|
||||||
|
|
||||||
async def set_account_validity_for_user(
|
async def set_account_validity_for_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
Loading…
Reference in a new issue