mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-25 17:54:16 +01:00
Reduce state pulled from DB due to sending typing and receipts over federation (#12964)
Reducing the amount of state we pull from the DB is useful as fetching state is expensive in terms of DB, CPU and memory.
This commit is contained in:
parent
148fe58a24
commit
44de53bb79
9 changed files with 68 additions and 16 deletions
1
changelog.d/12964.misc
Normal file
1
changelog.d/12964.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Reduce the amount of state we pull from the DB.
|
|
@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
|
@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender):
|
||||||
room_id = receipt.room_id
|
room_id = receipt.room_id
|
||||||
|
|
||||||
# Work out which remote servers should be poked and poke them.
|
# Work out which remote servers should be poked and poke them.
|
||||||
domains_set = await self.state.get_current_hosts_in_room(room_id)
|
domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
domains = [
|
domains = [
|
||||||
d
|
d
|
||||||
for d in domains_set
|
for d in domains_set
|
||||||
|
|
|
@ -59,6 +59,7 @@ class FollowerTypingHandler:
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.server_name = hs.config.server.server_name
|
self.server_name = hs.config.server.server_name
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
@ -131,7 +132,6 @@ class FollowerTypingHandler:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
users = await self.store.get_users_in_room(member.room_id)
|
|
||||||
self._member_last_federation_poke[member] = self.clock.time_msec()
|
self._member_last_federation_poke[member] = self.clock.time_msec()
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
|
@ -139,7 +139,10 @@ class FollowerTypingHandler:
|
||||||
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
|
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
|
||||||
)
|
)
|
||||||
|
|
||||||
for domain in {get_domain_from_id(u) for u in users}:
|
hosts = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||||
|
member.room_id
|
||||||
|
)
|
||||||
|
for domain in hosts:
|
||||||
if domain != self.server_name:
|
if domain != self.server_name:
|
||||||
logger.debug("sending typing update to %s", domain)
|
logger.debug("sending typing update to %s", domain)
|
||||||
self.federation.build_and_send_edu(
|
self.federation.build_and_send_edu(
|
||||||
|
|
|
@ -172,10 +172,6 @@ class StateHandler:
|
||||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
return await self.store.get_joined_users_from_state(room_id, entry)
|
||||||
|
|
||||||
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
|
|
||||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
|
||||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
|
||||||
|
|
||||||
async def get_hosts_in_room_at_events(
|
async def get_hosts_in_room_at_events(
|
||||||
self, room_id: str, event_ids: Collection[str]
|
self, room_id: str, event_ids: Collection[str]
|
||||||
) -> FrozenSet[str]:
|
) -> FrozenSet[str]:
|
||||||
|
|
|
@ -71,6 +71,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||||
if members_changed:
|
if members_changed:
|
||||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||||
|
self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
|
||||||
self._attempt_to_invalidate_cache(
|
self._attempt_to_invalidate_cache(
|
||||||
"get_users_in_room_with_profiles", (room_id,)
|
"get_users_in_room_with_profiles", (room_id,)
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from typing import (
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -482,3 +483,10 @@ class StateStorageController:
|
||||||
room_id, StateFilter.from_types((key,))
|
room_id, StateFilter.from_types((key,))
|
||||||
)
|
)
|
||||||
return state_map.get(key)
|
return state_map.get(key)
|
||||||
|
|
||||||
|
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||||
|
"""Get current hosts in room based on current state."""
|
||||||
|
|
||||||
|
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||||
|
|
||||||
|
return await self.stores.main.get_current_hosts_in_room(room_id)
|
||||||
|
|
|
@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@cached(iterable=True, max_entries=10000)
|
||||||
|
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||||
|
"""Get current hosts in room based on current state."""
|
||||||
|
|
||||||
|
# First we check if we already have `get_users_in_room` in the cache, as
|
||||||
|
# we can just calculate result from that
|
||||||
|
users = self.get_users_in_room.cache.get_immediate(
|
||||||
|
(room_id,), None, update_metrics=False
|
||||||
|
)
|
||||||
|
if users is not None:
|
||||||
|
return {get_domain_from_id(u) for u in users}
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, Sqlite3Engine):
|
||||||
|
# If we're using SQLite then let's just always use
|
||||||
|
# `get_users_in_room` rather than funky SQL.
|
||||||
|
users = await self.get_users_in_room(room_id)
|
||||||
|
return {get_domain_from_id(u) for u in users}
|
||||||
|
|
||||||
|
# For PostgreSQL we can use a regex to pull out the domains from the
|
||||||
|
# joined users in `current_state_events` via regex.
|
||||||
|
|
||||||
|
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
|
||||||
|
sql = """
|
||||||
|
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
|
||||||
|
FROM current_state_events
|
||||||
|
WHERE
|
||||||
|
type = 'm.room.member'
|
||||||
|
AND membership = 'join'
|
||||||
|
AND room_id = ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (room_id,))
|
||||||
|
return {d for d, in txn}
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
||||||
|
)
|
||||||
|
|
||||||
async def get_joined_hosts(
|
async def get_joined_hosts(
|
||||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||||
) -> FrozenSet[str]:
|
) -> FrozenSet[str]:
|
||||||
|
|
|
@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
|
hs = self.setup_test_homeserver(
|
||||||
# Ensure a new Awaitable is created for each call.
|
|
||||||
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
|
|
||||||
["test", "host2"]
|
|
||||||
)
|
|
||||||
return self.setup_test_homeserver(
|
|
||||||
state_handler=mock_state_handler,
|
|
||||||
federation_transport_client=Mock(spec=["send_transaction"]),
|
federation_transport_client=Mock(spec=["send_transaction"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
|
||||||
|
return_value=make_awaitable({"test", "host2"})
|
||||||
|
)
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
@override_config({"send_federation": True})
|
@override_config({"send_federation": True})
|
||||||
def test_send_receipts(self):
|
def test_send_receipts(self):
|
||||||
mock_send_transaction = (
|
mock_send_transaction = (
|
||||||
|
|
|
@ -129,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
|
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
|
||||||
|
|
||||||
def get_joined_hosts_for_room(room_id: str):
|
async def get_current_hosts_in_room(room_id: str):
|
||||||
return {member.domain for member in self.room_members}
|
return {member.domain for member in self.room_members}
|
||||||
|
|
||||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
hs.get_storage_controllers().state.get_current_hosts_in_room = (
|
||||||
|
get_current_hosts_in_room
|
||||||
|
)
|
||||||
|
|
||||||
async def get_users_in_room(room_id: str):
|
async def get_users_in_room(room_id: str):
|
||||||
return {str(u) for u in self.room_members}
|
return {str(u) for u in self.room_members}
|
||||||
|
|
Loading…
Reference in a new issue