Merge branch 'develop' of github.com:matrix-org/synapse into erikj/state_storage

This commit is contained in:
Erik Johnston 2016-09-02 11:04:48 +01:00
commit 657847e4c6
14 changed files with 230 additions and 84 deletions

View file

@ -134,6 +134,12 @@ Installing prerequisites on Raspbian::
sudo pip install --upgrade ndg-httpsclient sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv sudo pip install --upgrade virtualenv
Installing prerequisites on openSUSE::
sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse

View file

@ -67,6 +67,8 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user(self, service, user_id): def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None response = None
try: try:
@ -86,6 +88,8 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_alias(self, service, alias): def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None response = None
try: try:
@ -113,6 +117,8 @@ class ApplicationServiceApi(SimpleHttpClient):
raise ValueError( raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind "Unrecognised 'kind' argument %r to query_3pe()", kind
) )
if service.url is None:
defer.returnValue([])
uri = "%s%s/thirdparty/%s/%s" % ( uri = "%s%s/thirdparty/%s/%s" % (
service.url, service.url,
@ -145,6 +151,9 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue([]) defer.returnValue([])
def get_3pe_protocol(self, service, protocol): def get_3pe_protocol(self, service, protocol):
if service.url is None:
defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def _get(): def _get():
uri = "%s%s/thirdparty/protocol/%s" % ( uri = "%s%s/thirdparty/protocol/%s" % (
@ -166,6 +175,9 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
if service.url is None:
defer.returnValue(True)
events = self._serialize(events) events = self._serialize(events)
if txn_id is None: if txn_id is None:

View file

@ -86,7 +86,7 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename): def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [ required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart" "id", "as_token", "hs_token", "sender_localpart"
] ]
for field in required_string_fields: for field in required_string_fields:
if not isinstance(as_info.get(field), basestring): if not isinstance(as_info.get(field), basestring):
@ -94,6 +94,14 @@ def _load_appservice(hostname, as_info, config_filename):
field, config_filename, field, config_filename,
)) ))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
if (not isinstance(as_info.get("url"), basestring) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"] localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart: if urllib.quote(localpart) != localpart:
raise ValueError( raise ValueError(
@ -132,6 +140,13 @@ def _load_appservice(hostname, as_info, config_filename):
for p in protocols: for p in protocols:
if not isinstance(p, str): if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item") raise KeyError("Bad value for 'protocols' item")
if as_info["url"] is None:
logger.info(
"(%s) Explicitly empty 'url' provided. This application service"
" will not receive events or queries.",
config_filename,
)
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
url=as_info["url"], url=as_info["url"],

View file

@ -269,7 +269,7 @@ class FederationClient(FederationBase):
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None signed_pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec() now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0) last_attempt = pdu_attempts.get(destination, 0)
@ -299,7 +299,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hashes([pdu])[0] signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
@ -322,10 +322,10 @@ class FederationClient(FederationBase):
) )
continue continue
if self._get_pdu_cache is not None and pdu: if self._get_pdu_cache is not None and signed_pdu:
self._get_pdu_cache[event_id] = pdu self._get_pdu_cache[event_id] = signed_pdu
defer.returnValue(pdu) defer.returnValue(signed_pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function

View file

@ -230,7 +230,7 @@ class PresenceHandler(object):
""" """
logger.info( logger.info(
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
len(self.user_to_current_state) len(self.unpersisted_users_changes)
) )
unpersisted = self.unpersisted_users_changes unpersisted = self.unpersisted_users_changes

View file

@ -338,7 +338,7 @@ class Mailer(object):
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room" # end up with, "new message from Bob in the Bob room"
room_name = yield calculate_room_name( room_name = yield calculate_room_name(
state_by_room[room_id], user_id, fallback_to_members=False self.store, state_by_room[room_id], user_id, fallback_to_members=False
) )
my_member_event = state_by_room[room_id][("m.room.member", user_id)] my_member_event = state_by_room[room_id][("m.room.member", user_id)]

View file

@ -41,6 +41,7 @@ STREAM_NAMES = (
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("caches",), ("caches",),
("to_device",),
) )
@ -142,6 +143,7 @@ class ReplicationResource(Resource):
pushers_token, pushers_token,
0, # State stream is no longer a thing 0, # State stream is no longer a thing
caches_token, caches_token,
int(stream_token.to_device_key),
)) ))
@request_handler() @request_handler()
@ -190,6 +192,7 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
@ -376,6 +379,20 @@ class ReplicationResource(Resource):
"position", "cache_func", "keys", "invalidation_ts" "position", "cache_func", "keys", "invalidation_ts"
)) ))
@defer.inlineCallbacks
def to_device(self, writer, current_token, limit, request_streams):
current_position = current_token.to_device
to_device = request_streams.get("to_device")
if to_device is not None:
to_device_rows = yield self.store.get_all_new_device_messages(
to_device, current_position, limit
)
writer.write_header_and_rows("to_device", to_device_rows, (
"position", "user_id", "device_id", "message_json"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -404,7 +421,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "push_rules", "pushers", "state", "caches", "to_device",
))): ))):
__slots__ = [] __slots__ = []

View file

@ -28,3 +28,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
delete_messages_for_device = DataStore.delete_messages_for_device.__func__ delete_messages_for_device = DataStore.delete_messages_for_device.__func__
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("to_device")
if stream:
self._device_inbox_id_gen.advance(int(stream["position"]))
return super(SlavedDeviceInboxStore, self).process_replication(result)

View file

@ -40,6 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()
@ -71,9 +72,14 @@ class SendToDeviceRestServlet(servlet.RestServlet):
} }
for device_id, message_content in by_device.items() for device_id, message_content in by_device.items()
} }
local_messages[user_id] = messages_by_device if messages_by_device:
local_messages[user_id] = messages_by_device
yield self.store.add_messages_to_device_inbox(local_messages) stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
self.notifier.on_new_event(
"to_device", stream_id, users=local_messages.keys()
)
response = (200, {}) response = (200, {})
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)

View file

@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
from collections import namedtuple from collections import namedtuple
@ -87,6 +88,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry. # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None self._state_cache = None
self.resolve_linearizer = Linearizer()
def start_caching(self): def start_caching(self):
logger.debug("start_caching") logger.debug("start_caching")
@ -297,85 +299,85 @@ class StateHandler(object):
delta_ids={}, delta_ids={},
)) ))
if self._state_cache is not None: with (yield self.resolve_linearizer.queue(group_names)):
cache = self._state_cache.get(group_names, None) if self._state_cache is not None:
if cache: cache = self._state_cache.get(group_names, None)
defer.returnValue(cache) if cache:
defer.returnValue(cache)
logger.info( logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids) "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
state = {}
for st in state_groups_ids.values():
for key, e_id in st.items():
state.setdefault(key, set()).add(e_id)
conflicted_state = {
k: list(v)
for k, v in state.items()
if len(v) > 1
}
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
) )
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map] state = {}
for st in state_groups_ids.values() for st in state_groups_ids.values():
] for key, e_id in st.items():
new_state, _ = self._resolve_events( state.setdefault(key, set()).add(e_id)
state_sets, event_type, state_key
) conflicted_state = {
new_state = { k: list(v)
key: e.event_id for key, e in new_state.items() for k, v in state.items()
} if len(v) > 1
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
} }
state_group = None if conflicted_state:
new_state_event_ids = frozenset(new_state.values()) logger.info("Resolving conflicted state for %r", room_id)
for sg, events in state_groups_ids.items(): state_map = yield self.store.get_events(
if new_state_event_ids == frozenset(e_id for e_id in events): [e_id for st in state_groups_ids.values() for e_id in st.values()],
state_group = sg get_prev_content=False
break )
state_sets = [
if state_group is None: [state_map[e_id] for key, e_id in st.items() if e_id in state_map]
# Worker instances don't have access to this method, but we want for st in state_groups_ids.values()
# to set the state_group on the main instance to increase cache ]
# hits. new_state, _ = self._resolve_events(
if hasattr(self.store, "get_next_state_group"): state_sets, event_type, state_key
state_group = self.store.get_next_state_group() )
new_state = {
prev_group = None key: e.event_id for key, e in new_state.items()
delta_ids = None }
for old_group, old_ids in state_groups_ids.items(): else:
if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): new_state = {
n_delta_ids = { key: e_ids.pop() for key, e_ids in state.items()
k: v
for k, v in new_state.items()
if old_ids.get(k) != v
} }
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
cache = _StateCacheEntry( state_group = None
state=new_state, new_state_event_ids = frozenset(new_state.values())
state_group=state_group, for sg, events in state_groups_ids.items():
prev_group=prev_group, if new_state_event_ids == frozenset(e_id for e_id in events):
delta_ids=delta_ids, state_group = sg
) break
if state_group is None:
# Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache
# hits.
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
if self._state_cache is not None: prev_group = None
self._state_cache[group_names] = cache delta_ids = None
for old_group, old_ids in state_groups_ids.items():
if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
n_delta_ids = {
k: v
for k, v in new_state.items()
if old_ids.get(k) != v
}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
defer.returnValue(cache) cache = _StateCacheEntry(
state=new_state,
state_group=state_group,
prev_group=prev_group,
delta_ids=delta_ids,
)
if self._state_cache is not None:
self._state_cache[group_names] = cache
defer.returnValue(cache)
def resolve_events(self, state_sets, event): def resolve_events(self, state_sets, event):
logger.info( logger.info(

View file

@ -33,7 +33,8 @@ class DeviceInboxStore(SQLBaseStore):
messages_by_user_and_device(dict): messages_by_user_and_device(dict):
Dictionary of user_id to device_id to message. Dictionary of user_id to device_id to message.
Returns: Returns:
A deferred that resolves when the messages have been inserted. A deferred stream_id that resolves when the messages have been
inserted.
""" """
def select_devices_txn(txn, user_id, devices): def select_devices_txn(txn, user_id, devices):
@ -81,6 +82,8 @@ class DeviceInboxStore(SQLBaseStore):
stream_id stream_id
) )
defer.returnValue(self._device_inbox_id_gen.get_current_token())
def get_new_messages_for_device( def get_new_messages_for_device(
self, user_id, device_id, current_stream_id, limit=100 self, user_id, device_id, current_stream_id, limit=100
): ):
@ -136,5 +139,44 @@ class DeviceInboxStore(SQLBaseStore):
"delete_messages_for_device", delete_messages_for_device_txn "delete_messages_for_device", delete_messages_for_device_txn
) )
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
sql = (
"SELECT stream_id FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY stream_id"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (last_pos, current_pos, limit))
stream_ids = txn.fetchall()
if not stream_ids:
return []
max_stream_id_in_limit = stream_ids[-1]
sql = (
"SELECT stream_id, user_id, device_id, message_json"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
txn.execute(sql, (last_pos, max_stream_id_in_limit))
return txn.fetchall()
return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()

View file

@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) @cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.

View file

@ -0,0 +1,32 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
cur.execute("TRUNCATE sent_transactions")
else:
cur.execute("DELETE FROM sent_transactions")
cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)")
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions(self): def _cleanup_transactions(self):
now = self._clock.time_msec() now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000 month_ago = now - 30 * 24 * 60 * 60 * 1000
six_hours_ago = now - 6 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn): def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn) return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)