mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 23:43:55 +01:00
Update EventContext get_current_event_ids
and get_prev_event_ids
to accept state filters and update calls where possible (#12791)
This commit is contained in:
parent
2be5a2b07b
commit
71e8afe34d
10 changed files with 65 additions and 18 deletions
1
changelog.d/12791.misc
Normal file
1
changelog.d/12791.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible.
|
|
@ -24,6 +24,7 @@ from synapse.types import JsonDict, StateMap
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.storage import Storage
|
from synapse.storage import Storage
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, auto_attribs=True)
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
@ -196,7 +197,9 @@ class EventContext:
|
||||||
|
|
||||||
return self._state_group
|
return self._state_group
|
||||||
|
|
||||||
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
|
async def get_current_state_ids(
|
||||||
|
self, state_filter: Optional["StateFilter"] = None
|
||||||
|
) -> Optional[StateMap[str]]:
|
||||||
"""
|
"""
|
||||||
Gets the room state map, including this event - ie, the state in ``state_group``
|
Gets the room state map, including this event - ie, the state in ``state_group``
|
||||||
|
|
||||||
|
@ -204,6 +207,9 @@ class EventContext:
|
||||||
not make it into the room state. This method will raise an exception if
|
not make it into the room state. This method will raise an exception if
|
||||||
``rejected`` is set.
|
``rejected`` is set.
|
||||||
|
|
||||||
|
Arg:
|
||||||
|
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns None if state_group is None, which happens when the associated
|
Returns None if state_group is None, which happens when the associated
|
||||||
event is an outlier.
|
event is an outlier.
|
||||||
|
@ -216,7 +222,7 @@ class EventContext:
|
||||||
|
|
||||||
assert self._state_delta_due_to_event is not None
|
assert self._state_delta_due_to_event is not None
|
||||||
|
|
||||||
prev_state_ids = await self.get_prev_state_ids()
|
prev_state_ids = await self.get_prev_state_ids(state_filter)
|
||||||
|
|
||||||
if self._state_delta_due_to_event:
|
if self._state_delta_due_to_event:
|
||||||
prev_state_ids = dict(prev_state_ids)
|
prev_state_ids = dict(prev_state_ids)
|
||||||
|
@ -224,12 +230,17 @@ class EventContext:
|
||||||
|
|
||||||
return prev_state_ids
|
return prev_state_ids
|
||||||
|
|
||||||
async def get_prev_state_ids(self) -> StateMap[str]:
|
async def get_prev_state_ids(
|
||||||
|
self, state_filter: Optional["StateFilter"] = None
|
||||||
|
) -> StateMap[str]:
|
||||||
"""
|
"""
|
||||||
Gets the room state map, excluding this event.
|
Gets the room state map, excluding this event.
|
||||||
|
|
||||||
For a non-state event, this will be the same as get_current_state_ids().
|
For a non-state event, this will be the same as get_current_state_ids().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns {} if state_group is None, which happens when the associated
|
Returns {} if state_group is None, which happens when the associated
|
||||||
event is an outlier.
|
event is an outlier.
|
||||||
|
@ -239,7 +250,7 @@ class EventContext:
|
||||||
"""
|
"""
|
||||||
assert self.state_group_before_event is not None
|
assert self.state_group_before_event is not None
|
||||||
return await self._storage.state.get_state_ids_for_group(
|
return await self._storage.state.get_state_ids_for_group(
|
||||||
self.state_group_before_event
|
self.state_group_before_event, state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ from synapse.replication.http.federation import (
|
||||||
ReplicationStoreRoomOnOutlierMembershipRestServlet,
|
ReplicationStoreRoomOnOutlierMembershipRestServlet,
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import JsonDict, StateMap, get_domain_from_id
|
from synapse.types import JsonDict, StateMap, get_domain_from_id
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
@ -1259,7 +1260,9 @@ class FederationHandler:
|
||||||
event.content["third_party_invite"]["signed"]["token"],
|
event.content["third_party_invite"]["signed"]["token"],
|
||||||
)
|
)
|
||||||
original_invite = None
|
original_invite = None
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
|
||||||
|
)
|
||||||
original_invite_id = prev_state_ids.get(key)
|
original_invite_id = prev_state_ids.get(key)
|
||||||
if original_invite_id:
|
if original_invite_id:
|
||||||
original_invite = await self.store.get_event(
|
original_invite = await self.store.get_event(
|
||||||
|
@ -1308,7 +1311,9 @@ class FederationHandler:
|
||||||
signed = event.content["third_party_invite"]["signed"]
|
signed = event.content["third_party_invite"]["signed"]
|
||||||
token = signed["token"]
|
token = signed["token"]
|
||||||
|
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
|
||||||
|
)
|
||||||
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
|
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
|
||||||
|
|
||||||
invite_event = None
|
invite_event = None
|
||||||
|
|
|
@ -30,6 +30,7 @@ from typing import (
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
from synapse import event_auth
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventContentFields,
|
EventContentFields,
|
||||||
EventTypes,
|
EventTypes,
|
||||||
|
@ -63,6 +64,7 @@ from synapse.replication.http.federation import (
|
||||||
)
|
)
|
||||||
from synapse.state import StateResolutionStore
|
from synapse.state import StateResolutionStore
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
PersistedEventPosition,
|
PersistedEventPosition,
|
||||||
RoomStreamToken,
|
RoomStreamToken,
|
||||||
|
@ -1500,7 +1502,11 @@ class FederationEventHandler:
|
||||||
return context
|
return context
|
||||||
|
|
||||||
# now check auth against what we think the auth events *should* be.
|
# now check auth against what we think the auth events *should* be.
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
event_types = event_auth.auth_types_for_event(event.room_version, event)
|
||||||
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types(event_types)
|
||||||
|
)
|
||||||
|
|
||||||
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -634,7 +634,9 @@ class EventCreationHandler:
|
||||||
# federation as well as those created locally. As of room v3, aliases events
|
# federation as well as those created locally. As of room v3, aliases events
|
||||||
# can be created by users that are not in the room, therefore we have to
|
# can be created by users that are not in the room, therefore we have to
|
||||||
# tolerate them in event_auth.check().
|
# tolerate them in event_auth.check().
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types([(EventTypes.Member, None)])
|
||||||
|
)
|
||||||
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
|
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
|
||||||
prev_event = (
|
prev_event = (
|
||||||
await self.store.get_event(prev_event_id, allow_none=True)
|
await self.store.get_event(prev_event_id, allow_none=True)
|
||||||
|
@ -761,7 +763,9 @@ class EventCreationHandler:
|
||||||
# This can happen due to out of band memberships
|
# This can happen due to out of band memberships
|
||||||
return None
|
return None
|
||||||
|
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types([(event.type, None)])
|
||||||
|
)
|
||||||
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
||||||
if not prev_event_id:
|
if not prev_event_id:
|
||||||
return None
|
return None
|
||||||
|
@ -1547,7 +1551,11 @@ class EventCreationHandler:
|
||||||
"Redacting MSC2716 events is not supported in this room version",
|
"Redacting MSC2716 events is not supported in this room version",
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
event_types = event_auth.auth_types_for_event(event.room_version, event)
|
||||||
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types(event_types)
|
||||||
|
)
|
||||||
|
|
||||||
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -303,7 +303,10 @@ class RoomCreationHandler:
|
||||||
context=tombstone_context,
|
context=tombstone_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
old_room_state = await tombstone_context.get_current_state_ids()
|
state_filter = StateFilter.from_types(
|
||||||
|
[(EventTypes.CanonicalAlias, ""), (EventTypes.PowerLevels, "")]
|
||||||
|
)
|
||||||
|
old_room_state = await tombstone_context.get_current_state_ids(state_filter)
|
||||||
|
|
||||||
# We know the tombstone event isn't an outlier so it has current state.
|
# We know the tombstone event isn't an outlier so it has current state.
|
||||||
assert old_room_state is not None
|
assert old_room_state is not None
|
||||||
|
|
|
@ -38,6 +38,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
Requester,
|
Requester,
|
||||||
|
@ -362,7 +363,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
historical=historical,
|
historical=historical,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types([(EventTypes.Member, None)])
|
||||||
|
)
|
||||||
|
|
||||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||||
|
|
||||||
|
@ -1160,7 +1163,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
else:
|
else:
|
||||||
requester = types.create_requester(target_user)
|
requester = types.create_requester(target_user)
|
||||||
|
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types([(EventTypes.GuestAccess, None)])
|
||||||
|
)
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
guest_can_join = await self._can_guest_join(prev_state_ids)
|
guest_can_join = await self._can_guest_join(prev_state_ids)
|
||||||
|
|
|
@ -20,7 +20,7 @@ import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
||||||
from synapse.event_auth import get_user_power_level
|
from synapse.event_auth import auth_types_for_event, get_user_power_level
|
||||||
from synapse.events import EventBase, relation_from_event
|
from synapse.events import EventBase, relation_from_event
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.state import POWER_KEY
|
from synapse.state import POWER_KEY
|
||||||
|
@ -31,6 +31,7 @@ from synapse.util.caches.descriptors import lru_cache
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
|
||||||
|
from ..storage.state import StateFilter
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -168,8 +169,12 @@ class BulkPushRuleEvaluator:
|
||||||
async def _get_power_levels_and_sender_level(
|
async def _get_power_levels_and_sender_level(
|
||||||
self, event: EventBase, context: EventContext
|
self, event: EventBase, context: EventContext
|
||||||
) -> Tuple[dict, int]:
|
) -> Tuple[dict, int]:
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
event_types = auth_types_for_event(event.room_version, event)
|
||||||
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
|
StateFilter.from_types(event_types)
|
||||||
|
)
|
||||||
pl_event_id = prev_state_ids.get(POWER_KEY)
|
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||||
|
|
||||||
if pl_event_id:
|
if pl_event_id:
|
||||||
# fastpath: if there's a power level event, that's all we need, and
|
# fastpath: if there's a power level event, that's all we need, and
|
||||||
# not having a power level event is an extreme edge case
|
# not having a power level event is an extreme edge case
|
||||||
|
|
|
@ -634,16 +634,19 @@ class StateGroupStorage:
|
||||||
|
|
||||||
return group_to_state
|
return group_to_state
|
||||||
|
|
||||||
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
|
async def get_state_ids_for_group(
|
||||||
|
self, state_group: int, state_filter: Optional[StateFilter] = None
|
||||||
|
) -> StateMap[str]:
|
||||||
"""Get the event IDs of all the state in the given state group
|
"""Get the event IDs of all the state in the given state group
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_group: A state group for which we want to get the state IDs.
|
state_group: A state group for which we want to get the state IDs.
|
||||||
|
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolves to a map of (type, state_key) -> event_id
|
Resolves to a map of (type, state_key) -> event_id
|
||||||
"""
|
"""
|
||||||
group_to_state = await self.get_state_for_groups((state_group,))
|
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
|
||||||
|
|
||||||
return group_to_state[state_group]
|
return group_to_state[state_group]
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,7 @@ class _DummyStore:
|
||||||
|
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
async def get_state_ids_for_group(self, state_group):
|
async def get_state_ids_for_group(self, state_group, state_filter=None):
|
||||||
return self._group_to_state[state_group]
|
return self._group_to_state[state_group]
|
||||||
|
|
||||||
async def store_state_group(
|
async def store_state_group(
|
||||||
|
|
Loading…
Reference in a new issue