mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 15:53:51 +01:00
Merge pull request #2097 from matrix-org/erikj/repl_tcp_client
Move to using TCP replication
This commit is contained in:
commit
a5c401bd12
21 changed files with 601 additions and 581 deletions
|
@ -26,17 +26,17 @@ from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
from synapse import events
|
from synapse import events
|
||||||
|
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from daemonize import Daemonize
|
from daemonize import Daemonize
|
||||||
|
@ -120,30 +120,25 @@ class AppserviceServer(HomeServer):
|
||||||
else:
|
else:
|
||||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
self.get_tcp_replication().start_replication(self)
|
||||||
def replicate(self):
|
|
||||||
http_client = self.get_simple_http_client()
|
|
||||||
store = self.get_datastore()
|
|
||||||
replication_url = self.config.worker_replication_url
|
|
||||||
appservice_handler = self.get_application_service_handler()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def build_tcp_replication(self):
|
||||||
def replicate(results):
|
return ASReplicationHandler(self)
|
||||||
stream = results.get("events")
|
|
||||||
if stream:
|
|
||||||
max_stream_id = stream["position"]
|
|
||||||
yield appservice_handler.notify_interested_services(max_stream_id)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
class ASReplicationHandler(ReplicationClientHandler):
|
||||||
args = store.stream_positions()
|
def __init__(self, hs):
|
||||||
args["timeout"] = 30000
|
super(ASReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
yield store.process_replication(result)
|
|
||||||
replicate(result)
|
def on_rdata(self, stream_name, token, rows):
|
||||||
except:
|
super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
logger.exception("Error replicating from %r", replication_url)
|
|
||||||
yield sleep(30)
|
if stream_name == "events":
|
||||||
|
max_stream_id = self.store.get_room_max_stream_ordering()
|
||||||
|
preserve_fn(
|
||||||
|
self.appservice_handler.notify_interested_services
|
||||||
|
)(max_stream_id)
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options):
|
||||||
|
@ -199,7 +194,6 @@ def start(config_options):
|
||||||
reactor.run()
|
reactor.run()
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
ps.replicate()
|
|
||||||
ps.get_datastore().start_profiling()
|
ps.get_datastore().start_profiling()
|
||||||
ps.get_state_handler().start_caching()
|
ps.get_state_handler().start_caching()
|
||||||
|
|
||||||
|
|
|
@ -30,11 +30,11 @@ from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.client.v1.room import PublicRoomListRestServlet
|
from synapse.rest.client.v1.room import PublicRoomListRestServlet
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.client_ips import ClientIpStore
|
from synapse.storage.client_ips import ClientIpStore
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
|
@ -45,7 +45,7 @@ from synapse.crypto import context_factory
|
||||||
from synapse import events
|
from synapse import events
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from daemonize import Daemonize
|
from daemonize import Daemonize
|
||||||
|
@ -145,21 +145,10 @@ class ClientReaderServer(HomeServer):
|
||||||
else:
|
else:
|
||||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
self.get_tcp_replication().start_replication(self)
|
||||||
def replicate(self):
|
|
||||||
http_client = self.get_simple_http_client()
|
|
||||||
store = self.get_datastore()
|
|
||||||
replication_url = self.config.worker_replication_url
|
|
||||||
|
|
||||||
while True:
|
def build_tcp_replication(self):
|
||||||
try:
|
return ReplicationClientHandler(self.get_datastore())
|
||||||
args = store.stream_positions()
|
|
||||||
args["timeout"] = 30000
|
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
|
||||||
yield store.process_replication(result)
|
|
||||||
except:
|
|
||||||
logger.exception("Error replicating from %r", replication_url)
|
|
||||||
yield sleep(5)
|
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options):
|
||||||
|
@ -209,7 +198,6 @@ def start(config_options):
|
||||||
def start():
|
def start():
|
||||||
ss.get_state_handler().start_caching()
|
ss.get_state_handler().start_caching()
|
||||||
ss.get_datastore().start_profiling()
|
ss.get_datastore().start_profiling()
|
||||||
ss.replicate()
|
|
||||||
|
|
||||||
reactor.callWhenRunning(start)
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
|
|
@ -27,9 +27,9 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
|
@ -42,7 +42,7 @@ from synapse.crypto import context_factory
|
||||||
from synapse import events
|
from synapse import events
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from daemonize import Daemonize
|
from daemonize import Daemonize
|
||||||
|
@ -134,21 +134,10 @@ class FederationReaderServer(HomeServer):
|
||||||
else:
|
else:
|
||||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
self.get_tcp_replication().start_replication(self)
|
||||||
def replicate(self):
|
|
||||||
http_client = self.get_simple_http_client()
|
|
||||||
store = self.get_datastore()
|
|
||||||
replication_url = self.config.worker_replication_url
|
|
||||||
|
|
||||||
while True:
|
def build_tcp_replication(self):
|
||||||
try:
|
return ReplicationClientHandler(self.get_datastore())
|
||||||
args = store.stream_positions()
|
|
||||||
args["timeout"] = 30000
|
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
|
||||||
yield store.process_replication(result)
|
|
||||||
except:
|
|
||||||
logger.exception("Error replicating from %r", replication_url)
|
|
||||||
yield sleep(5)
|
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options):
|
||||||
|
@ -198,7 +187,6 @@ def start(config_options):
|
||||||
def start():
|
def start():
|
||||||
ss.get_state_handler().start_caching()
|
ss.get_state_handler().start_caching()
|
||||||
ss.get_datastore().start_profiling()
|
ss.get_datastore().start_profiling()
|
||||||
ss.replicate()
|
|
||||||
|
|
||||||
reactor.callWhenRunning(start)
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
|
|
@ -31,11 +31,12 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||||
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
@ -59,7 +60,28 @@ class FederationSenderSlaveStore(
|
||||||
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
||||||
SlavedRegistrationStore, SlavedDeviceStore,
|
SlavedRegistrationStore, SlavedDeviceStore,
|
||||||
):
|
):
|
||||||
pass
|
def __init__(self, db_conn, hs):
|
||||||
|
super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
# We pull out the current federation stream position now so that we
|
||||||
|
# always have a known value for the federation position in memory so
|
||||||
|
# that we don't have to bounce via a deferred once when we start the
|
||||||
|
# replication streams.
|
||||||
|
self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
|
||||||
|
|
||||||
|
def _get_federation_out_pos(self, db_conn):
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id FROM federation_stream_position"
|
||||||
|
" WHERE type = ?"
|
||||||
|
)
|
||||||
|
sql = self.database_engine.convert_param_style(sql)
|
||||||
|
|
||||||
|
txn = db_conn.cursor()
|
||||||
|
txn.execute(sql, ("federation",))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
txn.close()
|
||||||
|
|
||||||
|
return rows[0][0] if rows else -1
|
||||||
|
|
||||||
|
|
||||||
class FederationSenderServer(HomeServer):
|
class FederationSenderServer(HomeServer):
|
||||||
|
@ -127,26 +149,27 @@ class FederationSenderServer(HomeServer):
|
||||||
else:
|
else:
|
||||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
self.get_tcp_replication().start_replication(self)
|
||||||
def replicate(self):
|
|
||||||
http_client = self.get_simple_http_client()
|
|
||||||
store = self.get_datastore()
|
|
||||||
replication_url = self.config.worker_replication_url
|
|
||||||
send_handler = FederationSenderHandler(self)
|
|
||||||
|
|
||||||
send_handler.on_start()
|
def build_tcp_replication(self):
|
||||||
|
return FederationSenderReplicationHandler(self)
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
class FederationSenderReplicationHandler(ReplicationClientHandler):
|
||||||
args = store.stream_positions()
|
def __init__(self, hs):
|
||||||
args.update((yield send_handler.stream_positions()))
|
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
args["timeout"] = 30000
|
self.send_handler = FederationSenderHandler(hs, self)
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
|
||||||
yield store.process_replication(result)
|
def on_rdata(self, stream_name, token, rows):
|
||||||
yield send_handler.process_replication(result)
|
super(FederationSenderReplicationHandler, self).on_rdata(
|
||||||
except:
|
stream_name, token, rows
|
||||||
logger.exception("Error replicating from %r", replication_url)
|
)
|
||||||
yield sleep(30)
|
self.send_handler.process_replication_rows(stream_name, token, rows)
|
||||||
|
|
||||||
|
def get_streams_to_replicate(self):
|
||||||
|
args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
|
||||||
|
args.update(self.send_handler.stream_positions())
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options):
|
||||||
|
@ -205,7 +228,6 @@ def start(config_options):
|
||||||
reactor.run()
|
reactor.run()
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
ps.replicate()
|
|
||||||
ps.get_datastore().start_profiling()
|
ps.get_datastore().start_profiling()
|
||||||
ps.get_state_handler().start_caching()
|
ps.get_state_handler().start_caching()
|
||||||
|
|
||||||
|
@ -229,9 +251,15 @@ class FederationSenderHandler(object):
|
||||||
"""Processes the replication stream and forwards the appropriate entries
|
"""Processes the replication stream and forwards the appropriate entries
|
||||||
to the federation sender.
|
to the federation sender.
|
||||||
"""
|
"""
|
||||||
def __init__(self, hs):
|
def __init__(self, hs, replication_client):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.federation_sender = hs.get_federation_sender()
|
self.federation_sender = hs.get_federation_sender()
|
||||||
|
self.replication_client = replication_client
|
||||||
|
|
||||||
|
self.federation_position = self.store.federation_out_pos_startup
|
||||||
|
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
|
||||||
|
|
||||||
|
self._last_ack = self.federation_position
|
||||||
|
|
||||||
self._room_serials = {}
|
self._room_serials = {}
|
||||||
self._room_typing = {}
|
self._room_typing = {}
|
||||||
|
@ -243,25 +271,13 @@ class FederationSenderHandler(object):
|
||||||
self.store.get_room_max_stream_ordering()
|
self.store.get_room_max_stream_ordering()
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
stream_id = yield self.store.get_federation_out_pos("federation")
|
return {"federation": self.federation_position}
|
||||||
defer.returnValue({
|
|
||||||
"federation": stream_id,
|
|
||||||
|
|
||||||
# Ack stuff we've "processed", this should only be called from
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
# one process.
|
|
||||||
"federation_ack": stream_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def process_replication(self, result):
|
|
||||||
# The federation stream contains things that we want to send out, e.g.
|
# The federation stream contains things that we want to send out, e.g.
|
||||||
# presence, typing, etc.
|
# presence, typing, etc.
|
||||||
fed_stream = result.get("federation")
|
if stream_name == "federation":
|
||||||
if fed_stream:
|
|
||||||
latest_id = int(fed_stream["position"])
|
|
||||||
|
|
||||||
# The federation stream containis a bunch of different types of
|
# The federation stream containis a bunch of different types of
|
||||||
# rows that need to be handled differently. We parse the rows, put
|
# rows that need to be handled differently. We parse the rows, put
|
||||||
# them into the appropriate collection and then send them off.
|
# them into the appropriate collection and then send them off.
|
||||||
|
@ -272,8 +288,9 @@ class FederationSenderHandler(object):
|
||||||
device_destinations = set()
|
device_destinations = set()
|
||||||
|
|
||||||
# Parse the rows in the stream
|
# Parse the rows in the stream
|
||||||
for row in fed_stream["rows"]:
|
for row in rows:
|
||||||
position, typ, content_js = row
|
typ = row.type
|
||||||
|
content_js = row.data
|
||||||
content = json.loads(content_js)
|
content = json.loads(content_js)
|
||||||
|
|
||||||
if typ == send_queue.PRESENCE_TYPE:
|
if typ == send_queue.PRESENCE_TYPE:
|
||||||
|
@ -325,16 +342,27 @@ class FederationSenderHandler(object):
|
||||||
for destination in device_destinations:
|
for destination in device_destinations:
|
||||||
self.federation_sender.send_device_messages(destination)
|
self.federation_sender.send_device_messages(destination)
|
||||||
|
|
||||||
# Record where we are in the stream.
|
preserve_fn(self.update_token)(token)
|
||||||
yield self.store.update_federation_out_pos(
|
|
||||||
"federation", latest_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# We also need to poke the federation sender when new events happen
|
# We also need to poke the federation sender when new events happen
|
||||||
event_stream = result.get("events")
|
elif stream_name == "events":
|
||||||
if event_stream:
|
self.federation_sender.notify_new_events(token)
|
||||||
latest_pos = event_stream["position"]
|
|
||||||
self.federation_sender.notify_new_events(latest_pos)
|
@defer.inlineCallbacks
|
||||||
|
def update_token(self, token):
|
||||||
|
self.federation_position = token
|
||||||
|
|
||||||
|
# We linearize here to ensure we don't have races updating the token
|
||||||
|
with (yield self._fed_position_linearizer.queue(None)):
|
||||||
|
if self._last_ack < self.federation_position:
|
||||||
|
yield self.store.update_federation_out_pos(
|
||||||
|
"federation", self.federation_position
|
||||||
|
)
|
||||||
|
|
||||||
|
# We ACK this token over replication so that the master can drop
|
||||||
|
# its in memory queues
|
||||||
|
self.replication_client.send_federation_ack(self.federation_position)
|
||||||
|
self._last_ack = self.federation_position
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -25,13 +25,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.client_ips import ClientIpStore
|
from synapse.storage.client_ips import ClientIpStore
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.media_repository import MediaRepositoryStore
|
from synapse.storage.media_repository import MediaRepositoryStore
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
|
@ -45,7 +45,7 @@ from synapse.crypto import context_factory
|
||||||
from synapse import events
|
from synapse import events
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from daemonize import Daemonize
|
from daemonize import Daemonize
|
||||||
|
@ -142,21 +142,10 @@ class MediaRepositoryServer(HomeServer):
|
||||||
else:
|
else:
|
||||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
self.get_tcp_replication().start_replication(self)
|
||||||
def replicate(self):
|
|
||||||
http_client = self.get_simple_http_client()
|
|
||||||
store = self.get_datastore()
|
|
||||||
replication_url = self.config.worker_replication_url
|
|
||||||
|
|
||||||
while True:
|
def build_tcp_replication(self):
|
||||||
try:
|
return ReplicationClientHandler(self.get_datastore())
|
||||||
args = store.stream_positions()
|
|
||||||
args["timeout"] = 30000
|
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
|
||||||
yield store.process_replication(result)
|
|
||||||
except:
|
|
||||||
logger.exception("Error replicating from %r", replication_url)
|
|
||||||
yield sleep(5)
|
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options):
|
||||||
|
@ -206,7 +195,6 @@ def start(config_options):
|
||||||
def start():
|
def start():
|
||||||
ss.get_state_handler().start_caching()
|
ss.get_state_handler().start_caching()
|
||||||
ss.get_datastore().start_profiling()
|
ss.get_datastore().start_profiling()
|
||||||
ss.replicate()
|
|
||||||
|
|
||||||
reactor.callWhenRunning(start)
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
|
|
@ -27,9 +27,9 @@ from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
||||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||||
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, preserve_fn, \
|
from synapse.util.logcontext import LoggingContext, preserve_fn, \
|
||||||
PreserveLoggingContext
|
PreserveLoggingContext
|
||||||
|
@ -89,7 +89,6 @@ class PusherSlaveStore(
|
||||||
|
|
||||||
|
|
||||||
class PusherServer(HomeServer):
|
class PusherServer(HomeServer):
|
||||||
|
|
||||||
def get_db_conn(self, run_new_connection=True):
|
def get_db_conn(self, run_new_connection=True):
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
# not be passed to the database engine.
|
# not be passed to the database engine.
|
||||||
|
@ -109,16 +108,7 @@ class PusherServer(HomeServer):
|
||||||
logger.info("Finished setting up.")
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
http_client = self.get_simple_http_client()
|
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
|
||||||
replication_url = self.config.worker_replication_url
|
|
||||||
url = replication_url + "/remove_pushers"
|
|
||||||
return http_client.post_json_get_json(url, {
|
|
||||||
"remove": [{
|
|
||||||
"app_id": app_id,
|
|
||||||
"push_key": push_key,
|
|
||||||
"user_id": user_id,
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
|
|
||||||
def _listen_http(self, listener_config):
|
def _listen_http(self, listener_config):
|
||||||
port = listener_config["port"]
|
port = listener_config["port"]
|
||||||
|
@ -166,73 +156,52 @@ class PusherServer(HomeServer):
|
||||||
else:
|
else:
|
||||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
|
self.get_tcp_replication().start_replication(self)
|
||||||
|
|
||||||
|
def build_tcp_replication(self):
|
||||||
|
return PusherReplicationHandler(self)
|
||||||
|
|
||||||
|
|
||||||
|
class PusherReplicationHandler(ReplicationClientHandler):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(PusherReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
|
|
||||||
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
|
||||||
|
def on_rdata(self, stream_name, token, rows):
|
||||||
|
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
|
preserve_fn(self.poke_pushers)(stream_name, token, rows)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def replicate(self):
|
def poke_pushers(self, stream_name, token, rows):
|
||||||
http_client = self.get_simple_http_client()
|
if stream_name == "pushers":
|
||||||
store = self.get_datastore()
|
for row in rows:
|
||||||
replication_url = self.config.worker_replication_url
|
if row.deleted:
|
||||||
pusher_pool = self.get_pusherpool()
|
yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
|
||||||
|
else:
|
||||||
def stop_pusher(user_id, app_id, pushkey):
|
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
|
||||||
key = "%s:%s" % (app_id, pushkey)
|
elif stream_name == "events":
|
||||||
pushers_for_user = pusher_pool.pushers.get(user_id, {})
|
yield self.pusher_pool.on_new_notifications(
|
||||||
pusher = pushers_for_user.pop(key, None)
|
token, token,
|
||||||
if pusher is None:
|
|
||||||
return
|
|
||||||
logger.info("Stopping pusher %r / %r", user_id, key)
|
|
||||||
pusher.on_stop()
|
|
||||||
|
|
||||||
def start_pusher(user_id, app_id, pushkey):
|
|
||||||
key = "%s:%s" % (app_id, pushkey)
|
|
||||||
logger.info("Starting pusher %r / %r", user_id, key)
|
|
||||||
return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def poke_pushers(results):
|
|
||||||
pushers_rows = set(
|
|
||||||
map(tuple, results.get("pushers", {}).get("rows", []))
|
|
||||||
)
|
)
|
||||||
deleted_pushers_rows = set(
|
elif stream_name == "receipts":
|
||||||
map(tuple, results.get("deleted_pushers", {}).get("rows", []))
|
yield self.pusher_pool.on_new_receipts(
|
||||||
|
token, token, set(row.room_id for row in rows)
|
||||||
)
|
)
|
||||||
for row in sorted(pushers_rows | deleted_pushers_rows):
|
|
||||||
if row in deleted_pushers_rows:
|
|
||||||
user_id, app_id, pushkey = row[1:4]
|
|
||||||
stop_pusher(user_id, app_id, pushkey)
|
|
||||||
elif row in pushers_rows:
|
|
||||||
user_id = row[1]
|
|
||||||
app_id = row[5]
|
|
||||||
pushkey = row[8]
|
|
||||||
yield start_pusher(user_id, app_id, pushkey)
|
|
||||||
|
|
||||||
stream = results.get("events")
|
def stop_pusher(self, user_id, app_id, pushkey):
|
||||||
if stream and stream["rows"]:
|
key = "%s:%s" % (app_id, pushkey)
|
||||||
min_stream_id = stream["rows"][0][0]
|
pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
|
||||||
max_stream_id = stream["position"]
|
pusher = pushers_for_user.pop(key, None)
|
||||||
preserve_fn(pusher_pool.on_new_notifications)(
|
if pusher is None:
|
||||||
min_stream_id, max_stream_id
|
return
|
||||||
)
|
logger.info("Stopping pusher %r / %r", user_id, key)
|
||||||
|
pusher.on_stop()
|
||||||
|
|
||||||
stream = results.get("receipts")
|
def start_pusher(self, user_id, app_id, pushkey):
|
||||||
if stream and stream["rows"]:
|
key = "%s:%s" % (app_id, pushkey)
|
||||||
rows = stream["rows"]
|
logger.info("Starting pusher %r / %r", user_id, key)
|
||||||
affected_room_ids = set(row[1] for row in rows)
|
return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id)
|
||||||
min_stream_id = rows[0][0]
|
|
||||||
max_stream_id = stream["position"]
|
|
||||||
preserve_fn(pusher_pool.on_new_receipts)(
|
|
||||||
min_stream_id, max_stream_id, affected_room_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
args = store.stream_positions()
|
|
||||||
args["timeout"] = 30000
|
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
|
||||||
yield store.process_replication(result)
|
|
||||||
poke_pushers(result)
|
|
||||||
except:
|
|
||||||
logger.exception("Error replicating from %r", replication_url)
|
|
||||||
yield sleep(30)
|
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options):
|
||||||
|
@ -288,7 +257,6 @@ def start(config_options):
|
||||||
reactor.run()
|
reactor.run()
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
ps.replicate()
|
|
||||||
ps.get_pusherpool().start()
|
ps.get_pusherpool().start()
|
||||||
ps.get_datastore().start_profiling()
|
ps.get_datastore().start_profiling()
|
||||||
ps.get_state_handler().start_caching()
|
ps.get_state_handler().start_caching()
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, PresenceState
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.logger import setup_logging
|
from synapse.config.logger import setup_logging
|
||||||
|
@ -40,15 +40,14 @@ from synapse.replication.slave.storage.presence import SlavedPresenceStore
|
||||||
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
||||||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.client_ips import ClientIpStore
|
from synapse.storage.client_ips import ClientIpStore
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.presence import PresenceStore, UserPresenceState
|
from synapse.storage.presence import PresenceStore, UserPresenceState
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
from synapse.storage.roommember import RoomMemberStore
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, preserve_fn, \
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
|
||||||
PreserveLoggingContext
|
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
@ -107,11 +106,11 @@ UPDATE_SYNCING_USERS_MS = 10 * 1000
|
||||||
|
|
||||||
class SynchrotronPresence(object):
|
class SynchrotronPresence(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_simple_http_client()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.user_to_num_current_syncs = {}
|
self.user_to_num_current_syncs = {}
|
||||||
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
|
@ -124,14 +123,8 @@ class SynchrotronPresence(object):
|
||||||
self.process_id = random_string(16)
|
self.process_id = random_string(16)
|
||||||
logger.info("Presence process_id is %r", self.process_id)
|
logger.info("Presence process_id is %r", self.process_id)
|
||||||
|
|
||||||
self._sending_sync = False
|
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||||
self._need_to_send_sync = False
|
self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms)
|
||||||
self.clock.looping_call(
|
|
||||||
self._send_syncing_users_regularly,
|
|
||||||
UPDATE_SYNCING_USERS_MS,
|
|
||||||
)
|
|
||||||
|
|
||||||
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
|
|
||||||
|
|
||||||
def set_state(self, user, state, ignore_status_msg=False):
|
def set_state(self, user, state, ignore_status_msg=False):
|
||||||
# TODO Hows this supposed to work?
|
# TODO Hows this supposed to work?
|
||||||
|
@ -142,15 +135,15 @@ class SynchrotronPresence(object):
|
||||||
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
|
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
|
||||||
current_state_for_users = PresenceHandler.current_state_for_users.__func__
|
current_state_for_users = PresenceHandler.current_state_for_users.__func__
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def user_syncing(self, user_id, affect_presence):
|
def user_syncing(self, user_id, affect_presence):
|
||||||
if affect_presence:
|
if affect_presence:
|
||||||
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
||||||
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
||||||
prev_states = yield self.current_state_for_users([user_id])
|
|
||||||
if prev_states[user_id].state == PresenceState.OFFLINE:
|
# If we went from no in flight sync to some, notify replication
|
||||||
# TODO: Don't block the sync request on this HTTP hit.
|
if self.user_to_num_current_syncs[user_id] == 1:
|
||||||
yield self._send_syncing_users_now()
|
now = self.clock.time_msec()
|
||||||
|
self.send_user_sync(user_id, True, now)
|
||||||
|
|
||||||
def _end():
|
def _end():
|
||||||
# We check that the user_id is in user_to_num_current_syncs because
|
# We check that the user_id is in user_to_num_current_syncs because
|
||||||
|
@ -159,6 +152,11 @@ class SynchrotronPresence(object):
|
||||||
if affect_presence and user_id in self.user_to_num_current_syncs:
|
if affect_presence and user_id in self.user_to_num_current_syncs:
|
||||||
self.user_to_num_current_syncs[user_id] -= 1
|
self.user_to_num_current_syncs[user_id] -= 1
|
||||||
|
|
||||||
|
# If we went from one in flight sync to non, notify replication
|
||||||
|
if self.user_to_num_current_syncs[user_id] == 0:
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
self.send_user_sync(user_id, False, now)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _user_syncing():
|
def _user_syncing():
|
||||||
try:
|
try:
|
||||||
|
@ -166,49 +164,7 @@ class SynchrotronPresence(object):
|
||||||
finally:
|
finally:
|
||||||
_end()
|
_end()
|
||||||
|
|
||||||
defer.returnValue(_user_syncing())
|
return defer.succeed(_user_syncing())
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _on_shutdown(self):
|
|
||||||
# When the synchrotron is shutdown tell the master to clear the in
|
|
||||||
# progress syncs for this process
|
|
||||||
self.user_to_num_current_syncs.clear()
|
|
||||||
yield self._send_syncing_users_now()
|
|
||||||
|
|
||||||
def _send_syncing_users_regularly(self):
|
|
||||||
# Only send an update if we aren't in the middle of sending one.
|
|
||||||
if not self._sending_sync:
|
|
||||||
preserve_fn(self._send_syncing_users_now)()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _send_syncing_users_now(self):
|
|
||||||
if self._sending_sync:
|
|
||||||
# We don't want to race with sending another update.
|
|
||||||
# Instead we wait for that update to finish and send another
|
|
||||||
# update afterwards.
|
|
||||||
self._need_to_send_sync = True
|
|
||||||
return
|
|
||||||
|
|
||||||
# Flag that we are sending an update.
|
|
||||||
self._sending_sync = True
|
|
||||||
|
|
||||||
yield self.http_client.post_json_get_json(self.syncing_users_url, {
|
|
||||||
"process_id": self.process_id,
|
|
||||||
"syncing_users": [
|
|
||||||
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
|
||||||
if count > 0
|
|
||||||
],
|
|
||||||
})
|
|
||||||
|
|
||||||
# Unset the flag as we are no longer sending an update.
|
|
||||||
self._sending_sync = False
|
|
||||||
if self._need_to_send_sync:
|
|
||||||
# If something happened while we were sending the update then
|
|
||||||
# we might need to send another update.
|
|
||||||
# TODO: Check if the update that was sent matches the current state
|
|
||||||
# as we only need to send an update if they are different.
|
|
||||||
self._need_to_send_sync = False
|
|
||||||
yield self._send_syncing_users_now()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def notify_from_replication(self, states, stream_id):
|
def notify_from_replication(self, states, stream_id):
|
||||||
|
@ -223,26 +179,24 @@ class SynchrotronPresence(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, token, rows):
|
||||||
stream = result.get("presence", {"rows": []})
|
states = [UserPresenceState(
|
||||||
states = []
|
row.user_id, row.state, row.last_active_ts,
|
||||||
for row in stream["rows"]:
|
row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg,
|
||||||
(
|
row.currently_active
|
||||||
position, user_id, state, last_active_ts,
|
) for row in rows]
|
||||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
|
||||||
currently_active
|
|
||||||
) = row
|
|
||||||
state = UserPresenceState(
|
|
||||||
user_id, state, last_active_ts,
|
|
||||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
|
||||||
currently_active
|
|
||||||
)
|
|
||||||
self.user_to_current_state[user_id] = state
|
|
||||||
states.append(state)
|
|
||||||
|
|
||||||
if states and "position" in stream:
|
for state in states:
|
||||||
stream_id = int(stream["position"])
|
self.user_to_current_state[row.user_id] = state
|
||||||
yield self.notify_from_replication(states, stream_id)
|
|
||||||
|
stream_id = token
|
||||||
|
yield self.notify_from_replication(states, stream_id)
|
||||||
|
|
||||||
|
def get_currently_syncing_users(self):
|
||||||
|
return [
|
||||||
|
user_id for user_id, count in self.user_to_num_current_syncs.iteritems()
|
||||||
|
if count > 0
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SynchrotronTyping(object):
|
class SynchrotronTyping(object):
|
||||||
|
@ -257,16 +211,13 @@ class SynchrotronTyping(object):
|
||||||
# value which we *must* use for the next replication request.
|
# value which we *must* use for the next replication request.
|
||||||
return {"typing": self._latest_room_serial}
|
return {"typing": self._latest_room_serial}
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, token, rows):
|
||||||
stream = result.get("typing")
|
self._latest_room_serial = token
|
||||||
if stream:
|
|
||||||
self._latest_room_serial = int(stream["position"])
|
|
||||||
|
|
||||||
for row in stream["rows"]:
|
for row in rows:
|
||||||
position, room_id, typing_json = row
|
typing = json.loads(row.user_ids)
|
||||||
typing = json.loads(typing_json)
|
self._room_serials[row.room_id] = token
|
||||||
self._room_serials[room_id] = position
|
self._room_typing[row.room_id] = typing
|
||||||
self._room_typing[room_id] = typing
|
|
||||||
|
|
||||||
|
|
||||||
class SynchrotronApplicationService(object):
|
class SynchrotronApplicationService(object):
|
||||||
|
@ -351,118 +302,10 @@ class SynchrotronServer(HomeServer):
|
||||||
else:
|
else:
|
||||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
self.get_tcp_replication().start_replication(self)
|
||||||
def replicate(self):
|
|
||||||
http_client = self.get_simple_http_client()
|
|
||||||
store = self.get_datastore()
|
|
||||||
replication_url = self.config.worker_replication_url
|
|
||||||
notifier = self.get_notifier()
|
|
||||||
presence_handler = self.get_presence_handler()
|
|
||||||
typing_handler = self.get_typing_handler()
|
|
||||||
|
|
||||||
def notify_from_stream(
|
def build_tcp_replication(self):
|
||||||
result, stream_name, stream_key, room=None, user=None
|
return SyncReplicationHandler(self)
|
||||||
):
|
|
||||||
stream = result.get(stream_name)
|
|
||||||
if stream:
|
|
||||||
position_index = stream["field_names"].index("position")
|
|
||||||
if room:
|
|
||||||
room_index = stream["field_names"].index(room)
|
|
||||||
if user:
|
|
||||||
user_index = stream["field_names"].index(user)
|
|
||||||
|
|
||||||
users = ()
|
|
||||||
rooms = ()
|
|
||||||
for row in stream["rows"]:
|
|
||||||
position = row[position_index]
|
|
||||||
|
|
||||||
if user:
|
|
||||||
users = (row[user_index],)
|
|
||||||
|
|
||||||
if room:
|
|
||||||
rooms = (row[room_index],)
|
|
||||||
|
|
||||||
notifier.on_new_event(
|
|
||||||
stream_key, position, users=users, rooms=rooms
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def notify_device_list_update(result):
|
|
||||||
stream = result.get("device_lists")
|
|
||||||
if not stream:
|
|
||||||
return
|
|
||||||
|
|
||||||
position_index = stream["field_names"].index("position")
|
|
||||||
user_index = stream["field_names"].index("user_id")
|
|
||||||
|
|
||||||
for row in stream["rows"]:
|
|
||||||
position = row[position_index]
|
|
||||||
user_id = row[user_index]
|
|
||||||
|
|
||||||
room_ids = yield store.get_rooms_for_user(user_id)
|
|
||||||
|
|
||||||
notifier.on_new_event(
|
|
||||||
"device_list_key", position, rooms=room_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def notify(result):
|
|
||||||
stream = result.get("events")
|
|
||||||
if stream:
|
|
||||||
max_position = stream["position"]
|
|
||||||
|
|
||||||
event_map = yield store.get_events([row[1] for row in stream["rows"]])
|
|
||||||
|
|
||||||
for row in stream["rows"]:
|
|
||||||
position = row[0]
|
|
||||||
event_id = row[1]
|
|
||||||
event = event_map.get(event_id, None)
|
|
||||||
if not event:
|
|
||||||
continue
|
|
||||||
|
|
||||||
extra_users = ()
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
extra_users = (event.state_key,)
|
|
||||||
notifier.on_new_room_event(
|
|
||||||
event, position, max_position, extra_users
|
|
||||||
)
|
|
||||||
|
|
||||||
notify_from_stream(
|
|
||||||
result, "push_rules", "push_rules_key", user="user_id"
|
|
||||||
)
|
|
||||||
notify_from_stream(
|
|
||||||
result, "user_account_data", "account_data_key", user="user_id"
|
|
||||||
)
|
|
||||||
notify_from_stream(
|
|
||||||
result, "room_account_data", "account_data_key", user="user_id"
|
|
||||||
)
|
|
||||||
notify_from_stream(
|
|
||||||
result, "tag_account_data", "account_data_key", user="user_id"
|
|
||||||
)
|
|
||||||
notify_from_stream(
|
|
||||||
result, "receipts", "receipt_key", room="room_id"
|
|
||||||
)
|
|
||||||
notify_from_stream(
|
|
||||||
result, "typing", "typing_key", room="room_id"
|
|
||||||
)
|
|
||||||
notify_from_stream(
|
|
||||||
result, "to_device", "to_device_key", user="user_id"
|
|
||||||
)
|
|
||||||
yield notify_device_list_update(result)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
args = store.stream_positions()
|
|
||||||
args.update(typing_handler.stream_positions())
|
|
||||||
args["timeout"] = 30000
|
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
|
||||||
yield store.process_replication(result)
|
|
||||||
typing_handler.process_replication(result)
|
|
||||||
yield presence_handler.process_replication(result)
|
|
||||||
yield notify(result)
|
|
||||||
except:
|
|
||||||
logger.exception("Error replicating from %r", replication_url)
|
|
||||||
yield sleep(5)
|
|
||||||
|
|
||||||
def build_presence_handler(self):
|
def build_presence_handler(self):
|
||||||
return SynchrotronPresence(self)
|
return SynchrotronPresence(self)
|
||||||
|
@ -471,6 +314,79 @@ class SynchrotronServer(HomeServer):
|
||||||
return SynchrotronTyping(self)
|
return SynchrotronTyping(self)
|
||||||
|
|
||||||
|
|
||||||
|
class SyncReplicationHandler(ReplicationClientHandler):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(SyncReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
|
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.typing_handler = hs.get_typing_handler()
|
||||||
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
|
self.presence_handler.sync_callback = self.send_user_sync
|
||||||
|
|
||||||
|
def on_rdata(self, stream_name, token, rows):
|
||||||
|
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
|
|
||||||
|
preserve_fn(self.process_and_notify)(stream_name, token, rows)
|
||||||
|
|
||||||
|
def get_streams_to_replicate(self):
|
||||||
|
args = super(SyncReplicationHandler, self).get_streams_to_replicate()
|
||||||
|
args.update(self.typing_handler.stream_positions())
|
||||||
|
return args
|
||||||
|
|
||||||
|
def get_currently_syncing_users(self):
|
||||||
|
return self.presence_handler.get_currently_syncing_users()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def process_and_notify(self, stream_name, token, rows):
|
||||||
|
if stream_name == "events":
|
||||||
|
# We shouldn't get multiple rows per token for events stream, so
|
||||||
|
# we don't need to optimise this for multiple rows.
|
||||||
|
for row in rows:
|
||||||
|
event = yield self.store.get_event(row.event_id)
|
||||||
|
extra_users = ()
|
||||||
|
if event.type == EventTypes.Member:
|
||||||
|
extra_users = (event.state_key,)
|
||||||
|
max_token = self.store.get_room_max_stream_ordering()
|
||||||
|
self.notifier.on_new_room_event(
|
||||||
|
event, token, max_token, extra_users
|
||||||
|
)
|
||||||
|
elif stream_name == "push_rules":
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"push_rules_key", token, users=[row.user_id for row in rows],
|
||||||
|
)
|
||||||
|
elif stream_name in ("account_data", "tag_account_data",):
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"account_data_key", token, users=[row.user_id for row in rows],
|
||||||
|
)
|
||||||
|
elif stream_name == "receipts":
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"receipt_key", token, rooms=[row.room_id for row in rows],
|
||||||
|
)
|
||||||
|
elif stream_name == "typing":
|
||||||
|
self.typing_handler.process_replication_rows(token, rows)
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"typing_key", token, rooms=[row.room_id for row in rows],
|
||||||
|
)
|
||||||
|
elif stream_name == "to_device":
|
||||||
|
entities = [row.entity for row in rows if row.entity.startswith("@")]
|
||||||
|
if entities:
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"to_device_key", token, users=entities,
|
||||||
|
)
|
||||||
|
elif stream_name == "device_lists":
|
||||||
|
all_room_ids = set()
|
||||||
|
for row in rows:
|
||||||
|
room_ids = yield self.store.get_rooms_for_user(row.user_id)
|
||||||
|
all_room_ids.update(room_ids)
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"device_list_key", token, rooms=all_room_ids,
|
||||||
|
)
|
||||||
|
elif stream_name == "presence":
|
||||||
|
yield self.presence_handler.process_replication_rows(token, rows)
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options):
|
||||||
try:
|
try:
|
||||||
config = HomeServerConfig.load_config(
|
config = HomeServerConfig.load_config(
|
||||||
|
@ -514,7 +430,6 @@ def start(config_options):
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
ss.get_datastore().start_profiling()
|
ss.get_datastore().start_profiling()
|
||||||
ss.replicate()
|
|
||||||
ss.get_state_handler().start_caching()
|
ss.get_state_handler().start_caching()
|
||||||
|
|
||||||
reactor.callWhenRunning(start)
|
reactor.callWhenRunning(start)
|
||||||
|
|
|
@ -28,7 +28,9 @@ class WorkerConfig(Config):
|
||||||
self.worker_pid_file = config.get("worker_pid_file")
|
self.worker_pid_file = config.get("worker_pid_file")
|
||||||
self.worker_log_file = config.get("worker_log_file")
|
self.worker_log_file = config.get("worker_log_file")
|
||||||
self.worker_log_config = config.get("worker_log_config")
|
self.worker_log_config = config.get("worker_log_config")
|
||||||
self.worker_replication_url = config.get("worker_replication_url")
|
self.worker_replication_host = config.get("worker_replication_host", None)
|
||||||
|
self.worker_replication_port = config.get("worker_replication_port", None)
|
||||||
|
self.worker_name = config.get("worker_name", self.worker_app)
|
||||||
|
|
||||||
if self.worker_listeners:
|
if self.worker_listeners:
|
||||||
for listener in self.worker_listeners:
|
for listener in self.worker_listeners:
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
self._cache_id_gen = None
|
self._cache_id_gen = None
|
||||||
|
|
||||||
self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
|
self.hs = hs
|
||||||
self.http_client = hs.get_simple_http_client()
|
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
pos = {}
|
pos = {}
|
||||||
|
@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
|
||||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("caches")
|
if stream_name == "caches":
|
||||||
if stream:
|
self._cache_id_gen.advance(token)
|
||||||
for row in stream["rows"]:
|
for row in rows:
|
||||||
(
|
|
||||||
position, cache_func, keys, invalidation_ts,
|
|
||||||
) = row
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
getattr(self, cache_func).invalidate(tuple(keys))
|
getattr(self, row.cache_func).invalidate(tuple(row.keys))
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# We probably haven't pulled in the cache in this worker,
|
# We probably haven't pulled in the cache in this worker,
|
||||||
# which is fine.
|
# which is fine.
|
||||||
pass
|
pass
|
||||||
self._cache_id_gen.advance(int(stream["position"]))
|
|
||||||
return defer.succeed(None)
|
|
||||||
|
|
||||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||||
txn.call_after(cache_func.invalidate, keys)
|
txn.call_after(cache_func.invalidate, keys)
|
||||||
txn.call_after(self._send_invalidation_poke, cache_func, keys)
|
txn.call_after(self._send_invalidation_poke, cache_func, keys)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _send_invalidation_poke(self, cache_func, keys):
|
def _send_invalidation_poke(self, cache_func, keys):
|
||||||
try:
|
self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
|
||||||
yield self.http_client.post_json_get_json(self.expire_cache_url, {
|
|
||||||
"invalidate": [{
|
|
||||||
"name": cache_func.__name__,
|
|
||||||
"keys": list(keys),
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
except:
|
|
||||||
logger.exception("Failed to poke on expire_cache")
|
|
||||||
|
|
|
@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
||||||
result["tag_account_data"] = position
|
result["tag_account_data"] = position
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("user_account_data")
|
if stream_name == "tag_account_data":
|
||||||
if stream:
|
self._account_data_id_gen.advance(token)
|
||||||
self._account_data_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
for row in stream["rows"]:
|
self.get_tags_for_user.invalidate((row.user_id,))
|
||||||
position, user_id, data_type = row[:3]
|
|
||||||
self.get_global_account_data_by_type_for_user.invalidate(
|
|
||||||
(data_type, user_id,)
|
|
||||||
)
|
|
||||||
self.get_account_data_for_user.invalidate((user_id,))
|
|
||||||
self._account_data_stream_cache.entity_has_changed(
|
self._account_data_stream_cache.entity_has_changed(
|
||||||
user_id, position
|
row.user_id, token
|
||||||
)
|
)
|
||||||
|
elif stream_name == "account_data":
|
||||||
stream = result.get("room_account_data")
|
self._account_data_id_gen.advance(token)
|
||||||
if stream:
|
for row in rows:
|
||||||
self._account_data_id_gen.advance(int(stream["position"]))
|
if not row.room_id:
|
||||||
for row in stream["rows"]:
|
self.get_global_account_data_by_type_for_user.invalidate(
|
||||||
position, user_id = row[:2]
|
(row.data_type, row.user_id,)
|
||||||
self.get_account_data_for_user.invalidate((user_id,))
|
)
|
||||||
|
self.get_account_data_for_user.invalidate((row.user_id,))
|
||||||
self._account_data_stream_cache.entity_has_changed(
|
self._account_data_stream_cache.entity_has_changed(
|
||||||
user_id, position
|
row.user_id, token
|
||||||
)
|
)
|
||||||
|
return super(SlavedAccountDataStore, self).process_replication_rows(
|
||||||
stream = result.get("tag_account_data")
|
stream_name, token, rows
|
||||||
if stream:
|
)
|
||||||
self._account_data_id_gen.advance(int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
position, user_id = row[:2]
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
|
||||||
self._account_data_stream_cache.entity_has_changed(
|
|
||||||
user_id, position
|
|
||||||
)
|
|
||||||
|
|
||||||
return super(SlavedAccountDataStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
|
||||||
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("to_device")
|
if stream_name == "to_device":
|
||||||
if stream:
|
self._device_inbox_id_gen.advance(token)
|
||||||
self._device_inbox_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
for row in stream["rows"]:
|
if row.entity.startswith("@"):
|
||||||
stream_id = row[0]
|
|
||||||
entity = row[1]
|
|
||||||
|
|
||||||
if entity.startswith("@"):
|
|
||||||
self._device_inbox_stream_cache.entity_has_changed(
|
self._device_inbox_stream_cache.entity_has_changed(
|
||||||
entity, stream_id
|
row.entity, token
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._device_federation_outbox_stream_cache.entity_has_changed(
|
self._device_federation_outbox_stream_cache.entity_has_changed(
|
||||||
entity, stream_id
|
row.entity, token
|
||||||
)
|
)
|
||||||
|
return super(SlavedDeviceInboxStore, self).process_replication_rows(
|
||||||
return super(SlavedDeviceInboxStore, self).process_replication(result)
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
|
|
@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore):
|
||||||
result["device_lists"] = self._device_list_id_gen.get_current_token()
|
result["device_lists"] = self._device_list_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("device_lists")
|
if stream_name == "device_lists":
|
||||||
if stream:
|
self._device_list_id_gen.advance(token)
|
||||||
self._device_list_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
for row in stream["rows"]:
|
|
||||||
stream_id = row[0]
|
|
||||||
user_id = row[1]
|
|
||||||
destination = row[2]
|
|
||||||
|
|
||||||
self._device_list_stream_cache.entity_has_changed(
|
self._device_list_stream_cache.entity_has_changed(
|
||||||
user_id, stream_id
|
row.user_id, token
|
||||||
)
|
)
|
||||||
|
|
||||||
if destination:
|
if row.destination:
|
||||||
self._device_list_federation_stream_cache.entity_has_changed(
|
self._device_list_federation_stream_cache.entity_has_changed(
|
||||||
destination, stream_id
|
row.destination, token
|
||||||
)
|
)
|
||||||
|
return super(SlavedDeviceStore, self).process_replication_rows(
|
||||||
return super(SlavedDeviceStore, self).process_replication(result)
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
|
|
@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("events")
|
if stream_name == "events":
|
||||||
if stream:
|
self._stream_id_gen.advance(token)
|
||||||
self._stream_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
|
self.invalidate_caches_for_event(
|
||||||
if stream["rows"]:
|
token, row.event_id, row.room_id, row.type, row.state_key,
|
||||||
logger.info("Got %d event rows", len(stream["rows"]))
|
row.redacts,
|
||||||
|
backfilled=False,
|
||||||
for row in stream["rows"]:
|
|
||||||
self._process_replication_row(
|
|
||||||
row, backfilled=False,
|
|
||||||
)
|
)
|
||||||
|
elif stream_name == "backfill":
|
||||||
stream = result.get("backfill")
|
self._backfill_id_gen.advance(-token)
|
||||||
if stream:
|
for row in rows:
|
||||||
self._backfill_id_gen.advance(-int(stream["position"]))
|
self.invalidate_caches_for_event(
|
||||||
for row in stream["rows"]:
|
-token, row.event_id, row.room_id, row.type, row.state_key,
|
||||||
self._process_replication_row(
|
row.redacts,
|
||||||
row, backfilled=True,
|
backfilled=True,
|
||||||
)
|
)
|
||||||
|
return super(SlavedEventStore, self).process_replication_rows(
|
||||||
stream = result.get("forward_ex_outliers")
|
stream_name, token, rows
|
||||||
if stream:
|
|
||||||
self._stream_id_gen.advance(int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
event_id = row[1]
|
|
||||||
self._invalidate_get_event_cache(event_id)
|
|
||||||
|
|
||||||
stream = result.get("backward_ex_outliers")
|
|
||||||
if stream:
|
|
||||||
self._backfill_id_gen.advance(-int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
event_id = row[1]
|
|
||||||
self._invalidate_get_event_cache(event_id)
|
|
||||||
|
|
||||||
return super(SlavedEventStore, self).process_replication(result)
|
|
||||||
|
|
||||||
def _process_replication_row(self, row, backfilled):
|
|
||||||
stream_ordering = row[0] if not backfilled else -row[0]
|
|
||||||
self.invalidate_caches_for_event(
|
|
||||||
stream_ordering, row[1], row[2], row[3], row[4], row[5],
|
|
||||||
backfilled=backfilled,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
|
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
|
||||||
|
|
|
@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||||
result["presence"] = position
|
result["presence"] = position
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("presence")
|
if stream_name == "presence":
|
||||||
if stream:
|
self._presence_id_gen.advance(token)
|
||||||
self._presence_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
for row in stream["rows"]:
|
|
||||||
position, user_id = row[:2]
|
|
||||||
self.presence_stream_cache.entity_has_changed(
|
self.presence_stream_cache.entity_has_changed(
|
||||||
user_id, position
|
row.user_id, token
|
||||||
)
|
)
|
||||||
self._get_presence_for_user.invalidate((user_id,))
|
self._get_presence_for_user.invalidate((row.user_id,))
|
||||||
|
return super(SlavedPresenceStore, self).process_replication_rows(
|
||||||
return super(SlavedPresenceStore, self).process_replication(result)
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
|
|
@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore):
|
||||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("push_rules")
|
if stream_name == "push_rules":
|
||||||
if stream:
|
self._push_rules_stream_id_gen.advance(token)
|
||||||
for row in stream["rows"]:
|
for row in rows:
|
||||||
position = row[0]
|
self.get_push_rules_for_user.invalidate((row.user_id,))
|
||||||
user_id = row[2]
|
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
|
||||||
self.get_push_rules_for_user.invalidate((user_id,))
|
|
||||||
self.get_push_rules_enabled_for_user.invalidate((user_id,))
|
|
||||||
self.push_rules_stream_cache.entity_has_changed(
|
self.push_rules_stream_cache.entity_has_changed(
|
||||||
user_id, position
|
row.user_id, token
|
||||||
)
|
)
|
||||||
|
return super(SlavedPushRuleStore, self).process_replication_rows(
|
||||||
self._push_rules_stream_id_gen.advance(int(stream["position"]))
|
stream_name, token, rows
|
||||||
|
)
|
||||||
return super(SlavedPushRuleStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore):
|
||||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("pushers")
|
if stream_name == "pushers":
|
||||||
if stream:
|
self._pushers_id_gen.advance(token)
|
||||||
self._pushers_id_gen.advance(int(stream["position"]))
|
return super(SlavedPusherStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
stream = result.get("deleted_pushers")
|
)
|
||||||
if stream:
|
|
||||||
self._pushers_id_gen.advance(int(stream["position"]))
|
|
||||||
|
|
||||||
return super(SlavedPusherStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
||||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
|
||||||
stream = result.get("receipts")
|
|
||||||
if stream:
|
|
||||||
self._receipts_id_gen.advance(int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
position, room_id, receipt_type, user_id = row[:4]
|
|
||||||
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
|
|
||||||
self._receipts_stream_cache.entity_has_changed(room_id, position)
|
|
||||||
|
|
||||||
return super(SlavedReceiptsStore, self).process_replication(result)
|
|
||||||
|
|
||||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||||
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
|
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||||
self.get_last_receipt_event_id_for_user.invalidate(
|
self.get_last_receipt_event_id_for_user.invalidate(
|
||||||
(user_id, room_id, receipt_type)
|
(user_id, room_id, receipt_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
|
if stream_name == "receipts":
|
||||||
|
self._receipts_id_gen.advance(token)
|
||||||
|
for row in rows:
|
||||||
|
self.invalidate_caches_for_receipt(
|
||||||
|
row.room_id, row.receipt_type, row.user_id
|
||||||
|
)
|
||||||
|
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
|
||||||
|
|
||||||
|
return super(SlavedReceiptsStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
|
|
@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore):
|
||||||
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("public_rooms")
|
if stream_name == "public_rooms":
|
||||||
if stream:
|
self._public_room_id_gen.advance(token)
|
||||||
self._public_room_id_gen.advance(int(stream["position"]))
|
|
||||||
|
|
||||||
return super(RoomStore, self).process_replication(result)
|
return super(RoomStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
|
196
synapse/replication/tcp/client.py
Normal file
196
synapse/replication/tcp/client.py
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""A replication client for use by synapse workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from twisted.internet import reactor, defer
|
||||||
|
from twisted.internet.protocol import ReconnectingClientFactory
|
||||||
|
|
||||||
|
from .commands import (
|
||||||
|
FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
|
||||||
|
)
|
||||||
|
from .protocol import ClientReplicationStreamProtocol
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationClientFactory(ReconnectingClientFactory):
|
||||||
|
"""Factory for building connections to the master. Will reconnect if the
|
||||||
|
connection is lost.
|
||||||
|
|
||||||
|
Accepts a handler that will be called when new data is available or data
|
||||||
|
is required.
|
||||||
|
"""
|
||||||
|
maxDelay = 5 # Try at least once every N seconds
|
||||||
|
|
||||||
|
def __init__(self, hs, client_name, handler):
|
||||||
|
self.client_name = client_name
|
||||||
|
self.handler = handler
|
||||||
|
self.server_name = hs.config.server_name
|
||||||
|
self._clock = hs.get_clock() # As self.clock is defined in super class
|
||||||
|
|
||||||
|
reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
||||||
|
|
||||||
|
def startedConnecting(self, connector):
|
||||||
|
logger.info("Connecting to replication: %r", connector.getDestination())
|
||||||
|
|
||||||
|
def buildProtocol(self, addr):
|
||||||
|
logger.info("Connected to replication: %r", addr)
|
||||||
|
self.resetDelay()
|
||||||
|
return ClientReplicationStreamProtocol(
|
||||||
|
self.client_name, self.server_name, self._clock, self.handler
|
||||||
|
)
|
||||||
|
|
||||||
|
def clientConnectionLost(self, connector, reason):
|
||||||
|
logger.error("Lost replication conn: %r", reason)
|
||||||
|
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
|
||||||
|
|
||||||
|
def clientConnectionFailed(self, connector, reason):
|
||||||
|
logger.error("Failed to connect to replication: %r", reason)
|
||||||
|
ReconnectingClientFactory.clientConnectionFailed(
|
||||||
|
self, connector, reason
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationClientHandler(object):
|
||||||
|
"""A base handler that can be passed to the ReplicationClientFactory.
|
||||||
|
|
||||||
|
By default proxies incoming replication data to the SlaveStore.
|
||||||
|
"""
|
||||||
|
def __init__(self, store):
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
# The current connection. None if we are currently (re)connecting
|
||||||
|
self.connection = None
|
||||||
|
|
||||||
|
# Any pending commands to be sent once a new connection has been
|
||||||
|
# established
|
||||||
|
self.pending_commands = []
|
||||||
|
|
||||||
|
# Map from string -> deferred, to wake up when receiveing a SYNC with
|
||||||
|
# the given string.
|
||||||
|
# Used for tests.
|
||||||
|
self.awaiting_syncs = {}
|
||||||
|
|
||||||
|
def start_replication(self, hs):
|
||||||
|
"""Helper method to start a replication connection to the remote server
|
||||||
|
using TCP.
|
||||||
|
"""
|
||||||
|
client_name = hs.config.worker_name
|
||||||
|
factory = ReplicationClientFactory(hs, client_name, self)
|
||||||
|
host = hs.config.worker_replication_host
|
||||||
|
port = hs.config.worker_replication_port
|
||||||
|
reactor.connectTCP(host, port, factory)
|
||||||
|
|
||||||
|
def on_rdata(self, stream_name, token, rows):
|
||||||
|
"""Called when we get new replication data. By default this just pokes
|
||||||
|
the slave store.
|
||||||
|
|
||||||
|
Can be overriden in subclasses to handle more.
|
||||||
|
"""
|
||||||
|
logger.info("Received rdata %s -> %s", stream_name, token)
|
||||||
|
self.store.process_replication_rows(stream_name, token, rows)
|
||||||
|
|
||||||
|
def on_position(self, stream_name, token):
|
||||||
|
"""Called when we get new position data. By default this just pokes
|
||||||
|
the slave store.
|
||||||
|
|
||||||
|
Can be overriden in subclasses to handle more.
|
||||||
|
"""
|
||||||
|
self.store.process_replication_rows(stream_name, token, [])
|
||||||
|
|
||||||
|
def on_sync(self, data):
|
||||||
|
"""When we received a SYNC we wake up any deferreds that were waiting
|
||||||
|
for the sync with the given data.
|
||||||
|
|
||||||
|
Used by tests.
|
||||||
|
"""
|
||||||
|
d = self.awaiting_syncs.pop(data, None)
|
||||||
|
if d:
|
||||||
|
d.callback(data)
|
||||||
|
|
||||||
|
def get_streams_to_replicate(self):
|
||||||
|
"""Called when a new connection has been established and we need to
|
||||||
|
subscribe to streams.
|
||||||
|
|
||||||
|
Returns a dictionary of stream name to token.
|
||||||
|
"""
|
||||||
|
args = self.store.stream_positions()
|
||||||
|
user_account_data = args.pop("user_account_data", None)
|
||||||
|
room_account_data = args.pop("room_account_data", None)
|
||||||
|
if user_account_data:
|
||||||
|
args["account_data"] = user_account_data
|
||||||
|
elif room_account_data:
|
||||||
|
args["account_data"] = room_account_data
|
||||||
|
return args
|
||||||
|
|
||||||
|
def get_currently_syncing_users(self):
|
||||||
|
"""Get the list of currently syncing users (if any). This is called
|
||||||
|
when a connection has been established and we need to send the
|
||||||
|
currently syncing users. (Overriden by the synchrotron's only)
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def send_command(self, cmd):
|
||||||
|
"""Send a command to master (when we get establish a connection if we
|
||||||
|
don't have one already.)
|
||||||
|
"""
|
||||||
|
if self.connection:
|
||||||
|
self.connection.send_command(cmd)
|
||||||
|
else:
|
||||||
|
logger.warn("Queuing command as not connected: %r", cmd.NAME)
|
||||||
|
self.pending_commands.append(cmd)
|
||||||
|
|
||||||
|
def send_federation_ack(self, token):
|
||||||
|
"""Ack data for the federation stream. This allows the master to drop
|
||||||
|
data stored purely in memory.
|
||||||
|
"""
|
||||||
|
self.send_command(FederationAckCommand(token))
|
||||||
|
|
||||||
|
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||||
|
"""Poke the master that a user has started/stopped syncing.
|
||||||
|
"""
|
||||||
|
self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
|
||||||
|
|
||||||
|
def send_remove_pusher(self, app_id, push_key, user_id):
|
||||||
|
"""Poke the master to remove a pusher for a user
|
||||||
|
"""
|
||||||
|
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||||
|
self.send_command(cmd)
|
||||||
|
|
||||||
|
def send_invalidate_cache(self, cache_func, keys):
|
||||||
|
"""Poke the master to invalidate a cache.
|
||||||
|
"""
|
||||||
|
cmd = InvalidateCacheCommand(cache_func, keys)
|
||||||
|
self.send_command(cmd)
|
||||||
|
|
||||||
|
def await_sync(self, data):
|
||||||
|
"""Returns a deferred that is resolved when we receive a SYNC command
|
||||||
|
with given data.
|
||||||
|
|
||||||
|
Used by tests.
|
||||||
|
"""
|
||||||
|
return self.awaiting_syncs.setdefault(data, defer.Deferred())
|
||||||
|
|
||||||
|
def update_connection(self, connection):
|
||||||
|
"""Called when a connection has been established (or lost with None).
|
||||||
|
"""
|
||||||
|
self.connection = connection
|
||||||
|
if connection:
|
||||||
|
for cmd in self.pending_commands:
|
||||||
|
connection.send_command(cmd)
|
||||||
|
self.pending_commands = []
|
|
@ -132,6 +132,7 @@ class HomeServer(object):
|
||||||
'federation_sender',
|
'federation_sender',
|
||||||
'receipts_handler',
|
'receipts_handler',
|
||||||
'macaroon_generator',
|
'macaroon_generator',
|
||||||
|
'tcp_replication',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, **kwargs):
|
||||||
|
@ -290,6 +291,9 @@ class HomeServer(object):
|
||||||
def build_receipts_handler(self):
|
def build_receipts_handler(self):
|
||||||
return ReceiptsHandler(self)
|
return ReceiptsHandler(self)
|
||||||
|
|
||||||
|
def build_tcp_replication(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
|
|
@ -12,12 +12,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from synapse.replication.resource import ReplicationResource
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
|
from synapse.replication.tcp.client import (
|
||||||
|
ReplicationClientHandler, ReplicationClientFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseSlavedStoreTestCase(unittest.TestCase):
|
class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||||
|
@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
|
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
|
||||||
|
|
||||||
self.replication = ReplicationResource(self.hs)
|
|
||||||
|
|
||||||
self.master_store = self.hs.get_datastore()
|
self.master_store = self.hs.get_datastore()
|
||||||
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
|
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
|
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||||
|
listener = reactor.listenUNIX("\0xxx", server_factory)
|
||||||
|
self.addCleanup(listener.stopListening)
|
||||||
|
self.streamer = server_factory.streamer
|
||||||
|
|
||||||
|
self.replication_handler = ReplicationClientHandler(self.slaved_store)
|
||||||
|
client_factory = ReplicationClientFactory(
|
||||||
|
self.hs, "client_name", self.replication_handler
|
||||||
|
)
|
||||||
|
client_connector = reactor.connectUNIX("\0xxx", client_factory)
|
||||||
|
self.addCleanup(client_factory.stopTrying)
|
||||||
|
self.addCleanup(client_connector.disconnect)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def replicate(self):
|
def replicate(self):
|
||||||
streams = self.slaved_store.stream_positions()
|
yield self.streamer.on_notifier_poke()
|
||||||
writer = yield self.replication.replicate(streams, 100)
|
d = self.replication_handler.await_sync("replication_test")
|
||||||
result = writer.finish()
|
self.streamer.send_sync_to_all_connections("replication_test")
|
||||||
yield self.slaved_store.process_replication(result)
|
yield d
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check(self, method, args, expected_result=None):
|
def check(self, method, args, expected_result=None):
|
||||||
|
|
Loading…
Reference in a new issue