Refactor getting replication updates from database v2. (#7740)

This commit is contained in:
Erik Johnston 2020-07-07 12:11:35 +01:00 committed by GitHub
parent d378c3da78
commit 67d7756fcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 336 additions and 195 deletions

1
changelog.d/7740.misc Normal file
View file

@ -0,0 +1 @@
Refactor getting replication updates from database.

View file

@ -294,6 +294,9 @@ class TypingHandler(object):
rows.sort()
limited = False
# We, unusually, use a strict limit here as we have all the rows in
# memory rather than pulling them out of the database with a `LIMIT ?`
# clause.
if len(rows) > limit:
rows = rows[:limit]
current_id = rows[-1][0]

View file

@ -198,26 +198,6 @@ def current_token_without_instance(
return lambda instance_name: current_token()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return update_function
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
@ -393,7 +373,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
store.get_all_updated_pushers_rows,
)
@ -421,27 +401,13 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
self.store = hs.get_datastore()
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_cache_stream_token,
self._update_function,
store.get_cache_stream_token,
store.get_all_updated_caches,
)
async def _update_function(
self, instance_name: str, from_token: int, upto_token: int, limit: int
):
rows = await self.store.get_all_updated_caches(
instance_name, from_token, upto_token, limit
)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
class PublicRoomsStream(Stream):
"""The public rooms list changed
@ -465,7 +431,7 @@ class PublicRoomsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
store.get_all_new_public_rooms,
)
@ -486,7 +452,7 @@ class DeviceListsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
store.get_all_device_list_changes_for_remotes,
)
@ -504,7 +470,7 @@ class ToDeviceStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
store.get_all_new_device_messages,
)
@ -524,7 +490,7 @@ class TagAccountDataStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
store.get_all_updated_tags,
)
@ -612,7 +578,7 @@ class GroupServerStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
store.get_all_groups_changes,
)
@ -630,7 +596,5 @@ class UserSignatureStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),
store.get_all_user_signature_changes_for_remotes,
)

View file

@ -16,7 +16,7 @@
import itertools
import logging
from typing import Any, Iterable, Optional, Tuple
from typing import Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@ -46,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
):
"""Fetches cache invalidation rows between the two given IDs written
by the given instance. Returns at most `limit` rows.
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for caches replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return []
return [], current_id, False
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
@ -66,7 +83,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall()
updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import List, Tuple
from canonicaljson import json
@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
async def get_all_new_device_messages(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for to device replication stream.
Args:
last_pos(int):
current_pos(int):
limit(int):
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A deferred list of rows from the device inbox
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_pos == current_pos:
return defer.succeed([])
if last_id == current_id:
return [], current_id, False
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
upper_pos = min(current_id, last_id + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
txn.execute(sql, (last_id, upper_pos))
updates = [(row[0], row[1:]) for row in txn]
sql = (
"SELECT max(stream_id), destination"
@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
txn.execute(sql, (last_id, upper_pos))
updates.extend((row[0], row[1:]) for row in txn)
# Order by ascending stream ordering
rows.sort()
updates.sort()
return rows
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return self.db.runInteraction(
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)

View file

@ -582,32 +582,58 @@ class DeviceWorkerStore(SQLBaseStore):
return set()
async def get_all_device_list_changes_for_remotes(
self, from_key: int, to_key: int, limit: int,
) -> List[Tuple[int, str]]:
"""Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is
either a user ID (starting with '@') or a remote destination.
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for device lists replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
if last_id == current_id:
return [], current_id, False
return await self.db.execute(
def _get_all_device_list_changes_for_remotes(txn):
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_device_list_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
_get_all_device_list_changes_for_remotes,
)
@cached(max_entries=10000)

View file

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
from typing import Dict, List, Tuple
from canonicaljson import encode_canonical_json, json
@ -479,34 +479,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
async def get_all_user_signature_changes_for_remotes(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for groups replication stream.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
users or servers, so no `destination` is needed in the returned
list. However, this is needed to poke workers.
Args:
from_key (int): the stream ID to start at (exclusive)
to_key (int): the stream ID to end at (inclusive)
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
return self.db.execute(
if last_id == current_id:
return [], current_id, False
def _get_all_user_signature_changes_for_remotes_txn(txn):
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], (row[1:])) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_user_signature_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
_get_all_user_signature_changes_for_remotes_txn,
)

View file

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from canonicaljson import json
from twisted.internet import defer
@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
def get_all_groups_changes(self, from_token, to_token, limit):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(
from_token
)
async def get_all_groups_changes(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for groups replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
last_id = int(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
if not has_changed:
return defer.succeed([])
return [], current_id, False
def _get_all_groups_changes_txn(txn):
sql = """
@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return [
(stream_id, group_id, user_id, gtype, json.loads(content_json))
txn.execute(sql, (last_id, current_id, limit))
updates = [
(stream_id, (group_id, user_id, gtype, json.loads(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
return self.db.runInteraction(
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn
)

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import Iterable, Iterator
from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json, json
@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore):
rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed(([], []))
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for pushers replication stream.
def get_all_updated_pushers_txn(txn):
sql = (
"SELECT id, user_name, access_token, profile_tag, kind,"
" app_id, app_display_name, device_display_name, pushkey, ts,"
" lang, data"
" FROM pushers"
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
updated = txn.fetchall()
sql = (
"SELECT stream_id, user_id, app_id, pushkey"
" FROM deleted_pushers"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
deleted = txn.fetchall()
return updated, deleted
return self.db.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
def get_all_updated_pushers_rows(self, last_id, current_id, limit):
"""Get all the pushers that have changed between the given tokens.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
Deferred(list(tuple)): each tuple consists of:
stream_id (str)
user_id (str)
app_id (str)
pushkey (str)
was_deleted (bool): whether the pusher was added/updated (False)
or deleted (True)
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return defer.succeed([])
return [], current_id, False
def get_all_updated_pushers_rows_txn(txn):
sql = (
"SELECT id, user_name, app_id, pushkey"
" FROM pushers"
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC LIMIT ?"
)
sql = """
SELECT id, user_name, app_id, pushkey
FROM pushers
WHERE ? < id AND id <= ?
ORDER BY id ASC LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
results = [list(row) + [False] for row in txn]
updates = [
(stream_id, (user_name, app_id, pushkey, False))
for stream_id, user_name, app_id, pushkey in txn
]
sql = (
"SELECT stream_id, user_id, app_id, pushkey"
" FROM deleted_pushers"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
sql = """
SELECT stream_id, user_id, app_id, pushkey
FROM deleted_pushers
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates.extend(
(stream_id, (user_name, app_id, pushkey, True))
for stream_id, user_name, app_id, pushkey in txn
)
results.extend(list(row) + [True] for row in txn)
results.sort() # Sort so that they're ordered by stream id
updates.sort() # Sort so that they're ordered by stream id
return results
limited = False
upper_bound = current_id
if len(updates) >= limit:
limited = True
upper_bound = updates[-1][0]
return self.db.runInteraction(
return updates, upper_bound, limited
return await self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)

View file

@ -803,7 +803,32 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
def get_all_new_public_rooms(self, prev_id, current_id, limit):
async def get_all_new_public_rooms(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for public rooms replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return [], current_id, False
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
@ -813,13 +838,17 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
if prev_id == current_id:
return defer.succeed([])
return updates, upto_token, limited
return self.db.runInteraction(
return await self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)

View file

@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import List, Tuple
from canonicaljson import json
@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore):
return deferred
@defer.inlineCallbacks
def get_all_updated_tags(self, last_id, current_id, limit):
"""Get all the client tags that have changed on the server
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for tags replication stream.
Args:
last_id(int): The position to fetch from.
current_id(int): The position to fetch up to.
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A deferred list of tuples of stream_id int, user_id string,
room_id string, tag string and content string.
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return []
return [], current_id, False
def get_all_updated_tags_txn(txn):
sql = (
@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
tag_ids = yield self.db.runInteraction(
tag_ids = await self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json))
results.append((stream_id, (user_id, room_id, tag_json)))
return results
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
tags = yield self.db.runInteraction(
tags = await self.db.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
return results
limited = False
upto_token = current_id
if len(results) >= limit:
upto_token = results[-1][0]
limited = True
return results, upto_token, limited
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):