Add type hints to auth and auth_blocking. (#9876)

This commit is contained in:
Patrick Cloke 2021-04-23 12:02:16 -04:00 committed by GitHub
parent a15c003e5b
commit e83627926f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 44 deletions

1
changelog.d/9876.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.

View file

@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import pymacaroons import pymacaroons
from netaddr import IPAddress from netaddr import IPAddress
from twisted.web.server import Request from twisted.web.server import Request
import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,7 +70,7 @@ class Auth:
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler. The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -88,13 +90,13 @@ class Auth:
async def check_from_context( async def check_from_context(
self, room_version: str, event, context, do_sig_check=True self, room_version: str, event, context, do_sig_check=True
): ) -> None:
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events( auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = await self.store.get_events(auth_events_ids) auth_events_by_id = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check( event_auth.check(
@ -151,17 +153,11 @@ class Auth:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
async def check_host_in_room(self, room_id, host): async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = await self.store.is_host_joined(room_id, host) return await self.store.is_host_joined(room_id, host)
return latest_event_ids
def can_federate(self, event, auth_events): def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event) return event_auth.get_public_keys(invite_event)
async def get_user_by_req( async def get_user_by_req(
@ -170,7 +166,7 @@ class Auth:
allow_guest: bool = False, allow_guest: bool = False,
rights: str = "access", rights: str = "access",
allow_expired: bool = False, allow_expired: bool = False,
) -> synapse.types.Requester: ) -> Requester:
"""Get a registered user's ID. """Get a registered user's ID.
Args: Args:
@ -196,7 +192,7 @@ class Auth:
access_token = self.get_access_token_from_request(request) access_token = self.get_access_token_from_request(request)
user_id, app_service = await self._get_appservice_user_id(request) user_id, app_service = await self._get_appservice_user_id(request)
if user_id: if user_id and app_service:
if ip_addr and self._track_appservice_user_ips: if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user_id, user_id=user_id,
@ -206,9 +202,7 @@ class Auth:
device_id="dummy-device", # stubbed device_id="dummy-device", # stubbed
) )
requester = synapse.types.create_requester( requester = create_requester(user_id, app_service=app_service)
user_id, app_service=app_service
)
request.requester = user_id request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("authenticated_entity", user_id)
@ -251,7 +245,7 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN, errcode=Codes.GUEST_ACCESS_FORBIDDEN,
) )
requester = synapse.types.create_requester( requester = create_requester(
user_info.user_id, user_info.user_id,
token_id, token_id,
is_guest, is_guest,
@ -271,7 +265,9 @@ class Auth:
except KeyError: except KeyError:
raise MissingClientTokenError() raise MissingClientTokenError()
async def _get_appservice_user_id(self, request): async def _get_appservice_user_id(
self, request: Request
) -> Tuple[Optional[str], Optional[ApplicationService]]:
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request) self.get_access_token_from_request(request)
) )
@ -283,6 +279,9 @@ class Auth:
if ip_address not in app_service.ip_range_whitelist: if ip_address not in app_service.ip_range_whitelist:
return None, None return None, None
# This will always be set by the time Twisted calls us.
assert request.args is not None
if b"user_id" not in request.args: if b"user_id" not in request.args:
return app_service.sender, app_service return app_service.sender, app_service
@ -387,7 +386,9 @@ class Auth:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e) logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise InvalidClientTokenError("Invalid macaroon passed.") raise InvalidClientTokenError("Invalid macaroon passed.")
def _parse_and_validate_macaroon(self, token, rights="access"): def _parse_and_validate_macaroon(
self, token: str, rights: str = "access"
) -> Tuple[str, bool]:
"""Takes a macaroon and tries to parse and validate it. This is cached """Takes a macaroon and tries to parse and validate it. This is cached
if and only if rights == access and there isn't an expiry. if and only if rights == access and there isn't an expiry.
@ -432,15 +433,16 @@ class Auth:
return user_id, guest return user_id, guest
def validate_macaroon(self, macaroon, type_string, user_id): def validate_macaroon(
self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
) -> None:
""" """
validate that a Macaroon is understood by and was signed by this server. validate that a Macaroon is understood by and was signed by this server.
Args: Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon: The macaroon to validate
type_string(str): The kind of token required (e.g. "access", type_string: The kind of token required (e.g. "access", "delete_pusher")
"delete_pusher") user_id: The user_id required
user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
@ -465,9 +467,7 @@ class Auth:
if not service: if not service:
logger.warning("Unrecognised appservice access token.") logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError() raise InvalidClientTokenError()
request.requester = synapse.types.create_requester( request.requester = create_requester(service.sender, app_service=service)
service.sender, app_service=service
)
return service return service
async def is_server_admin(self, user: UserID) -> bool: async def is_server_admin(self, user: UserID) -> bool:
@ -519,7 +519,7 @@ class Auth:
return auth_ids return auth_ids
async def check_can_change_room_list(self, room_id: str, user: UserID): async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the """Determine whether the user is allowed to edit the room's entry in the
published room list. published room list.
@ -554,11 +554,11 @@ class Auth:
return user_level >= send_level return user_level >= send_level
@staticmethod @staticmethod
def has_access_token(request: Request): def has_access_token(request: Request) -> bool:
"""Checks if the request has an access_token. """Checks if the request has an access_token.
Returns: Returns:
bool: False if no access_token was given, True otherwise. False if no access_token was given, True otherwise.
""" """
# This will always be set by the time Twisted calls us. # This will always be set by the time Twisted calls us.
assert request.args is not None assert request.args is not None
@ -568,13 +568,13 @@ class Auth:
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@staticmethod @staticmethod
def get_access_token_from_request(request: Request): def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request. """Extracts the access_token from the request.
Args: Args:
request: The http request. request: The http request.
Returns: Returns:
unicode: The access_token The access_token
Raises: Raises:
MissingClientTokenError: If there isn't a single access_token in the MissingClientTokenError: If there isn't a single access_token in the
request request
@ -649,5 +649,5 @@ class Auth:
% (user_id, room_id), % (user_id, room_id),
) )
def check_auth_blocking(self, *args, **kwargs): async def check_auth_blocking(self, *args, **kwargs) -> None:
return self._auth_blocking.check_auth_blocking(*args, **kwargs) await self._auth_blocking.check_auth_blocking(*args, **kwargs)

View file

@ -13,18 +13,21 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.types import Requester from synapse.types import Requester
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AuthBlocking: class AuthBlocking:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
@ -43,7 +46,7 @@ class AuthBlocking:
threepid: Optional[dict] = None, threepid: Optional[dict] = None,
user_type: Optional[str] = None, user_type: Optional[str] = None,
requester: Optional[Requester] = None, requester: Optional[Requester] = None,
): ) -> None:
"""Checks if the user should be rejected for some external reason, """Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
return False return False
def get_public_keys(invite_event): def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
public_keys = [] public_keys = []
if "public_key" in invite_event.content: if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]} o = {"public_key": invite_event.content["public_key"]}