0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-21 01:24:04 +01:00

Use state handler instead of get_users_in_room/get_joined_hosts

This commit is contained in:
Erik Johnston 2016-08-26 14:54:30 +01:00
parent 3cf15edef7
commit bed10f9880
12 changed files with 44 additions and 27 deletions

View file

@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@ -63,6 +64,7 @@ class FederationClient(FederationBase):
self._clock.looping_call( self._clock.looping_call(
self._clear_tried_cache, 60 * 1000, self._clear_tried_cache, 60 * 1000,
) )
self.state = hs.get_state_handler()
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""
@ -811,7 +813,8 @@ class FederationClient(FederationBase):
if len(signed_events) >= limit: if len(signed_events) >= limit:
defer.returnValue(signed_events) defer.returnValue(signed_events)
servers = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
servers = set(servers) servers = set(servers)
servers.discard(self.server_name) servers.discard(self.server_name)

View file

@ -19,7 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomAlias, UserID from synapse.types import RoomAlias, UserID, get_domain_from_id
import logging import logging
import string import string
@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions. # TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association. # TODO(erikj): Check if there is a current association.
if not servers: if not servers:
servers = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
if not servers: if not servers:
raise SynapseError(400, "Failed to get server list") raise SynapseError(400, "Failed to get server list")
@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND Codes.NOT_FOUND
) )
extra_servers = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
extra_servers = set(get_domain_from_id(u) for u in users)
servers = set(extra_servers) | set(servers) servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first. # If this server is in the list of servers, return it first.

View file

@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence. # Send down presence.
if event.state_key == auth_user_id: if event.state_key == auth_user_id:
# Send down presence for everyone in the room. # Send down presence for everyone in the room.
users = yield self.store.get_users_in_room(event.room_id) users = yield self.state.get_current_user_in_room(event.room_id)
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
users, users,
as_event=True, as_event=True,

View file

@ -88,6 +88,8 @@ class PresenceHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.state = hs.get_state_handler()
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
@ -532,7 +534,9 @@ class PresenceHandler(object):
if not local_states: if not local_states:
continue continue
hosts = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
for host in hosts: for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
@ -725,13 +729,13 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part # don't need to send to local clients here, as that is done as part
# of the event stream/sync. # of the event stream/sync.
# TODO: Only send to servers not already in the room. # TODO: Only send to servers not already in the room.
user_ids = yield self.state.get_current_user_in_room(room_id)
if self.is_mine(user): if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string()) state = yield self.current_state_for_user(user.to_string())
hosts = yield self.store.get_joined_hosts_for_room(room_id) hosts = set(get_domain_from_id(u) for u in user_ids)
self._push_to_remotes({host: (state,) for host in hosts}) self._push_to_remotes({host: (state,) for host in hosts})
else: else:
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.is_mine_id, user_ids) user_ids = filter(self.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids) states = yield self.current_state_for_users(user_ids)
@ -955,6 +959,7 @@ class PresenceEventSource(object):
self.get_presence_handler = hs.get_presence_handler self.get_presence_handler = hs.get_presence_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -1017,7 +1022,7 @@ class PresenceEventSource(object):
user_ids_to_check = set() user_ids_to_check = set()
for room_id in room_ids: for room_id in room_ids:
users = yield self.store.get_users_in_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
user_ids_to_check.update(users) user_ids_to_check.update(users)
user_ids_to_check.update(friends) user_ids_to_check.update(friends)

View file

@ -18,6 +18,7 @@ from ._base import BaseHandler
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import get_domain_from_id
import logging import logging
@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler):
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id, def received_client_receipt(self, room_id, receipt_type, user_id,
@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"] event_ids = receipt["event_ids"]
data = receipt["data"] data = receipt["data"]
remotedomains = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy() remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name) remotedomains.discard(self.server_name)

View file

@ -139,6 +139,7 @@ class SyncHandler(object):
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs) self.response_cache = ResponseCache(hs)
self.state = hs.get_state_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
@ -630,7 +631,7 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_users) extra_users_ids = set(newly_joined_users)
for room_id in newly_joined_rooms: for room_id in newly_joined_rooms:
users = yield self.store.get_users_in_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
extra_users_ids.update(users) extra_users_ids.update(users)
extra_users_ids.discard(user.to_string()) extra_users_ids.discard(user.to_string())

View file

@ -20,7 +20,7 @@ from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
) )
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import UserID from synapse.types import UserID, get_domain_from_id
import logging import logging
@ -42,6 +42,7 @@ class TypingHandler(object):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -166,7 +167,8 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user_id, typing): def _push_update(self, room_id, user_id, typing):
domains = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
deferreds = [] deferreds = []
for domain in domains: for domain in domains:
@ -199,7 +201,8 @@ class TypingHandler(object):
# Check that the string is a valid user id # Check that the string is a valid user id
UserID.from_string(user_id) UserID.from_string(user_id)
domains = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains: if self.server_name in domains:
self._push_update_local( self._push_update_local(

View file

@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
) )
room_members = yield self.store.get_joined_users_from_context( room_members = yield self.store.get_joined_users_from_context(
event.room_id, context, event.room_id, context.state_group, context.current_state_ids
) )
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))

View file

@ -216,7 +216,6 @@ class SlavedEventStore(BaseSlavedStore):
self._get_current_state_for_key.invalidate_all() self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all() self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id) self._invalidate_get_event_cache(event.event_id)
@ -240,7 +239,6 @@ class SlavedEventStore(BaseSlavedStore):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,)) self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
self._membership_stream_cache.entity_has_changed( self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering

View file

@ -124,6 +124,15 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks
def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_context(
room_id, group, state_ids
)
defer.returnValue(joined_users)
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The """ Fills out the context with the `current state` of the graph. The

View file

@ -393,7 +393,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point # Add an entry to the current_state_resets table to record the point
# where we clobbered the current state # where we clobbered the current state

View file

@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore):
for event in events: for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after( txn.call_after(
self._membership_stream_cache.entity_has_changed, self._membership_stream_cache.entity_has_changed,
@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore):
return results return results
@cachedInlineCallbacks(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
user_ids = yield self.get_users_in_room(room_id)
defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?" where_clause = "c.room_id = ?"
where_values = [room_id] where_values = [room_id]
@ -360,8 +354,7 @@ class RoomMemberStore(SQLBaseStore):
desc="who_forgot" desc="who_forgot"
) )
def get_joined_users_from_context(self, room_id, context): def get_joined_users_from_context(self, room_id, state_group, state_ids):
state_group = context.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -370,7 +363,7 @@ class RoomMemberStore(SQLBaseStore):
state_group = object() state_group = object()
return self._get_joined_users_from_context( return self._get_joined_users_from_context(
room_id, state_group, context.current_state_ids room_id, state_group, state_ids
) )
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True)