Merge branch 'develop' of github.com:matrix-org/synapse into erikj/state_fixup

This commit is contained in:
Erik Johnston 2017-06-07 11:05:23 +01:00
commit 09e4bc0501
16 changed files with 1175 additions and 12 deletions

View file

@ -0,0 +1,429 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
class UserDirectoyHandler(object):
"""Handles querying of and keeping updated the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
The user directory is filled with users who this server can see are joined to a
world_readable or publically joinable room. We keep a database table up to date
by streaming changes of the current state and recalculating whether users should
be in the directory or not when necessary.
For each user in the directory we also store a room_id which is public and that the
user is joined to. This allows us to ignore history_visibility and join_rules changes
for that user in all other public rooms, as we know they'll still be in at least
one public room.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.server_name = hs.hostname
self.clock = hs.get_clock()
# When start up for the first time we need to populate the user_directory.
# This is a set of user_id's we've inserted already
self.initially_handled_users = set()
self.initially_handled_users_in_public = set()
# The current position in the current_state_delta stream
self.pos = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False
# We kick this off so that we don't have to wait for a change before
# we start populating the user directory
self.clock.call_later(0, self.notify_new_event)
def search_users(self, search_term, limit):
"""Searches for users in directory
Returns:
dict of the form::
{
"limited": <bool>, # whether there were more results or not
"results": [ # Ordered by best match first
{
"user_id": <user_id>,
"display_name": <display_name>,
"avatar_url": <avatar_url>
}
]
}
"""
return self.store.search_user_dir(search_term, limit)
@defer.inlineCallbacks
def notify_new_event(self):
"""Called when there may be more deltas to process
"""
if self._is_processing:
return
self._is_processing = True
try:
yield self._unsafe_process()
finally:
self._is_processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = yield self.store.get_user_directory_stream_pos()
# If still None then we need to do the initial fill of directory
if self.pos is None:
yield self._do_initial_spam()
self.pos = yield self.store.get_user_directory_stream_pos()
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
deltas = yield self.store.get_current_state_deltas(self.pos)
if not deltas:
return
yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"]
yield self.store.update_user_directory_stream_pos(self.pos)
@defer.inlineCallbacks
def _do_initial_spam(self):
"""Populates the user_directory from the current state of the DB, used
when synapse first starts with user_directory support
"""
new_pos = yield self.store.get_max_stream_id_in_current_state_deltas()
# Delete any existing entries just in case there are any
yield self.store.delete_all_from_user_dir()
# We process by going through each existing room at a time.
room_ids = yield self.store.get_all_rooms()
for room_id in room_ids:
yield self._handle_intial_room(room_id)
self.initially_handled_users = None
yield self.store.update_user_directory_stream_pos(new_pos)
@defer.inlineCallbacks
def _handle_intial_room(self, room_id):
"""Called when we initially fill out user_directory one room at a time
"""
is_in_room = yield self.state.get_is_host_in_room(room_id, self.server_name)
if not is_in_room:
return
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id)
users_with_profile = yield self.state.get_current_user_in_room(room_id)
unhandled_users = set(users_with_profile) - self.initially_handled_users
yield self.store.add_profiles_to_user_dir(
room_id, {
user_id: users_with_profile[user_id] for user_id in unhandled_users
}
)
self.initially_handled_users |= unhandled_users
if is_public:
yield self.store.add_users_to_public_room(
room_id,
user_ids=unhandled_users - self.initially_handled_users_in_public
)
self.initially_handled_users_in_public != unhandled_users
@defer.inlineCallbacks
def _handle_deltas(self, deltas):
"""Called with the state deltas to process
"""
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
room_id = delta["room_id"]
event_id = delta["event_id"]
prev_event_id = delta["prev_event_id"]
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
# For join rule and visibility changes we need to check if the room
# may have become public or not and add/remove the users in said room
if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
yield self._handle_room_publicity_change(
room_id, prev_event_id, event_id, typ,
)
elif typ == EventTypes.Member:
change = yield self._get_key_change(
prev_event_id, event_id,
key_name="membership",
public_value=Membership.JOIN,
)
if change is None:
# Handle any profile changes
yield self._handle_profile_change(state_key, prev_event_id, event_id)
continue
if not change:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
is_in_room = yield self.state.get_is_host_in_room(
room_id, self.server_name,
)
if not is_in_room:
logger.debug("Server left room: %r", room_id)
# Fetch all the users that we marked as being in user
# directory due to being in the room and then check if
# need to remove those users or not
user_ids = yield self.store.get_users_in_dir_due_to_room(room_id)
for user_id in user_ids:
yield self._handle_remove_user(room_id, user_id)
return
else:
logger.debug("Server is still in room: %r", room_id)
if change: # The user joined
event = yield self.store.get_event(event_id)
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
)
yield self._handle_new_user(room_id, state_key, profile)
else: # The user left
yield self._handle_remove_user(room_id, state_key)
else:
logger.debug("Ignoring irrelevant type: %r", typ)
@defer.inlineCallbacks
def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ):
"""Handle a room having potentially changed from/to world_readable/publically
joinable.
Args:
room_id (str)
prev_event_id (str|None): The previous event before the state change
event_id (str|None): The new event after the state change
typ (str): Type of the event
"""
logger.debug("Handling change for %s", typ)
if typ == EventTypes.RoomHistoryVisibility:
change = yield self._get_key_change(
prev_event_id, event_id,
key_name="history_visibility",
public_value="world_readable",
)
elif typ == EventTypes.JoinRules:
change = yield self._get_key_change(
prev_event_id, event_id,
key_name="join_rule",
public_value=JoinRules.PUBLIC,
)
else:
raise Exception("Invalid event type")
# If change is None, no change. True => become world_readable/public,
# False => was world_readable/public
if change is None:
logger.debug("No change")
return
# There's been a change to or from being world readable.
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
logger.debug("Change: %r, is_public: %r", change, is_public)
if change and not is_public:
# If we became world readable but room isn't currently public then
# we ignore the change
return
elif not change and is_public:
# If we stopped being world readable but are still public,
# ignore the change
return
if change:
users_with_profile = yield self.state.get_current_user_in_room(room_id)
for user_id, profile in users_with_profile.iteritems():
yield self._handle_new_user(room_id, user_id, profile)
else:
users = yield self.store.get_users_in_public_due_to_room(room_id)
for user_id in users:
yield self._handle_remove_user(room_id, user_id)
@defer.inlineCallbacks
def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory
Args:
room_id (str): room_id that user joined or started being public that
user_id (str)
"""
logger.debug("Adding user to dir, %r", user_id)
row = yield self.store.get_user_in_directory(user_id)
if not row:
yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile})
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
if not is_public:
return
row = yield self.store.get_user_in_public_room(user_id)
if not row:
yield self.store.add_users_to_public_room(room_id, [user_id])
@defer.inlineCallbacks
def _handle_remove_user(self, room_id, user_id):
"""Called when we might need to remove user to directory
Args:
room_id (str): room_id that user left or stopped being public that
user_id (str)
"""
logger.debug("Maybe removing user %r", user_id)
row = yield self.store.get_user_in_directory(user_id)
update_user_dir = row and row["room_id"] == room_id
row = yield self.store.get_user_in_public_room(user_id)
update_user_in_public = row and row["room_id"] == room_id
if not update_user_in_public and not update_user_dir:
return
# XXX: Make this faster?
rooms = yield self.store.get_rooms_for_user(user_id)
for j_room_id in rooms:
if not update_user_in_public and not update_user_dir:
break
is_in_room = yield self.state.get_is_host_in_room(
j_room_id, self.server_name,
)
if not is_in_room:
continue
if update_user_dir:
update_user_dir = False
yield self.store.update_user_in_user_dir(user_id, j_room_id)
if update_user_in_public:
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
j_room_id
)
if is_public:
yield self.store.update_user_in_public_user_list(user_id, j_room_id)
update_user_in_public = False
if update_user_dir:
yield self.store.remove_from_user_dir(user_id)
elif update_user_in_public:
yield self.store.remove_from_user_in_public_room(user_id)
@defer.inlineCallbacks
def _handle_profile_change(self, user_id, prev_event_id, event_id):
"""Check member event changes for any profile changes and update the
database if there are.
"""
if not prev_event_id or not event_id:
return
prev_event = yield self.store.get_event(prev_event_id)
event = yield self.store.get_event(event_id)
if event.membership != Membership.JOIN:
return
prev_name = prev_event.content.get("displayname")
new_name = event.content.get("displayname")
prev_avatar = prev_event.content.get("avatar_url")
new_avatar = event.content.get("avatar_url")
if prev_name != new_name or prev_avatar != new_avatar:
yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
@defer.inlineCallbacks
def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
For example, check if `history_visibility` (`key_name`) changed from
`shared` to `world_readable` (`public_value`).
Returns:
None if the field in the events either both match `public_value`
or if neither do, i.e. there has been no change.
True if it didnt match `public_value` but now does
False if it did match `public_value` but now doesn't
"""
prev_event = None
event = None
if prev_event_id:
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
defer.returnValue(None)
prev_value = None
value = None
if prev_event:
prev_value = prev_event.content.get(key_name)
if event:
value = event.content.get(key_name)
logger.debug("prev_value: %r -> value: %r", prev_value, value)
if value == public_value and prev_value != public_value:
defer.returnValue(True)
elif value != public_value and prev_value == public_value:
defer.returnValue(False)
else:
defer.returnValue(None)

View file

@ -167,6 +167,7 @@ class Notifier(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.user_directory_handler = hs.get_user_directory_handler()
if hs.should_send_federation(): if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
@ -251,7 +252,10 @@ class Notifier(object):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
preserve_fn(self.appservice_handler.notify_interested_services)( preserve_fn(self.appservice_handler.notify_interested_services)(
room_stream_id) room_stream_id
)
preserve_fn(self.user_directory_handler.notify_new_event)()
if self.federation_sender: if self.federation_sender:
preserve_fn(self.federation_sender.notify_new_events)( preserve_fn(self.federation_sender.notify_new_events)(

View file

@ -57,8 +57,8 @@ class PusherFactory(object):
logger.info("found pusher") logger.info("found pusher")
return self.pusher_types[pusherdict['kind']](self.hs, pusherdict) return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
def _create_email_pusher(self, pusherdict): def _create_email_pusher(self, _hs, pusherdict):
app_name = self._brand_from_pusherdict app_name = self._app_name_from_pusherdict(pusherdict)
mailer = self.mailers.get(app_name) mailer = self.mailers.get(app_name)
if not mailer: if not mailer:
mailer = Mailer( mailer = Mailer(

View file

@ -51,6 +51,7 @@ from synapse.rest.client.v2_alpha import (
devices, devices,
thirdparty, thirdparty,
sendtodevice, sendtodevice,
user_directory,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -100,3 +101,4 @@ class ClientRestResource(JsonResource):
devices.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource) thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)

View file

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user_directory/search$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(UserDirectorySearchRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
@defer.inlineCallbacks
def on_POST(self, request):
"""Searches for users in directory
Returns:
dict of the form::
{
"limited": <bool>, # whether there were more results or not
"results": [ # Ordered by best match first
{
"user_id": <user_id>,
"display_name": <display_name>,
"avatar_url": <avatar_url>
}
]
}
"""
yield self.auth.get_user_by_req(request, allow_guest=False)
body = parse_json_object_from_request(request)
limit = body.get("limit", 10)
limit = min(limit, 50)
try:
search_term = body["search_term"]
except:
raise SynapseError(400, "`search_term` is required field")
results = yield self.user_directory_handler.search_users(search_term, limit)
defer.returnValue((200, results))
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)

View file

@ -49,6 +49,7 @@ from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -137,6 +138,7 @@ class HomeServer(object):
'tcp_replication', 'tcp_replication',
'read_marker_handler', 'read_marker_handler',
'action_generator', 'action_generator',
'user_directory_handler',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -304,6 +306,9 @@ class HomeServer(object):
def build_action_generator(self): def build_action_generator(self):
return ActionGenerator(self) return ActionGenerator(self)
def build_user_directory_handler(self):
return UserDirectoyHandler(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View file

@ -182,6 +182,17 @@ class StateHandler(object):
joined_hosts = yield self.store.get_joined_hosts(room_id, entry) joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@defer.inlineCallbacks
def get_is_host_in_room(self, room_id, host, latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_is_host_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
is_host_joined = yield self.store.is_host_joined(
room_id, host, entry.state_id, entry.state
)
defer.returnValue(is_host_joined)
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event. """Build an EventContext structure for the event.

View file

@ -49,6 +49,7 @@ from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from .openid import OpenIdStore from .openid import OpenIdStore
from .client_ips import ClientIpStore from .client_ips import ClientIpStore
from .user_directory import UserDirectoryStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine from .engines import PostgresEngine
@ -86,6 +87,7 @@ class DataStore(RoomMemberStore, RoomStore,
ClientIpStore, ClientIpStore,
DeviceStore, DeviceStore,
DeviceInboxStore, DeviceInboxStore,
UserDirectoryStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -221,6 +223,18 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max, "DeviceListFederationStreamChangeCache", device_list_max,
) )
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",

View file

@ -425,6 +425,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals) txn.execute(sql, vals)
def _simple_insert_many(self, table, values, desc):
return self.runInteraction(
desc, self._simple_insert_many_txn, table, values
)
@staticmethod @staticmethod
def _simple_insert_many_txn(txn, table, values): def _simple_insert_many_txn(txn, table, values):
if not values: if not values:

View file

@ -20,6 +20,8 @@ from twisted.internet import defer
from ._base import Cache from ._base import Cache
from . import background_updates from . import background_updates
import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller # Number of msec of granularity to store the user IP 'last seen' time. Smaller
@ -28,12 +30,15 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000 LAST_SEEN_GRANULARITY = 120 * 1000
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class ClientIpStore(background_updates.BackgroundUpdateStore): class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs): def __init__(self, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
max_entries=5000, max_entries=50000 * CACHE_SIZE_FACTOR,
) )
super(ClientIpStore, self).__init__(hs) super(ClientIpStore, self).__init__(hs)

View file

@ -648,9 +648,10 @@ class EventsStore(SQLBaseStore):
list of the event ids which are the forward extremities. list of the event ids which are the forward extremities.
""" """
self._update_current_state_txn(txn, current_state_for_room)
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
self._update_forward_extremities_txn( self._update_forward_extremities_txn(
txn, txn,
new_forward_extremities=new_forward_extremeties, new_forward_extremities=new_forward_extremeties,
@ -713,7 +714,7 @@ class EventsStore(SQLBaseStore):
backfilled=backfilled, backfilled=backfilled,
) )
def _update_current_state_txn(self, txn, state_delta_by_room): def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems(): for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert, _ = current_state_tuple to_delete, to_insert, _ = current_state_tuple
txn.executemany( txn.executemany(
@ -735,6 +736,29 @@ class EventsStore(SQLBaseStore):
], ],
) )
state_deltas = {key: None for key in to_delete}
state_deltas.update(to_insert)
self._simple_insert_many_txn(
txn,
table="current_state_delta_stream",
values=[
{
"stream_id": max_stream_order,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": ev_id,
"prev_event_id": to_delete.get(key, None),
}
for key, ev_id in state_deltas.iteritems()
]
)
self._curr_state_delta_stream_cache.entity_has_changed(
room_id, max_stream_order,
)
# Invalidate the various caches # Invalidate the various caches
# Figure out the changes of membership to invalidate the # Figure out the changes of membership to invalidate the
@ -743,11 +767,7 @@ class EventsStore(SQLBaseStore):
# and which we have added, then we invlidate the caches for all # and which we have added, then we invlidate the caches for all
# those users. # those users.
members_changed = set( members_changed = set(
state_key for ev_type, state_key in to_delete.iterkeys() state_key for ev_type, state_key in state_deltas
if ev_type == EventTypes.Member
)
members_changed.update(
state_key for ev_type, state_key in to_insert.iterkeys()
if ev_type == EventTypes.Member if ev_type == EventTypes.Member
) )

View file

@ -0,0 +1,26 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE current_state_delta_stream (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
event_id TEXT, -- Is null if the key was removed
prev_event_id TEXT -- Is null if the key was added
);
CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id);

View file

@ -0,0 +1,84 @@
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
logger = logging.getLogger(__name__)
BOTH_TABLES = """
CREATE TABLE user_directory_stream_pos (
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_id BIGINT,
CHECK (Lock='X')
);
INSERT INTO user_directory_stream_pos (stream_id) VALUES (null);
CREATE TABLE user_directory (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL, -- A room_id that we know the user is joined to
display_name TEXT,
avatar_url TEXT
);
CREATE INDEX user_directory_room_idx ON user_directory(room_id);
CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
CREATE TABLE users_in_pubic_room (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL -- A room_id that we know is public
);
CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id);
CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id);
"""
POSTGRES_TABLE = """
CREATE TABLE user_directory_search (
user_id TEXT NOT NULL,
vector tsvector
);
CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector);
CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id);
"""
SQLITE_TABLE = """
CREATE VIRTUAL TABLE user_directory_search
USING fts4 ( user_id, value );
"""
def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(BOTH_TABLES.splitlines()):
cur.execute(statement)
if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement)
elif isinstance(database_engine, Sqlite3Engine):
for statement in get_statements(SQLITE_TABLE.splitlines()):
cur.execute(statement)
else:
raise Exception("Unrecognized database engine")
def run_upgrade(*args, **kwargs):
pass

View file

@ -0,0 +1,461 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
import re
class UserDirectoryStore(SQLBaseStore):
@cachedInlineCallbacks(cache_context=True)
def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
"""Check if the room is either world_readable or publically joinable
"""
current_state_ids = yield self.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
defer.returnValue(True)
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def add_users_to_public_room(self, room_id, user_ids):
"""Add user to the list of users in public rooms
Args:
room_id (str): A room_id that all users are in that is world_readable
or publically joinable
user_ids (list(str)): Users to add
"""
yield self._simple_insert_many(
table="users_in_pubic_room",
values=[
{
"user_id": user_id,
"room_id": room_id,
}
for user_id in user_ids
],
desc="add_users_to_public_room"
)
for user_id in user_ids:
self.get_user_in_public_room.invalidate((user_id,))
def add_profiles_to_user_dir(self, room_id, users_with_profile):
"""Add profiles to the user directory
Args:
room_id (str): A room_id that all users are joined to
users_with_profile (dict): Users to add to directory in the form of
mapping of user_id -> ProfileInfo
"""
if isinstance(self.database_engine, PostgresEngine):
# We weight the loclpart most highly, then display name and finally
# server name
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
)
"""
args = (
(
user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
profile.display_name,
)
for user_id, profile in users_with_profile.iteritems()
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
INSERT INTO user_directory_search(user_id, value)
VALUES (?,?)
"""
args = (
(
user_id,
"%s %s" % (user_id, p.display_name,) if p.display_name else user_id
)
for user_id, p in users_with_profile.iteritems()
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
def _add_profiles_to_user_dir_txn(txn):
txn.executemany(sql, args)
self._simple_insert_many_txn(
txn,
table="user_directory",
values=[
{
"user_id": user_id,
"room_id": room_id,
"display_name": profile.display_name,
"avatar_url": profile.avatar_url,
}
for user_id, profile in users_with_profile.iteritems()
]
)
for user_id in users_with_profile:
txn.call_after(
self.get_user_in_directory.invalidate, (user_id,)
)
return self.runInteraction(
"add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
)
@defer.inlineCallbacks
def update_user_in_user_dir(self, user_id, room_id):
yield self._simple_update_one(
table="user_directory",
keyvalues={"user_id": user_id},
updatevalues={"room_id": room_id},
desc="update_user_in_user_dir",
)
self.get_user_in_directory.invalidate((user_id,))
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
def _update_profile_in_user_dir_txn(txn):
self._simple_update_one_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
updatevalues={"display_name": display_name, "avatar_url": avatar_url},
)
if isinstance(self.database_engine, PostgresEngine):
# We weight the loclpart most highly, then display name and finally
# server name
sql = """
UPDATE user_directory_search
SET vector = setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
args = (
get_localpart_from_id(user_id), get_domain_from_id(user_id),
display_name,
user_id,
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
UPDATE user_directory_search
set value = ?
WHERE user_id = ?
"""
args = (
"%s %s" % (user_id, display_name,) if display_name else user_id,
user_id,
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
txn.execute(sql, args)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
@defer.inlineCallbacks
def update_user_in_public_user_list(self, user_id, room_id):
yield self._simple_update_one(
table="users_in_pubic_room",
keyvalues={"user_id": user_id},
updatevalues={"room_id": room_id},
desc="update_user_in_public_user_list",
)
self.get_user_in_public_room.invalidate((user_id,))
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="users_in_pubic_room",
keyvalues={"user_id": user_id},
)
txn.call_after(
self.get_user_in_directory.invalidate, (user_id,)
)
txn.call_after(
self.get_user_in_public_room.invalidate, (user_id,)
)
return self.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn,
)
@defer.inlineCallbacks
def remove_from_user_in_public_room(self, user_id):
yield self._simple_delete(
table="users_in_pubic_room",
keyvalues={"user_id": user_id},
desc="remove_from_user_in_public_room",
)
self.get_user_in_public_room.invalidate((user_id,))
def get_users_in_public_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory becuase they're
in the given room_id
"""
return self._simple_select_onecol(
table="users_in_pubic_room",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_public_due_to_room",
)
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory becuase they're
in the given room_id
"""
return self._simple_select_onecol(
table="user_directory",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
def get_all_rooms(self):
"""Get all room_ids we've ever known about
"""
return self._simple_select_onecol(
table="current_state_events",
keyvalues={},
retcol="DISTINCT room_id",
desc="get_all_rooms",
)
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_pubic_room")
txn.call_after(self.get_user_in_directory.invalidate_all)
txn.call_after(self.get_user_in_public_room.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("room_id", "display_name", "avatar_url",),
allow_none=True,
desc="get_user_in_directory",
)
@cached()
def get_user_in_public_room(self, user_id):
return self._simple_select_one(
table="users_in_pubic_room",
keyvalues={"user_id": user_id},
retcols=("room_id",),
allow_none=True,
desc="get_user_in_public_room",
)
def get_user_directory_stream_pos(self):
return self._simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
desc="get_user_directory_stream_pos",
)
def update_user_directory_stream_pos(self, stream_id):
return self._simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_user_directory_stream_pos",
)
def get_current_state_deltas(self, prev_stream_id):
prev_stream_id = int(prev_stream_id)
if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
return []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
# N results.
# We arbitarily limit to 100 stream_id entries to ensure we don't
# select toooo many.
sql = """
SELECT stream_id, count(*)
FROM current_state_delta_stream
WHERE stream_id > ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
"""
txn.execute(sql, (prev_stream_id,))
total = 0
max_stream_id = prev_stream_id
for max_stream_id, count in txn:
total += count
if total > 100:
# We arbitarily limit to 100 entries to ensure we don't
# select toooo many.
break
# Now actually get the deltas
sql = """
SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, max_stream_id,))
return self.cursor_to_dict(txn)
return self.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
def get_max_stream_id_in_current_state_deltas(self):
return self._simple_select_one_onecol(
table="current_state_delta_stream",
keyvalues={},
retcol="COALESCE(MAX(stream_id), -1)",
desc="get_max_stream_id_in_current_state_deltas",
)
@defer.inlineCallbacks
def search_user_dir(self, search_term, limit):
"""Searches for users in directory
Returns:
dict of the form::
{
"limited": <bool>, # whether there were more results or not
"results": [ # Ordered by best match first
{
"user_id": <user_id>,
"display_name": <display_name>,
"avatar_url": <avatar_url>
}
]
}
"""
search_query = _parse_query(self.database_engine, search_term)
if isinstance(self.database_engine, PostgresEngine):
# We order by rank and then if they have profile info
sql = """
SELECT user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory USING (user_id)
INNER JOIN users_in_pubic_room USING (user_id)
WHERE vector @@ to_tsquery('english', ?)
ORDER BY
ts_rank_cd(vector, to_tsquery('english', ?), 1) DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
"""
args = (search_query, search_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
SELECT user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory USING (user_id)
INNER JOIN users_in_pubic_room USING (user_id)
WHERE value MATCH ?
ORDER BY
rank(matchinfo(user_directory_search)) DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
"""
args = (search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
results = yield self._execute(
"search_user_dir", self.cursor_to_dict, sql, *args
)
limited = len(results) > limit
defer.returnValue({
"limited": limited,
"results": results,
})
def _parse_query(database_engine, search_term):
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
that is supported by default.
We specifically add both a prefix and non prefix matching term so that
exact matches get ranked higher.
"""
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
if isinstance(database_engine, PostgresEngine):
return " & ".join("(%s:* | %s)" % (result, result,) for result in results)
elif isinstance(database_engine, Sqlite3Engine):
return " & ".join("(%s* | %s)" % (result, result,) for result in results)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")

View file

@ -62,6 +62,13 @@ def get_domain_from_id(string):
return string[idx + 1:] return string[idx + 1:]
def get_localpart_from_id(string):
idx = string.find(":")
if idx == -1:
raise SynapseError(400, "Invalid ID: %r" % (string,))
return string[1:idx]
class DomainSpecificString( class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain")) namedtuple("DomainSpecificString", ("localpart", "domain"))
): ):

View file

@ -89,6 +89,21 @@ class StreamChangeCache(object):
return result return result
def has_any_entity_changed(self, stream_pos):
"""Returns if any entity has changed
"""
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
self.metrics.inc_hits()
if stream_pos >= max(self._cache):
return False
else:
return True
else:
self.metrics.inc_misses()
return True
def get_all_entities_changed(self, stream_pos): def get_all_entities_changed(self, stream_pos):
"""Returns all entites that have had new things since the given """Returns all entites that have had new things since the given
position. If the position is too old it will return None. position. If the position is too old it will return None.