mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-19 03:01:57 +01:00
Convert additional databases to async/await part 3 (#8201)
This commit is contained in:
parent
7d103a594e
commit
37db6252b7
7 changed files with 121 additions and 87 deletions
1
changelog.d/8201.misc
Normal file
1
changelog.d/8201.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -433,7 +433,7 @@ class BackgroundUpdater(object):
|
||||||
"background_updates", keyvalues={"update_name": update_name}
|
"background_updates", keyvalues={"update_name": update_name}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _background_update_progress(self, update_name: str, progress: dict):
|
async def _background_update_progress(self, update_name: str, progress: dict):
|
||||||
"""Update the progress of a background update
|
"""Update the progress of a background update
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -441,7 +441,7 @@ class BackgroundUpdater(object):
|
||||||
progress: The progress of the update.
|
progress: The progress of the update.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"background_update_progress",
|
"background_update_progress",
|
||||||
self._background_update_progress_txn,
|
self._background_update_progress_txn,
|
||||||
update_name,
|
update_name,
|
||||||
|
|
|
@ -16,9 +16,7 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
|
@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_account_data_for_user(self, user_id):
|
async def get_account_data_for_user(
|
||||||
|
self, user_id: str
|
||||||
|
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
|
||||||
"""Get all the client account_data for a user.
|
"""Get all the client account_data for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to get the account_data for.
|
user_id: The user to get the account_data for.
|
||||||
Returns:
|
Returns:
|
||||||
A deferred pair of a dict of global account_data and a dict
|
A 2-tuple of a dict of global account_data and a dict mapping from
|
||||||
mapping from room_id string to per room account_data dicts.
|
room_id string to per room account_data dicts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_account_data_for_user_txn(txn):
|
def get_account_data_for_user_txn(txn):
|
||||||
|
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return global_account_data, by_room
|
return global_account_data, by_room
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_account_data_for_user", get_account_data_for_user_txn
|
"get_account_data_for_user", get_account_data_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
def get_account_data_for_room(self, user_id, room_id):
|
async def get_account_data_for_room(
|
||||||
|
self, user_id: str, room_id: str
|
||||||
|
) -> Dict[str, JsonDict]:
|
||||||
"""Get all the client account_data for a user for a room.
|
"""Get all the client account_data for a user for a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to get the account_data for.
|
user_id: The user to get the account_data for.
|
||||||
room_id(str): The room to get the account_data for.
|
room_id: The room to get the account_data for.
|
||||||
Returns:
|
Returns:
|
||||||
A deferred dict of the room account_data
|
A dict of the room account_data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_account_data_for_room_txn(txn):
|
def get_account_data_for_room_txn(txn):
|
||||||
|
@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_account_data_for_room", get_account_data_for_room_txn
|
"get_account_data_for_room", get_account_data_for_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(num_args=3, max_entries=5000)
|
@cached(num_args=3, max_entries=5000)
|
||||||
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
|
async def get_account_data_for_room_and_type(
|
||||||
|
self, user_id: str, room_id: str, account_data_type: str
|
||||||
|
) -> Optional[JsonDict]:
|
||||||
"""Get the client account_data of given type for a user for a room.
|
"""Get the client account_data of given type for a user for a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to get the account_data for.
|
user_id: The user to get the account_data for.
|
||||||
room_id(str): The room to get the account_data for.
|
room_id: The room to get the account_data for.
|
||||||
account_data_type (str): The account data type to get.
|
account_data_type: The account data type to get.
|
||||||
Returns:
|
Returns:
|
||||||
A deferred of the room account_data for that type, or None if
|
The room account_data for that type, or None if there isn't any set.
|
||||||
there isn't any set.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_account_data_for_room_and_type_txn(txn):
|
def get_account_data_for_room_and_type_txn(txn):
|
||||||
|
@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return db_to_json(content_json) if content_json else None
|
return db_to_json(content_json) if content_json else None
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
"get_updated_room_account_data", get_updated_room_account_data_txn
|
"get_updated_room_account_data", get_updated_room_account_data_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_updated_account_data_for_user(self, user_id, stream_id):
|
async def get_updated_account_data_for_user(
|
||||||
|
self, user_id: str, stream_id: int
|
||||||
|
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
|
||||||
"""Get all the client account_data for a that's changed for a user
|
"""Get all the client account_data for a that's changed for a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to get the account_data for.
|
user_id: The user to get the account_data for.
|
||||||
stream_id(int): The point in the stream since which to get updates
|
stream_id: The point in the stream since which to get updates
|
||||||
Returns:
|
Returns:
|
||||||
A deferred pair of a dict of global account_data and a dict
|
A deferred pair of a dict of global account_data and a dict
|
||||||
mapping from room_id string to per room account_data dicts.
|
mapping from room_id string to per room account_data dicts.
|
||||||
|
@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
user_id, int(stream_id)
|
user_id, int(stream_id)
|
||||||
)
|
)
|
||||||
if not changed:
|
if not changed:
|
||||||
return defer.succeed(({}, {}))
|
return ({}, {})
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
|
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
def _update_max_stream_id(self, next_id: int):
|
async def _update_max_stream_id(self, next_id: int) -> None:
|
||||||
"""Update the max stream_id
|
"""Update the max stream_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
)
|
)
|
||||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||||
|
|
||||||
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
|
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
|
||||||
|
|
|
@ -34,13 +34,15 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
def get_e2e_device_keys_for_federation_query(self, user_id: str):
|
async def get_e2e_device_keys_for_federation_query(
|
||||||
|
self, user_id: str
|
||||||
|
) -> Tuple[int, List[JsonDict]]:
|
||||||
"""Get all devices (with any device keys) for a user
|
"""Get all devices (with any device keys) for a user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred which resolves to (stream_id, devices)
|
(stream_id, devices)
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_e2e_device_keys_for_federation_query",
|
"get_e2e_device_keys_for_federation_query",
|
||||||
self._get_e2e_device_keys_for_federation_query_txn,
|
self._get_e2e_device_keys_for_federation_query_txn,
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -292,10 +294,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
def count_e2e_one_time_keys(self, user_id, device_id):
|
async def count_e2e_one_time_keys(
|
||||||
|
self, user_id: str, device_id: str
|
||||||
|
) -> Dict[str, int]:
|
||||||
""" Count the number of one time keys the server has for a device
|
""" Count the number of one time keys the server has for a device
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping from algorithm to number of keys for that algorithm.
|
A mapping from algorithm to number of keys for that algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_e2e_one_time_keys(txn):
|
def _count_e2e_one_time_keys(txn):
|
||||||
|
@ -310,7 +314,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
result[algorithm] = key_count
|
result[algorithm] = key_count
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -348,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
num_args=1,
|
num_args=1,
|
||||||
)
|
)
|
||||||
def _get_bare_e2e_cross_signing_keys_bulk(
|
async def _get_bare_e2e_cross_signing_keys_bulk(
|
||||||
self, user_ids: List[str]
|
self, user_ids: List[str]
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
"""Returns the cross-signing keys for a set of users. The output of this
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
|
@ -356,16 +360,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
the signatures for the calling user need to be fetched.
|
the signatures for the calling user need to be fetched.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_ids (list[str]): the users whose keys are being requested
|
user_ids: the users whose keys are being requested
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
A mapping from user ID to key type to key data. If a user's cross-signing
|
||||||
data. If a user's cross-signing keys were not found, either
|
keys were not found, either their user ID will not be in the dict, or
|
||||||
their user ID will not be in the dict, or their user ID will map
|
their user ID will map to None.
|
||||||
to None.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_bare_e2e_cross_signing_keys_bulk",
|
"get_bare_e2e_cross_signing_keys_bulk",
|
||||||
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
||||||
user_ids,
|
user_ids,
|
||||||
|
@ -588,7 +591,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
async def set_e2e_device_keys(
|
||||||
|
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
|
||||||
|
) -> bool:
|
||||||
"""Stores device keys for a device. Returns whether there was a change
|
"""Stores device keys for a device. Returns whether there was a change
|
||||||
or the keys were already in the database.
|
or the keys were already in the database.
|
||||||
"""
|
"""
|
||||||
|
@ -624,12 +629,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
log_kv({"message": "Device keys stored."})
|
log_kv({"message": "Device keys stored."})
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"set_e2e_device_keys", _set_e2e_device_keys_txn
|
"set_e2e_device_keys", _set_e2e_device_keys_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def claim_e2e_one_time_keys(self, query_list):
|
async def claim_e2e_one_time_keys(
|
||||||
"""Take a list of one time keys out of the database"""
|
self, query_list: Iterable[Tuple[str, str, str]]
|
||||||
|
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
|
||||||
|
"""Take a list of one time keys out of the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
def _claim_e2e_one_time_keys(txn):
|
def _claim_e2e_one_time_keys(txn):
|
||||||
|
@ -665,11 +679,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_e2e_keys_by_device(self, user_id, device_id):
|
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
||||||
def delete_e2e_keys_by_device_txn(txn):
|
def delete_e2e_keys_by_device_txn(txn):
|
||||||
log_kv(
|
log_kv(
|
||||||
{
|
{
|
||||||
|
@ -692,7 +706,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
|
@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
desc="mark_local_media_as_safe",
|
desc="mark_local_media_as_safe",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_url_cache(self, url, ts):
|
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
|
||||||
"""Get the media_id and ts for a cached URL as of the given timestamp
|
"""Get the media_id and ts for a cached URL as of the given timestamp
|
||||||
Returns:
|
Returns:
|
||||||
None if the URL isn't cached.
|
None if the URL isn't cached.
|
||||||
|
@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
||||||
|
|
||||||
async def store_url_cache(
|
async def store_url_cache(
|
||||||
self, url, response_code, etag, expires_ts, og, media_id, download_ts
|
self, url, response_code, etag, expires_ts, og, media_id, download_ts
|
||||||
|
@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
desc="store_cached_remote_media",
|
desc="store_cached_remote_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_cached_last_access_time(self, local_media, remote_media, time_ms):
|
async def update_cached_last_access_time(
|
||||||
|
self,
|
||||||
|
local_media: Iterable[str],
|
||||||
|
remote_media: Iterable[Tuple[str, str]],
|
||||||
|
time_ms: int,
|
||||||
|
):
|
||||||
"""Updates the last access time of the given media
|
"""Updates the last access time of the given media
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_media (iterable[str]): Set of media_ids
|
local_media: Set of media_ids
|
||||||
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
|
remote_media: Set of (server_name, media_id)
|
||||||
time_ms: Current time in milliseconds
|
time_ms: Current time in milliseconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
|
|
||||||
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
|
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"update_cached_last_access_time", update_cache_txn
|
"update_cached_last_access_time", update_cache_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
|
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_remote_media(self, media_origin, media_id):
|
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|
||||||
def delete_remote_media_txn(txn):
|
def delete_remote_media_txn(txn):
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
keyvalues={"media_origin": media_origin, "media_id": media_id},
|
keyvalues={"media_origin": media_origin, "media_id": media_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_remote_media", delete_remote_media_txn
|
"delete_remote_media", delete_remote_media_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_expired_url_cache(self, now_ts):
|
async def get_expired_url_cache(self, now_ts: int) -> List[str]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT media_id FROM local_media_repository_url_cache"
|
"SELECT media_id FROM local_media_repository_url_cache"
|
||||||
" WHERE expires_ts < ?"
|
" WHERE expires_ts < ?"
|
||||||
|
@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
txn.execute(sql, (now_ts,))
|
txn.execute(sql, (now_ts,))
|
||||||
return [row[0] for row in txn]
|
return [row[0] for row in txn]
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_expired_url_cache", _get_expired_url_cache_txn
|
"get_expired_url_cache", _get_expired_url_cache_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"delete_url_cache", _delete_url_cache_txn
|
"delete_url_cache", _delete_url_cache_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_url_cache_media_before(self, before_ts):
|
async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT media_id FROM local_media_repository"
|
"SELECT media_id FROM local_media_repository"
|
||||||
" WHERE created_ts < ? AND url_cache IS NOT NULL"
|
" WHERE created_ts < ? AND url_cache IS NOT NULL"
|
||||||
|
@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
txn.execute(sql, (before_ts,))
|
txn.execute(sql, (before_ts,))
|
||||||
return [row[0] for row in txn]
|
return [row[0] for row in txn]
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_url_cache_media_before", _get_url_cache_media_before_txn
|
"get_url_cache_media_before", _get_url_cache_media_before_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,10 @@
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Set
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
|
@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
"count": count,
|
"count": count,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _find_highlights_in_postgres(self, search_query, events):
|
async def _find_highlights_in_postgres(
|
||||||
|
self, search_query: str, events: List[EventBase]
|
||||||
|
) -> Set[str]:
|
||||||
"""Given a list of events and a search term, return a list of words
|
"""Given a list of events and a search term, return a list of words
|
||||||
that match from the content of the event.
|
that match from the content of the event.
|
||||||
|
|
||||||
|
@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
highlight the matching parts.
|
highlight the matching parts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_query (str)
|
search_query
|
||||||
events (list): A list of events
|
events: A list of events
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
deferred : A set of strings.
|
A set of strings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
|
||||||
return highlight_words
|
return highlight_words
|
||||||
|
|
||||||
return self.db_pool.runInteraction("_find_highlights", f)
|
return await self.db_pool.runInteraction("_find_highlights", f)
|
||||||
|
|
||||||
|
|
||||||
def _to_postgres_options(options_dict):
|
def _to_postgres_options(options_dict):
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, JoinRules
|
from synapse.api.constants import EventTypes, JoinRules
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
|
@ -365,7 +365,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
|
async def update_profile_in_user_dir(
|
||||||
|
self, user_id: str, display_name: str, avatar_url: str
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update or add a user's profile in the user directory.
|
Update or add a user's profile in the user directory.
|
||||||
"""
|
"""
|
||||||
|
@ -458,17 +460,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
|
|
||||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
|
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_users_who_share_private_room(self, room_id, user_id_tuples):
|
async def add_users_who_share_private_room(
|
||||||
|
self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
|
||||||
|
) -> None:
|
||||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
"""Insert entries into the users_who_share_private_rooms table. The first
|
||||||
user should be a local user.
|
user should be a local user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id
|
||||||
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
|
user_id_tuples: iterable of 2-tuple of user IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _add_users_who_share_room_txn(txn):
|
def _add_users_who_share_room_txn(txn):
|
||||||
|
@ -484,17 +488,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
value_values=None,
|
value_values=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_users_who_share_room", _add_users_who_share_room_txn
|
"add_users_who_share_room", _add_users_who_share_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_users_in_public_rooms(self, room_id, user_ids):
|
async def add_users_in_public_rooms(
|
||||||
|
self, room_id: str, user_ids: Iterable[str]
|
||||||
|
) -> None:
|
||||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
"""Insert entries into the users_who_share_private_rooms table. The first
|
||||||
user should be a local user.
|
user should be a local user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id
|
||||||
user_ids (list[str])
|
user_ids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _add_users_in_public_rooms_txn(txn):
|
def _add_users_in_public_rooms_txn(txn):
|
||||||
|
@ -508,11 +514,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
value_values=None,
|
value_values=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
|
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_all_from_user_dir(self):
|
async def delete_all_from_user_dir(self) -> None:
|
||||||
"""Delete the entire user directory
|
"""Delete the entire user directory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -523,7 +529,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
txn.execute("DELETE FROM users_who_share_private_rooms")
|
txn.execute("DELETE FROM users_who_share_private_rooms")
|
||||||
txn.call_after(self.get_user_in_directory.invalidate_all)
|
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -555,7 +561,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
|
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
|
||||||
|
|
||||||
def remove_from_user_dir(self, user_id):
|
async def remove_from_user_dir(self, user_id: str) -> None:
|
||||||
def _remove_from_user_dir_txn(txn):
|
def _remove_from_user_dir_txn(txn):
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn, table="user_directory", keyvalues={"user_id": user_id}
|
txn, table="user_directory", keyvalues={"user_id": user_id}
|
||||||
|
@ -578,7 +584,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"remove_from_user_dir", _remove_from_user_dir_txn
|
"remove_from_user_dir", _remove_from_user_dir_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -605,14 +611,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
|
|
||||||
return user_ids
|
return user_ids
|
||||||
|
|
||||||
def remove_user_who_share_room(self, user_id, room_id):
|
async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Deletes entries in the users_who_share_*_rooms table. The first
|
Deletes entries in the users_who_share_*_rooms table. The first
|
||||||
user should be a local user.
|
user should be a local user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str)
|
user_id
|
||||||
room_id (str)
|
room_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _remove_user_who_share_room_txn(txn):
|
def _remove_user_who_share_room_txn(txn):
|
||||||
|
@ -632,7 +638,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
keyvalues={"user_id": user_id, "room_id": room_id},
|
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"remove_user_who_share_room", _remove_user_who_share_room_txn
|
"remove_user_who_share_room", _remove_user_who_share_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue