0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 15:01:23 +01:00

Merge pull request #582 from matrix-org/erikj/presence

Rewrite presence for performance.
This commit is contained in:
Erik Johnston 2016-02-19 09:37:50 +00:00
commit e5ad2e5267
30 changed files with 1572 additions and 3224 deletions

View file

@ -32,7 +32,6 @@ class PresenceState(object):
OFFLINE = u"offline"
UNAVAILABLE = u"unavailable"
ONLINE = u"online"
FREE_FOR_CHAT = u"free_for_chat"
class JoinRules(object):

View file

@ -19,6 +19,8 @@ from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_context_over_fn
from synapse.api.constants import Membership, EventTypes
from synapse.events import EventBase
from ._base import BaseHandler
@ -126,11 +128,12 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_handlers().presence_handler
try:
if affect_presence:
yield self.started_stream(auth_user)
context = yield presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence,
)
with context:
if timeout:
# If they've set a timeout set a minimum limit.
timeout = max(timeout, 500)
@ -145,6 +148,34 @@ class EventStreamHandler(BaseHandler):
is_guest=is_guest, explicit_room_id=room_id
)
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = []
for event in events:
if not isinstance(event, EventBase):
continue
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.store.get_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,
)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
UserID.from_string(event.state_key),
as_event=True,
)
to_add.append(ev)
events.extend(to_add)
time_now = self.clock.time_msec()
chunks = [
@ -159,10 +190,6 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk)
finally:
if affect_presence:
self.stopped_stream(auth_user)
class EventHandler(BaseHandler):

View file

@ -21,7 +21,6 @@ from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken
@ -249,8 +248,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Message:
presence = self.hs.get_handlers().presence_handler
with PreserveLoggingContext():
presence.bump_presence_active_time(user)
yield presence.bump_presence_active_time(user)
def deduplicate_state_event(self, event, context):
"""
@ -674,10 +672,6 @@ class MessageHandler(BaseHandler):
room_id=room_id,
)
# TODO(paul): I wish I was called with user objects not user_id
# strings...
auth_user = UserID.from_string(user_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
@ -702,13 +696,11 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members],
auth_user=auth_user,
[m.user_id for m in room_members],
as_event=True,
check_auth=False,
)
defer.returnValue(states.values())
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():

File diff suppressed because it is too large Load diff

View file

@ -48,6 +48,9 @@ class ProfileHandler(BaseHandler):
distributor = hs.get_distributor()
self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user)
distributor.observe(

View file

@ -582,6 +582,28 @@ class SyncHandler(BaseHandler):
if room_sync:
joined.append(room_sync)
# For each newly joined room, we want to send down presence of
# existing users.
presence_handler = self.hs.get_handlers().presence_handler
extra_presence_users = set()
for room_id in newly_joined_rooms:
users = yield self.store.get_users_in_room(event.room_id)
extra_presence_users.update(users)
# For each new member, send down presence.
for joined_sync in joined:
it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values())
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
extra_presence_users.add(event.state_key)
states = yield presence_handler.get_states(
[u for u in extra_presence_users if u != user_id],
as_event=True,
)
presence.extend(states)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)

View file

@ -17,7 +17,7 @@
"""
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state(
target_user=user, auth_user=requester.user)
if requester.user != user:
allowed = yield self.handlers.presence_handler.is_visible(
observed_user=user, observer_user=requester.user,
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.handlers.presence_handler.get_state(target_user=user)
defer.returnValue((200, state))
@ -45,6 +52,9 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
raise AuthError(403, "Can only set your own presence state")
state = {}
try:
content = json.loads(request.content.read())
@ -63,8 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except:
raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state(
target_user=user, auth_user=requester.user, state=state)
yield self.handlers.presence_handler.set_state(user, state)
defer.returnValue((200, {}))
@ -87,11 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=user, accepted=True)
for p in presence:
observed_user = p.pop("observed_user")
p["user_id"] = observed_user.to_string()
observer_user=user, accepted=True
)
defer.returnValue((200, presence))

View file

@ -298,18 +298,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
if event["type"] != EventTypes.Member:
continue
chunk.append(event)
# FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
target_user=target_user,
auth_user=requester.user,
)
event["content"].update(presence_state)
except:
pass
defer.returnValue((200, {
"chunk": chunk
@ -535,6 +523,10 @@ class RoomTypingRestServlet(ClientV1RestServlet):
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs)
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request)
@ -546,6 +538,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
typing_handler = self.handlers.typing_notification_handler
yield self.presence_handler.bump_presence_active_time(requester.user)
if content["typing"]:
yield typing_handler.started_typing(
target_user=target_user,

View file

@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):
@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet):
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.presence_handler.bump_presence_active_time(requester.user)
yield self.receipts_handler.received_client_receipt(
room_id,
receipt_type,

View file

@ -25,6 +25,7 @@ from synapse.events.utils import (
)
from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from ._base import client_v2_patterns
import copy
@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet):
self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_GET(self, request):
@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet):
else:
since_token = None
if set_presence == "online":
yield self.event_stream_handler.started_stream(user)
affect_presence = set_presence != PresenceState.OFFLINE
try:
if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence})
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,
)
with context:
sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout,
full_state=full_state
)
finally:
if set_presence == "online":
self.event_stream_handler.stopped_stream(user)
time_now = self.clock.time_msec()

View file

@ -20,7 +20,7 @@ from .appservice import (
from ._base import Cache
from .directory import DirectoryStore
from .events import EventsStore
from .presence import PresenceStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
from .registration import RegistrationStore
from .room import RoomStore
@ -47,6 +47,7 @@ from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -110,6 +111,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
@ -119,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
events_max = self._stream_id_gen.get_max_token(None)
events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max,
)
account_max = self._account_data_id_gen.get_max_token(None)
account_max = self._account_data_id_gen.get_max_token()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
self.__presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val,
prefilled_cache=presence_cache_prefill
)
super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
active_on_startup = self.__presence_on_startup
self.__presence_on_startup = None
return active_on_startup
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore,
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
@ -174,6 +197,28 @@ class DataStore(RoomMemberStore, RoomStore,
return cache, min_val
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
sql = (
"SELECT user_id, state, last_active_ts, last_federation_update_ts,"
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.cursor_to_dict(txn)
txn.close()
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return [UserPresenceState(**row) for row in rows]
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())

View file

@ -168,7 +168,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id
)
result = yield self._account_data_id_gen.get_max_token(self)
result = yield self._account_data_id_gen.get_max_token()
defer.returnValue(result)
@defer.inlineCallbacks
@ -207,7 +207,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id
)
result = yield self._account_data_id_gen.get_max_token(self)
result = yield self._account_data_id_gen.get_max_token()
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):

View file

@ -131,7 +131,7 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException:
pass
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
max_persisted_id = yield self._stream_id_gen.get_max_token()
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 29
SCHEMA_VERSION = 30
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -14,73 +14,129 @@
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
from synapse.api.constants import PresenceState
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from collections import namedtuple
from twisted.internet import defer
class UserPresenceState(namedtuple("UserPresenceState",
("user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active"))):
"""Represents the current presence state of the user.
user_id (str)
last_active (int): Time in msec that the user last interacted with server.
last_federation_update (int): Time in msec since either a) we sent a presence
update to other servers or b) we received a presence update, depending
on if is a local user or not.
last_user_sync (int): Time in msec that the user last *completed* a sync
(or event stream).
status_msg (str): User set status message.
"""
def copy_and_replace(self, **kwargs):
return self._replace(**kwargs)
@classmethod
def default(cls, user_id):
"""Returns a default presence state.
"""
return cls(
user_id=user_id,
state=PresenceState.OFFLINE,
last_active_ts=0,
last_federation_update_ts=0,
last_user_sync_ts=0,
status_msg=None,
currently_active=False,
)
class PresenceStore(SQLBaseStore):
def create_presence(self, user_localpart):
res = self._simple_insert(
table="presence",
values={"user_id": user_localpart},
desc="create_presence",
@defer.inlineCallbacks
def update_presence(self, presence_states):
stream_id_manager = yield self._presence_id_gen.get_next(self)
with stream_id_manager as stream_id:
yield self.runInteraction(
"update_presence",
self._update_presence_txn, stream_id, presence_states,
)
defer.returnValue((stream_id, self._presence_id_gen.get_max_token()))
def _update_presence_txn(self, txn, stream_id, presence_states):
for state in presence_states:
txn.call_after(
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
# Actually insert new rows
self._simple_insert_many_txn(
txn,
table="presence_stream",
values=[
{
"stream_id": stream_id,
"user_id": state.user_id,
"state": state.state,
"last_active_ts": state.last_active_ts,
"last_federation_update_ts": state.last_federation_update_ts,
"last_user_sync_ts": state.last_user_sync_ts,
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
for state in presence_states
],
)
self.get_presence_state.invalidate((user_localpart,))
return res
def has_presence_state(self, user_localpart):
return self._simple_select_one(
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["user_id"],
allow_none=True,
desc="has_presence_state",
# Delete old rows to stop database from getting really big
sql = (
"DELETE FROM presence_stream WHERE"
" stream_id < ?"
" AND user_id IN (%s)"
)
@cached(max_entries=2000)
def get_presence_state(self, user_localpart):
return self._simple_select_one(
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"],
desc="get_presence_state",
batches = (
presence_states[i:i + 50]
for i in xrange(0, len(presence_states), 50)
)
@cachedList(get_presence_state.cache, list_name="user_localparts",
inlineCallbacks=True)
def get_presence_states(self, user_localparts):
rows = yield self._simple_select_many_batch(
table="presence",
column="user_id",
iterable=user_localparts,
retcols=("user_id", "state", "status_msg", "mtime",),
desc="get_presence_states",
)
defer.returnValue({
row["user_id"]: {
"state": row["state"],
"status_msg": row["status_msg"],
"mtime": row["mtime"],
}
for row in rows
})
for states in batches:
args = [stream_id]
args.extend(s.user_id for s in states)
txn.execute(
sql % (",".join("?" for _ in states),),
args
)
@defer.inlineCallbacks
def set_presence_state(self, user_localpart, new_state):
res = yield self._simple_update_one(
table="presence",
keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()},
desc="set_presence_state",
def get_presence_for_users(self, user_ids):
rows = yield self._simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
keyvalues={},
retcols=(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
)
self.get_presence_state.invalidate((user_localpart,))
defer.returnValue(res)
for row in rows:
row["currently_active"] = bool(row["currently_active"])
defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self):
return self._presence_id_gen.get_max_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
@ -128,6 +184,7 @@ class PresenceStore(SQLBaseStore):
desc="set_presence_list_accepted",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@ -154,6 +211,19 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_list_accepted",
)
@cachedInlineCallbacks()
def get_presence_list_observers_accepted(self, observed_userid):
user_localparts = yield self._simple_select_onecol(
table="presence_list",
keyvalues={"observed_user_id": observed_userid, "accepted": True},
retcol="user_id",
desc="get_presence_list_accepted",
)
defer.returnValue([
"@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
])
@defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid):
yield self._simple_delete_one(
@ -163,3 +233,4 @@ class PresenceStore(SQLBaseStore):
desc="del_presence_list",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))

View file

@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
)
@cached(num_args=2)
@ -222,7 +222,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
return self._receipts_id_gen.get_max_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data
)
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
max_persisted_id = yield self._stream_id_gen.get_max_token()
defer.returnValue((stream_id, max_persisted_id))

View file

@ -0,0 +1,30 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE presence_stream(
stream_id BIGINT,
user_id TEXT,
state TEXT,
last_active_ts BIGINT,
last_federation_update_ts BIGINT,
last_user_sync_ts BIGINT,
status_msg TEXT,
currently_active BOOLEAN
);
CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id);
CREATE INDEX presence_stream_user_id ON presence_stream(user_id);
CREATE INDEX presence_stream_state ON presence_stream(state);

View file

@ -531,7 +531,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token(self)
token = yield self._stream_id_gen.get_max_token()
if direction != 'b':
defer.returnValue("s%d" % (token,))
else:

View file

@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns:
A deferred int.
"""
return self._account_data_id_gen.get_max_token(self)
return self._account_data_id_gen.get_max_token()
@cached()
def get_tags_for_user(self, user_id):
@ -147,7 +147,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
result = yield self._account_data_id_gen.get_max_token(self)
result = yield self._account_data_id_gen.get_max_token()
defer.returnValue(result)
@defer.inlineCallbacks
@ -169,7 +169,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
result = yield self._account_data_id_gen.get_max_token(self)
result = yield self._account_data_id_gen.get_max_token()
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):

View file

@ -130,7 +130,7 @@ class StreamIdGenerator(object):
return manager()
def get_max_token(self, store):
def get_max_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""

View file

@ -42,7 +42,7 @@ class Clock(object):
def time_msec(self):
"""Returns the current system time in miliseconds since epoch."""
return self.time() * 1000
return int(self.time() * 1000)
def looping_call(self, f, msec):
l = task.LoopingCall(f)

View file

@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class _Entry(object):
__slots__ = ["end_key", "queue"]
def __init__(self, end_key):
self.end_key = end_key
self.queue = []
class WheelTimer(object):
"""Stores arbitrary objects that will be returned after their timers have
expired.
"""
def __init__(self, bucket_size=5000):
"""
Args:
bucket_size (int): Size of buckets in ms. Corresponds roughly to the
accuracy of the timer.
"""
self.bucket_size = bucket_size
self.entries = []
self.current_tick = 0
def insert(self, now, obj, then):
"""Inserts object into timer.
Args:
now (int): Current time in msec
obj (object): Object to be inserted
then (int): When to return the object strictly after.
"""
then_key = int(then / self.bucket_size) + 1
if self.entries:
min_key = self.entries[0].end_key
max_key = self.entries[-1].end_key
if then_key <= max_key:
# The max here is to protect against inserts for times in the past
self.entries[max(min_key, then_key) - min_key].queue.append(obj)
return
next_key = int(now / self.bucket_size) + 1
if self.entries:
last_key = self.entries[-1].end_key
else:
last_key = next_key
# Handle the case when `then` is in the past and `entries` is empty.
then_key = max(last_key, then_key)
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
self.entries.extend(
_Entry(key) for key in xrange(last_key, then_key + 1)
)
self.entries[-1].queue.append(obj)
def fetch(self, now):
"""Fetch any objects that have timed out
Args:
now (ms): Current time in msec
Returns:
list: List of objects that have timed out
"""
now_key = int(now / self.bucket_size)
ret = []
while self.entries and self.entries[0].end_key <= now_key:
ret.extend(self.entries.pop(0).queue)
return ret

File diff suppressed because it is too large Load diff

View file

@ -1,311 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains tests of the "presence-like" data that is shared between
presence and profiles; namely, the displayname and avatar_url."""
from tests import unittest
from twisted.internet import defer
from mock import Mock, call, ANY, NonCallableMock
from ..utils import MockClock, setup_test_homeserver
from synapse.api.constants import PresenceState
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
OFFLINE = PresenceState.OFFLINE
UNAVAILABLE = PresenceState.UNAVAILABLE
ONLINE = PresenceState.ONLINE
class MockReplication(object):
def __init__(self):
self.edu_handlers = {}
def register_edu_handler(self, edu_type, handler):
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
pass
def received_edu(self, origin, edu_type, content):
self.edu_handlers[edu_type](origin, content)
class PresenceAndProfileHandlers(object):
def __init__(self, hs):
self.presence_handler = PresenceHandler(hs)
self.profile_handler = ProfileHandler(hs)
class PresenceProfilelikeDataTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(
clock=MockClock(),
datastore=Mock(spec=[
"set_presence_state",
"is_presence_visible",
"set_profile_displayname",
"get_rooms_for_user",
]),
handlers=None,
resource_for_federation=Mock(),
http_client=None,
replication_layer=MockReplication(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.handlers = PresenceAndProfileHandlers(hs)
self.datastore = hs.get_datastore()
self.replication = hs.get_replication_layer()
self.replication.send_edu = Mock()
def send_edu(*args, **kwargs):
# print "send_edu: %s, %s" % (args, kwargs)
return defer.succeed((200, "OK"))
self.replication.send_edu.side_effect = send_edu
def get_profile_displayname(user_localpart):
return defer.succeed("Frank")
self.datastore.get_profile_displayname = get_profile_displayname
def is_presence_visible(*args, **kwargs):
return defer.succeed(False)
self.datastore.is_presence_visible = is_presence_visible
def get_profile_avatar_url(user_localpart):
return defer.succeed("http://foo")
self.datastore.get_profile_avatar_url = get_profile_avatar_url
self.presence_list = [
{"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test", "accepted": True},
]
def get_presence_list(user_localpart, accepted=None):
return defer.succeed(self.presence_list)
self.datastore.get_presence_list = get_presence_list
def user_rooms_intersect(userlist):
return defer.succeed(False)
self.datastore.user_rooms_intersect = user_rooms_intersect
self.handlers = hs.get_handlers()
self.mock_update_client = Mock()
def update(*args, **kwargs):
# print "mock_update_client: %s, %s" %(args, kwargs)
return defer.succeed(None)
self.mock_update_client.side_effect = update
self.handlers.presence_handler.push_update_to_clients = (
self.mock_update_client)
hs.handlers.room_member_handler = Mock(spec=[
"get_joined_rooms_for_user",
])
hs.handlers.room_member_handler.get_joined_rooms_for_user = (
lambda u: defer.succeed([]))
# Some local users to test with
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
self.u_clementine = UserID.from_string("@clementine:test")
# Remote user
self.u_potato = UserID.from_string("@potato:remote")
self.mock_get_joined = (
self.datastore.get_rooms_for_user
)
@defer.inlineCallbacks
def test_set_my_state(self):
self.presence_list = [
{"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test", "accepted": True},
]
mocked_set = self.datastore.set_presence_state
mocked_set.return_value = defer.succeed({"state": OFFLINE})
yield self.handlers.presence_handler.set_state(
target_user=self.u_apple, auth_user=self.u_apple,
state={"presence": UNAVAILABLE, "status_msg": "Away"})
mocked_set.assert_called_with("apple",
{"state": UNAVAILABLE, "status_msg": "Away"}
)
@defer.inlineCallbacks
def test_push_local(self):
def get_joined(*args):
return defer.succeed([])
self.mock_get_joined.side_effect = get_joined
self.presence_list = [
{"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test", "accepted": True},
]
self.datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE}
)
# TODO(paul): Gut-wrenching
from synapse.handlers.presence import UserPresenceCache
self.handlers.presence_handler._user_cachemap[self.u_apple] = (
UserPresenceCache()
)
self.handlers.presence_handler._user_cachemap[self.u_apple].update(
{"presence": OFFLINE}, serial=0
)
apple_set = self.handlers.presence_handler._local_pushmap.setdefault(
"apple", set())
apple_set.add(self.u_banana)
apple_set.add(self.u_clementine)
yield self.handlers.presence_handler.set_state(self.u_apple,
self.u_apple, {"presence": ONLINE}
)
yield self.handlers.presence_handler.set_state(self.u_banana,
self.u_banana, {"presence": ONLINE}
)
presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=self.u_apple, accepted=True)
self.assertEquals([
{"observed_user": self.u_banana,
"presence": ONLINE,
"last_active_ago": 0,
"displayname": "Frank",
"avatar_url": "http://foo",
"accepted": True},
{"observed_user": self.u_clementine,
"presence": OFFLINE,
"accepted": True}
], presence)
self.mock_update_client.assert_has_calls([
call(
users_to_push={self.u_apple, self.u_banana, self.u_clementine},
room_ids=[]
),
], any_order=True)
self.mock_update_client.reset_mock()
self.datastore.set_profile_displayname.return_value = defer.succeed(
None)
yield self.handlers.profile_handler.set_displayname(self.u_apple,
self.u_apple, "I am an Apple")
self.mock_update_client.assert_has_calls([
call(
users_to_push={self.u_apple, self.u_banana, self.u_clementine},
room_ids=[],
),
], any_order=True)
@defer.inlineCallbacks
def test_push_remote(self):
self.presence_list = [
{"observed_user_id": "@potato:remote", "accepted": True},
]
self.datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE}
)
# TODO(paul): Gut-wrenching
from synapse.handlers.presence import UserPresenceCache
self.handlers.presence_handler._user_cachemap[self.u_apple] = (
UserPresenceCache()
)
self.handlers.presence_handler._user_cachemap[self.u_apple].update(
{"presence": OFFLINE}, serial=0
)
apple_set = self.handlers.presence_handler._remote_sendmap.setdefault(
"apple", set())
apple_set.add(self.u_potato.domain)
yield self.handlers.presence_handler.set_state(self.u_apple,
self.u_apple, {"presence": ONLINE}
)
self.replication.send_edu.assert_called_with(
destination="remote",
edu_type="m.presence",
content={
"push": [
{"user_id": "@apple:test",
"presence": "online",
"last_active_ago": 0,
"displayname": "Frank",
"avatar_url": "http://foo"},
],
},
)
@defer.inlineCallbacks
def test_recv_remote(self):
self.presence_list = [
{"observed_user_id": "@banana:test"},
{"observed_user_id": "@clementine:test"},
]
# TODO(paul): Gut-wrenching
potato_set = self.handlers.presence_handler._remote_recvmap.setdefault(
self.u_potato, set()
)
potato_set.add(self.u_apple)
yield self.replication.received_edu(
"remote", "m.presence", {
"push": [
{"user_id": "@potato:remote",
"presence": "online",
"displayname": "Frank",
"avatar_url": "http://foo"},
],
}
)
self.mock_update_client.assert_called_with(
users_to_push=set([self.u_apple]),
room_ids=[],
)
state = yield self.handlers.presence_handler.get_state(self.u_potato,
self.u_apple)
self.assertEquals(
{"presence": ONLINE,
"displayname": "Frank",
"avatar_url": "http://foo"},
state)

View file

@ -70,9 +70,6 @@ class ProfileTestCase(unittest.TestCase):
self.handler = hs.get_handlers().profile_handler
# TODO(paul): Icky signal declarings.. booo
hs.get_distributor().declare("changed_presencelike_data")
@defer.inlineCallbacks
def test_get_my_name(self):
yield self.store.set_profile_displayname(

View file

@ -1,412 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests REST events for /presence paths."""
from tests import unittest
from twisted.internet import defer
from mock import Mock
from ....utils import MockHttpResource, setup_test_homeserver
from synapse.api.constants import PresenceState
from synapse.handlers.presence import PresenceHandler
from synapse.rest.client.v1 import presence
from synapse.rest.client.v1 import events
from synapse.types import Requester, UserID
from synapse.util.async import run_on_reactor
from collections import namedtuple
OFFLINE = PresenceState.OFFLINE
UNAVAILABLE = PresenceState.UNAVAILABLE
ONLINE = PresenceState.ONLINE
myid = "@apple:test"
PATH_PREFIX = "/_matrix/client/api/v1"
class NullSource(object):
"""This event source never yields any events and its token remains at
zero. It may be useful for unit-testing."""
def __init__(self, hs):
pass
def get_new_events(
self,
user,
from_key,
room_ids=None,
limit=None,
is_guest=None
):
return defer.succeed(([], from_key))
def get_current_key(self, direction='f'):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_key))
class JustPresenceHandlers(object):
def __init__(self, hs):
self.presence_handler = PresenceHandler(hs)
class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver(
datastore=Mock(spec=[
"get_presence_state",
"set_presence_state",
"insert_client_ip",
]),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore()
self.datastore.get_app_service_by_token = Mock(return_value=None)
def get_presence_list(*a, **kw):
return defer.succeed([])
self.datastore.get_presence_list = get_presence_list
def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(myid),
"token_id": 1,
"is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[
"get_joined_rooms_for_user",
]
)
def get_rooms_for_user(user):
return defer.succeed([])
room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
presence.register_servlets(hs, self.mock_resource)
self.u_apple = UserID.from_string(myid)
@defer.inlineCallbacks
def test_get_my_status(self):
mocked_get = self.datastore.get_presence_state
mocked_get.return_value = defer.succeed(
{"state": ONLINE, "status_msg": "Available"}
)
(code, response) = yield self.mock_resource.trigger("GET",
"/presence/%s/status" % (myid), None)
self.assertEquals(200, code)
self.assertEquals(
{"presence": ONLINE, "status_msg": "Available"},
response
)
mocked_get.assert_called_with("apple")
@defer.inlineCallbacks
def test_set_my_status(self):
mocked_set = self.datastore.set_presence_state
mocked_set.return_value = defer.succeed({"state": OFFLINE})
(code, response) = yield self.mock_resource.trigger("PUT",
"/presence/%s/status" % (myid),
'{"presence": "unavailable", "status_msg": "Away"}')
self.assertEquals(200, code)
mocked_set.assert_called_with("apple",
{"state": UNAVAILABLE, "status_msg": "Away"}
)
class PresenceListTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver(
datastore=Mock(spec=[
"has_presence_state",
"get_presence_state",
"allow_presence_visible",
"is_presence_visible",
"add_presence_list_pending",
"set_presence_list_accepted",
"del_presence_list",
"get_presence_list",
"insert_client_ip",
]),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore()
self.datastore.get_app_service_by_token = Mock(return_value=None)
def has_presence_state(user_localpart):
return defer.succeed(
user_localpart in ("apple", "banana",)
)
self.datastore.has_presence_state = has_presence_state
def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(myid),
"token_id": 1,
"is_guest": False,
}
hs.handlers.room_member_handler = Mock(
spec=[
"get_joined_rooms_for_user",
]
)
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource)
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_get_my_list(self):
self.datastore.get_presence_list.return_value = defer.succeed(
[{"observed_user_id": "@banana:test", "accepted": True}],
)
(code, response) = yield self.mock_resource.trigger("GET",
"/presence/list/%s" % (myid), None)
self.assertEquals(200, code)
self.assertEquals([
{"user_id": "@banana:test", "presence": OFFLINE, "accepted": True},
], response)
self.datastore.get_presence_list.assert_called_with(
"apple", accepted=True
)
@defer.inlineCallbacks
def test_invite(self):
self.datastore.add_presence_list_pending.return_value = (
defer.succeed(())
)
self.datastore.is_presence_visible.return_value = defer.succeed(
True
)
(code, response) = yield self.mock_resource.trigger("POST",
"/presence/list/%s" % (myid),
"""{"invite": ["@banana:test"]}"""
)
self.assertEquals(200, code)
self.datastore.add_presence_list_pending.assert_called_with(
"apple", "@banana:test"
)
self.datastore.set_presence_list_accepted.assert_called_with(
"apple", "@banana:test"
)
@defer.inlineCallbacks
def test_drop(self):
self.datastore.del_presence_list.return_value = (
defer.succeed(())
)
(code, response) = yield self.mock_resource.trigger("POST",
"/presence/list/%s" % (myid),
"""{"drop": ["@banana:test"]}"""
)
self.assertEquals(200, code)
self.datastore.del_presence_list.assert_called_with(
"apple", "@banana:test"
)
class PresenceEventStreamTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
# HIDEOUS HACKERY
# TODO(paul): This should be injected in via the HomeServer DI system
from synapse.streams.events import (
PresenceEventSource, EventSources
)
old_SOURCE_TYPES = EventSources.SOURCE_TYPES
def tearDown():
EventSources.SOURCE_TYPES = old_SOURCE_TYPES
self.tearDown = tearDown
EventSources.SOURCE_TYPES = {
k: NullSource for k in old_SOURCE_TYPES.keys()
}
EventSources.SOURCE_TYPES["presence"] = PresenceEventSource
clock = Mock(spec=[
"call_later",
"cancel_call_later",
"time_msec",
"looping_call",
])
clock.time_msec.return_value = 1000000
hs = yield setup_test_homeserver(
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
datastore=Mock(spec=[
"set_presence_state",
"get_presence_list",
"get_rooms_for_user",
]),
clock=clock,
)
def _get_user_by_req(req=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req
presence.register_servlets(hs, self.mock_resource)
events.register_servlets(hs, self.mock_resource)
hs.handlers.room_member_handler = Mock(spec=[])
self.room_members = []
def get_rooms_for_user(user):
if user in self.room_members:
return ["a-room"]
else:
return []
hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
hs.handlers.room_member_handler.get_room_members = (
lambda r: self.room_members if r == "a-room" else []
)
hs.handlers.room_member_handler._filter_events_for_client = (
lambda user_id, events, **kwargs: events
)
self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
self.mock_datastore.get_app_service_by_user_id = Mock(
return_value=defer.succeed(None)
)
self.mock_datastore.get_rooms_for_user = (
lambda u: [
namedtuple("Room", "room_id")(r)
for r in get_rooms_for_user(UserID.from_string(u))
]
)
def get_profile_displayname(user_id):
return defer.succeed("Frank")
self.mock_datastore.get_profile_displayname = get_profile_displayname
def get_profile_avatar_url(user_id):
return defer.succeed(None)
self.mock_datastore.get_profile_avatar_url = get_profile_avatar_url
def user_rooms_intersect(user_list):
room_member_ids = map(lambda u: u.to_string(), self.room_members)
shared = all(map(lambda i: i in room_member_ids, user_list))
return defer.succeed(shared)
self.mock_datastore.user_rooms_intersect = user_rooms_intersect
def get_joined_hosts_for_room(room_id):
return []
self.mock_datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
self.presence = hs.get_handlers().presence_handler
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_shortpoll(self):
self.room_members = [self.u_apple, self.u_banana]
self.mock_datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE}
)
self.mock_datastore.get_presence_list.return_value = defer.succeed(
[]
)
(code, response) = yield self.mock_resource.trigger("GET",
"/events?timeout=0", None)
self.assertEquals(200, code)
# We've forced there to be only one data stream so the tokens will
# all be ours
# I'll already get my own presence state change
self.assertEquals({"start": "0_1_0_0_0", "end": "0_1_0_0_0", "chunk": []},
response
)
self.mock_datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE}
)
self.mock_datastore.get_presence_list.return_value = defer.succeed([])
yield self.presence.set_state(self.u_banana, self.u_banana,
state={"presence": ONLINE}
)
yield run_on_reactor()
(code, response) = yield self.mock_resource.trigger("GET",
"/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code)
self.assertEquals({"start": "s0_1_0_0_0", "end": "s0_2_0_0_0", "chunk": [
{"type": "m.presence",
"content": {
"user_id": "@banana:test",
"presence": ONLINE,
"displayname": "Frank",
"last_active_ago": 0,
}},
]}, response)

View file

@ -953,12 +953,6 @@ class RoomInitialSyncTestCase(RestTestCase):
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
# Since I'm getting my own presence I need to exist as far as presence
# is concerned.
hs.get_handlers().presence_handler.registered_user(
UserID.from_string(self.user_id)
)
# create the room
self.room_id = yield self.create_room_as(self.user_id)

View file

@ -34,32 +34,6 @@ class PresenceStoreTestCase(unittest.TestCase):
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_state(self):
yield self.store.create_presence(
self.u_apple.localpart
)
state = yield self.store.get_presence_state(
self.u_apple.localpart
)
self.assertEquals(
{"state": None, "status_msg": None, "mtime": None}, state
)
yield self.store.set_presence_state(
self.u_apple.localpart, {"state": "online", "status_msg": "Here"}
)
state = yield self.store.get_presence_state(
self.u_apple.localpart
)
self.assertEquals(
{"state": "online", "status_msg": "Here", "mtime": 1000000}, state
)
@defer.inlineCallbacks
def test_visibility(self):
self.assertFalse((yield self.store.is_presence_visible(

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import unittest
from synapse.util.wheel_timer import WheelTimer
class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self):
wheel = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 150)
self.assertListEqual(wheel.fetch(101), [])
self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(130), [])
self.assertListEqual(wheel.fetch(149), [])
self.assertListEqual(wheel.fetch(156), [obj])
self.assertListEqual(wheel.fetch(170), [])
def test_mutli_insert(self):
wheel = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
obj3 = object()
wheel.insert(100, obj1, 150)
wheel.insert(105, obj2, 130)
wheel.insert(106, obj3, 160)
self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(135), [obj2])
self.assertListEqual(wheel.fetch(149), [])
self.assertListEqual(wheel.fetch(158), [obj1])
self.assertListEqual(wheel.fetch(160), [])
self.assertListEqual(wheel.fetch(200), [obj3])
self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self):
wheel = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 50)
self.assertListEqual(wheel.fetch(120), [obj])
def test_insert_past_mutli(self):
wheel = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
obj3 = object()
wheel.insert(100, obj1, 150)
wheel.insert(100, obj2, 140)
wheel.insert(100, obj3, 50)
self.assertListEqual(wheel.fetch(110), [obj3])
self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(147), [obj2])
self.assertListEqual(wheel.fetch(200), [obj1])
self.assertListEqual(wheel.fetch(240), [])

View file

@ -224,12 +224,12 @@ class MockClock(object):
def time_msec(self):
return self.time() * 1000
def call_later(self, delay, callback):
def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context()
def wrapped_callback():
LoggingContext.thread_local.current_context = current_context
callback()
callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False]
self.timers.append(t)