0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-29 07:58:19 +02:00

Merge branch 'develop' into dbkr/notifications_api

This commit is contained in:
Matthew Hodgson 2016-08-20 00:16:18 +01:00
commit 6e80c03d45
24 changed files with 750 additions and 91 deletions

View file

@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- Python 2.7
- At least 512 MB RAM.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the

97
docs/workers.rst Normal file
View file

@ -0,0 +1,97 @@
Scaling synapse via workers
---------------------------
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
postgres anyway if you care about scalability).
The workers communicate with the master synapse process via a synapse-specific
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
- port: 9092
bind_address: '127.0.0.1'
type: http
tls: false
x_forwarded: false
resources:
- names: [replication]
compress: false
Under **no circumstances** should this replication API listener be exposed to the
public internet; it currently implements no authentication whatsoever and is
unencrypted HTTP.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_url).
For instance::
worker_app: synapse.app.synchrotron
# The replication listener on the synapse to talk to.
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
worker_listeners:
- type: http
port: 8083
resources:
- names:
- client
worker_daemonize: True
worker_pid_file: /home/matrix/synapse/synchrotron.pid
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
in the given directory, e.g.::
synctl -a $CONFIG/workers start
Currently one should always restart all workers when restarting or upgrading
synapse, unless you explicitly know it's safe not to. For instance, restarting
synapse without restarting all the synchrotrons may result in broken typing
notifications.
To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!

View file

@ -81,13 +81,17 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None):
sender=None, id=None, protocols=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.namespaces = self._check_namespaces(namespaces)
self.id = id
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
@ -219,6 +223,9 @@ class ApplicationService(object):
or user_id == self.sender
)
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind
import logging
import urllib
@ -24,6 +25,28 @@ import urllib
logger = logging.getLogger(__name__)
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
@ -71,6 +94,43 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events)

View file

@ -123,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
@ -130,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
protocols=protocols,
)

View file

@ -159,6 +159,22 @@ class ApplicationServicesHandler(object):
)
defer.returnValue(result)
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([
self.appservice_api.query_3pe(service, kind, protocol, fields)
for service in services
], consumeErrors=True)
ret = []
for (success, result) in results:
if success:
ret.extend(result)
defer.returnValue(ret)
@defer.inlineCallbacks
def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event.
@ -187,6 +203,14 @@ class ApplicationServicesHandler(object):
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):

View file

@ -40,7 +40,7 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.current_state
event, self.hs, self.store, context.state_group, context.current_state
)
with Measure(self.clock, "action_for_event_by_user"):

View file

@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks
def evaluator_for_event(event, hs, store, current_state):
room_id = event.room_id
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
def evaluator_for_event(event, hs, store, state_group, current_state):
rules_by_user = yield store.bulk_get_push_rules_for_room(
event.room_id, state_group, current_state
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher:
user_ids.add(invited_user)
rules_by_user = yield _get_rules(room_id, user_ids, store)
rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user
)
defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, user_ids, store
event.room_id, rules_by_user, store
))
@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562)
"""
def __init__(self, room_id, rules_by_user, users_in_room, store):
def __init__(self, room_id, rules_by_user, store):
self.room_id = room_id
self.rules_by_user = rules_by_user
self.users_in_room = users_in_room
self.store = store
@defer.inlineCallbacks

View file

@ -48,6 +48,7 @@ from synapse.rest.client.v2_alpha import (
openid,
notifications,
devices,
thirdparty,
)
from synapse.http.server import JsonResource
@ -94,3 +95,4 @@ class ClientRestResource(JsonResource):
openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)

View file

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
defer.returnValue((200, results))
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
defer.returnValue((200, results))
def register_servlets(hs, http_server):
ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server)

View file

@ -166,7 +166,7 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(

View file

@ -366,8 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
def get_new_events_for_appservice_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e, appservice_stream_position AS a"
" WHERE a.stream_ordering < e.stream_ordering AND e.stream_ordering <= ?"
" FROM events AS e"
" WHERE"
" (SELECT stream_ordering FROM appservice_stream_position)"
" < e.stream_ordering"
" AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)

View file

@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
)
self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):

View file

@ -16,6 +16,7 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer
import logging
@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True)
@cachedInlineCallbacks()
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules)
@cachedInlineCallbacks(lru=True)
@cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results)
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
cache_context):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and self.hs.is_mine_id(e.state_key)
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate,
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids):

View file

@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
@cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch(
table='pushers',

View file

@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.

View file

@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
@cached(lru=True)
@cached()
def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id)

View file

@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, lru=True, max_entries=1000)
@cached(num_args=2, max_entries=1000)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000)
@cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",

View file

@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'

View file

@ -25,8 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
from collections import OrderedDict
from collections import namedtuple
import os
import functools
@ -54,16 +53,11 @@ class Cache(object):
"metrics",
)
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru:
cache_type = TreeCache if tree else dict
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type
)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
def __init__(self, name, max_entries=1000, keylen=1, tree=False):
cache_type = TreeCache if tree else dict
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type
)
self.name = name
self.keylen = keylen
@ -81,8 +75,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@ -94,19 +88,15 @@ class Cache(object):
else:
return default
def update(self, sequence, key, value):
def update(self, sequence, key, value, callback=None):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value)
self.prefill(key, value, callback=callback)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback)
def invalidate(self, key):
self.check_thread()
@ -151,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
@cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
inlineCallbacks=False):
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@ -165,15 +167,33 @@ class CacheDescriptor(object):
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in all_args.args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
try:
self.arg_names.remove("cache_context")
except ValueError:
pass
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
" (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
@ -182,16 +202,29 @@ class CacheDescriptor(object):
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
tree=self.tree,
)
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
# Add temp cache_context so inspect.getcallargs doesn't explode
if self.add_cache_context:
kwargs["cache_context"] = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext(cache, cache_key)
try:
cached_result_d = cache.get(cache_key)
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@ -228,7 +261,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret)
cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@ -297,6 +330,10 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
res = cache.get(tuple(key))
res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
cache.update(sequence, tuple(key), observer)
cache.update(
sequence, tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
@ -376,24 +416,29 @@ class CacheListDescriptor(object):
return wrapped
def cached(max_entries=1000, num_args=1, lru=True, tree=False):
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
def invalidate(self):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
tree=tree,
cache_context=cache_context,
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
tree=tree,
inlineCallbacks=True,
cache_context=cache_context,
)

View file

@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object):
__slots__ = ["prev_node", "next_node", "key", "value"]
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value):
def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
self.callbacks = callbacks
class LruCache(object):
@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
@ -62,10 +66,10 @@ class LruCache(object):
return inner
def add_node(key, value):
def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value)
node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized
def cache_get(key, default=None):
def cache_get(key, default=None, callback=None):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
if callback:
node.callbacks.add(callback)
return node.value
else:
return default
@synchronized
def cache_set(key, value):
def cache_set(key, value, callback=None):
node = cache.get(key, None)
if node is not None:
if value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
move_node_to_front(node)
node.value = value
else:
add_node(key, value)
if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear():
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
for cb in node.callbacks:
cb()
cache.clear()
@synchronized

View file

@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt
return popped
def values(self):
return [e.value for e in self.root.values()]
def __len__(self):
return self.size

View file

@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
from mock import Mock
from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
def test_eviction_lru(self):
cache = Cache("test", max_entries=2, lru=True)
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)

View file

@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from mock import Mock
class LruCacheTestCase(unittest.TestCase):
@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1)
cache["key"] = 2 # Make sure overriding works.
self.assertEquals(cache.get("key"), 2)
def test_pop(self):
cache = LruCache(1)
@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)
class LruCacheCallbacksTestCase(unittest.TestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_multi_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_set(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.set("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_pop(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.pop("key")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
cache.pop("key")
self.assertEquals(m.call_count, 1)
def test_del_multi(self):
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
cache.del_multi(("a",))
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
def test_clear(self):
m1 = Mock()
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
cache.clear()
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
def test_eviction(self):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.get("key2")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)