Initial cut

This commit is contained in:
Erik Johnston 2016-02-15 17:10:40 +00:00
parent 9e7900da1e
commit e5999bfb1a
16 changed files with 1004 additions and 1207 deletions

View file

@ -19,6 +19,8 @@ from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_context_over_fn
from synapse.api.constants import Membership, EventTypes
from synapse.events import EventBase
from ._base import BaseHandler
@ -126,11 +128,12 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_handlers().presence_handler
try:
if affect_presence:
yield self.started_stream(auth_user)
context = yield presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence,
)
with context:
if timeout:
# If they've set a timeout set a minimum limit.
timeout = max(timeout, 500)
@ -145,6 +148,34 @@ class EventStreamHandler(BaseHandler):
is_guest=is_guest, explicit_room_id=room_id
)
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = []
for event in events:
if not isinstance(event, EventBase):
continue
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.store.get_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,
)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
UserID.from_string(event.state_key),
as_event=True,
)
to_add.append(ev)
events.extend(to_add)
time_now = self.clock.time_msec()
chunks = [
@ -159,10 +190,6 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk)
finally:
if affect_presence:
self.stopped_stream(auth_user)
class EventHandler(BaseHandler):

View file

@ -21,7 +21,6 @@ from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken
@ -254,8 +253,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Message:
presence = self.hs.get_handlers().presence_handler
with PreserveLoggingContext():
presence.bump_presence_active_time(user)
yield presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
@ -660,10 +658,6 @@ class MessageHandler(BaseHandler):
room_id=room_id,
)
# TODO(paul): I wish I was called with user objects not user_id
# strings...
auth_user = UserID.from_string(user_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
@ -688,13 +682,11 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members],
auth_user=auth_user,
[m.user_id for m in room_members],
as_event=True,
check_auth=False,
)
defer.returnValue(states.values())
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():

File diff suppressed because it is too large Load diff

View file

@ -49,6 +49,9 @@ class ProfileHandler(BaseHandler):
distributor = hs.get_distributor()
self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user)
distributor.observe(

View file

@ -582,6 +582,28 @@ class SyncHandler(BaseHandler):
if room_sync:
joined.append(room_sync)
# For each newly joined room, we want to send down presence of
# existing users.
presence_handler = self.hs.get_handlers().presence_handler
extra_presence_users = set()
for room_id in newly_joined_rooms:
users = yield self.store.get_users_in_room(event.room_id)
extra_presence_users.update(users)
# For each new member, send down presence.
for joined_sync in joined:
it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values())
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
extra_presence_users.add(event.state_key)
states = yield presence_handler.get_states(
[u for u in extra_presence_users if u != user_id],
as_event=True,
)
presence.extend(states)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)

View file

@ -17,7 +17,7 @@
"""
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state(
target_user=user, auth_user=requester.user)
if requester.user != user:
allowed = yield self.handlers.presence_handler.is_visible(
observed_user=user, observer_user=requester.user,
)
if not allowed:
raise AuthError(403, "You are allowed to see their presence.")
state = yield self.handlers.presence_handler.get_state(target_user=user)
defer.returnValue((200, state))
@ -45,6 +52,9 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
raise AuthError(403, "Can only set your own presence state")
state = {}
try:
content = json.loads(request.content.read())
@ -63,8 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except:
raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state(
target_user=user, auth_user=requester.user, state=state)
yield self.handlers.presence_handler.set_state(user, state)
defer.returnValue((200, {}))
@ -87,11 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=user, accepted=True)
for p in presence:
observed_user = p.pop("observed_user")
p["user_id"] = observed_user.to_string()
observer_user=user, accepted=True
)
defer.returnValue((200, presence))

View file

@ -304,18 +304,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
if event["type"] != EventTypes.Member:
continue
chunk.append(event)
# FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
target_user=target_user,
auth_user=requester.user,
)
event["content"].update(presence_state)
except:
pass
defer.returnValue((200, {
"chunk": chunk
@ -541,6 +529,10 @@ class RoomTypingRestServlet(ClientV1RestServlet):
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs)
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request)
@ -552,6 +544,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
typing_handler = self.handlers.typing_notification_handler
yield self.presence_handler.bump_presence_active_time(requester.user)
if content["typing"]:
yield typing_handler.started_typing(
target_user=target_user,

View file

@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):
@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet):
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.presence_handler.bump_presence_active_time(requester.user)
yield self.receipts_handler.received_client_receipt(
room_id,
receipt_type,

View file

@ -25,6 +25,7 @@ from synapse.events.utils import (
)
from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from ._base import client_v2_patterns
import copy
@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet):
self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_GET(self, request):
@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet):
else:
since_token = None
if set_presence == "online":
yield self.event_stream_handler.started_stream(user)
affect_presence = set_presence != PresenceState.OFFLINE
try:
if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence})
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,
)
with context:
sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout,
full_state=full_state
)
finally:
if set_presence == "online":
self.event_stream_handler.stopped_stream(user)
time_now = self.clock.time_msec()

View file

@ -20,7 +20,7 @@ from .appservice import (
from ._base import Cache
from .directory import DirectoryStore
from .events import EventsStore
from .presence import PresenceStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
from .registration import RegistrationStore
from .room import RoomStore
@ -47,6 +47,7 @@ from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -110,6 +111,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
@ -119,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
events_max = self._stream_id_gen.get_max_token(None)
events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max,
)
account_max = self._account_data_id_gen.get_max_token(None)
account_max = self._account_data_id_gen.get_max_token()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
self.__presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val,
prefilled_cache=presence_cache_prefill
)
super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
active_on_startup = self.__presence_on_startup
self.__presence_on_startup = None
return active_on_startup
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore,
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
@ -174,6 +197,27 @@ class DataStore(RoomMemberStore, RoomStore,
return cache, min_val
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
sql = (
"SELECT user_id, state, last_active, last_federation_update,"
" last_user_sync, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.cursor_to_dict(txn)
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return [UserPresenceState(**row) for row in rows]
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 29
SCHEMA_VERSION = 30
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -14,73 +14,128 @@
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
from synapse.api.constants import PresenceState
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from collections import namedtuple
from twisted.internet import defer
class UserPresenceState(namedtuple("UserPresenceState",
("user_id", "state", "last_active", "last_federation_update",
"last_user_sync", "status_msg", "currently_active"))):
"""Represents the current presence state of the user.
user_id (str)
last_active (int): Time in msec that the user last interacted with server.
last_federation_update (int): Time in msec since either a) we sent a presence
update to other servers or b) we received a presence update, depending
on if is a local user or not.
last_user_sync (int): Time in msec that the user last *completed* a sync
(or event stream).
status_msg (str): User set status message.
"""
def copy_and_replace(self, **kwargs):
return self._replace(**kwargs)
@classmethod
def default(cls, user_id):
"""Returns a default presence state.
"""
return cls(
user_id=user_id,
state=PresenceState.OFFLINE,
last_active=0,
last_federation_update=0,
last_user_sync=0,
status_msg=None,
currently_active=False,
)
class PresenceStore(SQLBaseStore):
def create_presence(self, user_localpart):
res = self._simple_insert(
table="presence",
values={"user_id": user_localpart},
desc="create_presence",
@defer.inlineCallbacks
def update_presence(self, presence_states):
stream_id_manager = yield self._presence_id_gen.get_next(self)
with stream_id_manager as stream_id:
yield self.runInteraction(
"update_presence",
self._update_presence_txn, stream_id, presence_states,
)
self.get_presence_state.invalidate((user_localpart,))
return res
defer.returnValue((stream_id, self._presence_id_gen.get_max_token()))
def has_presence_state(self, user_localpart):
return self._simple_select_one(
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["user_id"],
allow_none=True,
desc="has_presence_state",
def _update_presence_txn(self, txn, stream_id, presence_states):
for state in presence_states:
txn.call_after(
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
@cached(max_entries=2000)
def get_presence_state(self, user_localpart):
return self._simple_select_one(
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"],
desc="get_presence_state",
)
@cachedList(get_presence_state.cache, list_name="user_localparts",
inlineCallbacks=True)
def get_presence_states(self, user_localparts):
rows = yield self._simple_select_many_batch(
table="presence",
column="user_id",
iterable=user_localparts,
retcols=("user_id", "state", "status_msg", "mtime",),
desc="get_presence_states",
)
defer.returnValue({
row["user_id"]: {
"state": row["state"],
"status_msg": row["status_msg"],
"mtime": row["mtime"],
# Actually insert new rows
self._simple_insert_many_txn(
txn,
table="presence_stream",
values=[
{
"stream_id": stream_id,
"user_id": state.user_id,
"state": state.state,
"last_active": state.last_active,
"last_federation_update": state.last_federation_update,
"last_user_sync": state.last_user_sync,
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
for row in rows
})
for state in presence_states
],
)
# Delete old rows to stop database from getting really big
sql = (
"DELETE FROM presence_stream WHERE"
" stream_id < ?"
" AND user_id IN (%s)"
)
batches = (
presence_states[i:i + 50]
for i in xrange(0, len(presence_states), 50)
)
for states in batches:
args = [stream_id]
args.extend(s.user_id for s in states)
txn.execute(
sql % (",".join("?" for _ in states),),
args
)
@defer.inlineCallbacks
def set_presence_state(self, user_localpart, new_state):
res = yield self._simple_update_one(
table="presence",
keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()},
desc="set_presence_state",
def get_presence_for_users(self, user_ids):
rows = yield self._simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
keyvalues={},
retcols=(
"user_id",
"state",
"last_active",
"last_federation_update",
"last_user_sync",
"status_msg",
"currently_active",
),
)
self.get_presence_state.invalidate((user_localpart,))
defer.returnValue(res)
for row in rows:
row["currently_active"] = bool(row["currently_active"])
defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self):
return self._presence_id_gen.get_max_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
@ -128,6 +183,7 @@ class PresenceStore(SQLBaseStore):
desc="set_presence_list_accepted",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@ -154,6 +210,19 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_list_accepted",
)
@cachedInlineCallbacks()
def get_presence_list_observers_accepted(self, observed_userid):
user_localparts = yield self._simple_select_onecol(
table="presence_list",
keyvalues={"observed_user_id": observed_userid, "accepted": True},
retcol="user_id",
desc="get_presence_list_accepted",
)
defer.returnValue([
"@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
])
@defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid):
yield self._simple_delete_one(
@ -163,3 +232,4 @@ class PresenceStore(SQLBaseStore):
desc="del_presence_list",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))

View file

@ -0,0 +1,30 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE presence_stream(
stream_id BIGINT,
user_id TEXT,
state TEXT,
last_active BIGINT,
last_federation_update BIGINT,
last_user_sync BIGINT,
status_msg TEXT,
currently_active BOOLEAN
);
CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id);
CREATE INDEX presence_stream_user_id ON presence_stream(user_id);
CREATE INDEX presence_stream_state ON presence_stream(state);

View file

@ -130,9 +130,11 @@ class StreamIdGenerator(object):
return manager()
def get_max_token(self, store):
def get_max_token(self, *args):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Used to take a DataStore param, which is no longer needed.
"""
with self._lock:
if self._unfinished_ids:

View file

@ -42,7 +42,7 @@ class Clock(object):
def time_msec(self):
"""Returns the current system time in miliseconds since epoch."""
return self.time() * 1000
return int(self.time() * 1000)
def looping_call(self, f, msec):
l = task.LoopingCall(f)

View file

@ -224,12 +224,12 @@ class MockClock(object):
def time_msec(self):
return self.time() * 1000
def call_later(self, delay, callback):
def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context()
def wrapped_callback():
LoggingContext.thread_local.current_context = current_context
callback()
callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False]
self.timers.append(t)