0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-30 17:04:00 +01:00

Merge branch 'notifier_performance' into markjh/presence_performance

This commit is contained in:
Mark Haines 2015-05-18 14:33:58 +01:00
commit 880fb46de0
13 changed files with 295 additions and 160 deletions

View file

@ -105,7 +105,9 @@ class BaseHandler(object):
if not suppress_auth: if not suppress_auth:
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.store.persist_event(event, context=context) (event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
@ -142,7 +144,8 @@ class BaseHandler(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners. # Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event( notify_d = self.notifier.on_new_room_event(
event, extra_users=extra_users event, event_stream_id, max_stream_id,
extra_users=extra_users
) )
def log_failure(f): def log_failure(f):

View file

@ -160,7 +160,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
yield self._handle_new_event( _, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, origin,
event, event,
state=state, state=state,
@ -203,7 +203,8 @@ class FederationHandler(BaseHandler):
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=extra_users event, event_stream_id, max_stream_id,
extra_users=extra_users
) )
def log_failure(f): def log_failure(f):
@ -563,7 +564,7 @@ class FederationHandler(BaseHandler):
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event( _, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, origin,
new_event, new_event,
state=state, state=state,
@ -573,7 +574,8 @@ class FederationHandler(BaseHandler):
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
new_event, extra_users=[joinee] new_event, event_stream_id, max_stream_id,
extra_users=[joinee]
) )
def log_failure(f): def log_failure(f):
@ -639,7 +641,9 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
context = yield self._handle_new_event(origin, event) context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
logger.debug( logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s", "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@ -655,7 +659,7 @@ class FederationHandler(BaseHandler):
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f): def log_failure(f):
@ -729,7 +733,7 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
backfilled=False, backfilled=False,
@ -738,7 +742,8 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=[target_user], event, event_stream_id, max_stream_id,
extra_users=[target_user],
) )
def log_failure(f): def log_failure(f):
@ -916,7 +921,7 @@ class FederationHandler(BaseHandler):
) )
raise raise
yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
backfilled=backfilled, backfilled=backfilled,
@ -924,7 +929,7 @@ class FederationHandler(BaseHandler):
current_state=current_state, current_state=current_state,
) )
defer.returnValue(context) defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,

View file

@ -363,6 +363,8 @@ class PresenceHandler(BaseHandler):
curr_users = yield rm_handler.get_room_members(room_id) curr_users = yield rm_handler.get_room_members(room_id)
for local_user in [c for c in curr_users if self.hs.is_mine(c)]: for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
statuscache = self._get_or_offline_usercache(local_user)
statuscache.update({}, serial=self._user_cachemap_latest_serial)
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
observed_user=local_user, observed_user=local_user,
users_to_push=[user], users_to_push=[user],
@ -952,6 +954,8 @@ class PresenceHandler(BaseHandler):
""" """
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event( self.notifier.on_new_user_event(
"presence_key",
self._user_cachemap_latest_serial,
users_to_push, users_to_push,
room_ids, room_ids,
) )

View file

@ -92,7 +92,7 @@ class SyncHandler(BaseHandler):
result = yield self.current_sync_for_user(sync_config, since_token) result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result) defer.returnValue(result)
else: else:
def current_sync_callback(): def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler

View file

@ -218,7 +218,9 @@ class TypingNotificationHandler(BaseHandler):
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event(rooms=[room_id]) self.notifier.on_new_user_event(
"typing_key", self._latest_room_serial, rooms=[room_id]
)
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):

View file

@ -42,56 +42,78 @@ def count(func, l):
class _NotificationListener(object): class _NotificationListener(object):
""" This represents a single client connection to the events stream. """ This represents a single client connection to the events stream.
The events stream handler will have yielded to the deferred, so to The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred. notify the handler it is sufficient to resolve the deferred.
"""
def __init__(self, deferred):
self.deferred = deferred
def notified(self):
return self.deferred.called
def notify(self, token):
""" Inform whoever is listening about the new events.
"""
try:
self.deferred.callback(token)
except defer.AlreadyCalledError:
pass
class _NotifierUserStream(object):
"""This represents a user connected to the event stream.
It tracks the most recent stream token for that user.
At a given point a user may have a number of streams listening for
events.
This listener will also keep track of which rooms it is listening in This listener will also keep track of which rooms it is listening in
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, deferred, appservice=None): def __init__(self, user, rooms, current_token, time_now_ms,
self.user = user appservice=None):
self.user = str(user)
self.appservice = appservice self.appservice = appservice
self.deferred = deferred self.listeners = set()
self.rooms = rooms self.rooms = set(rooms)
self.timer = None self.current_token = current_token
self.last_notified_ms = time_now_ms
def notified(self): def notify(self, stream_key, stream_id, time_now_ms):
return self.deferred.called """Notify any listeners for this user of a new event from an
event source.
Args:
stream_key(str): The stream the event came from.
stream_id(str): The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds.
"""
self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id
)
if self.listeners:
self.last_notified_ms = time_now_ms
listeners = self.listeners
self.listeners = set()
for listener in listeners:
listener.notify(self.current_token)
def notify(self, notifier): def remove(self, notifier):
""" Inform whoever is listening about the new events. This will """ Remove this listener from all the indexes in the Notifier
also remove this listener from all the indexes in the Notifier
it knows about. it knows about.
""" """
try:
self.deferred.callback(None)
except defer.AlreadyCalledError:
pass
# Should the following be done be using intrusively linked lists?
# -- erikj
for room in self.rooms: for room in self.rooms:
lst = notifier.room_to_listeners.get(room, set()) lst = notifier.room_to_user_streams.get(room, set())
lst.discard(self) lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self) notifier.user_to_user_stream.pop(self.user)
if self.appservice: if self.appservice:
notifier.appservice_to_listeners.get( notifier.appservice_to_user_streams.get(
self.appservice, set() self.appservice, set()
).discard(self) ).discard(self)
# Cancel the timeout for this notifer if one exists.
if self.timer is not None:
try:
notifier.clock.cancel_call_later(self.timer)
except:
logger.warn("Failed to cancel notifier timer")
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
@ -100,14 +122,18 @@ class Notifier(object):
Primarily used from the /events stream. Primarily used from the /events stream.
""" """
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.room_to_listeners = {} self.user_to_user_stream = {}
self.user_to_listeners = {} self.room_to_user_streams = {}
self.appservice_to_listeners = {} self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -115,46 +141,80 @@ class Notifier(object):
"user_joined_room", self._user_joined_room "user_joined_room", self._user_joined_room
) )
self.clock.looping_call(
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
)
# This is not a very cheap test to perform, but it's only executed # This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at # when rendering the metrics page, which is likely once per minute at
# most when scraping it. # most when scraping it.
def count_listeners(): def count_listeners():
all_listeners = set() all_user_streams = set()
for x in self.room_to_listeners.values(): for x in self.room_to_user_streams.values():
all_listeners |= x all_user_streams |= x
for x in self.user_to_listeners.values(): for x in self.user_to_user_stream:
all_listeners |= x all_user_streams.add(x)
for x in self.appservice_to_listeners.values(): for x in self.appservice_to_user_streams.values():
all_listeners |= x all_user_streams |= x
return len(all_listeners) return sum(len(stream.listeners) for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
metrics.register_callback( metrics.register_callback(
"rooms", "rooms",
lambda: count(bool, self.room_to_listeners.values()), lambda: count(bool, self.room_to_user_streams.values()),
) )
metrics.register_callback( metrics.register_callback(
"users", "users",
lambda: count(bool, self.user_to_listeners.values()), lambda: len(self.user_to_user_stream),
) )
metrics.register_callback( metrics.register_callback(
"appservices", "appservices",
lambda: count(bool, self.appservice_to_listeners.values()), lambda: count(bool, self.appservice_to_user_streams.values()),
) )
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_room_event(self, event, extra_users=[]): def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
in the room, room event wise. in the room, room event wise.
This triggers the notifier to wake up any listeners that are This triggers the notifier to wake up any listeners that are
listening to the room, and any listeners for the users in the listening to the room, and any listeners for the users in the
`extra_users` param. `extra_users` param.
The events can be peristed out of order. The notifier will wait
until all previous events have been persisted before notifying
the client streams.
""" """
yield run_on_reactor() yield run_on_reactor()
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
self._notify_pending_new_room_events(max_room_stream_id)
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
max_room_stream_id(int): The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
self.pending_new_room_events = []
for room_stream_id, event, extra_users in pending:
if room_stream_id > max_room_stream_id:
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
else:
self._on_new_room_event(event, room_stream_id, extra_users)
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services( self.hs.get_handlers().appservice_handler.notify_interested_services(
event event
@ -162,70 +222,61 @@ class Notifier(object):
room_id = event.room_id room_id = event.room_id
room_listeners = self.room_to_listeners.get(room_id, set()) room_user_streams = self.room_to_user_streams.get(room_id, set())
_discard_if_notified(room_listeners) user_streams = room_user_streams.copy()
listeners = room_listeners.copy()
for user in extra_users: for user in extra_users:
user_listeners = self.user_to_listeners.get(user, set()) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
_discard_if_notified(user_listeners) for appservice in self.appservice_to_user_streams:
listeners |= user_listeners
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks? # TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_listeners set, but # App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to # that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this # receive *invites* for users they are interested in. Does this
# make the room_to_listeners check somewhat obselete? # make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event): if appservice.is_interested(event):
app_listeners = self.appservice_to_listeners.get( app_user_streams = self.appservice_to_user_streams.get(
appservice, set() appservice, set()
) )
user_streams |= app_user_streams
_discard_if_notified(app_listeners) logger.debug("on_new_room_event listeners %s", user_streams)
listeners |= app_listeners time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
logger.debug("on_new_room_event listeners %s", listeners)
for listener in listeners:
try: try:
listener.notify(self) user_stream.notify(
"room_key", "s%d" % (room_stream_id,), time_now_ms
)
except: except:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_new_user_event(self, users=[], rooms=[]): def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend """ Used to inform listeners that something has happend
presence/user event wise. presence/user event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
yield run_on_reactor() yield run_on_reactor()
listeners = set() user_streams = set()
for user in users: for user in users:
user_listeners = self.user_to_listeners.get(user, set()) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
_discard_if_notified(user_listeners) user_streams.add(user_stream)
listeners |= user_listeners
for room in rooms: for room in rooms:
room_listeners = self.room_to_listeners.get(room, set()) user_streams |= self.room_to_user_streams.get(room, set())
_discard_if_notified(room_listeners) time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
listeners |= room_listeners
for listener in listeners:
try: try:
listener.notify(self) user_stream.notify(stream_key, new_token, time_now_ms)
except: except:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@ -237,46 +288,62 @@ class Notifier(object):
""" """
deferred = defer.Deferred() deferred = defer.Deferred()
appservice = yield self.hs.get_datastore().get_app_service_by_user_id( time_now_ms = self.clock.time_msec()
user.to_string()
)
listener = [_NotificationListener( user = str(user)
user_stream = self.user_to_user_stream.get(user)
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(user)
user_stream = _NotifierUserStream(
user=user, user=user,
rooms=rooms, rooms=rooms,
deferred=deferred,
appservice=appservice, appservice=appservice,
)] current_token=current_token,
time_now_ms=time_now_ms,
)
self._register_with_keys(user_stream)
else:
current_token = user_stream.current_token
if timeout: listener = [_NotificationListener(deferred)]
self._register_with_keys(listener[0])
if timeout and not current_token.is_after(from_token):
user_stream.listeners.add(listener[0])
if current_token.is_after(from_token):
result = yield callback(from_token, current_token)
else:
result = None
result = yield callback()
timer = [None] timer = [None]
if result:
user_stream.listeners.discard(listener[0])
defer.returnValue(result)
return
if timeout: if timeout:
timed_out = [False] timed_out = [False]
def _timeout_listener(): def _timeout_listener():
timed_out[0] = True timed_out[0] = True
timer[0] = None timer[0] = None
listener[0].notify(self) user_stream.listeners.discard(listener[0])
listener[0].notify(current_token)
# We create multiple notification listeners so we have to manage # We create multiple notification listeners so we have to manage
# canceling the timeout ourselves. # canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener) timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]: while not result and not timed_out[0]:
yield deferred new_token = yield deferred
deferred = defer.Deferred() deferred = defer.Deferred()
listener[0] = _NotificationListener( listener[0] = _NotificationListener(deferred)
user=user, user_stream.listeners.add(listener[0])
rooms=rooms, result = yield callback(current_token, new_token)
deferred=deferred, current_token = new_token
appservice=appservice,
)
self._register_with_keys(listener[0])
result = yield callback()
if timer[0] is not None: if timer[0] is not None:
try: try:
@ -299,11 +366,15 @@ class Notifier(object):
limit = pagination_config.limit limit = pagination_config.limit
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(): def check_for_updates(before_token, after_token):
events = [] events = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)
if before_id == after_id:
continue
stuff, new_key = yield source.get_new_events_for_user( stuff, new_key = yield source.get_new_events_for_user(
user, getattr(from_token, keyname), limit, user, getattr(from_token, keyname), limit,
) )
@ -325,34 +396,36 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@log_function @log_function
def _register_with_keys(self, listener): def remove_expired_streams(self):
for room in listener.rooms: time_now_ms = self.clock.time_msec()
s = self.room_to_listeners.setdefault(room, set()) expired_streams = []
s.add(listener) expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values():
if stream.listeners:
continue
if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream)
self.user_to_listeners.setdefault(listener.user, set()).add(listener) for expired_stream in expired_streams:
expired_stream.remove(self)
if listener.appservice: @log_function
self.appservice_to_listeners.setdefault( def _register_with_keys(self, user_stream):
listener.appservice, set() self.user_to_user_stream[user_stream.user] = user_stream
).add(listener)
for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
def _user_joined_room(self, user, room_id): def _user_joined_room(self, user, room_id):
new_listeners = self.user_to_listeners.get(user, set()) user = str(user)
new_user_stream = self.user_to_user_stream.get(user)
listeners = self.room_to_listeners.setdefault(room_id, set()) if new_user_stream is not None:
listeners |= new_listeners room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream)
for l in new_listeners: new_user_stream.rooms.add(room_id)
l.rooms.add(room_id)
def _discard_if_notified(listener_set):
"""Remove any 'stale' listeners from the given set.
"""
to_discard = set()
for l in listener_set:
if l.notified():
to_discard.add(l)
listener_set -= to_discard

View file

@ -65,6 +65,9 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False, get_prev_content=False, allow_rejected=False,

View file

@ -70,6 +70,8 @@ class DomainSpecificString(
"""Return a string encoding the fields of the structure object.""" """Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
__str__ = to_string
@classmethod @classmethod
def create(cls, localpart, domain,): def create(cls, localpart, domain,):
return cls(localpart=localpart, domain=domain) return cls(localpart=localpart, domain=domain)
@ -107,7 +109,6 @@ class StreamToken(
def from_string(cls, string): def from_string(cls, string):
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
return cls(*keys) return cls(*keys)
except: except:
raise SynapseError(400, "Invalid Token") raise SynapseError(400, "Invalid Token")
@ -115,6 +116,39 @@ class StreamToken(
def to_string(self): def to_string(self):
return self._SEPARATOR.join([str(k) for k in self]) return self._SEPARATOR.join([str(k) for k in self])
@property
def room_stream_id(self):
# TODO(markjh): Awful hack to work around hacks in the presence tests
# which assume that the keys are integers.
if type(self.room_key) is int:
return self.room_key
else:
return int(self.room_key[1:].split("-")[-1])
def is_after(self, other_token):
"""Does this token contain events that the other doesn't?"""
return (
(other_token.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key))
)
def copy_and_advance(self, key, new_value):
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
"""
new_token = self.copy_and_replace(key, new_value)
if key == "room_key":
new_id = new_token.room_stream_id
old_id = self.room_stream_id
else:
new_id = int(getattr(new_token, key))
old_id = int(getattr(self, key))
if old_id < new_id:
return new_token
else:
return self
def copy_and_replace(self, key, new_value): def copy_and_replace(self, key, new_value):
d = self._asdict() d = self._asdict()
d[key] = new_value d[key] = new_value

View file

@ -83,7 +83,7 @@ class FederationTestCase(unittest.TestCase):
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"}, "hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
}) })
self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.persist_event.return_value = defer.succeed((1,1))
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
@ -126,5 +126,5 @@ class FederationTestCase(unittest.TestCase):
self.auth.check.assert_called_once_with(ANY, auth_events={}) self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
ANY, extra_users=[] ANY, 1, 1, extra_users=[]
) )

View file

@ -87,6 +87,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite(self): def test_invite(self):
room_id = "!foo:red" room_id = "!foo:red"
@ -160,7 +162,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event, context=context, event, context=context,
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[UserID.from_string(target_user_id)] event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
) )
self.assertFalse(self.datastore.get_room.called) self.assertFalse(self.datastore.get_room.called)
self.assertFalse(self.datastore.store_room.called) self.assertFalse(self.datastore.store_room.called)
@ -226,7 +228,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event, context=context event, context=context
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user] event, 1, 1, extra_users=[user]
) )
join_signal_observer.assert_called_with( join_signal_observer.assert_called_with(
@ -304,7 +306,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event, context=context event, context=context
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user] event, 1, 1, extra_users=[user]
) )
leave_signal_observer.assert_called_with( leave_signal_observer.assert_called_with(

View file

@ -183,7 +183,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -246,7 +246,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -300,7 +300,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
yield put_json.await_calls() yield put_json.await_calls()
@ -332,7 +332,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
self.on_new_user_event.reset_mock() self.on_new_user_event.reset_mock()
@ -352,7 +352,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.clock.advance_time(11) self.clock.advance_time(11)
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 2, rooms=[self.room_id]),
]) ])
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
@ -378,7 +378,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 3, rooms=[self.room_id]),
]) ])
self.on_new_user_event.reset_mock() self.on_new_user_event.reset_mock()

View file

@ -27,6 +27,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.rest.client.v1 import presence from synapse.rest.client.v1 import presence
from synapse.rest.client.v1 import events from synapse.rest.client.v1 import events
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor
OFFLINE = PresenceState.OFFLINE OFFLINE = PresenceState.OFFLINE
@ -264,11 +265,13 @@ class PresenceEventStreamTestCase(unittest.TestCase):
datastore=Mock(spec=[ datastore=Mock(spec=[
"set_presence_state", "set_presence_state",
"get_presence_list", "get_presence_list",
"get_rooms_for_user",
]), ]),
clock=Mock(spec=[ clock=Mock(spec=[
"call_later", "call_later",
"cancel_call_later", "cancel_call_later",
"time_msec", "time_msec",
"looping_call",
]), ]),
) )
@ -298,6 +301,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
self.mock_datastore.get_app_service_by_user_id = Mock( self.mock_datastore.get_app_service_by_user_id = Mock(
return_value=defer.succeed(None) return_value=defer.succeed(None)
) )
self.mock_datastore.get_rooms_for_user = (
lambda u: get_rooms_for_user(UserID.from_string(u))
)
def get_profile_displayname(user_id): def get_profile_displayname(user_id):
return defer.succeed("Frank") return defer.succeed("Frank")
@ -350,19 +356,19 @@ class PresenceEventStreamTestCase(unittest.TestCase):
self.mock_datastore.set_presence_state.return_value = defer.succeed( self.mock_datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE} {"state": ONLINE}
) )
self.mock_datastore.get_presence_list.return_value = defer.succeed( self.mock_datastore.get_presence_list.return_value = defer.succeed([])
[]
)
yield self.presence.set_state(self.u_banana, self.u_banana, yield self.presence.set_state(self.u_banana, self.u_banana,
state={"presence": ONLINE} state={"presence": ONLINE}
) )
yield run_on_reactor()
(code, response) = yield self.mock_resource.trigger("GET", (code, response) = yield self.mock_resource.trigger("GET",
"/events?from=0_1_0&timeout=0", None) "/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"start": "0_1_0", "end": "0_2_0", "chunk": [ self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
{"type": "m.presence", {"type": "m.presence",
"content": { "content": {
"user_id": "@banana:test", "user_id": "@banana:test",

View file

@ -197,6 +197,9 @@ class MockClock(object):
return t return t
def looping_call(self, function, interval):
pass
def cancel_call_later(self, timer): def cancel_call_later(self, timer):
if timer[2]: if timer[2]:
raise Exception("Cannot cancel an expired timer") raise Exception("Cannot cancel an expired timer")
@ -355,7 +358,7 @@ class MemoryDataStore(object):
return [] return []
def get_room_events_max_id(self): def get_room_events_max_id(self):
return 0 # TODO (erikj) return "s0" # TODO (erikj)
def get_send_event_level(self, room_id): def get_send_event_level(self, room_id):
return defer.succeed(0) return defer.succeed(0)