0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 15:53:51 +01:00

Merge pull request #2097 from matrix-org/erikj/repl_tcp_client

Move to using TCP replication
This commit is contained in:
Erik Johnston 2017-04-05 09:36:21 +01:00 committed by GitHub
commit a5c401bd12
21 changed files with 601 additions and 581 deletions

View file

@ -26,17 +26,17 @@ from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse import events from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
from daemonize import Daemonize from daemonize import Daemonize
@ -120,30 +120,25 @@ class AppserviceServer(HomeServer):
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks self.get_tcp_replication().start_replication(self)
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks def build_tcp_replication(self):
def replicate(results): return ASReplicationHandler(self)
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try: class ASReplicationHandler(ReplicationClientHandler):
args = store.stream_positions() def __init__(self, hs):
args["timeout"] = 30000 super(ASReplicationHandler, self).__init__(hs.get_datastore())
result = yield http_client.get_json(replication_url, args=args) self.appservice_handler = hs.get_application_service_handler()
yield store.process_replication(result)
replicate(result) def on_rdata(self, stream_name, token, rows):
except: super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
logger.exception("Error replicating from %r", replication_url)
yield sleep(30) if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
preserve_fn(
self.appservice_handler.notify_interested_services
)(max_stream_id)
def start(config_options): def start(config_options):
@ -199,7 +194,6 @@ def start(config_options):
reactor.run() reactor.run()
def start(): def start():
ps.replicate()
ps.get_datastore().start_profiling() ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching() ps.get_state_handler().start_caching()

View file

@ -30,11 +30,11 @@ from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import PublicRoomListRestServlet from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -45,7 +45,7 @@ from synapse.crypto import context_factory
from synapse import events from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
from daemonize import Daemonize from daemonize import Daemonize
@ -145,21 +145,10 @@ class ClientReaderServer(HomeServer):
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks self.get_tcp_replication().start_replication(self)
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True: def build_tcp_replication(self):
try: return ReplicationClientHandler(self.get_datastore())
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options): def start(config_options):
@ -209,7 +198,6 @@ def start(config_options):
def start(): def start():
ss.get_state_handler().start_caching() ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -27,9 +27,9 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -42,7 +42,7 @@ from synapse.crypto import context_factory
from synapse import events from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
from daemonize import Daemonize from daemonize import Daemonize
@ -134,21 +134,10 @@ class FederationReaderServer(HomeServer):
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks self.get_tcp_replication().start_replication(self)
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True: def build_tcp_replication(self):
try: return ReplicationClientHandler(self.get_datastore())
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options): def start(config_options):
@ -198,7 +187,6 @@ def start(config_options):
def start(): def start():
ss.get_state_handler().start_caching() ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -31,11 +31,12 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep from synapse.util.async import Linearizer
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -59,7 +60,28 @@ class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedDeviceStore, SlavedRegistrationStore, SlavedDeviceStore,
): ):
pass def __init__(self, db_conn, hs):
super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
# We pull out the current federation stream position now so that we
# always have a known value for the federation position in memory so
# that we don't have to bounce via a deferred once when we start the
# replication streams.
self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
def _get_federation_out_pos(self, db_conn):
sql = (
"SELECT stream_id FROM federation_stream_position"
" WHERE type = ?"
)
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, ("federation",))
rows = txn.fetchall()
txn.close()
return rows[0][0] if rows else -1
class FederationSenderServer(HomeServer): class FederationSenderServer(HomeServer):
@ -127,26 +149,27 @@ class FederationSenderServer(HomeServer):
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks self.get_tcp_replication().start_replication(self)
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
send_handler = FederationSenderHandler(self)
send_handler.on_start() def build_tcp_replication(self):
return FederationSenderReplicationHandler(self)
while True:
try: class FederationSenderReplicationHandler(ReplicationClientHandler):
args = store.stream_positions() def __init__(self, hs):
args.update((yield send_handler.stream_positions())) super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
args["timeout"] = 30000 self.send_handler = FederationSenderHandler(hs, self)
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result) def on_rdata(self, stream_name, token, rows):
yield send_handler.process_replication(result) super(FederationSenderReplicationHandler, self).on_rdata(
except: stream_name, token, rows
logger.exception("Error replicating from %r", replication_url) )
yield sleep(30) self.send_handler.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self):
args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
args.update(self.send_handler.stream_positions())
return args
def start(config_options): def start(config_options):
@ -205,7 +228,6 @@ def start(config_options):
reactor.run() reactor.run()
def start(): def start():
ps.replicate()
ps.get_datastore().start_profiling() ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching() ps.get_state_handler().start_caching()
@ -229,9 +251,15 @@ class FederationSenderHandler(object):
"""Processes the replication stream and forwards the appropriate entries """Processes the replication stream and forwards the appropriate entries
to the federation sender. to the federation sender.
""" """
def __init__(self, hs): def __init__(self, hs, replication_client):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.replication_client = replication_client
self.federation_position = self.store.federation_out_pos_startup
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
self._last_ack = self.federation_position
self._room_serials = {} self._room_serials = {}
self._room_typing = {} self._room_typing = {}
@ -243,25 +271,13 @@ class FederationSenderHandler(object):
self.store.get_room_max_stream_ordering() self.store.get_room_max_stream_ordering()
) )
@defer.inlineCallbacks
def stream_positions(self): def stream_positions(self):
stream_id = yield self.store.get_federation_out_pos("federation") return {"federation": self.federation_position}
defer.returnValue({
"federation": stream_id,
# Ack stuff we've "processed", this should only be called from def process_replication_rows(self, stream_name, token, rows):
# one process.
"federation_ack": stream_id,
})
@defer.inlineCallbacks
def process_replication(self, result):
# The federation stream contains things that we want to send out, e.g. # The federation stream contains things that we want to send out, e.g.
# presence, typing, etc. # presence, typing, etc.
fed_stream = result.get("federation") if stream_name == "federation":
if fed_stream:
latest_id = int(fed_stream["position"])
# The federation stream containis a bunch of different types of # The federation stream containis a bunch of different types of
# rows that need to be handled differently. We parse the rows, put # rows that need to be handled differently. We parse the rows, put
# them into the appropriate collection and then send them off. # them into the appropriate collection and then send them off.
@ -272,8 +288,9 @@ class FederationSenderHandler(object):
device_destinations = set() device_destinations = set()
# Parse the rows in the stream # Parse the rows in the stream
for row in fed_stream["rows"]: for row in rows:
position, typ, content_js = row typ = row.type
content_js = row.data
content = json.loads(content_js) content = json.loads(content_js)
if typ == send_queue.PRESENCE_TYPE: if typ == send_queue.PRESENCE_TYPE:
@ -325,16 +342,27 @@ class FederationSenderHandler(object):
for destination in device_destinations: for destination in device_destinations:
self.federation_sender.send_device_messages(destination) self.federation_sender.send_device_messages(destination)
# Record where we are in the stream. preserve_fn(self.update_token)(token)
yield self.store.update_federation_out_pos(
"federation", latest_id
)
# We also need to poke the federation sender when new events happen # We also need to poke the federation sender when new events happen
event_stream = result.get("events") elif stream_name == "events":
if event_stream: self.federation_sender.notify_new_events(token)
latest_pos = event_stream["position"]
self.federation_sender.notify_new_events(latest_pos) @defer.inlineCallbacks
def update_token(self, token):
self.federation_position = token
# We linearize here to ensure we don't have races updating the token
with (yield self._fed_position_linearizer.queue(None)):
if self._last_ack < self.federation_position:
yield self.store.update_federation_out_pos(
"federation", self.federation_position
)
# We ACK this token over replication so that the master can drop
# its in memory queues
self.replication_client.send_federation_ack(self.federation_position)
self._last_ack = self.federation_position
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -25,13 +25,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -45,7 +45,7 @@ from synapse.crypto import context_factory
from synapse import events from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
from daemonize import Daemonize from daemonize import Daemonize
@ -142,21 +142,10 @@ class MediaRepositoryServer(HomeServer):
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks self.get_tcp_replication().start_replication(self)
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True: def build_tcp_replication(self):
try: return ReplicationClientHandler(self.get_datastore())
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options): def start(config_options):
@ -206,7 +195,6 @@ def start(config_options):
def start(): def start():
ss.get_state_handler().start_caching() ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -27,9 +27,9 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn, \ from synapse.util.logcontext import LoggingContext, preserve_fn, \
PreserveLoggingContext PreserveLoggingContext
@ -89,7 +89,6 @@ class PusherSlaveStore(
class PusherServer(HomeServer): class PusherServer(HomeServer):
def get_db_conn(self, run_new_connection=True): def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should # Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine. # not be passed to the database engine.
@ -109,16 +108,7 @@ class PusherServer(HomeServer):
logger.info("Finished setting up.") logger.info("Finished setting up.")
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
http_client = self.get_simple_http_client() self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
replication_url = self.config.worker_replication_url
url = replication_url + "/remove_pushers"
return http_client.post_json_get_json(url, {
"remove": [{
"app_id": app_id,
"push_key": push_key,
"user_id": user_id,
}]
})
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
@ -166,73 +156,52 @@ class PusherServer(HomeServer):
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
def build_tcp_replication(self):
return PusherReplicationHandler(self)
class PusherReplicationHandler(ReplicationClientHandler):
def __init__(self, hs):
super(PusherReplicationHandler, self).__init__(hs.get_datastore())
self.pusher_pool = hs.get_pusherpool()
def on_rdata(self, stream_name, token, rows):
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
preserve_fn(self.poke_pushers)(stream_name, token, rows)
@defer.inlineCallbacks @defer.inlineCallbacks
def replicate(self): def poke_pushers(self, stream_name, token, rows):
http_client = self.get_simple_http_client() if stream_name == "pushers":
store = self.get_datastore() for row in rows:
replication_url = self.config.worker_replication_url if row.deleted:
pusher_pool = self.get_pusherpool() yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
else:
def stop_pusher(user_id, app_id, pushkey): yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
key = "%s:%s" % (app_id, pushkey) elif stream_name == "events":
pushers_for_user = pusher_pool.pushers.get(user_id, {}) yield self.pusher_pool.on_new_notifications(
pusher = pushers_for_user.pop(key, None) token, token,
if pusher is None:
return
logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()
def start_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
@defer.inlineCallbacks
def poke_pushers(results):
pushers_rows = set(
map(tuple, results.get("pushers", {}).get("rows", []))
) )
deleted_pushers_rows = set( elif stream_name == "receipts":
map(tuple, results.get("deleted_pushers", {}).get("rows", [])) yield self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows)
) )
for row in sorted(pushers_rows | deleted_pushers_rows):
if row in deleted_pushers_rows:
user_id, app_id, pushkey = row[1:4]
stop_pusher(user_id, app_id, pushkey)
elif row in pushers_rows:
user_id = row[1]
app_id = row[5]
pushkey = row[8]
yield start_pusher(user_id, app_id, pushkey)
stream = results.get("events") def stop_pusher(self, user_id, app_id, pushkey):
if stream and stream["rows"]: key = "%s:%s" % (app_id, pushkey)
min_stream_id = stream["rows"][0][0] pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
max_stream_id = stream["position"] pusher = pushers_for_user.pop(key, None)
preserve_fn(pusher_pool.on_new_notifications)( if pusher is None:
min_stream_id, max_stream_id return
) logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()
stream = results.get("receipts") def start_pusher(self, user_id, app_id, pushkey):
if stream and stream["rows"]: key = "%s:%s" % (app_id, pushkey)
rows = stream["rows"] logger.info("Starting pusher %r / %r", user_id, key)
affected_room_ids = set(row[1] for row in rows) return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id)
min_stream_id = rows[0][0]
max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_receipts)(
min_stream_id, max_stream_id, affected_room_ids
)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
poke_pushers(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options): def start(config_options):
@ -288,7 +257,6 @@ def start(config_options):
reactor.run() reactor.run()
def start(): def start():
ps.replicate()
ps.get_pusherpool().start() ps.get_pusherpool().start()
ps.get_datastore().start_profiling() ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching() ps.get_state_handler().start_caching()

View file

@ -16,7 +16,7 @@
import synapse import synapse
from synapse.api.constants import EventTypes, PresenceState from synapse.api.constants import EventTypes
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging from synapse.config.logger import setup_logging
@ -40,15 +40,14 @@ from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.presence import PresenceStore, UserPresenceState from synapse.storage.presence import PresenceStore, UserPresenceState
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn, \ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -107,11 +106,11 @@ UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object): class SynchrotronPresence(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -124,14 +123,8 @@ class SynchrotronPresence(object):
self.process_id = random_string(16) self.process_id = random_string(16)
logger.info("Presence process_id is %r", self.process_id) logger.info("Presence process_id is %r", self.process_id)
self._sending_sync = False def send_user_sync(self, user_id, is_syncing, last_sync_ms):
self._need_to_send_sync = False self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms)
self.clock.looping_call(
self._send_syncing_users_regularly,
UPDATE_SYNCING_USERS_MS,
)
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state, ignore_status_msg=False): def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work? # TODO Hows this supposed to work?
@ -142,15 +135,15 @@ class SynchrotronPresence(object):
_get_interested_parties = PresenceHandler._get_interested_parties.__func__ _get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
def user_syncing(self, user_id, affect_presence): def user_syncing(self, user_id, affect_presence):
if affect_presence: if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0) curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1 self.user_to_num_current_syncs[user_id] = curr_sync + 1
prev_states = yield self.current_state_for_users([user_id])
if prev_states[user_id].state == PresenceState.OFFLINE: # If we went from no in flight sync to some, notify replication
# TODO: Don't block the sync request on this HTTP hit. if self.user_to_num_current_syncs[user_id] == 1:
yield self._send_syncing_users_now() now = self.clock.time_msec()
self.send_user_sync(user_id, True, now)
def _end(): def _end():
# We check that the user_id is in user_to_num_current_syncs because # We check that the user_id is in user_to_num_current_syncs because
@ -159,6 +152,11 @@ class SynchrotronPresence(object):
if affect_presence and user_id in self.user_to_num_current_syncs: if affect_presence and user_id in self.user_to_num_current_syncs:
self.user_to_num_current_syncs[user_id] -= 1 self.user_to_num_current_syncs[user_id] -= 1
# If we went from one in flight sync to non, notify replication
if self.user_to_num_current_syncs[user_id] == 0:
now = self.clock.time_msec()
self.send_user_sync(user_id, False, now)
@contextlib.contextmanager @contextlib.contextmanager
def _user_syncing(): def _user_syncing():
try: try:
@ -166,49 +164,7 @@ class SynchrotronPresence(object):
finally: finally:
_end() _end()
defer.returnValue(_user_syncing()) return defer.succeed(_user_syncing())
@defer.inlineCallbacks
def _on_shutdown(self):
# When the synchrotron is shutdown tell the master to clear the in
# progress syncs for this process
self.user_to_num_current_syncs.clear()
yield self._send_syncing_users_now()
def _send_syncing_users_regularly(self):
# Only send an update if we aren't in the middle of sending one.
if not self._sending_sync:
preserve_fn(self._send_syncing_users_now)()
@defer.inlineCallbacks
def _send_syncing_users_now(self):
if self._sending_sync:
# We don't want to race with sending another update.
# Instead we wait for that update to finish and send another
# update afterwards.
self._need_to_send_sync = True
return
# Flag that we are sending an update.
self._sending_sync = True
yield self.http_client.post_json_get_json(self.syncing_users_url, {
"process_id": self.process_id,
"syncing_users": [
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count > 0
],
})
# Unset the flag as we are no longer sending an update.
self._sending_sync = False
if self._need_to_send_sync:
# If something happened while we were sending the update then
# we might need to send another update.
# TODO: Check if the update that was sent matches the current state
# as we only need to send an update if they are different.
self._need_to_send_sync = False
yield self._send_syncing_users_now()
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_from_replication(self, states, stream_id): def notify_from_replication(self, states, stream_id):
@ -223,26 +179,24 @@ class SynchrotronPresence(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def process_replication(self, result): def process_replication_rows(self, token, rows):
stream = result.get("presence", {"rows": []}) states = [UserPresenceState(
states = [] row.user_id, row.state, row.last_active_ts,
for row in stream["rows"]: row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg,
( row.currently_active
position, user_id, state, last_active_ts, ) for row in rows]
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
state = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream: for state in states:
stream_id = int(stream["position"]) self.user_to_current_state[row.user_id] = state
yield self.notify_from_replication(states, stream_id)
stream_id = token
yield self.notify_from_replication(states, stream_id)
def get_currently_syncing_users(self):
return [
user_id for user_id, count in self.user_to_num_current_syncs.iteritems()
if count > 0
]
class SynchrotronTyping(object): class SynchrotronTyping(object):
@ -257,16 +211,13 @@ class SynchrotronTyping(object):
# value which we *must* use for the next replication request. # value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial} return {"typing": self._latest_room_serial}
def process_replication(self, result): def process_replication_rows(self, token, rows):
stream = result.get("typing") self._latest_room_serial = token
if stream:
self._latest_room_serial = int(stream["position"])
for row in stream["rows"]: for row in rows:
position, room_id, typing_json = row typing = json.loads(row.user_ids)
typing = json.loads(typing_json) self._room_serials[row.room_id] = token
self._room_serials[room_id] = position self._room_typing[row.room_id] = typing
self._room_typing[room_id] = typing
class SynchrotronApplicationService(object): class SynchrotronApplicationService(object):
@ -351,118 +302,10 @@ class SynchrotronServer(HomeServer):
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks self.get_tcp_replication().start_replication(self)
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
def notify_from_stream( def build_tcp_replication(self):
result, stream_name, stream_key, room=None, user=None return SyncReplicationHandler(self)
):
stream = result.get(stream_name)
if stream:
position_index = stream["field_names"].index("position")
if room:
room_index = stream["field_names"].index(room)
if user:
user_index = stream["field_names"].index(user)
users = ()
rooms = ()
for row in stream["rows"]:
position = row[position_index]
if user:
users = (row[user_index],)
if room:
rooms = (row[room_index],)
notifier.on_new_event(
stream_key, position, users=users, rooms=rooms
)
@defer.inlineCallbacks
def notify_device_list_update(result):
stream = result.get("device_lists")
if not stream:
return
position_index = stream["field_names"].index("position")
user_index = stream["field_names"].index("user_id")
for row in stream["rows"]:
position = row[position_index]
user_id = row[user_index]
room_ids = yield store.get_rooms_for_user(user_id)
notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
@defer.inlineCallbacks
def notify(result):
stream = result.get("events")
if stream:
max_position = stream["position"]
event_map = yield store.get_events([row[1] for row in stream["rows"]])
for row in stream["rows"]:
position = row[0]
event_id = row[1]
event = event_map.get(event_id, None)
if not event:
continue
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
notifier.on_new_room_event(
event, position, max_position, extra_users
)
notify_from_stream(
result, "push_rules", "push_rules_key", user="user_id"
)
notify_from_stream(
result, "user_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "room_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "tag_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "receipts", "receipt_key", room="room_id"
)
notify_from_stream(
result, "typing", "typing_key", room="room_id"
)
notify_from_stream(
result, "to_device", "to_device_key", user="user_id"
)
yield notify_device_list_update(result)
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
typing_handler.process_replication(result)
yield presence_handler.process_replication(result)
yield notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def build_presence_handler(self): def build_presence_handler(self):
return SynchrotronPresence(self) return SynchrotronPresence(self)
@ -471,6 +314,79 @@ class SynchrotronServer(HomeServer):
return SynchrotronTyping(self) return SynchrotronTyping(self)
class SyncReplicationHandler(ReplicationClientHandler):
def __init__(self, hs):
super(SyncReplicationHandler, self).__init__(hs.get_datastore())
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()
self.presence_handler.sync_callback = self.send_user_sync
def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
preserve_fn(self.process_and_notify)(stream_name, token, rows)
def get_streams_to_replicate(self):
args = super(SyncReplicationHandler, self).get_streams_to_replicate()
args.update(self.typing_handler.stream_positions())
return args
def get_currently_syncing_users(self):
return self.presence_handler.get_currently_syncing_users()
@defer.inlineCallbacks
def process_and_notify(self, stream_name, token, rows):
if stream_name == "events":
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
event = yield self.store.get_event(row.event_id)
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
max_token = self.store.get_room_max_stream_ordering()
self.notifier.on_new_room_event(
event, token, max_token, extra_users
)
elif stream_name == "push_rules":
self.notifier.on_new_event(
"push_rules_key", token, users=[row.user_id for row in rows],
)
elif stream_name in ("account_data", "tag_account_data",):
self.notifier.on_new_event(
"account_data_key", token, users=[row.user_id for row in rows],
)
elif stream_name == "receipts":
self.notifier.on_new_event(
"receipt_key", token, rooms=[row.room_id for row in rows],
)
elif stream_name == "typing":
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows],
)
elif stream_name == "to_device":
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
self.notifier.on_new_event(
"to_device_key", token, users=entities,
)
elif stream_name == "device_lists":
all_room_ids = set()
for row in rows:
room_ids = yield self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
self.notifier.on_new_event(
"device_list_key", token, rooms=all_room_ids,
)
elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows)
def start(config_options): def start(config_options):
try: try:
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
@ -514,7 +430,6 @@ def start(config_options):
def start(): def start():
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate()
ss.get_state_handler().start_caching() ss.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -28,7 +28,9 @@ class WorkerConfig(Config):
self.worker_pid_file = config.get("worker_pid_file") self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_file = config.get("worker_log_file") self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config") self.worker_log_config = config.get("worker_log_config")
self.worker_replication_url = config.get("worker_replication_url") self.worker_replication_host = config.get("worker_replication_host", None)
self.worker_replication_port = config.get("worker_replication_port", None)
self.worker_name = config.get("worker_name", self.worker_app)
if self.worker_listeners: if self.worker_listeners:
for listener in self.worker_listeners: for listener in self.worker_listeners:

View file

@ -15,7 +15,6 @@
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
else: else:
self._cache_id_gen = None self._cache_id_gen = None
self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" self.hs = hs
self.http_client = hs.get_simple_http_client()
def stream_positions(self): def stream_positions(self):
pos = {} pos = {}
@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token() pos["caches"] = self._cache_id_gen.get_current_token()
return pos return pos
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("caches") if stream_name == "caches":
if stream: self._cache_id_gen.advance(token)
for row in stream["rows"]: for row in rows:
(
position, cache_func, keys, invalidation_ts,
) = row
try: try:
getattr(self, cache_func).invalidate(tuple(keys)) getattr(self, row.cache_func).invalidate(tuple(row.keys))
except AttributeError: except AttributeError:
# We probably haven't pulled in the cache in this worker, # We probably haven't pulled in the cache in this worker,
# which is fine. # which is fine.
pass pass
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)
def _invalidate_cache_and_stream(self, txn, cache_func, keys): def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys) txn.call_after(self._send_invalidation_poke, cache_func, keys)
@defer.inlineCallbacks
def _send_invalidation_poke(self, cache_func, keys): def _send_invalidation_poke(self, cache_func, keys):
try: self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
yield self.http_client.post_json_get_json(self.expire_cache_url, {
"invalidate": [{
"name": cache_func.__name__,
"keys": list(keys),
}]
})
except:
logger.exception("Failed to poke on expire_cache")

View file

@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
result["tag_account_data"] = position result["tag_account_data"] = position
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("user_account_data") if stream_name == "tag_account_data":
if stream: self._account_data_id_gen.advance(token)
self._account_data_id_gen.advance(int(stream["position"])) for row in rows:
for row in stream["rows"]: self.get_tags_for_user.invalidate((row.user_id,))
position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,)
)
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed( self._account_data_stream_cache.entity_has_changed(
user_id, position row.user_id, token
) )
elif stream_name == "account_data":
stream = result.get("room_account_data") self._account_data_id_gen.advance(token)
if stream: for row in rows:
self._account_data_id_gen.advance(int(stream["position"])) if not row.room_id:
for row in stream["rows"]: self.get_global_account_data_by_type_for_user.invalidate(
position, user_id = row[:2] (row.data_type, row.user_id,)
self.get_account_data_for_user.invalidate((user_id,)) )
self.get_account_data_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed( self._account_data_stream_cache.entity_has_changed(
user_id, position row.user_id, token
) )
return super(SlavedAccountDataStore, self).process_replication_rows(
stream = result.get("tag_account_data") stream_name, token, rows
if stream: )
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)

View file

@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
result["to_device"] = self._device_inbox_id_gen.get_current_token() result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("to_device") if stream_name == "to_device":
if stream: self._device_inbox_id_gen.advance(token)
self._device_inbox_id_gen.advance(int(stream["position"])) for row in rows:
for row in stream["rows"]: if row.entity.startswith("@"):
stream_id = row[0]
entity = row[1]
if entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed( self._device_inbox_stream_cache.entity_has_changed(
entity, stream_id row.entity, token
) )
else: else:
self._device_federation_outbox_stream_cache.entity_has_changed( self._device_federation_outbox_stream_cache.entity_has_changed(
entity, stream_id row.entity, token
) )
return super(SlavedDeviceInboxStore, self).process_replication_rows(
return super(SlavedDeviceInboxStore, self).process_replication(result) stream_name, token, rows
)

View file

@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore):
result["device_lists"] = self._device_list_id_gen.get_current_token() result["device_lists"] = self._device_list_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("device_lists") if stream_name == "device_lists":
if stream: self._device_list_id_gen.advance(token)
self._device_list_id_gen.advance(int(stream["position"])) for row in rows:
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
self._device_list_stream_cache.entity_has_changed( self._device_list_stream_cache.entity_has_changed(
user_id, stream_id row.user_id, token
) )
if destination: if row.destination:
self._device_list_federation_stream_cache.entity_has_changed( self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id row.destination, token
) )
return super(SlavedDeviceStore, self).process_replication_rows(
return super(SlavedDeviceStore, self).process_replication(result) stream_name, token, rows
)

View file

@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore):
result["backfill"] = -self._backfill_id_gen.get_current_token() result["backfill"] = -self._backfill_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("events") if stream_name == "events":
if stream: self._stream_id_gen.advance(token)
self._stream_id_gen.advance(int(stream["position"])) for row in rows:
self.invalidate_caches_for_event(
if stream["rows"]: token, row.event_id, row.room_id, row.type, row.state_key,
logger.info("Got %d event rows", len(stream["rows"])) row.redacts,
backfilled=False,
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False,
) )
elif stream_name == "backfill":
stream = result.get("backfill") self._backfill_id_gen.advance(-token)
if stream: for row in rows:
self._backfill_id_gen.advance(-int(stream["position"])) self.invalidate_caches_for_event(
for row in stream["rows"]: -token, row.event_id, row.room_id, row.type, row.state_key,
self._process_replication_row( row.redacts,
row, backfilled=True, backfilled=True,
) )
return super(SlavedEventStore, self).process_replication_rows(
stream = result.get("forward_ex_outliers") stream_name, token, rows
if stream:
self._stream_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled):
stream_ordering = row[0] if not backfilled else -row[0]
self.invalidate_caches_for_event(
stream_ordering, row[1], row[2], row[3], row[4], row[5],
backfilled=backfilled,
) )
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,

View file

@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore):
result["presence"] = position result["presence"] = position
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("presence") if stream_name == "presence":
if stream: self._presence_id_gen.advance(token)
self._presence_id_gen.advance(int(stream["position"])) for row in rows:
for row in stream["rows"]:
position, user_id = row[:2]
self.presence_stream_cache.entity_has_changed( self.presence_stream_cache.entity_has_changed(
user_id, position row.user_id, token
) )
self._get_presence_for_user.invalidate((user_id,)) self._get_presence_for_user.invalidate((row.user_id,))
return super(SlavedPresenceStore, self).process_replication_rows(
return super(SlavedPresenceStore, self).process_replication(result) stream_name, token, rows
)

View file

@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore):
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("push_rules") if stream_name == "push_rules":
if stream: self._push_rules_stream_id_gen.advance(token)
for row in stream["rows"]: for row in rows:
position = row[0] self.get_push_rules_for_user.invalidate((row.user_id,))
user_id = row[2] self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
self.push_rules_stream_cache.entity_has_changed( self.push_rules_stream_cache.entity_has_changed(
user_id, position row.user_id, token
) )
return super(SlavedPushRuleStore, self).process_replication_rows(
self._push_rules_stream_id_gen.advance(int(stream["position"])) stream_name, token, rows
)
return super(SlavedPushRuleStore, self).process_replication(result)

View file

@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token() result["pushers"] = self._pushers_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("pushers") if stream_name == "pushers":
if stream: self._pushers_id_gen.advance(token)
self._pushers_id_gen.advance(int(stream["position"])) return super(SlavedPusherStore, self).process_replication_rows(
stream_name, token, rows
stream = result.get("deleted_pushers") )
if stream:
self._pushers_id_gen.advance(int(stream["position"]))
return super(SlavedPusherStore, self).process_replication(result)

View file

@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore):
result["receipts"] = self._receipts_id_gen.get_current_token() result["receipts"] = self._receipts_id_gen.get_current_token()
return result return result
def process_replication(self, result):
stream = result.get("receipts")
if stream:
self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate( self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type) (user_id, room_id, receipt_type)
) )
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "receipts":
self._receipts_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super(SlavedReceiptsStore, self).process_replication_rows(
stream_name, token, rows
)

View file

@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore):
result["public_rooms"] = self._public_room_id_gen.get_current_token() result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("public_rooms") if stream_name == "public_rooms":
if stream: self._public_room_id_gen.advance(token)
self._public_room_id_gen.advance(int(stream["position"]))
return super(RoomStore, self).process_replication(result) return super(RoomStore, self).process_replication_rows(
stream_name, token, rows
)

View file

@ -0,0 +1,196 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A replication client for use by synapse workers.
"""
from twisted.internet import reactor, defer
from twisted.internet.protocol import ReconnectingClientFactory
from .commands import (
FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
)
from .protocol import ClientReplicationStreamProtocol
import logging
logger = logging.getLogger(__name__)
class ReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
Accepts a handler that will be called when new data is available or data
is required.
"""
maxDelay = 5 # Try at least once every N seconds
def __init__(self, hs, client_name, handler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
self._clock = hs.get_clock() # As self.clock is defined in super class
reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector):
logger.info("Connecting to replication: %r", connector.getDestination())
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
self.resetDelay()
return ClientReplicationStreamProtocol(
self.client_name, self.server_name, self._clock, self.handler
)
def clientConnectionLost(self, connector, reason):
logger.error("Lost replication conn: %r", reason)
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
def clientConnectionFailed(self, connector, reason):
logger.error("Failed to connect to replication: %r", reason)
ReconnectingClientFactory.clientConnectionFailed(
self, connector, reason
)
class ReplicationClientHandler(object):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
def __init__(self, store):
self.store = store
# The current connection. None if we are currently (re)connecting
self.connection = None
# Any pending commands to be sent once a new connection has been
# established
self.pending_commands = []
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
self.awaiting_syncs = {}
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
client_name = hs.config.worker_name
factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
reactor.connectTCP(host, port, factory)
def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
logger.info("Received rdata %s -> %s", stream_name, token)
self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
for the sync with the given data.
Used by tests.
"""
d = self.awaiting_syncs.pop(data, None)
if d:
d.callback(data)
def get_streams_to_replicate(self):
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns a dictionary of stream name to token.
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users. (Overriden by the synchrotron's only)
"""
return []
def send_command(self, cmd):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
if self.connection:
self.connection.send_command(cmd)
else:
logger.warn("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
def send_remove_pusher(self, app_id, push_key, user_id):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func, keys):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func, keys)
self.send_command(cmd)
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
Used by tests.
"""
return self.awaiting_syncs.setdefault(data, defer.Deferred())
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
self.connection = connection
if connection:
for cmd in self.pending_commands:
connection.send_command(cmd)
self.pending_commands = []

View file

@ -132,6 +132,7 @@ class HomeServer(object):
'federation_sender', 'federation_sender',
'receipts_handler', 'receipts_handler',
'macaroon_generator', 'macaroon_generator',
'tcp_replication',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -290,6 +291,9 @@ class HomeServer(object):
def build_receipts_handler(self): def build_receipts_handler(self):
return ReceiptsHandler(self) return ReceiptsHandler(self)
def build_tcp_replication(self):
raise NotImplementedError()
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View file

@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, reactor
from tests import unittest from tests import unittest
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.replication.tcp.client import (
ReplicationClientHandler, ReplicationClientFactory,
)
class BaseSlavedStoreTestCase(unittest.TestCase): class BaseSlavedStoreTestCase(unittest.TestCase):
@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
) )
self.hs.get_ratelimiter().send_message.return_value = (True, 0) self.hs.get_ratelimiter().send_message.return_value = (True, 0)
self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore() self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0 self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
listener = reactor.listenUNIX("\0xxx", server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
client_connector = reactor.connectUNIX("\0xxx", client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
@defer.inlineCallbacks @defer.inlineCallbacks
def replicate(self): def replicate(self):
streams = self.slaved_store.stream_positions() yield self.streamer.on_notifier_poke()
writer = yield self.replication.replicate(streams, 100) d = self.replication_handler.await_sync("replication_test")
result = writer.finish() self.streamer.send_sync_to_all_connections("replication_test")
yield self.slaved_store.process_replication(result) yield d
@defer.inlineCallbacks @defer.inlineCallbacks
def check(self, method, args, expected_result=None): def check(self, method, args, expected_result=None):