0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-05 02:38:54 +02:00

Merge branch 'develop' into daniel/removesomelies

Conflicts:
	synapse/notifier.py
This commit is contained in:
Mark Haines 2015-11-04 16:01:00 +00:00
commit 33b3e04049
20 changed files with 580 additions and 271 deletions

View file

@ -147,6 +147,10 @@ class FilterCollection(object):
self.filter_json.get("room", {}).get("ephemeral", {})
)
self.room_private_user_data = Filter(
self.filter_json.get("room", {}).get("private_user_data", {})
)
self.presence_filter = Filter(
self.filter_json.get("presence", {})
)
@ -172,6 +176,9 @@ class FilterCollection(object):
def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events)
def filter_room_private_user_data(self, events):
return self.room_private_user_data.filter(events)
class Filter(object):
def __init__(self, filter_json):

View file

@ -202,6 +202,7 @@ class TransactionQueue(object):
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
@ -213,9 +214,6 @@ class TransactionQueue(object):
)
return
logger.debug("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
@ -228,20 +226,22 @@ class TransactionQueue(object):
logger.debug("TX [%s] Nothing to send", destination)
return
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(

View file

@ -72,8 +72,6 @@ class FederationHandler(BaseHandler):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.lock_manager = hs.get_room_lock_manager()
self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up

View file

@ -322,6 +322,8 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
@ -398,6 +400,15 @@ class MessageHandler(BaseHandler):
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
private_user_data = []
tags = tags_by_room.get(event.room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
d["private_user_data"] = private_user_data
except:
logger.exception("Failed to get snapshot")
@ -447,6 +458,16 @@ class MessageHandler(BaseHandler):
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, member_event
)
private_user_data = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
result["private_user_data"] = private_user_data
defer.returnValue(result)
@defer.inlineCallbacks
@ -476,8 +497,8 @@ class MessageHandler(BaseHandler):
user_id, messages
)
start_token = StreamToken(token[0], 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0)
start_token = StreamToken(token[0], 0, 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0, 0)
time_now = self.clock.time_msec()

View file

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
class PrivateUserDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
def get_current_key(self, direction='f'):
return self.store.get_max_private_user_data_stream_id()
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
user_id = user.to_string()
last_stream_id = from_key
current_stream_id = yield self.store.get_max_private_user_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = []
for room_id, room_tags in tags.items():
results.append({
"type": "m.tag",
"content": {"tags": room_tags},
"room_id": room_id,
})
defer.returnValue((results, current_stream_id))
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
defer.returnValue(([], config.to_id))

View file

@ -51,6 +51,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"timeline",
"state",
"ephemeral",
"private_user_data",
])):
__slots__ = []
@ -58,13 +59,19 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.timeline or self.state or self.ephemeral)
return bool(
self.timeline
or self.state
or self.ephemeral
or self.private_user_data
)
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id",
"timeline",
"state",
"private_user_data",
])):
__slots__ = []
@ -72,7 +79,11 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.timeline or self.state)
return bool(
self.timeline
or self.state
or self.private_user_data
)
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
@ -161,7 +172,7 @@ class SyncHandler(BaseHandler):
"""
now_token = yield self.event_sources.get_current_token()
now_token, typing_by_room = yield self.typing_by_room(
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token
)
@ -184,6 +195,10 @@ class SyncHandler(BaseHandler):
)
)
tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string()
)
joined = []
invited = []
archived = []
@ -194,7 +209,8 @@ class SyncHandler(BaseHandler):
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
typing_by_room=typing_by_room
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
)
joined.append(room_sync)
elif event.membership == Membership.INVITE:
@ -213,6 +229,7 @@ class SyncHandler(BaseHandler):
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
)
archived.append(room_sync)
@ -227,7 +244,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def full_state_sync_for_joined_room(self, room_id, sync_config,
now_token, timeline_since_token,
typing_by_room):
ephemeral_by_room, tags_by_room):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred JoinedSyncResult.
@ -246,12 +263,25 @@ class SyncHandler(BaseHandler):
room_id=room_id,
timeline=batch,
state=current_state_events,
ephemeral=typing_by_room.get(room_id, []),
ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
))
def private_user_data_for_room(self, room_id, tags_by_room):
private_user_data = []
tags = tags_by_room.get(room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
return private_user_data
@defer.inlineCallbacks
def typing_by_room(self, sync_config, now_token, since_token=None):
"""Get the typing events for each room the user is in
def ephemeral_by_room(self, sync_config, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in
Args:
sync_config (SyncConfig): The flags, filters and user for the sync.
now_token (StreamToken): Where the server is currently up to.
@ -273,17 +303,32 @@ class SyncHandler(BaseHandler):
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
typing_by_room = {event["room_id"]: [event] for event in typing}
for event in typing:
event.pop("room_id")
logger.debug("Typing %r", typing_by_room)
ephemeral_by_room = {}
defer.returnValue((now_token, typing_by_room))
for event in typing:
room_id = event.pop("room_id")
ephemeral_by_room.setdefault(room_id, []).append(event)
receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events_for_user(
user=sync_config.user,
from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(),
)
now_token = now_token.copy_and_replace("receipt_key", receipt_key)
for event in receipts:
room_id = event.pop("room_id")
ephemeral_by_room.setdefault(room_id, []).append(event)
defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token,
timeline_since_token):
timeline_since_token, tags_by_room):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred JoinedSyncResult.
@ -301,6 +346,9 @@ class SyncHandler(BaseHandler):
room_id=room_id,
timeline=batch,
state=leave_state[leave_event_id].values(),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
))
@defer.inlineCallbacks
@ -320,7 +368,7 @@ class SyncHandler(BaseHandler):
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
now_token, typing_by_room = yield self.typing_by_room(
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token
)
@ -345,6 +393,11 @@ class SyncHandler(BaseHandler):
limit=timeline_limit + 1,
)
tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(),
since_token.private_user_data_key,
)
joined = []
archived = []
if len(room_events) <= timeline_limit:
@ -385,7 +438,10 @@ class SyncHandler(BaseHandler):
limited=limited,
),
state=state,
ephemeral=typing_by_room.get(room_id, [])
ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
)
if room_sync:
joined.append(room_sync)
@ -402,14 +458,14 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
typing_by_room
ephemeral_by_room, tags_by_room
)
if room_sync:
joined.append(room_sync)
for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token
sync_config, leave_event, since_token, tags_by_room
)
archived.append(room_sync)
@ -473,7 +529,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token,
typing_by_room):
ephemeral_by_room, tags_by_room):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
@ -514,7 +570,10 @@ class SyncHandler(BaseHandler):
room_id=room_id,
timeline=batch,
state=state_events_delta,
ephemeral=typing_by_room.get(room_id, [])
ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
)
logging.debug("Room sync: %r", room_sync)
@ -523,7 +582,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event,
since_token):
since_token, tags_by_room):
""" Get the incremental delta needed to bring the client up to date for
the archived room.
Returns:
@ -564,6 +623,9 @@ class SyncHandler(BaseHandler):
room_id=leave_event.room_id,
timeline=batch,
state=state_events_delta,
private_user_data=self.private_user_data_for_room(
leave_event.room_id, tags_by_room
),
)
logging.debug("Room sync: %r", room_sync)

View file

@ -270,7 +270,7 @@ class Notifier(object):
@defer.inlineCallbacks
def wait_for_events(self, user, timeout, callback,
from_token=StreamToken("s0", "0", "0", "0")):
from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""

View file

@ -22,6 +22,7 @@ from . import (
receipts,
keys,
tokenrefresh,
tags,
)
from synapse.http.server import JsonResource
@ -44,3 +45,4 @@ class ClientV2AlphaRestResource(JsonResource):
receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)

View file

@ -220,6 +220,10 @@ class SyncRestServlet(RestServlet):
)
timeline_event_ids.append(event.event_id)
private_user_data = filter.filter_room_private_user_data(
room.private_user_data
)
result = {
"event_map": event_map,
"timeline": {
@ -228,6 +232,7 @@ class SyncRestServlet(RestServlet):
"limited": room.timeline.limited,
},
"state": {"events": state_event_ids},
"private_user_data": {"events": private_user_data},
}
if joined:

View file

@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import client_v2_pattern
from synapse.http.servlet import RestServlet
from synapse.api.errors import AuthError, SynapseError
from twisted.internet import defer
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
PATTERN = client_v2_pattern(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
)
def __init__(self, hs):
super(TagListServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id):
auth_user, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id)
defer.returnValue((200, {"tags": tags}))
class TagServlet(RestServlet):
"""
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
PATTERN = client_v2_pattern(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
def __init__(self, hs):
super(TagServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag):
auth_user, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
try:
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid tag JSON")
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag):
auth_user, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)

View file

@ -29,7 +29,6 @@ from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
@ -70,7 +69,6 @@ class BaseHomeServer(object):
'auth',
'rest_servlet_factory',
'state_handler',
'room_lock_manager',
'notifier',
'distributor',
'resource_for_client',
@ -201,9 +199,6 @@ class HomeServer(BaseHomeServer):
def build_state_handler(self):
return StateHandler(self)
def build_room_lock_manager(self):
return LockManager()
def build_distributor(self):
return Distributor()

View file

@ -41,6 +41,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore
from .search import SearchStore
from .tags import TagsStore
import logging
@ -71,6 +72,7 @@ class DataStore(RoomMemberStore, RoomStore,
ReceiptsStore,
EndToEndKeyStore,
SearchStore,
TagsStore,
):
def __init__(self, hs):

View file

@ -0,0 +1,38 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS room_tags(
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
tag TEXT NOT NULL, -- The name of the tag.
content TEXT NOT NULL, -- The JSON content of the tag.
CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag)
);
CREATE TABLE IF NOT EXISTS room_tags_revisions (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
stream_id BIGINT NOT NULL, -- The current version of the room tags.
CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id)
);
CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_id BIGINT NOT NULL,
CHECK (Lock='X')
);
INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0);

216
synapse/storage/tags.py Normal file
View file

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
self._private_user_data_id_gen = StreamIdGenerator(
"private_user_data_max_stream_id", "stream_id"
)
def get_max_private_user_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
return self._private_user_data_id_gen.get_max_token(self)
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
Args:
user_id(str): The user to get the tags for.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings.
"""
deferred = self._simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@deferred.addCallback
def tags_by_room(rows):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = json.loads(row["content"])
return tags_by_room
return deferred
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
"""Get all the tags for the rooms where the tags have changed since the
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
sql = (
"SELECT room_id from room_tags_revisions"
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()]
return room_ids
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
results = {}
if room_ids:
tags_by_room = yield self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room[room_id]
defer.returnValue(results)
def get_tags_for_room(self, user_id, room_id):
"""Get all the tags for the given room
Args:
user_id(str): The user to get tags for
room_id(str): The room to get tags for
Returns:
A deferred list of string tags.
"""
return self._simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(lambda rows: {
row["tag"]: json.loads(row["content"]) for row in rows
})
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
tag(str): The tag name to add.
content(dict): A json object to associate with the tag.
Returns:
A deferred that completes once the tag has been added.
"""
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
self._simple_upsert_txn(
txn,
table="room_tags",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"tag": tag,
},
values={
"content": content_json,
}
)
self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
"""Remove a tag from a room for a user.
Returns:
A deferred that completes once the tag has been removed
"""
def remove_tag_txn(txn, next_id):
sql = (
"DELETE FROM room_tags "
" WHERE user_id = ? AND room_id = ? AND tag = ?"
)
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.
Args:
txn: The database cursor
user_id(str): The ID of the user.
room_id(str): The ID of the room.
next_id(int): The the revision to advance to.
"""
update_max_id_sql = (
"UPDATE private_user_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
update_sql = (
"UPDATE room_tags_revisions"
" SET stream_id = ?"
" WHERE user_id = ?"
" AND room_id = ?"
)
txn.execute(update_sql, (next_id, user_id, room_id))
if txn.rowcount == 0:
insert_sql = (
"INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
" VALUES (?, ?, ?)"
)
try:
txn.execute(insert_sql, (user_id, room_id, next_id))
except self.database_engine.module.IntegrityError:
# Ignore insertion errors. It doesn't matter if the row wasn't
# inserted because if two updates happend concurrently the one
# with the higher stream_id will not be reported to a client
# unless the previous update has completed. It doesn't matter
# which stream_id ends up in the table, as long as it is higher
# than the id that the client has.
pass

View file

@ -253,16 +253,6 @@ class TransactionStore(SQLBaseStore):
retry_interval (int) - how long until next retry in ms
"""
# As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill(
destination,
{
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval
},
)
# XXX: we could chose to not bother persisting this if our cache thinks
# this is a NOOP
return self.runInteraction(
@ -275,31 +265,25 @@ class TransactionStore(SQLBaseStore):
def _set_destination_retry_timings(self, txn, destination,
retry_last_ts, retry_interval):
query = (
"UPDATE destinations"
" SET retry_last_ts = ?, retry_interval = ?"
" WHERE destination = ?"
)
txn.call_after(self.get_destination_retry_timings.invalidate, (destination,))
txn.execute(
query,
(
retry_last_ts, retry_interval, destination,
)
self._simple_upsert_txn(
txn,
"destinations",
keyvalues={
"destination": destination,
},
values={
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
insertion_values={
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
}
)
if txn.rowcount == 0:
# destination wasn't already in table. Insert it.
self._simple_insert_txn(
txn,
table="destinations",
values={
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
}
)
def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction.

View file

@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.private_user_data import PrivateUserDataEventSource
class EventSources(object):
@ -29,6 +30,7 @@ class EventSources(object):
"presence": PresenceEventSource,
"typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource,
"private_user_data": PrivateUserDataEventSource,
}
def __init__(self, hs):
@ -52,5 +54,8 @@ class EventSources(object):
receipt_key=(
yield self.sources["receipt"].get_current_key()
),
private_user_data_key=(
yield self.sources["private_user_data"].get_current_key()
),
)
defer.returnValue(token)

View file

@ -98,10 +98,13 @@ class EventID(DomainSpecificString):
class StreamToken(
namedtuple(
"Token",
("room_key", "presence_key", "typing_key", "receipt_key")
)
namedtuple("Token", (
"room_key",
"presence_key",
"typing_key",
"receipt_key",
"private_user_data_key",
))
):
_SEPARATOR = "_"
@ -109,7 +112,7 @@ class StreamToken(
def from_string(cls, string):
try:
keys = string.split(cls._SEPARATOR)
if len(keys) == len(cls._fields) - 1:
while len(keys) < len(cls._fields):
# i.e. old token from before receipt_key
keys.append("0")
return cls(*keys)
@ -128,13 +131,14 @@ class StreamToken(
else:
return int(self.room_key[1:].split("-")[-1])
def is_after(self, other_token):
def is_after(self, other):
"""Does this token contain events that the other doesn't?"""
return (
(other_token.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key))
or (int(other_token.receipt_key) < int(self.receipt_key))
(other.room_stream_id < self.room_stream_id)
or (int(other.presence_key) < int(self.presence_key))
or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.private_user_data_key) < int(self.private_user_data_key))
)
def copy_and_advance(self, key, new_value):

View file

@ -1,74 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class Lock(object):
def __init__(self, deferred, key):
self._deferred = deferred
self.released = False
self.key = key
def release(self):
self.released = True
self._deferred.callback(None)
def __del__(self):
if not self.released:
logger.critical("Lock was destructed but never released!")
self.release()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
logger.debug("Releasing lock for key=%r", self.key)
self.release()
class LockManager(object):
""" Utility class that allows us to lock based on a `key` """
def __init__(self):
self._lock_deferreds = {}
@defer.inlineCallbacks
def lock(self, key):
""" Allows us to block until it is our turn.
Args:
key (str)
Returns:
Lock
"""
new_deferred = defer.Deferred()
old_deferred = self._lock_deferreds.get(key)
self._lock_deferreds[key] = new_deferred
if old_deferred:
logger.debug("Queueing on lock for key=%r", key)
yield old_deferred
logger.debug("Obtained lock for key=%r", key)
else:
logger.debug("Entering uncontended lock for key=%r", key)
defer.returnValue(Lock(new_deferred, key))

View file

@ -369,7 +369,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# all be ours
# I'll already get my own presence state change
self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []},
self.assertEquals({"start": "0_1_0_0_0", "end": "0_1_0_0_0", "chunk": []},
response
)
@ -388,7 +388,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code)
self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [
self.assertEquals({"start": "s0_1_0_0_0", "end": "s0_2_0_0_0", "chunk": [
{"type": "m.presence",
"content": {
"user_id": "@banana:test",

View file

@ -1,108 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.util.lockutils import LockManager
class LockManagerTestCase(unittest.TestCase):
def setUp(self):
self.lock_manager = LockManager()
@defer.inlineCallbacks
def test_one_lock(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
lock1.release()
self.assertTrue(lock1.released)
@defer.inlineCallbacks
def test_concurrent_locks(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
deferred_lock2 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
self.assertFalse(deferred_lock2.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
self.assertFalse(deferred_lock2.called)
lock1.release()
self.assertTrue(lock1.released)
self.assertTrue(deferred_lock2.called)
lock2 = yield deferred_lock2
lock2.release()
@defer.inlineCallbacks
def test_sequential_locks(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
lock1.release()
self.assertTrue(lock1.released)
deferred_lock2 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock2.called)
lock2 = yield deferred_lock2
self.assertFalse(lock2.released)
lock2.release()
self.assertTrue(lock2.released)
@defer.inlineCallbacks
def test_with_statement(self):
key = "test"
with (yield self.lock_manager.lock(key)) as lock:
self.assertFalse(lock.released)
self.assertTrue(lock.released)
@defer.inlineCallbacks
def test_two_with_statement(self):
key = "test"
with (yield self.lock_manager.lock(key)):
pass
with (yield self.lock_manager.lock(key)):
pass