mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 16:33:53 +01:00
Convert simple_update* and simple_select* to async (#8173)
This commit is contained in:
parent
a466b67972
commit
4a739c73b4
19 changed files with 164 additions and 133 deletions
1
changelog.d/8173.misc
Normal file
1
changelog.d/8173.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -51,7 +51,7 @@ from synapse.types import (
|
||||||
create_requester,
|
create_requester,
|
||||||
)
|
)
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.async_helpers import Linearizer, maybe_awaitable
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
|
@ -1329,9 +1329,7 @@ class RoomShutdownHandler(object):
|
||||||
ratelimit=False,
|
ratelimit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
aliases_for_room = await maybe_awaitable(
|
aliases_for_room = await self.store.get_aliases_for_room(room_id)
|
||||||
self.store.get_aliases_for_room(room_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.store.update_aliases_for_room(
|
await self.store.update_aliases_for_room(
|
||||||
room_id, new_room_id, requester_user_id
|
room_id, new_room_id, requester_user_id
|
||||||
|
|
|
@ -1132,13 +1132,13 @@ class DatabasePool(object):
|
||||||
|
|
||||||
return [r[0] for r in txn]
|
return [r[0] for r in txn]
|
||||||
|
|
||||||
def simple_select_onecol(
|
async def simple_select_onecol(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Optional[Dict[str, Any]],
|
keyvalues: Optional[Dict[str, Any]],
|
||||||
retcol: str,
|
retcol: str,
|
||||||
desc: str = "simple_select_onecol",
|
desc: str = "simple_select_onecol",
|
||||||
) -> defer.Deferred:
|
) -> List[Any]:
|
||||||
"""Executes a SELECT query on the named table, which returns a list
|
"""Executes a SELECT query on the named table, which returns a list
|
||||||
comprising of the values of the named column from the selected rows.
|
comprising of the values of the named column from the selected rows.
|
||||||
|
|
||||||
|
@ -1148,19 +1148,19 @@ class DatabasePool(object):
|
||||||
retcol: column whos value we wish to retrieve.
|
retcol: column whos value we wish to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Results in a list
|
Results in a list
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return await self.runInteraction(
|
||||||
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
|
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
|
||||||
)
|
)
|
||||||
|
|
||||||
def simple_select_list(
|
async def simple_select_list(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Optional[Dict[str, Any]],
|
keyvalues: Optional[Dict[str, Any]],
|
||||||
retcols: Iterable[str],
|
retcols: Iterable[str],
|
||||||
desc: str = "simple_select_list",
|
desc: str = "simple_select_list",
|
||||||
) -> defer.Deferred:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
|
||||||
|
@ -1170,10 +1170,11 @@ class DatabasePool(object):
|
||||||
column names and values to select the rows with, or None to not
|
column names and values to select the rows with, or None to not
|
||||||
apply a WHERE clause.
|
apply a WHERE clause.
|
||||||
retcols: the names of the columns to return
|
retcols: the names of the columns to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to list[dict[str, Any]]
|
A list of dictionaries.
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return await self.runInteraction(
|
||||||
desc, self.simple_select_list_txn, table, keyvalues, retcols
|
desc, self.simple_select_list_txn, table, keyvalues, retcols
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1299,14 +1300,14 @@ class DatabasePool(object):
|
||||||
txn.execute(sql, values)
|
txn.execute(sql, values)
|
||||||
return cls.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
def simple_update(
|
async def simple_update(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
updatevalues: Dict[str, Any],
|
updatevalues: Dict[str, Any],
|
||||||
desc: str,
|
desc: str,
|
||||||
) -> defer.Deferred:
|
) -> int:
|
||||||
return self.runInteraction(
|
return await self.runInteraction(
|
||||||
desc, self.simple_update_txn, table, keyvalues, updatevalues
|
desc, self.simple_update_txn, table, keyvalues, updatevalues
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1332,13 +1333,13 @@ class DatabasePool(object):
|
||||||
|
|
||||||
return txn.rowcount
|
return txn.rowcount
|
||||||
|
|
||||||
def simple_update_one(
|
async def simple_update_one(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
updatevalues: Dict[str, Any],
|
updatevalues: Dict[str, Any],
|
||||||
desc: str = "simple_update_one",
|
desc: str = "simple_update_one",
|
||||||
) -> defer.Deferred:
|
) -> None:
|
||||||
"""Executes an UPDATE query on the named table, setting new values for
|
"""Executes an UPDATE query on the named table, setting new values for
|
||||||
columns in a row matching the key values.
|
columns in a row matching the key values.
|
||||||
|
|
||||||
|
@ -1347,7 +1348,7 @@ class DatabasePool(object):
|
||||||
keyvalues: dict of column names and values to select the row with
|
keyvalues: dict of column names and values to select the row with
|
||||||
updatevalues: dict giving column names and values to update
|
updatevalues: dict giving column names and values to update
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
await self.runInteraction(
|
||||||
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
|
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
import calendar
|
import calendar
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
@ -476,14 +477,13 @@ class DataStore(
|
||||||
"generate_user_daily_visits", _generate_user_daily_visits
|
"generate_user_daily_visits", _generate_user_daily_visits
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_users(self):
|
async def get_users(self) -> List[Dict[str, Any]]:
|
||||||
"""Function to retrieve a list of users in users table.
|
"""Function to retrieve a list of users in users table.
|
||||||
|
|
||||||
Args:
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to list[dict[str, Any]]
|
A list of dictionaries representing users.
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_select_list(
|
return await self.db_pool.simple_select_list(
|
||||||
table="users",
|
table="users",
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=[
|
retcols=[
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Iterable, Optional
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=5000)
|
@cached(max_entries=5000)
|
||||||
def get_aliases_for_room(self, room_id):
|
async def get_aliases_for_room(self, room_id: str) -> List[str]:
|
||||||
return self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
"room_aliases",
|
"room_aliases",
|
||||||
{"room_id": room_id},
|
{"room_id": room_id},
|
||||||
"room_alias",
|
"room_alias",
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.logging.opentracing import log_kv, trace
|
from synapse.logging.opentracing import log_kv, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
|
@ -368,18 +370,22 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
def update_e2e_room_keys_version(
|
async def update_e2e_room_keys_version(
|
||||||
self, user_id, version, info=None, version_etag=None
|
self,
|
||||||
):
|
user_id: str,
|
||||||
|
version: str,
|
||||||
|
info: Optional[dict] = None,
|
||||||
|
version_etag: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
"""Update a given backup version
|
"""Update a given backup version
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): the user whose backup version we're updating
|
user_id: the user whose backup version we're updating
|
||||||
version(str): the version ID of the backup version we're updating
|
version: the version ID of the backup version we're updating
|
||||||
info (dict): the new backup version info to store. If None, then
|
info: the new backup version info to store. If None, then the backup
|
||||||
the backup version info is not updated
|
version info is not updated.
|
||||||
version_etag (Optional[int]): etag of the keys in the backup. If
|
version_etag: etag of the keys in the backup. If None, then the etag
|
||||||
None, then the etag is not updated
|
is not updated.
|
||||||
"""
|
"""
|
||||||
updatevalues = {}
|
updatevalues = {}
|
||||||
|
|
||||||
|
@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||||
updatevalues["etag"] = version_etag
|
updatevalues["etag"] = version_etag
|
||||||
|
|
||||||
if updatevalues:
|
if updatevalues:
|
||||||
return self.db_pool.simple_update(
|
await self.db_pool.simple_update(
|
||||||
table="e2e_room_keys_versions",
|
table="e2e_room_keys_versions",
|
||||||
keyvalues={"user_id": user_id, "version": version},
|
keyvalues={"user_id": user_id, "version": version},
|
||||||
updatevalues=updatevalues,
|
updatevalues=updatevalues,
|
||||||
|
|
|
@ -368,8 +368,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=5000, iterable=True)
|
@cached(max_entries=5000, iterable=True)
|
||||||
def get_latest_event_ids_in_room(self, room_id):
|
async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
|
||||||
return self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
table="event_forward_extremities",
|
table="event_forward_extremities",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
retcol="event_id",
|
retcol="event_id",
|
||||||
|
|
|
@ -44,24 +44,26 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
desc="get_group",
|
desc="get_group",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_users_in_group(self, group_id, include_private=False):
|
async def get_users_in_group(
|
||||||
|
self, group_id: str, include_private: bool = False
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
# TODO: Pagination
|
# TODO: Pagination
|
||||||
|
|
||||||
keyvalues = {"group_id": group_id}
|
keyvalues = {"group_id": group_id}
|
||||||
if not include_private:
|
if not include_private:
|
||||||
keyvalues["is_public"] = True
|
keyvalues["is_public"] = True
|
||||||
|
|
||||||
return self.db_pool.simple_select_list(
|
return await self.db_pool.simple_select_list(
|
||||||
table="group_users",
|
table="group_users",
|
||||||
keyvalues=keyvalues,
|
keyvalues=keyvalues,
|
||||||
retcols=("user_id", "is_public", "is_admin"),
|
retcols=("user_id", "is_public", "is_admin"),
|
||||||
desc="get_users_in_group",
|
desc="get_users_in_group",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_invited_users_in_group(self, group_id):
|
async def get_invited_users_in_group(self, group_id: str) -> List[str]:
|
||||||
# TODO: Pagination
|
# TODO: Pagination
|
||||||
|
|
||||||
return self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
table="group_invites",
|
table="group_invites",
|
||||||
keyvalues={"group_id": group_id},
|
keyvalues={"group_id": group_id},
|
||||||
retcol="user_id",
|
retcol="user_id",
|
||||||
|
@ -265,15 +267,14 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return role
|
return role
|
||||||
|
|
||||||
def get_local_groups_for_room(self, room_id):
|
async def get_local_groups_for_room(self, room_id: str) -> List[str]:
|
||||||
"""Get all of the local group that contain a given room
|
"""Get all of the local group that contain a given room
|
||||||
Args:
|
Args:
|
||||||
room_id (str): The ID of a room
|
room_id: The ID of a room
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list[str]]: A twisted.Deferred containing a list of group ids
|
A list of group ids containing this room
|
||||||
containing this room
|
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
table="group_rooms",
|
table="group_rooms",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
retcol="group_id",
|
retcol="group_id",
|
||||||
|
@ -422,10 +423,10 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
|
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_publicised_groups_for_user(self, user_id):
|
async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
|
||||||
"""Get all groups a user is publicising
|
"""Get all groups a user is publicising
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
table="local_group_membership",
|
table="local_group_membership",
|
||||||
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
|
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
|
||||||
retcol="group_id",
|
retcol="group_id",
|
||||||
|
@ -466,8 +467,8 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_joined_groups(self, user_id):
|
async def get_joined_groups(self, user_id: str) -> List[str]:
|
||||||
return self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
table="local_group_membership",
|
table="local_group_membership",
|
||||||
keyvalues={"user_id": user_id, "membership": "join"},
|
keyvalues={"user_id": user_id, "membership": "join"},
|
||||||
retcol="group_id",
|
retcol="group_id",
|
||||||
|
@ -585,14 +586,14 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
|
|
||||||
class GroupServerStore(GroupServerWorkerStore):
|
class GroupServerStore(GroupServerWorkerStore):
|
||||||
def set_group_join_policy(self, group_id, join_policy):
|
async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
|
||||||
"""Set the join policy of a group.
|
"""Set the join policy of a group.
|
||||||
|
|
||||||
join_policy can be one of:
|
join_policy can be one of:
|
||||||
* "invite"
|
* "invite"
|
||||||
* "open"
|
* "open"
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="groups",
|
table="groups",
|
||||||
keyvalues={"group_id": group_id},
|
keyvalues={"group_id": group_id},
|
||||||
updatevalues={"join_policy": join_policy},
|
updatevalues={"join_policy": join_policy},
|
||||||
|
@ -1050,8 +1051,10 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
desc="add_room_to_group",
|
desc="add_room_to_group",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_room_in_group_visibility(self, group_id, room_id, is_public):
|
async def update_room_in_group_visibility(
|
||||||
return self.db_pool.simple_update(
|
self, group_id: str, room_id: str, is_public: bool
|
||||||
|
) -> int:
|
||||||
|
return await self.db_pool.simple_update(
|
||||||
table="group_rooms",
|
table="group_rooms",
|
||||||
keyvalues={"group_id": group_id, "room_id": room_id},
|
keyvalues={"group_id": group_id, "room_id": room_id},
|
||||||
updatevalues={"is_public": is_public},
|
updatevalues={"is_public": is_public},
|
||||||
|
@ -1076,10 +1079,12 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
"remove_room_from_group", _remove_room_from_group_txn
|
"remove_room_from_group", _remove_room_from_group_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_group_publicity(self, group_id, user_id, publicise):
|
async def update_group_publicity(
|
||||||
|
self, group_id: str, user_id: str, publicise: bool
|
||||||
|
) -> None:
|
||||||
"""Update whether the user is publicising their membership of the group
|
"""Update whether the user is publicising their membership of the group
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="local_group_membership",
|
table="local_group_membership",
|
||||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||||
updatevalues={"is_publicised": publicise},
|
updatevalues={"is_publicised": publicise},
|
||||||
|
@ -1218,20 +1223,24 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
desc="update_group_profile",
|
desc="update_group_profile",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_attestation_renewal(self, group_id, user_id, attestation):
|
async def update_attestation_renewal(
|
||||||
|
self, group_id: str, user_id: str, attestation: dict
|
||||||
|
) -> None:
|
||||||
"""Update an attestation that we have renewed
|
"""Update an attestation that we have renewed
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="group_attestations_renewals",
|
table="group_attestations_renewals",
|
||||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||||
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
|
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
|
||||||
desc="update_attestation_renewal",
|
desc="update_attestation_renewal",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_remote_attestion(self, group_id, user_id, attestation):
|
async def update_remote_attestion(
|
||||||
|
self, group_id: str, user_id: str, attestation: dict
|
||||||
|
) -> None:
|
||||||
"""Update an attestation that a remote has renewed
|
"""Update an attestation that a remote has renewed
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="group_attestations_remote",
|
table="group_attestations_remote",
|
||||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||||
updatevalues={
|
updatevalues={
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
|
@ -84,9 +84,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
desc="store_local_media",
|
desc="store_local_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def mark_local_media_as_safe(self, media_id: str):
|
async def mark_local_media_as_safe(self, media_id: str) -> None:
|
||||||
"""Mark a local media as safe from quarantining."""
|
"""Mark a local media as safe from quarantining."""
|
||||||
return self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="local_media_repository",
|
table="local_media_repository",
|
||||||
keyvalues={"media_id": media_id},
|
keyvalues={"media_id": media_id},
|
||||||
updatevalues={"safe_from_quarantine": True},
|
updatevalues={"safe_from_quarantine": True},
|
||||||
|
@ -158,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
desc="store_url_cache",
|
desc="store_url_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_local_media_thumbnails(self, media_id):
|
async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
|
||||||
return self.db_pool.simple_select_list(
|
return await self.db_pool.simple_select_list(
|
||||||
"local_media_repository_thumbnails",
|
"local_media_repository_thumbnails",
|
||||||
{"media_id": media_id},
|
{"media_id": media_id},
|
||||||
(
|
(
|
||||||
|
@ -271,8 +271,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"update_cached_last_access_time", update_cache_txn
|
"update_cached_last_access_time", update_cache_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_remote_media_thumbnails(self, origin, media_id):
|
async def get_remote_media_thumbnails(
|
||||||
return self.db_pool.simple_select_list(
|
self, origin: str, media_id: str
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
return await self.db_pool.simple_select_list(
|
||||||
"remote_media_cache_thumbnails",
|
"remote_media_cache_thumbnails",
|
||||||
{"media_origin": origin, "media_id": media_id},
|
{"media_origin": origin, "media_id": media_id},
|
||||||
(
|
(
|
||||||
|
|
|
@ -71,16 +71,20 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||||
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
|
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
async def set_profile_displayname(
|
||||||
return self.db_pool.simple_update_one(
|
self, user_localpart: str, new_displayname: str
|
||||||
|
) -> None:
|
||||||
|
await self.db_pool.simple_update_one(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
updatevalues={"displayname": new_displayname},
|
updatevalues={"displayname": new_displayname},
|
||||||
desc="set_profile_displayname",
|
desc="set_profile_displayname",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
async def set_profile_avatar_url(
|
||||||
return self.db_pool.simple_update_one(
|
self, user_localpart: str, new_avatar_url: str
|
||||||
|
) -> None:
|
||||||
|
await self.db_pool.simple_update_one(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
updatevalues={"avatar_url": new_avatar_url},
|
updatevalues={"avatar_url": new_avatar_url},
|
||||||
|
@ -106,8 +110,10 @@ class ProfileStore(ProfileWorkerStore):
|
||||||
desc="add_remote_profile_cache",
|
desc="add_remote_profile_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
|
async def update_remote_profile_cache(
|
||||||
return self.db_pool.simple_update(
|
self, user_id: str, displayname: str, avatar_url: str
|
||||||
|
) -> int:
|
||||||
|
return await self.db_pool.simple_update(
|
||||||
table="remote_profile_cache",
|
table="remote_profile_cache",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
updatevalues={
|
updatevalues={
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -62,8 +62,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
return {r["user_id"] for r in receipts}
|
return {r["user_id"] for r in receipts}
|
||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
def get_receipts_for_room(self, room_id, receipt_type):
|
async def get_receipts_for_room(
|
||||||
return self.db_pool.simple_select_list(
|
self, room_id: str, receipt_type: str
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
return await self.db_pool.simple_select_list(
|
||||||
table="receipts_linearized",
|
table="receipts_linearized",
|
||||||
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
|
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
|
||||||
retcols=("user_id", "event_id"),
|
retcols=("user_id", "event_id"),
|
||||||
|
|
|
@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
desc="add_user_bound_threepid",
|
desc="add_user_bound_threepid",
|
||||||
)
|
)
|
||||||
|
|
||||||
def user_get_bound_threepids(self, user_id):
|
async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
|
||||||
"""Get the threepids that a user has bound to an identity server through the homeserver
|
"""Get the threepids that a user has bound to an identity server through the homeserver
|
||||||
The homeserver remembers where binds to an identity server occurred. Using this
|
The homeserver remembers where binds to an identity server occurred. Using this
|
||||||
method can retrieve those threepids.
|
method can retrieve those threepids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The ID of the user to retrieve threepids for
|
user_id: The ID of the user to retrieve threepids for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list[dict]]: List of dictionaries containing the following:
|
List of dictionaries containing the following keys:
|
||||||
medium (str): The medium of the threepid (e.g "email")
|
medium (str): The medium of the threepid (e.g "email")
|
||||||
address (str): The address of the threepid (e.g "bob@example.com")
|
address (str): The address of the threepid (e.g "bob@example.com")
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_select_list(
|
return await self.db_pool.simple_select_list(
|
||||||
table="user_threepid_id_server",
|
table="user_threepid_id_server",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
retcols=["medium", "address"],
|
retcols=["medium", "address"],
|
||||||
|
@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
desc="remove_user_bound_threepid",
|
desc="remove_user_bound_threepid",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_id_servers_user_bound(self, user_id, medium, address):
|
async def get_id_servers_user_bound(
|
||||||
|
self, user_id: str, medium: str, address: str
|
||||||
|
) -> List[str]:
|
||||||
"""Get the list of identity servers that the server proxied bind
|
"""Get the list of identity servers that the server proxied bind
|
||||||
requests to for given user and threepid
|
requests to for given user and threepid
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str)
|
user_id: The user to query for identity servers.
|
||||||
medium (str)
|
medium: The medium to query for identity servers.
|
||||||
address (str)
|
address: The address to query for identity servers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list[str]]: Resolves to a list of identity servers
|
A list of identity servers
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
table="user_threepid_id_server",
|
table="user_threepid_id_server",
|
||||||
keyvalues={"user_id": user_id, "medium": medium, "address": address},
|
keyvalues={"user_id": user_id, "medium": medium, "address": address},
|
||||||
retcol="id_server",
|
retcol="id_server",
|
||||||
|
|
|
@ -125,8 +125,8 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
"get_room_with_stats", get_room_with_stats_txn, room_id
|
"get_room_with_stats", get_room_with_stats_txn, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_public_room_ids(self):
|
async def get_public_room_ids(self) -> List[str]:
|
||||||
return self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
table="rooms",
|
table="rooms",
|
||||||
keyvalues={"is_public": True},
|
keyvalues={"is_public": True},
|
||||||
retcol="room_id",
|
retcol="room_id",
|
||||||
|
|
|
@ -537,8 +537,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
desc="get_user_in_directory",
|
desc="get_user_in_directory",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_user_directory_stream_pos(self, stream_id):
|
async def update_user_directory_stream_pos(self, stream_id: str) -> None:
|
||||||
return self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="user_directory_stream_pos",
|
table="user_directory_stream_pos",
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
updatevalues={"stream_id": stream_id},
|
updatevalues={"stream_id": stream_id},
|
||||||
|
|
|
@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_all_room_state(self):
|
async def get_all_room_state(self):
|
||||||
return self.store.db_pool.simple_select_list(
|
return await self.store.db_pool.simple_select_list(
|
||||||
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
|
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -148,9 +148,11 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
|
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
|
||||||
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
|
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
|
||||||
|
|
||||||
ret = yield self.datastore.db_pool.simple_select_list(
|
ret = yield defer.ensureDeferred(
|
||||||
|
self.datastore.db_pool.simple_select_list(
|
||||||
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
|
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
|
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
|
||||||
self.mock_txn.execute.assert_called_with(
|
self.mock_txn.execute.assert_called_with(
|
||||||
|
@ -161,11 +163,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
def test_update_one_1col(self):
|
def test_update_one_1col(self):
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield self.datastore.db_pool.simple_update_one(
|
yield defer.ensureDeferred(
|
||||||
|
self.datastore.db_pool.simple_update_one(
|
||||||
table="tablename",
|
table="tablename",
|
||||||
keyvalues={"keycol": "TheKey"},
|
keyvalues={"keycol": "TheKey"},
|
||||||
updatevalues={"columnname": "New Value"},
|
updatevalues={"columnname": "New Value"},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.mock_txn.execute.assert_called_with(
|
self.mock_txn.execute.assert_called_with(
|
||||||
"UPDATE tablename SET columnname = ? WHERE keycol = ?",
|
"UPDATE tablename SET columnname = ? WHERE keycol = ?",
|
||||||
|
@ -176,11 +180,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
def test_update_one_4cols(self):
|
def test_update_one_4cols(self):
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield self.datastore.db_pool.simple_update_one(
|
yield defer.ensureDeferred(
|
||||||
|
self.datastore.db_pool.simple_update_one(
|
||||||
table="tablename",
|
table="tablename",
|
||||||
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
|
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
|
||||||
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
|
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.mock_txn.execute.assert_called_with(
|
self.mock_txn.execute.assert_called_with(
|
||||||
"UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?",
|
"UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?",
|
||||||
|
|
|
@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
["#my-room:test"],
|
["#my-room:test"],
|
||||||
(yield self.store.get_aliases_for_room(self.room.to_string())),
|
(
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.store.get_aliases_for_room(self.room.to_string())
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase):
|
||||||
def test_get_users_paginate(self):
|
def test_get_users_paginate(self):
|
||||||
yield self.store.register_user(self.user.to_string(), "pass")
|
yield self.store.register_user(self.user.to_string(), "pass")
|
||||||
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
|
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
|
||||||
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
|
yield defer.ensureDeferred(
|
||||||
|
self.store.set_profile_displayname(self.user.localpart, self.displayname)
|
||||||
|
)
|
||||||
|
|
||||||
users, total = yield self.store.get_users_paginate(
|
users, total = yield self.store.get_users_paginate(
|
||||||
0, 10, name="bc", guests=False
|
0, 10, name="bc", guests=False
|
||||||
|
|
|
@ -15,8 +15,9 @@
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
|
from twisted.internet.defer import succeed
|
||||||
|
|
||||||
|
from synapse.api.errors import FederationError
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.events import make_event_from_dict
|
||||||
from synapse.logging.context import LoggingContext
|
from synapse.logging.context import LoggingContext
|
||||||
from synapse.types import Requester, UserID
|
from synapse.types import Requester, UserID
|
||||||
|
@ -44,22 +45,17 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
user_id = UserID("us", "test")
|
user_id = UserID("us", "test")
|
||||||
our_user = Requester(user_id, None, False, False, None, None)
|
our_user = Requester(user_id, None, False, False, None, None)
|
||||||
room_creator = self.homeserver.get_room_creation_handler()
|
room_creator = self.homeserver.get_room_creation_handler()
|
||||||
room_deferred = ensureDeferred(
|
self.room_id = self.get_success(
|
||||||
room_creator.create_room(
|
room_creator.create_room(
|
||||||
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
|
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
|
||||||
)
|
)
|
||||||
)
|
)[0]["room_id"]
|
||||||
self.reactor.advance(0.1)
|
|
||||||
self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
|
|
||||||
|
|
||||||
self.store = self.homeserver.get_datastore()
|
self.store = self.homeserver.get_datastore()
|
||||||
|
|
||||||
# Figure out what the most recent event is
|
# Figure out what the most recent event is
|
||||||
most_recent = self.successResultOf(
|
most_recent = self.get_success(
|
||||||
maybeDeferred(
|
self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
|
||||||
self.homeserver.get_datastore().get_latest_event_ids_in_room,
|
|
||||||
self.room_id,
|
|
||||||
)
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
join_event = make_event_from_dict(
|
join_event = make_event_from_dict(
|
||||||
|
@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send the join, it should return None (which is not an error)
|
# Send the join, it should return None (which is not an error)
|
||||||
d = ensureDeferred(
|
self.assertEqual(
|
||||||
|
self.get_success(
|
||||||
self.handler.on_receive_pdu(
|
self.handler.on_receive_pdu(
|
||||||
"test.serv", join_event, sent_to_us_directly=True
|
"test.serv", join_event, sent_to_us_directly=True
|
||||||
)
|
)
|
||||||
|
),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
self.reactor.advance(1)
|
|
||||||
self.assertEqual(self.successResultOf(d), None)
|
|
||||||
|
|
||||||
# Make sure we actually joined the room
|
# Make sure we actually joined the room
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.successResultOf(
|
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
|
||||||
maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
|
|
||||||
)[0],
|
|
||||||
"$join:test.serv",
|
"$join:test.serv",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
self.http_client.post_json = post_json
|
self.http_client.post_json = post_json
|
||||||
|
|
||||||
# Figure out what the most recent event is
|
# Figure out what the most recent event is
|
||||||
most_recent = self.successResultOf(
|
most_recent = self.get_success(
|
||||||
maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
|
self.store.get_latest_event_ids_in_room(self.room_id)
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Now lie about an event
|
# Now lie about an event
|
||||||
|
@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
with LoggingContext(request="lying_event"):
|
with LoggingContext(request="lying_event"):
|
||||||
d = ensureDeferred(
|
failure = self.get_failure(
|
||||||
self.handler.on_receive_pdu(
|
self.handler.on_receive_pdu(
|
||||||
"test.serv", lying_event, sent_to_us_directly=True
|
"test.serv", lying_event, sent_to_us_directly=True
|
||||||
|
),
|
||||||
|
FederationError,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Step the reactor, so the database fetches come back
|
|
||||||
self.reactor.advance(1)
|
|
||||||
|
|
||||||
# on_receive_pdu should throw an error
|
# on_receive_pdu should throw an error
|
||||||
failure = self.failureResultOf(d)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
failure.value.args[0],
|
failure.value.args[0],
|
||||||
(
|
(
|
||||||
|
@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure the invalid event isn't there
|
# Make sure the invalid event isn't there
|
||||||
extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
|
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
|
||||||
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
|
self.assertEqual(extrem[0], "$join:test.serv")
|
||||||
|
|
||||||
def test_retry_device_list_resync(self):
|
def test_retry_device_list_resync(self):
|
||||||
"""Tests that device lists are marked as stale if they couldn't be synced, and
|
"""Tests that device lists are marked as stale if they couldn't be synced, and
|
||||||
|
|
Loading…
Reference in a new issue