mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 18:33:23 +01:00
Remove remaining usage of cursor_to_dict. (#16564)
This commit is contained in:
parent
c0ba319b22
commit
cfb6d38c47
18 changed files with 300 additions and 157 deletions
1
changelog.d/16564.misc
Normal file
1
changelog.d/16564.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve type hints.
|
|
@ -283,7 +283,7 @@ class AdminHandler:
|
|||
start, limit, user_id
|
||||
)
|
||||
for media in media_ids:
|
||||
writer.write_media_id(media["media_id"], media)
|
||||
writer.write_media_id(media.media_id, attr.asdict(media))
|
||||
|
||||
logger.info(
|
||||
"[%s] Written %d media_ids of %s",
|
||||
|
|
|
@ -33,6 +33,7 @@ from synapse.api.errors import (
|
|||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.storage.databases.main.room import LargestRoomStats
|
||||
from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
@ -170,26 +171,24 @@ class RoomListHandler:
|
|||
ignore_non_federatable=from_federation,
|
||||
)
|
||||
|
||||
def build_room_entry(room: JsonDict) -> JsonDict:
|
||||
def build_room_entry(room: LargestRoomStats) -> JsonDict:
|
||||
entry = {
|
||||
"room_id": room["room_id"],
|
||||
"name": room["name"],
|
||||
"topic": room["topic"],
|
||||
"canonical_alias": room["canonical_alias"],
|
||||
"num_joined_members": room["joined_members"],
|
||||
"avatar_url": room["avatar"],
|
||||
"world_readable": room["history_visibility"]
|
||||
"room_id": room.room_id,
|
||||
"name": room.name,
|
||||
"topic": room.topic,
|
||||
"canonical_alias": room.canonical_alias,
|
||||
"num_joined_members": room.joined_members,
|
||||
"avatar_url": room.avatar,
|
||||
"world_readable": room.history_visibility
|
||||
== HistoryVisibility.WORLD_READABLE,
|
||||
"guest_can_join": room["guest_access"] == "can_join",
|
||||
"join_rule": room["join_rules"],
|
||||
"room_type": room["room_type"],
|
||||
"guest_can_join": room.guest_access == "can_join",
|
||||
"join_rule": room.join_rules,
|
||||
"room_type": room.room_type,
|
||||
}
|
||||
|
||||
# Filter out Nones – rather omit the field altogether
|
||||
return {k: v for k, v in entry.items() if v is not None}
|
||||
|
||||
results = [build_room_entry(r) for r in results]
|
||||
|
||||
response: JsonDict = {}
|
||||
num_results = len(results)
|
||||
if limit is not None:
|
||||
|
@ -212,33 +211,33 @@ class RoomListHandler:
|
|||
# If there was a token given then we assume that there
|
||||
# must be previous results.
|
||||
response["prev_batch"] = RoomListNextBatch(
|
||||
last_joined_members=initial_entry["num_joined_members"],
|
||||
last_room_id=initial_entry["room_id"],
|
||||
last_joined_members=initial_entry.joined_members,
|
||||
last_room_id=initial_entry.room_id,
|
||||
direction_is_forward=False,
|
||||
).to_token()
|
||||
|
||||
if more_to_come:
|
||||
response["next_batch"] = RoomListNextBatch(
|
||||
last_joined_members=final_entry["num_joined_members"],
|
||||
last_room_id=final_entry["room_id"],
|
||||
last_joined_members=final_entry.joined_members,
|
||||
last_room_id=final_entry.room_id,
|
||||
direction_is_forward=True,
|
||||
).to_token()
|
||||
else:
|
||||
if has_batch_token:
|
||||
response["next_batch"] = RoomListNextBatch(
|
||||
last_joined_members=final_entry["num_joined_members"],
|
||||
last_room_id=final_entry["room_id"],
|
||||
last_joined_members=final_entry.joined_members,
|
||||
last_room_id=final_entry.room_id,
|
||||
direction_is_forward=True,
|
||||
).to_token()
|
||||
|
||||
if more_to_come:
|
||||
response["prev_batch"] = RoomListNextBatch(
|
||||
last_joined_members=initial_entry["num_joined_members"],
|
||||
last_room_id=initial_entry["room_id"],
|
||||
last_joined_members=initial_entry.joined_members,
|
||||
last_room_id=initial_entry.room_id,
|
||||
direction_is_forward=False,
|
||||
).to_token()
|
||||
|
||||
response["chunk"] = results
|
||||
response["chunk"] = [build_room_entry(r) for r in results]
|
||||
|
||||
response["total_room_count_estimate"] = await self.store.count_public_rooms(
|
||||
network_tuple,
|
||||
|
|
|
@ -703,24 +703,24 @@ class RoomSummaryHandler:
|
|||
# there should always be an entry
|
||||
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
|
||||
|
||||
entry = {
|
||||
"room_id": stats["room_id"],
|
||||
"name": stats["name"],
|
||||
"topic": stats["topic"],
|
||||
"canonical_alias": stats["canonical_alias"],
|
||||
"num_joined_members": stats["joined_members"],
|
||||
"avatar_url": stats["avatar"],
|
||||
"join_rule": stats["join_rules"],
|
||||
entry: JsonDict = {
|
||||
"room_id": stats.room_id,
|
||||
"name": stats.name,
|
||||
"topic": stats.topic,
|
||||
"canonical_alias": stats.canonical_alias,
|
||||
"num_joined_members": stats.joined_members,
|
||||
"avatar_url": stats.avatar,
|
||||
"join_rule": stats.join_rules,
|
||||
"world_readable": (
|
||||
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
|
||||
stats.history_visibility == HistoryVisibility.WORLD_READABLE
|
||||
),
|
||||
"guest_can_join": stats["guest_access"] == "can_join",
|
||||
"room_type": stats["room_type"],
|
||||
"guest_can_join": stats.guest_access == "can_join",
|
||||
"room_type": stats.room_type,
|
||||
}
|
||||
|
||||
if self._msc3266_enabled:
|
||||
entry["im.nheko.summary.version"] = stats["version"]
|
||||
entry["im.nheko.summary.encryption"] = stats["encryption"]
|
||||
entry["im.nheko.summary.version"] = stats.version
|
||||
entry["im.nheko.summary.encryption"] = stats.encryption
|
||||
|
||||
# Federation requests need to provide additional information so the
|
||||
# requested server is able to filter the response appropriately.
|
||||
|
|
|
@ -17,6 +17,8 @@ import logging
|
|||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
|
@ -418,7 +420,7 @@ class UserMediaRestServlet(RestServlet):
|
|||
start, limit, user_id, order_by, direction
|
||||
)
|
||||
|
||||
ret = {"media": media, "total": total}
|
||||
ret = {"media": [attr.asdict(m) for m in media], "total": total}
|
||||
if (start + limit) < total:
|
||||
ret["next_token"] = start + len(media)
|
||||
|
||||
|
@ -477,7 +479,7 @@ class UserMediaRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
deleted_media, total = await self.media_repository.delete_local_media_ids(
|
||||
[row["media_id"] for row in media]
|
||||
[m.media_id for m in media]
|
||||
)
|
||||
|
||||
return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
|
||||
|
|
|
@ -77,7 +77,18 @@ class ListRegistrationTokensRestServlet(RestServlet):
|
|||
await assert_requester_is_admin(self.auth, request)
|
||||
valid = parse_boolean(request, "valid")
|
||||
token_list = await self.store.get_registration_tokens(valid)
|
||||
return HTTPStatus.OK, {"registration_tokens": token_list}
|
||||
return HTTPStatus.OK, {
|
||||
"registration_tokens": [
|
||||
{
|
||||
"token": t[0],
|
||||
"uses_allowed": t[1],
|
||||
"pending": t[2],
|
||||
"completed": t[3],
|
||||
"expiry_time": t[4],
|
||||
}
|
||||
for t in token_list
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class NewRegistrationTokenRestServlet(RestServlet):
|
||||
|
|
|
@ -16,6 +16,8 @@ from http import HTTPStatus
|
|||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
|
@ -306,10 +308,13 @@ class RoomRestServlet(RestServlet):
|
|||
raise NotFoundError("Room not found")
|
||||
|
||||
members = await self.store.get_users_in_room(room_id)
|
||||
ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
|
||||
ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
|
||||
result = attr.asdict(ret)
|
||||
result["joined_local_devices"] = await self.store.count_devices_by_users(
|
||||
members
|
||||
)
|
||||
result["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
|
||||
|
||||
return HTTPStatus.OK, ret
|
||||
return HTTPStatus.OK, result
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
|
|
|
@ -18,6 +18,8 @@ import secrets
|
|||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction, UserTypes
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
|
@ -161,11 +163,13 @@ class UsersRestServletV2(RestServlet):
|
|||
)
|
||||
|
||||
# If support for MSC3866 is not enabled, don't show the approval flag.
|
||||
filter = None
|
||||
if not self._msc3866_enabled:
|
||||
for user in users:
|
||||
del user["approved"]
|
||||
|
||||
ret = {"users": users, "total": total}
|
||||
def _filter(a: attr.Attribute) -> bool:
|
||||
return a.name != "approved"
|
||||
|
||||
ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total}
|
||||
if (start + limit) < total:
|
||||
ret["next_token"] = str(start + len(users))
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from typing import (
|
|||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
@ -488,14 +489,14 @@ class BackgroundUpdater:
|
|||
True if we have finished running all the background updates, otherwise False
|
||||
"""
|
||||
|
||||
def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
|
||||
def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT update_name, depends_on FROM background_updates
|
||||
ORDER BY ordering, update_name
|
||||
"""
|
||||
)
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(List[Tuple[str, Optional[str]]], txn.fetchall())
|
||||
|
||||
if not self._current_background_update:
|
||||
all_pending_updates = await self.db_pool.runInteraction(
|
||||
|
@ -507,14 +508,13 @@ class BackgroundUpdater:
|
|||
return True
|
||||
|
||||
# find the first update which isn't dependent on another one in the queue.
|
||||
pending = {update["update_name"] for update in all_pending_updates}
|
||||
for upd in all_pending_updates:
|
||||
depends_on = upd["depends_on"]
|
||||
pending = {update_name for update_name, depends_on in all_pending_updates}
|
||||
for update_name, depends_on in all_pending_updates:
|
||||
if not depends_on or depends_on not in pending:
|
||||
break
|
||||
logger.info(
|
||||
"Not starting on bg update %s until %s is done",
|
||||
upd["update_name"],
|
||||
update_name,
|
||||
depends_on,
|
||||
)
|
||||
else:
|
||||
|
@ -524,7 +524,7 @@ class BackgroundUpdater:
|
|||
"another: dependency cycle?"
|
||||
)
|
||||
|
||||
self._current_background_update = upd["update_name"]
|
||||
self._current_background_update = update_name
|
||||
|
||||
# We have a background update to run, otherwise we would have returned
|
||||
# early.
|
||||
|
|
|
@ -18,7 +18,6 @@ import logging
|
|||
import time
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from sys import intern
|
||||
from time import monotonic as monotonic_time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
|
@ -1042,20 +1041,6 @@ class DatabasePool:
|
|||
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
|
||||
"""Converts a SQL cursor into an list of dicts.
|
||||
|
||||
Args:
|
||||
cursor: The DBAPI cursor which has executed a query.
|
||||
Returns:
|
||||
A list of dicts where the key is the column header.
|
||||
"""
|
||||
assert cursor.description is not None, "cursor.description was None"
|
||||
col_headers = [intern(str(column[0])) for column in cursor.description]
|
||||
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||
return results
|
||||
|
||||
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
|
||||
"""Runs a single query for a result set.
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage._base import make_in_list_sql_clause
|
||||
|
@ -28,7 +30,7 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
from synapse.storage.engines import BaseDatabaseEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
from .account_data import AccountDataStore
|
||||
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||
|
@ -82,6 +84,25 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class UserPaginateResponse:
|
||||
"""This is very similar to UserInfo, but not quite the same."""
|
||||
|
||||
name: str
|
||||
user_type: Optional[str]
|
||||
is_guest: bool
|
||||
admin: bool
|
||||
deactivated: bool
|
||||
shadow_banned: bool
|
||||
displayname: Optional[str]
|
||||
avatar_url: Optional[str]
|
||||
creation_ts: Optional[int]
|
||||
approved: bool
|
||||
erased: bool
|
||||
last_seen_ts: int
|
||||
locked: bool
|
||||
|
||||
|
||||
class DataStore(
|
||||
EventsBackgroundUpdatesStore,
|
||||
ExperimentalFeaturesStore,
|
||||
|
@ -156,7 +177,7 @@ class DataStore(
|
|||
approved: bool = True,
|
||||
not_user_types: Optional[List[str]] = None,
|
||||
locked: bool = False,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
) -> Tuple[List[UserPaginateResponse], int]:
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users and the
|
||||
total number of users matching the filter criteria.
|
||||
|
@ -182,7 +203,7 @@ class DataStore(
|
|||
|
||||
def get_users_paginate_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
) -> Tuple[List[UserPaginateResponse], int]:
|
||||
filters = []
|
||||
args: list = []
|
||||
|
||||
|
@ -282,13 +303,24 @@ class DataStore(
|
|||
"""
|
||||
args += [limit, start]
|
||||
txn.execute(sql, args)
|
||||
users = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
# some of those boolean values are returned as integers when we're on SQLite
|
||||
columns_to_boolify = ["erased"]
|
||||
for user in users:
|
||||
for column in columns_to_boolify:
|
||||
user[column] = bool(user[column])
|
||||
users = [
|
||||
UserPaginateResponse(
|
||||
name=row[0],
|
||||
user_type=row[1],
|
||||
is_guest=bool(row[2]),
|
||||
admin=bool(row[3]),
|
||||
deactivated=bool(row[4]),
|
||||
shadow_banned=bool(row[5]),
|
||||
displayname=row[6],
|
||||
avatar_url=row[7],
|
||||
creation_ts=row[8],
|
||||
approved=bool(row[9]),
|
||||
erased=bool(row[10]),
|
||||
last_seen_ts=row[11],
|
||||
locked=bool(row[12]),
|
||||
)
|
||||
for row in txn
|
||||
]
|
||||
|
||||
return users, count
|
||||
|
||||
|
|
|
@ -1620,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
#
|
||||
# For each duplicate, we delete all the existing rows and put one back.
|
||||
|
||||
KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
|
||||
last_row = progress.get(
|
||||
"last_row",
|
||||
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
|
||||
|
@ -1628,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
def _txn(txn: LoggingTransaction) -> int:
|
||||
clause, args = make_tuple_comparison_clause(
|
||||
[(x, last_row[x]) for x in KEY_COLS]
|
||||
[
|
||||
("stream_id", last_row["stream_id"]),
|
||||
("destination", last_row["destination"]),
|
||||
("user_id", last_row["user_id"]),
|
||||
("device_id", last_row["device_id"]),
|
||||
]
|
||||
)
|
||||
sql = """
|
||||
sql = f"""
|
||||
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
|
||||
FROM device_lists_outbound_pokes
|
||||
WHERE %s
|
||||
GROUP BY %s
|
||||
WHERE {clause}
|
||||
GROUP BY stream_id, destination, user_id, device_id
|
||||
HAVING count(*) > 1
|
||||
ORDER BY %s
|
||||
ORDER BY stream_id, destination, user_id, device_id
|
||||
LIMIT ?
|
||||
""" % (
|
||||
clause, # WHERE
|
||||
",".join(KEY_COLS), # GROUP BY
|
||||
",".join(KEY_COLS), # ORDER BY
|
||||
)
|
||||
"""
|
||||
txn.execute(sql, args + [batch_size])
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
rows = txn.fetchall()
|
||||
|
||||
row = None
|
||||
for row in rows:
|
||||
stream_id, destination, user_id, device_id = None, None, None, None
|
||||
for stream_id, destination, user_id, device_id, _ in rows:
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
"device_lists_outbound_pokes",
|
||||
{x: row[x] for x in KEY_COLS},
|
||||
{
|
||||
"stream_id": stream_id,
|
||||
"destination": destination,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
|
||||
row["sent"] = False
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"device_lists_outbound_pokes",
|
||||
row,
|
||||
{
|
||||
"stream_id": stream_id,
|
||||
"destination": destination,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"sent": False,
|
||||
},
|
||||
)
|
||||
|
||||
if row:
|
||||
if rows:
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn,
|
||||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
|
||||
{"last_row": row},
|
||||
{
|
||||
"last_row": {
|
||||
"stream_id": stream_id,
|
||||
"destination": destination,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
|
|
@ -26,6 +26,8 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.media._base import ThumbnailInfo
|
||||
|
@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
|
|||
)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class LocalMedia:
|
||||
media_id: str
|
||||
media_type: str
|
||||
media_length: int
|
||||
upload_name: str
|
||||
created_ts: int
|
||||
last_access_ts: int
|
||||
quarantined_by: Optional[str]
|
||||
safe_from_quarantine: bool
|
||||
|
||||
|
||||
class MediaSortOrder(Enum):
|
||||
"""
|
||||
Enum to define the sorting method used when returning media with
|
||||
|
@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
user_id: str,
|
||||
order_by: str = MediaSortOrder.CREATED_TS.value,
|
||||
direction: Direction = Direction.FORWARDS,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
) -> Tuple[List[LocalMedia], int]:
|
||||
"""Get a paginated list of metadata for a local piece of media
|
||||
which an user_id has uploaded
|
||||
|
||||
|
@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
def get_local_media_by_user_paginate_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
) -> Tuple[List[LocalMedia], int]:
|
||||
# Set ordering
|
||||
order_by_column = MediaSortOrder(order_by).value
|
||||
|
||||
|
@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
sql = """
|
||||
SELECT
|
||||
"media_id",
|
||||
"media_type",
|
||||
"media_length",
|
||||
"upload_name",
|
||||
"created_ts",
|
||||
"last_access_ts",
|
||||
"quarantined_by",
|
||||
"safe_from_quarantine"
|
||||
media_id,
|
||||
media_type,
|
||||
media_length,
|
||||
upload_name,
|
||||
created_ts,
|
||||
last_access_ts,
|
||||
quarantined_by,
|
||||
safe_from_quarantine
|
||||
FROM local_media_repository
|
||||
WHERE user_id = ?
|
||||
ORDER BY {order_by_column} {order}, media_id ASC
|
||||
|
@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
args += [limit, start]
|
||||
txn.execute(sql, args)
|
||||
media = self.db_pool.cursor_to_dict(txn)
|
||||
media = [
|
||||
LocalMedia(
|
||||
media_id=row[0],
|
||||
media_type=row[1],
|
||||
media_length=row[2],
|
||||
upload_name=row[3],
|
||||
created_ts=row[4],
|
||||
last_access_ts=row[5],
|
||||
quarantined_by=row[6],
|
||||
safe_from_quarantine=bool(row[7]),
|
||||
)
|
||||
for row in txn
|
||||
]
|
||||
return media, count
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
|
|
@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
async def get_registration_tokens(
|
||||
self, valid: Optional[bool] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
|
||||
"""List all registration tokens. Used by the admin API.
|
||||
|
||||
Args:
|
||||
|
@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
Default is None: return all tokens regardless of validity.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each containing details of a token.
|
||||
A list of tuples containing:
|
||||
* The token
|
||||
* The number of users allowed (or None)
|
||||
* Whether it is pending
|
||||
* Whether it has been completed
|
||||
* An expiry time (or None if no expiry)
|
||||
"""
|
||||
|
||||
def select_registration_tokens_txn(
|
||||
txn: LoggingTransaction, now: int, valid: Optional[bool]
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
|
||||
if valid is None:
|
||||
# Return all tokens regardless of validity
|
||||
txn.execute("SELECT * FROM registration_tokens")
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT token, uses_allowed, pending, completed, expiry_time
|
||||
FROM registration_tokens
|
||||
"""
|
||||
)
|
||||
|
||||
elif valid:
|
||||
# Select valid tokens only
|
||||
sql = (
|
||||
"SELECT * FROM registration_tokens WHERE "
|
||||
"(uses_allowed > pending + completed OR uses_allowed IS NULL) "
|
||||
"AND (expiry_time > ? OR expiry_time IS NULL)"
|
||||
)
|
||||
sql = """
|
||||
SELECT token, uses_allowed, pending, completed, expiry_time
|
||||
FROM registration_tokens
|
||||
WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL)
|
||||
AND (expiry_time > ? OR expiry_time IS NULL)
|
||||
"""
|
||||
txn.execute(sql, [now])
|
||||
|
||||
else:
|
||||
# Select invalid tokens only
|
||||
sql = (
|
||||
"SELECT * FROM registration_tokens WHERE "
|
||||
"uses_allowed <= pending + completed OR expiry_time <= ?"
|
||||
)
|
||||
sql = """
|
||||
SELECT token, uses_allowed, pending, completed, expiry_time
|
||||
FROM registration_tokens
|
||||
WHERE uses_allowed <= pending + completed OR expiry_time <= ?
|
||||
"""
|
||||
txn.execute(sql, [now])
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(
|
||||
List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall()
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"select_registration_tokens",
|
||||
|
|
|
@ -78,6 +78,31 @@ class RatelimitOverride:
|
|||
burst_count: int
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class LargestRoomStats:
|
||||
room_id: str
|
||||
name: Optional[str]
|
||||
canonical_alias: Optional[str]
|
||||
joined_members: int
|
||||
join_rules: Optional[str]
|
||||
guest_access: Optional[str]
|
||||
history_visibility: Optional[str]
|
||||
state_events: int
|
||||
avatar: Optional[str]
|
||||
topic: Optional[str]
|
||||
room_type: Optional[str]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class RoomStats(LargestRoomStats):
|
||||
joined_local_members: int
|
||||
version: Optional[str]
|
||||
creator: Optional[str]
|
||||
encryption: Optional[str]
|
||||
federatable: bool
|
||||
public: bool
|
||||
|
||||
|
||||
class RoomSortOrder(Enum):
|
||||
"""
|
||||
Enum to define the sorting method used when returning rooms with get_rooms_paginate
|
||||
|
@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
allow_none=True,
|
||||
)
|
||||
|
||||
async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
|
||||
"""Retrieve room with statistics.
|
||||
|
||||
Args:
|
||||
|
@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
def get_room_with_stats_txn(
|
||||
txn: LoggingTransaction, room_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[RoomStats]:
|
||||
sql = """
|
||||
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
|
||||
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
|
||||
|
@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
WHERE room_id = ?
|
||||
"""
|
||||
txn.execute(sql, [room_id])
|
||||
# Catch error if sql returns empty result to return "None" instead of an error
|
||||
try:
|
||||
res = self.db_pool.cursor_to_dict(txn)[0]
|
||||
except IndexError:
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
res["federatable"] = bool(res["federatable"])
|
||||
res["public"] = bool(res["public"])
|
||||
return res
|
||||
return RoomStats(
|
||||
room_id=row[0],
|
||||
name=row[1],
|
||||
canonical_alias=row[2],
|
||||
joined_members=row[3],
|
||||
joined_local_members=row[4],
|
||||
version=row[5],
|
||||
creator=row[6],
|
||||
encryption=row[7],
|
||||
federatable=bool(row[8]),
|
||||
public=bool(row[9]),
|
||||
join_rules=row[10],
|
||||
guest_access=row[11],
|
||||
history_visibility=row[12],
|
||||
state_events=row[13],
|
||||
avatar=row[14],
|
||||
topic=row[15],
|
||||
room_type=row[16],
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_room_with_stats", get_room_with_stats_txn, room_id
|
||||
|
@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
bounds: Optional[Tuple[int, str]],
|
||||
forwards: bool,
|
||||
ignore_non_federatable: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[LargestRoomStats]:
|
||||
"""Gets the largest public rooms (where largest is in terms of joined
|
||||
members, as tracked in the statistics table).
|
||||
|
||||
|
@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
def _get_largest_public_rooms_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[LargestRoomStats]:
|
||||
txn.execute(sql, query_args)
|
||||
|
||||
results = self.db_pool.cursor_to_dict(txn)
|
||||
results = [
|
||||
LargestRoomStats(
|
||||
room_id=r[0],
|
||||
name=r[1],
|
||||
canonical_alias=r[3],
|
||||
joined_members=r[4],
|
||||
join_rules=r[8],
|
||||
guest_access=r[7],
|
||||
history_visibility=r[6],
|
||||
state_events=0,
|
||||
avatar=r[5],
|
||||
topic=r[2],
|
||||
room_type=r[9],
|
||||
)
|
||||
for r in txn
|
||||
]
|
||||
|
||||
if not forwards:
|
||||
results.reverse()
|
||||
|
||||
return results
|
||||
|
||||
ret_val = await self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_largest_public_rooms", _get_largest_public_rooms_txn
|
||||
)
|
||||
return ret_val
|
||||
|
||||
@cached(max_entries=10000)
|
||||
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
|
||||
|
|
|
@ -342,10 +342,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
# Ensure the room is properly not federated.
|
||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
||||
assert room is not None
|
||||
self.assertFalse(room["federatable"])
|
||||
self.assertFalse(room["public"])
|
||||
self.assertEqual(room["join_rules"], "public")
|
||||
self.assertIsNone(room["guest_access"])
|
||||
self.assertFalse(room.federatable)
|
||||
self.assertFalse(room.public)
|
||||
self.assertEqual(room.join_rules, "public")
|
||||
self.assertIsNone(room.guest_access)
|
||||
|
||||
# The user should be in the room.
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
|
@ -372,7 +372,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
# Ensure the room is properly a public room.
|
||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
||||
assert room is not None
|
||||
self.assertEqual(room["join_rules"], "public")
|
||||
self.assertEqual(room.join_rules, "public")
|
||||
|
||||
# Both users should be in the room.
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(inviter))
|
||||
|
@ -411,9 +411,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
# Ensure the room is properly a private room.
|
||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
||||
assert room is not None
|
||||
self.assertFalse(room["public"])
|
||||
self.assertEqual(room["join_rules"], "invite")
|
||||
self.assertEqual(room["guest_access"], "can_join")
|
||||
self.assertFalse(room.public)
|
||||
self.assertEqual(room.join_rules, "invite")
|
||||
self.assertEqual(room.guest_access, "can_join")
|
||||
|
||||
# Both users should be in the room.
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(inviter))
|
||||
|
@ -455,9 +455,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
# Ensure the room is properly a private room.
|
||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
||||
assert room is not None
|
||||
self.assertFalse(room["public"])
|
||||
self.assertEqual(room["join_rules"], "invite")
|
||||
self.assertEqual(room["guest_access"], "can_join")
|
||||
self.assertFalse(room.public)
|
||||
self.assertEqual(room.join_rules, "invite")
|
||||
self.assertEqual(room.guest_access, "can_join")
|
||||
|
||||
# Both users should be in the room.
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(inviter))
|
||||
|
|
|
@ -39,11 +39,11 @@ class DataStoreTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(1, total)
|
||||
self.assertEqual(self.displayname, users.pop()["displayname"])
|
||||
self.assertEqual(self.displayname, users.pop().displayname)
|
||||
|
||||
users, total = self.get_success(
|
||||
self.store.get_users_paginate(0, 10, name="BC", guests=False)
|
||||
)
|
||||
|
||||
self.assertEqual(1, total)
|
||||
self.assertEqual(self.displayname, users.pop()["displayname"])
|
||||
self.assertEqual(self.displayname, users.pop().displayname)
|
||||
|
|
|
@ -59,14 +59,9 @@ class RoomStoreTestCase(HomeserverTestCase):
|
|||
def test_get_room_with_stats(self) -> None:
|
||||
res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
|
||||
assert res is not None
|
||||
self.assertLessEqual(
|
||||
{
|
||||
"room_id": self.room.to_string(),
|
||||
"creator": self.u_creator.to_string(),
|
||||
"public": True,
|
||||
}.items(),
|
||||
res.items(),
|
||||
)
|
||||
self.assertEqual(res.room_id, self.room.to_string())
|
||||
self.assertEqual(res.creator, self.u_creator.to_string())
|
||||
self.assertTrue(res.public)
|
||||
|
||||
def test_get_room_with_stats_unknown_room(self) -> None:
|
||||
self.assertIsNone(
|
||||
|
|
Loading…
Reference in a new issue