mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-16 23:11:34 +01:00
Add type hints to user admin API. (#9521)
This commit is contained in:
parent
0c330423bc
commit
d790d0d314
4 changed files with 63 additions and 35 deletions
1
changelog.d/9521.misc
Normal file
1
changelog.d/9521.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to user admin API.
|
|
@ -16,7 +16,7 @@ import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||||
|
@ -47,13 +47,15 @@ logger = logging.getLogger(__name__)
|
||||||
class UsersRestServlet(RestServlet):
|
class UsersRestServlet(RestServlet):
|
||||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.admin_handler = hs.get_admin_handler()
|
self.admin_handler = hs.get_admin_handler()
|
||||||
|
|
||||||
async def on_GET(self, request, user_id):
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, List[JsonDict]]:
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
|
@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
|
||||||
otherwise an error.
|
otherwise an error.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.admin_handler = hs.get_admin_handler()
|
self.admin_handler = hs.get_admin_handler()
|
||||||
|
@ -165,7 +167,9 @@ class UserRestServletV2(RestServlet):
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
self.pusher_pool = hs.get_pusherpool()
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
|
||||||
async def on_GET(self, request, user_id):
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
|
@ -179,7 +183,9 @@ class UserRestServletV2(RestServlet):
|
||||||
|
|
||||||
return 200, ret
|
return 200, ret
|
||||||
|
|
||||||
async def on_PUT(self, request, user_id):
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
@ -273,6 +279,8 @@ class UserRestServletV2(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await self.admin_handler.get_user(target_user)
|
user = await self.admin_handler.get_user(target_user)
|
||||||
|
assert user is not None
|
||||||
|
|
||||||
return 200, user
|
return 200, user
|
||||||
|
|
||||||
else: # create user
|
else: # create user
|
||||||
|
@ -330,9 +338,10 @@ class UserRestServletV2(RestServlet):
|
||||||
target_user, requester, body["avatar_url"], True
|
target_user, requester, body["avatar_url"], True
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = await self.admin_handler.get_user(target_user)
|
user = await self.admin_handler.get_user(target_user)
|
||||||
|
assert user is not None
|
||||||
|
|
||||||
return 201, ret
|
return 201, user
|
||||||
|
|
||||||
|
|
||||||
class UserRegisterServlet(RestServlet):
|
class UserRegisterServlet(RestServlet):
|
||||||
|
@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet):
|
||||||
PATTERNS = admin_patterns("/register")
|
PATTERNS = admin_patterns("/register")
|
||||||
NONCE_TIMEOUT = 60
|
NONCE_TIMEOUT = 60
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.reactor = hs.get_reactor()
|
self.reactor = hs.get_reactor()
|
||||||
self.nonces = {}
|
self.nonces = {} # type: Dict[str, int]
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def _clear_old_nonces(self):
|
def _clear_old_nonces(self):
|
||||||
|
@ -362,7 +371,7 @@ class UserRegisterServlet(RestServlet):
|
||||||
if now - v > self.NONCE_TIMEOUT:
|
if now - v > self.NONCE_TIMEOUT:
|
||||||
del self.nonces[k]
|
del self.nonces[k]
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
"""
|
"""
|
||||||
Generate a new nonce.
|
Generate a new nonce.
|
||||||
"""
|
"""
|
||||||
|
@ -372,7 +381,7 @@ class UserRegisterServlet(RestServlet):
|
||||||
self.nonces[nonce] = int(self.reactor.seconds())
|
self.nonces[nonce] = int(self.reactor.seconds())
|
||||||
return 200, {"nonce": nonce}
|
return 200, {"nonce": nonce}
|
||||||
|
|
||||||
async def on_POST(self, request):
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
self._clear_old_nonces()
|
self._clear_old_nonces()
|
||||||
|
|
||||||
if not self.hs.config.registration_shared_secret:
|
if not self.hs.config.registration_shared_secret:
|
||||||
|
@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet):
|
||||||
client_patterns("/admin" + path_regex, v1=True)
|
client_patterns("/admin" + path_regex, v1=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.admin_handler = hs.get_admin_handler()
|
self.admin_handler = hs.get_admin_handler()
|
||||||
|
|
||||||
async def on_GET(self, request, user_id):
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
auth_user = requester.user
|
auth_user = requester.user
|
||||||
|
@ -508,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, target_user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
@ -550,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
|
||||||
self.account_activity_handler = hs.get_account_validity_handler()
|
self.account_activity_handler = hs.get_account_validity_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request):
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
|
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self._set_password_handler = hs.get_set_password_handler()
|
self._set_password_handler = hs.get_set_password_handler()
|
||||||
|
|
||||||
async def on_POST(self, request, target_user_id):
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, target_user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
"""Post request to allow an administrator reset password for a user.
|
"""Post request to allow an administrator reset password for a user.
|
||||||
This needs user to have administrator access in Synapse.
|
This needs user to have administrator access in Synapse.
|
||||||
"""
|
"""
|
||||||
|
@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
|
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_GET(self, request, target_user_id):
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, target_user_id: str
|
||||||
|
) -> Tuple[int, Optional[List[JsonDict]]]:
|
||||||
"""Get request to search user table for specific users according to
|
"""Get request to search user table for specific users according to
|
||||||
search term.
|
search term.
|
||||||
This needs user to have a administrator access in Synapse.
|
This needs user to have a administrator access in Synapse.
|
||||||
|
@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_GET(self, request, user_id):
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
|
@ -699,7 +718,9 @@ class UserAdminServlet(RestServlet):
|
||||||
|
|
||||||
return 200, {"admin": is_admin}
|
return 200, {"admin": is_admin}
|
||||||
|
|
||||||
async def on_PUT(self, request, user_id):
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
auth_user = requester.user
|
auth_user = requester.user
|
||||||
|
@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def on_GET(self, request, user_id):
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||||
|
@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -891,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
async def on_POST(self, request, user_id):
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
auth_user = requester.user
|
auth_user = requester.user
|
||||||
|
@ -943,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request, user_id):
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
|
||||||
MultiWriterIdGenerator,
|
MultiWriterIdGenerator,
|
||||||
StreamIdGenerator,
|
StreamIdGenerator,
|
||||||
)
|
)
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
from .account_data import AccountDataStore
|
from .account_data import AccountDataStore
|
||||||
|
@ -264,7 +264,7 @@ class DataStore(
|
||||||
|
|
||||||
return [UserPresenceState(**row) for row in rows]
|
return [UserPresenceState(**row) for row in rows]
|
||||||
|
|
||||||
async def get_users(self) -> List[Dict[str, Any]]:
|
async def get_users(self) -> List[JsonDict]:
|
||||||
"""Function to retrieve a list of users in users table.
|
"""Function to retrieve a list of users in users table.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -292,7 +292,7 @@ class DataStore(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
guests: bool = True,
|
guests: bool = True,
|
||||||
deactivated: bool = False,
|
deactivated: bool = False,
|
||||||
) -> Tuple[List[Dict[str, Any]], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
"""Function to retrieve a paginated list of users from
|
"""Function to retrieve a paginated list of users from
|
||||||
users list. This will return a json list of users and the
|
users list. This will return a json list of users and the
|
||||||
total number of users matching the filter criteria.
|
total number of users matching the filter criteria.
|
||||||
|
@ -353,7 +353,7 @@ class DataStore(
|
||||||
"get_users_paginate_txn", get_users_paginate_txn
|
"get_users_paginate_txn", get_users_paginate_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
|
async def search_users(self, term: str) -> Optional[List[JsonDict]]:
|
||||||
"""Function to search users list for one or more users with
|
"""Function to search users list for one or more users with
|
||||||
the matched term.
|
the matched term.
|
||||||
|
|
||||||
|
|
|
@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
start: int,
|
start: int,
|
||||||
limit: int,
|
limit: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
|
order_by: str = MediaSortOrder.CREATED_TS.value,
|
||||||
direction: str = "f",
|
direction: str = "f",
|
||||||
) -> Tuple[List[Dict[str, Any]], int]:
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
"""Get a paginated list of metadata for a local piece of media
|
"""Get a paginated list of metadata for a local piece of media
|
||||||
|
|
Loading…
Reference in a new issue