mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 21:23:51 +01:00
Stronger typing in the federation handler (#6480)
replace the event_info dict with an attrs thing
This commit is contained in:
parent
e1f4c83f41
commit
63d6ad1064
2 changed files with 58 additions and 24 deletions
1
changelog.d/6480.misc
Normal file
1
changelog.d/6480.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Refactor some code in the event authentication path for clarity.
|
|
@ -19,11 +19,13 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict, Iterable, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from six import iteritems, itervalues
|
from six import iteritems, itervalues
|
||||||
from six.moves import http_client, zip
|
from six.moves import http_client, zip
|
||||||
|
|
||||||
|
import attr
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.sign import verify_signed_json
|
from signedjson.sign import verify_signed_json
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
@ -45,6 +47,7 @@ from synapse.api.errors import (
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
from synapse.event_auth import auth_types_for_event
|
from synapse.event_auth import auth_types_for_event
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
|
@ -72,6 +75,23 @@ from ._base import BaseHandler
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class _NewEventInfo:
|
||||||
|
"""Holds information about a received event, ready for passing to _handle_new_events
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
event: the received event
|
||||||
|
|
||||||
|
state: the state at that event
|
||||||
|
|
||||||
|
auth_events: the auth_event map for that event
|
||||||
|
"""
|
||||||
|
|
||||||
|
event = attr.ib(type=EventBase)
|
||||||
|
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
|
||||||
|
auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
|
||||||
|
|
||||||
|
|
||||||
def shortstr(iterable, maxitems=5):
|
def shortstr(iterable, maxitems=5):
|
||||||
"""If iterable has maxitems or fewer, return the stringification of a list
|
"""If iterable has maxitems or fewer, return the stringification of a list
|
||||||
containing those items.
|
containing those items.
|
||||||
|
@ -597,14 +617,14 @@ class FederationHandler(BaseHandler):
|
||||||
for e in auth_chain
|
for e in auth_chain
|
||||||
if e.event_id in auth_ids or e.type == EventTypes.Create
|
if e.event_id in auth_ids or e.type == EventTypes.Create
|
||||||
}
|
}
|
||||||
event_infos.append({"event": e, "auth_events": auth})
|
event_infos.append(_NewEventInfo(event=e, auth_events=auth))
|
||||||
seen_ids.add(e.event_id)
|
seen_ids.add(e.event_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s] persisting newly-received auth/state events %s",
|
"[%s %s] persisting newly-received auth/state events %s",
|
||||||
room_id,
|
room_id,
|
||||||
event_id,
|
event_id,
|
||||||
[e["event"].event_id for e in event_infos],
|
[e.event.event_id for e in event_infos],
|
||||||
)
|
)
|
||||||
yield self._handle_new_events(origin, event_infos)
|
yield self._handle_new_events(origin, event_infos)
|
||||||
|
|
||||||
|
@ -795,9 +815,9 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
a.internal_metadata.outlier = True
|
a.internal_metadata.outlier = True
|
||||||
ev_infos.append(
|
ev_infos.append(
|
||||||
{
|
_NewEventInfo(
|
||||||
"event": a,
|
event=a,
|
||||||
"auth_events": {
|
auth_events={
|
||||||
(
|
(
|
||||||
auth_events[a_id].type,
|
auth_events[a_id].type,
|
||||||
auth_events[a_id].state_key,
|
auth_events[a_id].state_key,
|
||||||
|
@ -805,7 +825,7 @@ class FederationHandler(BaseHandler):
|
||||||
for a_id in a.auth_event_ids()
|
for a_id in a.auth_event_ids()
|
||||||
if a_id in auth_events
|
if a_id in auth_events
|
||||||
},
|
},
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 1b: persist the events in the chunk we fetched state for (i.e.
|
# Step 1b: persist the events in the chunk we fetched state for (i.e.
|
||||||
|
@ -817,10 +837,10 @@ class FederationHandler(BaseHandler):
|
||||||
assert not ev.internal_metadata.is_outlier()
|
assert not ev.internal_metadata.is_outlier()
|
||||||
|
|
||||||
ev_infos.append(
|
ev_infos.append(
|
||||||
{
|
_NewEventInfo(
|
||||||
"event": ev,
|
event=ev,
|
||||||
"state": events_to_state[e_id],
|
state=events_to_state[e_id],
|
||||||
"auth_events": {
|
auth_events={
|
||||||
(
|
(
|
||||||
auth_events[a_id].type,
|
auth_events[a_id].type,
|
||||||
auth_events[a_id].state_key,
|
auth_events[a_id].state_key,
|
||||||
|
@ -828,7 +848,7 @@ class FederationHandler(BaseHandler):
|
||||||
for a_id in ev.auth_event_ids()
|
for a_id in ev.auth_event_ids()
|
||||||
if a_id in auth_events
|
if a_id in auth_events
|
||||||
},
|
},
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._handle_new_events(dest, ev_infos, backfilled=True)
|
yield self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||||
|
@ -1713,7 +1733,12 @@ class FederationHandler(BaseHandler):
|
||||||
return context
|
return context
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _handle_new_events(self, origin, event_infos, backfilled=False):
|
def _handle_new_events(
|
||||||
|
self,
|
||||||
|
origin: str,
|
||||||
|
event_infos: Iterable[_NewEventInfo],
|
||||||
|
backfilled: bool = False,
|
||||||
|
):
|
||||||
"""Creates the appropriate contexts and persists events. The events
|
"""Creates the appropriate contexts and persists events. The events
|
||||||
should not depend on one another, e.g. this should be used to persist
|
should not depend on one another, e.g. this should be used to persist
|
||||||
a bunch of outliers, but not a chunk of individual events that depend
|
a bunch of outliers, but not a chunk of individual events that depend
|
||||||
|
@ -1723,14 +1748,14 @@ class FederationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def prep(ev_info):
|
def prep(ev_info: _NewEventInfo):
|
||||||
event = ev_info["event"]
|
event = ev_info.event
|
||||||
with nested_logging_context(suffix=event.event_id):
|
with nested_logging_context(suffix=event.event_id):
|
||||||
res = yield self._prep_event(
|
res = yield self._prep_event(
|
||||||
origin,
|
origin,
|
||||||
event,
|
event,
|
||||||
state=ev_info.get("state"),
|
state=ev_info.state,
|
||||||
auth_events=ev_info.get("auth_events"),
|
auth_events=ev_info.auth_events,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
@ -1744,7 +1769,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
yield self.persist_events_and_notify(
|
yield self.persist_events_and_notify(
|
||||||
[
|
[
|
||||||
(ev_info["event"], context)
|
(ev_info.event, context)
|
||||||
for ev_info, context in zip(event_infos, contexts)
|
for ev_info, context in zip(event_infos, contexts)
|
||||||
],
|
],
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
|
@ -1846,7 +1871,14 @@ class FederationHandler(BaseHandler):
|
||||||
yield self.persist_events_and_notify([(event, new_event_context)])
|
yield self.persist_events_and_notify([(event, new_event_context)])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _prep_event(self, origin, event, state, auth_events, backfilled):
|
def _prep_event(
|
||||||
|
self,
|
||||||
|
origin: str,
|
||||||
|
event: EventBase,
|
||||||
|
state: Optional[Iterable[EventBase]],
|
||||||
|
auth_events: Optional[Dict[Tuple[str, str], EventBase]],
|
||||||
|
backfilled: bool,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1854,7 +1886,7 @@ class FederationHandler(BaseHandler):
|
||||||
event:
|
event:
|
||||||
state:
|
state:
|
||||||
auth_events:
|
auth_events:
|
||||||
backfilled (bool)
|
backfilled:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred, which resolves to synapse.events.snapshot.EventContext
|
Deferred, which resolves to synapse.events.snapshot.EventContext
|
||||||
|
@ -1890,15 +1922,16 @@ class FederationHandler(BaseHandler):
|
||||||
return context
|
return context
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_for_soft_fail(self, event, state, backfilled):
|
def _check_for_soft_fail(
|
||||||
|
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
|
||||||
|
):
|
||||||
"""Checks if we should soft fail the event, if so marks the event as
|
"""Checks if we should soft fail the event, if so marks the event as
|
||||||
such.
|
such.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (FrozenEvent)
|
event
|
||||||
state (dict|None): The state at the event if we don't have all the
|
state: The state at the event if we don't have all the event's prev events
|
||||||
event's prev events
|
backfilled: Whether the event is from backfill
|
||||||
backfilled (bool): Whether the event is from backfill
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Deferred
|
||||||
|
|
Loading…
Reference in a new issue