mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-22 09:24:00 +01:00
Reduce the amount of state we pull from the DB (#12811)
This commit is contained in:
parent
6b46c3eb3d
commit
e3163e2e11
23 changed files with 162 additions and 147 deletions
1
changelog.d/12811.misc
Normal file
1
changelog.d/12811.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Reduce the amount of state we pull from the DB.
|
|
@ -29,12 +29,11 @@ from synapse.api.errors import (
|
||||||
MissingClientTokenError,
|
MissingClientTokenError,
|
||||||
)
|
)
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.events import EventBase
|
|
||||||
from synapse.http import get_request_user_agent
|
from synapse.http import get_request_user_agent
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
|
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
|
||||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import Requester, StateMap, UserID, create_requester
|
from synapse.types import Requester, UserID, create_requester
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
|
|
||||||
|
@ -61,8 +60,8 @@ class Auth:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.state = hs.get_state_handler()
|
|
||||||
self._account_validity_handler = hs.get_account_validity_handler()
|
self._account_validity_handler = hs.get_account_validity_handler()
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
|
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
|
||||||
10000, "token_cache"
|
10000, "token_cache"
|
||||||
|
@ -79,9 +78,8 @@ class Auth:
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
current_state: Optional[StateMap[EventBase]] = None,
|
|
||||||
allow_departed_users: bool = False,
|
allow_departed_users: bool = False,
|
||||||
) -> EventBase:
|
) -> Tuple[str, Optional[str]]:
|
||||||
"""Check if the user is in the room, or was at some point.
|
"""Check if the user is in the room, or was at some point.
|
||||||
Args:
|
Args:
|
||||||
room_id: The room to check.
|
room_id: The room to check.
|
||||||
|
@ -99,29 +97,28 @@ class Auth:
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if the user is/was not in the room.
|
AuthError if the user is/was not in the room.
|
||||||
Returns:
|
Returns:
|
||||||
Membership event for the user if the user was in the
|
The current membership of the user in the room and the
|
||||||
room. This will be the join event if they are currently joined to
|
membership event ID of the user.
|
||||||
the room. This will be the leave event if they have left the room.
|
|
||||||
"""
|
"""
|
||||||
if current_state:
|
|
||||||
member = current_state.get((EventTypes.Member, user_id), None)
|
(
|
||||||
else:
|
membership,
|
||||||
member = await self.state.get_current_state(
|
member_event_id,
|
||||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
) = await self.store.get_local_current_membership_for_user_in_room(
|
||||||
|
user_id=user_id,
|
||||||
|
room_id=room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if member:
|
if membership:
|
||||||
membership = member.membership
|
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
return member
|
return membership, member_event_id
|
||||||
|
|
||||||
# XXX this looks totally bogus. Why do we not allow users who have been banned,
|
# XXX this looks totally bogus. Why do we not allow users who have been banned,
|
||||||
# or those who were members previously and have been re-invited?
|
# or those who were members previously and have been re-invited?
|
||||||
if allow_departed_users and membership == Membership.LEAVE:
|
if allow_departed_users and membership == Membership.LEAVE:
|
||||||
forgot = await self.store.did_forget(user_id, room_id)
|
forgot = await self.store.did_forget(user_id, room_id)
|
||||||
if not forgot:
|
if not forgot:
|
||||||
return member
|
return membership, member_event_id
|
||||||
|
|
||||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||||
|
|
||||||
|
@ -602,9 +599,12 @@ class Auth:
|
||||||
# We currently require the user is a "moderator" in the room. We do this
|
# We currently require the user is a "moderator" in the room. We do this
|
||||||
# by checking if they would (theoretically) be able to change the
|
# by checking if they would (theoretically) be able to change the
|
||||||
# m.room.canonical_alias events
|
# m.room.canonical_alias events
|
||||||
power_level_event = await self.state.get_current_state(
|
|
||||||
|
power_level_event = (
|
||||||
|
await self._storage_controllers.state.get_current_state_event(
|
||||||
room_id, EventTypes.PowerLevels, ""
|
room_id, EventTypes.PowerLevels, ""
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
auth_events = {}
|
auth_events = {}
|
||||||
if power_level_event:
|
if power_level_event:
|
||||||
|
@ -693,12 +693,11 @@ class Auth:
|
||||||
# * The user is a non-guest user, and was ever in the room
|
# * The user is a non-guest user, and was ever in the room
|
||||||
# * The user is a guest user, and has joined the room
|
# * The user is a guest user, and has joined the room
|
||||||
# else it will throw.
|
# else it will throw.
|
||||||
member_event = await self.check_user_in_room(
|
return await self.check_user_in_room(
|
||||||
room_id, user_id, allow_departed_users=allow_departed_users
|
room_id, user_id, allow_departed_users=allow_departed_users
|
||||||
)
|
)
|
||||||
return member_event.membership, member_event.event_id
|
|
||||||
except AuthError:
|
except AuthError:
|
||||||
visibility = await self.state.get_current_state(
|
visibility = await self._storage_controllers.state.get_current_state_event(
|
||||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -53,6 +53,7 @@ class FederationBase:
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
async def _check_sigs_and_hash(
|
async def _check_sigs_and_hash(
|
||||||
self, room_version: RoomVersion, pdu: EventBase
|
self, room_version: RoomVersion, pdu: EventBase
|
||||||
|
|
|
@ -1223,14 +1223,10 @@ class FederationServer(FederationBase):
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if the server does not match the ACL
|
AuthError if the server does not match the ACL
|
||||||
"""
|
"""
|
||||||
state_ids = await self._state_storage_controller.get_current_state_ids(room_id)
|
acl_event = await self._storage_controllers.state.get_current_state_event(
|
||||||
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
|
room_id, EventTypes.ServerACL, ""
|
||||||
|
)
|
||||||
if not acl_event_id:
|
if not acl_event or server_matches_acl_event(server_name, acl_event):
|
||||||
return
|
|
||||||
|
|
||||||
acl_event = await self.store.get_event(acl_event_id)
|
|
||||||
if server_matches_acl_event(server_name, acl_event):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
raise AuthError(code=403, msg="Server is banned from room")
|
raise AuthError(code=403, msg="Server is banned from room")
|
||||||
|
|
|
@ -320,7 +320,7 @@ class DirectoryHandler:
|
||||||
Raises:
|
Raises:
|
||||||
ShadowBanError if the requester has been shadow-banned.
|
ShadowBanError if the requester has been shadow-banned.
|
||||||
"""
|
"""
|
||||||
alias_event = await self.state.get_current_state(
|
alias_event = await self._storage_controllers.state.get_current_state_event(
|
||||||
room_id, EventTypes.CanonicalAlias, ""
|
room_id, EventTypes.CanonicalAlias, ""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -371,7 +371,7 @@ class FederationHandler:
|
||||||
# First we try hosts that are already in the room
|
# First we try hosts that are already in the room
|
||||||
# TODO: HEURISTIC ALERT.
|
# TODO: HEURISTIC ALERT.
|
||||||
|
|
||||||
curr_state = await self.state_handler.get_current_state(room_id)
|
curr_state = await self._storage_controllers.state.get_current_state(room_id)
|
||||||
|
|
||||||
curr_domains = get_domains_from_state(curr_state)
|
curr_domains = get_domains_from_state(curr_state)
|
||||||
|
|
||||||
|
|
|
@ -1584,9 +1584,11 @@ class FederationEventHandler:
|
||||||
if guest_access == GuestAccess.CAN_JOIN:
|
if guest_access == GuestAccess.CAN_JOIN:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_state_map = await self._state_handler.get_current_state(event.room_id)
|
current_state = await self._storage_controllers.state.get_current_state(
|
||||||
current_state = list(current_state_map.values())
|
event.room_id
|
||||||
await self._get_room_member_handler().kick_guest_users(current_state)
|
)
|
||||||
|
current_state_list = list(current_state.values())
|
||||||
|
await self._get_room_member_handler().kick_guest_users(current_state_list)
|
||||||
|
|
||||||
async def _check_for_soft_fail(
|
async def _check_for_soft_fail(
|
||||||
self,
|
self,
|
||||||
|
@ -1614,6 +1616,9 @@ class FederationEventHandler:
|
||||||
room_version = await self._store.get_room_version_id(event.room_id)
|
room_version = await self._store.get_room_version_id(event.room_id)
|
||||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||||
|
|
||||||
|
# The event types we want to pull from the "current" state.
|
||||||
|
auth_types = auth_types_for_event(room_version_obj, event)
|
||||||
|
|
||||||
# Calculate the "current state".
|
# Calculate the "current state".
|
||||||
if state_ids is not None:
|
if state_ids is not None:
|
||||||
# If we're explicitly given the state then we won't have all the
|
# If we're explicitly given the state then we won't have all the
|
||||||
|
@ -1643,8 +1648,10 @@ class FederationEventHandler:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
current_state_ids = await self._state_handler.get_current_state_ids(
|
current_state_ids = (
|
||||||
event.room_id, latest_event_ids=extrem_ids
|
await self._state_storage_controller.get_current_state_ids(
|
||||||
|
event.room_id, StateFilter.from_types(auth_types)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -1654,7 +1661,6 @@ class FederationEventHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now check if event pass auth against said current state
|
# Now check if event pass auth against said current state
|
||||||
auth_types = auth_types_for_event(room_version_obj, event)
|
|
||||||
current_state_ids_list = [
|
current_state_ids_list = [
|
||||||
e for k, e in current_state_ids.items() if k in auth_types
|
e for k, e in current_state_ids.items() if k in auth_types
|
||||||
]
|
]
|
||||||
|
|
|
@ -190,7 +190,7 @@ class InitialSyncHandler:
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
room_end_token = now_token.room_key
|
room_end_token = now_token.room_key
|
||||||
deferred_room_state = run_in_background(
|
deferred_room_state = run_in_background(
|
||||||
self.state_handler.get_current_state, event.room_id
|
self._state_storage_controller.get_current_state, event.room_id
|
||||||
)
|
)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
room_end_token = RoomStreamToken(
|
room_end_token = RoomStreamToken(
|
||||||
|
@ -407,7 +407,9 @@ class InitialSyncHandler:
|
||||||
membership: str,
|
membership: str,
|
||||||
is_peeking: bool,
|
is_peeking: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
current_state = await self.state.get_current_state(room_id=room_id)
|
current_state = await self._storage_controllers.state.get_current_state(
|
||||||
|
room_id=room_id
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: These concurrently
|
# TODO: These concurrently
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
|
@ -125,7 +125,9 @@ class MessageHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
data = await self.state.get_current_state(room_id, event_type, state_key)
|
data = await self._storage_controllers.state.get_current_state_event(
|
||||||
|
room_id, event_type, state_key
|
||||||
|
)
|
||||||
elif membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
key = (event_type, state_key)
|
key = (event_type, state_key)
|
||||||
# If the membership is not JOIN, then the event ID should exist.
|
# If the membership is not JOIN, then the event ID should exist.
|
||||||
|
|
|
@ -1333,6 +1333,7 @@ class TimestampLookupHandler:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.federation_client = hs.get_federation_client()
|
self.federation_client = hs.get_federation_client()
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
async def get_event_for_timestamp(
|
async def get_event_for_timestamp(
|
||||||
self,
|
self,
|
||||||
|
@ -1406,7 +1407,9 @@ class TimestampLookupHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find other homeservers from the given state in the room
|
# Find other homeservers from the given state in the room
|
||||||
curr_state = await self.state_handler.get_current_state(room_id)
|
curr_state = await self._storage_controllers.state.get_current_state(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
curr_domains = get_domains_from_state(curr_state)
|
curr_domains = get_domains_from_state(curr_state)
|
||||||
likely_domains = [
|
likely_domains = [
|
||||||
domain for domain, depth in curr_domains if domain != self.server_name
|
domain for domain, depth in curr_domains if domain != self.server_name
|
||||||
|
|
|
@ -1401,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
txn_id: Optional[str],
|
txn_id: Optional[str],
|
||||||
id_access_token: Optional[str] = None,
|
id_access_token: Optional[str] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
room_state = await self.state_handler.get_current_state(room_id)
|
room_state = await self._storage_controllers.state.get_current_state(
|
||||||
|
room_id,
|
||||||
|
StateFilter.from_types(
|
||||||
|
[
|
||||||
|
(EventTypes.Member, user.to_string()),
|
||||||
|
(EventTypes.CanonicalAlias, ""),
|
||||||
|
(EventTypes.Name, ""),
|
||||||
|
(EventTypes.Create, ""),
|
||||||
|
(EventTypes.JoinRules, ""),
|
||||||
|
(EventTypes.RoomAvatar, ""),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
inviter_display_name = ""
|
inviter_display_name = ""
|
||||||
inviter_avatar_url = ""
|
inviter_avatar_url = ""
|
||||||
|
@ -1797,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||||
async def forget(self, user: UserID, room_id: str) -> None:
|
async def forget(self, user: UserID, room_id: str) -> None:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
member = await self.state_handler.get_current_state(
|
member = await self._storage_controllers.state.get_current_state_event(
|
||||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||||
)
|
)
|
||||||
membership = member.membership if member else None
|
membership = member.membership if member else None
|
||||||
|
|
|
@ -348,7 +348,7 @@ class SearchHandler:
|
||||||
state_results = {}
|
state_results = {}
|
||||||
if include_state:
|
if include_state:
|
||||||
for room_id in {e.room_id for e in search_result.allowed_events}:
|
for room_id in {e.room_id for e in search_result.allowed_events}:
|
||||||
state = await self.state_handler.get_current_state(room_id)
|
state = await self._storage_controllers.state.get_current_state(room_id)
|
||||||
state_results[room_id] = list(state.values())
|
state_results[room_id] = list(state.values())
|
||||||
|
|
||||||
aggregations = await self._relations_handler.get_bundled_aggregations(
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
|
|
|
@ -681,7 +681,7 @@ class Notifier:
|
||||||
return joined_room_ids, True
|
return joined_room_ids, True
|
||||||
|
|
||||||
async def _is_world_readable(self, room_id: str) -> bool:
|
async def _is_world_readable(self, room_id: str) -> bool:
|
||||||
state = await self.state_handler.get_current_state(
|
state = await self._storage_controllers.state.get_current_state_event(
|
||||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||||
)
|
)
|
||||||
if state and "history_visibility" in state.content:
|
if state and "history_visibility" in state.content:
|
||||||
|
|
|
@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
|
||||||
assert_user_is_admin,
|
assert_user_is_admin,
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.room import RoomSortOrder
|
from synapse.storage.databases.main.room import RoomSortOrder
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import JsonDict, RoomID, UserID, create_requester
|
from synapse.types import JsonDict, RoomID, UserID, create_requester
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
|
@ -448,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.admin_handler = hs.get_admin_handler()
|
self.admin_handler = hs.get_admin_handler()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
|
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
|
@ -490,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
# send invite if room has "JoinRules.INVITE"
|
# send invite if room has "JoinRules.INVITE"
|
||||||
room_state = await self.state_handler.get_current_state(room_id)
|
join_rules_event = (
|
||||||
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
|
await self._storage_controllers.state.get_current_state_event(
|
||||||
|
room_id, EventTypes.JoinRules, ""
|
||||||
|
)
|
||||||
|
)
|
||||||
if join_rules_event:
|
if join_rules_event:
|
||||||
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
|
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
|
||||||
# update_membership with an action of "invite" can raise a
|
# update_membership with an action of "invite" can raise a
|
||||||
|
@ -536,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._state_storage_controller = hs.get_storage_controllers().state
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
@ -553,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||||
user_to_add = content.get("user_id", requester.user.to_string())
|
user_to_add = content.get("user_id", requester.user.to_string())
|
||||||
|
|
||||||
# Figure out which local users currently have power in the room, if any.
|
# Figure out which local users currently have power in the room, if any.
|
||||||
room_state = await self.state_handler.get_current_state(room_id)
|
filtered_room_state = await self._state_storage_controller.get_current_state(
|
||||||
if not room_state:
|
room_id,
|
||||||
|
StateFilter.from_types(
|
||||||
|
[
|
||||||
|
(EventTypes.Create, ""),
|
||||||
|
(EventTypes.PowerLevels, ""),
|
||||||
|
(EventTypes.JoinRules, ""),
|
||||||
|
(EventTypes.Member, user_to_add),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not filtered_room_state:
|
||||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
|
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
|
||||||
|
|
||||||
create_event = room_state[(EventTypes.Create, "")]
|
create_event = filtered_room_state[(EventTypes.Create, "")]
|
||||||
power_levels = room_state.get((EventTypes.PowerLevels, ""))
|
power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
|
||||||
|
|
||||||
if power_levels is not None:
|
if power_levels is not None:
|
||||||
# We pick the local user with the highest power.
|
# We pick the local user with the highest power.
|
||||||
|
@ -634,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||||
|
|
||||||
# Now we check if the user we're granting admin rights to is already in
|
# Now we check if the user we're granting admin rights to is already in
|
||||||
# the room. If not and it's not a public room we invite them.
|
# the room. If not and it's not a public room we invite them.
|
||||||
member_event = room_state.get((EventTypes.Member, user_to_add))
|
member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
|
||||||
is_joined = False
|
is_joined = False
|
||||||
if member_event:
|
if member_event:
|
||||||
is_joined = member_event.content["membership"] in (
|
is_joined = member_event.content["membership"] in (
|
||||||
|
@ -645,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||||
if is_joined:
|
if is_joined:
|
||||||
return HTTPStatus.OK, {}
|
return HTTPStatus.OK, {}
|
||||||
|
|
||||||
join_rules = room_state.get((EventTypes.JoinRules, ""))
|
join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
|
||||||
is_public = False
|
is_public = False
|
||||||
if join_rules:
|
if join_rules:
|
||||||
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
|
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
|
||||||
|
|
|
@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self._state = hs.get_state_handler()
|
self._state = hs.get_state_handler()
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.event_handler = hs.get_event_handler()
|
self.event_handler = hs.get_event_handler()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self._relations_handler = hs.get_relations_handler()
|
self._relations_handler = hs.get_relations_handler()
|
||||||
|
@ -673,9 +674,11 @@ class RoomEventServlet(RestServlet):
|
||||||
if include_unredacted_content and not await self.auth.is_server_admin(
|
if include_unredacted_content and not await self.auth.is_server_admin(
|
||||||
requester.user
|
requester.user
|
||||||
):
|
):
|
||||||
power_level_event = await self._state.get_current_state(
|
power_level_event = (
|
||||||
|
await self._storage_controllers.state.get_current_state_event(
|
||||||
room_id, EventTypes.PowerLevels, ""
|
room_id, EventTypes.PowerLevels, ""
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
auth_events = {}
|
auth_events = {}
|
||||||
if power_level_event:
|
if power_level_event:
|
||||||
|
|
|
@ -36,6 +36,7 @@ class ResourceLimitsServerNotices:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._server_notices_manager = hs.get_server_notices_manager()
|
self._server_notices_manager = hs.get_server_notices_manager()
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._config = hs.config
|
self._config = hs.config
|
||||||
self._resouce_limited = False
|
self._resouce_limited = False
|
||||||
|
@ -178,8 +179,10 @@ class ResourceLimitsServerNotices:
|
||||||
currently_blocked = False
|
currently_blocked = False
|
||||||
pinned_state_event = None
|
pinned_state_event = None
|
||||||
try:
|
try:
|
||||||
pinned_state_event = await self._state.get_current_state(
|
pinned_state_event = (
|
||||||
room_id, event_type=EventTypes.Pinned
|
await self._storage_controllers.state.get_current_state_event(
|
||||||
|
room_id, event_type=EventTypes.Pinned, state_key=""
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except AuthError:
|
except AuthError:
|
||||||
# The user has yet to join the server notices room
|
# The user has yet to join the server notices room
|
||||||
|
|
|
@ -32,13 +32,11 @@ from typing import (
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
overload,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
from prometheus_client import Counter, Histogram
|
from prometheus_client import Counter, Histogram
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
||||||
|
@ -132,85 +130,20 @@ class StateHandler:
|
||||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||||
self._storage_controllers = hs.get_storage_controllers()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
@overload
|
|
||||||
async def get_current_state(
|
|
||||||
self,
|
|
||||||
room_id: str,
|
|
||||||
event_type: Literal[None] = None,
|
|
||||||
state_key: str = "",
|
|
||||||
latest_event_ids: Optional[List[str]] = None,
|
|
||||||
) -> StateMap[EventBase]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def get_current_state(
|
|
||||||
self,
|
|
||||||
room_id: str,
|
|
||||||
event_type: str,
|
|
||||||
state_key: str = "",
|
|
||||||
latest_event_ids: Optional[List[str]] = None,
|
|
||||||
) -> Optional[EventBase]:
|
|
||||||
...
|
|
||||||
|
|
||||||
async def get_current_state(
|
|
||||||
self,
|
|
||||||
room_id: str,
|
|
||||||
event_type: Optional[str] = None,
|
|
||||||
state_key: str = "",
|
|
||||||
latest_event_ids: Optional[List[str]] = None,
|
|
||||||
) -> Union[Optional[EventBase], StateMap[EventBase]]:
|
|
||||||
"""Retrieves the current state for the room. This is done by
|
|
||||||
calling `get_latest_events_in_room` to get the leading edges of the
|
|
||||||
event graph and then resolving any of the state conflicts.
|
|
||||||
|
|
||||||
This is equivalent to getting the state of an event that were to send
|
|
||||||
next before receiving any new events.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If `event_type` is specified, then the method returns only the one
|
|
||||||
event (or None) with that `event_type` and `state_key`.
|
|
||||||
|
|
||||||
Otherwise, a map from (type, state_key) to event.
|
|
||||||
"""
|
|
||||||
if not latest_event_ids:
|
|
||||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
|
||||||
assert latest_event_ids is not None
|
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups from get_current_state")
|
|
||||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
|
||||||
state = ret.state
|
|
||||||
|
|
||||||
if event_type:
|
|
||||||
event_id = state.get((event_type, state_key))
|
|
||||||
event = None
|
|
||||||
if event_id:
|
|
||||||
event = await self.store.get_event(event_id, allow_none=True)
|
|
||||||
return event
|
|
||||||
|
|
||||||
state_map = await self.store.get_events(
|
|
||||||
list(state.values()), get_prev_content=False
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_current_state_ids(
|
async def get_current_state_ids(
|
||||||
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
|
self,
|
||||||
|
room_id: str,
|
||||||
|
latest_event_ids: Collection[str],
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""Get the current state, or the state at a set of events, for a room
|
"""Get the current state, or the state at a set of events, for a room
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id:
|
room_id:
|
||||||
latest_event_ids: if given, the forward extremities to resolve. If
|
latest_event_ids: The forward extremities to resolve.
|
||||||
None, we look them up from the database (via a cache).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the state dict, mapping from (event_type, state_key) -> event_id
|
the state dict, mapping from (event_type, state_key) -> event_id
|
||||||
"""
|
"""
|
||||||
if not latest_event_ids:
|
|
||||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
|
||||||
assert latest_event_ids is not None
|
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||||
return ret.state
|
return ret.state
|
||||||
|
|
|
@ -455,3 +455,30 @@ class StateStorageController:
|
||||||
return await self.stores.main.get_partial_current_state_deltas(
|
return await self.stores.main.get_partial_current_state_deltas(
|
||||||
prev_stream_id, max_stream_id
|
prev_stream_id, max_stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_current_state(
|
||||||
|
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||||
|
) -> StateMap[EventBase]:
|
||||||
|
"""Same as `get_current_state_ids` but also fetches the events"""
|
||||||
|
state_map_ids = await self.get_current_state_ids(room_id, state_filter)
|
||||||
|
|
||||||
|
event_map = await self.stores.main.get_events(list(state_map_ids.values()))
|
||||||
|
|
||||||
|
state_map = {}
|
||||||
|
for key, event_id in state_map_ids.items():
|
||||||
|
event = event_map.get(event_id)
|
||||||
|
if event:
|
||||||
|
state_map[key] = event
|
||||||
|
|
||||||
|
return state_map
|
||||||
|
|
||||||
|
async def get_current_state_event(
|
||||||
|
self, room_id: str, event_type: str, state_key: str
|
||||||
|
) -> Optional[EventBase]:
|
||||||
|
"""Get the current state event for the given type/state_key."""
|
||||||
|
|
||||||
|
key = (event_type, state_key)
|
||||||
|
state_map = await self.get_current_state(
|
||||||
|
room_id, StateFilter.from_types((key,))
|
||||||
|
)
|
||||||
|
return state_map.get(key)
|
||||||
|
|
|
@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
||||||
super().prepare(reactor, clock, hs)
|
super().prepare(reactor, clock, hs)
|
||||||
|
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
# create the room
|
# create the room
|
||||||
creator_user_id = self.register_user("kermit", "test")
|
creator_user_id = self.register_user("kermit", "test")
|
||||||
tok = self.login("kermit", "test")
|
tok = self.login("kermit", "test")
|
||||||
|
@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# the room should show that the new user is a member
|
# the room should show that the new user is a member
|
||||||
r = self.get_success(
|
r = self.get_success(
|
||||||
self.hs.get_state_handler().get_current_state(self._room_id)
|
self._storage_controllers.state.get_current_state(self._room_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||||
|
|
||||||
|
@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# the room should show that the new user is a member
|
# the room should show that the new user is a member
|
||||||
r = self.get_success(
|
r = self.get_success(
|
||||||
self.hs.get_state_handler().get_current_state(self._room_id)
|
self._storage_controllers.state.get_current_state(self._room_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||||
|
|
||||||
|
|
|
@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.handler = hs.get_directory_handler()
|
self.handler = hs.get_directory_handler()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
|
@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
def _get_canonical_alias(self):
|
def _get_canonical_alias(self):
|
||||||
"""Get the canonical alias state of the room."""
|
"""Get the canonical alias state of the room."""
|
||||||
return self.get_success(
|
return self.get_success(
|
||||||
self.state_handler.get_current_state(
|
self._storage_controllers.state.get_current_state_event(
|
||||||
self.room_id, EventTypes.CanonicalAlias, ""
|
self.room_id, EventTypes.CanonicalAlias, ""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor, clock, homeserver):
|
||||||
self.state = self.hs.get_state_handler()
|
self.state = self.hs.get_state_handler()
|
||||||
self._persistence = self.hs.get_storage_controllers().persistence
|
self._persistence = self.hs.get_storage_controllers().persistence
|
||||||
|
self._state_storage_controller = self.hs.get_storage_controllers().state
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
self.register_user("user", "pass")
|
self.register_user("user", "pass")
|
||||||
|
@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(
|
state_before_gap = self.get_success(
|
||||||
self.state.get_current_state_ids(self.room_id)
|
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap)
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# setting. The state resolution across the old and new event will then
|
# setting. The state resolution across the old and new event will then
|
||||||
# include it, and so the resolved state won't match the new state.
|
# include it, and so the resolved state won't match the new state.
|
||||||
state_before_gap = dict(
|
state_before_gap = dict(
|
||||||
self.get_success(self.state.get_current_state_ids(self.room_id))
|
self.get_success(
|
||||||
|
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
state_before_gap.pop(("m.room.history_visibility", ""))
|
state_before_gap.pop(("m.room.history_visibility", ""))
|
||||||
|
|
||||||
|
@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(
|
state_before_gap = self.get_success(
|
||||||
self.state.get_current_state_ids(self.room_id)
|
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap)
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(
|
state_before_gap = self.get_success(
|
||||||
self.state.get_current_state_ids(self.room_id)
|
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap)
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(
|
state_before_gap = self.get_success(
|
||||||
self.state.get_current_state_ids(self.room_id)
|
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap)
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(
|
state_before_gap = self.get_success(
|
||||||
self.state.get_current_state_ids(self.room_id)
|
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap)
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
state_before_gap = self.get_success(
|
state_before_gap = self.get_success(
|
||||||
self.state.get_current_state_ids(self.room_id)
|
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.persist_event(remote_event_2, state=state_before_gap)
|
self.persist_event(remote_event_2, state=state_before_gap)
|
||||||
|
|
|
@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase):
|
||||||
first = self.helper.send(self.room_id, body="test1")
|
first = self.helper.send(self.room_id, body="test1")
|
||||||
|
|
||||||
# Get the current room state.
|
# Get the current room state.
|
||||||
state_handler = self.hs.get_state_handler()
|
|
||||||
create_event = self.get_success(
|
create_event = self.get_success(
|
||||||
state_handler.get_current_state(self.room_id, "m.room.create", "")
|
self._storage_controllers.state.get_current_state_event(
|
||||||
|
self.room_id, "m.room.create", ""
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(create_event)
|
self.assertIsNotNone(create_event)
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||||
# Room events need the full datastore, for persist_event() and
|
# Room events need the full datastore, for persist_event() and
|
||||||
# get_room_state()
|
# get_room_state()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self._storage = hs.get_storage_controllers()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.event_factory = hs.get_event_factory()
|
self.event_factory = hs.get_event_factory()
|
||||||
|
|
||||||
self.room = RoomID.from_string("!abcde:test")
|
self.room = RoomID.from_string("!abcde:test")
|
||||||
|
@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def inject_room_event(self, **kwargs):
|
def inject_room_event(self, **kwargs):
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self._storage.persistence.persist_event(
|
self._storage_controllers.persistence.persist_event(
|
||||||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
state = self.get_success(
|
state = self.get_success(
|
||||||
self.store.get_current_state(room_id=self.room.to_string())
|
self._storage_controllers.state.get_current_state(
|
||||||
|
room_id=self.room.to_string()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1, len(state))
|
self.assertEqual(1, len(state))
|
||||||
|
@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
state = self.get_success(
|
state = self.get_success(
|
||||||
self.store.get_current_state(room_id=self.room.to_string())
|
self._storage_controllers.state.get_current_state(
|
||||||
|
room_id=self.room.to_string()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1, len(state))
|
self.assertEqual(1, len(state))
|
||||||
|
|
Loading…
Reference in a new issue