Improve typing in user_directory files (#10891)

* Improve typing in user_directory files

This makes the user_directory.py in storage pass most of mypy's
checks (including `no-untyped-defs`). Unfortunately that file is in the
tangled web of Store class inheritance so doesn't pass mypy at the moment.

The handlers directory has already been mypyed.

Co-authored-by: reivilibre <olivier@librepush.net>
This commit is contained in:
David Robertson 2021-09-24 10:38:22 +01:00 committed by GitHub
parent e704cc2a48
commit 7f3352743e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 37 deletions

1
changelog.d/10891.misc Normal file
View file

@ -0,0 +1 @@
Improve type hinting in the user directory code.

View file

@ -85,9 +85,11 @@ files =
tests/handlers/test_room_summary.py,
tests/handlers/test_send_email.py,
tests/handlers/test_sync.py,
tests/handlers/test_user_directory.py,
tests/rest/client/test_login.py,
tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/storage/test_user_directory.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

View file

@ -14,14 +14,28 @@
import logging
import re
from typing import Any, Dict, Iterable, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
cast,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.storage.types import Connection
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@ -57,10 +76,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
async def _populate_user_directory_createtables(self, progress, batch_size):
async def _populate_user_directory_createtables(
self, progress: JsonDict, batch_size: int
) -> int:
# Get all the rooms that we want to process.
def _make_staging_area(txn):
def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
"CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE
@ -110,16 +131,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
return 1
async def _populate_user_directory_cleanup(self, progress, batch_size):
async def _populate_user_directory_cleanup(
self,
progress: JsonDict,
batch_size: int,
) -> int:
"""
Update the user directory stream position, then clean up the old tables.
"""
position = await self.db_pool.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
TEMP_TABLE + "_position", {}, "position"
)
await self.update_user_directory_stream_pos(position)
def _delete_staging_area(txn):
def _delete_staging_area(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
@ -133,18 +158,32 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
return 1
async def _populate_user_directory_process_rooms(self, progress, batch_size):
async def _populate_user_directory_process_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Rescan the state of all rooms so we can track
- who's in a public room;
- which local users share a private room with other users (local
and remote); and
- who should be in the user_directory.
Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.
Returns:
number of events processed.
"""
# If we don't have progress filed, delete everything.
if not progress:
await self.delete_all_from_user_dir()
def _get_next_batch(txn):
def _get_next_batch(
txn: LoggingTransaction,
) -> Optional[Sequence[Tuple[str, int]]]:
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """
@ -155,7 +194,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
TEMP_TABLE + "_rooms",
)
txn.execute(sql)
rooms_to_work_on = txn.fetchall()
rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
if not rooms_to_work_on:
return None
@ -163,7 +202,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# Get how many are left to process, so we can give status on how
# far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
progress["remaining"] = txn.fetchone()[0]
result = txn.fetchone()
assert result is not None
progress["remaining"] = result[0]
return rooms_to_work_on
@ -261,29 +302,33 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count
async def _populate_user_directory_process_users(self, progress, batch_size):
async def _populate_user_directory_process_users(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Add all local users to the user directory.
"""
def _get_next_batch(txn):
def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
sql = "SELECT user_id FROM %s LIMIT %s" % (
TEMP_TABLE + "_users",
str(batch_size),
)
txn.execute(sql)
users_to_work_on = txn.fetchall()
user_result = cast(List[Tuple[str]], txn.fetchall())
if not users_to_work_on:
if not user_result:
return None
users_to_work_on = [x[0] for x in users_to_work_on]
users_to_work_on = [x[0] for x in user_result]
# Get how many are left to process, so we can give status on how
# far we are in processing
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
txn.execute(sql)
progress["remaining"] = txn.fetchone()[0]
count_result = txn.fetchone()
assert count_result is not None
progress["remaining"] = count_result[0]
return users_to_work_on
@ -324,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
async def is_room_world_readable_or_publicly_joinable(self, room_id):
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
# Create a state filter that only queries join and history state event
@ -368,7 +413,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if not isinstance(avatar_url, str):
avatar_url = None
def _update_profile_in_user_dir_txn(txn):
def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
@ -435,7 +480,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id, other_user_id in user_id_tuples
],
value_names=(),
value_values=None,
value_values=(),
desc="add_users_who_share_room",
)
@ -454,14 +499,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(),
value_values=None,
value_values=(),
desc="add_users_in_public_rooms",
)
async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory"""
def _delete_all_from_user_dir_txn(txn):
def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
@ -473,7 +518,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
hs: "HomeServer",
) -> None:
super().__init__(database, db_conn, hs)
self._prefer_local_users_in_search = (
@ -506,7 +556,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
self._server_name = hs.config.server.server_name
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
@ -532,7 +582,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_from_user_dir", _remove_from_user_dir_txn
)
async def get_users_in_dir_due_to_room(self, room_id):
async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
@ -565,7 +615,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
room_id
"""
def _remove_user_who_share_room_txn(txn):
def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
@ -586,7 +636,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
async def get_user_dir_rooms_user_is_in(self, user_id):
async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
"""
Returns the rooms that a user is in.
@ -628,7 +678,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
def _get_shared_rooms_for_users_txn(txn):
def _get_shared_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute(
"""
SELECT p1.room_id
@ -669,7 +721,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
desc="get_user_directory_stream_pos",
)
async def search_user_dir(self, user_id, search_term, limit):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
"""Searches for users in directory
Returns:
@ -705,7 +759,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
ordering_arguments = ()
ordering_arguments: Tuple[str, ...] = ()
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@ -811,7 +865,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results}
def _parse_query_sqlite(search_term):
def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
return " & ".join("(%s* OR %s)" % (result, result) for result in results)
def _parse_query_postgres(search_term):
def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something

View file

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from unittest.mock import Mock
from urllib.parse import quote
@ -325,7 +326,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
return r
def get_users_in_public_rooms(self):
def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
r = self.get_success(
self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
@ -336,7 +337,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
retval.append((i["user_id"], i["room_id"]))
return retval
def get_users_who_share_private_rooms(self):
def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
return self.get_success(
self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",