0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 16:53:51 +01:00

Merge pull request #1074 from matrix-org/markjh/direct_to_device_federation

Send device messages over federation
This commit is contained in:
Mark Haines 2016-09-08 14:26:47 +01:00 committed by GitHub
commit 2117c409a0
11 changed files with 459 additions and 74 deletions

View file

@ -137,6 +137,12 @@ class FederationClient(FederationBase):
self._transaction_queue.enqueue_edu(edu) self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None) return defer.succeed(None)
@log_function
def send_device_messages(self, destination):
"""Sends the device messages in the local database to the remote
destination"""
self._transaction_queue.enqueue_device_messages(destination)
@log_function @log_function
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination) self._transaction_queue.enqueue_failure(failure, destination)

View file

@ -188,7 +188,7 @@ class FederationServer(FederationBase):
except SynapseError as e: except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e) logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e: except Exception as e:
logger.exception("Failed to handle edu %r", edu_type, e) logger.exception("Failed to handle edu %r", edu_type)
else: else:
logger.warn("Received EDU of type %s with no handler", edu_type) logger.warn("Received EDU of type %s with no handler", edu_type)

View file

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -81,6 +81,8 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
@ -155,6 +157,17 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
if not self.can_send_to(destination):
return
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor() yield run_on_reactor()
@ -175,6 +188,12 @@ class TransactionQueue(object):
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_failures = self.pending_failures_by_dest.pop(destination, [])
device_message_edus, device_stream_id = (
yield self._get_new_device_messages(destination)
)
pending_edus.extend(device_message_edus)
if pending_pdus: if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus)) destination, len(pending_pdus))
@ -184,13 +203,34 @@ class TransactionQueue(object):
return return
yield self._send_new_transaction( yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus)
) )
@defer.inlineCallbacks
def _get_new_device_messages(self, destination):
last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
to_device_stream_id = self.store.get_to_device_stream_token()
contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
destination, last_device_stream_id, to_device_stream_id
)
edus = [
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.direct_to_device",
content=content,
)
for content in contents
]
defer.returnValue((edus, stream_id))
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures): pending_failures, device_stream_id,
should_delete_from_device_stream):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
@ -215,9 +255,9 @@ class TransactionQueue(object):
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d, failures: %d)",
destination, txn_id, destination, txn_id,
len(pending_pdus), len(pdus),
len(pending_edus), len(edus),
len(pending_failures) len(failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
@ -242,9 +282,9 @@ class TransactionQueue(object):
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pending_pdus), len(pdus),
len(pending_edus), len(edus),
len(pending_failures), len(failures),
) )
with limiter: with limiter:
@ -299,6 +339,13 @@ class TransactionQueue(object):
logger.info( logger.info(
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
else:
# Remove the acknowledged device messages from the database
if should_delete_from_device_stream:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
self.last_device_stream_id_by_dest[destination] = device_stream_id
except NotRetryingDestination: except NotRetryingDestination:
logger.info( logger.info(
"TX [%s] not ready for retry yet - " "TX [%s] not ready for retry yet - "

View file

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
origin, sender_user_id
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
@defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
message_id = random_string(16)
remote_edu_contents = {}
for destination, messages in remote_messages.items():
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
}
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation.send_device_messages(destination)

View file

@ -16,6 +16,7 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(BaseSlavedStore): class SlavedDeviceInboxStore(BaseSlavedStore):
@ -24,6 +25,10 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
self._device_inbox_id_gen = SlavedIdTracker( self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id", db_conn, "device_inbox", "stream_id",
) )
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token()
)
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
@ -38,5 +43,11 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
stream = result.get("to_device") stream = result.get("to_device")
if stream: if stream:
self._device_inbox_id_gen.advance(int(stream["position"])) self._device_inbox_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
return super(SlavedDeviceInboxStore, self).process_replication(result) return super(SlavedDeviceInboxStore, self).process_replication(result)

View file

@ -16,10 +16,11 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionStore from synapse.rest.client.v1.transactions import HttpTransactionStore
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,10 +40,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__() super(SendToDeviceRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()
self.device_message_handler = hs.get_device_message_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):
@ -57,28 +56,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
# TODO: Prod the notifier to wake up sync streams. sender_user_id = requester.user.to_string()
# TODO: Implement replication for the messages.
# TODO: Send the messages to remote servers if needed.
local_messages = {} yield self.device_message_handler.send_device_message(
for user_id, by_device in content["messages"].items(): sender_user_id, message_type, content["messages"]
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": requester.user.to_string(),
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
) )
response = (200, {}) response = (200, {})

View file

@ -35,6 +35,7 @@ from synapse.federation import initialize_http_replication
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
@ -100,6 +101,7 @@ class HomeServer(object):
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
'device_message_handler',
'notifier', 'notifier',
'distributor', 'distributor',
'client_resource', 'client_resource',
@ -205,6 +207,9 @@ class HomeServer(object):
def build_device_handler(self): def build_device_handler(self):
return DeviceHandler(self) return DeviceHandler(self)
def build_device_message_handler(self):
return DeviceMessageHandler(self)
def build_e2e_keys_handler(self): def build_e2e_keys_handler(self):
return E2eKeysHandler(self) return E2eKeysHandler(self)

View file

@ -182,6 +182,30 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=push_rules_prefill, prefilled_cache=push_rules_prefill,
) )
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
db_conn, "device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",

View file

@ -27,19 +27,112 @@ logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore): class DeviceInboxStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_messages_to_device_inbox(self, messages_by_user_then_device): def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
""" remote_messages_by_destination):
"""Used to send messages from this server.
Args: Args:
messages_by_user_and_device(dict): sender_user_id(str): The ID of the user sending these messages.
local_messages_by_user_and_device(dict):
Dictionary of user_id to device_id to message. Dictionary of user_id to device_id to message.
remote_messages_by_destination(dict):
Dictionary of destination server_name to the EDU JSON to send.
Returns: Returns:
A deferred stream_id that resolves when the messages have been A deferred stream_id that resolves when the messages have been
inserted. inserted.
""" """
def select_devices_txn(txn, user_id, devices): def add_messages_txn(txn, now_ms, stream_id):
if not devices: # Add the local messages directly to the local inbox.
return [] self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
)
# Add the remote messages to the federation outbox.
# We'll send them to a remote server when we next send a
# federation transaction to that destination.
sql = (
"INSERT INTO device_federation_outbox"
" (destination, stream_id, queued_ts, messages_json)"
" VALUES (?,?,?,?)"
)
rows = []
for destination, edu in remote_messages_by_destination.items():
edu_json = ujson.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
"add_messages_to_device_inbox",
add_messages_txn,
now_ms,
stream_id,
)
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
for destination in remote_messages_by_destination.keys():
self._device_federation_outbox_stream_cache.entity_has_changed(
destination, stream_id
)
defer.returnValue(self._device_inbox_id_gen.get_current_token())
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
self, origin, message_id, local_messages_by_user_then_device
):
def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
already_inserted = self._simple_select_one_txn(
txn, table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
retcols=("message_id",),
allow_none=True,
)
if already_inserted is not None:
return
# Add an entry for this message_id so that we know we've processed
# it.
self._simple_insert_txn(
txn, table="device_federation_inbox",
values={
"origin": origin,
"message_id": message_id,
"received_ts": now_ms,
},
)
# Add the messages to the approriate local device inboxes so that
# they'll be sent to the devices when they next sync.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
stream_id,
)
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
messages_by_user_then_device):
local_users_and_devices = set()
for user_id, messages_by_device in messages_by_user_then_device.items():
devices = messages_by_device.keys()
sql = ( sql = (
"SELECT user_id, device_id FROM devices" "SELECT user_id, device_id FROM devices"
" WHERE user_id = ? AND device_id IN (" " WHERE user_id = ? AND device_id IN ("
@ -48,41 +141,24 @@ class DeviceInboxStore(SQLBaseStore):
) )
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
args = [user_id] + devices txn.execute(sql, [user_id] + devices)
txn.execute(sql, args) local_users_and_devices.update(map(tuple, txn.fetchall()))
return [tuple(row) for row in txn.fetchall()]
def add_messages_to_device_inbox_txn(txn, stream_id): sql = (
local_users_and_devices = set() "INSERT INTO device_inbox"
for user_id, messages_by_device in messages_by_user_then_device.items(): " (user_id, device_id, stream_id, message_json)"
local_users_and_devices.update( " VALUES (?,?,?,?)"
select_devices_txn(txn, user_id, messages_by_device.keys()) )
) rows = []
for user_id, messages_by_device in messages_by_user_then_device.items():
for device_id, message in messages_by_device.items():
message_json = ujson.dumps(message)
# Only insert into the local inbox if the device exists on
# this server
if (user_id, device_id) in local_users_and_devices:
rows.append((user_id, device_id, stream_id, message_json))
sql = ( txn.executemany(sql, rows)
"INSERT INTO device_inbox"
" (user_id, device_id, stream_id, message_json)"
" VALUES (?,?,?,?)"
)
rows = []
for user_id, messages_by_device in messages_by_user_then_device.items():
for device_id, message in messages_by_device.items():
message_json = ujson.dumps(message)
# Only insert into the local inbox if the device exists on
# this server
if (user_id, device_id) in local_users_and_devices:
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_messages_to_device_inbox",
add_messages_to_device_inbox_txn,
stream_id
)
defer.returnValue(self._device_inbox_id_gen.get_current_token())
def get_new_messages_for_device( def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100 self, user_id, device_id, last_stream_id, current_stream_id, limit=100
@ -97,6 +173,12 @@ class DeviceInboxStore(SQLBaseStore):
Deferred ([dict], int): List of messages for the device and where Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to. in the stream the messages got to.
""" """
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn): def get_new_messages_for_device_txn(txn):
sql = ( sql = (
"SELECT stream_id, message_json FROM device_inbox" "SELECT stream_id, message_json FROM device_inbox"
@ -182,3 +264,71 @@ class DeviceInboxStore(SQLBaseStore):
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int): The last position of the device message stream
that the server sent up to.
current_stream_id(int): The current position of the device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn.fetchall():
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)

View file

@ -0,0 +1,36 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE device_federation_outbox (
destination TEXT NOT NULL,
stream_id BIGINT NOT NULL,
queued_ts BIGINT NOT NULL,
messages_json TEXT NOT NULL
);
CREATE INDEX device_federation_outbox_destination_id
ON device_federation_outbox(destination, stream_id);
CREATE TABLE device_federation_inbox (
origin TEXT NOT NULL,
message_id TEXT NOT NULL,
received_ts BIGINT NOT NULL
);
CREATE INDEX device_federation_inbox_sender_id
ON device_federation_inbox(origin, message_id);

View file

@ -121,6 +121,14 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = (
lambda *args, **kargs: ([], 0)
)
self.datastore.delete_device_msgs_for_remote = (
lambda *args, **kargs: None
)
# Some local users to test with # Some local users to test with
self.u_apple = UserID.from_string("@apple:test") self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test") self.u_banana = UserID.from_string("@banana:test")