mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-11 12:31:58 +01:00
Replaces calls to fetch_room_distributions_into with get_joined_hosts_for_room
This commit is contained in:
parent
1a3a2002ff
commit
821306120a
5 changed files with 33 additions and 112 deletions
|
@ -29,6 +29,8 @@ class ReceiptsHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ReceiptsHandler, self).__init__(hs)
|
super(ReceiptsHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
self.server_name = hs.config.server_name
|
||||||
|
self.store = hs.get_datastore()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
self.federation.register_edu_handler(
|
self.federation.register_edu_handler(
|
||||||
|
@ -131,12 +133,9 @@ class ReceiptsHandler(BaseHandler):
|
||||||
event_ids = receipt["event_ids"]
|
event_ids = receipt["event_ids"]
|
||||||
data = receipt["data"]
|
data = receipt["data"]
|
||||||
|
|
||||||
remotedomains = set()
|
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||||
|
remotedomains = remotedomains.copy()
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
remotedomains.discard(self.server_name)
|
||||||
yield rm_handler.fetch_room_distributions_into(
|
|
||||||
room_id, localusers=None, remotedomains=remotedomains
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Sending receipt to: %r", remotedomains)
|
logger.debug("Sending receipt to: %r", remotedomains)
|
||||||
|
|
||||||
|
|
|
@ -55,35 +55,6 @@ class RoomMemberHandler(BaseHandler):
|
||||||
self.distributor.declare("user_joined_room")
|
self.distributor.declare("user_joined_room")
|
||||||
self.distributor.declare("user_left_room")
|
self.distributor.declare("user_left_room")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_room_members(self, room_id):
|
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
|
||||||
|
|
||||||
defer.returnValue([UserID.from_string(u) for u in users])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def fetch_room_distributions_into(self, room_id, localusers=None,
|
|
||||||
remotedomains=None, ignore_user=None):
|
|
||||||
"""Fetch the distribution of a room, adding elements to either
|
|
||||||
'localusers' or 'remotedomains', which should be a set() if supplied.
|
|
||||||
If ignore_user is set, ignore that user.
|
|
||||||
|
|
||||||
This function returns nothing; its result is performed by the
|
|
||||||
side-effect on the two passed sets. This allows easy accumulation of
|
|
||||||
member lists of multiple rooms at once if required.
|
|
||||||
"""
|
|
||||||
members = yield self.get_room_members(room_id)
|
|
||||||
for member in members:
|
|
||||||
if ignore_user is not None and member == ignore_user:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.hs.is_mine(member):
|
|
||||||
if localusers is not None:
|
|
||||||
localusers.add(member)
|
|
||||||
else:
|
|
||||||
if remotedomains is not None:
|
|
||||||
remotedomains.add(member.domain)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _local_membership_update(
|
def _local_membership_update(
|
||||||
self, requester, target, room_id, membership,
|
self, requester, target, room_id, membership,
|
||||||
|
|
|
@ -39,7 +39,8 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(TypingNotificationHandler, self).__init__(hs)
|
super(TypingNotificationHandler, self).__init__(hs)
|
||||||
|
|
||||||
self.homeserver = hs
|
self.store = hs.get_datastore()
|
||||||
|
self.server_name = hs.config.server_name
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@ -157,32 +158,26 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _push_update(self, room_id, user, typing):
|
def _push_update(self, room_id, user, typing):
|
||||||
localusers = set()
|
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||||
remotedomains = set()
|
|
||||||
|
|
||||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
|
||||||
yield rm_handler.fetch_room_distributions_into(
|
|
||||||
room_id, localusers=localusers, remotedomains=remotedomains
|
|
||||||
)
|
|
||||||
|
|
||||||
if localusers:
|
|
||||||
self._push_update_local(
|
|
||||||
room_id=room_id,
|
|
||||||
user=user,
|
|
||||||
typing=typing
|
|
||||||
)
|
|
||||||
|
|
||||||
deferreds = []
|
deferreds = []
|
||||||
for domain in remotedomains:
|
for domain in domains:
|
||||||
deferreds.append(self.federation.send_edu(
|
if domain == self.server_name:
|
||||||
destination=domain,
|
self._push_update_local(
|
||||||
edu_type="m.typing",
|
room_id=room_id,
|
||||||
content={
|
user=user,
|
||||||
"room_id": room_id,
|
typing=typing
|
||||||
"user_id": user.to_string(),
|
)
|
||||||
"typing": typing,
|
else:
|
||||||
},
|
deferreds.append(self.federation.send_edu(
|
||||||
))
|
destination=domain,
|
||||||
|
edu_type="m.typing",
|
||||||
|
content={
|
||||||
|
"room_id": room_id,
|
||||||
|
"user_id": user.to_string(),
|
||||||
|
"typing": typing,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
|
|
||||||
|
@ -191,14 +186,9 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
room_id = content["room_id"]
|
room_id = content["room_id"]
|
||||||
user = UserID.from_string(content["user_id"])
|
user = UserID.from_string(content["user_id"])
|
||||||
|
|
||||||
localusers = set()
|
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||||
|
|
||||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
if self.server_name in domains:
|
||||||
yield rm_handler.fetch_room_distributions_into(
|
|
||||||
room_id, localusers=localusers
|
|
||||||
)
|
|
||||||
|
|
||||||
if localusers:
|
|
||||||
self._push_update_local(
|
self._push_update_local(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
|
@ -71,6 +71,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
self.auth = Mock(spec=[])
|
self.auth = Mock(spec=[])
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
|
"test",
|
||||||
auth=self.auth,
|
auth=self.auth,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
datastore=Mock(spec=[
|
datastore=Mock(spec=[
|
||||||
|
@ -110,56 +111,16 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.room_id = "a-room"
|
self.room_id = "a-room"
|
||||||
|
|
||||||
# Mock the RoomMemberHandler
|
|
||||||
hs.handlers.room_member_handler = Mock(spec=[])
|
|
||||||
self.room_member_handler = hs.handlers.room_member_handler
|
|
||||||
|
|
||||||
self.room_members = []
|
self.room_members = []
|
||||||
|
|
||||||
def get_rooms_for_user(user):
|
|
||||||
if user in self.room_members:
|
|
||||||
return defer.succeed([self.room_id])
|
|
||||||
else:
|
|
||||||
return defer.succeed([])
|
|
||||||
self.room_member_handler.get_rooms_for_user = get_rooms_for_user
|
|
||||||
|
|
||||||
def get_room_members(room_id):
|
|
||||||
if room_id == self.room_id:
|
|
||||||
return defer.succeed(self.room_members)
|
|
||||||
else:
|
|
||||||
return defer.succeed([])
|
|
||||||
self.room_member_handler.get_room_members = get_room_members
|
|
||||||
|
|
||||||
def get_joined_rooms_for_user(user):
|
|
||||||
if user in self.room_members:
|
|
||||||
return defer.succeed([self.room_id])
|
|
||||||
else:
|
|
||||||
return defer.succeed([])
|
|
||||||
self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def fetch_room_distributions_into(
|
|
||||||
room_id, localusers=None, remotedomains=None, ignore_user=None
|
|
||||||
):
|
|
||||||
members = yield get_room_members(room_id)
|
|
||||||
for member in members:
|
|
||||||
if ignore_user is not None and member == ignore_user:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if hs.is_mine(member):
|
|
||||||
if localusers is not None:
|
|
||||||
localusers.add(member)
|
|
||||||
else:
|
|
||||||
if remotedomains is not None:
|
|
||||||
remotedomains.add(member.domain)
|
|
||||||
self.room_member_handler.fetch_room_distributions_into = (
|
|
||||||
fetch_room_distributions_into
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_joined_room(room_id, user_id):
|
def check_joined_room(room_id, user_id):
|
||||||
if user_id not in [u.to_string() for u in self.room_members]:
|
if user_id not in [u.to_string() for u in self.room_members]:
|
||||||
raise AuthError(401, "User is not in the room")
|
raise AuthError(401, "User is not in the room")
|
||||||
|
|
||||||
|
def get_joined_hosts_for_room(room_id):
|
||||||
|
return set(member.domain for member in self.room_members)
|
||||||
|
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||||
|
|
||||||
self.auth.check_joined_room = check_joined_room
|
self.auth.check_joined_room = check_joined_room
|
||||||
|
|
||||||
# Some local users to test with
|
# Some local users to test with
|
||||||
|
|
|
@ -50,7 +50,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
config.enable_registration = True
|
config.enable_registration = True
|
||||||
config.macaroon_secret_key = "not even a little secret"
|
config.macaroon_secret_key = "not even a little secret"
|
||||||
config.expire_access_token = False
|
config.expire_access_token = False
|
||||||
config.server_name = "server.under.test"
|
config.server_name = name
|
||||||
config.trusted_third_party_id_servers = []
|
config.trusted_third_party_id_servers = []
|
||||||
config.room_invite_state_types = []
|
config.room_invite_state_types = []
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue