0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 15:43:52 +01:00

Merge pull request #199 from matrix-org/erikj/receipts

Implement read receipts.
This commit is contained in:
Erik Johnston 2015-07-16 18:18:36 +01:00
commit b6d4a4c6d8
19 changed files with 724 additions and 48 deletions

View file

@ -32,6 +32,7 @@ from .appservice import ApplicationServicesHandler
from .sync import SyncHandler from .sync import SyncHandler
from .auth import AuthHandler from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler
class Handlers(object): class Handlers(object):
@ -57,6 +58,7 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs) asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler( self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler( hs, asapi, AppServiceScheduler(

View file

@ -334,6 +334,11 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("presence"), None user, pagination_config.get_source_config("presence"), None
) )
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
public_room_ids = yield self.store.get_public_room_ids() public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit limit = pagin_config.limit
@ -404,7 +409,8 @@ class MessageHandler(BaseHandler):
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": presence,
"end": now_token.to_string() "receipts": receipt,
"end": now_token.to_string(),
} }
defer.returnValue(ret) defer.returnValue(ret)
@ -465,9 +471,12 @@ class MessageHandler(BaseHandler):
defer.returnValue([p for success, p in presence_defs if success]) defer.returnValue([p for success, p in presence_defs if success])
presence, (messages, token) = yield defer.gatherResults( receipts_handler = self.hs.get_handlers().receipts_handler
presence, receipts, (messages, token) = yield defer.gatherResults(
[ [
get_presence(), get_presence(),
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
self.store.get_recent_events_for_room( self.store.get_recent_events_for_room(
room_id, room_id,
limit=limit, limit=limit,
@ -495,5 +504,6 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(), "end": end_token.to_string(),
}, },
"state": state, "state": state,
"presence": presence "presence": presence,
"receipts": receipts,
}) })

View file

@ -992,7 +992,7 @@ class PresenceHandler(BaseHandler):
room_ids([str]): List of room_ids to notify. room_ids([str]): List of room_ids to notify.
""" """
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event( self.notifier.on_new_event(
"presence_key", "presence_key",
self._user_cachemap_latest_serial, self._user_cachemap_latest_serial,
users_to_push, users_to_push,

View file

@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
import logging
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs)
self.hs = hs
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
self._receipt_cache = None
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id,
event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
receipt = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": [event_id],
"data": {
"ts": int(self.clock.time_msec()),
}
}
is_new = yield self._handle_new_receipts([receipt])
if is_new:
self._push_remotes([receipt])
@defer.inlineCallbacks
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = [
{
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": user_values["event_ids"],
"data": user_values.get("data", {}),
}
for room_id, room_values in content.items()
for receipt_type, users in room_values.items()
for user_id, user_values in users.items()
]
yield self._handle_new_receipts(receipts)
@defer.inlineCallbacks
def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
res = yield self.store.insert_receipt(
room_id, receipt_type, user_id, event_ids, data
)
if not res:
# res will be None if this read receipt is 'old'
defer.returnValue(False)
stream_id, max_persisted_id = res
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_persisted_id, rooms=[room_id]
)
defer.returnValue(True)
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
# TODO: Some of this stuff should be coallesced.
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
remotedomains = set()
rm_handler = self.hs.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains)
for domain in remotedomains:
self.federation.send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": event_ids,
"data": data,
}
}
},
},
)
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
"""Gets all receipts for a room, upto the given key.
"""
result = yield self.store.get_linearized_receipts_for_room(
room_id,
to_key=to_key,
)
if not result:
defer.returnValue([])
event = {
"type": "m.receipt",
"room_id": room_id,
"content": result,
}
defer.returnValue([event])
class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
to_key = yield self.get_current_key()
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))
def get_current_key(self, direction='f'):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
if config.to_key:
from_key = int(config.to_key)
else:
from_key = None
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))

View file

@ -218,7 +218,7 @@ class TypingNotificationHandler(BaseHandler):
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event( self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[room_id] "typing_key", self._latest_room_serial, rooms=[room_id]
) )

View file

@ -221,16 +221,7 @@ class Notifier(object):
event event
) )
room_id = event.room_id app_streams = set()
room_user_streams = self.room_to_user_streams.get(room_id, set())
user_streams = room_user_streams.copy()
for user in extra_users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
for appservice in self.appservice_to_user_streams: for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks? # TODO (kegan): Redundant appservice listener checks?
@ -242,24 +233,20 @@ class Notifier(object):
app_user_streams = self.appservice_to_user_streams.get( app_user_streams = self.appservice_to_user_streams.get(
appservice, set() appservice, set()
) )
user_streams |= app_user_streams app_streams |= app_user_streams
logger.debug("on_new_room_event listeners %s", user_streams) self.on_new_event(
"room_key", room_stream_id,
time_now_ms = self.clock.time_msec() users=extra_users,
for user_stream in user_streams: rooms=[event.room_id],
try: extra_streams=app_streams,
user_stream.notify(
"room_key", "s%d" % (room_stream_id,), time_now_ms
) )
except:
logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[],
""" Used to inform listeners that something has happend extra_streams=set()):
presence/user event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
@ -283,7 +270,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback, def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """

View file

@ -31,6 +31,7 @@ REQUIREMENTS = {
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
"blist": ["blist"],
"pysaml2": ["saml2"], "pysaml2": ["saml2"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {

View file

@ -19,6 +19,7 @@ from . import (
account, account,
register, register,
auth, auth,
receipts,
keys, keys,
) )
@ -39,4 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
account.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource)

View file

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern
import logging
logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet):
PATTERN = client_v2_pattern(
"/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(ReceiptRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):
user, client = yield self.auth.get_user_by_req(request)
yield self.receipts_handler.received_client_receipt(
room_id,
receipt_type,
user_id=user.to_string(),
event_id=event_id
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReceiptRestServlet(hs).register(http_server)

View file

@ -39,6 +39,8 @@ from .signatures import SignatureStore
from .filtering import FilteringStore from .filtering import FilteringStore
from .end_to_end_keys import EndToEndKeyStore from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore
import fnmatch import fnmatch
import imp import imp
@ -75,6 +77,7 @@ class DataStore(RoomMemberStore, RoomStore,
PushRuleStore, PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventsStore, EventsStore,
ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
): ):

View file

@ -329,13 +329,14 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator() self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()

348
synapse/storage/receipts.py Normal file
View file

@ -0,0 +1,348 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from twisted.internet import defer
from synapse.util import unwrapFirstError
from blist import sorteddict
import logging
import ujson as json
logger = logging.getLogger(__name__)
class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = _RoomStreamChangeCache()
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients.
Args:
room_ids (list): List of room_ids.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
list: A list of receipts.
"""
room_ids = set(room_ids)
if from_key:
room_ids = yield self._receipts_stream_cache.get_rooms_changed(
self, room_ids, from_key
)
results = yield defer.gatherResults(
[
self.get_linearized_receipts_for_room(
room_id, to_key, from_key=from_key
)
for room_id in room_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue([ev for res in results for ev in res])
@defer.inlineCallbacks
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
Args:
room_ids (str): The room id.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
list: A list of receipts.
"""
def f(txn):
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
txn.execute(
sql,
(room_id, from_key, to_key)
)
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
txn.execute(
sql,
(room_id, to_key)
)
rows = self.cursor_to_dict(txn)
return rows
rows = yield self.runInteraction(
"get_linearized_receipts_for_room", f
)
if not rows:
defer.returnValue([])
content = {}
for row in rows:
content.setdefault(
row["event_id"], {}
).setdefault(
row["receipt_type"], {}
)[row["user_id"]] = json.loads(row["data"])
defer.returnValue([{
"type": "m.receipt",
"room_id": room_id,
"content": content,
}])
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
@cached
@defer.inlineCallbacks
def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers.
"""
rows = yield self._simple_select_list(
table="receipts_graph",
keyvalues={"room_id": room_id},
retcols=["receipt_type", "user_id", "event_id"],
desc="get_linearized_receipts_for_room",
)
result = {}
for row in rows:
result.setdefault(
row["user_id"], {}
).setdefault(
row["receipt_type"], []
).append(row["event_id"])
defer.returnValue(result)
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
)
txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall()
if results:
res = self._simple_select_one_txn(
txn,
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
)
topological_ordering = int(res["topological_ordering"])
stream_ordering = int(res["stream_ordering"])
for to, so, _ in results:
if int(to) > topological_ordering:
return False
elif int(to) == topological_ordering and int(so) >= stream_ordering:
return False
self._simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
)
self._simple_insert_txn(
txn,
table="receipts_linearized",
values={
"stream_id": stream_id,
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
}
)
return True
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
return
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
query = (
"SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
" SELECT max(stream_ordering) WHERE event_id IN (%s)"
")"
) % (",".join(["?"] * len(event_ids)))
txn.execute(query, [room_id] + event_ids)
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id:
yield self._receipts_stream_cache.room_has_changed(
self, room_id, stream_id
)
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id, receipt_type, user_id, linearized_event_id,
data,
stream_id=stream_id,
)
if not have_persisted:
defer.returnValue(None)
yield self.insert_graph_receipt(
room_id, receipt_type, user_id, event_ids, data
)
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
defer.returnValue((stream_id, max_persisted_id))
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
data):
return self.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id, receipt_type, user_id, event_ids, data
)
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_ids, data):
self._simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
)
self._simple_insert_txn(
txn,
table="receipts_graph",
values={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": json.dumps(event_ids),
"data": json.dumps(data),
}
)
class _RoomStreamChangeCache(object):
"""Keeps track of the stream_id of the latest change in rooms.
Given a list of rooms and stream key, it will give a subset of rooms that
may have changed since that key. If the key is too old then the cache
will simply return all rooms.
"""
def __init__(self, size_of_cache=1000):
self._size_of_cache = size_of_cache
self._room_to_key = {}
self._cache = sorteddict()
self._earliest_key = None
@defer.inlineCallbacks
def get_rooms_changed(self, store, room_ids, key):
"""Returns subset of room ids that have had new receipts since the
given key. If the key is too old it will just return the given list.
"""
if key > (yield self._get_earliest_key(store)):
keys = self._cache.keys()
i = keys.bisect_right(key)
result = set(
self._cache[k] for k in keys[i:]
).intersection(room_ids)
else:
result = room_ids
defer.returnValue(result)
@defer.inlineCallbacks
def room_has_changed(self, store, room_id, key):
"""Informs the cache that the room has been changed at the given key.
"""
if key > (yield self._get_earliest_key(store)):
old_key = self._room_to_key.get(room_id, None)
if old_key:
key = max(key, old_key)
self._cache.pop(old_key, None)
self._cache[key] = room_id
while len(self._cache) > self._size_of_cache:
k, r = self._cache.popitem()
self._earliest_key = max(k, self._earliest_key)
self._room_to_key.pop(r, None)
@defer.inlineCallbacks
def _get_earliest_key(self, store):
if self._earliest_key is None:
self._earliest_key = yield store.get_max_receipt_stream_id()
self._earliest_key = int(self._earliest_key)
defer.returnValue(self._earliest_key)

View file

@ -0,0 +1,38 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS receipts_graph(
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_ids TEXT NOT NULL,
data TEXT NOT NULL,
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
);
CREATE TABLE IF NOT EXISTS receipts_linearized (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
data TEXT NOT NULL,
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
);
CREATE INDEX receipts_linearized_id ON receipts_linearized(
stream_id
);

View file

@ -72,7 +72,10 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id: with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self): def __init__(self, table, column):
self.table = table
self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = None self._current_max = None
@ -157,7 +160,7 @@ class StreamIdGenerator(object):
def _get_or_compute_current_max(self, txn): def _get_or_compute_current_max(self, txn):
with self._lock: with self._lock:
txn.execute("SELECT MAX(stream_ordering) FROM events") txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall() rows = txn.fetchall()
val, = rows[0] val, = rows[0]

View file

@ -20,6 +20,7 @@ from synapse.types import StreamToken
from synapse.handlers.presence import PresenceEventSource from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource
class NullSource(object): class NullSource(object):
@ -43,6 +44,7 @@ class EventSources(object):
"room": RoomEventSource, "room": RoomEventSource,
"presence": PresenceEventSource, "presence": PresenceEventSource,
"typing": TypingNotificationEventSource, "typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource,
} }
def __init__(self, hs): def __init__(self, hs):
@ -62,7 +64,10 @@ class EventSources(object):
), ),
typing_key=( typing_key=(
yield self.sources["typing"].get_current_key() yield self.sources["typing"].get_current_key()
) ),
receipt_key=(
yield self.sources["receipt"].get_current_key()
),
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -100,7 +100,7 @@ class EventID(DomainSpecificString):
class StreamToken( class StreamToken(
namedtuple( namedtuple(
"Token", "Token",
("room_key", "presence_key", "typing_key") ("room_key", "presence_key", "typing_key", "receipt_key")
) )
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -109,6 +109,9 @@ class StreamToken(
def from_string(cls, string): def from_string(cls, string):
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
if len(keys) == len(cls._fields) - 1:
# i.e. old token from before receipt_key
keys.append("0")
return cls(*keys) return cls(*keys)
except: except:
raise SynapseError(400, "Invalid Token") raise SynapseError(400, "Invalid Token")
@ -131,6 +134,7 @@ class StreamToken(
(other_token.room_stream_id < self.room_stream_id) (other_token.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key)) or (int(other_token.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key)) or (int(other_token.typing_key) < int(self.typing_key))
or (int(other_token.receipt_key) < int(self.receipt_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View file

@ -66,8 +66,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource() self.mock_federation_resource = MockHttpResource()
mock_notifier = Mock(spec=["on_new_user_event"]) mock_notifier = Mock(spec=["on_new_event"])
self.on_new_user_event = mock_notifier.on_new_user_event self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[]) self.auth = Mock(spec=[])
@ -182,7 +182,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=20000, timeout=20000,
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
@ -245,7 +245,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
@ -299,7 +299,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
room_id=self.room_id, room_id=self.room_id,
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
@ -331,10 +331,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=10000, timeout=10000,
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
self.on_new_user_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
@ -351,7 +351,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.clock.advance_time(11) self.clock.advance_time(11)
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]), call('typing_key', 2, rooms=[self.room_id]),
]) ])
@ -377,10 +377,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=10000, timeout=10000,
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 3, rooms=[self.room_id]), call('typing_key', 3, rooms=[self.room_id]),
]) ])
self.on_new_user_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)

View file

@ -183,7 +183,17 @@ class EventStreamPermissionsTestCase(RestTestCase):
) )
self.assertEquals(200, code, msg=str(response)) self.assertEquals(200, code, msg=str(response))
self.assertEquals(0, len(response["chunk"])) # We may get a presence event for ourselves down
self.assertEquals(
0,
len([
c for c in response["chunk"]
if not (
c.get("type") == "m.presence"
and c["content"].get("user_id") == self.user_id
)
])
)
# joined room (expect all content for room) # joined room (expect all content for room)
yield self.join(room=room_id, user=self.user_id, tok=self.token) yield self.join(room=room_id, user=self.user_id, tok=self.token)

View file

@ -357,7 +357,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# all be ours # all be ours
# I'll already get my own presence state change # I'll already get my own presence state change
self.assertEquals({"start": "0_1_0", "end": "0_1_0", "chunk": []}, self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []},
response response
) )
@ -376,7 +376,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"/events?from=s0_1_0&timeout=0", None) "/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [ self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [
{"type": "m.presence", {"type": "m.presence",
"content": { "content": {
"user_id": "@banana:test", "user_id": "@banana:test",