Reduce the amount of state we pull from the DB (#12811)

This commit is contained in:
Erik Johnston 2022-06-06 11:24:12 +03:00 committed by GitHub
parent 6b46c3eb3d
commit e3163e2e11
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 162 additions and 147 deletions

1
changelog.d/12811.misc Normal file
View file

@ -0,0 +1 @@
Reduce the amount of state we pull from the DB.

View file

@ -29,12 +29,11 @@ from synapse.api.errors import (
MissingClientTokenError, MissingClientTokenError,
) )
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, StateMap, UserID, create_requester from synapse.types import Requester, UserID, create_requester
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
@ -61,8 +60,8 @@ class Auth:
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
self._account_validity_handler = hs.get_account_validity_handler() self._account_validity_handler = hs.get_account_validity_handler()
self._storage_controllers = hs.get_storage_controllers()
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache" 10000, "token_cache"
@ -79,9 +78,8 @@ class Auth:
self, self,
room_id: str, room_id: str,
user_id: str, user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False, allow_departed_users: bool = False,
) -> EventBase: ) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point. """Check if the user is in the room, or was at some point.
Args: Args:
room_id: The room to check. room_id: The room to check.
@ -99,29 +97,28 @@ class Auth:
Raises: Raises:
AuthError if the user is/was not in the room. AuthError if the user is/was not in the room.
Returns: Returns:
Membership event for the user if the user was in the The current membership of the user in the room and the
room. This will be the join event if they are currently joined to membership event ID of the user.
the room. This will be the leave event if they have left the room.
""" """
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
if member: (
membership = member.membership membership,
member_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if membership:
if membership == Membership.JOIN: if membership == Membership.JOIN:
return member return membership, member_event_id
# XXX this looks totally bogus. Why do we not allow users who have been banned, # XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited? # or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE: if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id) forgot = await self.store.did_forget(user_id, room_id)
if not forgot: if not forgot:
return member return membership, member_event_id
raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@ -602,8 +599,11 @@ class Auth:
# We currently require the user is a "moderator" in the room. We do this # We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the # by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events # m.room.canonical_alias events
power_level_event = await self.state.get_current_state(
room_id, EventTypes.PowerLevels, "" power_level_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)
) )
auth_events = {} auth_events = {}
@ -693,12 +693,11 @@ class Auth:
# * The user is a non-guest user, and was ever in the room # * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room # * The user is a guest user, and has joined the room
# else it will throw. # else it will throw.
member_event = await self.check_user_in_room( return await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users room_id, user_id, allow_departed_users=allow_departed_users
) )
return member_event.membership, member_event.event_id
except AuthError: except AuthError:
visibility = await self.state.get_current_state( visibility = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
if ( if (

View file

@ -53,6 +53,7 @@ class FederationBase:
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._storage_controllers = hs.get_storage_controllers()
async def _check_sigs_and_hash( async def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase self, room_version: RoomVersion, pdu: EventBase

View file

@ -1223,14 +1223,10 @@ class FederationServer(FederationBase):
Raises: Raises:
AuthError if the server does not match the ACL AuthError if the server does not match the ACL
""" """
state_ids = await self._state_storage_controller.get_current_state_ids(room_id) acl_event = await self._storage_controllers.state.get_current_state_event(
acl_event_id = state_ids.get((EventTypes.ServerACL, "")) room_id, EventTypes.ServerACL, ""
)
if not acl_event_id: if not acl_event or server_matches_acl_event(server_name, acl_event):
return
acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
return return
raise AuthError(code=403, msg="Server is banned from room") raise AuthError(code=403, msg="Server is banned from room")

View file

@ -320,7 +320,7 @@ class DirectoryHandler:
Raises: Raises:
ShadowBanError if the requester has been shadow-banned. ShadowBanError if the requester has been shadow-banned.
""" """
alias_event = await self.state.get_current_state( alias_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.CanonicalAlias, "" room_id, EventTypes.CanonicalAlias, ""
) )

View file

@ -371,7 +371,7 @@ class FederationHandler:
# First we try hosts that are already in the room # First we try hosts that are already in the room
# TODO: HEURISTIC ALERT. # TODO: HEURISTIC ALERT.
curr_state = await self.state_handler.get_current_state(room_id) curr_state = await self._storage_controllers.state.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state) curr_domains = get_domains_from_state(curr_state)

View file

@ -1584,9 +1584,11 @@ class FederationEventHandler:
if guest_access == GuestAccess.CAN_JOIN: if guest_access == GuestAccess.CAN_JOIN:
return return
current_state_map = await self._state_handler.get_current_state(event.room_id) current_state = await self._storage_controllers.state.get_current_state(
current_state = list(current_state_map.values()) event.room_id
await self._get_room_member_handler().kick_guest_users(current_state) )
current_state_list = list(current_state.values())
await self._get_room_member_handler().kick_guest_users(current_state_list)
async def _check_for_soft_fail( async def _check_for_soft_fail(
self, self,
@ -1614,6 +1616,9 @@ class FederationEventHandler:
room_version = await self._store.get_room_version_id(event.room_id) room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# The event types we want to pull from the "current" state.
auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state". # Calculate the "current state".
if state_ids is not None: if state_ids is not None:
# If we're explicitly given the state then we won't have all the # If we're explicitly given the state then we won't have all the
@ -1643,8 +1648,10 @@ class FederationEventHandler:
) )
) )
else: else:
current_state_ids = await self._state_handler.get_current_state_ids( current_state_ids = (
event.room_id, latest_event_ids=extrem_ids await self._state_storage_controller.get_current_state_ids(
event.room_id, StateFilter.from_types(auth_types)
)
) )
logger.debug( logger.debug(
@ -1654,7 +1661,6 @@ class FederationEventHandler:
) )
# Now check if event pass auth against said current state # Now check if event pass auth against said current state
auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [ current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types e for k, e in current_state_ids.items() if k in auth_types
] ]

View file

@ -190,7 +190,7 @@ class InitialSyncHandler:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
room_end_token = now_token.room_key room_end_token = now_token.room_key
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.state_handler.get_current_state, event.room_id self._state_storage_controller.get_current_state, event.room_id
) )
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken( room_end_token = RoomStreamToken(
@ -407,7 +407,9 @@ class InitialSyncHandler:
membership: str, membership: str,
is_peeking: bool, is_peeking: bool,
) -> JsonDict: ) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id) current_state = await self._storage_controllers.state.get_current_state(
room_id=room_id
)
# TODO: These concurrently # TODO: These concurrently
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -125,7 +125,9 @@ class MessageHandler:
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
data = await self.state.get_current_state(room_id, event_type, state_key) data = await self._storage_controllers.state.get_current_state_event(
room_id, event_type, state_key
)
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist. # If the membership is not JOIN, then the event ID should exist.

View file

@ -1333,6 +1333,7 @@ class TimestampLookupHandler:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.federation_client = hs.get_federation_client() self.federation_client = hs.get_federation_client()
self._storage_controllers = hs.get_storage_controllers()
async def get_event_for_timestamp( async def get_event_for_timestamp(
self, self,
@ -1406,7 +1407,9 @@ class TimestampLookupHandler:
) )
# Find other homeservers from the given state in the room # Find other homeservers from the given state in the room
curr_state = await self.state_handler.get_current_state(room_id) curr_state = await self._storage_controllers.state.get_current_state(
room_id
)
curr_domains = get_domains_from_state(curr_state) curr_domains = get_domains_from_state(curr_state)
likely_domains = [ likely_domains = [
domain for domain, depth in curr_domains if domain != self.server_name domain for domain, depth in curr_domains if domain != self.server_name

View file

@ -1401,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: Optional[str], txn_id: Optional[str],
id_access_token: Optional[str] = None, id_access_token: Optional[str] = None,
) -> int: ) -> int:
room_state = await self.state_handler.get_current_state(room_id) room_state = await self._storage_controllers.state.get_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Member, user.to_string()),
(EventTypes.CanonicalAlias, ""),
(EventTypes.Name, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
(EventTypes.RoomAvatar, ""),
]
),
)
inviter_display_name = "" inviter_display_name = ""
inviter_avatar_url = "" inviter_avatar_url = ""
@ -1797,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
async def forget(self, user: UserID, room_id: str) -> None: async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string() user_id = user.to_string()
member = await self.state_handler.get_current_state( member = await self._storage_controllers.state.get_current_state_event(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id room_id=room_id, event_type=EventTypes.Member, state_key=user_id
) )
membership = member.membership if member else None membership = member.membership if member else None

View file

@ -348,7 +348,7 @@ class SearchHandler:
state_results = {} state_results = {}
if include_state: if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}: for room_id in {e.room_id for e in search_result.allowed_events}:
state = await self.state_handler.get_current_state(room_id) state = await self._storage_controllers.state.get_current_state(room_id)
state_results[room_id] = list(state.values()) state_results[room_id] = list(state.values())
aggregations = await self._relations_handler.get_bundled_aggregations( aggregations = await self._relations_handler.get_bundled_aggregations(

View file

@ -681,7 +681,7 @@ class Notifier:
return joined_room_ids, True return joined_room_ids, True
async def _is_world_readable(self, room_id: str) -> bool: async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state( state = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
if state and "history_visibility" in state.content: if state and "history_visibility" in state.content:

View file

@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin, assert_user_is_admin,
) )
from synapse.storage.databases.main.room import RoomSortOrder from synapse.storage.databases.main.room import RoomSortOrder
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.types import JsonDict, RoomID, UserID, create_requester
from synapse.util import json_decoder from synapse.util import json_decoder
@ -448,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs) super().__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler() self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
async def on_POST( async def on_POST(
@ -490,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
) )
# send invite if room has "JoinRules.INVITE" # send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id) join_rules_event = (
join_rules_event = room_state.get((EventTypes.JoinRules, "")) await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.JoinRules, ""
)
)
if join_rules_event: if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
# update_membership with an action of "invite" can raise a # update_membership with an action of "invite" can raise a
@ -536,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs) super().__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@ -553,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
user_to_add = content.get("user_id", requester.user.to_string()) user_to_add = content.get("user_id", requester.user.to_string())
# Figure out which local users currently have power in the room, if any. # Figure out which local users currently have power in the room, if any.
room_state = await self.state_handler.get_current_state(room_id) filtered_room_state = await self._state_storage_controller.get_current_state(
if not room_state: room_id,
StateFilter.from_types(
[
(EventTypes.Create, ""),
(EventTypes.PowerLevels, ""),
(EventTypes.JoinRules, ""),
(EventTypes.Member, user_to_add),
]
),
)
if not filtered_room_state:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
create_event = room_state[(EventTypes.Create, "")] create_event = filtered_room_state[(EventTypes.Create, "")]
power_levels = room_state.get((EventTypes.PowerLevels, "")) power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
if power_levels is not None: if power_levels is not None:
# We pick the local user with the highest power. # We pick the local user with the highest power.
@ -634,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
# Now we check if the user we're granting admin rights to is already in # Now we check if the user we're granting admin rights to is already in
# the room. If not and it's not a public room we invite them. # the room. If not and it's not a public room we invite them.
member_event = room_state.get((EventTypes.Member, user_to_add)) member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
is_joined = False is_joined = False
if member_event: if member_event:
is_joined = member_event.content["membership"] in ( is_joined = member_event.content["membership"] in (
@ -645,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
if is_joined: if is_joined:
return HTTPStatus.OK, {} return HTTPStatus.OK, {}
join_rules = room_state.get((EventTypes.JoinRules, "")) join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
is_public = False is_public = False
if join_rules: if join_rules:
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC

View file

@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self._state = hs.get_state_handler() self._state = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler() self._relations_handler = hs.get_relations_handler()
@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet):
if include_unredacted_content and not await self.auth.is_server_admin( if include_unredacted_content and not await self.auth.is_server_admin(
requester.user requester.user
): ):
power_level_event = await self._state.get_current_state( power_level_event = (
room_id, EventTypes.PowerLevels, "" await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)
) )
auth_events = {} auth_events = {}

View file

@ -36,6 +36,7 @@ class ResourceLimitsServerNotices:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._server_notices_manager = hs.get_server_notices_manager() self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._config = hs.config self._config = hs.config
self._resouce_limited = False self._resouce_limited = False
@ -178,8 +179,10 @@ class ResourceLimitsServerNotices:
currently_blocked = False currently_blocked = False
pinned_state_event = None pinned_state_event = None
try: try:
pinned_state_event = await self._state.get_current_state( pinned_state_event = (
room_id, event_type=EventTypes.Pinned await self._storage_controllers.state.get_current_state_event(
room_id, event_type=EventTypes.Pinned, state_key=""
)
) )
except AuthError: except AuthError:
# The user has yet to join the server notices room # The user has yet to join the server notices room

View file

@ -32,13 +32,11 @@ from typing import (
Set, Set,
Tuple, Tuple,
Union, Union,
overload,
) )
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@ -132,85 +130,20 @@ class StateHandler:
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
@overload
async def get_current_state(
self,
room_id: str,
event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
Returns:
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = await self.store.get_event(event_id, allow_none=True)
return event
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
async def get_current_state_ids( async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None self,
room_id: str,
latest_event_ids: Collection[str],
) -> StateMap[str]: ) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room """Get the current state, or the state at a set of events, for a room
Args: Args:
room_id: room_id:
latest_event_ids: if given, the forward extremities to resolve. If latest_event_ids: The forward extremities to resolve.
None, we look them up from the database (via a cache).
Returns: Returns:
the state dict, mapping from (event_type, state_key) -> event_id the state dict, mapping from (event_type, state_key) -> event_id
""" """
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state return ret.state

View file

@ -455,3 +455,30 @@ class StateStorageController:
return await self.stores.main.get_partial_current_state_deltas( return await self.stores.main.get_partial_current_state_deltas(
prev_stream_id, max_stream_id prev_stream_id, max_stream_id
) )
async def get_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""Same as `get_current_state_ids` but also fetches the events"""
state_map_ids = await self.get_current_state_ids(room_id, state_filter)
event_map = await self.stores.main.get_events(list(state_map_ids.values()))
state_map = {}
for key, event_id in state_map_ids.items():
event = event_map.get(event_id)
if event:
state_map[key] = event
return state_map
async def get_current_state_event(
self, room_id: str, event_type: str, state_key: str
) -> Optional[EventBase]:
"""Get the current state event for the given type/state_key."""
key = (event_type, state_key)
state_map = await self.get_current_state(
room_id, StateFilter.from_types((key,))
)
return state_map.get(key)

View file

@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs) super().prepare(reactor, clock, hs)
self._storage_controllers = hs.get_storage_controllers()
# create the room # create the room
creator_user_id = self.register_user("kermit", "test") creator_user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member # the room should show that the new user is a member
r = self.get_success( r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id) self._storage_controllers.state.get_current_state(self._room_id)
) )
self.assertEqual(r[("m.room.member", joining_user)].membership, "join") self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member # the room should show that the new user is a member
r = self.get_success( r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id) self._storage_controllers.state.get_current_state(self._room_id)
) )
self.assertEqual(r[("m.room.member", joining_user)].membership, "join") self.assertEqual(r[("m.room.member", joining_user)].membership, "join")

View file

@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler() self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
# Create user # Create user
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _get_canonical_alias(self): def _get_canonical_alias(self):
"""Get the canonical alias state of the room.""" """Get the canonical alias state of the room."""
return self.get_success( return self.get_success(
self.state_handler.get_current_state( self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, "" self.room_id, EventTypes.CanonicalAlias, ""
) )
) )

View file

@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler() self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence self._persistence = self.hs.get_storage_controllers().persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
self.register_user("user", "pass") self.register_user("user", "pass")
@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
) )
state_before_gap = self.get_success( state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id) self._state_storage_controller.get_current_state_ids(self.room_id)
) )
self.persist_event(remote_event_2, state=state_before_gap) self.persist_event(remote_event_2, state=state_before_gap)
@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then # setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state. # include it, and so the resolved state won't match the new state.
state_before_gap = dict( state_before_gap = dict(
self.get_success(self.state.get_current_state_ids(self.room_id)) self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
) )
state_before_gap.pop(("m.room.history_visibility", "")) state_before_gap.pop(("m.room.history_visibility", ""))
@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
) )
state_before_gap = self.get_success( state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id) self._state_storage_controller.get_current_state_ids(self.room_id)
) )
self.persist_event(remote_event_2, state=state_before_gap) self.persist_event(remote_event_2, state=state_before_gap)
@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
) )
state_before_gap = self.get_success( state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id) self._state_storage_controller.get_current_state_ids(self.room_id)
) )
self.persist_event(remote_event_2, state=state_before_gap) self.persist_event(remote_event_2, state=state_before_gap)
@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
) )
state_before_gap = self.get_success( state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id) self._state_storage_controller.get_current_state_ids(self.room_id)
) )
self.persist_event(remote_event_2, state=state_before_gap) self.persist_event(remote_event_2, state=state_before_gap)
@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
) )
state_before_gap = self.get_success( state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id) self._state_storage_controller.get_current_state_ids(self.room_id)
) )
self.persist_event(remote_event_2, state=state_before_gap) self.persist_event(remote_event_2, state=state_before_gap)
@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
) )
state_before_gap = self.get_success( state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id) self._state_storage_controller.get_current_state_ids(self.room_id)
) )
self.persist_event(remote_event_2, state=state_before_gap) self.persist_event(remote_event_2, state=state_before_gap)

View file

@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase):
first = self.helper.send(self.room_id, body="test1") first = self.helper.send(self.room_id, body="test1")
# Get the current room state. # Get the current room state.
state_handler = self.hs.get_state_handler()
create_event = self.get_success( create_event = self.get_success(
state_handler.get_current_state(self.room_id, "m.room.create", "") self._storage_controllers.state.get_current_state_event(
self.room_id, "m.room.create", ""
)
) )
self.assertIsNotNone(create_event) self.assertIsNotNone(create_event)

View file

@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and # Room events need the full datastore, for persist_event() and
# get_room_state() # get_room_state()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._storage = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory() self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs): def inject_room_event(self, **kwargs):
self.get_success( self.get_success(
self._storage.persistence.persist_event( self._storage_controllers.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
) )
) )
@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
) )
state = self.get_success( state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string()) self._storage_controllers.state.get_current_state(
room_id=self.room.to_string()
)
) )
self.assertEqual(1, len(state)) self.assertEqual(1, len(state))
@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
) )
state = self.get_success( state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string()) self._storage_controllers.state.get_current_state(
room_id=self.room.to_string()
)
) )
self.assertEqual(1, len(state)) self.assertEqual(1, len(state))