forked from MirrorHub/synapse
Convert some of the data store to async. (#7976)
This commit is contained in:
parent
3950ae51ef
commit
b3a97d6dac
10 changed files with 190 additions and 207 deletions
1
changelog.d/7976.misc
Normal file
1
changelog.d/7976.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -15,11 +15,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import Database
|
from synapse.storage.database import Database
|
||||||
|
@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return {"notify_count": notify_count, "highlight_count": highlight_count}
|
return {"notify_count": notify_count, "highlight_count": highlight_count}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_push_action_users_in_range(
|
||||||
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
|
self, min_stream_ordering, max_stream_ordering
|
||||||
|
):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
|
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
|
||||||
|
@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
|
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
|
||||||
return [r[0] for r in txn]
|
return [r[0] for r in txn]
|
||||||
|
|
||||||
ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
|
ret = await self.db.runInteraction("get_push_action_users_in_range", f)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_unread_push_actions_for_user_in_range_for_http(
|
||||||
def get_unread_push_actions_for_user_in_range_for_http(
|
self,
|
||||||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
user_id: str,
|
||||||
):
|
min_stream_ordering: int,
|
||||||
|
max_stream_ordering: int,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> List[dict]:
|
||||||
"""Get a list of the most recent unread push actions for a given user,
|
"""Get a list of the most recent unread push actions for a given user,
|
||||||
within the given stream ordering range. Called by the httppusher.
|
within the given stream ordering range. Called by the httppusher.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user to fetch push actions for.
|
user_id: The user to fetch push actions for.
|
||||||
min_stream_ordering(int): The exclusive lower bound on the
|
min_stream_ordering: The exclusive lower bound on the
|
||||||
stream ordering of event push actions to fetch.
|
stream ordering of event push actions to fetch.
|
||||||
max_stream_ordering(int): The inclusive upper bound on the
|
max_stream_ordering: The inclusive upper bound on the
|
||||||
stream ordering of event push actions to fetch.
|
stream ordering of event push actions to fetch.
|
||||||
limit (int): The maximum number of rows to return.
|
limit: The maximum number of rows to return.
|
||||||
Returns:
|
Returns:
|
||||||
A promise which resolves to a list of dicts with the keys "event_id",
|
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
|
||||||
"room_id", "stream_ordering", "actions".
|
|
||||||
The list will be ordered by ascending stream_ordering.
|
The list will be ordered by ascending stream_ordering.
|
||||||
The list will have between 0~limit entries.
|
The list will have between 0~limit entries.
|
||||||
"""
|
"""
|
||||||
|
@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
|
|
||||||
after_read_receipt = yield self.db.runInteraction(
|
after_read_receipt = await self.db.runInteraction(
|
||||||
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
|
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
|
|
||||||
no_read_receipt = yield self.db.runInteraction(
|
no_read_receipt = await self.db.runInteraction(
|
||||||
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
|
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
# one of the subqueries may have hit the limit.
|
# one of the subqueries may have hit the limit.
|
||||||
return notifs[:limit]
|
return notifs[:limit]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_unread_push_actions_for_user_in_range_for_email(
|
||||||
def get_unread_push_actions_for_user_in_range_for_email(
|
self,
|
||||||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
user_id: str,
|
||||||
):
|
min_stream_ordering: int,
|
||||||
|
max_stream_ordering: int,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> List[dict]:
|
||||||
"""Get a list of the most recent unread push actions for a given user,
|
"""Get a list of the most recent unread push actions for a given user,
|
||||||
within the given stream ordering range. Called by the emailpusher
|
within the given stream ordering range. Called by the emailpusher
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user to fetch push actions for.
|
user_id: The user to fetch push actions for.
|
||||||
min_stream_ordering(int): The exclusive lower bound on the
|
min_stream_ordering: The exclusive lower bound on the
|
||||||
stream ordering of event push actions to fetch.
|
stream ordering of event push actions to fetch.
|
||||||
max_stream_ordering(int): The inclusive upper bound on the
|
max_stream_ordering: The inclusive upper bound on the
|
||||||
stream ordering of event push actions to fetch.
|
stream ordering of event push actions to fetch.
|
||||||
limit (int): The maximum number of rows to return.
|
limit: The maximum number of rows to return.
|
||||||
Returns:
|
Returns:
|
||||||
A promise which resolves to a list of dicts with the keys "event_id",
|
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
|
||||||
"room_id", "stream_ordering", "actions", "received_ts".
|
|
||||||
The list will be ordered by descending received_ts.
|
The list will be ordered by descending received_ts.
|
||||||
The list will have between 0~limit entries.
|
The list will have between 0~limit entries.
|
||||||
"""
|
"""
|
||||||
|
@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
|
|
||||||
after_read_receipt = yield self.db.runInteraction(
|
after_read_receipt = await self.db.runInteraction(
|
||||||
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
|
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
|
|
||||||
no_read_receipt = yield self.db.runInteraction(
|
no_read_receipt = await self.db.runInteraction(
|
||||||
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
|
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -461,17 +465,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
|
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def remove_push_actions_from_staging(self, event_id: str) -> None:
|
||||||
def remove_push_actions_from_staging(self, event_id):
|
|
||||||
"""Called if we failed to persist the event to ensure that stale push
|
"""Called if we failed to persist the event to ensure that stale push
|
||||||
actions don't build up in the DB
|
actions don't build up in the DB
|
||||||
|
|
||||||
Args:
|
|
||||||
event_id (str)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = yield self.db.simple_delete(
|
res = await self.db.simple_delete(
|
||||||
table="event_push_actions_staging",
|
table="event_push_actions_staging",
|
||||||
keyvalues={"event_id": event_id},
|
keyvalues={"event_id": event_id},
|
||||||
desc="remove_push_actions_from_staging",
|
desc="remove_push_actions_from_staging",
|
||||||
|
@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return range_end
|
return range_end
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_time_of_last_push_action_before(self, stream_ordering):
|
||||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT e.received_ts"
|
"SELECT e.received_ts"
|
||||||
|
@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
txn.execute(sql, (stream_ordering,))
|
txn.execute(sql, (stream_ordering,))
|
||||||
return txn.fetchone()
|
return txn.fetchone()
|
||||||
|
|
||||||
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
|
result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
|
@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||||
self._start_rotate_notifs, 30 * 60 * 1000
|
self._start_rotate_notifs, 30 * 60 * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_push_actions_for_user(
|
||||||
def get_push_actions_for_user(
|
|
||||||
self, user_id, before=None, limit=50, only_highlight=False
|
self, user_id, before=None, limit=50, only_highlight=False
|
||||||
):
|
):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return self.db.cursor_to_dict(txn)
|
return self.db.cursor_to_dict(txn)
|
||||||
|
|
||||||
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
|
push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
|
||||||
for pa in push_actions:
|
for pa in push_actions:
|
||||||
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
|
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
|
||||||
return push_actions
|
return push_actions
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_latest_push_action_stream_ordering(self):
|
||||||
def get_latest_push_action_stream_ordering(self):
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
|
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
|
||||||
return txn.fetchone()
|
return txn.fetchone()
|
||||||
|
|
||||||
result = yield self.db.runInteraction(
|
result = await self.db.runInteraction(
|
||||||
"get_latest_push_action_stream_ordering", f
|
"get_latest_push_action_stream_ordering", f
|
||||||
)
|
)
|
||||||
return result[0] or 0
|
return result[0] or 0
|
||||||
|
@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||||
def _start_rotate_notifs(self):
|
def _start_rotate_notifs(self):
|
||||||
return run_as_background_process("rotate_notifs", self._rotate_notifs)
|
return run_as_background_process("rotate_notifs", self._rotate_notifs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _rotate_notifs(self):
|
||||||
def _rotate_notifs(self):
|
|
||||||
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
|
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
|
||||||
return
|
return
|
||||||
self._doing_notif_rotation = True
|
self._doing_notif_rotation = True
|
||||||
|
@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||||
while True:
|
while True:
|
||||||
logger.info("Rotating notifications")
|
logger.info("Rotating notifications")
|
||||||
|
|
||||||
caught_up = yield self.db.runInteraction(
|
caught_up = await self.db.runInteraction(
|
||||||
"_rotate_notifs", self._rotate_notifs_txn
|
"_rotate_notifs", self._rotate_notifs_txn
|
||||||
)
|
)
|
||||||
if caught_up:
|
if caught_up:
|
||||||
break
|
break
|
||||||
yield self.hs.get_clock().sleep(self._rotate_delay)
|
await self.hs.get_clock().sleep(self._rotate_delay)
|
||||||
finally:
|
finally:
|
||||||
self._doing_notif_rotation = False
|
self._doing_notif_rotation = False
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||||
|
@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.data_stores.main.search import SearchStore
|
from synapse.storage.data_stores.main.search import SearchStore
|
||||||
from synapse.storage.database import Database, LoggingTransaction
|
from synapse.storage.database import Database, LoggingTransaction
|
||||||
from synapse.types import ThirdPartyInstanceID
|
from synapse.types import ThirdPartyInstanceID
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
|
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_largest_public_rooms(
|
||||||
def get_largest_public_rooms(
|
|
||||||
self,
|
self,
|
||||||
network_tuple: Optional[ThirdPartyInstanceID],
|
network_tuple: Optional[ThirdPartyInstanceID],
|
||||||
search_filter: Optional[dict],
|
search_filter: Optional[dict],
|
||||||
|
@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
ret_val = yield self.db.runInteraction(
|
ret_val = await self.db.runInteraction(
|
||||||
"get_largest_public_rooms", _get_largest_public_rooms_txn
|
"get_largest_public_rooms", _get_largest_public_rooms_txn
|
||||||
)
|
)
|
||||||
defer.returnValue(ret_val)
|
return ret_val
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
def is_room_blocked(self, room_id):
|
def is_room_blocked(self, room_id):
|
||||||
|
@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
"get_rooms_paginate", _get_rooms_paginate_txn,
|
"get_rooms_paginate", _get_rooms_paginate_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
def get_ratelimit_for_user(self, user_id):
|
async def get_ratelimit_for_user(self, user_id):
|
||||||
"""Check if there are any overrides for ratelimiting for the given
|
"""Check if there are any overrides for ratelimiting for the given
|
||||||
user
|
user
|
||||||
|
|
||||||
|
@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
of RatelimitOverride are None or 0 then ratelimitng has been
|
of RatelimitOverride are None or 0 then ratelimitng has been
|
||||||
disabled for that user entirely.
|
disabled for that user entirely.
|
||||||
"""
|
"""
|
||||||
row = yield self.db.simple_select_one(
|
row = await self.db.simple_select_one(
|
||||||
table="ratelimit_override",
|
table="ratelimit_override",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
retcols=("messages_per_second", "burst_count"),
|
retcols=("messages_per_second", "burst_count"),
|
||||||
|
@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cached()
|
||||||
def get_retention_policy_for_room(self, room_id):
|
async def get_retention_policy_for_room(self, room_id):
|
||||||
"""Get the retention policy for a given room.
|
"""Get the retention policy for a given room.
|
||||||
|
|
||||||
If no retention policy has been found for this room, returns a policy defined
|
If no retention policy has been found for this room, returns a policy defined
|
||||||
|
@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return self.db.cursor_to_dict(txn)
|
return self.db.cursor_to_dict(txn)
|
||||||
|
|
||||||
ret = yield self.db.runInteraction(
|
ret = await self.db.runInteraction(
|
||||||
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
|
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we don't know this room ID, ret will be None, in this case return the default
|
# If we don't know this room ID, ret will be None, in this case return the default
|
||||||
# policy.
|
# policy.
|
||||||
if not ret:
|
if not ret:
|
||||||
defer.returnValue(
|
return {
|
||||||
{
|
|
||||||
"min_lifetime": self.config.retention_default_min_lifetime,
|
"min_lifetime": self.config.retention_default_min_lifetime,
|
||||||
"max_lifetime": self.config.retention_default_max_lifetime,
|
"max_lifetime": self.config.retention_default_max_lifetime,
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
row = ret[0]
|
row = ret[0]
|
||||||
|
|
||||||
|
@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
if row["max_lifetime"] is None:
|
if row["max_lifetime"] is None:
|
||||||
row["max_lifetime"] = self.config.retention_default_max_lifetime
|
row["max_lifetime"] = self.config.retention_default_max_lifetime
|
||||||
|
|
||||||
defer.returnValue(row)
|
return row
|
||||||
|
|
||||||
def get_media_mxcs_in_room(self, room_id):
|
def get_media_mxcs_in_room(self, room_id):
|
||||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||||
|
@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
||||||
self._background_add_rooms_room_version_column,
|
self._background_add_rooms_room_version_column,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _background_insert_retention(self, progress, batch_size):
|
||||||
def _background_insert_retention(self, progress, batch_size):
|
|
||||||
"""Retrieves a list of all rooms within a range and inserts an entry for each of
|
"""Retrieves a list of all rooms within a range and inserts an entry for each of
|
||||||
them into the room_retention table.
|
them into the room_retention table.
|
||||||
NULLs the property's columns if missing from the retention event in the room's
|
NULLs the property's columns if missing from the retention event in the room's
|
||||||
|
@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
end = yield self.db.runInteraction(
|
end = await self.db.runInteraction(
|
||||||
"insert_room_retention", _background_insert_retention_txn,
|
"insert_room_retention", _background_insert_retention_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
if end:
|
if end:
|
||||||
yield self.db.updates._end_background_update("insert_room_retention")
|
await self.db.updates._end_background_update("insert_room_retention")
|
||||||
|
|
||||||
defer.returnValue(batch_size)
|
return batch_size
|
||||||
|
|
||||||
async def _background_add_rooms_room_version_column(
|
async def _background_add_rooms_room_version_column(
|
||||||
self, progress: dict, batch_size: int
|
self, progress: dict, batch_size: int
|
||||||
|
@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def store_room(
|
||||||
def store_room(
|
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
room_creator_user_id: str,
|
room_creator_user_id: str,
|
||||||
|
@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._public_room_id_gen.get_next() as next_id:
|
with self._public_room_id_gen.get_next() as next_id:
|
||||||
yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
|
await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||||
raise StoreError(500, "Problem creating room.")
|
raise StoreError(500, "Problem creating room.")
|
||||||
|
@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def set_room_is_public(self, room_id, is_public):
|
||||||
def set_room_is_public(self, room_id, is_public):
|
|
||||||
def set_room_is_public_txn(txn, next_id):
|
def set_room_is_public_txn(txn, next_id):
|
||||||
self.db.simple_update_one_txn(
|
self.db.simple_update_one_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._public_room_id_gen.get_next() as next_id:
|
with self._public_room_id_gen.get_next() as next_id:
|
||||||
yield self.db.runInteraction(
|
await self.db.runInteraction(
|
||||||
"set_room_is_public", set_room_is_public_txn, next_id
|
"set_room_is_public", set_room_is_public_txn, next_id
|
||||||
)
|
)
|
||||||
self.hs.get_notifier().on_new_replication_data()
|
self.hs.get_notifier().on_new_replication_data()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def set_room_is_public_appservice(
|
||||||
def set_room_is_public_appservice(
|
|
||||||
self, room_id, appservice_id, network_id, is_public
|
self, room_id, appservice_id, network_id, is_public
|
||||||
):
|
):
|
||||||
"""Edit the appservice/network specific public room list.
|
"""Edit the appservice/network specific public room list.
|
||||||
|
@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._public_room_id_gen.get_next() as next_id:
|
with self._public_room_id_gen.get_next() as next_id:
|
||||||
yield self.db.runInteraction(
|
await self.db.runInteraction(
|
||||||
"set_room_is_public_appservice",
|
"set_room_is_public_appservice",
|
||||||
set_room_is_public_appservice_txn,
|
set_room_is_public_appservice_txn,
|
||||||
next_id,
|
next_id,
|
||||||
|
@ -1327,49 +1318,44 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
def get_current_public_room_stream_id(self):
|
def get_current_public_room_stream_id(self):
|
||||||
return self._public_room_id_gen.get_current_token()
|
return self._public_room_id_gen.get_current_token()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def block_room(self, room_id: str, user_id: str) -> None:
|
||||||
def block_room(self, room_id, user_id):
|
|
||||||
"""Marks the room as blocked. Can be called multiple times.
|
"""Marks the room as blocked. Can be called multiple times.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str): Room to block
|
room_id: Room to block
|
||||||
user_id (str): Who blocked it
|
user_id: Who blocked it
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
"""
|
||||||
yield self.db.simple_upsert(
|
await self.db.simple_upsert(
|
||||||
table="blocked_rooms",
|
table="blocked_rooms",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
values={},
|
values={},
|
||||||
insertion_values={"user_id": user_id},
|
insertion_values={"user_id": user_id},
|
||||||
desc="block_room",
|
desc="block_room",
|
||||||
)
|
)
|
||||||
yield self.db.runInteraction(
|
await self.db.runInteraction(
|
||||||
"block_room_invalidation",
|
"block_room_invalidation",
|
||||||
self._invalidate_cache_and_stream,
|
self._invalidate_cache_and_stream,
|
||||||
self.is_room_blocked,
|
self.is_room_blocked,
|
||||||
(room_id,),
|
(room_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_rooms_for_retention_period_in_range(
|
||||||
def get_rooms_for_retention_period_in_range(
|
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
|
||||||
self, min_ms, max_ms, include_null=False
|
) -> Dict[str, dict]:
|
||||||
):
|
|
||||||
"""Retrieves all of the rooms within the given retention range.
|
"""Retrieves all of the rooms within the given retention range.
|
||||||
|
|
||||||
Optionally includes the rooms which don't have a retention policy.
|
Optionally includes the rooms which don't have a retention policy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
min_ms (int|None): Duration in milliseconds that define the lower limit of
|
min_ms: Duration in milliseconds that define the lower limit of
|
||||||
the range to handle (exclusive). If None, doesn't set a lower limit.
|
the range to handle (exclusive). If None, doesn't set a lower limit.
|
||||||
max_ms (int|None): Duration in milliseconds that define the upper limit of
|
max_ms: Duration in milliseconds that define the upper limit of
|
||||||
the range to handle (inclusive). If None, doesn't set an upper limit.
|
the range to handle (inclusive). If None, doesn't set an upper limit.
|
||||||
include_null (bool): Whether to include rooms which retention policy is NULL
|
include_null: Whether to include rooms which retention policy is NULL
|
||||||
in the returned set.
|
in the returned set.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict]: The rooms within this range, along with their retention
|
The rooms within this range, along with their retention
|
||||||
policy. The key is "room_id", and maps to a dict describing the retention
|
policy. The key is "room_id", and maps to a dict describing the retention
|
||||||
policy associated with this room ID. The keys for this nested dict are
|
policy associated with this room ID. The keys for this nested dict are
|
||||||
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
||||||
|
@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
|
|
||||||
return rooms_dict
|
return rooms_dict
|
||||||
|
|
||||||
rooms = yield self.db.runInteraction(
|
rooms = await self.db.runInteraction(
|
||||||
"get_rooms_for_retention_period_in_range",
|
"get_rooms_for_retention_period_in_range",
|
||||||
get_rooms_for_retention_period_in_range_txn,
|
get_rooms_for_retention_period_in_range_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(rooms)
|
return rooms
|
||||||
|
|
|
@ -16,12 +16,12 @@
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from typing import Iterable, Optional, Set
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
|
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
|
||||||
|
@ -108,16 +108,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
create_event = await self.get_create_event_for_room(room_id)
|
create_event = await self.get_create_event_for_room(room_id)
|
||||||
return create_event.content.get("room_version", "1")
|
return create_event.content.get("room_version", "1")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
|
||||||
def get_room_predecessor(self, room_id):
|
|
||||||
"""Get the predecessor of an upgraded room if it exists.
|
"""Get the predecessor of an upgraded room if it exists.
|
||||||
Otherwise return None.
|
Otherwise return None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id: The room ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict|None]: A dictionary containing the structure of the predecessor
|
A dictionary containing the structure of the predecessor
|
||||||
field from the room's create event. The structure is subject to other servers,
|
field from the room's create event. The structure is subject to other servers,
|
||||||
but it is expected to be:
|
but it is expected to be:
|
||||||
* room_id (str): The room ID of the predecessor room
|
* room_id (str): The room ID of the predecessor room
|
||||||
|
@ -129,7 +128,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
NotFoundError if the given room is unknown
|
NotFoundError if the given room is unknown
|
||||||
"""
|
"""
|
||||||
# Retrieve the room's create event
|
# Retrieve the room's create event
|
||||||
create_event = yield self.get_create_event_for_room(room_id)
|
create_event = await self.get_create_event_for_room(room_id)
|
||||||
|
|
||||||
# Retrieve the predecessor key of the create event
|
# Retrieve the predecessor key of the create event
|
||||||
predecessor = create_event.content.get("predecessor", None)
|
predecessor = create_event.content.get("predecessor", None)
|
||||||
|
@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return predecessor
|
return predecessor
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_create_event_for_room(self, room_id: str) -> EventBase:
|
||||||
def get_create_event_for_room(self, room_id):
|
|
||||||
"""Get the create state event for a room.
|
"""Get the create state event for a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id: The room ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[EventBase]: The room creation event.
|
The room creation event.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError if the room is unknown
|
NotFoundError if the room is unknown
|
||||||
"""
|
"""
|
||||||
state_ids = yield self.get_current_state_ids(room_id)
|
state_ids = await self.get_current_state_ids(room_id)
|
||||||
create_id = state_ids.get((EventTypes.Create, ""))
|
create_id = state_ids.get((EventTypes.Create, ""))
|
||||||
|
|
||||||
# If we can't find the create event, assume we've hit a dead end
|
# If we can't find the create event, assume we've hit a dead end
|
||||||
|
@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
raise NotFoundError("Unknown room %s" % (room_id,))
|
raise NotFoundError("Unknown room %s" % (room_id,))
|
||||||
|
|
||||||
# Retrieve the room's create event and return
|
# Retrieve the room's create event and return
|
||||||
create_event = yield self.get_event(create_id)
|
create_event = await self.get_event(create_id)
|
||||||
return create_event
|
return create_event
|
||||||
|
|
||||||
@cached(max_entries=100000, iterable=True)
|
@cached(max_entries=100000, iterable=True)
|
||||||
|
@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
|
||||||
def get_canonical_alias_for_room(self, room_id):
|
|
||||||
"""Get canonical alias for room, if any
|
"""Get canonical alias for room, if any
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id: The room ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[str|None]: The canonical alias, if any
|
The canonical alias, if any
|
||||||
"""
|
"""
|
||||||
|
|
||||||
state = yield self.get_filtered_current_state_ids(
|
state = await self.get_filtered_current_state_ids(
|
||||||
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
if not event_id:
|
if not event_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
event = yield self.get_event(event_id, allow_none=True)
|
event = await self.get_event(event_id, allow_none=True)
|
||||||
if not event:
|
if not event:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return {row["event_id"]: row["state_group"] for row in rows}
|
return {row["event_id"]: row["state_group"] for row in rows}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_referenced_state_groups(
|
||||||
def get_referenced_state_groups(self, state_groups):
|
self, state_groups: Iterable[int]
|
||||||
|
) -> Set[int]:
|
||||||
"""Check if the state groups are referenced by events.
|
"""Check if the state groups are referenced by events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_groups (Iterable[int])
|
state_groups
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[set[int]]: The subset of state groups that are
|
The subset of state groups that are referenced.
|
||||||
referenced.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = yield self.db.simple_select_many_batch(
|
rows = await self.db.simple_select_many_batch(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
column="state_group",
|
column="state_group",
|
||||||
iterable=state_groups,
|
iterable=state_groups,
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.defer import DeferredLock
|
from twisted.internet.defer import DeferredLock
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore):
|
||||||
"""
|
"""
|
||||||
return (ts // self.stats_bucket_size) * self.stats_bucket_size
|
return (ts // self.stats_bucket_size) * self.stats_bucket_size
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _populate_stats_process_users(self, progress, batch_size):
|
||||||
def _populate_stats_process_users(self, progress, batch_size):
|
|
||||||
"""
|
"""
|
||||||
This is a background update which regenerates statistics for users.
|
This is a background update which regenerates statistics for users.
|
||||||
"""
|
"""
|
||||||
if not self.stats_enabled:
|
if not self.stats_enabled:
|
||||||
yield self.db.updates._end_background_update("populate_stats_process_users")
|
await self.db.updates._end_background_update("populate_stats_process_users")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
last_user_id = progress.get("last_user_id", "")
|
last_user_id = progress.get("last_user_id", "")
|
||||||
|
@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore):
|
||||||
txn.execute(sql, (last_user_id, batch_size))
|
txn.execute(sql, (last_user_id, batch_size))
|
||||||
return [r for r, in txn]
|
return [r for r, in txn]
|
||||||
|
|
||||||
users_to_work_on = yield self.db.runInteraction(
|
users_to_work_on = await self.db.runInteraction(
|
||||||
"_populate_stats_process_users", _get_next_batch
|
"_populate_stats_process_users", _get_next_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# No more rooms -- complete the transaction.
|
# No more rooms -- complete the transaction.
|
||||||
if not users_to_work_on:
|
if not users_to_work_on:
|
||||||
yield self.db.updates._end_background_update("populate_stats_process_users")
|
await self.db.updates._end_background_update("populate_stats_process_users")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
for user_id in users_to_work_on:
|
for user_id in users_to_work_on:
|
||||||
yield self._calculate_and_set_initial_state_for_user(user_id)
|
await self._calculate_and_set_initial_state_for_user(user_id)
|
||||||
progress["last_user_id"] = user_id
|
progress["last_user_id"] = user_id
|
||||||
|
|
||||||
yield self.db.runInteraction(
|
await self.db.runInteraction(
|
||||||
"populate_stats_process_users",
|
"populate_stats_process_users",
|
||||||
self.db.updates._background_update_progress_txn,
|
self.db.updates._background_update_progress_txn,
|
||||||
"populate_stats_process_users",
|
"populate_stats_process_users",
|
||||||
|
@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore):
|
||||||
|
|
||||||
return len(users_to_work_on)
|
return len(users_to_work_on)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _populate_stats_process_rooms(self, progress, batch_size):
|
||||||
def _populate_stats_process_rooms(self, progress, batch_size):
|
|
||||||
"""
|
"""
|
||||||
This is a background update which regenerates statistics for rooms.
|
This is a background update which regenerates statistics for rooms.
|
||||||
"""
|
"""
|
||||||
if not self.stats_enabled:
|
if not self.stats_enabled:
|
||||||
yield self.db.updates._end_background_update("populate_stats_process_rooms")
|
await self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
last_room_id = progress.get("last_room_id", "")
|
last_room_id = progress.get("last_room_id", "")
|
||||||
|
@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore):
|
||||||
txn.execute(sql, (last_room_id, batch_size))
|
txn.execute(sql, (last_room_id, batch_size))
|
||||||
return [r for r, in txn]
|
return [r for r, in txn]
|
||||||
|
|
||||||
rooms_to_work_on = yield self.db.runInteraction(
|
rooms_to_work_on = await self.db.runInteraction(
|
||||||
"populate_stats_rooms_get_batch", _get_next_batch
|
"populate_stats_rooms_get_batch", _get_next_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# No more rooms -- complete the transaction.
|
# No more rooms -- complete the transaction.
|
||||||
if not rooms_to_work_on:
|
if not rooms_to_work_on:
|
||||||
yield self.db.updates._end_background_update("populate_stats_process_rooms")
|
await self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
for room_id in rooms_to_work_on:
|
for room_id in rooms_to_work_on:
|
||||||
yield self._calculate_and_set_initial_state_for_room(room_id)
|
await self._calculate_and_set_initial_state_for_room(room_id)
|
||||||
progress["last_room_id"] = room_id
|
progress["last_room_id"] = room_id
|
||||||
|
|
||||||
yield self.db.runInteraction(
|
await self.db.runInteraction(
|
||||||
"_populate_stats_process_rooms",
|
"_populate_stats_process_rooms",
|
||||||
self.db.updates._background_update_progress_txn,
|
self.db.updates._background_update_progress_txn,
|
||||||
"populate_stats_process_rooms",
|
"populate_stats_process_rooms",
|
||||||
|
@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore):
|
||||||
|
|
||||||
return room_deltas, user_deltas
|
return room_deltas, user_deltas
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _calculate_and_set_initial_state_for_room(
|
||||||
def _calculate_and_set_initial_state_for_room(self, room_id):
|
self, room_id: str
|
||||||
|
) -> Tuple[dict, dict, int]:
|
||||||
"""Calculate and insert an entry into room_stats_current.
|
"""Calculate and insert an entry into room_stats_current.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id: The room ID under calculation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
|
A tuple of room state, membership counts and stream position.
|
||||||
counts and stream position.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _fetch_current_state_stats(txn):
|
def _fetch_current_state_stats(txn):
|
||||||
|
@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore):
|
||||||
current_state_events_count,
|
current_state_events_count,
|
||||||
users_in_room,
|
users_in_room,
|
||||||
pos,
|
pos,
|
||||||
) = yield self.db.runInteraction(
|
) = await self.db.runInteraction(
|
||||||
"get_initial_state_for_room", _fetch_current_state_stats
|
"get_initial_state_for_room", _fetch_current_state_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
state_event_map = yield self.get_events(event_ids, get_prev_content=False)
|
state_event_map = await self.get_events(event_ids, get_prev_content=False)
|
||||||
|
|
||||||
room_state = {
|
room_state = {
|
||||||
"join_rules": None,
|
"join_rules": None,
|
||||||
|
@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore):
|
||||||
event.content.get("m.federate", True) is True
|
event.content.get("m.federate", True) is True
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.update_room_state(room_id, room_state)
|
await self.update_room_state(room_id, room_state)
|
||||||
|
|
||||||
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
|
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
|
||||||
|
|
||||||
yield self.update_stats_delta(
|
await self.update_stats_delta(
|
||||||
ts=self.clock.time_msec(),
|
ts=self.clock.time_msec(),
|
||||||
stats_type="room",
|
stats_type="room",
|
||||||
stats_id=room_id,
|
stats_id=room_id,
|
||||||
|
@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _calculate_and_set_initial_state_for_user(self, user_id):
|
||||||
def _calculate_and_set_initial_state_for_user(self, user_id):
|
|
||||||
def _calculate_and_set_initial_state_for_user_txn(txn):
|
def _calculate_and_set_initial_state_for_user_txn(txn):
|
||||||
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
|
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
|
||||||
|
|
||||||
|
@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore):
|
||||||
(count,) = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count, pos
|
return count, pos
|
||||||
|
|
||||||
joined_rooms, pos = yield self.db.runInteraction(
|
joined_rooms, pos = await self.db.runInteraction(
|
||||||
"calculate_and_set_initial_state_for_user",
|
"calculate_and_set_initial_state_for_user",
|
||||||
_calculate_and_set_initial_state_for_user_txn,
|
_calculate_and_set_initial_state_for_user_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.update_stats_delta(
|
await self.update_stats_delta(
|
||||||
ts=self.clock.time_msec(),
|
ts=self.clock.time_msec(),
|
||||||
stats_type="user",
|
stats_type="user",
|
||||||
stats_id=user_id,
|
stats_id=user_id,
|
||||||
|
|
|
@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
"get_state_group_delta", _get_state_group_delta_txn
|
"get_state_group_delta", _get_state_group_delta_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_state_groups_from_groups(
|
||||||
def _get_state_groups_from_groups(
|
|
||||||
self, groups: List[int], state_filter: StateFilter
|
self, groups: List[int], state_filter: StateFilter
|
||||||
):
|
) -> Dict[int, StateMap[str]]:
|
||||||
"""Returns the state groups for a given set of groups from the
|
"""Returns the state groups for a given set of groups from the
|
||||||
database, filtering on types of state events.
|
database, filtering on types of state events.
|
||||||
|
|
||||||
|
@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
state_filter: The state filter used to fetch state
|
state_filter: The state filter used to fetch state
|
||||||
from the database.
|
from the database.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
Dict of state group to state map.
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
|
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
res = yield self.db.runInteraction(
|
res = await self.db.runInteraction(
|
||||||
"_get_state_groups_from_groups",
|
"_get_state_groups_from_groups",
|
||||||
self._get_state_groups_from_groups_txn,
|
self._get_state_groups_from_groups_txn,
|
||||||
chunk,
|
chunk,
|
||||||
|
@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
|
|
||||||
return state_filter.filter_state(state_dict_ids), not missing_types
|
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_state_for_groups(
|
||||||
def _get_state_for_groups(
|
|
||||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||||
):
|
) -> Dict[int, StateMap[str]]:
|
||||||
"""Gets the state at each of a list of state groups, optionally
|
"""Gets the state at each of a list of state groups, optionally
|
||||||
filtering by type/state_key
|
filtering by type/state_key
|
||||||
|
|
||||||
|
@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
state_filter: The state filter used to fetch state
|
state_filter: The state filter used to fetch state
|
||||||
from the database.
|
from the database.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
Dict of state group to state map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
member_filter, non_member_filter = state_filter.get_member_split()
|
member_filter, non_member_filter = state_filter.get_member_split()
|
||||||
|
@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
(
|
(
|
||||||
non_member_state,
|
non_member_state,
|
||||||
incomplete_groups_nm,
|
incomplete_groups_nm,
|
||||||
) = yield self._get_state_for_groups_using_cache(
|
) = self._get_state_for_groups_using_cache(
|
||||||
groups, self._state_group_cache, state_filter=non_member_filter
|
groups, self._state_group_cache, state_filter=non_member_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
|
||||||
member_state,
|
|
||||||
incomplete_groups_m,
|
|
||||||
) = yield self._get_state_for_groups_using_cache(
|
|
||||||
groups, self._state_group_members_cache, state_filter=member_filter
|
groups, self._state_group_members_cache, state_filter=member_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
# Help the cache hit ratio by expanding the filter a bit
|
# Help the cache hit ratio by expanding the filter a bit
|
||||||
db_state_filter = state_filter.return_expanded()
|
db_state_filter = state_filter.return_expanded()
|
||||||
|
|
||||||
group_to_state_dict = yield self._get_state_groups_from_groups(
|
group_to_state_dict = await self._get_state_groups_from_groups(
|
||||||
list(incomplete_groups), state_filter=db_state_filter
|
list(incomplete_groups), state_filter=db_state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
((sg,) for sg in state_groups_to_delete),
|
((sg,) for sg in state_groups_to_delete),
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_previous_state_groups(
|
||||||
def get_previous_state_groups(self, state_groups):
|
self, state_groups: Iterable[int]
|
||||||
|
) -> Dict[int, int]:
|
||||||
"""Fetch the previous groups of the given state groups.
|
"""Fetch the previous groups of the given state groups.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_groups (Iterable[int])
|
state_groups
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[int, int]]: mapping from state group to previous
|
A mapping from state group to previous state group.
|
||||||
state group.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = yield self.db.simple_select_many_batch(
|
rows = await self.db.simple_select_many_batch(
|
||||||
table="state_group_edges",
|
table="state_group_edges",
|
||||||
column="prev_state_group",
|
column="prev_state_group",
|
||||||
iterable=state_groups,
|
iterable=state_groups,
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar
|
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -419,7 +419,7 @@ class StateGroupStorage(object):
|
||||||
|
|
||||||
def _get_state_groups_from_groups(
|
def _get_state_groups_from_groups(
|
||||||
self, groups: List[int], state_filter: StateFilter
|
self, groups: List[int], state_filter: StateFilter
|
||||||
):
|
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||||
"""Returns the state groups for a given set of groups, filtering on
|
"""Returns the state groups for a given set of groups, filtering on
|
||||||
types of state events.
|
types of state events.
|
||||||
|
|
||||||
|
@ -429,7 +429,7 @@ class StateGroupStorage(object):
|
||||||
from the database.
|
from the database.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
Dict of state group to state map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||||
|
@ -532,7 +532,7 @@ class StateGroupStorage(object):
|
||||||
|
|
||||||
def _get_state_for_groups(
|
def _get_state_for_groups(
|
||||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||||
):
|
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||||
"""Gets the state at each of a list of state groups, optionally
|
"""Gets the state at each of a list of state groups, optionally
|
||||||
filtering by type/state_key
|
filtering by type/state_key
|
||||||
|
|
||||||
|
@ -540,8 +540,9 @@ class StateGroupStorage(object):
|
||||||
groups: list of state groups for which we want to get the state.
|
groups: list of state groups for which we want to get the state.
|
||||||
state_filter: The state filter used to fetch state.
|
state_filter: The state filter used to fetch state.
|
||||||
from the database.
|
from the database.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
|
Dict of state group to state map.
|
||||||
"""
|
"""
|
||||||
return self.stores.state._get_state_for_groups(groups, state_filter)
|
return self.stores.state._get_state_for_groups(groups, state_filter)
|
||||||
|
|
||||||
|
|
|
@ -39,15 +39,19 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_unread_push_actions_for_user_in_range_for_http(self):
|
def test_get_unread_push_actions_for_user_in_range_for_http(self):
|
||||||
yield self.store.get_unread_push_actions_for_user_in_range_for_http(
|
yield defer.ensureDeferred(
|
||||||
|
self.store.get_unread_push_actions_for_user_in_range_for_http(
|
||||||
USER_ID, 0, 1000, 20
|
USER_ID, 0, 1000, 20
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_unread_push_actions_for_user_in_range_for_email(self):
|
def test_get_unread_push_actions_for_user_in_range_for_email(self):
|
||||||
yield self.store.get_unread_push_actions_for_user_in_range_for_email(
|
yield defer.ensureDeferred(
|
||||||
|
self.store.get_unread_push_actions_for_user_in_range_for_email(
|
||||||
USER_ID, 0, 1000, 20
|
USER_ID, 0, 1000, 20
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_count_aggregation(self):
|
def test_count_aggregation(self):
|
||||||
|
|
|
@ -37,12 +37,14 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||||
self.alias = RoomAlias.from_string("#a-room-name:test")
|
self.alias = RoomAlias.from_string("#a-room-name:test")
|
||||||
self.u_creator = UserID.from_string("@creator:test")
|
self.u_creator = UserID.from_string("@creator:test")
|
||||||
|
|
||||||
yield self.store.store_room(
|
yield defer.ensureDeferred(
|
||||||
|
self.store.store_room(
|
||||||
self.room.to_string(),
|
self.room.to_string(),
|
||||||
room_creator_user_id=self.u_creator.to_string(),
|
room_creator_user_id=self.u_creator.to_string(),
|
||||||
is_public=True,
|
is_public=True,
|
||||||
room_version=RoomVersions.V1,
|
room_version=RoomVersions.V1,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_room(self):
|
def test_get_room(self):
|
||||||
|
@ -88,12 +90,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.room = RoomID.from_string("!abcde:test")
|
self.room = RoomID.from_string("!abcde:test")
|
||||||
|
|
||||||
yield self.store.store_room(
|
yield defer.ensureDeferred(
|
||||||
|
self.store.store_room(
|
||||||
self.room.to_string(),
|
self.room.to_string(),
|
||||||
room_creator_user_id="@creator:text",
|
room_creator_user_id="@creator:text",
|
||||||
is_public=True,
|
is_public=True,
|
||||||
room_version=RoomVersions.V1,
|
room_version=RoomVersions.V1,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def inject_room_event(self, **kwargs):
|
def inject_room_event(self, **kwargs):
|
||||||
|
|
|
@ -44,12 +44,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
self.room = RoomID.from_string("!abc123:test")
|
self.room = RoomID.from_string("!abc123:test")
|
||||||
|
|
||||||
yield self.store.store_room(
|
yield defer.ensureDeferred(
|
||||||
|
self.store.store_room(
|
||||||
self.room.to_string(),
|
self.room.to_string(),
|
||||||
room_creator_user_id="@creator:text",
|
room_creator_user_id="@creator:text",
|
||||||
is_public=True,
|
is_public=True,
|
||||||
room_version=RoomVersions.V1,
|
room_version=RoomVersions.V1,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def inject_state_event(self, room, sender, typ, state_key, content):
|
def inject_state_event(self, room, sender, typ, state_key, content):
|
||||||
|
|
Loading…
Reference in a new issue