0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-13 21:43:22 +01:00

Convert user_get_threepids response to attrs. (#16468)

This improves type annotations by not having a dictionary of Any values.
This commit is contained in:
Patrick Cloke 2023-10-11 20:08:11 -04:00 committed by GitHub
parent a4904dcb04
commit cc865fffc0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 18 deletions

1
changelog.d/16468.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -212,8 +212,8 @@ class AccountValidityHandler:
addresses = [] addresses = []
for threepid in threepids: for threepid in threepids:
if threepid["medium"] == "email": if threepid.medium == "email":
addresses.append(threepid["address"]) addresses.append(threepid.address)
return addresses return addresses

View file

@ -16,6 +16,8 @@ import abc
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
import attr
from synapse.api.constants import Direction, Membership from synapse.api.constants import Direction, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
@ -93,7 +95,7 @@ class AdminHandler:
] ]
user_info_dict["displayname"] = profile.display_name user_info_dict["displayname"] = profile.display_name
user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["avatar_url"] = profile.avatar_url
user_info_dict["threepids"] = threepids user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
user_info_dict["external_ids"] = external_ids user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())

View file

@ -117,9 +117,9 @@ class DeactivateAccountHandler:
# Remove any local threepid associations for this account. # Remove any local threepid associations for this account.
local_threepids = await self.store.user_get_threepids(user_id) local_threepids = await self.store.user_get_threepids(user_id)
for threepid in local_threepids: for local_threepid in local_threepids:
await self._auth_handler.delete_local_threepid( await self._auth_handler.delete_local_threepid(
user_id, threepid["medium"], threepid["address"] user_id, local_threepid.medium, local_threepid.address
) )
# delete any devices belonging to the user, which will also # delete any devices belonging to the user, which will also

View file

@ -678,7 +678,7 @@ class ModuleApi:
"msisdn" for phone numbers, and an "address" key which value is the "msisdn" for phone numbers, and an "address" key which value is the
threepid's address. threepid's address.
""" """
return await self._store.user_get_threepids(user_id) return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)]
def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]": def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
"""Check if user exists. """Check if user exists.

View file

@ -329,9 +329,8 @@ class UserRestServletV2(RestServlet):
if threepids is not None: if threepids is not None:
# get changed threepids (added and removed) # get changed threepids (added and removed)
# convert List[Dict[str, Any]] into Set[Tuple[str, str]]
cur_threepids = { cur_threepids = {
(threepid["medium"], threepid["address"]) (threepid.medium, threepid.address)
for threepid in await self.store.user_get_threepids(user_id) for threepid in await self.store.user_get_threepids(user_id)
} }
add_threepids = new_threepids - cur_threepids add_threepids = new_threepids - cur_threepids

View file

@ -24,6 +24,8 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import StrictBool, StrictStr, constr from pydantic.v1 import StrictBool, StrictStr, constr
else: else:
from pydantic import StrictBool, StrictStr, constr from pydantic import StrictBool, StrictStr, constr
import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.web.server import Request from twisted.web.server import Request
@ -595,7 +597,7 @@ class ThreepidRestServlet(RestServlet):
threepids = await self.datastore.user_get_threepids(requester.user.to_string()) threepids = await self.datastore.user_get_threepids(requester.user.to_string())
return 200, {"threepids": threepids} return 200, {"threepids": [attr.asdict(t) for t in threepids]}
# NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because # NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
# the endpoint is deprecated. (If you really want to, you could do this by reusing # the endpoint is deprecated. (If you really want to, you could do this by reusing

View file

@ -143,6 +143,14 @@ class LoginTokenLookupResult:
"""The session ID advertised by the SSO Identity Provider.""" """The session ID advertised by the SSO Identity Provider."""
@attr.s(frozen=True, slots=True, auto_attribs=True)
class ThreepidResult:
medium: str
address: str
validated_at: int
added_at: int
class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__( def __init__(
self, self,
@ -988,13 +996,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
) )
async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]: async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
return await self.db_pool.simple_select_list( results = await self.db_pool.simple_select_list(
"user_threepids", "user_threepids",
{"user_id": user_id}, keyvalues={"user_id": user_id},
["medium", "address", "validated_at", "added_at"], retcols=["medium", "address", "validated_at", "added_at"],
"user_get_threepids", desc="user_get_threepids",
) )
return [ThreepidResult(**r) for r in results]
async def user_delete_threepid( async def user_delete_threepid(
self, user_id: str, medium: str, address: str self, user_id: str, medium: str, address: str

View file

@ -94,12 +94,12 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(len(emails), 1) self.assertEqual(len(emails), 1)
email = emails[0] email = emails[0]
self.assertEqual(email["medium"], "email") self.assertEqual(email.medium, "email")
self.assertEqual(email["address"], "bob@bobinator.bob") self.assertEqual(email.address, "bob@bobinator.bob")
# Should these be 0? # Should these be 0?
self.assertEqual(email["validated_at"], 0) self.assertEqual(email.validated_at, 0)
self.assertEqual(email["added_at"], 0) self.assertEqual(email.added_at, 0)
# Check that the displayname was assigned # Check that the displayname was assigned
displayname = self.get_success( displayname = self.get_success(