mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-19 02:32:06 +01:00
Merge pull request #789 from matrix-org/markjh/member_cleanup
Cleanup room member handler
This commit is contained in:
commit
1ed33784a6
8 changed files with 37 additions and 174 deletions
|
@ -29,6 +29,8 @@ class ReceiptsHandler(BaseHandler):
|
|||
def __init__(self, hs):
|
||||
super(ReceiptsHandler, self).__init__(hs)
|
||||
|
||||
self.server_name = hs.config.server_name
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_edu_handler(
|
||||
|
@ -131,12 +133,9 @@ class ReceiptsHandler(BaseHandler):
|
|||
event_ids = receipt["event_ids"]
|
||||
data = receipt["data"]
|
||||
|
||||
remotedomains = set()
|
||||
|
||||
rm_handler = self.hs.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=None, remotedomains=remotedomains
|
||||
)
|
||||
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
remotedomains = remotedomains.copy()
|
||||
remotedomains.discard(self.server_name)
|
||||
|
||||
logger.debug("Sending receipt to: %r", remotedomains)
|
||||
|
||||
|
|
|
@ -55,35 +55,6 @@ class RoomMemberHandler(BaseHandler):
|
|||
self.distributor.declare("user_joined_room")
|
||||
self.distributor.declare("user_left_room")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_members(self, room_id):
|
||||
users = yield self.store.get_users_in_room(room_id)
|
||||
|
||||
defer.returnValue([UserID.from_string(u) for u in users])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_room_distributions_into(self, room_id, localusers=None,
|
||||
remotedomains=None, ignore_user=None):
|
||||
"""Fetch the distribution of a room, adding elements to either
|
||||
'localusers' or 'remotedomains', which should be a set() if supplied.
|
||||
If ignore_user is set, ignore that user.
|
||||
|
||||
This function returns nothing; its result is performed by the
|
||||
side-effect on the two passed sets. This allows easy accumulation of
|
||||
member lists of multiple rooms at once if required.
|
||||
"""
|
||||
members = yield self.get_room_members(room_id)
|
||||
for member in members:
|
||||
if ignore_user is not None and member == ignore_user:
|
||||
continue
|
||||
|
||||
if self.hs.is_mine(member):
|
||||
if localusers is not None:
|
||||
localusers.add(member)
|
||||
else:
|
||||
if remotedomains is not None:
|
||||
remotedomains.add(member.domain)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _local_membership_update(
|
||||
self, requester, target, room_id, membership,
|
||||
|
@ -426,21 +397,6 @@ class RoomMemberHandler(BaseHandler):
|
|||
if invite:
|
||||
defer.returnValue(UserID.from_string(invite.sender))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_rooms_for_user(self, user):
|
||||
"""Returns a list of roomids that the user has any of the given
|
||||
membership states in."""
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(
|
||||
user.to_string(),
|
||||
)
|
||||
|
||||
# For some reason the list of events contains duplicates
|
||||
# TODO(paul): work out why because I really don't think it should
|
||||
room_ids = set(r.room_id for r in rooms)
|
||||
|
||||
defer.returnValue(room_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_3pid_invite(
|
||||
self,
|
||||
|
@ -457,8 +413,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if invitee:
|
||||
handler = self.hs.get_handlers().room_member_handler
|
||||
yield handler.update_membership(
|
||||
yield self.update_membership(
|
||||
requester,
|
||||
UserID.from_string(invitee),
|
||||
room_id,
|
||||
|
|
|
@ -485,7 +485,6 @@ class SyncHandler(BaseHandler):
|
|||
sync_config, now_token, since_token
|
||||
)
|
||||
|
||||
rm_handler = self.hs.get_handlers().room_member_handler
|
||||
app_service = yield self.store.get_app_service_by_user_id(
|
||||
sync_config.user.to_string()
|
||||
)
|
||||
|
@ -493,9 +492,10 @@ class SyncHandler(BaseHandler):
|
|||
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||
joined_room_ids = set(r.room_id for r in rooms)
|
||||
else:
|
||||
joined_room_ids = yield rm_handler.get_joined_rooms_for_user(
|
||||
sync_config.user
|
||||
rooms = yield self.store.get_rooms_for_user(
|
||||
sync_config.user.to_string()
|
||||
)
|
||||
joined_room_ids = set(r.room_id for r in rooms)
|
||||
|
||||
user_id = sync_config.user.to_string()
|
||||
|
||||
|
|
|
@ -39,7 +39,8 @@ class TypingNotificationHandler(BaseHandler):
|
|||
def __init__(self, hs):
|
||||
super(TypingNotificationHandler, self).__init__(hs)
|
||||
|
||||
self.homeserver = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.server_name = hs.config.server_name
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
|
@ -157,32 +158,26 @@ class TypingNotificationHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _push_update(self, room_id, user, typing):
|
||||
localusers = set()
|
||||
remotedomains = set()
|
||||
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=localusers, remotedomains=remotedomains
|
||||
)
|
||||
|
||||
if localusers:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user=user,
|
||||
typing=typing
|
||||
)
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
|
||||
deferreds = []
|
||||
for domain in remotedomains:
|
||||
deferreds.append(self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.typing",
|
||||
content={
|
||||
"room_id": room_id,
|
||||
"user_id": user.to_string(),
|
||||
"typing": typing,
|
||||
},
|
||||
))
|
||||
for domain in domains:
|
||||
if domain == self.server_name:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user=user,
|
||||
typing=typing
|
||||
)
|
||||
else:
|
||||
deferreds.append(self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.typing",
|
||||
content={
|
||||
"room_id": room_id,
|
||||
"user_id": user.to_string(),
|
||||
"typing": typing,
|
||||
},
|
||||
))
|
||||
|
||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
|
||||
|
@ -191,14 +186,9 @@ class TypingNotificationHandler(BaseHandler):
|
|||
room_id = content["room_id"]
|
||||
user = UserID.from_string(content["user_id"])
|
||||
|
||||
localusers = set()
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=localusers
|
||||
)
|
||||
|
||||
if localusers:
|
||||
if self.server_name in domains:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user=user,
|
||||
|
@ -239,7 +229,6 @@ class TypingNotificationEventSource(object):
|
|||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self._handler = None
|
||||
self._room_member_handler = None
|
||||
|
||||
def handler(self):
|
||||
# Avoid cyclic dependency in handler setup
|
||||
|
@ -247,11 +236,6 @@ class TypingNotificationEventSource(object):
|
|||
self._handler = self.hs.get_handlers().typing_notification_handler
|
||||
return self._handler
|
||||
|
||||
def room_member_handler(self):
|
||||
if not self._room_member_handler:
|
||||
self._room_member_handler = self.hs.get_handlers().room_member_handler
|
||||
return self._room_member_handler
|
||||
|
||||
def _make_event_for(self, room_id):
|
||||
typing = self.handler()._room_typing[room_id]
|
||||
return {
|
||||
|
|
|
@ -137,24 +137,6 @@ class RoomMemberStore(SQLBaseStore):
|
|||
return [r["user_id"] for r in rows]
|
||||
return self.runInteraction("get_users_in_room", f)
|
||||
|
||||
def get_room_members(self, room_id, membership=None):
|
||||
"""Retrieve the current room member list for a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The room to get the list of members.
|
||||
membership (synapse.api.constants.Membership): The filter to apply
|
||||
to this list, or None to return all members with some state
|
||||
associated with this room.
|
||||
Returns:
|
||||
list of namedtuples representing the members in this room.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_room_members",
|
||||
self._get_members_events_txn,
|
||||
room_id,
|
||||
membership=membership,
|
||||
).addCallback(self._get_events)
|
||||
|
||||
@cached()
|
||||
def get_invited_rooms_for_user(self, user_id):
|
||||
""" Get all the rooms the user is invited to
|
||||
|
|
|
@ -71,6 +71,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
self.auth = Mock(spec=[])
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
"test",
|
||||
auth=self.auth,
|
||||
clock=self.clock,
|
||||
datastore=Mock(spec=[
|
||||
|
@ -110,56 +111,16 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
|
||||
self.room_id = "a-room"
|
||||
|
||||
# Mock the RoomMemberHandler
|
||||
hs.handlers.room_member_handler = Mock(spec=[])
|
||||
self.room_member_handler = hs.handlers.room_member_handler
|
||||
|
||||
self.room_members = []
|
||||
|
||||
def get_rooms_for_user(user):
|
||||
if user in self.room_members:
|
||||
return defer.succeed([self.room_id])
|
||||
else:
|
||||
return defer.succeed([])
|
||||
self.room_member_handler.get_rooms_for_user = get_rooms_for_user
|
||||
|
||||
def get_room_members(room_id):
|
||||
if room_id == self.room_id:
|
||||
return defer.succeed(self.room_members)
|
||||
else:
|
||||
return defer.succeed([])
|
||||
self.room_member_handler.get_room_members = get_room_members
|
||||
|
||||
def get_joined_rooms_for_user(user):
|
||||
if user in self.room_members:
|
||||
return defer.succeed([self.room_id])
|
||||
else:
|
||||
return defer.succeed([])
|
||||
self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_room_distributions_into(
|
||||
room_id, localusers=None, remotedomains=None, ignore_user=None
|
||||
):
|
||||
members = yield get_room_members(room_id)
|
||||
for member in members:
|
||||
if ignore_user is not None and member == ignore_user:
|
||||
continue
|
||||
|
||||
if hs.is_mine(member):
|
||||
if localusers is not None:
|
||||
localusers.add(member)
|
||||
else:
|
||||
if remotedomains is not None:
|
||||
remotedomains.add(member.domain)
|
||||
self.room_member_handler.fetch_room_distributions_into = (
|
||||
fetch_room_distributions_into
|
||||
)
|
||||
|
||||
def check_joined_room(room_id, user_id):
|
||||
if user_id not in [u.to_string() for u in self.room_members]:
|
||||
raise AuthError(401, "User is not in the room")
|
||||
|
||||
def get_joined_hosts_for_room(room_id):
|
||||
return set(member.domain for member in self.room_members)
|
||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||
|
||||
self.auth.check_joined_room = check_joined_room
|
||||
|
||||
# Some local users to test with
|
||||
|
|
|
@ -70,12 +70,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
|||
def test_one_member(self):
|
||||
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
||||
|
||||
self.assertEquals(
|
||||
[self.u_alice.to_string()],
|
||||
[m.user_id for m in (
|
||||
yield self.store.get_room_members(self.room.to_string())
|
||||
)]
|
||||
)
|
||||
self.assertEquals(
|
||||
[self.room.to_string()],
|
||||
[m.room_id for m in (
|
||||
|
@ -85,18 +79,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
|||
)]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_two_members(self):
|
||||
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
||||
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
|
||||
|
||||
self.assertEquals(
|
||||
{self.u_alice.to_string(), self.u_bob.to_string()},
|
||||
{m.user_id for m in (
|
||||
yield self.store.get_room_members(self.room.to_string())
|
||||
)}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_room_hosts(self):
|
||||
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
||||
|
|
|
@ -50,7 +50,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||
config.enable_registration = True
|
||||
config.macaroon_secret_key = "not even a little secret"
|
||||
config.expire_access_token = False
|
||||
config.server_name = "server.under.test"
|
||||
config.server_name = name
|
||||
config.trusted_third_party_id_servers = []
|
||||
config.room_invite_state_types = []
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue