0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-18 07:52:56 +01:00

Merge branch 'develop' into dbkr/notifications_api

This commit is contained in:
Matthew Hodgson 2016-08-20 00:16:18 +01:00
commit 6e80c03d45
24 changed files with 750 additions and 91 deletions

View file

@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OS X) - POSIX-compliant system (tested on Linux & OS X)
- Python 2.7 - Python 2.7
- At least 512 MB RAM. - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in python but some of the libraries is uses are written in Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the C. So before we can install synapse itself we need a working C compiler and the

97
docs/workers.rst Normal file
View file

@ -0,0 +1,97 @@
Scaling synapse via workers
---------------------------
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
postgres anyway if you care about scalability).
The workers communicate with the master synapse process via a synapse-specific
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
- port: 9092
bind_address: '127.0.0.1'
type: http
tls: false
x_forwarded: false
resources:
- names: [replication]
compress: false
Under **no circumstances** should this replication API listener be exposed to the
public internet; it currently implements no authentication whatsoever and is
unencrypted HTTP.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_url).
For instance::
worker_app: synapse.app.synchrotron
# The replication listener on the synapse to talk to.
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
worker_listeners:
- type: http
port: 8083
resources:
- names:
- client
worker_daemonize: True
worker_pid_file: /home/matrix/synapse/synchrotron.pid
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
in the given directory, e.g.::
synctl -a $CONFIG/workers start
Currently one should always restart all workers when restarting or upgrading
synapse, unless you explicitly know it's safe not to. For instance, restarting
synapse without restarting all the synchrotrons may result in broken typing
notifications.
To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!

View file

@ -81,13 +81,17 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None): sender=None, id=None, protocols=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
@ -219,6 +223,9 @@ class ApplicationService(object):
or user_id == self.sender or user_id == self.sender
) )
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind
import logging import logging
import urllib import urllib
@ -24,6 +25,28 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient): class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and """This class manages HS -> AS communications, including querying and
pushing. pushing.
@ -71,6 +94,43 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex) logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events) events = self._serialize(events)

View file

@ -123,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError( raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj "Missing/bad type 'exclusive' key in %s", regex_obj
) )
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
url=as_info["url"], url=as_info["url"],
@ -130,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
protocols=protocols,
) )

View file

@ -159,6 +159,22 @@ class ApplicationServicesHandler(object):
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([
self.appservice_api.query_3pe(service, kind, protocol, fields)
for service in services
], consumeErrors=True)
ret = []
for (success, result) in results:
if success:
ret.extend(result)
defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_services_for_event(self, event): def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.
@ -187,6 +203,14 @@ class ApplicationServicesHandler(object):
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):

View file

@ -40,7 +40,7 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"): with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.current_state event, self.hs, self.store, context.state_group, context.current_state
) )
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):

View file

@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, current_state): def evaluator_for_event(event, hs, store, state_group, current_state):
room_id = event.room_id rules_by_user = yield store.bulk_get_push_rules_for_room(
# We also will want to generate notifs for other people in the room so event.room_id, state_group, current_state
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
) )
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite': if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
if invited_user and hs.is_mine_id(invited_user): if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user) has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher: if has_pusher:
user_ids.add(invited_user) rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user
rules_by_user = yield _get_rules(room_id, user_ids, store) )
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, user_ids, store event.room_id, rules_by_user, store
)) ))
@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, users_in_room, store): def __init__(self, room_id, rules_by_user, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.users_in_room = users_in_room
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -48,6 +48,7 @@ from synapse.rest.client.v2_alpha import (
openid, openid,
notifications, notifications,
devices, devices,
thirdparty,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -94,3 +95,4 @@ class ClientRestResource(JsonResource):
openid.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource) notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)

View file

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
defer.returnValue((200, results))
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
defer.returnValue((200, results))
def register_servlets(hs, http_server):
ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server)

View file

@ -166,7 +166,7 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters() self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(

View file

@ -366,8 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
def get_new_events_for_appservice_txn(txn): def get_new_events_for_appservice_txn(txn):
sql = ( sql = (
"SELECT e.stream_ordering, e.event_id" "SELECT e.stream_ordering, e.event_id"
" FROM events AS e, appservice_stream_position AS a" " FROM events AS e"
" WHERE a.stream_ordering < e.stream_ordering AND e.stream_ordering <= ?" " WHERE"
" (SELECT stream_ordering FROM appservice_stream_position)"
" < e.stream_ordering"
" AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC" " ORDER BY e.stream_ordering ASC"
" LIMIT ?" " LIMIT ?"
) )

View file

@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
) )
self._simple_insert_many_txn(txn, "event_push_actions", values) self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000) @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user( def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id, user_id, last_read_event_id
): ):

View file

@ -16,6 +16,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",
@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules) defer.returnValue(rules)
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table="push_rules_enable", table="push_rules_enable",
@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
cache_context):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and self.hs.is_mine_id(e.state_key)
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate,
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@cachedList(cached_method_name="get_push_rules_enabled_for_user", @cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True) list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids): def bulk_get_push_rules_enabled(self, user_ids):

View file

@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn "get_all_updated_pushers", get_all_updated_pushers_txn
) )
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000) @cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id): def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch( result = yield self._simple_select_many_batch(
table='pushers', table='pushers',

View file

@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True) @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.

View file

@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes""" """Persistence for event signatures and hashes"""
@cached(lru=True) @cached()
def get_event_reference_hash(self, event_id): def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id) return self._get_event_reference_hashes_txn(event_id)

View file

@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results] return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f) return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, lru=True, max_entries=1000) @cached(num_args=2, max_entries=1000)
def _get_state_group_from_group(self, group, types): def _get_state_group_from_group(self, group, types):
raise NotImplementedError() raise NotImplementedError()
@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000) @cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="event_to_state_groups", table="event_to_state_groups",

View file

@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'

View file

@ -25,8 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache from . import DEBUG_CACHES, register_cache
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
from collections import OrderedDict
import os import os
import functools import functools
@ -54,16 +53,11 @@ class Cache(object):
"metrics", "metrics",
) )
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): def __init__(self, name, max_entries=1000, keylen=1, tree=False):
if lru:
cache_type = TreeCache if tree else dict cache_type = TreeCache if tree else dict
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type max_size=max_entries, keylen=keylen, cache_type=cache_type
) )
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
@ -81,8 +75,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, key, default=_CacheSentinel): def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel) val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel: if val is not _CacheSentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
@ -94,19 +88,15 @@ class Cache(object):
else: else:
return default return default
def update(self, sequence, key, value): def update(self, sequence, key, value, callback=None):
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value) self.prefill(key, value, callback=callback)
def prefill(self, key, value): def prefill(self, key, value, callback=None):
if self.max_entries is not None: self.cache.set(key, value, callback=callback)
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
@ -151,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill", The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
@cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
""" """
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False): inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig self.orig = orig
@ -165,15 +167,33 @@ class CacheDescriptor(object):
self.max_entries = max_entries self.max_entries = max_entries
self.num_args = num_args self.num_args = num_args
self.lru = lru
self.tree = tree self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in all_args.args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
try:
self.arg_names.remove("cache_context")
except ValueError:
pass
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args: if len(self.arg_names) < self.num_args:
raise Exception( raise Exception(
"Not enough explicit positional arguments to key off of for %r." "Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)" " (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,) % (orig.__name__,)
) )
@ -182,16 +202,29 @@ class CacheDescriptor(object):
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
lru=self.lru,
tree=self.tree, tree=self.tree,
) )
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
# Add temp cache_context so inspect.getcallargs doesn't explode
if self.add_cache_context:
kwargs["cache_context"] = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext(cache, cache_key)
try: try:
cached_result_d = cache.get(cache_key) cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -228,7 +261,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret) cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
@ -297,6 +330,10 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name] list_args = arg_dict[self.list_name]
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = cache.get(tuple(key)) res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded(): if not res.has_succeeded():
res = res.observe() res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
cache.update(sequence, tuple(key), observer) cache.update(
sequence, tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key): def invalidate(f, key):
cache.invalidate(key) cache.invalidate(key)
@ -376,24 +416,29 @@ class CacheListDescriptor(object):
return wrapped return wrapped
def cached(max_entries=1000, num_args=1, lru=True, tree=False): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
def invalidate(self):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
lru=lru,
tree=tree, tree=tree,
cache_context=cache_context,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
lru=lru,
tree=tree, tree=tree,
inlineCallbacks=True, inlineCallbacks=True,
cache_context=cache_context,
) )

View file

@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object): class _Node(object):
__slots__ = ["prev_node", "next_node", "key", "value"] __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value): def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.prev_node = prev_node self.prev_node = prev_node
self.next_node = next_node self.next_node = next_node
self.key = key self.key = key
self.value = value self.value = value
self.callbacks = callbacks
class LruCache(object): class LruCache(object):
@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache. Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples. If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type() cache = cache_type()
@ -62,10 +66,10 @@ class LruCache(object):
return inner return inner
def add_node(key, value): def add_node(key, value, callbacks=set()):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value) node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node prev_node.next_node = node
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node prev_node.next_node = next_node
next_node.prev_node = prev_node next_node.prev_node = prev_node
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized @synchronized
def cache_get(key, default=None): def cache_get(key, default=None, callback=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
if callback:
node.callbacks.add(callback)
return node.value return node.value
else: else:
return default return default
@synchronized @synchronized
def cache_set(key, value): def cache_set(key, value, callback=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
if value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
else: else:
add_node(key, value) if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size: if len(cache) > max_size:
todelete = list_root.prev_node todelete = list_root.prev_node
delete_node(todelete) delete_node(todelete)
@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear(): def cache_clear():
list_root.next_node = list_root list_root.next_node = list_root
list_root.prev_node = list_root list_root.prev_node = list_root
for node in cache.values():
for cb in node.callbacks:
cb()
cache.clear() cache.clear()
@synchronized @synchronized

View file

@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt self.size -= cnt
return popped return popped
def values(self):
return [e.value for e in self.root.values()]
def __len__(self): def __len__(self):
return self.size return self.size

View file

@ -17,6 +17,8 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.descriptors import Cache, cached
@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3) cache.get(3)
def test_eviction_lru(self): def test_eviction_lru(self):
cache = Cache("test", max_entries=2, lru=True) cache = Cache("test", max_entries=2)
cache.prefill(1, "one") cache.prefill(1, "one")
cache.prefill(2, "two") cache.prefill(2, "two")
@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)

View file

@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
from mock import Mock
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1) self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get("key"), 1)
cache["key"] = 2 # Make sure overriding works.
self.assertEquals(cache.get("key"), 2)
def test_pop(self): def test_pop(self):
cache = LruCache(1) cache = LruCache(1)
@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1 cache["key"] = 1
cache.clear() cache.clear()
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
class LruCacheCallbacksTestCase(unittest.TestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_multi_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_set(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.set("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_pop(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.pop("key")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
cache.pop("key")
self.assertEquals(m.call_count, 1)
def test_del_multi(self):
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
cache.del_multi(("a",))
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
def test_clear(self):
m1 = Mock()
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
cache.clear()
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
def test_eviction(self):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.get("key2")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)