0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-10 12:02:43 +01:00

Add column full_user_id to tables profiles and user_filters. (#15458)

This commit is contained in:
Shay 2023-04-26 16:03:26 -07:00 committed by GitHub
parent 247e6a8a78
commit 301b4156d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 186 additions and 74 deletions

1
changelog.d/15458.misc Normal file
View file

@ -0,0 +1 @@
Add column `full_user_id` to tables `profiles` and `user_filters`.

View file

@ -54,7 +54,7 @@ from synapse.logging.context import (
) )
from synapse.notifier import ReplicationNotifier from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
@ -69,6 +69,7 @@ from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore, MediaRepositoryBackgroundUpdateStore,
) )
from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore
from synapse.storage.databases.main.profile import ProfileWorkerStore
from synapse.storage.databases.main.pusher import ( from synapse.storage.databases.main.pusher import (
PusherBackgroundUpdatesStore, PusherBackgroundUpdatesStore,
PusherWorkerStore, PusherWorkerStore,
@ -229,6 +230,8 @@ class Store(
EndToEndRoomKeyBackgroundStore, EndToEndRoomKeyBackgroundStore,
StatsStore, StatsStore,
AccountDataWorkerStore, AccountDataWorkerStore,
FilteringWorkerStore,
ProfileWorkerStore,
PushRuleStore, PushRuleStore,
PusherWorkerStore, PusherWorkerStore,
PusherBackgroundUpdatesStore, PusherBackgroundUpdatesStore,

View file

@ -170,11 +170,9 @@ class Filtering:
result = await self.store.get_user_filter(user_localpart, filter_id) result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(self._hs, result) return FilterCollection(self._hs, result)
def add_user_filter( def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> Awaitable[int]:
self, user_localpart: str, user_filter: JsonDict
) -> Awaitable[int]:
self.check_valid_filter(user_filter) self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter) return self.store.add_user_filter(user_id, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or # TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for # replace_user_filter at some point? There's no REST API specified for

View file

@ -178,9 +178,7 @@ class ProfileHandler:
authenticated_entity=requester.authenticated_entity, authenticated_entity=requester.authenticated_entity,
) )
await self.store.set_profile_displayname( await self.store.set_profile_displayname(target_user, displayname_to_set)
target_user.localpart, displayname_to_set
)
profile = await self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
@ -272,9 +270,7 @@ class ProfileHandler:
target_user, authenticated_entity=requester.authenticated_entity target_user, authenticated_entity=requester.authenticated_entity
) )
await self.store.set_profile_avatar_url( await self.store.set_profile_avatar_url(target_user, avatar_url_to_set)
target_user.localpart, avatar_url_to_set
)
profile = await self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(

View file

@ -94,7 +94,7 @@ class CreateFilterRestServlet(RestServlet):
set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit) set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit)
filter_id = await self.filtering.add_user_filter( filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content user_id=target_user, user_filter=content
) )
return 200, {"filter_id": str(filter_id)} return 200, {"filter_id": str(filter_id)}

View file

@ -16,15 +16,38 @@
from typing import Optional, Tuple, Union, cast from typing import Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from typing_extensions import TYPE_CHECKING
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction from synapse.storage.database import (
from synapse.types import JsonDict DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict, UserID
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
class FilteringWorkerStore(SQLBaseStore): class FilteringWorkerStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"full_users_filters_unique_idx",
index_name="full_users_unique_idx",
table="user_filters",
columns=["full_user_id, filter_id"],
unique=True,
)
@cached(num_args=2) @cached(num_args=2)
async def get_user_filter( async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str] self, user_localpart: str, filter_id: Union[int, str]
@ -46,7 +69,7 @@ class FilteringWorkerStore(SQLBaseStore):
return db_to_json(def_json) return db_to_json(def_json)
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: async def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter) def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then # Need an atomic transaction to SELECT the maximal ID so far then
@ -56,13 +79,13 @@ class FilteringWorkerStore(SQLBaseStore):
"SELECT filter_id FROM user_filters " "SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?" "WHERE user_id = ? AND filter_json = ?"
) )
txn.execute(sql, (user_localpart, bytearray(def_json))) txn.execute(sql, (user_id.localpart, bytearray(def_json)))
filter_id_response = txn.fetchone() filter_id_response = txn.fetchone()
if filter_id_response is not None: if filter_id_response is not None:
return filter_id_response[0] return filter_id_response[0]
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,)) txn.execute(sql, (user_id.localpart,))
max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
if max_id is None: if max_id is None:
filter_id = 0 filter_id = 0
@ -70,10 +93,18 @@ class FilteringWorkerStore(SQLBaseStore):
filter_id = max_id + 1 filter_id = max_id + 1
sql = ( sql = (
"INSERT INTO user_filters (user_id, filter_id, filter_json)" "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)" "VALUES(?, ?, ?, ?)"
)
txn.execute(
sql,
(
user_id.to_string(),
user_id.localpart,
filter_id,
bytearray(def_json),
),
) )
txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
return filter_id return filter_id

View file

@ -11,14 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import TYPE_CHECKING, Optional
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
class ProfileWorkerStore(SQLBaseStore): class ProfileWorkerStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"profiles_full_user_id_key_idx",
index_name="profiles_full_user_id_key",
table="profiles",
columns=["full_user_id"],
unique=True,
)
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try: try:
profile = await self.db_pool.simple_select_one( profile = await self.db_pool.simple_select_one(
@ -54,28 +74,36 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url", desc="get_profile_avatar_url",
) )
async def create_profile(self, user_localpart: str) -> None: async def create_profile(self, user_id: UserID) -> None:
user_localpart = user_id.localpart
await self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile" table="profiles",
values={"user_id": user_localpart, "full_user_id": user_id.to_string()},
desc="create_profile",
) )
async def set_profile_displayname( async def set_profile_displayname(
self, user_localpart: str, new_displayname: Optional[str] self, user_id: UserID, new_displayname: Optional[str]
) -> None: ) -> None:
user_localpart = user_id.localpart
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
values={"displayname": new_displayname}, values={
"displayname": new_displayname,
"full_user_id": user_id.to_string(),
},
desc="set_profile_displayname", desc="set_profile_displayname",
) )
async def set_profile_avatar_url( async def set_profile_avatar_url(
self, user_localpart: str, new_avatar_url: Optional[str] self, user_id: UserID, new_avatar_url: Optional[str]
) -> None: ) -> None:
user_localpart = user_id.localpart
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
values={"avatar_url": new_avatar_url}, values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
desc="set_profile_avatar_url", desc="set_profile_avatar_url",
) )

View file

@ -2414,8 +2414,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# *obviously* the 'profiles' table uses localpart for user_id # *obviously* the 'profiles' table uses localpart for user_id
# while everything else uses the full mxid. # while everything else uses the full mxid.
txn.execute( txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)", "INSERT INTO profiles(full_user_id, user_id, displayname) VALUES (?,?,?)",
(user_id_obj.localpart, create_profile_with_displayname), (user_id, user_id_obj.localpart, create_profile_with_displayname),
) )
if self.hs.config.stats.stats_enabled: if self.hs.config.stats.stats_enabled:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
SCHEMA_VERSION = 75 # remember to update the list below when updating SCHEMA_VERSION = 76 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema """Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the This should be incremented whenever the codebase changes its requirements on the
@ -97,6 +97,9 @@ Changes in SCHEMA_VERSION = 75:
`local_current_membership` & `room_memberships`) is now being populated for new `local_current_membership` & `room_memberships`) is now being populated for new
rows. When the background job to populate historical rows lands this will rows. When the background job to populate historical rows lands this will
become the compat schema version. become the compat schema version.
Changes in SCHEMA_VERSION = 76:
- Adds a full_user_id column to tables profiles and user_filters.
""" """

View file

@ -0,0 +1,20 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE profiles ADD COLUMN full_user_id TEXT;
-- Make sure the column has a unique constraint, mirroring the `profiles_user_id_key`
-- constraint.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (7501, 'profiles_full_user_id_key_idx', '{}');

View file

@ -0,0 +1,20 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE user_filters ADD COLUMN full_user_id TEXT;
-- Add a unique index on the new column, mirroring the `user_filters_unique` unique
-- index.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (7502, 'full_users_filters_unique_idx', '{}');

View file

@ -26,13 +26,15 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
from tests import unittest from tests import unittest
from tests.events.test_utils import MockEvent from tests.events.test_utils import MockEvent
user_id = UserID.from_string("@test_user:test")
user2_id = UserID.from_string("@test_user2:test")
user_localpart = "test_user" user_localpart = "test_user"
@ -437,7 +439,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json = {"presence": {"senders": ["@foo:bar"]}} user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_id=user_id, user_filter=user_filter_json
) )
) )
presence_states = [ presence_states = [
@ -467,7 +469,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json user_id=user2_id, user_filter=user_filter_json
) )
) )
presence_states = [ presence_states = [
@ -495,7 +497,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_id=user_id, user_filter=user_filter_json
) )
) )
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
@ -514,7 +516,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_id=user_id, user_filter=user_filter_json
) )
) )
event = MockEvent( event = MockEvent(
@ -598,7 +600,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
filter_id = self.get_success( filter_id = self.get_success(
self.filtering.add_user_filter( self.filtering.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_id=user_id, user_filter=user_filter_json
) )
) )
@ -619,7 +621,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_id=user_id, user_filter=user_filter_json
) )
) )

View file

@ -66,9 +66,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_profile_handler() self.handler = hs.get_profile_handler()
def test_get_my_name(self) -> None: def test_get_my_name(self) -> None:
self.get_success( self.get_success(self.store.set_profile_displayname(self.frank, "Frank"))
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
displayname = self.get_success(self.handler.get_displayname(self.frank)) displayname = self.get_success(self.handler.get_displayname(self.frank))
@ -121,9 +119,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.hs.config.registration.enable_set_displayname = False self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
self.get_success( self.get_success(self.store.set_profile_displayname(self.frank, "Frank"))
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
self.assertEqual( self.assertEqual(
( (
@ -166,8 +162,14 @@ class ProfileTestCase(unittest.HomeserverTestCase):
) )
def test_incoming_fed_query(self) -> None: def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline")) self.get_success(
self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) self.store.create_profile(UserID.from_string("@caroline:test"))
)
self.get_success(
self.store.set_profile_displayname(
UserID.from_string("@caroline:test"), "Caroline"
)
)
response = self.get_success( response = self.get_success(
self.query_handlers["profile"]( self.query_handlers["profile"](
@ -183,9 +185,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def test_get_my_avatar(self) -> None: def test_get_my_avatar(self) -> None:
self.get_success( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(self.frank, "http://my.server/me.png")
self.frank.localpart, "http://my.server/me.png"
)
) )
avatar_url = self.get_success(self.handler.get_avatar_url(self.frank)) avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
@ -237,9 +237,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
self.get_success( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(self.frank, "http://my.server/me.png")
self.frank.localpart, "http://my.server/me.png"
)
) )
self.assertEqual( self.assertEqual(

View file

@ -802,9 +802,21 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Set avatar URL to all users, that no user has a NULL value to avoid # Set avatar URL to all users, that no user has a NULL value to avoid
# different sort order between SQlite and PostreSQL # different sort order between SQlite and PostreSQL
self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3")) self.get_success(
self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2")) self.store.set_profile_avatar_url(
self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1")) UserID.from_string("@user1:test"), "mxc://url3"
)
)
self.get_success(
self.store.set_profile_avatar_url(
UserID.from_string("@user2:test"), "mxc://url2"
)
)
self.get_success(
self.store.set_profile_avatar_url(
UserID.from_string("@admin:test"), "mxc://url1"
)
)
# order by default (name) # order by default (name)
self._order_test([self.admin_user, user1, user2], None) self._order_test([self.admin_user, user1, user2], None)
@ -1127,7 +1139,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
# set attributes for user # set attributes for user
self.get_success( self.get_success(
self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") self.store.set_profile_avatar_url(
UserID.from_string("@user:test"), "mxc://servername/mediaid"
)
) )
self.get_success( self.get_success(
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
@ -1257,7 +1271,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
Reproduces #12257. Reproduces #12257.
""" """
# Patch `self.other_user` to have an empty string as their avatar. # Patch `self.other_user` to have an empty string as their avatar.
self.get_success(self.store.set_profile_avatar_url("user", "")) self.get_success(
self.store.set_profile_avatar_url(UserID.from_string("@user:test"), "")
)
# Check we can still erase them. # Check we can still erase them.
channel = self.make_request( channel = self.make_request(
@ -2311,7 +2327,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# set attributes for user # set attributes for user
self.get_success( self.get_success(
self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") self.store.set_profile_avatar_url(
UserID.from_string("@user:test"), "mxc://servername/mediaid"
)
) )
self.get_success( self.get_success(
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)

View file

@ -17,6 +17,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import filter from synapse.rest.client import filter
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -76,7 +77,8 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_get_filter(self) -> None: def test_get_filter(self) -> None:
filter_id = self.get_success( filter_id = self.get_success(
self.filtering.add_user_filter( self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER user_id=UserID.from_string("@apple:test"),
user_filter=self.EXAMPLE_FILTER,
) )
) )
self.reactor.advance(1) self.reactor.advance(1)

View file

@ -29,9 +29,9 @@ class DataStoreTestCase(unittest.HomeserverTestCase):
def test_get_users_paginate(self) -> None: def test_get_users_paginate(self) -> None:
self.get_success(self.store.register_user(self.user.to_string(), "pass")) self.get_success(self.store.register_user(self.user.to_string(), "pass"))
self.get_success(self.store.create_profile(self.user.localpart)) self.get_success(self.store.create_profile(self.user))
self.get_success( self.get_success(
self.store.set_profile_displayname(self.user.localpart, self.displayname) self.store.set_profile_displayname(self.user, self.displayname)
) )
users, total = self.get_success( users, total = self.get_success(

View file

@ -27,11 +27,9 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.u_frank = UserID.from_string("@frank:test") self.u_frank = UserID.from_string("@frank:test")
def test_displayname(self) -> None: def test_displayname(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart)) self.get_success(self.store.create_profile(self.u_frank))
self.get_success( self.get_success(self.store.set_profile_displayname(self.u_frank, "Frank"))
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEqual( self.assertEqual(
"Frank", "Frank",
@ -43,21 +41,17 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
) )
# test set to None # test set to None
self.get_success( self.get_success(self.store.set_profile_displayname(self.u_frank, None))
self.store.set_profile_displayname(self.u_frank.localpart, None)
)
self.assertIsNone( self.assertIsNone(
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
) )
def test_avatar_url(self) -> None: def test_avatar_url(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart)) self.get_success(self.store.create_profile(self.u_frank))
self.get_success( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(self.u_frank, "http://my.site/here")
self.u_frank.localpart, "http://my.site/here"
)
) )
self.assertEqual( self.assertEqual(
@ -70,9 +64,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
) )
# test set to None # test set to None
self.get_success( self.get_success(self.store.set_profile_avatar_url(self.u_frank, None))
self.store.set_profile_avatar_url(self.u_frank.localpart, None)
)
self.assertIsNone( self.assertIsNone(
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))