Support any process writing to cache invalidation stream. (#7436)

This commit is contained in:
Erik Johnston 2020-05-07 13:51:08 +01:00 committed by GitHub
parent 2929ce29d6
commit d7983b63a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 225 additions and 230 deletions

1
changelog.d/7436.misc Normal file
View file

@ -0,0 +1 @@
Support any process writing to cache invalidation stream.

View file

@ -219,10 +219,6 @@ Asks the server for the current position of all streams.
Inform the server a pusher should be removed
#### INVALIDATE_CACHE (C)
Inform the server a cache should be invalidated
### REMOTE_SERVER_UP (S, C)
Inform other processes that a remote server may have come back online.

View file

@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [
"presence_stream",
"push_rules_stream",
"ex_outlier_stream",
"cache_invalidation_stream",
"cache_invalidation_stream_by_instance",
"public_room_list_stream",
"state_group_edges",
"stream_ordering_to_exterm",
@ -188,7 +188,7 @@ class MockHomeserver:
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server_name
self.version_string = "Synapse/"+get_version_string(synapse)
self.version_string = "Synapse/" + get_version_string(synapse)
def get_clock(self):
return self.clock

View file

@ -18,14 +18,10 @@ from typing import Optional
import six
from synapse.storage.data_stores.main.cache import (
CURRENT_STATE_CACHE_NAME,
CacheInvalidationWorkerStore,
)
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage.util.id_generators import MultiWriterIdGenerator
logger = logging.getLogger(__name__)
@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
) # type: Optional[SlavedIdTracker]
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
) # type: Optional[MultiWriterIdGenerator]
else:
self._cache_id_gen = None
self.hs = hs
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
raise Exception(
"Can't send an 'invalidate all' for current state cache"
)
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys)
def _send_invalidation_poke(self, cache_func, keys):
self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)

View file

@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)
for row in rows:
@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedAccountDataStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)
for row in rows:
@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token
)
return super(SlavedDeviceInboxStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
self._invalidate_caches_for_devices(token, rows)
@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(self, token, rows):
for row in rows:

View file

@ -93,7 +93,7 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
for row in rows:
@ -111,9 +111,7 @@ class SlavedEventStore(
row.relates_to,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row):
data = row.data

View file

@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedGroupServerStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super(SlavedPresenceStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedPushRuleStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)
return super(SlavedPusherStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "receipts":
self._receipts_id_gen.advance(token)
for row in rows:
@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super(SlavedReceiptsStore, self).process_replication_rows(
stream_name, token, rows
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -100,10 +100,10 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
self.store.process_replication_rows(stream_name, instance_name, token, rows)
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""

View file

@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
class InvalidateCacheCommand(Command):
"""Sent by the client to invalidate an upstream cache.
THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
NOT DISASTROUS IF WE DROP ON THE FLOOR.
Mainly used to invalidate destination retry timing caches.
Format::
INVALIDATE_CACHE <cache_func> <keys_json>
Where <keys_json> is a json list.
"""
NAME = "INVALIDATE_CACHE"
def __init__(self, cache_func, keys):
self.cache_func = cache_func
self.keys = keys
@classmethod
def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1)
return cls(cache_func, json.loads(keys_json))
def to_line(self):
return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@ -439,7 +408,6 @@ _COMMANDS = (
UserSyncCommand,
FederationAckCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,

View file

@ -15,18 +15,7 @@
# limitations under the License.
import logging
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from prometheus_client import Counter
@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@ -171,7 +159,7 @@ class ReplicationCommandHandler:
return
for stream_name, stream in self._streams.items():
current_token = stream.current_token()
current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(stream_name, self._instance_name, current_token)
)
@ -210,18 +198,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
):
invalidate_cache_counter.inc()
if self._is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self._store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
@ -295,7 +271,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
@ -326,7 +302,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
current_token = stream.current_token()
current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@ -363,7 +339,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(stream_name, cmd.token)
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@ -491,12 +469,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(
self,
user_id: str,

View file

@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
from synapse.replication.tcp.streams import (
STREAMS_MAP,
CachesStream,
FederationStream,
Stream,
)
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@ -71,11 +76,16 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level
# Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream]
# All workers can write to the cache invalidation stream.
self.streams.append(CachesStream(hs))
if hs.config.worker_app is None:
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
@ -83,6 +93,10 @@ class ReplicationStreamer(object):
# has been disabled on the master.
continue
if stream == CachesStream:
# We've already added it above.
continue
self.streams.append(stream(hs))
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
@ -145,7 +159,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
if stream.last_token == stream.current_token():
if stream.last_token == stream.current_token(
self._instance_name
):
continue
if self._replication_torture_level:
@ -157,7 +173,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
stream.current_token(),
stream.current_token(self._instance_name),
)
try:
updates, current_token, limited = await stream.get_updates()

View file

@ -95,20 +95,25 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[], Token],
current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
current_token_function and update_function are callbacks which should be
implemented by subclasses.
`current_token_function` and `update_function` are callbacks which
should be implemented by subclasses.
current_token_function is called to get the current token of the underlying
stream. It is only meaningful on the process that is the source of the
replication stream (ie, usually the master).
`current_token_function` takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
Args:
local_instance_name: The instance name of the current process
@ -120,13 +125,13 @@ class Stream(object):
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token()
self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.last_token = self.current_token()
self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
@ -138,7 +143,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates
to fetch.
"""
current_token = self.current_token()
current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@ -170,6 +175,16 @@ class Stream(object):
return updates, upto_token, limited
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
"""Takes a current token callback function for a single writer stream
that doesn't take an instance name parameter and wraps it in a function that
does accept an instance name parameter but ignores it.
"""
return lambda instance_name: current_token()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
@ -235,7 +250,7 @@ class BackfillStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token,
current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@ -271,7 +286,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
hs.get_instance_name(),
current_token_without_instance(store.get_current_presence_token),
update_function,
)
@ -296,7 +313,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
hs.get_instance_name(),
current_token_without_instance(typing_handler.get_current_token),
update_function,
)
@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id,
current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts),
)
@ -339,7 +358,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function
)
def _current_token(self) -> int:
def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@ -373,7 +392,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token,
current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@ -402,13 +421,27 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
self.store.get_cache_stream_token,
self._update_function,
)
async def _update_function(
self, instance_name: str, from_token: int, upto_token: int, limit: int
):
rows = await self.store.get_all_updated_caches(
instance_name, from_token, upto_token, limit
)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
class PublicRoomsStream(Stream):
"""The public rooms list changed
@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id,
current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
)
@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token,
current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
)
@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id,
current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
)
@ -510,7 +543,7 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function),
)
@ -541,7 +574,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_group_stream_token,
current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
)
@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),

View file

@ -20,7 +20,7 @@ from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@ -119,7 +119,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self._store.get_current_events_token,
current_token_without_instance(self._store.get_current_events_token),
self._update_function,
)

View file

@ -15,7 +15,11 @@
# limitations under the License.
from collections import namedtuple
from synapse.replication.tcp.streams._base import Stream, make_http_update_function
from synapse.replication.tcp.streams._base import (
Stream,
current_token_without_instance,
make_http_update_function,
)
class FederationStream(Stream):
@ -41,7 +45,9 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
current_token = current_token_without_instance(
federation_sender.get_current_token
)
update_function = federation_sender.get_replication_rows
elif hs.should_send_federation():
@ -58,7 +64,7 @@ class FederationStream(Stream):
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
def _stub_current_token():
def _stub_current_token(instance_name: str) -> int:
# dummy current-token method for use on workers
return 0

View file

@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta):
self.db = database
self.rand = random.SystemRandom()
def process_replication_rows(self, stream_name, instance_name, token, rows):
pass
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.

View file

@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
ChainedIdGenerator,
IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationStore
from .cache import CacheInvalidationWorkerStore
from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
@ -112,8 +113,8 @@ class DataStore(
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
CacheInvalidationStore,
UIAuthStore,
CacheInvalidationWorkerStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
@ -170,8 +171,14 @@ class DataStore(
)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id"
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
instance_name="master",
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
)
else:
self._cache_id_gen = None

View file

@ -16,11 +16,10 @@
import itertools
import logging
from typing import Any, Iterable, Optional, Tuple
from twisted.internet import defer
from typing import Any, Iterable, Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
def get_all_updated_caches(self, last_id, current_id, limit):
def __init__(self, database: Database, db_conn, hs):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
):
"""Fetches cache invalidation rows between the two given IDs written
by the given instance. Returns at most `limit` rows.
"""
if last_id == current_id:
return defer.succeed([])
return []
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit))
sql = """
SELECT stream_id, cache_func, keys, invalidation_ts
FROM cache_invalidation_stream_by_instance
WHERE stream_id > ? AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall()
return self.db.runInteraction(
return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
raise Exception(
"Can't send an 'invalidate all' for current state cache"
)
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
cache_func = getattr(self, cache_name, None)
if not cache_func:
return
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
cache_func.invalidate(keys)
await self.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
keys,
)
super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
txn.call_on_exception(ctx.__exit__, None, None, None)
txn.call_after(ctx.__exit__, None, None, None)
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None:
@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
self.db.simple_insert_txn(
txn,
table="cache_invalidation_stream",
table="cache_invalidation_stream_by_instance",
values={
"stream_id": stream_id,
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
},
)
def get_cache_stream_token(self):
def get_cache_stream_token(self, instance_name):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
return self._cache_id_gen.get_current_token(instance_name)
else:
return 0

View file

@ -0,0 +1,30 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We keep the old table here to enable us to roll back. It doesn't matter
-- that we have dropped all the data here.
TRUNCATE cache_invalidation_stream;
CREATE TABLE cache_invalidation_stream_by_instance (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
cache_func TEXT NOT NULL,
keys TEXT[],
invalidation_ts BIGINT
);
CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
CREATE SEQUENCE cache_invalidation_stream_seq;

View file

@ -29,6 +29,8 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
# XXX: If you're about to bump this to 59 (or higher) please create an update
# that drops the unused `cache_invalidation_stream` table, as per #7436!
SCHEMA_VERSION = 58
dir_path = os.path.abspath(os.path.dirname(__file__))