mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-21 01:24:04 +01:00
Use state handler instead of get_users_in_room/get_joined_hosts
This commit is contained in:
parent
3cf15edef7
commit
bed10f9880
12 changed files with 44 additions and 27 deletions
|
@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.types import get_domain_from_id
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
||||||
|
@ -63,6 +64,7 @@ class FederationClient(FederationBase):
|
||||||
self._clock.looping_call(
|
self._clock.looping_call(
|
||||||
self._clear_tried_cache, 60 * 1000,
|
self._clear_tried_cache, 60 * 1000,
|
||||||
)
|
)
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
def _clear_tried_cache(self):
|
def _clear_tried_cache(self):
|
||||||
"""Clear pdu_destination_tried cache"""
|
"""Clear pdu_destination_tried cache"""
|
||||||
|
@ -811,7 +813,8 @@ class FederationClient(FederationBase):
|
||||||
if len(signed_events) >= limit:
|
if len(signed_events) >= limit:
|
||||||
defer.returnValue(signed_events)
|
defer.returnValue(signed_events)
|
||||||
|
|
||||||
servers = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
servers = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
servers = set(servers)
|
servers = set(servers)
|
||||||
servers.discard(self.server_name)
|
servers.discard(self.server_name)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
|
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomAlias, UserID
|
from synapse.types import RoomAlias, UserID, get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import string
|
import string
|
||||||
|
@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler):
|
||||||
# TODO(erikj): Add transactions.
|
# TODO(erikj): Add transactions.
|
||||||
# TODO(erikj): Check if there is a current association.
|
# TODO(erikj): Check if there is a current association.
|
||||||
if not servers:
|
if not servers:
|
||||||
servers = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
servers = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
if not servers:
|
if not servers:
|
||||||
raise SynapseError(400, "Failed to get server list")
|
raise SynapseError(400, "Failed to get server list")
|
||||||
|
@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler):
|
||||||
Codes.NOT_FOUND
|
Codes.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
extra_servers = set(get_domain_from_id(u) for u in users)
|
||||||
servers = set(extra_servers) | set(servers)
|
servers = set(extra_servers) | set(servers)
|
||||||
|
|
||||||
# If this server is in the list of servers, return it first.
|
# If this server is in the list of servers, return it first.
|
||||||
|
|
|
@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler):
|
||||||
# Send down presence.
|
# Send down presence.
|
||||||
if event.state_key == auth_user_id:
|
if event.state_key == auth_user_id:
|
||||||
# Send down presence for everyone in the room.
|
# Send down presence for everyone in the room.
|
||||||
users = yield self.store.get_users_in_room(event.room_id)
|
users = yield self.state.get_current_user_in_room(event.room_id)
|
||||||
states = yield presence_handler.get_states(
|
states = yield presence_handler.get_states(
|
||||||
users,
|
users,
|
||||||
as_event=True,
|
as_event=True,
|
||||||
|
|
|
@ -88,6 +88,8 @@ class PresenceHandler(object):
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
|
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
self.federation.register_edu_handler(
|
self.federation.register_edu_handler(
|
||||||
"m.presence", self.incoming_presence
|
"m.presence", self.incoming_presence
|
||||||
)
|
)
|
||||||
|
@ -532,7 +534,9 @@ class PresenceHandler(object):
|
||||||
if not local_states:
|
if not local_states:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
hosts = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||||
|
|
||||||
|
@ -725,13 +729,13 @@ class PresenceHandler(object):
|
||||||
# don't need to send to local clients here, as that is done as part
|
# don't need to send to local clients here, as that is done as part
|
||||||
# of the event stream/sync.
|
# of the event stream/sync.
|
||||||
# TODO: Only send to servers not already in the room.
|
# TODO: Only send to servers not already in the room.
|
||||||
|
user_ids = yield self.state.get_current_user_in_room(room_id)
|
||||||
if self.is_mine(user):
|
if self.is_mine(user):
|
||||||
state = yield self.current_state_for_user(user.to_string())
|
state = yield self.current_state_for_user(user.to_string())
|
||||||
|
|
||||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
hosts = set(get_domain_from_id(u) for u in user_ids)
|
||||||
self._push_to_remotes({host: (state,) for host in hosts})
|
self._push_to_remotes({host: (state,) for host in hosts})
|
||||||
else:
|
else:
|
||||||
user_ids = yield self.store.get_users_in_room(room_id)
|
|
||||||
user_ids = filter(self.is_mine_id, user_ids)
|
user_ids = filter(self.is_mine_id, user_ids)
|
||||||
|
|
||||||
states = yield self.current_state_for_users(user_ids)
|
states = yield self.current_state_for_users(user_ids)
|
||||||
|
@ -955,6 +959,7 @@ class PresenceEventSource(object):
|
||||||
self.get_presence_handler = hs.get_presence_handler
|
self.get_presence_handler = hs.get_presence_handler
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -1017,7 +1022,7 @@ class PresenceEventSource(object):
|
||||||
|
|
||||||
user_ids_to_check = set()
|
user_ids_to_check = set()
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
user_ids_to_check.update(users)
|
user_ids_to_check.update(users)
|
||||||
|
|
||||||
user_ids_to_check.update(friends)
|
user_ids_to_check.update(friends)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from ._base import BaseHandler
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler):
|
||||||
"m.receipt", self._received_remote_receipt
|
"m.receipt", self._received_remote_receipt
|
||||||
)
|
)
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def received_client_receipt(self, room_id, receipt_type, user_id,
|
def received_client_receipt(self, room_id, receipt_type, user_id,
|
||||||
|
@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler):
|
||||||
event_ids = receipt["event_ids"]
|
event_ids = receipt["event_ids"]
|
||||||
data = receipt["data"]
|
data = receipt["data"]
|
||||||
|
|
||||||
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
remotedomains = set(get_domain_from_id(u) for u in users)
|
||||||
remotedomains = remotedomains.copy()
|
remotedomains = remotedomains.copy()
|
||||||
remotedomains.discard(self.server_name)
|
remotedomains.discard(self.server_name)
|
||||||
|
|
||||||
|
|
|
@ -139,6 +139,7 @@ class SyncHandler(object):
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.response_cache = ResponseCache(hs)
|
self.response_cache = ResponseCache(hs)
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
|
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
|
||||||
full_state=False):
|
full_state=False):
|
||||||
|
@ -630,7 +631,7 @@ class SyncHandler(object):
|
||||||
|
|
||||||
extra_users_ids = set(newly_joined_users)
|
extra_users_ids = set(newly_joined_users)
|
||||||
for room_id in newly_joined_rooms:
|
for room_id in newly_joined_rooms:
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
extra_users_ids.update(users)
|
extra_users_ids.update(users)
|
||||||
extra_users_ids.discard(user.to_string())
|
extra_users_ids.discard(user.to_string())
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from synapse.util.logcontext import (
|
||||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||||
)
|
)
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ class TypingHandler(object):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@ -166,7 +167,8 @@ class TypingHandler(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _push_update(self, room_id, user_id, typing):
|
def _push_update(self, room_id, user_id, typing):
|
||||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
domains = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
deferreds = []
|
deferreds = []
|
||||||
for domain in domains:
|
for domain in domains:
|
||||||
|
@ -199,7 +201,8 @@ class TypingHandler(object):
|
||||||
# Check that the string is a valid user id
|
# Check that the string is a valid user id
|
||||||
UserID.from_string(user_id)
|
UserID.from_string(user_id)
|
||||||
|
|
||||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
domains = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
if self.server_name in domains:
|
if self.server_name in domains:
|
||||||
self._push_update_local(
|
self._push_update_local(
|
||||||
|
|
|
@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
|
||||||
)
|
)
|
||||||
|
|
||||||
room_members = yield self.store.get_joined_users_from_context(
|
room_members = yield self.store.get_joined_users_from_context(
|
||||||
event.room_id, context,
|
event.room_id, context.state_group, context.current_state_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||||
|
|
|
@ -216,7 +216,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
self._get_current_state_for_key.invalidate_all()
|
self._get_current_state_for_key.invalidate_all()
|
||||||
self.get_rooms_for_user.invalidate_all()
|
self.get_rooms_for_user.invalidate_all()
|
||||||
self.get_users_in_room.invalidate((event.room_id,))
|
self.get_users_in_room.invalidate((event.room_id,))
|
||||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
|
||||||
|
|
||||||
self._invalidate_get_event_cache(event.event_id)
|
self._invalidate_get_event_cache(event.event_id)
|
||||||
|
|
||||||
|
@ -240,7 +239,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
self.get_rooms_for_user.invalidate((event.state_key,))
|
self.get_rooms_for_user.invalidate((event.state_key,))
|
||||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
|
||||||
self.get_users_in_room.invalidate((event.room_id,))
|
self.get_users_in_room.invalidate((event.room_id,))
|
||||||
self._membership_stream_cache.entity_has_changed(
|
self._membership_stream_cache.entity_has_changed(
|
||||||
event.state_key, event.internal_metadata.stream_ordering
|
event.state_key, event.internal_metadata.stream_ordering
|
||||||
|
|
|
@ -124,6 +124,15 @@ class StateHandler(object):
|
||||||
|
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_current_user_in_room(self, room_id):
|
||||||
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
|
joined_users = yield self.store.get_joined_users_from_context(
|
||||||
|
room_id, group, state_ids
|
||||||
|
)
|
||||||
|
defer.returnValue(joined_users)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def compute_event_context(self, event, old_state=None):
|
def compute_event_context(self, event, old_state=None):
|
||||||
""" Fills out the context with the `current state` of the graph. The
|
""" Fills out the context with the `current state` of the graph. The
|
||||||
|
|
|
@ -393,7 +393,6 @@ class EventsStore(SQLBaseStore):
|
||||||
txn.call_after(self._get_current_state_for_key.invalidate_all)
|
txn.call_after(self._get_current_state_for_key.invalidate_all)
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
|
||||||
|
|
||||||
# Add an entry to the current_state_resets table to record the point
|
# Add an entry to the current_state_resets table to record the point
|
||||||
# where we clobbered the current state
|
# where we clobbered the current state
|
||||||
|
|
|
@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
|
||||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._membership_stream_cache.entity_has_changed,
|
self._membership_stream_cache.entity_has_changed,
|
||||||
|
@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=5000)
|
|
||||||
def get_joined_hosts_for_room(self, room_id):
|
|
||||||
user_ids = yield self.get_users_in_room(room_id)
|
|
||||||
defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
|
|
||||||
|
|
||||||
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
||||||
where_clause = "c.room_id = ?"
|
where_clause = "c.room_id = ?"
|
||||||
where_values = [room_id]
|
where_values = [room_id]
|
||||||
|
@ -360,8 +354,7 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
desc="who_forgot"
|
desc="who_forgot"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_joined_users_from_context(self, room_id, context):
|
def get_joined_users_from_context(self, room_id, state_group, state_ids):
|
||||||
state_group = context.state_group
|
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
@ -370,7 +363,7 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._get_joined_users_from_context(
|
return self._get_joined_users_from_context(
|
||||||
room_id, state_group, context.current_state_ids
|
room_id, state_group, state_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
|
|
Loading…
Reference in a new issue