0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-04 17:54:04 +01:00

Merge pull request from matrix-org/erikj/group_sync_support

Add groups to sync stream
This commit is contained in:
Erik Johnston 2017-07-21 11:05:39 +01:00 committed by GitHub
commit 96917d5552
12 changed files with 283 additions and 12 deletions
synapse
app
handlers
replication
slave/storage
tcp
rest/client/v2_alpha
storage
streams
types.py
tests/rest/client/v1

View file

@ -41,6 +41,7 @@ from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -75,6 +76,7 @@ class SynchrotronSlavedStore(
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedGroupServerStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore, SlavedDeviceStore,
SlavedClientIpStore, SlavedClientIpStore,
@ -409,6 +411,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
) )
elif stream_name == "presence": elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows) yield self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows],
)
def start(config_options): def start(config_options):

View file

@ -63,6 +63,7 @@ class GroupsLocalHandler(object):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.signing_key = hs.config.signing_key[0] self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.notifier = hs.get_notifier()
self.attestations = hs.get_groups_attestation_signing() self.attestations = hs.get_groups_attestation_signing()
# Ensure attestations get renewed # Ensure attestations get renewed
@ -212,13 +213,16 @@ class GroupsLocalHandler(object):
user_id=user_id, user_id=user_id,
) )
yield self.store.register_user_group_membership( token = yield self.store.register_user_group_membership(
group_id, user_id, group_id, user_id,
membership="join", membership="join",
is_admin=False, is_admin=False,
local_attestation=local_attestation, local_attestation=local_attestation,
remote_attestation=remote_attestation, remote_attestation=remote_attestation,
) )
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue({}) defer.returnValue({})
@ -258,11 +262,14 @@ class GroupsLocalHandler(object):
if "avatar_url" in content["profile"]: if "avatar_url" in content["profile"]:
local_profile["avatar_url"] = content["profile"]["avatar_url"] local_profile["avatar_url"] = content["profile"]["avatar_url"]
yield self.store.register_user_group_membership( token = yield self.store.register_user_group_membership(
group_id, user_id, group_id, user_id,
membership="invite", membership="invite",
content={"profile": local_profile, "inviter": content["inviter"]}, content={"profile": local_profile, "inviter": content["inviter"]},
) )
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue({"state": "invite"}) defer.returnValue({"state": "invite"})
@ -271,10 +278,13 @@ class GroupsLocalHandler(object):
"""Remove a user from a group """Remove a user from a group
""" """
if user_id == requester_user_id: if user_id == requester_user_id:
yield self.store.register_user_group_membership( token = yield self.store.register_user_group_membership(
group_id, user_id, group_id, user_id,
membership="leave", membership="leave",
) )
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
# TODO: Should probably remember that we tried to leave so that we can # TODO: Should probably remember that we tried to leave so that we can
# retry if the group server is currently down. # retry if the group server is currently down.
@ -297,10 +307,13 @@ class GroupsLocalHandler(object):
"""One of our users was removed/kicked from a group """One of our users was removed/kicked from a group
""" """
# TODO: Check if user in group # TODO: Check if user in group
yield self.store.register_user_group_membership( token = yield self.store.register_user_group_membership(
group_id, user_id, group_id, user_id,
membership="leave", membership="leave",
) )
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_joined_groups(self, user_id): def get_joined_groups(self, user_id):

View file

@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
return True return True
class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
"join",
"invite",
"leave",
])):
__slots__ = []
def __nonzero__(self):
return bool(self.join or self.invite or self.leave)
class SyncResult(collections.namedtuple("SyncResult", [ class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync "next_batch", # Token for the next sync
"presence", # List of presence events for the user. "presence", # List of presence events for the user.
@ -119,6 +130,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"device_lists", # List of user_ids whose devices have chanegd "device_lists", # List of user_ids whose devices have chanegd
"device_one_time_keys_count", # Dict of algorithm to count for one time keys "device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device # for this device
"groups",
])): ])):
__slots__ = [] __slots__ = []
@ -134,7 +146,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.archived or self.archived or
self.account_data or self.account_data or
self.to_device or self.to_device or
self.device_lists self.device_lists or
self.groups
) )
@ -560,6 +573,8 @@ class SyncHandler(object):
user_id, device_id user_id, device_id
) )
yield self._generate_sync_entry_for_groups(sync_result_builder)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
@ -568,10 +583,56 @@ class SyncHandler(object):
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device, to_device=sync_result_builder.to_device,
device_lists=device_lists, device_lists=device_lists,
groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts, device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))
@measure_func("_generate_sync_entry_for_groups")
@defer.inlineCallbacks
def _generate_sync_entry_for_groups(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
if since_token and since_token.groups_key:
results = yield self.store.get_groups_changes_for_user(
user_id, since_token.groups_key, now_token.groups_key,
)
else:
results = yield self.store.get_all_groups_for_user(
user_id, now_token.groups_key,
)
invited = {}
joined = {}
left = {}
for result in results:
membership = result["membership"]
group_id = result["group_id"]
gtype = result["type"]
content = result["content"]
if membership == "join":
if gtype == "membership":
content.pop("membership", None)
invited[group_id] = content["content"]
else:
joined.setdefault(group_id, {})[gtype] = content
elif membership == "invite":
if gtype == "membership":
content.pop("membership", None)
invited[group_id] = content["content"]
else:
if gtype == "membership":
left[group_id] = content["content"]
sync_result_builder.groups = GroupsSyncResult(
join=joined,
invite=invited,
leave=left,
)
@measure_func("_generate_sync_entry_for_device_list") @measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder): def _generate_sync_entry_for_device_list(self, sync_result_builder):
@ -1260,6 +1321,7 @@ class SyncResultBuilder(object):
self.invited = [] self.invited = []
self.archived = [] self.archived = []
self.device = [] self.device = []
self.groups = None
class RoomSyncResultBuilder(object): class RoomSyncResultBuilder(object):

View file

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(db_conn, hs)
self.hs = hs
self._group_updates_id_gen = SlavedIdTracker(
db_conn, "local_group_updates", "stream_id",
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
)
get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
get_group_stream_token = DataStore.get_group_stream_token.__func__
get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(
row.user_id, token
)
return super(SlavedGroupServerStore, self).process_replication_rows(
stream_name, token, rows
)

View file

@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
"state_key", # str "state_key", # str
"event_id", # str, optional "event_id", # str, optional
)) ))
GroupsStreamRow = namedtuple("GroupsStreamRow", (
"group_id", # str
"user_id", # str
"type", # str
"content", # dict
))
class Stream(object): class Stream(object):
@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
super(CurrentStateDeltaStream, self).__init__(hs) super(CurrentStateDeltaStream, self).__init__(hs)
class GroupServerStream(Stream):
NAME = "groups"
ROW_TYPE = GroupsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
STREAMS_MAP = { STREAMS_MAP = {
stream.NAME: stream stream.NAME: stream
for stream in ( for stream in (
@ -482,5 +501,6 @@ STREAMS_MAP = {
TagAccountDataStream, TagAccountDataStream,
AccountDataStream, AccountDataStream,
CurrentStateDeltaStream, CurrentStateDeltaStream,
GroupServerStream,
) )
} }

View file

@ -199,6 +199,11 @@ class SyncRestServlet(RestServlet):
"invite": invited, "invite": invited,
"leave": archived, "leave": archived,
}, },
"groups": {
"join": sync_result.groups.join,
"invite": sync_result.groups.invite,
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count, "device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(), "next_batch": sync_result.next_batch.to_string(),
} }

View file

@ -136,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "pushers", "id", db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id",
)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator( self._cache_id_gen = StreamIdGenerator(
@ -236,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=curr_state_delta_prefill, prefilled_cache=curr_state_delta_prefill,
) )
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
db_conn, "local_group_updates",
entity_column="user_id",
stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(),
limit=1000,
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", min_group_updates_id,
prefilled_cache=_group_updates_prefill,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",

View file

@ -776,7 +776,7 @@ class GroupServerStore(SQLBaseStore):
remote_attestation (dict): If remote group then store the remote remote_attestation (dict): If remote group then store the remote
attestation from the group, else None. attestation from the group, else None.
""" """
def _register_user_group_membership_txn(txn): def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert? # TODO: Upsert?
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
@ -798,6 +798,19 @@ class GroupServerStore(SQLBaseStore):
}, },
) )
self._simple_insert_txn(
txn,
table="local_group_updates",
values={
"stream_id": next_id,
"group_id": group_id,
"user_id": user_id,
"type": "membership",
"content": json.dumps({"membership": membership, "content": content}),
}
)
self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
# TODO: Insert profile to ensure it comes down stream if its a join. # TODO: Insert profile to ensure it comes down stream if its a join.
if membership == "join": if membership == "join":
@ -840,9 +853,12 @@ class GroupServerStore(SQLBaseStore):
}, },
) )
return next_id
with self._group_updates_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction(
"register_user_group_membership", "register_user_group_membership",
_register_user_group_membership_txn, _register_user_group_membership_txn, next_id,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -948,3 +964,68 @@ class GroupServerStore(SQLBaseStore):
retcol="group_id", retcol="group_id",
desc="get_joined_groups", desc="get_joined_groups",
) )
def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn):
sql = """
SELECT group_id, type, membership, u.content
FROM local_group_updates AS u
INNER JOIN local_group_membership USING (group_id, user_id)
WHERE user_id = ? AND membership != 'leave'
AND stream_id <= ?
"""
txn.execute(sql, (user_id, now_token,))
return self.cursor_to_dict(txn)
return self.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn,
)
def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
user_id, from_token,
)
if not has_changed:
return []
def _get_groups_changes_for_user_txn(txn):
sql = """
SELECT group_id, membership, type, u.content
FROM local_group_updates AS u
INNER JOIN local_group_membership USING (group_id, user_id)
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
"""
txn.execute(sql, (user_id, from_token, to_token,))
return [{
"group_id": group_id,
"membership": membership,
"type": gtype,
"content": json.loads(content_json),
} for group_id, membership, gtype, content_json in txn]
return self.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn,
)
def get_all_groups_changes(self, from_token, to_token, limit):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(
from_token,
)
if not has_changed:
return []
def _get_all_groups_changes_txn(txn):
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn,
)
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()

View file

@ -155,3 +155,12 @@ CREATE TABLE local_group_membership (
CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
CREATE TABLE local_group_updates (
stream_id BIGINT NOT NULL,
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
type TEXT NOT NULL,
content TEXT NOT NULL
);

View file

@ -45,6 +45,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token() device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -65,6 +66,7 @@ class EventSources(object):
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key, device_list_key=device_list_key,
groups_key=groups_key,
) )
defer.returnValue(token) defer.returnValue(token)
@ -73,6 +75,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token() device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -93,5 +96,6 @@ class EventSources(object):
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key, device_list_key=device_list_key,
groups_key=groups_key,
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -171,6 +171,7 @@ class StreamToken(
"push_rules_key", "push_rules_key",
"to_device_key", "to_device_key",
"device_list_key", "device_list_key",
"groups_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -209,6 +210,7 @@ class StreamToken(
or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key)) or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key)) or (int(other.device_list_key) < int(self.device_list_key))
or (int(other.groups_key) < int(self.groups_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View file

@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0" token = "s0_0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))