0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 06:51:46 +01:00

Merge pull request #2723 from matrix-org/matthew/search-all-local-users

Add all local users to the user_directory and optionally search them
This commit is contained in:
Matthew Hodgson 2017-12-05 11:09:47 +00:00 committed by GitHub
commit 33cb7ef0b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 234 additions and 37 deletions

View file

@ -36,6 +36,7 @@ from .workers import WorkerConfig
from .push import PushConfig from .push import PushConfig
from .spam_checker import SpamCheckerConfig from .spam_checker import SpamCheckerConfig
from .groups import GroupsConfig from .groups import GroupsConfig
from .user_directory import UserDirectoryConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@ -44,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig, GroupsConfig,): SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
pass pass

View file

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class UserDirectoryConfig(Config):
"""User Directory Configuration
Configuration for the behaviour of the /user_directory API
"""
def read_config(self, config):
self.user_directory_search_all_users = False
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_all_users = (
user_directory_config.get("search_all_users", False)
)
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# User Directory configuration
#
# 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible
# in public rooms. Defaults to false. If you set it True, you'll have to run
# UPDATE user_directory_stream_pos SET stream_id = NULL;
# on your database to tell it to rebuild the user_directory search indexes.
#
#user_directory:
# search_all_users: false
"""

View file

@ -36,6 +36,8 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query "profile", self.on_profile_query
) )
self.user_directory_handler = hs.get_user_directory_handler()
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS) self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -139,6 +141,12 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname target_user.localpart, new_displayname
) )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user) yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -183,6 +191,12 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url target_user.localpart, new_avatar_url
) )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user) yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -38,6 +38,7 @@ class RegistrationHandler(BaseHandler):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None self._next_generated_user_id = None
@ -165,6 +166,13 @@ class RegistrationHandler(BaseHandler):
), ),
admin=admin, admin=admin,
) )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart)
yield self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
else: else:
# autogen a sequential user ID # autogen a sequential user ID
attempts = 0 attempts = 0

View file

@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.types import get_localpart_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserDirectoyHandler(object): class UserDirectoryHandler(object):
"""Handles querying of and keeping updated the user_directory. """Handles querying of and keeping updated the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
@ -41,9 +42,10 @@ class UserDirectoyHandler(object):
one public room. one public room.
""" """
INITIAL_SLEEP_MS = 50 INITIAL_ROOM_SLEEP_MS = 50
INITIAL_SLEEP_COUNT = 100 INITIAL_ROOM_SLEEP_COUNT = 100
INITIAL_BATCH_SIZE = 100 INITIAL_ROOM_BATCH_SIZE = 100
INITIAL_USER_SLEEP_MS = 10
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -53,6 +55,7 @@ class UserDirectoyHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory self.update_user_directory = hs.config.update_user_directory
self.search_all_users = hs.config.user_directory_search_all_users
# When start up for the first time we need to populate the user_directory. # When start up for the first time we need to populate the user_directory.
# This is a set of user_id's we've inserted already # This is a set of user_id's we've inserted already
@ -110,6 +113,15 @@ class UserDirectoyHandler(object):
finally: finally:
self._is_processing = False self._is_processing = False
@defer.inlineCallbacks
def handle_local_profile_change(self, user_id, profile):
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
yield self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url, None,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _unsafe_process(self): def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB # If self.pos is None then means we haven't fetched it from DB
@ -148,16 +160,30 @@ class UserDirectoyHandler(object):
room_ids = yield self.store.get_all_rooms() room_ids = yield self.store.get_all_rooms()
logger.info("Doing initial update of user directory. %d rooms", len(room_ids)) logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
num_processed_rooms = 1 num_processed_rooms = 0
for room_id in room_ids: for room_id in room_ids:
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids)) logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
yield self._handle_initial_room(room_id) yield self._handle_initial_room(room_id)
num_processed_rooms += 1 num_processed_rooms += 1
yield sleep(self.INITIAL_SLEEP_MS / 1000.) yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
logger.info("Processed all rooms.") logger.info("Processed all rooms.")
if self.search_all_users:
num_processed_users = 0
user_ids = yield self.store.get_all_local_users()
logger.info("Doing initial update of user directory. %d users", len(user_ids))
for user_id in user_ids:
# We add profiles for all users even if they don't match the
# include pattern, just in case we want to change it in future
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
yield self._handle_local_user(user_id)
num_processed_users += 1
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
logger.info("Processed all users")
self.initially_handled_users = None self.initially_handled_users = None
self.initially_handled_users_in_public = None self.initially_handled_users_in_public = None
self.initially_handled_users_share = None self.initially_handled_users_share = None
@ -201,8 +227,8 @@ class UserDirectoyHandler(object):
to_update = set() to_update = set()
count = 0 count = 0
for user_id in user_ids: for user_id in user_ids:
if count % self.INITIAL_SLEEP_COUNT == 0: if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_SLEEP_MS / 1000.) yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):
count += 1 count += 1
@ -216,8 +242,8 @@ class UserDirectoyHandler(object):
if user_id == other_user_id: if user_id == other_user_id:
continue continue
if count % self.INITIAL_SLEEP_COUNT == 0: if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_SLEEP_MS / 1000.) yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
count += 1 count += 1
user_set = (user_id, other_user_id) user_set = (user_id, other_user_id)
@ -237,13 +263,13 @@ class UserDirectoyHandler(object):
else: else:
self.initially_handled_users_share_private_room.add(user_set) self.initially_handled_users_share_private_room.add(user_set)
if len(to_insert) > self.INITIAL_BATCH_SIZE: if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.add_users_who_share_room( yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert, room_id, not is_public, to_insert,
) )
to_insert.clear() to_insert.clear()
if len(to_update) > self.INITIAL_BATCH_SIZE: if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.update_users_who_share_room( yield self.store.update_users_who_share_room(
room_id, not is_public, to_update, room_id, not is_public, to_update,
) )
@ -384,15 +410,29 @@ class UserDirectoyHandler(object):
for user_id in users: for user_id in users:
yield self._handle_remove_user(room_id, user_id) yield self._handle_remove_user(room_id, user_id)
@defer.inlineCallbacks
def _handle_local_user(self, user_id):
"""Adds a new local roomless user into the user_directory_search table.
Used to populate up the user index when we have an
user_directory_search_all_users specified.
"""
logger.debug("Adding new local user to dir, %r", user_id)
profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
row = yield self.store.get_user_in_directory(user_id)
if not row:
yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_user(self, room_id, user_id, profile): def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory """Called when we might need to add user to directory
Args: Args:
room_id (str): room_id that user joined or started being public that room_id (str): room_id that user joined or started being public
user_id (str) user_id (str)
""" """
logger.debug("Adding user to dir, %r", user_id) logger.debug("Adding new user to dir, %r", user_id)
row = yield self.store.get_user_in_directory(user_id) row = yield self.store.get_user_in_directory(user_id)
if not row: if not row:
@ -407,7 +447,7 @@ class UserDirectoyHandler(object):
if not row: if not row:
yield self.store.add_users_to_public_room(room_id, [user_id]) yield self.store.add_users_to_public_room(room_id, [user_id])
else: else:
logger.debug("Not adding user to public dir, %r", user_id) logger.debug("Not adding new user to public dir, %r", user_id)
# Now we update users who share rooms with users. We do this by getting # Now we update users who share rooms with users. We do this by getting
# all the current users in the room and seeing which aren't already # all the current users in the room and seeing which aren't already

View file

@ -52,7 +52,7 @@ from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
@ -339,7 +339,7 @@ class HomeServer(object):
return ActionGenerator(self) return ActionGenerator(self)
def build_user_directory_handler(self): def build_user_directory_handler(self):
return UserDirectoyHandler(self) return UserDirectoryHandler(self)
def build_groups_local_handler(self): def build_groups_local_handler(self):
return GroupsLocalHandler(self) return GroupsLocalHandler(self)

View file

@ -554,7 +554,7 @@ class SQLBaseStore(object):
def _simple_select_one(self, table, keyvalues, retcols, def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"): allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it. return a single row, returning multiple columns from it.
Args: Args:
table : string giving the table name table : string giving the table name

View file

@ -15,6 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.roommember import ProfileInfo
from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -26,6 +29,30 @@ class ProfileStore(SQLBaseStore):
desc="create_profile", desc="create_profile",
) )
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
profile = yield self._simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
)
except StoreError as e:
if e.code == 404:
# no match
defer.returnValue(ProfileInfo(None, None))
return
else:
raise
defer.returnValue(
ProfileInfo(
avatar_url=profile['avatar_url'],
display_name=profile['displayname'],
)
)
def get_profile_displayname(self, user_localpart): def get_profile_displayname(self, user_localpart):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="profiles", table="profiles",

View file

@ -0,0 +1,35 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- change the user_directory table to also cover global local user profiles
-- rather than just profiles within specific rooms.
CREATE TABLE user_directory2 (
user_id TEXT NOT NULL,
room_id TEXT,
display_name TEXT,
avatar_url TEXT
);
INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
SELECT user_id, room_id, display_name, avatar_url from user_directory;
DROP TABLE user_directory;
ALTER TABLE user_directory2 RENAME TO user_directory;
-- create indexes after doing the inserts because that's more efficient.
-- it also means we can give it the same name as the old one without renaming.
CREATE INDEX user_directory_room_idx ON user_directory(room_id);
CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);

View file

@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore):
) )
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# We weight the loclpart most highly, then display name and finally # We weight the localpart most highly, then display name and finally
# server name # server name
if new_entry: if new_entry:
sql = """ sql = """
@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore):
rows = yield self._execute("get_all_rooms", None, sql) rows = yield self._execute("get_all_rooms", None, sql)
defer.returnValue([room_id for room_id, in rows]) defer.returnValue([room_id for room_id, in rows])
@defer.inlineCallbacks
def get_all_local_users(self):
"""Get all local users
"""
sql = """
SELECT name FROM users
"""
rows = yield self._execute("get_all_local_users", None, sql)
defer.returnValue([name for name, in rows])
def add_users_who_share_room(self, room_id, share_private, user_id_tuples): def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
"""Insert entries into the users_who_share_rooms table. The first """Insert entries into the users_who_share_rooms table. The first
user should be a local user. user should be a local user.
@ -629,6 +639,20 @@ class UserDirectoryStore(SQLBaseStore):
] ]
} }
""" """
if self.hs.config.user_directory_search_all_users:
join_clause = ""
where_clause = "?<>''" # naughty hack to keep the same number of binds
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
"""
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term) full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@ -641,13 +665,9 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url SELECT d.user_id, display_name, avatar_url
FROM user_directory_search FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id) INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_public_rooms AS p USING (user_id) %s
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
WHERE WHERE
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL) %s
AND vector @@ to_tsquery('english', ?) AND vector @@ to_tsquery('english', ?)
ORDER BY ORDER BY
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@ -671,7 +691,7 @@ class UserDirectoryStore(SQLBaseStore):
display_name IS NULL, display_name IS NULL,
avatar_url IS NULL avatar_url IS NULL
LIMIT ? LIMIT ?
""" """ % (join_clause, where_clause)
args = (user_id, full_query, exact_query, prefix_query, limit + 1,) args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term) search_query = _parse_query_sqlite(search_term)
@ -680,20 +700,16 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url SELECT d.user_id, display_name, avatar_url
FROM user_directory_search FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id) INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_public_rooms AS p USING (user_id) %s
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
WHERE WHERE
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL) %s
AND value MATCH ? AND value MATCH ?
ORDER BY ORDER BY
rank(matchinfo(user_directory_search)) DESC, rank(matchinfo(user_directory_search)) DESC,
display_name IS NULL, display_name IS NULL,
avatar_url IS NULL avatar_url IS NULL
LIMIT ? LIMIT ?
""" """ % (join_clause, where_clause)
args = (user_id, search_query, limit + 1) args = (user_id, search_query, limit + 1)
else: else:
# This should be unreachable. # This should be unreachable.
@ -723,7 +739,7 @@ def _parse_query_sqlite(search_term):
# Pull out the individual words, discarding any non-word characters. # Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
return " & ".join("(%s* | %s)" % (result, result,) for result in results) return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
def _parse_query_postgres(search_term): def _parse_query_postgres(search_term):

View file

@ -58,7 +58,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource() self.mock_federation_resource = MockHttpResource()
mock_notifier = Mock(spec=["on_new_event"]) mock_notifier = Mock()
self.on_new_event = mock_notifier.on_new_event self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[]) self.auth = Mock(spec=[])
@ -76,6 +76,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_devices_by_remote", "get_devices_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
]), ]),
state_handler=self.state_handler, state_handler=self.state_handler,
handlers=None, handlers=None,
@ -122,6 +125,15 @@ class TypingNotificationsTestCase(unittest.TestCase):
return set(str(u) for u in self.room_members) return set(str(u) for u in self.room_members)
self.state_handler.get_current_user_in_room = get_current_user_in_room self.state_handler.get_current_user_in_room = get_current_user_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
defer.succeed(1)
)
self.datastore.get_current_state_deltas.return_value = (
None
)
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_to_device_stream_token = lambda: 0