0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-13 11:23:25 +01:00

Add local_current_membership table (#6655)

Currently we rely on `current_state_events` to figure out what rooms a
user was in and their last membership event in there. However, if the
server leaves the room then the table may be cleaned up and that
information is lost. So lets add a table that separately holds that
information.
This commit is contained in:
Erik Johnston 2020-01-15 14:59:33 +00:00 committed by GitHub
parent b5ce7f5874
commit 28c98e51ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 264 additions and 108 deletions

1
changelog.d/6655.misc Normal file
View file

@ -0,0 +1 @@
Add `local_current_membership` table for tracking local user membership state in rooms.

View file

@ -470,7 +470,7 @@ class Porter(object):
engine.check_database(
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=None)
prepare_database(db_conn, engine, config=self.hs_config)
store = Store(Database(hs, db_config, engine), db_conn, hs)
db_conn.commit()

View file

@ -134,7 +134,7 @@ class AdminHandler(BaseHandler):
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
rooms = await self.store.get_rooms_for_user_where_membership_is(
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,

View file

@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
pending_invites = await self.store.get_invited_rooms_for_user(user_id)
pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
for room in pending_invites:
try:

View file

@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived:
memberships.append(Membership.LEAVE)
room_list = await self.store.get_rooms_for_user_where_membership_is(
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=memberships
)

View file

@ -690,7 +690,7 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
invite = yield self.store.get_invite_for_local_user_in_room(
user_id=user_id, room_id=room_id
)
if invite:

View file

@ -179,7 +179,7 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_user_where_membership_is(
rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],

View file

@ -1662,7 +1662,7 @@ class SyncHandler(object):
Membership.BAN,
)
room_list = await self.store.get_rooms_for_user_where_membership_is(
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=membership_list
)

View file

@ -21,7 +21,7 @@ from synapse.storage import Storage
@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites = yield store.get_invited_rooms_for_user(user_id)
invites = yield store.get_invited_rooms_for_local_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")

View file

@ -152,7 +152,7 @@ class SlavedEventStore(
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_user.invalidate((state_key,))
self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))

View file

@ -105,7 +105,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_user_where_membership_is(
rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
system_mxid = self._config.server_notices_mxid

View file

@ -128,6 +128,7 @@ class EventsStore(
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def _read_forward_extremities(self):
@ -547,6 +548,34 @@ class EventsStore(
],
)
# Note: Do we really want to delete rows here (that we do not
# subsequently reinsert below)? While technically correct it means
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
txn.executemany(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
(room_id, state_key)
for etype, state_key in itertools.chain(to_delete, to_insert)
if etype == EventTypes.Member and self.is_mine_id(state_key)
),
)
if to_insert:
txn.executemany(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
""",
[
(room_id, key[1], ev_id, ev_id)
for key, ev_id in to_insert.items()
if key[0] == EventTypes.Member and self.is_mine_id(key[1])
],
)
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id,
@ -1724,6 +1753,7 @@ class EventsStore(
"local_invites",
"room_account_data",
"room_tags",
"local_current_membership",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))

View file

@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0]: row[1] for row in txn}
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
def get_invited_rooms_for_local_user(self, user_id):
""" Get all the rooms the *local* user is invited to
Args:
user_id (str): The user ID.
Returns:
A deferred list of RoomsForUser.
"""
return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
return self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@defer.inlineCallbacks
def get_invite_for_user_in_room(self, user_id, room_id):
"""Gets the invite for the given user and room
def get_invite_for_local_user_in_room(self, user_id, room_id):
"""Gets the invite for the given *local* user and room
Args:
user_id (str)
@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
"""
invites = yield self.get_invited_rooms_for_user(user_id)
invites = yield self.get_invited_rooms_for_local_user(user_id)
for invite in invites:
if invite.room_id == room_id:
return invite
return None
@defer.inlineCallbacks
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return defer.succeed(None)
rooms = yield self.db.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
"get_rooms_for_local_user_where_membership_is",
self._get_rooms_for_local_user_where_membership_is_txn,
user_id,
membership_list,
)
@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore):
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_user_where_membership_is_txn(
def _get_rooms_for_local_user_where_membership_is_txn(
self, txn, user_id, membership_list
):
do_invite = Membership.INVITE in membership_list
membership_list = [m for m in membership_list if m != Membership.INVITE]
results = []
if membership_list:
if self._current_state_events_membership_up_to_date:
clause, args = make_in_list_sql_clause(
self.database_engine, "c.membership", membership_list
)
sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND %s
""" % (
clause,
)
else:
clause, args = make_in_list_sql_clause(
self.database_engine, "m.membership", membership_list
)
sql = """
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND %s
""" % (
clause,
)
txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
if do_invite:
sql = (
"SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
" FROM local_invites as i"
" INNER JOIN events as e USING (event_id)"
" WHERE invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
"Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
% (user_id,),
)
txn.execute(sql, (user_id,))
results.extend(
RoomsForUser(
room_id=r["room_id"],
sender=r["inviter"],
event_id=r["event_id"],
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
)
for r in self.db.cursor_to_dict(txn)
)
clause, args = make_in_list_sql_clause(
self.database_engine, "c.membership", membership_list
)
sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM local_current_membership AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
user_id = ?
AND %s
""" % (
clause,
)
txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
return results
@cachedInlineCallbacks(max_entries=500000, iterable=True)
@cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id):
"""Returns a set of room_ids the user is currently joined to
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
Args:
user_id (str)
@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore):
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
"""
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN]
)
return frozenset(
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
for r in rooms
return self.db.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
if self._current_state_events_membership_up_to_date:
sql = """
SELECT room_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND c.membership = ?
"""
else:
sql = """
SELECT room_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND m.membership = ?
"""
txn.execute(sql, (user_id, Membership.JOIN))
results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
return results
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
),
)
# We also update the `local_current_membership` table with
# latest invite info. This will usually get updated by the
# `current_state_events` handling, unless its an outlier.
if event.internal_metadata.is_outlier():
# This should only happen for out of band memberships, so
# we add a paranoia check.
assert event.internal_metadata.is_out_of_band_membership()
self.db.simple_upsert_txn(
txn,
table="local_current_membership",
keyvalues={
"room_id": event.room_id,
"user_id": event.state_key,
},
values={
"event_id": event.event_id,
"membership": event.membership,
},
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
# We also clear this entry from `local_current_membership`.
# Ideally we'd point to a leave event, but we don't have one, so
# nevermind.
self.db.simple_delete_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
)
with self._stream_id_gen.get_next() as stream_ordering:
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)

View file

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# Copyright 2020 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We create a new table called `local_current_membership` that stores the latest
# membership state of local users in rooms, which helps track leaves/bans/etc
# even if the server has left the room (and so has deleted the room from
# `current_state_events`). This will also include outstanding invites for local
# users for rooms the server isn't in.
#
# If the server isn't and hasn't been in the room then it will only include
# outsstanding invites, and not e.g. pre-emptive bans of local users.
#
# If the server later rejoins a room `local_current_membership` can simply be
# replaced with the new current state of the room (which results in the
# equivalent behaviour as if the server had remained in the room).
def run_upgrade(cur, database_engine, config, *args, **kwargs):
# We need to do the insert in `run_upgrade` section as we don't have access
# to `config` in `run_create`.
# This upgrade may take a bit of time for large servers (e.g. one minute for
# matrix.org) but means we avoid a lots of book keeping required to do it as
# a background update.
# We check if the `current_state_events.membership` is up to date by
# checking if the relevant background update has finished. If it has
# finished we can avoid doing a join against `room_memberships`, which
# speesd things up.
cur.execute(
"""SELECT 1 FROM background_updates
WHERE update_name = 'current_state_events_membership'
"""
)
current_state_membership_up_to_date = not bool(cur.fetchone())
# Cheekily drop and recreate indices, as that is faster.
cur.execute("DROP INDEX local_current_membership_idx")
cur.execute("DROP INDEX local_current_membership_room_idx")
if current_state_membership_up_to_date:
sql = """
INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
SELECT c.room_id, state_key AS user_id, event_id, c.membership
FROM current_state_events AS c
WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ?
"""
else:
# We can't rely on the membership column, so we need to join against
# `room_memberships`.
sql = """
INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
SELECT c.room_id, state_key AS user_id, event_id, r.membership
FROM current_state_events AS c
INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' and state_key like '%' || ?
"""
cur.execute(sql, (config.server_name,))
cur.execute(
"CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
)
cur.execute(
"CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
)
def run_create(cur, database_engine, *args, **kwargs):
cur.execute(
"""
CREATE TABLE local_current_membership (
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
membership TEXT NOT NULL
)"""
)
cur.execute(
"CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
)
cur.execute(
"CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
)

View file

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 56
SCHEMA_VERSION = 57
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time

View file

@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def test_invites(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_user", [USER_ID_2], [])
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
"get_invited_rooms_for_user",
"get_invited_rooms_for_local_user",
[USER_ID_2],
[
RoomsForUser(

View file

@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
)
# Make sure the invite is here.
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
pending_invites = self.get_success(
store.get_invited_rooms_for_local_user(invitee_id)
)
self.assertEqual(len(pending_invites), 1, pending_invites)
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.deactivate(invitee_id, invitee_tok)
# Check that the invite isn't there anymore.
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
pending_invites = self.get_success(
store.get_invited_rooms_for_local_user(invitee_id)
)
self.assertEqual(len(pending_invites), 0, pending_invites)
# Check that the membership of @invitee:test in the room is now "leave".
memberships = self.get_success(
store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
store.get_rooms_for_local_user_where_membership_is(
invitee_id, [Membership.LEAVE]
)
)
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)

View file

@ -15,8 +15,6 @@
# limitations under the License.
import json
from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
"red", http_client=None, federation_client=Mock()
)
return hs
def test_sync_argless(self):
request, channel = self.make_request("GET", "/sync")
self.render(request)

View file

@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
rooms_for_user = self.get_success(
self.store.get_rooms_for_user_where_membership_is(
self.store.get_rooms_for_local_user_where_membership_is(
self.u_alice, [Membership.JOIN]
)
)