mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-25 05:14:20 +01:00
Reduce state pulled from DB due to sending typing and receipts over federation (#12964)
Reducing the amount of state we pull from the DB is useful as fetching state is expensive in terms of DB, CPU and memory.
This commit is contained in:
parent
148fe58a24
commit
44de53bb79
9 changed files with 68 additions and 16 deletions
1
changelog.d/12964.misc
Normal file
1
changelog.d/12964.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Reduce the amount of state we pull from the DB.
|
|
@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender):
|
|||
self.store = hs.get_datastores().main
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
|
@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender):
|
|||
room_id = receipt.room_id
|
||||
|
||||
# Work out which remote servers should be poked and poke them.
|
||||
domains_set = await self.state.get_current_hosts_in_room(room_id)
|
||||
domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||
room_id
|
||||
)
|
||||
domains = [
|
||||
d
|
||||
for d in domains_set
|
||||
|
|
|
@ -59,6 +59,7 @@ class FollowerTypingHandler:
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.server_name = hs.config.server.server_name
|
||||
self.clock = hs.get_clock()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
@ -131,7 +132,6 @@ class FollowerTypingHandler:
|
|||
return
|
||||
|
||||
try:
|
||||
users = await self.store.get_users_in_room(member.room_id)
|
||||
self._member_last_federation_poke[member] = self.clock.time_msec()
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
@ -139,7 +139,10 @@ class FollowerTypingHandler:
|
|||
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
|
||||
)
|
||||
|
||||
for domain in {get_domain_from_id(u) for u in users}:
|
||||
hosts = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||
member.room_id
|
||||
)
|
||||
for domain in hosts:
|
||||
if domain != self.server_name:
|
||||
logger.debug("sending typing update to %s", domain)
|
||||
self.federation.build_and_send_edu(
|
||||
|
|
|
@ -172,10 +172,6 @@ class StateHandler:
|
|||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
|
||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> FrozenSet[str]:
|
||||
|
|
|
@ -71,6 +71,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||
if members_changed:
|
||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_users_in_room_with_profiles", (room_id,)
|
||||
)
|
||||
|
|
|
@ -23,6 +23,7 @@ from typing import (
|
|||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
@ -482,3 +483,10 @@ class StateStorageController:
|
|||
room_id, StateFilter.from_types((key,))
|
||||
)
|
||||
return state_map.get(key)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
"""Get current hosts in room based on current state."""
|
||||
|
||||
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||
|
||||
return await self.stores.main.get_current_hosts_in_room(room_id)
|
||||
|
|
|
@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
return True
|
||||
|
||||
@cached(iterable=True, max_entries=10000)
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
"""Get current hosts in room based on current state."""
|
||||
|
||||
# First we check if we already have `get_users_in_room` in the cache, as
|
||||
# we can just calculate result from that
|
||||
users = self.get_users_in_room.cache.get_immediate(
|
||||
(room_id,), None, update_metrics=False
|
||||
)
|
||||
if users is not None:
|
||||
return {get_domain_from_id(u) for u in users}
|
||||
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
# If we're using SQLite then let's just always use
|
||||
# `get_users_in_room` rather than funky SQL.
|
||||
users = await self.get_users_in_room(room_id)
|
||||
return {get_domain_from_id(u) for u in users}
|
||||
|
||||
# For PostgreSQL we can use a regex to pull out the domains from the
|
||||
# joined users in `current_state_events` via regex.
|
||||
|
||||
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
sql = """
|
||||
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
|
||||
FROM current_state_events
|
||||
WHERE
|
||||
type = 'm.room.member'
|
||||
AND membership = 'join'
|
||||
AND room_id = ?
|
||||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
return {d for d, in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
||||
)
|
||||
|
||||
async def get_joined_hosts(
|
||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||
) -> FrozenSet[str]:
|
||||
|
|
|
@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
|
|||
|
||||
class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
|
||||
# Ensure a new Awaitable is created for each call.
|
||||
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
|
||||
["test", "host2"]
|
||||
)
|
||||
return self.setup_test_homeserver(
|
||||
state_handler=mock_state_handler,
|
||||
hs = self.setup_test_homeserver(
|
||||
federation_transport_client=Mock(spec=["send_transaction"]),
|
||||
)
|
||||
|
||||
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
|
||||
return_value=make_awaitable({"test", "host2"})
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
@override_config({"send_federation": True})
|
||||
def test_send_receipts(self):
|
||||
mock_send_transaction = (
|
||||
|
|
|
@ -129,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
|
||||
|
||||
def get_joined_hosts_for_room(room_id: str):
|
||||
async def get_current_hosts_in_room(room_id: str):
|
||||
return {member.domain for member in self.room_members}
|
||||
|
||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||
hs.get_storage_controllers().state.get_current_hosts_in_room = (
|
||||
get_current_hosts_in_room
|
||||
)
|
||||
|
||||
async def get_users_in_room(room_id: str):
|
||||
return {str(u) for u in self.room_members}
|
||||
|
|
Loading…
Reference in a new issue