forked from MirrorHub/synapse
Add typing information to federation_server. (#7219)
This commit is contained in:
parent
ec5ac8e2b1
commit
d78cb31588
3 changed files with 108 additions and 65 deletions
1
changelog.d/7219.misc
Normal file
1
changelog.d/7219.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add typing information to federation server code.
|
|
@ -15,7 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict
|
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
@ -38,6 +38,7 @@ from synapse.api.errors import (
|
||||||
UnsupportedRoomVersionError,
|
UnsupportedRoomVersionError,
|
||||||
)
|
)
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||||
from synapse.federation.persistence import TransactionActions
|
from synapse.federation.persistence import TransactionActions
|
||||||
from synapse.federation.units import Edu, Transaction
|
from synapse.federation.units import Edu, Transaction
|
||||||
|
@ -94,7 +95,9 @@ class FederationServer(FederationBase):
|
||||||
# come in waves.
|
# come in waves.
|
||||||
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
|
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
|
||||||
|
|
||||||
async def on_backfill_request(self, origin, room_id, versions, limit):
|
async def on_backfill_request(
|
||||||
|
self, origin: str, room_id: str, versions: List[str], limit: int
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
with (await self._server_linearizer.queue((origin, room_id))):
|
with (await self._server_linearizer.queue((origin, room_id))):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
@ -107,23 +110,25 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
return 200, res
|
return 200, res
|
||||||
|
|
||||||
async def on_incoming_transaction(self, origin, transaction_data):
|
async def on_incoming_transaction(
|
||||||
|
self, origin: str, transaction_data: JsonDict
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
# keep this as early as possible to make the calculated origin ts as
|
# keep this as early as possible to make the calculated origin ts as
|
||||||
# accurate as possible.
|
# accurate as possible.
|
||||||
request_time = self._clock.time_msec()
|
request_time = self._clock.time_msec()
|
||||||
|
|
||||||
transaction = Transaction(**transaction_data)
|
transaction = Transaction(**transaction_data)
|
||||||
|
|
||||||
if not transaction.transaction_id:
|
if not transaction.transaction_id: # type: ignore
|
||||||
raise Exception("Transaction missing transaction_id")
|
raise Exception("Transaction missing transaction_id")
|
||||||
|
|
||||||
logger.debug("[%s] Got transaction", transaction.transaction_id)
|
logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore
|
||||||
|
|
||||||
# use a linearizer to ensure that we don't process the same transaction
|
# use a linearizer to ensure that we don't process the same transaction
|
||||||
# multiple times in parallel.
|
# multiple times in parallel.
|
||||||
with (
|
with (
|
||||||
await self._transaction_linearizer.queue(
|
await self._transaction_linearizer.queue(
|
||||||
(origin, transaction.transaction_id)
|
(origin, transaction.transaction_id) # type: ignore
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
result = await self._handle_incoming_transaction(
|
result = await self._handle_incoming_transaction(
|
||||||
|
@ -132,31 +137,33 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _handle_incoming_transaction(self, origin, transaction, request_time):
|
async def _handle_incoming_transaction(
|
||||||
|
self, origin: str, transaction: Transaction, request_time: int
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
""" Process an incoming transaction and return the HTTP response
|
""" Process an incoming transaction and return the HTTP response
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
origin (unicode): the server making the request
|
origin: the server making the request
|
||||||
transaction (Transaction): incoming transaction
|
transaction: incoming transaction
|
||||||
request_time (int): timestamp that the HTTP request arrived at
|
request_time: timestamp that the HTTP request arrived at
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[(int, object)]: http response code and body
|
HTTP response code and body
|
||||||
"""
|
"""
|
||||||
response = await self.transaction_actions.have_responded(origin, transaction)
|
response = await self.transaction_actions.have_responded(origin, transaction)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[%s] We've already responded to this request",
|
"[%s] We've already responded to this request",
|
||||||
transaction.transaction_id,
|
transaction.transaction_id, # type: ignore
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
|
||||||
|
|
||||||
# Reject if PDU count > 50 or EDU count > 100
|
# Reject if PDU count > 50 or EDU count > 100
|
||||||
if len(transaction.pdus) > 50 or (
|
if len(transaction.pdus) > 50 or ( # type: ignore
|
||||||
hasattr(transaction, "edus") and len(transaction.edus) > 100
|
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
|
||||||
):
|
):
|
||||||
|
|
||||||
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
||||||
|
@ -204,13 +211,13 @@ class FederationServer(FederationBase):
|
||||||
report back to the sending server.
|
report back to the sending server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
received_pdus_counter.inc(len(transaction.pdus))
|
received_pdus_counter.inc(len(transaction.pdus)) # type: ignore
|
||||||
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
|
|
||||||
pdus_by_room = {}
|
pdus_by_room = {} # type: Dict[str, List[EventBase]]
|
||||||
|
|
||||||
for p in transaction.pdus:
|
for p in transaction.pdus: # type: ignore
|
||||||
if "unsigned" in p:
|
if "unsigned" in p:
|
||||||
unsigned = p["unsigned"]
|
unsigned = p["unsigned"]
|
||||||
if "age" in unsigned:
|
if "age" in unsigned:
|
||||||
|
@ -254,7 +261,7 @@ class FederationServer(FederationBase):
|
||||||
# require callouts to other servers to fetch missing events), but
|
# require callouts to other servers to fetch missing events), but
|
||||||
# impose a limit to avoid going too crazy with ram/cpu.
|
# impose a limit to avoid going too crazy with ram/cpu.
|
||||||
|
|
||||||
async def process_pdus_for_room(room_id):
|
async def process_pdus_for_room(room_id: str):
|
||||||
logger.debug("Processing PDUs for %s", room_id)
|
logger.debug("Processing PDUs for %s", room_id)
|
||||||
try:
|
try:
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
@ -310,7 +317,9 @@ class FederationServer(FederationBase):
|
||||||
TRANSACTION_CONCURRENCY_LIMIT,
|
TRANSACTION_CONCURRENCY_LIMIT,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_context_state_request(self, origin, room_id, event_id):
|
async def on_context_state_request(
|
||||||
|
self, origin: str, room_id: str, event_id: str
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
|
@ -338,7 +347,9 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
return 200, resp
|
return 200, resp
|
||||||
|
|
||||||
async def on_state_ids_request(self, origin, room_id, event_id):
|
async def on_state_ids_request(
|
||||||
|
self, origin: str, room_id: str, event_id: str
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
if not event_id:
|
if not event_id:
|
||||||
raise NotImplementedError("Specify an event")
|
raise NotImplementedError("Specify an event")
|
||||||
|
|
||||||
|
@ -354,7 +365,9 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
||||||
|
|
||||||
async def _on_context_state_request_compute(self, room_id, event_id):
|
async def _on_context_state_request_compute(
|
||||||
|
self, room_id: str, event_id: str
|
||||||
|
) -> Dict[str, list]:
|
||||||
if event_id:
|
if event_id:
|
||||||
pdus = await self.handler.get_state_for_pdu(room_id, event_id)
|
pdus = await self.handler.get_state_for_pdu(room_id, event_id)
|
||||||
else:
|
else:
|
||||||
|
@ -367,7 +380,9 @@ class FederationServer(FederationBase):
|
||||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def on_pdu_request(self, origin, event_id):
|
async def on_pdu_request(
|
||||||
|
self, origin: str, event_id: str
|
||||||
|
) -> Tuple[int, Union[JsonDict, str]]:
|
||||||
pdu = await self.handler.get_persisted_pdu(origin, event_id)
|
pdu = await self.handler.get_persisted_pdu(origin, event_id)
|
||||||
|
|
||||||
if pdu:
|
if pdu:
|
||||||
|
@ -375,12 +390,16 @@ class FederationServer(FederationBase):
|
||||||
else:
|
else:
|
||||||
return 404, ""
|
return 404, ""
|
||||||
|
|
||||||
async def on_query_request(self, query_type, args):
|
async def on_query_request(
|
||||||
|
self, query_type: str, args: Dict[str, str]
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
received_queries_counter.labels(query_type).inc()
|
received_queries_counter.labels(query_type).inc()
|
||||||
resp = await self.registry.on_query(query_type, args)
|
resp = await self.registry.on_query(query_type, args)
|
||||||
return 200, resp
|
return 200, resp
|
||||||
|
|
||||||
async def on_make_join_request(self, origin, room_id, user_id, supported_versions):
|
async def on_make_join_request(
|
||||||
|
self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
|
@ -397,7 +416,7 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
async def on_invite_request(
|
async def on_invite_request(
|
||||||
self, origin: str, content: JsonDict, room_version_id: str
|
self, origin: str, content: JsonDict, room_version_id: str
|
||||||
):
|
) -> Dict[str, Any]:
|
||||||
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
||||||
if not room_version:
|
if not room_version:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -414,7 +433,9 @@ class FederationServer(FederationBase):
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
return {"event": ret_pdu.get_pdu_json(time_now)}
|
return {"event": ret_pdu.get_pdu_json(time_now)}
|
||||||
|
|
||||||
async def on_send_join_request(self, origin, content, room_id):
|
async def on_send_join_request(
|
||||||
|
self, origin: str, content: JsonDict, room_id: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
logger.debug("on_send_join_request: content: %s", content)
|
logger.debug("on_send_join_request: content: %s", content)
|
||||||
|
|
||||||
room_version = await self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
@ -434,7 +455,9 @@ class FederationServer(FederationBase):
|
||||||
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
|
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def on_make_leave_request(self, origin, room_id, user_id):
|
async def on_make_leave_request(
|
||||||
|
self, origin: str, room_id: str, user_id: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
|
pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
|
||||||
|
@ -444,7 +467,9 @@ class FederationServer(FederationBase):
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
||||||
|
|
||||||
async def on_send_leave_request(self, origin, content, room_id):
|
async def on_send_leave_request(
|
||||||
|
self, origin: str, content: JsonDict, room_id: str
|
||||||
|
) -> dict:
|
||||||
logger.debug("on_send_leave_request: content: %s", content)
|
logger.debug("on_send_leave_request: content: %s", content)
|
||||||
|
|
||||||
room_version = await self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
@ -460,7 +485,9 @@ class FederationServer(FederationBase):
|
||||||
await self.handler.on_send_leave_request(origin, pdu)
|
await self.handler.on_send_leave_request(origin, pdu)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def on_event_auth(self, origin, room_id, event_id):
|
async def on_event_auth(
|
||||||
|
self, origin: str, room_id: str, event_id: str
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
with (await self._server_linearizer.queue((origin, room_id))):
|
with (await self._server_linearizer.queue((origin, room_id))):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
@ -471,15 +498,21 @@ class FederationServer(FederationBase):
|
||||||
return 200, res
|
return 200, res
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def on_query_client_keys(self, origin, content):
|
async def on_query_client_keys(
|
||||||
return self.on_query_request("client_keys", content)
|
self, origin: str, content: Dict[str, str]
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
|
return await self.on_query_request("client_keys", content)
|
||||||
|
|
||||||
async def on_query_user_devices(self, origin: str, user_id: str):
|
async def on_query_user_devices(
|
||||||
|
self, origin: str, user_id: str
|
||||||
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
keys = await self.device_handler.on_federation_query_user_devices(user_id)
|
keys = await self.device_handler.on_federation_query_user_devices(user_id)
|
||||||
return 200, keys
|
return 200, keys
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def on_claim_client_keys(self, origin, content):
|
async def on_claim_client_keys(
|
||||||
|
self, origin: str, content: JsonDict
|
||||||
|
) -> Dict[str, Any]:
|
||||||
query = []
|
query = []
|
||||||
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||||
for device_id, algorithm in device_keys.items():
|
for device_id, algorithm in device_keys.items():
|
||||||
|
@ -488,7 +521,7 @@ class FederationServer(FederationBase):
|
||||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||||
results = await self.store.claim_e2e_one_time_keys(query)
|
results = await self.store.claim_e2e_one_time_keys(query)
|
||||||
|
|
||||||
json_result = {}
|
json_result = {} # type: Dict[str, Dict[str, dict]]
|
||||||
for user_id, device_keys in results.items():
|
for user_id, device_keys in results.items():
|
||||||
for device_id, keys in device_keys.items():
|
for device_id, keys in device_keys.items():
|
||||||
for key_id, json_bytes in keys.items():
|
for key_id, json_bytes in keys.items():
|
||||||
|
@ -511,8 +544,13 @@ class FederationServer(FederationBase):
|
||||||
return {"one_time_keys": json_result}
|
return {"one_time_keys": json_result}
|
||||||
|
|
||||||
async def on_get_missing_events(
|
async def on_get_missing_events(
|
||||||
self, origin, room_id, earliest_events, latest_events, limit
|
self,
|
||||||
):
|
origin: str,
|
||||||
|
room_id: str,
|
||||||
|
earliest_events: List[str],
|
||||||
|
latest_events: List[str],
|
||||||
|
limit: int,
|
||||||
|
) -> Dict[str, list]:
|
||||||
with (await self._server_linearizer.queue((origin, room_id))):
|
with (await self._server_linearizer.queue((origin, room_id))):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
@ -541,11 +579,11 @@ class FederationServer(FederationBase):
|
||||||
return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
|
return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def on_openid_userinfo(self, token):
|
async def on_openid_userinfo(self, token: str) -> Optional[str]:
|
||||||
ts_now_ms = self._clock.time_msec()
|
ts_now_ms = self._clock.time_msec()
|
||||||
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
|
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
|
||||||
|
|
||||||
def _transaction_from_pdus(self, pdu_list):
|
def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
|
||||||
"""Returns a new Transaction containing the given PDUs suitable for
|
"""Returns a new Transaction containing the given PDUs suitable for
|
||||||
transmission.
|
transmission.
|
||||||
"""
|
"""
|
||||||
|
@ -558,7 +596,7 @@ class FederationServer(FederationBase):
|
||||||
destination=None,
|
destination=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_received_pdu(self, origin, pdu):
|
async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
|
||||||
""" Process a PDU received in a federation /send/ transaction.
|
""" Process a PDU received in a federation /send/ transaction.
|
||||||
|
|
||||||
If the event is invalid, then this method throws a FederationError.
|
If the event is invalid, then this method throws a FederationError.
|
||||||
|
@ -579,10 +617,8 @@ class FederationServer(FederationBase):
|
||||||
until we try to backfill across the discontinuity.
|
until we try to backfill across the discontinuity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
origin (str): server which sent the pdu
|
origin: server which sent the pdu
|
||||||
pdu (FrozenEvent): received pdu
|
pdu: received pdu
|
||||||
|
|
||||||
Returns (Deferred): completes with None
|
|
||||||
|
|
||||||
Raises: FederationError if the signatures / hash do not match, or
|
Raises: FederationError if the signatures / hash do not match, or
|
||||||
if the event was unacceptable for any other reason (eg, too large,
|
if the event was unacceptable for any other reason (eg, too large,
|
||||||
|
@ -625,25 +661,27 @@ class FederationServer(FederationBase):
|
||||||
return "<ReplicationLayer(%s)>" % self.server_name
|
return "<ReplicationLayer(%s)>" % self.server_name
|
||||||
|
|
||||||
async def exchange_third_party_invite(
|
async def exchange_third_party_invite(
|
||||||
self, sender_user_id, target_user_id, room_id, signed
|
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
|
||||||
):
|
):
|
||||||
ret = await self.handler.exchange_third_party_invite(
|
ret = await self.handler.exchange_third_party_invite(
|
||||||
sender_user_id, target_user_id, room_id, signed
|
sender_user_id, target_user_id, room_id, signed
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def on_exchange_third_party_invite_request(self, room_id, event_dict):
|
async def on_exchange_third_party_invite_request(
|
||||||
|
self, room_id: str, event_dict: Dict
|
||||||
|
):
|
||||||
ret = await self.handler.on_exchange_third_party_invite_request(
|
ret = await self.handler.on_exchange_third_party_invite_request(
|
||||||
room_id, event_dict
|
room_id, event_dict
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def check_server_matches_acl(self, server_name, room_id):
|
async def check_server_matches_acl(self, server_name: str, room_id: str):
|
||||||
"""Check if the given server is allowed by the server ACLs in the room
|
"""Check if the given server is allowed by the server ACLs in the room
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name (str): name of server, *without any port part*
|
server_name: name of server, *without any port part*
|
||||||
room_id (str): ID of the room to check
|
room_id: ID of the room to check
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if the server does not match the ACL
|
AuthError if the server does not match the ACL
|
||||||
|
@ -661,15 +699,15 @@ class FederationServer(FederationBase):
|
||||||
raise AuthError(code=403, msg="Server is banned from room")
|
raise AuthError(code=403, msg="Server is banned from room")
|
||||||
|
|
||||||
|
|
||||||
def server_matches_acl_event(server_name, acl_event):
|
def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
|
||||||
"""Check if the given server is allowed by the ACL event
|
"""Check if the given server is allowed by the ACL event
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name (str): name of server, without any port part
|
server_name: name of server, without any port part
|
||||||
acl_event (EventBase): m.room.server_acl event
|
acl_event: m.room.server_acl event
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if this server is allowed by the ACLs
|
True if this server is allowed by the ACLs
|
||||||
"""
|
"""
|
||||||
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
|
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
|
||||||
|
|
||||||
|
@ -713,7 +751,7 @@ def server_matches_acl_event(server_name, acl_event):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _acl_entry_matches(server_name, acl_entry):
|
def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
|
||||||
if not isinstance(acl_entry, six.string_types):
|
if not isinstance(acl_entry, six.string_types):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
|
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
|
||||||
|
@ -732,13 +770,13 @@ class FederationHandlerRegistry(object):
|
||||||
self.edu_handlers = {}
|
self.edu_handlers = {}
|
||||||
self.query_handlers = {}
|
self.query_handlers = {}
|
||||||
|
|
||||||
def register_edu_handler(self, edu_type, handler):
|
def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
|
||||||
"""Sets the handler callable that will be used to handle an incoming
|
"""Sets the handler callable that will be used to handle an incoming
|
||||||
federation EDU of the given type.
|
federation EDU of the given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
edu_type (str): The type of the incoming EDU to register handler for
|
edu_type: The type of the incoming EDU to register handler for
|
||||||
handler (Callable[[str, dict]]): A callable invoked on incoming EDU
|
handler: A callable invoked on incoming EDU
|
||||||
of the given type. The arguments are the origin server name and
|
of the given type. The arguments are the origin server name and
|
||||||
the EDU contents.
|
the EDU contents.
|
||||||
"""
|
"""
|
||||||
|
@ -749,14 +787,16 @@ class FederationHandlerRegistry(object):
|
||||||
|
|
||||||
self.edu_handlers[edu_type] = handler
|
self.edu_handlers[edu_type] = handler
|
||||||
|
|
||||||
def register_query_handler(self, query_type, handler):
|
def register_query_handler(
|
||||||
|
self, query_type: str, handler: Callable[[dict], defer.Deferred]
|
||||||
|
):
|
||||||
"""Sets the handler callable that will be used to handle an incoming
|
"""Sets the handler callable that will be used to handle an incoming
|
||||||
federation query of the given type.
|
federation query of the given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_type (str): Category name of the query, which should match
|
query_type: Category name of the query, which should match
|
||||||
the string used by make_query.
|
the string used by make_query.
|
||||||
handler (Callable[[dict], Deferred[dict]]): Invoked to handle
|
handler: Invoked to handle
|
||||||
incoming queries of this type. The return will be yielded
|
incoming queries of this type. The return will be yielded
|
||||||
on and the result used as the response to the query request.
|
on and the result used as the response to the query request.
|
||||||
"""
|
"""
|
||||||
|
@ -767,10 +807,11 @@ class FederationHandlerRegistry(object):
|
||||||
|
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
async def on_edu(self, edu_type, origin, content):
|
async def on_edu(self, edu_type: str, origin: str, content: dict):
|
||||||
handler = self.edu_handlers.get(edu_type)
|
handler = self.edu_handlers.get(edu_type)
|
||||||
if not handler:
|
if not handler:
|
||||||
logger.warning("No handler registered for EDU type %s", edu_type)
|
logger.warning("No handler registered for EDU type %s", edu_type)
|
||||||
|
return
|
||||||
|
|
||||||
with start_active_span_from_edu(content, "handle_edu"):
|
with start_active_span_from_edu(content, "handle_edu"):
|
||||||
try:
|
try:
|
||||||
|
@ -780,7 +821,7 @@ class FederationHandlerRegistry(object):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to handle edu %r", edu_type)
|
logger.exception("Failed to handle edu %r", edu_type)
|
||||||
|
|
||||||
def on_query(self, query_type, args):
|
def on_query(self, query_type: str, args: dict) -> defer.Deferred:
|
||||||
handler = self.query_handlers.get(query_type)
|
handler = self.query_handlers.get(query_type)
|
||||||
if not handler:
|
if not handler:
|
||||||
logger.warning("No handler registered for query type %s", query_type)
|
logger.warning("No handler registered for query type %s", query_type)
|
||||||
|
@ -807,7 +848,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
||||||
|
|
||||||
super(ReplicationFederationHandlerRegistry, self).__init__()
|
super(ReplicationFederationHandlerRegistry, self).__init__()
|
||||||
|
|
||||||
async def on_edu(self, edu_type, origin, content):
|
async def on_edu(self, edu_type: str, origin: str, content: dict):
|
||||||
"""Overrides FederationHandlerRegistry
|
"""Overrides FederationHandlerRegistry
|
||||||
"""
|
"""
|
||||||
if not self.config.use_presence and edu_type == "m.presence":
|
if not self.config.use_presence and edu_type == "m.presence":
|
||||||
|
@ -821,7 +862,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
||||||
|
|
||||||
return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
|
return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
|
||||||
|
|
||||||
async def on_query(self, query_type, args):
|
async def on_query(self, query_type: str, args: dict):
|
||||||
"""Overrides FederationHandlerRegistry
|
"""Overrides FederationHandlerRegistry
|
||||||
"""
|
"""
|
||||||
handler = self.query_handlers.get(query_type)
|
handler = self.query_handlers.get(query_type)
|
||||||
|
|
1
tox.ini
1
tox.ini
|
@ -183,6 +183,7 @@ commands = mypy \
|
||||||
synapse/events/spamcheck.py \
|
synapse/events/spamcheck.py \
|
||||||
synapse/federation/federation_base.py \
|
synapse/federation/federation_base.py \
|
||||||
synapse/federation/federation_client.py \
|
synapse/federation/federation_client.py \
|
||||||
|
synapse/federation/federation_server.py \
|
||||||
synapse/federation/sender \
|
synapse/federation/sender \
|
||||||
synapse/federation/transport \
|
synapse/federation/transport \
|
||||||
synapse/handlers/auth.py \
|
synapse/handlers/auth.py \
|
||||||
|
|
Loading…
Reference in a new issue