Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.21.0

This commit is contained in:
Erik Johnston 2017-05-17 11:25:23 +01:00
commit ac08316548
39 changed files with 668 additions and 147 deletions

View file

@ -28,6 +28,15 @@ running:
git pull git pull
# Update the versions of synapse's python dependencies. # Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install --upgrade python synapse/python_dependencies.py | xargs -n1 pip install --upgrade
To check whether your update was sucessfull, run:
.. code:: bash
# replace your.server.domain with ther domain of your synaspe homeserver
curl https://<your.server.domain>/_matrix/federation/v1/version
So for the Matrix.org HS server the URL would be: https://matrix.org/_matrix/federation/v1/version.
Upgrading to v0.15.0 Upgrading to v0.15.0

View file

@ -35,6 +35,8 @@ class ServerConfig(Config):
# "disable" federation # "disable" federation
self.send_federation = config.get("send_federation", True) self.send_federation = config.get("send_federation", True)
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
self.public_baseurl += '/' self.public_baseurl += '/'
@ -161,6 +163,10 @@ class ServerConfig(Config):
# The GC threshold parameters to pass to `gc.set_threshold`, if defined # The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10] # gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
# and sync operations. The default value is -1, means no upper limit.
# filter_timeline_limit: 5000
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:

View file

@ -440,6 +440,16 @@ class FederationServer(FederationBase):
key_id: json.loads(json_bytes) key_id: json.loads(json_bytes)
} }
logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in json_result.iteritems()
for device_id, device_keys in user_keys.iteritems()
for key_id, _ in device_keys.iteritems()
)),
)
defer.returnValue({"one_time_keys": json_result}) defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -285,7 +285,7 @@ class TransactionQueue(object):
Args: Args:
states (list(UserPresenceState)) states (list(UserPresenceState))
""" """
hosts_and_states = yield get_interested_remotes(self.store, states) hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
for destinations, states in hosts_and_states: for destinations, states in hosts_and_states:
for destination in destinations: for destination in destinations:

View file

@ -24,6 +24,7 @@ from synapse.http.servlet import (
) )
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse.util.logcontext import preserve_fn
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
import functools import functools
@ -79,6 +80,7 @@ class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore()
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
@ -138,6 +140,13 @@ class Authenticator(object):
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
# If we get a valid signed request from the other side, its probably
# alive
retry_timings = yield self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
logger.info("Marking origin %r as up", origin)
preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0)
defer.returnValue(origin) defer.returnValue(origin)

View file

@ -53,7 +53,20 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
def ratelimit(self, requester): @defer.inlineCallbacks
def ratelimit(self, requester, update=True):
"""Ratelimits requests.
Args:
requester (Requester)
update (bool): Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
Raises:
LimitExceededError if the request should be ratelimited
"""
time_now = self.clock.time() time_now = self.clock.time()
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -67,10 +80,25 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited(): if requester.app_service and not requester.app_service.is_rate_limited():
return return
# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
if override:
# If overriden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
messages_per_second = self.hs.config.rc_messages_per_second
burst_count = self.hs.config.rc_message_burst_count
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now, user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=burst_count,
update=update,
) )
if not allowed: if not allowed:
raise LimitExceededError( raise LimitExceededError(

View file

@ -17,6 +17,7 @@ from synapse.api.constants import EventTypes
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id, RoomStreamToken from synapse.types import get_domain_from_id, RoomStreamToken
from twisted.internet import defer from twisted.internet import defer
@ -425,12 +426,38 @@ class DeviceListEduUpdater(object):
# This can happen since we batch updates # This can happen since we batch updates
return return
# Given a list of updates we check if we need to resync. This
# happens if we've missed updates.
resync = yield self._need_to_do_resync(user_id, pending_updates) resync = yield self._need_to_do_resync(user_id, pending_updates)
if resync: if resync:
# Fetch all devices for the user. # Fetch all devices for the user.
origin = get_domain_from_id(user_id) origin = get_domain_from_id(user_id)
result = yield self.federation.query_user_devices(origin, user_id) try:
result = yield self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
# TODO: Remember that we are now out of sync and try again
# later
logger.warn(
"Failed to handle device list update for %s,"
" we're not retrying the remote",
user_id,
)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
logger.exception(
"Failed to handle device list update for %s", user_id
)
return
stream_id = result["stream_id"] stream_id = result["stream_id"]
devices = result["devices"] devices = result["devices"]
yield self.store.update_remote_device_list_cache( yield self.store.update_remote_device_list_cache(

View file

@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -145,7 +145,7 @@ class E2eKeysHandler(object):
"status": 503, "message": e.message "status": 503, "message": e.message
} }
yield preserve_context_over_deferred(defer.gatherResults([ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(do_remote_query)(destination) preserve_fn(do_remote_query)(destination)
for destination in remote_queries_not_in_cache for destination in remote_queries_not_in_cache
])) ]))
@ -257,11 +257,21 @@ class E2eKeysHandler(object):
"status": 503, "message": e.message "status": 503, "message": e.message
} }
yield preserve_context_over_deferred(defer.gatherResults([ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(claim_client_keys)(destination) preserve_fn(claim_client_keys)(destination)
for destination in remote_queries for destination in remote_queries
])) ]))
logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in json_result.iteritems()
for device_id, device_keys in user_keys.iteritems()
for key_id, _ in device_keys.iteritems()
)),
)
defer.returnValue({ defer.returnValue({
"one_time_keys": json_result, "one_time_keys": json_result,
"failures": failures "failures": failures
@ -288,19 +298,8 @@ class E2eKeysHandler(object):
one_time_keys = keys.get("one_time_keys", None) one_time_keys = keys.get("one_time_keys", None)
if one_time_keys: if one_time_keys:
logger.info( yield self._upload_one_time_keys_for_user(
"Adding %d one_time_keys for device %r for user %r at %d", user_id, device_id, time_now, one_time_keys,
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
) )
# the device should have been registered already, but it may have been # the device should have been registered already, but it may have been
@ -313,3 +312,58 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result}) defer.returnValue({"one_time_key_counts": result})
@defer.inlineCallbacks
def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
one_time_keys):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(), device_id, user_id, time_now,
)
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, key_obj
))
# First we check if we have already persisted any of the keys.
existing_key_map = yield self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
for algorithm, key_id, key in key_list:
ex_json = existing_key_map.get((algorithm, key_id), None)
if ex_json:
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
("One time key %s:%s already exists. "
"Old key: %s; new key: %r") %
(algorithm, key_id, ex_json, key)
)
else:
new_keys.append((algorithm, key_id, encode_canonical_json(key)))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
)
def _one_time_keys_match(old_key_json, new_key):
old_key = json.loads(old_key_json)
# if either is a string rather than an object, they must match exactly
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
return old_key == new_key
# otherwise, we strip off the 'signatures' if any, because it's legitimate
# for different upload attempts to have different signatures.
old_key.pop("signatures", None)
new_key_copy = dict(new_key)
new_key_copy.pop("signatures", None)
return old_key == new_key_copy

View file

@ -380,13 +380,6 @@ class FederationHandler(BaseHandler):
affected=event.event_id, affected=event.event_id,
) )
# if we're receiving valid events from an origin,
# it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)
if not room: if not room:

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -254,17 +254,7 @@ class MessageHandler(BaseHandler):
# We check here if we are currently being rate limited, so that we # We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually # don't do unnecessary work. We check again just before we actually
# send the event. # send the event.
time_now = self.clock.time() yield self.ratelimit(requester, update=False)
allowed, time_allowed = self.ratelimiter.send_message(
event.sender, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
@ -499,7 +489,7 @@ class MessageHandler(BaseHandler):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit: if ratelimit:
self.ratelimit(requester) yield self.ratelimit(requester)
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(event, context)

View file

@ -780,12 +780,12 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part # don't need to send to local clients here, as that is done as part
# of the event stream/sync. # of the event stream/sync.
# TODO: Only send to servers not already in the room. # TODO: Only send to servers not already in the room.
user_ids = yield self.store.get_users_in_room(room_id)
if self.is_mine(user): if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string()) state = yield self.current_state_for_user(user.to_string())
self._push_to_remotes([state]) self._push_to_remotes([state])
else: else:
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.is_mine_id, user_ids) user_ids = filter(self.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids) states = yield self.current_state_for_users(user_ids)
@ -1322,7 +1322,7 @@ def get_interested_parties(store, states):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_interested_remotes(store, states): def get_interested_remotes(store, states, state_handler):
"""Given a list of presence states figure out which remote servers """Given a list of presence states figure out which remote servers
should be sent which. should be sent which.
@ -1345,7 +1345,7 @@ def get_interested_remotes(store, states):
room_ids_to_states, users_to_states = yield get_interested_parties(store, states) room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
for room_id, states in room_ids_to_states.iteritems(): for room_id, states in room_ids_to_states.iteritems():
hosts = yield store.get_hosts_in_room(room_id) hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states)) hosts_and_states.append((hosts, states))
for user_id, states in users_to_states.iteritems(): for user_id, states in users_to_states.iteritems():

View file

@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
return return
self.ratelimit(requester) yield self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user( room_ids = yield self.store.get_rooms_for_user(
user.to_string(), user.to_string(),

View file

@ -54,6 +54,13 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME Codes.INVALID_USERNAME
) )
if not localpart:
raise SynapseError(
400,
"User ID cannot be empty",
Codes.INVALID_USERNAME
)
if localpart[0] == '_': if localpart[0] == '_':
raise SynapseError( raise SynapseError(
400, 400,

View file

@ -75,7 +75,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
self.ratelimit(requester) yield self.ratelimit(requester)
if "room_alias_name" in config: if "room_alias_name" in config:
for wchar in string.whitespace: for wchar in string.whitespace:

View file

@ -739,10 +739,11 @@ class RoomMemberHandler(BaseHandler):
if len(current_state_ids) == 1 and create_event_id: if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id)) defer.returnValue(self.hs.is_mine_id(create_event_id))
for (etype, state_key), event_id in current_state_ids.items(): for etype, state_key in current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue continue
event_id = current_state_ids[(etype, state_key)]
event = yield self.store.get_event(event_id, allow_none=True) event = yield self.store.get_event(event_id, allow_none=True)
if not event: if not event:
continue continue

View file

@ -20,6 +20,7 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_clients_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,6 +67,17 @@ class BulkPushRuleEvaluator:
def action_for_event_by_user(self, event, context): def action_for_event_by_user(self, event, context):
actions_by_user = {} actions_by_user = {}
# None of these users can be peeking since this list of users comes
# from the set of users in the room, so we know for sure they're all
# actually in the room.
user_tuples = [
(u, False) for u in self.rules_by_user.keys()
]
filtered_by_user = yield filter_events_for_clients_context(
self.store, user_tuples, [event], {event.event_id: context}
)
room_members = yield self.store.get_joined_users_from_context( room_members = yield self.store.get_joined_users_from_context(
event, context event, context
) )
@ -75,14 +87,6 @@ class BulkPushRuleEvaluator:
condition_cache = {} condition_cache = {}
for uid, rules in self.rules_by_user.items(): for uid, rules in self.rules_by_user.items():
if event.sender == uid:
continue
if not event.is_state():
is_ignored = yield self.store.is_ignored_by(event.sender, uid)
if is_ignored:
continue
display_name = None display_name = None
profile_info = room_members.get(uid) profile_info = room_members.get(uid)
if profile_info: if profile_info:
@ -94,6 +98,13 @@ class BulkPushRuleEvaluator:
if event.type == EventTypes.Member and event.state_key == uid: if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None) display_name = event.content.get("displayname", None)
filtered = filtered_by_user[uid]
if len(filtered) == 0:
continue
if filtered[0].sender == uid:
continue
for rule in rules: for rule in rules:
if 'enabled' in rule and not rule['enabled']: if 'enabled' in rule and not rule['enabled']:
continue continue

View file

@ -47,3 +47,13 @@ def client_v2_patterns(path_regex, releases=(0,),
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns return patterns
def set_timeline_upper_limit(filter_json, filter_timeline_limit):
if filter_timeline_limit < 0:
return # no upper limits
timeline = filter_json.get('room', {}).get('timeline', {})
if 'limit' in timeline:
filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'],
filter_timeline_limit)

View file

@ -20,6 +20,7 @@ from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from ._base import client_v2_patterns from ._base import client_v2_patterns
from ._base import set_timeline_upper_limit
import logging import logging
@ -85,6 +86,11 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Can only create filters for local users") raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
set_timeline_upper_limit(
content,
self.hs.config.filter_timeline_limit
)
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,
user_filter=content, user_filter=content,

View file

@ -21,7 +21,7 @@ from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
) )
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
@ -147,10 +147,9 @@ class UsernameAvailabilityRestServlet(RestServlet):
with self.ratelimiter.ratelimit(ip) as wait_deferred: with self.ratelimiter.ratelimit(ip) as wait_deferred:
yield wait_deferred yield wait_deferred
body = parse_json_object_from_request(request) username = parse_string(request, "username", required=True)
assert_params_in_request(body, ['username'])
yield self.registration_handler.check_username(body['username']) yield self.registration_handler.check_username(username)
defer.returnValue((200, {"available": True})) defer.returnValue((200, {"available": True}))

View file

@ -28,6 +28,7 @@ from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from ._base import client_v2_patterns from ._base import client_v2_patterns
from ._base import set_timeline_upper_limit
import itertools import itertools
import logging import logging
@ -78,6 +79,7 @@ class SyncRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(SyncRestServlet, self).__init__() super(SyncRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.sync_handler = hs.get_sync_handler() self.sync_handler = hs.get_sync_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -121,6 +123,8 @@ class SyncRestServlet(RestServlet):
if filter_id.startswith('{'): if filter_id.startswith('{'):
try: try:
filter_object = json.loads(filter_id) filter_object = json.loads(filter_id)
set_timeline_upper_limit(filter_object,
self.hs.config.filter_timeline_limit)
except: except:
raise SynapseError(400, "Invalid filter JSON") raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object) self.filtering.check_valid_filter(filter_object)

View file

@ -34,6 +34,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import NotRetryingDestination
import os import os
import errno import errno
@ -181,7 +182,8 @@ class MediaRepository(object):
logger.exception("Failed to fetch remote media %s/%s", logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id) server_name, media_id)
raise raise
except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name)
except Exception: except Exception:
logger.exception("Failed to fetch remote media %s/%s", logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id) server_name, media_id)

View file

@ -60,12 +60,12 @@ class LoggingTransaction(object):
object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks) object.__setattr__(self, "after_callbacks", after_callbacks)
def call_after(self, callback, *args): def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the transaction has finished. Used to invalidate the caches on the
correct thread. correct thread.
""" """
self.after_callbacks.append((callback, args)) self.after_callbacks.append((callback, args, kwargs))
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -319,8 +319,8 @@ class SQLBaseStore(object):
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
finally: finally:
for after_callback, after_args in after_callbacks: for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args) after_callback(*after_args, **after_kwargs)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -308,16 +308,3 @@ class AccountDataStore(SQLBaseStore):
" WHERE stream_id < ?" " WHERE stream_id < ?"
) )
txn.execute(update_max_id_sql, (next_id, next_id)) txn.execute(update_max_id_sql, (next_id, next_id))
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
"m.ignored_user_list", ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
defer.returnValue(False)
defer.returnValue(
ignored_user_id in ignored_account_data.get("ignored_users", {})
)

View file

@ -210,7 +210,9 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_handlers[update_name] = update_handler self._background_update_handlers[update_name] = update_handler
def register_background_index_update(self, update_name, index_name, def register_background_index_update(self, update_name, index_name,
table, columns, where_clause=None): table, columns, where_clause=None,
unique=False,
psql_only=False):
"""Helper for store classes to do a background index addition """Helper for store classes to do a background index addition
To use: To use:
@ -226,6 +228,9 @@ class BackgroundUpdateStore(SQLBaseStore):
index_name (str): name of index to add index_name (str): name of index to add
table (str): table to add index to table (str): table to add index to
columns (list[str]): columns/expressions to include in index columns (list[str]): columns/expressions to include in index
unique (bool): true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
""" """
def create_index_psql(conn): def create_index_psql(conn):
@ -245,9 +250,11 @@ class BackgroundUpdateStore(SQLBaseStore):
c.execute(sql) c.execute(sql)
sql = ( sql = (
"CREATE INDEX CONCURRENTLY %(name)s ON %(table)s" "CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
" ON %(table)s"
" (%(columns)s) %(where_clause)s" " (%(columns)s) %(where_clause)s"
) % { ) % {
"unique": "UNIQUE" if unique else "",
"name": index_name, "name": index_name,
"table": table, "table": table,
"columns": ", ".join(columns), "columns": ", ".join(columns),
@ -270,9 +277,10 @@ class BackgroundUpdateStore(SQLBaseStore):
# down at the wrong moment - hance we use IF NOT EXISTS. (SQLite # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite
# has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.) # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.)
sql = ( sql = (
"CREATE INDEX IF NOT EXISTS %(name)s ON %(table)s" "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s"
" (%(columns)s)" " (%(columns)s)"
) % { ) % {
"unique": "UNIQUE" if unique else "",
"name": index_name, "name": index_name,
"table": table, "table": table,
"columns": ", ".join(columns), "columns": ", ".join(columns),
@ -284,13 +292,16 @@ class BackgroundUpdateStore(SQLBaseStore):
if isinstance(self.database_engine, engines.PostgresEngine): if isinstance(self.database_engine, engines.PostgresEngine):
runner = create_index_psql runner = create_index_psql
elif psql_only:
runner = None
else: else:
runner = create_index_sqlite runner = create_index_sqlite
@defer.inlineCallbacks @defer.inlineCallbacks
def updater(progress, batch_size): def updater(progress, batch_size):
logger.info("Adding index %s to %s", index_name, table) if runner is not None:
yield self.runWithConnection(runner) logger.info("Adding index %s to %s", index_name, table)
yield self.runWithConnection(runner)
yield self._end_background_update(update_name) yield self._end_background_update(update_name)
defer.returnValue(1) defer.returnValue(1)

View file

@ -33,6 +33,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
max_entries=5000,
) )
super(ClientIpStore, self).__init__(hs) super(ClientIpStore, self).__init__(hs)

View file

@ -18,7 +18,7 @@ import ujson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore, Cache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
@ -29,6 +29,14 @@ class DeviceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(DeviceStore, self).__init__(hs) super(DeviceStore, self).__init__(hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists",
keylen=2,
max_entries=10000,
)
self._clock.looping_call( self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000 self._prune_old_outbound_device_pokes, 60 * 60 * 1000
) )
@ -54,6 +62,10 @@ class DeviceStore(SQLBaseStore):
defer.Deferred: boolean whether the device was inserted or an defer.Deferred: boolean whether the device was inserted or an
existing device existed with that ID. existing device existed with that ID.
""" """
key = (user_id, device_id)
if self.device_id_exists_cache.get(key, None):
defer.returnValue(False)
try: try:
inserted = yield self._simple_insert( inserted = yield self._simple_insert(
"devices", "devices",
@ -65,6 +77,7 @@ class DeviceStore(SQLBaseStore):
desc="store_device", desc="store_device",
or_ignore=True, or_ignore=True,
) )
self.device_id_exists_cache.prefill(key, True)
defer.returnValue(inserted) defer.returnValue(inserted)
except Exception as e: except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
@ -93,6 +106,7 @@ class DeviceStore(SQLBaseStore):
desc="get_device", desc="get_device",
) )
@defer.inlineCallbacks
def delete_device(self, user_id, device_id): def delete_device(self, user_id, device_id):
"""Delete a device. """Delete a device.
@ -102,12 +116,15 @@ class DeviceStore(SQLBaseStore):
Returns: Returns:
defer.Deferred defer.Deferred
""" """
return self._simple_delete_one( yield self._simple_delete_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_device", desc="delete_device",
) )
self.device_id_exists_cache.invalidate((user_id, device_id))
@defer.inlineCallbacks
def delete_devices(self, user_id, device_ids): def delete_devices(self, user_id, device_ids):
"""Deletes several devices. """Deletes several devices.
@ -117,13 +134,15 @@ class DeviceStore(SQLBaseStore):
Returns: Returns:
defer.Deferred defer.Deferred
""" """
return self._simple_delete_many( yield self._simple_delete_many(
table="devices", table="devices",
column="device_id", column="device_id",
iterable=device_ids, iterable=device_ids,
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
desc="delete_devices", desc="delete_devices",
) )
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
def update_device(self, user_id, device_id, new_display_name=None): def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device. """Update a device.

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
import ujson as json import ujson as json
@ -123,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
return result return result
@defer.inlineCallbacks @defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
"""Insert some new one time keys for a device. """Retrieve a number of one-time keys for a user
Checks if any of the keys are already inserted, if they are then check Args:
if they match. If they don't then we raise an error. user_id(str): id of user to get keys for
device_id(str): id of device to get keys for
key_ids(list[str]): list of key ids (excluding algorithm) to
retrieve
Returns:
deferred resolving to Dict[(str, str), str]: map from (algorithm,
key_id) to json string for key
""" """
# First we check if we have already persisted any of the keys.
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
column="key_id", column="key_id",
iterable=[key_id for _, key_id, _ in key_list], iterable=key_ids,
retcols=("algorithm", "key_id", "key_json",), retcols=("algorithm", "key_id", "key_json",),
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -143,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
desc="add_e2e_one_time_keys_check", desc="add_e2e_one_time_keys_check",
) )
existing_key_map = { defer.returnValue({
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
} })
new_keys = [] # Keys that we need to insert @defer.inlineCallbacks
for algorithm, key_id, json_bytes in key_list: def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
ex_bytes = existing_key_map.get((algorithm, key_id), None) """Insert some new one time keys for a device. Errors if any of the
if ex_bytes: keys already exist.
if json_bytes != ex_bytes:
raise SynapseError( Args:
400, "One time key with key_id %r already exists" % (key_id,) user_id(str): id of user to get keys for
) device_id(str): id of device to get keys for
else: time_now(long): insertion time to record (ms since epoch)
new_keys.append((algorithm, key_id, json_bytes)) new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
(algorithm, key_id, key json)
"""
def _add_e2e_one_time_keys(txn): def _add_e2e_one_time_keys(txn):
# We are protected from race between lookup and insertion due to # We are protected from race between lookup and insertion due to
@ -177,10 +185,14 @@ class EndToEndKeyStore(SQLBaseStore):
for algorithm, key_id, json_bytes in new_keys for algorithm, key_id, json_bytes in new_keys
], ],
) )
txn.call_after(
self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
)
yield self.runInteraction( yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
) )
@cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id): def count_e2e_one_time_keys(self, user_id, device_id):
""" Count the number of one time keys the server has for a device """ Count the number of one time keys the server has for a device
Returns: Returns:
@ -225,6 +237,9 @@ class EndToEndKeyStore(SQLBaseStore):
) )
for user_id, device_id, algorithm, key_id in delete: for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id)) txn.execute(sql, (user_id, device_id, algorithm, key_id))
txn.call_after(
self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
)
return result return result
return self.runInteraction( return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
@ -242,3 +257,4 @@ class EndToEndKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_one_time_keys_by_device" desc="delete_e2e_one_time_keys_by_device"
) )
self.count_e2e_one_time_keys.invalidate((user_id, device_id,))

View file

@ -207,6 +207,18 @@ class EventsStore(SQLBaseStore):
where_clause="contains_url = true AND outlier = false", where_clause="contains_url = true AND outlier = false",
) )
# an event_id index on event_search is useful for the purge_history
# api. Plus it means we get to enforce some integrity with a UNIQUE
# clause
self.register_background_index_update(
"event_search_event_id_idx",
index_name="event_search_event_id_idx",
table="event_search",
columns=["event_id"],
unique=True,
psql_only=True,
)
self._event_persist_queue = _EventPeristenceQueue() self._event_persist_queue = _EventPeristenceQueue()
def persist_events(self, events_and_contexts, backfilled=False): def persist_events(self, events_and_contexts, backfilled=False):
@ -387,6 +399,11 @@ class EventsStore(SQLBaseStore):
event_counter.inc(event.type, origin_type, origin_entity) event_counter.inc(event.type, origin_type, origin_entity)
for room_id, (_, _, new_state) in current_state_for_room.iteritems():
self.get_current_state_ids.prefill(
(room_id, ), new_state
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
"""Calculates the new forward extremeties for a room given events to """Calculates the new forward extremeties for a room given events to
@ -435,10 +452,10 @@ class EventsStore(SQLBaseStore):
Assumes that we are only persisting events for one room at a time. Assumes that we are only persisting events for one room at a time.
Returns: Returns:
2-tuple (to_delete, to_insert) where both are state dicts, i.e. 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
(type, state_key) -> event_id. `to_delete` are the entries to i.e. (type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries first be deleted from current_state_events, `to_insert` are entries
to insert. to insert. `new_state` is the full set of state.
May return None if there are no changes to be applied. May return None if there are no changes to be applied.
""" """
# Now we need to work out the different state sets for # Now we need to work out the different state sets for
@ -545,7 +562,7 @@ class EventsStore(SQLBaseStore):
if ev_id in events_to_insert if ev_id in events_to_insert
} }
defer.returnValue((to_delete, to_insert)) defer.returnValue((to_delete, to_insert, current_state))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
@ -698,7 +715,7 @@ class EventsStore(SQLBaseStore):
def _update_current_state_txn(self, txn, state_delta_by_room): def _update_current_state_txn(self, txn, state_delta_by_room):
for room_id, current_state_tuple in state_delta_by_room.iteritems(): for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert = current_state_tuple to_delete, to_insert, _ = current_state_tuple
txn.executemany( txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?", "DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()], [(ev_id,) for ev_id in to_delete.itervalues()],
@ -1343,11 +1360,26 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,)) self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected): def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
Args:
events (list(str)): list of event_ids to fetch
allow_rejected (bool): Whether to teturn events that were rejected
update_metrics (bool): Whether to update the cache hit ratio metrics
Returns:
dict of event_id -> _EventCacheEntry for each event_id in cache. If
allow_rejected is `False` then there will still be an entry but it
will be `None`
"""
event_map = {} event_map = {}
for event_id in events: for event_id in events:
ret = self._get_event_cache.get((event_id,), None) ret = self._get_event_cache.get(
(event_id,), None,
update_metrics=update_metrics,
)
if not ret: if not ret:
continue continue
@ -2007,6 +2039,8 @@ class EventsStore(SQLBaseStore):
400, "topological_ordering is greater than forward extremeties" 400, "topological_ordering is greater than forward extremeties"
) )
logger.debug("[purge] looking for events to delete")
txn.execute( txn.execute(
"SELECT event_id, state_key FROM events" "SELECT event_id, state_key FROM events"
" LEFT JOIN state_events USING (room_id, event_id)" " LEFT JOIN state_events USING (room_id, event_id)"
@ -2015,9 +2049,19 @@ class EventsStore(SQLBaseStore):
) )
event_rows = txn.fetchall() event_rows = txn.fetchall()
to_delete = [
(event_id,) for event_id, state_key in event_rows
if state_key is None and not self.hs.is_mine_id(event_id)
]
logger.info(
"[purge] found %i events before cutoff, of which %i are remote"
" non-state events to delete", len(event_rows), len(to_delete))
for event_id, state_key in event_rows: for event_id, state_key in event_rows:
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
logger.debug("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding # We calculate the new entries for the backward extremeties by finding
# all events that point to events that are to be purged # all events that point to events that are to be purged
txn.execute( txn.execute(
@ -2030,6 +2074,8 @@ class EventsStore(SQLBaseStore):
) )
new_backwards_extrems = txn.fetchall() new_backwards_extrems = txn.fetchall()
logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
txn.execute( txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?", "DELETE FROM event_backward_extremities WHERE room_id = ?",
(room_id,) (room_id,)
@ -2044,6 +2090,8 @@ class EventsStore(SQLBaseStore):
] ]
) )
logger.debug("[purge] finding redundant state groups")
# Get all state groups that are only referenced by events that are # Get all state groups that are only referenced by events that are
# to be deleted. # to be deleted.
txn.execute( txn.execute(
@ -2059,15 +2107,20 @@ class EventsStore(SQLBaseStore):
) )
state_rows = txn.fetchall() state_rows = txn.fetchall()
state_groups_to_delete = [sg for sg, in state_rows] logger.debug("[purge] found %i redundant state groups", len(state_rows))
# make a set of the redundant state groups, so that we can look them up
# efficiently
state_groups_to_delete = set([sg for sg, in state_rows])
# Now we get all the state groups that rely on these state groups # Now we get all the state groups that rely on these state groups
new_state_edges = [] logger.debug("[purge] finding state groups which depend on redundant"
chunks = [ " state groups")
state_groups_to_delete[i:i + 100] remaining_state_groups = []
for i in xrange(0, len(state_groups_to_delete), 100) for i in xrange(0, len(state_rows), 100):
] chunk = [sg for sg, in state_rows[i:i + 100]]
for chunk in chunks: # look for state groups whose prev_state_group is one we are about
# to delete
rows = self._simple_select_many_txn( rows = self._simple_select_many_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
@ -2076,21 +2129,28 @@ class EventsStore(SQLBaseStore):
retcols=["state_group"], retcols=["state_group"],
keyvalues={}, keyvalues={},
) )
new_state_edges.extend(row["state_group"] for row in rows) remaining_state_groups.extend(
row["state_group"] for row in rows
# Now we turn the state groups that reference to-be-deleted state groups # exclude state groups we are about to delete: no point in
# to non delta versions. # updating them
for new_state_edge in new_state_edges: if row["state_group"] not in state_groups_to_delete
curr_state = self._get_state_groups_from_groups_txn(
txn, [new_state_edge], types=None
) )
curr_state = curr_state[new_state_edge]
# Now we turn the state groups that reference to-be-deleted state
# groups to non delta versions.
for sg in remaining_state_groups:
logger.debug("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(
txn, [sg], types=None
)
curr_state = curr_state[sg]
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="state_groups_state", table="state_groups_state",
keyvalues={ keyvalues={
"state_group": new_state_edge, "state_group": sg,
} }
) )
@ -2098,7 +2158,7 @@ class EventsStore(SQLBaseStore):
txn, txn,
table="state_group_edges", table="state_group_edges",
keyvalues={ keyvalues={
"state_group": new_state_edge, "state_group": sg,
} }
) )
@ -2107,7 +2167,7 @@ class EventsStore(SQLBaseStore):
table="state_groups_state", table="state_groups_state",
values=[ values=[
{ {
"state_group": new_state_edge, "state_group": sg,
"room_id": room_id, "room_id": room_id,
"type": key[0], "type": key[0],
"state_key": key[1], "state_key": key[1],
@ -2117,6 +2177,7 @@ class EventsStore(SQLBaseStore):
], ],
) )
logger.debug("[purge] removing redundant state groups")
txn.executemany( txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?", "DELETE FROM state_groups_state WHERE state_group = ?",
state_rows state_rows
@ -2125,22 +2186,21 @@ class EventsStore(SQLBaseStore):
"DELETE FROM state_groups WHERE id = ?", "DELETE FROM state_groups WHERE id = ?",
state_rows state_rows
) )
# Delete all non-state # Delete all non-state
logger.debug("[purge] removing events from event_to_state_groups")
txn.executemany( txn.executemany(
"DELETE FROM event_to_state_groups WHERE event_id = ?", "DELETE FROM event_to_state_groups WHERE event_id = ?",
[(event_id,) for event_id, _ in event_rows] [(event_id,) for event_id, _ in event_rows]
) )
logger.debug("[purge] updating room_depth")
txn.execute( txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?", "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
(topological_ordering, room_id,) (topological_ordering, room_id,)
) )
# Delete all remote non-state events # Delete all remote non-state events
to_delete = [
(event_id,) for event_id, state_key in event_rows
if state_key is None and not self.hs.is_mine_id(event_id)
]
for table in ( for table in (
"events", "events",
"event_json", "event_json",
@ -2156,16 +2216,15 @@ class EventsStore(SQLBaseStore):
"event_signatures", "event_signatures",
"rejections", "rejections",
): ):
logger.debug("[purge] removing remote non-state events from %s", table)
txn.executemany( txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,), "DELETE FROM %s WHERE event_id = ?" % (table,),
to_delete to_delete
) )
txn.executemany(
"DELETE FROM events WHERE event_id = ?",
to_delete
)
# Mark all state and own events as outliers # Mark all state and own events as outliers
logger.debug("[purge] marking remaining events as outliers")
txn.executemany( txn.executemany(
"UPDATE events SET outlier = ?" "UPDATE events SET outlier = ?"
" WHERE event_id = ?", " WHERE event_id = ?",
@ -2175,6 +2234,8 @@ class EventsStore(SQLBaseStore):
] ]
) )
logger.info("[purge] done")
@defer.inlineCallbacks @defer.inlineCallbacks
def is_event_after(self, event_id1, event_id2): def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream """Returns True if event_id1 is after event_id2 in the stream

View file

@ -16,6 +16,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -184,11 +185,23 @@ class PushRuleStore(SQLBaseStore):
if uid in local_users_in_room: if uid in local_users_in_room:
user_ids.add(uid) user_ids.add(uid)
forgotten = yield self.who_forgot_in_room(
event.room_id, on_invalidate=cache_context.invalidate,
)
for row in forgotten:
user_id = row["user_id"]
event_id = row["event_id"]
mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
if event_id == mem_id:
user_ids.discard(user_id)
rules_by_user = yield self.bulk_get_push_rules( rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate, user_ids, on_invalidate=cache_context.invalidate,
) )
rules_by_user = {k: v for k, v in rules_by_user.iteritems() if v is not None} rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user) defer.returnValue(rules_by_user)
@ -398,8 +411,7 @@ class PushRuleStore(SQLBaseStore):
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids stream_id, event_stream_ordering = ids
yield self.runInteraction( yield self.runInteraction(
"delete_push_rule", delete_push_rule_txn, stream_id, "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
event_stream_ordering,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from ._base import SQLBaseStore from ._base import SQLBaseStore
from .engines import PostgresEngine, Sqlite3Engine from .engines import PostgresEngine, Sqlite3Engine
@ -33,6 +33,11 @@ OpsLevel = collections.namedtuple(
("ban_level", "kick_level", "redact_level",) ("ban_level", "kick_level", "redact_level",)
) )
RatelimitOverride = collections.namedtuple(
"RatelimitOverride",
("messages_per_second", "burst_count",)
)
class RoomStore(SQLBaseStore): class RoomStore(SQLBaseStore):
@ -473,3 +478,32 @@ class RoomStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms "get_all_new_public_rooms", get_all_new_public_rooms
) )
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
Args:
user_id (str)
Returns:
RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self._simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
allow_none=True,
desc="get_ratelimit_for_user",
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
else:
defer.returnValue(None)

View file

@ -421,9 +421,13 @@ class RoomMemberStore(SQLBaseStore):
# We check if we have any of the member event ids in the event cache # We check if we have any of the member event ids in the event cache
# before we ask the DB # before we ask the DB
# We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
event_map = self._get_events_from_cache( event_map = self._get_events_from_cache(
member_event_ids, member_event_ids,
allow_rejected=False, allow_rejected=False,
update_metrics=False,
) )
missing_member_event_ids = [] missing_member_event_ids = []
@ -530,7 +534,7 @@ class RoomMemberStore(SQLBaseStore):
assert state_group is not None assert state_group is not None
joined_hosts = set() joined_hosts = set()
for (etype, state_key), event_id in current_state_ids.items(): for etype, state_key in current_state_ids:
if etype == EventTypes.Member: if etype == EventTypes.Member:
try: try:
host = get_domain_from_id(state_key) host = get_domain_from_id(state_key)
@ -541,6 +545,7 @@ class RoomMemberStore(SQLBaseStore):
if host in joined_hosts: if host in joined_hosts:
continue continue
event_id = current_state_ids[(etype, state_key)]
event = yield self.get_event(event_id, allow_none=True) event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN: if event and event.content["membership"] == Membership.JOIN:
joined_hosts.add(intern_string(host)) joined_hosts.add(intern_string(host))

View file

@ -36,6 +36,10 @@ DROP INDEX IF EXISTS transactions_have_ref;
-- and is used incredibly rarely. -- and is used incredibly rarely.
DROP INDEX IF EXISTS events_order_topo_stream_room; DROP INDEX IF EXISTS events_order_topo_stream_room;
-- an equivalent index to this actually gets re-created in delta 41, because it
-- turned out that deleting it wasn't a great plan :/. In any case, let's
-- delete it here, and delta 41 will create a new one with an added UNIQUE
-- constraint
DROP INDEX IF EXISTS event_search_ev_idx; DROP INDEX IF EXISTS event_search_ev_idx;
""" """

View file

@ -0,0 +1,17 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT into background_updates (update_name, progress_json)
VALUES ('event_search_event_id_idx', '{}');

View file

@ -0,0 +1,22 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE ratelimit_override (
user_id TEXT NOT NULL,
messages_per_second BIGINT,
burst_count BIGINT
);
CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id);

View file

@ -227,6 +227,18 @@ class StateStore(SQLBaseStore):
], ],
) )
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",

View file

@ -56,10 +56,10 @@ def create_requester(user_id, access_token_id=None, is_guest=False,
def get_domain_from_id(string): def get_domain_from_id(string):
try: idx = string.find(":")
return string.split(":", 1)[1] if idx == -1:
except IndexError:
raise SynapseError(400, "Invalid ID: %r" % (string,)) raise SynapseError(400, "Invalid ID: %r" % (string,))
return string[idx + 1:]
class DomainSpecificString( class DomainSpecificString(

View file

@ -96,7 +96,7 @@ class Cache(object):
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, key, default=_CacheSentinel, callback=None): def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
"""Looks the key up in the caches. """Looks the key up in the caches.
Args: Args:
@ -104,6 +104,7 @@ class Cache(object):
default: What is returned if key is not in the caches. If not default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns: Returns:
Either a Deferred or the raw result Either a Deferred or the raw result
@ -113,7 +114,8 @@ class Cache(object):
if val is not _CacheSentinel: if val is not _CacheSentinel:
if val.sequence == self.sequence: if val.sequence == self.sequence:
val.callbacks.update(callbacks) val.callbacks.update(callbacks)
self.metrics.inc_hits() if update_metrics:
self.metrics.inc_hits()
return val.deferred return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
@ -121,7 +123,8 @@ class Cache(object):
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
self.metrics.inc_misses() if update_metrics:
self.metrics.inc_misses()
if default is _CacheSentinel: if default is _CacheSentinel:
raise KeyError() raise KeyError()

View file

@ -188,6 +188,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
}) })
@defer.inlineCallbacks
def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
user_ids = set(u[0] for u in user_tuples)
event_id_to_state = {}
for event_id, context in event_id_to_context.items():
state = yield store.get_events([
e_id
for key, e_id in context.current_state_ids.iteritems()
if key == (EventTypes.RoomHistoryVisibility, "")
or (key[0] == EventTypes.Member and key[1] in user_ids)
])
event_id_to_state[event_id] = state
res = yield filter_events_for_clients(
store, user_tuples, events, event_id_to_state
)
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def filter_events_for_client(store, user_id, events, is_peeking=False): def filter_events_for_client(store, user_id, events, is_peeking=False):
""" """

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import mock import mock
from synapse.api import errors
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors import synapse.api.errors
@ -44,3 +45,134 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
res = yield self.handler.query_local_devices({local_user: None}) res = yield self.handler.query_local_devices({local_user: None})
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
def test_reupload_one_time_keys(self):
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {
"alg1:k1": "key1",
"alg2:k2": {
"key": "key2",
"signatures": {"k1": "sig1"}
},
"alg2:k3": {
"key": "key3",
},
}
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys},
)
self.assertDictEqual(res, {
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys},
)
self.assertDictEqual(res, {
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
@defer.inlineCallbacks
def test_change_one_time_keys(self):
"""attempts to change one-time-keys should be rejected"""
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {
"alg1:k1": "key1",
"alg2:k2": {
"key": "key2",
"signatures": {"k1": "sig1"}
},
"alg2:k3": {
"key": "key3",
},
}
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys},
)
self.assertDictEqual(res, {
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
)
self.fail("No error when changing string key")
except errors.SynapseError:
pass
try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
)
self.fail("No error when replacing dict key with string")
except errors.SynapseError:
pass
try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {
"one_time_keys": {"alg1:k1": {"key": "key"}}
},
)
self.fail("No error when replacing string key with dict")
except errors.SynapseError:
pass
try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {
"one_time_keys": {
"alg2:k2": {
"key": "key3",
"signatures": {"k1": "sig1"},
}
},
},
)
self.fail("No error when replacing dict key")
except errors.SynapseError:
pass
@unittest.DEBUG
@defer.inlineCallbacks
def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {
"alg1:k1": "key1",
}
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys},
)
self.assertDictEqual(res, {
"one_time_key_counts": {"alg1": 1}
})
res2 = yield self.handler.claim_one_time_keys({
"one_time_keys": {
local_user: {
device_id: "alg1"
}
}
}, timeout=None)
self.assertEqual(res2, {
"failures": {},
"one_time_keys": {
local_user: {
device_id: {
"alg1:k1": "key1"
}
}
}
})