diff --git a/changelog.d/11342.misc b/changelog.d/11342.misc new file mode 100644 index 000000000..86594a332 --- /dev/null +++ b/changelog.d/11342.misc @@ -0,0 +1 @@ +Add type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index 710b1f3a4..b2953974e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -38,7 +38,6 @@ exclude = (?x) |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/presence.py - |synapse/storage/databases/main/profile.py |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py @@ -182,6 +181,9 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.room_batch] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.profile] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.state_deltas] disallow_untyped_defs = True @@ -284,6 +286,9 @@ disallow_untyped_defs = True [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True +[mypy-tests.storage.test_profile] +disallow_untyped_defs = True + [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index dd8e27e22..e197b7203 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo @@ -104,7 +105,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="update_remote_profile_cache", ) - async def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id: str) -> None: """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ @@ -116,9 +117,9 @@ class ProfileWorkerStore(SQLBaseStore): desc="delete_remote_profile_cache", ) - async def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool: """Check whether we are interested in a remote user's profile.""" - res = await self.db_pool.simple_select_one_onecol( + res: Optional[str] = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -139,13 +140,16 @@ class ProfileWorkerStore(SQLBaseStore): if res: return True + return False async def get_remote_profile_cache_entries_that_expire( self, last_checked: int ) -> List[Dict[str, str]]: """Get all users who haven't been checked since `last_checked`""" - def _get_remote_profile_cache_entries_that_expire_txn(txn): + def _get_remote_profile_cache_entries_that_expire_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, str]]: sql = """ SELECT user_id, displayname, avatar_url FROM remote_profile_cache diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index a1ba99ff1..d37736edf 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -11,19 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest class ProfileStoreTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") - def test_displayname(self): + def test_displayname(self) -> None: self.get_success(self.store.create_profile(self.u_frank.localpart)) self.get_success( @@ -48,7 +51,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) ) - def test_avatar_url(self): + def test_avatar_url(self) -> None: self.get_success(self.store.create_profile(self.u_frank.localpart)) self.get_success(