0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-13 16:18:56 +02:00

Convert additional database code to async/await. (#8195)

This commit is contained in:
Patrick Cloke 2020-08-28 07:54:27 -04:00 committed by GitHub
parent d5e73cb6aa
commit 5c03134d0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 246 additions and 175 deletions

1
changelog.d/8195.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -14,11 +14,16 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING
from synapse.api.constants import EventTypes
from synapse.appservice.api import ApplicationServiceApi
from synapse.types import GroupID, get_domain_from_id
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@ -35,19 +40,19 @@ class AppServiceTransaction(object):
self.id = id
self.events = events
def send(self, as_api):
async def send(self, as_api: ApplicationServiceApi) -> bool:
"""Sends this transaction using the provided AS API interface.
Args:
as_api(ApplicationServiceApi): The API to use to send.
as_api: The API to use to send.
Returns:
An Awaitable which resolves to True if the transaction was sent.
True if the transaction was sent.
"""
return as_api.push_bulk(
return await as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id
)
def complete(self, store):
async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
@ -55,10 +60,8 @@ class AppServiceTransaction(object):
Args:
store: The database store to operate on.
Returns:
A Deferred which resolves to True if the transaction was completed.
"""
return store.complete_appservice_txn(service=self.service, txn_id=self.id)
await store.complete_appservice_txn(service=self.service, txn_id=self.id)
class ApplicationService(object):

View file

@ -20,6 +20,7 @@ These actions are mostly only used by the :py:mod:`.replication` module.
"""
import logging
from typing import Optional, Tuple
from synapse.federation.units import Transaction
from synapse.logging.utils import log_function
@ -36,25 +37,27 @@ class TransactionActions(object):
self.store = datastore
@log_function
def have_responded(self, origin, transaction):
""" Have we already responded to a transaction with the same id and
async def have_responded(
self, origin: str, transaction: Transaction
) -> Optional[Tuple[int, JsonDict]]:
"""Have we already responded to a transaction with the same id and
origin?
Returns:
Deferred: Results in `None` if we have not previously responded to
this transaction or a 2-tuple of `(int, dict)` representing the
response code and response body.
`None` if we have not previously responded to this transaction or a
2-tuple of `(int, dict)` representing the response code and response body.
"""
if not transaction.transaction_id:
transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.get_received_txn_response(transaction.transaction_id, origin)
return await self.store.get_received_txn_response(transaction_id, origin)
@log_function
async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
""" Persist how we responded to a transaction.
"""Persist how we responded to a transaction.
"""
transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:

View file

@ -1879,8 +1879,8 @@ class FederationHandler(BaseHandler):
else:
return None
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
async def get_min_depth_for_context(self, context):
return await self.store.get_min_depth(context)
async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False

View file

@ -172,7 +172,7 @@ class ApplicationServiceTransactionWorkerStore(
"application_services_state", {"as_id": service.id}, {"state": state}
)
def create_appservice_txn(self, service, events):
async def create_appservice_txn(self, service, events):
"""Atomically creates a new transaction for this application service
with the given list of events.
@ -209,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
)
def complete_appservice_txn(self, txn_id, service):
async def complete_appservice_txn(self, txn_id, service) -> None:
"""Completes an application service transaction.
Args:
txn_id(str): The transaction ID being completed.
service(ApplicationService): The application service which was sent
this transaction.
Returns:
A Deferred which resolves if this transaction was stored
successfully.
"""
txn_id = int(txn_id)
@ -258,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore(
{"txn_id": txn_id, "as_id": service.id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"complete_appservice_txn", _complete_appservice_txn
)
@ -312,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
def set_appservice_last_pos(self, pos):
async def set_appservice_last_pos(self, pos) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)

View file

@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
@trace
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
async def delete_device_msgs_for_remote(
self, destination: str, up_to_stream_id: int
) -> None:
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
destination: The destination server_name
up_to_stream_id: Where to delete messages up to.
"""
def delete_messages_for_remote_destination_txn(txn):
@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)

View file

@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions
def get_e2e_room_keys_multi(self, user_id, version, room_keys):
async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
@ -166,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
that we want to query
Returns:
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
@ -283,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No current backup version")
return row[0]
def get_e2e_room_keys_version_info(self, user_id, version=None):
async def get_e2e_room_keys_version_info(self, user_id, version=None):
"""Get info metadata about a version of our room_keys backup.
Args:
@ -293,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present
Returns:
A deferred dict giving the info metadata for this backup version, with
A dict giving the info metadata for this backup version, with
fields including:
version(str)
algorithm(str)
@ -324,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result["etag"] = 0
return result
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@trace
def create_e2e_room_keys_version(self, user_id, info):
async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
@ -338,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
info(dict): the info about the backup version to be created
Returns:
A deferred string for the newly created version ID
The newly created version ID
"""
def _create_e2e_room_keys_version_txn(txn):
@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@ -403,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
def delete_e2e_room_keys_version(self, user_id, version=None):
async def delete_e2e_room_keys_version(
self, user_id: str, version: Optional[str] = None
) -> None:
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
Args:
user_id(str): the user whose backup version we're deleting
version(str): Optional. the version ID of the backup version we're deleting
user_id: the user whose backup version we're deleting
version: Optional. the version ID of the backup version we're deleting
If missing, we delete the current backup version info.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present,
@ -430,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "version": this_version},
)
return self.db_pool.simple_update_one_txn(
self.db_pool.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)

View file

@ -59,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given: include the given events in result
Returns:
list of event_ids
An awaitable which resolve to a list of event_ids
"""
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
@ -95,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@ -104,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
chain.
Returns:
Deferred[Set[str]]
The set of the difference in auth chains.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
@ -252,8 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction(
async def get_oldest_events_with_depth_in_room(self, room_id):
return await self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@ -293,7 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
def get_prev_events_for_room(self, room_id: str):
async def get_prev_events_for_room(self, room_id: str) -> List[str]:
"""
Gets a subset of the current forward extremities in the given room.
@ -301,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
events which refer to hundreds of prev_events.
Args:
room_id (str): room_id
room_id: room_id
Returns:
Deferred[List[str]]: the event ids of the forward extremites
The event ids of the forward extremities.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
@ -328,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return [row[0] for row in txn]
def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
async def get_rooms_with_many_extremities(
self, min_count: int, limit: int, room_id_filter: Iterable[str]
) -> List[str]:
"""Get the top rooms with at least N extremities.
Args:
min_count (int): The minimum number of extremities
limit (int): The maximum number of rooms to return.
room_id_filter (iterable[str]): room_ids to exclude from the results
min_count: The minimum number of extremities
limit: The maximum number of rooms to return.
room_id_filter: room_ids to exclude from the results
Returns:
Deferred[list]: At most `limit` room IDs that have at least
`min_count` extremities, sorted by extremity count.
At most `limit` room IDs that have at least `min_count` extremities,
sorted by extremity count.
"""
def _get_rooms_with_many_extremities_txn(txn):
@ -363,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args)
return [room_id for room_id, in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@ -376,10 +378,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_latest_event_ids_in_room",
)
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
async def get_min_depth(self, room_id: str) -> int:
"""For the given room, get the minimum depth we have seen for it.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
@ -394,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
async def get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@ -402,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
stream_orderings from that point.
Args:
room_id (str):
stream_ordering (int):
room_id:
stream_ordering:
Returns:
deferred, which resolves to a list of event_ids
A list of event_ids
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
@ -422,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
return self._get_forward_extremeties_for_room(room_id, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@ -450,19 +454,18 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
async def get_backfill_events(self, room_id, event_list, limit):
async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
Args:
txn
room_id (str)
event_list (list)
limit (int)
room_id
event_list
limit
"""
event_ids = await self.db_pool.runInteraction(
"get_backfill_events",
@ -631,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore):
_delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
return self.db_pool.runInteraction(
async def clean_room_for_join(self, room_id):
return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@ -70,7 +70,9 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_invited_users_in_group",
)
def get_rooms_in_group(self, group_id: str, include_private: bool = False):
async def get_rooms_in_group(
self, group_id: str, include_private: bool = False
) -> List[Dict[str, Union[str, bool]]]:
"""Retrieve the rooms that belong to a given group. Does not return rooms that
lack members.
@ -79,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results
Returns:
Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
form of:
A list of dictionaries, each in the form of:
{
"room_id": "!a_room_id:example.com", # The ID of the room
@ -117,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore):
for room_id, is_public in txn
]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_rooms_in_group", _get_rooms_in_group_txn
)
def get_rooms_for_summary_by_category(
async def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False,
):
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Get the rooms and categories that should be included in a summary request
Args:
@ -131,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results
Returns:
Deferred[Tuple[List, Dict]]: A tuple containing:
A tuple containing:
* A list of dictionaries with the keys:
* "room_id": str, the room ID
@ -207,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return rooms, categories
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
@ -281,10 +282,11 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room",
)
def get_users_for_summary_by_role(self, group_id, include_private=False):
async def get_users_for_summary_by_role(self, group_id, include_private=False):
"""Get the users and roles that should be included in a summary request
Returns ([users], [roles])
Returns:
([users], [roles])
"""
def _get_users_for_summary_txn(txn):
@ -338,7 +340,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return users, roles
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
@ -376,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
def get_users_membership_info_in_group(self, group_id, user_id):
async def get_users_membership_info_in_group(self, group_id, user_id):
"""Get a dict describing the membership of a user in a group.
Example if joined:
@ -387,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_privileged": False,
}
Returns an empty dict if the user is not join/invite/etc
Returns:
An empty dict if the user is not join/invite/etc
"""
def _get_users_membership_in_group_txn(txn):
@ -419,7 +422,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {}
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
@ -433,7 +436,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user",
)
def get_attestations_need_renewals(self, valid_until_ms):
async def get_attestations_need_renewals(self, valid_until_ms):
"""Get all attestations that need to be renewed until givent time
"""
@ -445,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore):
txn.execute(sql, (valid_until_ms,))
return self.db_pool.cursor_to_dict(txn)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
@ -475,7 +478,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
def get_all_groups_for_user(self, user_id, now_token):
async def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn):
sql = """
SELECT group_id, type, membership, u.content
@ -495,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in txn
]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
@ -600,8 +603,27 @@ class GroupServerStore(GroupServerWorkerStore):
desc="set_group_join_policy",
)
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
return self.db_pool.runInteraction(
async def add_room_to_summary(
self,
group_id: str,
room_id: str,
category_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
Args:
group_id
room_id
category_id: If not None then adds the category to the end of
the summary if its not already there.
order: If not None inserts the room at that position, e.g. an order
of 1 will put the room first. Otherwise, the room gets added to
the end.
is_public
"""
await self.db_pool.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@ -612,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore):
)
def _add_room_to_summary_txn(
self, txn, group_id, room_id, category_id, order, is_public
):
self,
txn,
group_id: str,
room_id: str,
category_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
Args:
group_id (str)
room_id (str)
category_id (str): If not None then adds the category to the end of
the summary if its not already there. [Optional]
order (int): If not None inserts the room at that position, e.g.
an order of 1 will put the room first. Otherwise, the room gets
added to the end.
txn
group_id
room_id
category_id: If not None then adds the category to the end of
the summary if its not already there.
order: If not None inserts the room at that position, e.g. an order
of 1 will put the room first. Otherwise, the room gets added to
the end.
is_public
"""
room_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
@ -818,8 +848,27 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_group_role",
)
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
return self.db_pool.runInteraction(
async def add_user_to_summary(
self,
group_id: str,
user_id: str,
role_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
Args:
group_id
user_id
role_id: If not None then adds the role to the end of the summary if
its not already there.
order: If not None inserts the user at that position, e.g. an order
of 1 will put the user first. Otherwise, the user gets added to
the end.
is_public
"""
await self.db_pool.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@ -830,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore):
)
def _add_user_to_summary_txn(
self, txn, group_id, user_id, role_id, order, is_public
self,
txn,
group_id: str,
user_id: str,
role_id: str,
order: int,
is_public: Optional[bool],
):
"""Add (or update) user's entry in summary.
Args:
group_id (str)
user_id (str)
role_id (str): If not None then adds the role to the end of
the summary if its not already there. [Optional]
order (int): If not None inserts the user at that position, e.g.
an order of 1 will put the user first. Otherwise, the user gets
added to the end.
txn
group_id
user_id
role_id: If not None then adds the role to the end of the summary if
its not already there.
order: If not None inserts the user at that position, e.g. an order
of 1 will put the user first. Otherwise, the user gets added to
the end.
is_public
"""
user_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
@ -963,27 +1020,26 @@ class GroupServerStore(GroupServerWorkerStore):
desc="add_group_invite",
)
def add_user_to_group(
async def add_user_to_group(
self,
group_id,
user_id,
is_admin=False,
is_public=True,
local_attestation=None,
remote_attestation=None,
):
group_id: str,
user_id: str,
is_admin: bool = False,
is_public: bool = True,
local_attestation: dict = None,
remote_attestation: dict = None,
) -> None:
"""Add a user to the group server.
Args:
group_id (str)
user_id (str)
is_admin (bool)
is_public (bool)
local_attestation (dict): The attestation the GS created to give
to the remote server. Optional if the user and group are on the
same server
remote_attestation (dict): The attestation given to GS by remote
group_id
user_id
is_admin
is_public
local_attestation: The attestation the GS created to give to the remote
server. Optional if the user and group are on the same server
remote_attestation: The attestation given to GS by remote server.
Optional if the user and group are on the same server
"""
def _add_user_to_group_txn(txn):
@ -1026,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id):
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@ -1056,7 +1112,7 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "user_id": user_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
@ -1079,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="update_room_in_group_visibility",
)
def remove_room_from_group(self, group_id, room_id):
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@ -1093,7 +1149,7 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "room_id": room_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
@ -1286,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
def delete_group(self, group_id):
async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
Args:
group_id (str)
Returns:
Deferred
group_id: The group ID to delete.
"""
def _delete_group_txn(txn):
@ -1317,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore):
txn, table=table, keyvalues={"group_id": group_id}
)
return self.db_pool.runInteraction("delete_group", _delete_group_txn)
await self.db_pool.runInteraction("delete_group", _delete_group_txn)

View file

@ -16,7 +16,7 @@
import itertools
import logging
from typing import Iterable, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
@ -42,16 +42,17 @@ class KeyStore(SQLBaseStore):
@cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
)
def get_server_verify_keys(self, server_name_and_key_ids):
async def get_server_verify_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
"""
Args:
server_name_and_key_ids (iterable[Tuple[str, str]]):
server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
A map from (server_name, key_id) -> FetchKeyResult, or None if the
key is unknown
"""
keys = {}
@ -87,7 +88,7 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
return self.db_pool.runInteraction("get_server_verify_keys", _txn)
return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
async def store_server_verify_keys(
self,
@ -179,7 +180,9 @@ class KeyStore(SQLBaseStore):
desc="store_server_keys_json",
)
def get_server_keys_json(self, server_keys):
async def get_server_keys_json(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
"""Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
@ -188,8 +191,7 @@ class KeyStore(SQLBaseStore):
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
Dict mapping (server_name, key_id, source) triplets to lists of dicts
A mapping from (server_name, key_id, source) triplets to a list of dicts
"""
def _get_server_keys_json_txn(txn):
@ -215,6 +217,6 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows
return results
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
)

View file

@ -15,6 +15,7 @@
import logging
from collections import namedtuple
from typing import Optional, Tuple
from canonicaljson import encode_canonical_json
@ -56,21 +57,23 @@ class TransactionStore(SQLBaseStore):
expiry_ms=5 * 60 * 1000,
)
def get_received_txn_response(self, transaction_id, origin):
async def get_received_txn_response(
self, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
body (as a dict).
Args:
transaction_id (str)
origin(str)
transaction_id
origin
Returns:
tuple: None if we have not previously responded to
this transaction or a 2-tuple of (int, dict)
None if we have not previously responded to this transaction or a
2-tuple of (int, dict)
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@ -166,21 +169,25 @@ class TransactionStore(SQLBaseStore):
else:
return None
def set_destination_retry_timings(
self, destination, failure_ts, retry_last_ts, retry_interval
):
async def set_destination_retry_timings(
self,
destination: str,
failure_ts: Optional[int],
retry_last_ts: int,
retry_interval: int,
) -> None:
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
Args:
destination (str)
failure_ts (int|None) - when the server started failing (ms since epoch)
retry_last_ts (int) - time of last retry attempt in unix epoch ms
retry_interval (int) - how long until next retry in ms
destination
failure_ts: when the server started failing (ms since epoch)
retry_last_ts: time of last retry attempt in unix epoch ms
retry_interval: how long until next retry in ms
"""
self._destination_retry_cache.pop(destination, None)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
@ -256,13 +263,13 @@ class TransactionStore(SQLBaseStore):
"cleanup_transactions", self._cleanup_transactions
)
def _cleanup_transactions(self):
async def _cleanup_transactions(self) -> None:
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)