Hook device list updates to replication

This commit is contained in:
Erik Johnston 2017-01-27 13:36:39 +00:00
parent 84a35f32c7
commit 252b503fc8
7 changed files with 159 additions and 29 deletions

View file

@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep from synapse.util.async import sleep
@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
class FederationSenderSlaveStore( class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedRegistrationStore, SlavedDeviceStore,
): ):
pass pass

View file

@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore,
RoomStore, RoomStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
@ -380,6 +382,28 @@ class SynchrotronServer(HomeServer):
stream_key, position, users=users, rooms=rooms stream_key, position, users=users, rooms=rooms
) )
@defer.inlineCallbacks
def notify_device_list_update(result):
stream = result.get("device_lists")
if not stream:
return
position_index = stream["field_names"].index("position")
user_index = stream["field_names"].index("user_id")
for row in stream["rows"]:
logger.info("Handling device list row: %r", row)
position = row[position_index]
user_id = row[user_index]
rooms = yield store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
@defer.inlineCallbacks
def notify(result): def notify(result):
stream = result.get("events") stream = result.get("events")
if stream: if stream:
@ -417,6 +441,7 @@ class SynchrotronServer(HomeServer):
notify_from_stream( notify_from_stream(
result, "to_device", "to_device_key", user="user_id" result, "to_device", "to_device_key", user="user_id"
) )
yield notify_device_list_update(result)
while True: while True:
try: try:
@ -427,7 +452,7 @@ class SynchrotronServer(HomeServer):
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
yield presence_handler.process_replication(result) yield presence_handler.process_replication(result)
notify(result) yield notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)
yield sleep(5) yield sleep(5)

View file

@ -220,22 +220,6 @@ class DeviceHandler(BaseHandler):
for host in hosts: for host in hosts:
self.federation_sender.send_device_messages(host) self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def get_device_list_changes(self, user_id, room_ids, from_key):
"""For a user and their joined rooms, calculate which device updates
we need to return.
"""
room_ids = frozenset(room_ids)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(from_key)
for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
@defer.inlineCallbacks @defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content): def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"] user_id = edu_content["user_id"]

View file

@ -144,7 +144,6 @@ class SyncHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs) self.response_cache = ResponseCache(hs)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.device_handler = hs.get_device_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
@ -546,15 +545,9 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder) yield self._generate_sync_entry_for_to_device(sync_result_builder)
if since_token and since_token.device_list_key: device_lists = yield self._generate_sync_entry_for_device_list(
user_id = sync_config.user.to_string() sync_result_builder
rooms = yield self.store.get_rooms_for_user(user_id) )
joined_room_ids = set(r.room_id for r in rooms)
device_lists = yield self.device_handler.get_device_list_changes(
user_id, joined_room_ids, since_token.device_list_key
)
else:
device_lists = []
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
@ -567,6 +560,28 @@ class SyncHandler(object):
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
else:
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder): def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates """Generates the portion of the sync response. Populates

View file

@ -46,6 +46,7 @@ STREAM_NAMES = (
("to_device",), ("to_device",),
("public_rooms",), ("public_rooms",),
("federation",), ("federation",),
("device_lists",),
) )
@ -140,6 +141,7 @@ class ReplicationResource(Resource):
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id() public_rooms_token = self.store.get_current_public_room_stream_id()
federation_token = self.federation_sender.get_current_token() federation_token = self.federation_sender.get_current_token()
device_list_token = self.store.get_device_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -155,6 +157,7 @@ class ReplicationResource(Resource):
int(stream_token.to_device_key), int(stream_token.to_device_key),
int(public_rooms_token), int(public_rooms_token),
int(federation_token), int(federation_token),
int(device_list_token),
)) ))
@request_handler() @request_handler()
@ -214,6 +217,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams)
yield self.device_lists(writer, current_token, limit, request_streams)
self.federation(writer, current_token, limit, request_streams, federation_ack) self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
@ -495,6 +499,20 @@ class ReplicationResource(Resource):
"position", "type", "content", "position", "type", "content",
), position=upto_token) ), position=upto_token)
@defer.inlineCallbacks
def device_lists(self, writer, current_token, limit, request_streams):
current_position = current_token.device_lists
device_lists = request_streams.get("device_lists")
if device_lists is not None and device_lists != current_position:
changes = yield self.store.get_users_and_hosts_device_list_changes(
device_lists,
)
writer.write_header_and_rows("device_lists", changes, (
"position", "user_id", "destination",
), position=current_position)
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -527,7 +545,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
"federation", "federation", "device_lists",
))): ))):
__slots__ = [] __slots__ = []

View file

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id",
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = DataStore.get_device_stream_token.__func__
get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
get_devices_by_remote = DataStore.get_devices_by_remote.__func__
_get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
_get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
_mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("device_lists")
if stream:
self._device_list_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
self._device_list_stream_cache.entity_has_changed(
user_id, stream_id
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id
)
return super(SlavedDeviceStore, self).process_replication(result)

View file

@ -458,6 +458,21 @@ class DeviceStore(SQLBaseStore):
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row["user_id"] for row in rows)) defer.returnValue(set(row["user_id"] for row in rows))
def get_users_and_hosts_device_list_changes(self, from_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
sql = """
SELECT stream_id, user_id, destination FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE stream_id > ?
"""
return self._execute(
"get_users_and_hosts_device_list", None,
sql, from_key,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts): def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts """Persist that a user's devices have been updated, and which hosts