0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 09:33:53 +01:00

Optionally include account validity in MSC3720 account status responses (#12266)

This commit is contained in:
Brendan Abolivier 2022-03-24 11:19:41 +01:00 committed by GitHub
parent e78d4f61fc
commit 5436b014f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 1 deletions

1
changelog.d/12266.misc Normal file
View file

@ -0,0 +1 @@
Optionally include account validity expiration information to experimental [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) account status responses.

View file

@ -676,6 +676,10 @@ class ServerConfig(Config):
): ):
raise ConfigError("'custom_template_directory' must be a string") raise ConfigError("'custom_template_directory' must be a string")
self.use_account_validity_in_account_status: bool = (
config.get("use_account_validity_in_account_status") or False
)
def has_tls_listener(self) -> bool: def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners) return any(listener.tls for listener in self.listeners)

View file

@ -26,6 +26,10 @@ class AccountHandler:
self._main_store = hs.get_datastores().main self._main_store = hs.get_datastores().main
self._is_mine = hs.is_mine self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client() self._federation_client = hs.get_federation_client()
self._use_account_validity_in_account_status = (
hs.config.server.use_account_validity_in_account_status
)
self._account_validity_handler = hs.get_account_validity_handler()
async def get_account_statuses( async def get_account_statuses(
self, self,
@ -106,6 +110,13 @@ class AccountHandler:
"deactivated": userinfo.is_deactivated, "deactivated": userinfo.is_deactivated,
} }
if self._use_account_validity_in_account_status:
status[
"org.matrix.expired"
] = await self._account_validity_handler.is_user_expired(
user_id.to_string()
)
return status return status
async def _get_remote_account_statuses( async def _get_remote_account_statuses(

View file

@ -31,7 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[users[2]], expected_failures=[users[2]],
) )
@unittest.override_config(
{
"use_account_validity_in_account_status": True,
}
)
def test_no_account_validity(self) -> None:
"""Tests that if we decide to include account validity in the response but no
account validity 'is_user_expired' callback is provided, we default to marking all
users as not expired.
"""
user = self.register_user("someuser", "password")
self._test_status(
users=[user],
expected_statuses={
user: {
"exists": True,
"deactivated": False,
"org.matrix.expired": False,
},
},
expected_failures=[],
)
@unittest.override_config(
{
"use_account_validity_in_account_status": True,
}
)
def test_account_validity_expired(self) -> None:
"""Test that if we decide to include account validity in the response and the user
is expired, we return the correct info.
"""
user = self.register_user("someuser", "password")
async def is_expired(user_id: str) -> bool:
# We can't blindly say everyone is expired, otherwise the request to get the
# account status will fail.
return UserID.from_string(user_id).localpart == "someuser"
self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
is_expired
)
self._test_status(
users=[user],
expected_statuses={
user: {
"exists": True,
"deactivated": False,
"org.matrix.expired": True,
},
},
expected_failures=[],
)
def _test_status( def _test_status(
self, self,
users: Optional[List[str]], users: Optional[List[str]],