forked from MirrorHub/synapse
Update get_users_in_room
mis-use to get hosts with dedicated get_current_hosts_in_room
(#13605)
See https://github.com/matrix-org/synapse/pull/13575#discussion_r953023755
This commit is contained in:
parent
d58615c82c
commit
1a209efdb2
6 changed files with 31 additions and 17 deletions
1
changelog.d/13605.misc
Normal file
1
changelog.d/13605.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Refactor `get_users_in_room(room_id)` mis-use with dedicated `get_current_hosts_in_room(room_id)` function.
|
|
@ -310,6 +310,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.federation_sender = hs.get_federation_sender()
|
self.federation_sender = hs.get_federation_sender()
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
self.device_list_updater = DeviceListUpdater(hs, self)
|
self.device_list_updater = DeviceListUpdater(hs, self)
|
||||||
|
|
||||||
|
@ -694,8 +695,11 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
|
|
||||||
# Ignore any users that aren't ours
|
# Ignore any users that aren't ours
|
||||||
if self.hs.is_mine_id(user_id):
|
if self.hs.is_mine_id(user_id):
|
||||||
joined_user_ids = await self.store.get_users_in_room(room_id)
|
hosts = set(
|
||||||
hosts = {get_domain_from_id(u) for u in joined_user_ids}
|
await self._storage_controllers.state.get_current_hosts_in_room(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
)
|
||||||
hosts.discard(self.server_name)
|
hosts.discard(self.server_name)
|
||||||
|
|
||||||
# Check if we've already sent this update to some hosts
|
# Check if we've already sent this update to some hosts
|
||||||
|
|
|
@ -30,7 +30,7 @@ from synapse.api.errors import (
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.module_api import NOT_SPAM
|
from synapse.module_api import NOT_SPAM
|
||||||
from synapse.storage.databases.main.directory import RoomAliasMapping
|
from synapse.storage.databases.main.directory import RoomAliasMapping
|
||||||
from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id
|
from synapse.types import JsonDict, Requester, RoomAlias
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -83,8 +83,9 @@ class DirectoryHandler:
|
||||||
# TODO(erikj): Add transactions.
|
# TODO(erikj): Add transactions.
|
||||||
# TODO(erikj): Check if there is a current association.
|
# TODO(erikj): Check if there is a current association.
|
||||||
if not servers:
|
if not servers:
|
||||||
users = await self.store.get_users_in_room(room_id)
|
servers = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||||
servers = {get_domain_from_id(u) for u in users}
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
if not servers:
|
if not servers:
|
||||||
raise SynapseError(400, "Failed to get server list")
|
raise SynapseError(400, "Failed to get server list")
|
||||||
|
@ -287,8 +288,9 @@ class DirectoryHandler:
|
||||||
Codes.NOT_FOUND,
|
Codes.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
users = await self.store.get_users_in_room(room_id)
|
extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||||
extra_servers = {get_domain_from_id(u) for u in users}
|
room_id
|
||||||
|
)
|
||||||
servers_set = set(extra_servers) | set(servers)
|
servers_set = set(extra_servers) | set(servers)
|
||||||
|
|
||||||
# If this server is in the list of servers, return it first.
|
# If this server is in the list of servers, return it first.
|
||||||
|
|
|
@ -2051,8 +2051,7 @@ async def get_interested_remotes(
|
||||||
)
|
)
|
||||||
|
|
||||||
for room_id, states in room_ids_to_states.items():
|
for room_id, states in room_ids_to_states.items():
|
||||||
user_ids = await store.get_users_in_room(room_id)
|
hosts = await store.get_current_hosts_in_room(room_id)
|
||||||
hosts = {get_domain_from_id(user_id) for user_id in user_ids}
|
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
hosts_and_states.setdefault(host, set()).update(states)
|
hosts_and_states.setdefault(host, set()).update(states)
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import (
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.streams import TypingStream
|
from synapse.replication.tcp.streams import TypingStream
|
||||||
from synapse.streams import EventSource
|
from synapse.streams import EventSource
|
||||||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
|
from synapse.types import JsonDict, Requester, StreamKeyType, UserID
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.wheel_timer import WheelTimer
|
from synapse.util.wheel_timer import WheelTimer
|
||||||
|
@ -362,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
users = await self.store.get_users_in_room(room_id)
|
domains = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||||
domains = {get_domain_from_id(u) for u in users}
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
if self.server_name in domains:
|
if self.server_name in domains:
|
||||||
logger.info("Got typing update from %s: %r", user_id, content)
|
logger.info("Got typing update from %s: %r", user_id, content)
|
||||||
|
|
|
@ -173,17 +173,24 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
# stub out `get_rooms_for_user` and `get_users_in_room` so that the
|
test_room_id = "!room:host1"
|
||||||
|
|
||||||
|
# stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
|
||||||
# server thinks the user shares a room with `@user2:host2`
|
# server thinks the user shares a room with `@user2:host2`
|
||||||
def get_rooms_for_user(user_id):
|
def get_rooms_for_user(user_id):
|
||||||
return defer.succeed({"!room:host1"})
|
return defer.succeed({test_room_id})
|
||||||
|
|
||||||
hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
|
hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
|
||||||
|
|
||||||
def get_users_in_room(room_id):
|
async def get_current_hosts_in_room(room_id):
|
||||||
return defer.succeed({"@user2:host2"})
|
if room_id == test_room_id:
|
||||||
|
return ["host2"]
|
||||||
|
|
||||||
hs.get_datastores().main.get_users_in_room = get_users_in_room
|
# TODO: We should fail the test when we encounter an unxpected room ID.
|
||||||
|
# We can't just use `self.fail(...)` here because the app code is greedy
|
||||||
|
# with `Exception` and will catch it before the test can see it.
|
||||||
|
|
||||||
|
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
|
||||||
|
|
||||||
# whenever send_transaction is called, record the edu data
|
# whenever send_transaction is called, record the edu data
|
||||||
self.edus = []
|
self.edus = []
|
||||||
|
|
Loading…
Reference in a new issue