forked from MirrorHub/synapse
Convert additional databases to async/await part 3 (#8201)
This commit is contained in:
parent
7d103a594e
commit
37db6252b7
7 changed files with 121 additions and 87 deletions
1
changelog.d/8201.misc
Normal file
1
changelog.d/8201.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -433,7 +433,7 @@ class BackgroundUpdater(object):
|
|||
"background_updates", keyvalues={"update_name": update_name}
|
||||
)
|
||||
|
||||
def _background_update_progress(self, update_name: str, progress: dict):
|
||||
async def _background_update_progress(self, update_name: str, progress: dict):
|
||||
"""Update the progress of a background update
|
||||
|
||||
Args:
|
||||
|
@ -441,7 +441,7 @@ class BackgroundUpdater(object):
|
|||
progress: The progress of the update.
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"background_update_progress",
|
||||
self._background_update_progress_txn,
|
||||
update_name,
|
||||
|
|
|
@ -16,9 +16,7 @@
|
|||
|
||||
import abc
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cached()
|
||||
def get_account_data_for_user(self, user_id):
|
||||
async def get_account_data_for_user(
|
||||
self, user_id: str
|
||||
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Get all the client account_data for a user.
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
user_id: The user to get the account_data for.
|
||||
Returns:
|
||||
A deferred pair of a dict of global account_data and a dict
|
||||
mapping from room_id string to per room account_data dicts.
|
||||
A 2-tuple of a dict of global account_data and a dict mapping from
|
||||
room_id string to per room account_data dicts.
|
||||
"""
|
||||
|
||||
def get_account_data_for_user_txn(txn):
|
||||
|
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
|
||||
return global_account_data, by_room
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_account_data_for_user", get_account_data_for_user_txn
|
||||
)
|
||||
|
||||
|
@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
return None
|
||||
|
||||
@cached(num_args=2)
|
||||
def get_account_data_for_room(self, user_id, room_id):
|
||||
async def get_account_data_for_room(
|
||||
self, user_id: str, room_id: str
|
||||
) -> Dict[str, JsonDict]:
|
||||
"""Get all the client account_data for a user for a room.
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
room_id(str): The room to get the account_data for.
|
||||
user_id: The user to get the account_data for.
|
||||
room_id: The room to get the account_data for.
|
||||
Returns:
|
||||
A deferred dict of the room account_data
|
||||
A dict of the room account_data
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_txn(txn):
|
||||
|
@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
||||
}
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_account_data_for_room", get_account_data_for_room_txn
|
||||
)
|
||||
|
||||
@cached(num_args=3, max_entries=5000)
|
||||
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
|
||||
async def get_account_data_for_room_and_type(
|
||||
self, user_id: str, room_id: str, account_data_type: str
|
||||
) -> Optional[JsonDict]:
|
||||
"""Get the client account_data of given type for a user for a room.
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
room_id(str): The room to get the account_data for.
|
||||
account_data_type (str): The account data type to get.
|
||||
user_id: The user to get the account_data for.
|
||||
room_id: The room to get the account_data for.
|
||||
account_data_type: The account data type to get.
|
||||
Returns:
|
||||
A deferred of the room account_data for that type, or None if
|
||||
there isn't any set.
|
||||
The room account_data for that type, or None if there isn't any set.
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_and_type_txn(txn):
|
||||
|
@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
|
||||
return db_to_json(content_json) if content_json else None
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||
)
|
||||
|
||||
|
@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
"get_updated_room_account_data", get_updated_room_account_data_txn
|
||||
)
|
||||
|
||||
def get_updated_account_data_for_user(self, user_id, stream_id):
|
||||
async def get_updated_account_data_for_user(
|
||||
self, user_id: str, stream_id: int
|
||||
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Get all the client account_data for a that's changed for a user
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
stream_id(int): The point in the stream since which to get updates
|
||||
user_id: The user to get the account_data for.
|
||||
stream_id: The point in the stream since which to get updates
|
||||
Returns:
|
||||
A deferred pair of a dict of global account_data and a dict
|
||||
mapping from room_id string to per room account_data dicts.
|
||||
|
@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
user_id, int(stream_id)
|
||||
)
|
||||
if not changed:
|
||||
return defer.succeed(({}, {}))
|
||||
return ({}, {})
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
||||
)
|
||||
|
||||
|
@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def _update_max_stream_id(self, next_id: int):
|
||||
async def _update_max_stream_id(self, next_id: int) -> None:
|
||||
"""Update the max stream_id
|
||||
|
||||
Args:
|
||||
|
@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
)
|
||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||
|
||||
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
|
||||
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
|
||||
|
|
|
@ -34,13 +34,15 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
def get_e2e_device_keys_for_federation_query(self, user_id: str):
|
||||
async def get_e2e_device_keys_for_federation_query(
|
||||
self, user_id: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
"""Get all devices (with any device keys) for a user
|
||||
|
||||
Returns:
|
||||
Deferred which resolves to (stream_id, devices)
|
||||
(stream_id, devices)
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys_for_federation_query",
|
||||
self._get_e2e_device_keys_for_federation_query_txn,
|
||||
user_id,
|
||||
|
@ -292,10 +294,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||
async def count_e2e_one_time_keys(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Dict[str, int]:
|
||||
""" Count the number of one time keys the server has for a device
|
||||
Returns:
|
||||
Dict mapping from algorithm to number of keys for that algorithm.
|
||||
A mapping from algorithm to number of keys for that algorithm.
|
||||
"""
|
||||
|
||||
def _count_e2e_one_time_keys(txn):
|
||||
|
@ -310,7 +314,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
result[algorithm] = key_count
|
||||
return result
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
|
@ -348,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
def _get_bare_e2e_cross_signing_keys_bulk(
|
||||
async def _get_bare_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: List[str]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
|
@ -356,16 +360,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
the signatures for the calling user need to be fetched.
|
||||
|
||||
Args:
|
||||
user_ids (list[str]): the users whose keys are being requested
|
||||
user_ids: the users whose keys are being requested
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||
data. If a user's cross-signing keys were not found, either
|
||||
their user ID will not be in the dict, or their user ID will map
|
||||
to None.
|
||||
A mapping from user ID to key type to key data. If a user's cross-signing
|
||||
keys were not found, either their user ID will not be in the dict, or
|
||||
their user ID will map to None.
|
||||
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_bare_e2e_cross_signing_keys_bulk",
|
||||
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
||||
user_ids,
|
||||
|
@ -588,7 +591,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||
async def set_e2e_device_keys(
|
||||
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
|
||||
) -> bool:
|
||||
"""Stores device keys for a device. Returns whether there was a change
|
||||
or the keys were already in the database.
|
||||
"""
|
||||
|
@ -624,12 +629,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
log_kv({"message": "Device keys stored."})
|
||||
return True
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_e2e_device_keys", _set_e2e_device_keys_txn
|
||||
)
|
||||
|
||||
def claim_e2e_one_time_keys(self, query_list):
|
||||
"""Take a list of one time keys out of the database"""
|
||||
async def claim_e2e_one_time_keys(
|
||||
self, query_list: Iterable[Tuple[str, str, str]]
|
||||
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
|
||||
"""Take a list of one time keys out of the database.
|
||||
|
||||
Args:
|
||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
|
||||
Returns:
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||
"""
|
||||
|
||||
@trace
|
||||
def _claim_e2e_one_time_keys(txn):
|
||||
|
@ -665,11 +679,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
)
|
||||
return result
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||
)
|
||||
|
||||
def delete_e2e_keys_by_device(self, user_id, device_id):
|
||||
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
||||
def delete_e2e_keys_by_device_txn(txn):
|
||||
log_kv(
|
||||
{
|
||||
|
@ -692,7 +706,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
desc="mark_local_media_as_safe",
|
||||
)
|
||||
|
||||
def get_url_cache(self, url, ts):
|
||||
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get the media_id and ts for a cached URL as of the given timestamp
|
||||
Returns:
|
||||
None if the URL isn't cached.
|
||||
|
@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
||||
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
||||
|
||||
async def store_url_cache(
|
||||
self, url, response_code, etag, expires_ts, og, media_id, download_ts
|
||||
|
@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
desc="store_cached_remote_media",
|
||||
)
|
||||
|
||||
def update_cached_last_access_time(self, local_media, remote_media, time_ms):
|
||||
async def update_cached_last_access_time(
|
||||
self,
|
||||
local_media: Iterable[str],
|
||||
remote_media: Iterable[Tuple[str, str]],
|
||||
time_ms: int,
|
||||
):
|
||||
"""Updates the last access time of the given media
|
||||
|
||||
Args:
|
||||
local_media (iterable[str]): Set of media_ids
|
||||
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
|
||||
local_media: Set of media_ids
|
||||
remote_media: Set of (server_name, media_id)
|
||||
time_ms: Current time in milliseconds
|
||||
"""
|
||||
|
||||
|
@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"update_cached_last_access_time", update_cache_txn
|
||||
)
|
||||
|
||||
|
@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
|
||||
)
|
||||
|
||||
def delete_remote_media(self, media_origin, media_id):
|
||||
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|
||||
def delete_remote_media_txn(txn):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
|
@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
keyvalues={"media_origin": media_origin, "media_id": media_id},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_remote_media", delete_remote_media_txn
|
||||
)
|
||||
|
||||
def get_expired_url_cache(self, now_ts):
|
||||
async def get_expired_url_cache(self, now_ts: int) -> List[str]:
|
||||
sql = (
|
||||
"SELECT media_id FROM local_media_repository_url_cache"
|
||||
" WHERE expires_ts < ?"
|
||||
|
@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
txn.execute(sql, (now_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_expired_url_cache", _get_expired_url_cache_txn
|
||||
)
|
||||
|
||||
|
@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"delete_url_cache", _delete_url_cache_txn
|
||||
)
|
||||
|
||||
def get_url_cache_media_before(self, before_ts):
|
||||
async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
|
||||
sql = (
|
||||
"SELECT media_id FROM local_media_repository"
|
||||
" WHERE created_ts < ? AND url_cache IS NOT NULL"
|
||||
|
@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
txn.execute(sql, (before_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_url_cache_media_before", _get_url_cache_media_before_txn
|
||||
)
|
||||
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
import logging
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
|
@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
|
|||
"count": count,
|
||||
}
|
||||
|
||||
def _find_highlights_in_postgres(self, search_query, events):
|
||||
async def _find_highlights_in_postgres(
|
||||
self, search_query: str, events: List[EventBase]
|
||||
) -> Set[str]:
|
||||
"""Given a list of events and a search term, return a list of words
|
||||
that match from the content of the event.
|
||||
|
||||
|
@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
|
|||
highlight the matching parts.
|
||||
|
||||
Args:
|
||||
search_query (str)
|
||||
events (list): A list of events
|
||||
search_query
|
||||
events: A list of events
|
||||
|
||||
Returns:
|
||||
deferred : A set of strings.
|
||||
A set of strings.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
|
@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
|||
|
||||
return highlight_words
|
||||
|
||||
return self.db_pool.runInteraction("_find_highlights", f)
|
||||
return await self.db_pool.runInteraction("_find_highlights", f)
|
||||
|
||||
|
||||
def _to_postgres_options(options_dict):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -365,7 +365,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
|
||||
return False
|
||||
|
||||
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
|
||||
async def update_profile_in_user_dir(
|
||||
self, user_id: str, display_name: str, avatar_url: str
|
||||
) -> None:
|
||||
"""
|
||||
Update or add a user's profile in the user directory.
|
||||
"""
|
||||
|
@ -458,17 +460,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
|
||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
|
||||
)
|
||||
|
||||
def add_users_who_share_private_room(self, room_id, user_id_tuples):
|
||||
async def add_users_who_share_private_room(
|
||||
self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
|
||||
) -> None:
|
||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
||||
user should be a local user.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
|
||||
room_id
|
||||
user_id_tuples: iterable of 2-tuple of user IDs.
|
||||
"""
|
||||
|
||||
def _add_users_who_share_room_txn(txn):
|
||||
|
@ -484,17 +488,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
value_values=None,
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_users_who_share_room", _add_users_who_share_room_txn
|
||||
)
|
||||
|
||||
def add_users_in_public_rooms(self, room_id, user_ids):
|
||||
async def add_users_in_public_rooms(
|
||||
self, room_id: str, user_ids: Iterable[str]
|
||||
) -> None:
|
||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
||||
user should be a local user.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
user_ids (list[str])
|
||||
room_id
|
||||
user_ids
|
||||
"""
|
||||
|
||||
def _add_users_in_public_rooms_txn(txn):
|
||||
|
@ -508,11 +514,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
value_values=None,
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
|
||||
)
|
||||
|
||||
def delete_all_from_user_dir(self):
|
||||
async def delete_all_from_user_dir(self) -> None:
|
||||
"""Delete the entire user directory
|
||||
"""
|
||||
|
||||
|
@ -523,7 +529,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
txn.execute("DELETE FROM users_who_share_private_rooms")
|
||||
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||
)
|
||||
|
||||
|
@ -555,7 +561,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
def remove_from_user_dir(self, user_id):
|
||||
async def remove_from_user_dir(self, user_id: str) -> None:
|
||||
def _remove_from_user_dir_txn(txn):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="user_directory", keyvalues={"user_id": user_id}
|
||||
|
@ -578,7 +584,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
)
|
||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"remove_from_user_dir", _remove_from_user_dir_txn
|
||||
)
|
||||
|
||||
|
@ -605,14 +611,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
|
||||
return user_ids
|
||||
|
||||
def remove_user_who_share_room(self, user_id, room_id):
|
||||
async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
|
||||
"""
|
||||
Deletes entries in the users_who_share_*_rooms table. The first
|
||||
user should be a local user.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
room_id (str)
|
||||
user_id
|
||||
room_id
|
||||
"""
|
||||
|
||||
def _remove_user_who_share_room_txn(txn):
|
||||
|
@ -632,7 +638,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"remove_user_who_share_room", _remove_user_who_share_room_txn
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue