mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 21:43:22 +01:00
Convert user_get_threepids response to attrs. (#16468)
This improves type annotations by not having a dictionary of Any values.
This commit is contained in:
parent
a4904dcb04
commit
cc865fffc0
9 changed files with 31 additions and 18 deletions
1
changelog.d/16468.misc
Normal file
1
changelog.d/16468.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
|
@ -212,8 +212,8 @@ class AccountValidityHandler:
|
||||||
|
|
||||||
addresses = []
|
addresses = []
|
||||||
for threepid in threepids:
|
for threepid in threepids:
|
||||||
if threepid["medium"] == "email":
|
if threepid.medium == "email":
|
||||||
addresses.append(threepid["address"])
|
addresses.append(threepid.address)
|
||||||
|
|
||||||
return addresses
|
return addresses
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
|
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import Direction, Membership
|
from synapse.api.constants import Direction, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
|
from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
|
||||||
|
@ -93,7 +95,7 @@ class AdminHandler:
|
||||||
]
|
]
|
||||||
user_info_dict["displayname"] = profile.display_name
|
user_info_dict["displayname"] = profile.display_name
|
||||||
user_info_dict["avatar_url"] = profile.avatar_url
|
user_info_dict["avatar_url"] = profile.avatar_url
|
||||||
user_info_dict["threepids"] = threepids
|
user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
|
||||||
user_info_dict["external_ids"] = external_ids
|
user_info_dict["external_ids"] = external_ids
|
||||||
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
|
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
|
||||||
|
|
||||||
|
|
|
@ -117,9 +117,9 @@ class DeactivateAccountHandler:
|
||||||
|
|
||||||
# Remove any local threepid associations for this account.
|
# Remove any local threepid associations for this account.
|
||||||
local_threepids = await self.store.user_get_threepids(user_id)
|
local_threepids = await self.store.user_get_threepids(user_id)
|
||||||
for threepid in local_threepids:
|
for local_threepid in local_threepids:
|
||||||
await self._auth_handler.delete_local_threepid(
|
await self._auth_handler.delete_local_threepid(
|
||||||
user_id, threepid["medium"], threepid["address"]
|
user_id, local_threepid.medium, local_threepid.address
|
||||||
)
|
)
|
||||||
|
|
||||||
# delete any devices belonging to the user, which will also
|
# delete any devices belonging to the user, which will also
|
||||||
|
|
|
@ -678,7 +678,7 @@ class ModuleApi:
|
||||||
"msisdn" for phone numbers, and an "address" key which value is the
|
"msisdn" for phone numbers, and an "address" key which value is the
|
||||||
threepid's address.
|
threepid's address.
|
||||||
"""
|
"""
|
||||||
return await self._store.user_get_threepids(user_id)
|
return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)]
|
||||||
|
|
||||||
def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
|
def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
|
||||||
"""Check if user exists.
|
"""Check if user exists.
|
||||||
|
|
|
@ -329,9 +329,8 @@ class UserRestServletV2(RestServlet):
|
||||||
|
|
||||||
if threepids is not None:
|
if threepids is not None:
|
||||||
# get changed threepids (added and removed)
|
# get changed threepids (added and removed)
|
||||||
# convert List[Dict[str, Any]] into Set[Tuple[str, str]]
|
|
||||||
cur_threepids = {
|
cur_threepids = {
|
||||||
(threepid["medium"], threepid["address"])
|
(threepid.medium, threepid.address)
|
||||||
for threepid in await self.store.user_get_threepids(user_id)
|
for threepid in await self.store.user_get_threepids(user_id)
|
||||||
}
|
}
|
||||||
add_threepids = new_threepids - cur_threepids
|
add_threepids = new_threepids - cur_threepids
|
||||||
|
|
|
@ -24,6 +24,8 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2:
|
||||||
from pydantic.v1 import StrictBool, StrictStr, constr
|
from pydantic.v1 import StrictBool, StrictStr, constr
|
||||||
else:
|
else:
|
||||||
from pydantic import StrictBool, StrictStr, constr
|
from pydantic import StrictBool, StrictStr, constr
|
||||||
|
|
||||||
|
import attr
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
@ -595,7 +597,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
|
|
||||||
threepids = await self.datastore.user_get_threepids(requester.user.to_string())
|
threepids = await self.datastore.user_get_threepids(requester.user.to_string())
|
||||||
|
|
||||||
return 200, {"threepids": threepids}
|
return 200, {"threepids": [attr.asdict(t) for t in threepids]}
|
||||||
|
|
||||||
# NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
|
# NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
|
||||||
# the endpoint is deprecated. (If you really want to, you could do this by reusing
|
# the endpoint is deprecated. (If you really want to, you could do this by reusing
|
||||||
|
|
|
@ -143,6 +143,14 @@ class LoginTokenLookupResult:
|
||||||
"""The session ID advertised by the SSO Identity Provider."""
|
"""The session ID advertised by the SSO Identity Provider."""
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||||
|
class ThreepidResult:
|
||||||
|
medium: str
|
||||||
|
address: str
|
||||||
|
validated_at: int
|
||||||
|
added_at: int
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -988,13 +996,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
|
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
|
||||||
return await self.db_pool.simple_select_list(
|
results = await self.db_pool.simple_select_list(
|
||||||
"user_threepids",
|
"user_threepids",
|
||||||
{"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
["medium", "address", "validated_at", "added_at"],
|
retcols=["medium", "address", "validated_at", "added_at"],
|
||||||
"user_get_threepids",
|
desc="user_get_threepids",
|
||||||
)
|
)
|
||||||
|
return [ThreepidResult(**r) for r in results]
|
||||||
|
|
||||||
async def user_delete_threepid(
|
async def user_delete_threepid(
|
||||||
self, user_id: str, medium: str, address: str
|
self, user_id: str, medium: str, address: str
|
||||||
|
|
|
@ -94,12 +94,12 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
self.assertEqual(len(emails), 1)
|
self.assertEqual(len(emails), 1)
|
||||||
|
|
||||||
email = emails[0]
|
email = emails[0]
|
||||||
self.assertEqual(email["medium"], "email")
|
self.assertEqual(email.medium, "email")
|
||||||
self.assertEqual(email["address"], "bob@bobinator.bob")
|
self.assertEqual(email.address, "bob@bobinator.bob")
|
||||||
|
|
||||||
# Should these be 0?
|
# Should these be 0?
|
||||||
self.assertEqual(email["validated_at"], 0)
|
self.assertEqual(email.validated_at, 0)
|
||||||
self.assertEqual(email["added_at"], 0)
|
self.assertEqual(email.added_at, 0)
|
||||||
|
|
||||||
# Check that the displayname was assigned
|
# Check that the displayname was assigned
|
||||||
displayname = self.get_success(
|
displayname = self.get_success(
|
||||||
|
|
Loading…
Reference in a new issue