0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 00:22:05 +01:00

Merge branch 'develop' into matthew/brand-from-header

This commit is contained in:
Matthew Hodgson 2016-06-03 12:14:18 +01:00
commit 8d740132f4
35 changed files with 691 additions and 207 deletions

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains classes for authenticating the user."""
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
@ -42,13 +41,20 @@ AuthEventTypes = (
class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"guest = ",
@ -525,7 +531,7 @@ class Auth(object):
return default
@defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False):
def get_user_by_req(self, request, allow_guest=False, rights="access"):
""" Get a registered user's ID.
Args:
@ -547,7 +553,7 @@ class Auth(object):
)
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token)
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
@ -608,7 +614,7 @@ class Auth(object):
defer.returnValue(user_id)
@defer.inlineCallbacks
def get_user_by_access_token(self, token):
def get_user_by_access_token(self, token, rights="access"):
""" Get a registered user's ID.
Args:
@ -619,7 +625,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
ret = yield self.get_user_from_macaroon(token)
ret = yield self.get_user_from_macaroon(token, rights)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
@ -627,11 +633,11 @@ class Auth(object):
defer.returnValue(ret)
@defer.inlineCallbacks
def get_user_from_macaroon(self, macaroon_str):
def get_user_from_macaroon(self, macaroon_str, rights="access"):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
user_prefix = "user_id = "
user = None
@ -654,6 +660,13 @@ class Auth(object):
"is_guest": True,
"token_id": None,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"token_id": None,
}
else:
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
@ -685,7 +698,8 @@ class Auth(object):
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token this is (e.g. "access", "refresh")
type_string(str): The kind of token required (e.g. "access", "refresh",
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.

View file

@ -21,6 +21,7 @@ from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig
from synapse.config.emailconfig import EmailConfig
from synapse.config.key import KeyConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.storage.roommember import RoomMemberStore
@ -63,6 +64,26 @@ class SlaveConfig(DatabaseConfig):
self.pid_file = self.abspath(config.get("pid_file"))
self.public_baseurl = config["public_baseurl"]
# some things used by the auth handler but not actually used in the
# pusher codebase
self.bcrypt_rounds = None
self.ldap_enabled = None
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = None
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
# We would otherwise try to use the registration shared secret as the
# macaroon shared secret if there was no macaroon_shared_secret, but
# that means pulling in RegistrationConfig too. We don't need to be
# backwards compaitible in the pusher codebase so just make people set
# macaroon_shared_secret. We set this to None to prevent it referencing
# an undefined key.
self.registration_shared_secret = None
def default_config(self, server_name, **kwargs):
pid_file = self.abspath("pusher.pid")
return """\
@ -95,7 +116,7 @@ class SlaveConfig(DatabaseConfig):
""" % locals()
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig):
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
pass

View file

@ -529,6 +529,11 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)

View file

@ -26,9 +26,9 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.async import concurrently_execute, run_on_reactor
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@ -908,13 +908,16 @@ class MessageHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id
)
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
preserve_fn(_notify)()
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)

View file

@ -198,9 +198,8 @@ class SyncHandler(object):
@defer.inlineCallbacks
def push_rules_for_user(self, user):
user_id = user.to_string()
rawrules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
rules = format_push_rules_for_user(user, rawrules, enabled_map)
rules = yield self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
defer.returnValue(rules)
@defer.inlineCallbacks

View file

@ -33,11 +33,7 @@ from .metric import (
logger = logging.getLogger(__name__)
# We'll keep all the available metrics in a single toplevel dict, one shared
# for the entire process. We don't currently support per-HomeServer instances
# of metrics, because in practice any one python VM will host only one
# HomeServer anyway. This makes a lot of implementation neater
all_metrics = {}
all_metrics = []
class Metrics(object):
@ -53,7 +49,7 @@ class Metrics(object):
metric = metric_class(full_name, *args, **kwargs)
all_metrics[full_name] = metric
all_metrics.append(metric)
return metric
def register_counter(self, *args, **kwargs):
@ -84,12 +80,12 @@ def render_all():
# TODO(paul): Internal hack
update_resource_metrics()
for name in sorted(all_metrics.keys()):
for metric in all_metrics:
try:
strs += all_metrics[name].render()
strs += metric.render()
except Exception:
strs += ["# FAILED to render %s" % name]
logger.exception("Failed to render %s metric", name)
strs += ["# FAILED to render"]
logger.exception("Failed to render metric")
strs.append("") # to generate a final CRLF

View file

@ -47,9 +47,6 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)])
)
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing
@ -83,6 +80,9 @@ class CounterMetric(BaseMetric):
def render_item(self, k):
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
class CallbackMetric(BaseMetric):
"""A metric that returns the numeric value returned by a callback whenever
@ -126,30 +126,30 @@ class DistributionMetric(object):
class CacheMetric(object):
"""A combination of two CounterMetrics, one to count cache hits and one to
count a total, and a callback metric to yield the current size.
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
This metric generates standard metric name pairs, so that monitoring rules
can easily be applied to measure hit ratio."""
def __init__(self, name, size_callback, labels=[]):
def __init__(self, name, size_callback, cache_name):
self.name = name
self.cache_name = cache_name
self.hits = CounterMetric(name + ":hits", labels=labels)
self.total = CounterMetric(name + ":total", labels=labels)
self.hits = 0
self.misses = 0
self.size = CallbackMetric(
name + ":size",
callback=size_callback,
labels=labels,
)
self.size_callback = size_callback
def inc_hits(self, *values):
self.hits.inc(*values)
self.total.inc(*values)
def inc_hits(self):
self.hits += 1
def inc_misses(self, *values):
self.total.inc(*values)
def inc_misses(self):
self.misses += 1
def render(self):
return self.hits.render() + self.total.render() + self.size.render()
size = self.size_callback()
hits = self.hits
total = self.misses + self.hits
return [
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
]

View file

@ -40,7 +40,7 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store
event, self.hs, self.store, context.current_state
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(

View file

@ -18,10 +18,9 @@ import ujson as json
from twisted.internet import defer
from .baserules import list_with_base_rules
from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, Membership
from synapse.visibility import filter_events_for_clients
@ -38,62 +37,41 @@ def decode_rule_json(rule):
@defer.inlineCallbacks
def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
rules_by_user = {
uid: list_with_base_rules([
decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, [])
])
for uid in user_ids
}
# We apply the rules-enabled map here: bulk_get_push_rules doesn't
# fetch disabled rules, but this won't account for any server default
# rules the user has disabled, so we need to do this too.
for uid in user_ids:
user_enabled_map = rules_enabled_by_user.get(uid)
if not user_enabled_map:
continue
for i, rule in enumerate(rules_by_user[uid]):
rule_id = rule['rule_id']
if rule_id in user_enabled_map:
if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(user_enabled_map[rule_id])
rules_by_user[uid][i] = rule
defer.returnValue(rules_by_user)
@defer.inlineCallbacks
def evaluator_for_event(event, hs, store):
def evaluator_for_event(event, hs, store, current_state):
room_id = event.room_id
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
users_with_pushers = yield store.get_users_with_pushers_in_room(room_id)
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
all_in_room = yield store.get_users_in_room(room_id)
all_in_room = set(all_in_room)
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
)
receipts = yield store.get_receipts_for_room(room_id, "m.read")
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
user_ids = set(users_with_pushers)
for r in receipts:
if hs.is_mine_id(r['user_id']) and r['user_id'] in all_in_room:
user_ids.add(r['user_id'])
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
@ -104,8 +82,6 @@ def evaluator_for_event(event, hs, store):
if has_pusher:
user_ids.add(invited_user)
user_ids = list(user_ids)
rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator(
@ -143,7 +119,10 @@ class BulkPushRuleEvaluator:
self.store, user_tuples, [event], {event.event_id: current_state}
)
room_members = yield self.store.get_users_in_room(self.room_id)
room_members = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
)
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))

View file

@ -23,10 +23,7 @@ import copy
import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
def load_rules_for_user(user, rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@ -35,7 +32,26 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule['rule_id']
if rule_id in enabled_map:
if rule.get('enabled', True) != bool(enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(enabled_map[rule_id])
rules[i] = rule
return rules
def format_push_rules_for_user(user, ruleslist):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
rules = {'global': {}, 'device': {}}
@ -60,9 +76,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
if 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True

View file

@ -279,5 +279,5 @@ class EmailPusher(object):
logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail(
self.user_id, self.email, push_actions, reason
self.app_id, self.user_id, self.email, push_actions, reason
)

View file

@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
"in the %s room..."
"in the %(room)s room..."
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
@ -81,6 +81,7 @@ class Mailer(object):
def __init__(self, hs, app_name):
self.hs = hs
self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler()
self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
@ -96,7 +97,8 @@ class Mailer(object):
)
@defer.inlineCallbacks
def send_notification_mail(self, user_id, email_address, push_actions, reason):
def send_notification_mail(self, app_id, user_id, email_address,
push_actions, reason):
try:
from_string = self.hs.config.email_notif_from % {
"app": self.app_name
@ -167,7 +169,9 @@ class Mailer(object):
template_vars = {
"user_display_name": user_display_name,
"unsubscribe_link": self.make_unsubscribe_link(),
"unsubscribe_link": self.make_unsubscribe_link(
user_id, app_id, email_address
),
"summary_text": summary_text,
"app_name": self.app_name,
"rooms": rooms,
@ -433,9 +437,18 @@ class Mailer(object):
notif['room_id'], notif['event_id']
)
def make_unsubscribe_link(self):
# XXX: matrix.to
return "https://vector.im/#/settings"
def make_unsubscribe_link(self, user_id, app_id, email_address):
params = {
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
"app_id": app_id,
"pushkey": email_address,
}
# XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl,
urllib.urlencode(params),
)
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":

View file

@ -15,7 +15,10 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.account_data import AccountDataStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore):
@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore):
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = (
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore):
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore):
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
user_id, data_type = row[1:3]
position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,)
)
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("room_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("tag_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.config.appservice import load_appservices
class SlavedApplicationServiceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
self.services_cache = load_appservices(
hs.config.server_name,
hs.config.app_service_config_files
)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__

View file

@ -23,6 +23,7 @@ from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
@ -57,6 +58,9 @@ class SlavedEventStore(BaseSlavedStore):
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
@ -87,6 +91,9 @@ class SlavedEventStore(BaseSlavedStore):
_get_state_group_from_group = (
StateStore.__dict__["_get_state_group_from_group"]
)
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__
@ -109,10 +116,16 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = DataStore._set_before_and_after
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
@ -220,9 +233,9 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
# self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering
# )
self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering
)
self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state():

View file

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.filtering import FilteringStore
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedFilteringStore, self).__init__(db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]

View file

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPresenceStore, self).__init__(db_conn, hs)
self._presence_id_gen = SlavedIdTracker(
db_conn, "presence_stream", "stream_id",
)
self._presence_on_startup = self._get_active_presence(db_conn)
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
_get_active_presence = DataStore._get_active_presence.__func__
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
get_presence_for_users = DataStore.get_presence_for_users.__func__
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication(self, result):
stream = result.get("presence")
if stream:
self._presence_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.presence_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedPresenceStore, self).process_replication(result)

View file

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore):
def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("push_rules")
if stream:
for row in stream["rows"]:
position = row[0]
user_id = row[2]
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
self.push_rules_stream_cache.entity_has_changed(
user_id, position
)
self._push_rules_stream_id_gen.advance(int(stream["position"]))
return super(SlavedPushRuleStore, self).process_replication(result)

View file

@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore):
db_conn, "receipts_linearized", "stream_id"
)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_linearized_receipts_for_room = (
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore):
if stream:
self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
room_id, receipt_type, user_id = row[1:4]
position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.registration import RegistrationStore
class SlavedRegistrationStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
].orig
_query_for_auth = DataStore._query_for_auth.__func__

View file

@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rawrules = yield self.store.get_push_rules_for_user(user_id)
rules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
rules = format_push_rules_for_user(requester.user, rules)
path = request.postpath[1:]

View file

@ -17,7 +17,11 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import (
parse_json_object_from_request, parse_string, RestServlet
)
from synapse.http.server import finish_request
from synapse.api.errors import StoreError
from .base import ClientV1RestServlet, client_path_patterns
@ -136,6 +140,57 @@ class PushersSetRestServlet(ClientV1RestServlet):
return 200, {}
class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
PATTERNS = client_path_patterns("/pushers/remove$")
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_v1auth()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
pusher_pool = self.hs.get_pusherpool()
try:
yield pusher_pool.remove_pusher(
app_id=app_id,
pushkey=pushkey,
user_id=user.to_string(),
)
except StoreError as se:
if se.code != 404:
# This is fine: they're already unsubscribed
raise
self.notifier.on_new_replication_data()
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
defer.returnValue(None)
def on_OPTIONS(self, _):
return 200, {}
def register_servlets(hs, http_server):
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)

View file

@ -149,7 +149,7 @@ class DataStore(RoomMemberStore, RoomStore,
"AccountDataAndTagsChangeCache", account_max,
)
self.__presence_on_startup = self._get_active_presence(db_conn)
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
@ -190,8 +190,8 @@ class DataStore(RoomMemberStore, RoomStore,
super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
active_on_startup = self.__presence_on_startup
self.__presence_on_startup = None
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
return active_on_startup
def _get_active_presence(self, db_conn):

View file

@ -342,9 +342,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
)
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))

View file

@ -15,6 +15,7 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules
from twisted.internet import defer
import logging
@ -23,6 +24,29 @@ import simplejson as json
logger = logging.getLogger(__name__)
def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule['rule_id']
if rule_id in enabled_map:
if rule.get('enabled', True) != bool(enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(enabled_map[rule_id])
rules[i] = rule
return rules
class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True)
def get_push_rules_for_user(self, user_id):
@ -42,7 +66,11 @@ class PushRuleStore(SQLBaseStore):
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
defer.returnValue(rows)
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
rules = _load_rules(rows, enabled_map)
defer.returnValue(rules)
@cachedInlineCallbacks(lru=True)
def get_push_rules_enabled_for_user(self, user_id):
@ -85,6 +113,14 @@ class PushRuleStore(SQLBaseStore):
for row in rows:
results.setdefault(row['user_name'], []).append(row)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {})
)
defer.returnValue(results)
@cachedList(cached_method_name="get_push_rules_enabled_for_user",

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from canonicaljson import encode_canonical_json
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
import logging
import simplejson as json
@ -135,19 +135,35 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@cachedInlineCallbacks(num_args=1)
def get_users_with_pushers_in_room(self, room_id):
users = yield self.get_users_in_room(room_id)
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch(
table='pushers',
column='user_name',
iterable=users,
retcols=['user_name'],
desc='get_users_with_pushers_in_room'
keyvalues={
'user_name': 'user_id',
},
retcol='user_name',
desc='get_if_user_has_pusher',
allow_none=True,
)
defer.returnValue([r['user_name'] for r in result])
defer.returnValue(bool(result))
@cachedList(cached_method_name="get_if_user_has_pusher",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def get_if_users_have_pushers(self, user_ids):
rows = yield self._simple_select_many_batch(
table='pushers',
column='user_name',
iterable=user_ids,
retcols=['user_name'],
desc='get_if_users_have_pushers'
)
result = {user_id: False for user_id in user_ids}
result.update({r['user_name']: True for r in rows})
defer.returnValue(result)
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
@ -178,16 +194,16 @@ class PusherStore(SQLBaseStore):
},
)
if newly_inserted:
# get_users_with_pushers_in_room only cares if the user has
# get_if_user_has_pusher only cares if the user has
# at least *one* pusher.
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
yield self.runInteraction("add_pusher", f)
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id):
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
self._simple_delete_one_txn(
txn,

View file

@ -34,6 +34,26 @@ class ReceiptsStore(SQLBaseStore):
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts))
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
if receipt_type != "m.read":
return
# Returns an ObservableDeferred
res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
if res and res.called and user_id in res.result:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
return
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
@ -228,6 +248,10 @@ class ReceiptsStore(SQLBaseStore):
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id, receipt_type, user_id,
)
txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
@ -373,6 +397,10 @@ class ReceiptsStore(SQLBaseStore):
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id, receipt_type, user_id,
)
txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)

View file

@ -58,9 +58,6 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
)
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
@ -241,23 +238,10 @@ class RoomMemberStore(SQLBaseStore):
return results
@cached(max_entries=5000)
@cachedInlineCallbacks(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
return self.runInteraction(
"get_joined_hosts_for_room",
self._get_joined_hosts_for_room_txn,
room_id,
)
def _get_joined_hosts_for_room_txn(self, txn, room_id):
rows = self._get_members_rows_txn(
txn,
room_id, membership=Membership.JOIN
)
joined_domains = set(get_domain_from_id(r["user_id"]) for r in rows)
return joined_domains
user_ids = yield self.get_users_in_room(room_id)
defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn(

View file

@ -102,6 +102,15 @@ class ObservableDeferred(object):
def observers(self):
return self._observers
def has_called(self):
return self._result is not None
def has_succeeded(self):
return self._result is not None and self._result[0] is True
def get_result(self):
return self._result[1]
def __getattr__(self, name):
return getattr(self._deferred, name)

View file

@ -24,11 +24,21 @@ DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
cache_counter = metrics.register_cache(
"cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
labels=["name"],
)
# cache_counter = metrics.register_cache(
# "cache",
# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
# labels=["name"],
# )
def register_cache(name, cache):
caches_by_name[name] = cache
return metrics.register_cache(
"cache",
lambda: len(cache),
name,
)
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache

View file

@ -22,7 +22,7 @@ from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
from . import caches_by_name, DEBUG_CACHES, cache_counter
from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
@ -33,6 +33,7 @@ import functools
import inspect
import threading
logger = logging.getLogger(__name__)
@ -43,6 +44,15 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class Cache(object):
__slots__ = (
"cache",
"max_entries",
"name",
"keylen",
"sequence",
"thread",
"metrics",
)
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru:
@ -59,7 +69,7 @@ class Cache(object):
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@ -74,10 +84,10 @@ class Cache(object):
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
self.metrics.inc_hits()
return val
cache_counter.inc_misses(self.name)
self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
@ -293,16 +303,21 @@ class CacheListDescriptor(object):
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
cached = {}
results = {}
cached_defers = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
res = cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
res = cache.get(tuple(key))
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached_defers[arg] = res
else:
results[arg] = res.get_result()
except KeyError:
missing.append(arg)
@ -340,12 +355,21 @@ class CacheListDescriptor(object):
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
cached_defers[arg] = res
return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
if cached_defers:
def update_results_dict(res):
results.update(res)
return results
return preserve_context_over_deferred(defer.gatherResults(
cached_defers.values(),
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
unwrapFirstError
))
else:
return results
obj.__dict__[self.orig.__name__] = wrapped

View file

@ -15,7 +15,7 @@
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
from . import caches_by_name, cache_counter
from . import register_cache
import threading
import logging
@ -43,7 +43,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
caches_by_name[name] = self.cache
self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@ -58,7 +58,7 @@ class DictionaryCache(object):
def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
cache_counter.inc_hits(self.name)
self.metrics.inc_hits()
if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value))
@ -69,7 +69,7 @@ class DictionaryCache(object):
if k in entry.value
})
cache_counter.inc_misses(self.name)
self.metrics.inc_misses()
return DictionaryEntry(False, {})
def invalidate(self, key):

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches import cache_counter, caches_by_name
from synapse.util.caches import register_cache
import logging
@ -49,7 +49,7 @@ class ExpiringCache(object):
self._cache = {}
caches_by_name[cache_name] = self._cache
self.metrics = register_cache(cache_name, self._cache)
def start(self):
if not self._expiry_ms:
@ -78,9 +78,9 @@ class ExpiringCache(object):
def __getitem__(self, key):
try:
entry = self._cache[key]
cache_counter.inc_hits(self._cache_name)
self.metrics.inc_hits()
except KeyError:
cache_counter.inc_misses(self._cache_name)
self.metrics.inc_misses()
raise
if self._reset_expiry_on_get:

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches import cache_counter, caches_by_name
from synapse.util.caches import register_cache
from blist import sorteddict
@ -42,7 +42,7 @@ class StreamChangeCache(object):
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
caches_by_name[self.name] = self._cache
self.metrics = register_cache(self.name, self._cache)
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
@ -53,19 +53,19 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
cache_counter.inc_misses(self.name)
self.metrics.inc_misses()
return True
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
cache_counter.inc_hits(self.name)
self.metrics.inc_hits()
return False
if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name)
self.metrics.inc_misses()
return True
cache_counter.inc_hits(self.name)
self.metrics.inc_hits()
return False
def get_entities_changed(self, entities, stream_pos):
@ -82,10 +82,10 @@ class StreamChangeCache(object):
self._cache[k] for k in keys[i:]
).intersection(entities)
cache_counter.inc_hits(self.name)
self.metrics.inc_hits()
else:
result = entities
cache_counter.inc_misses(self.name)
self.metrics.inc_misses()
return result

View file

@ -61,9 +61,6 @@ class CounterMetricTestCase(unittest.TestCase):
'vector{method="PUT"} 1',
])
# Check that passing too few values errors
self.assertRaises(ValueError, counter.inc)
class CallbackMetricTestCase(unittest.TestCase):
@ -138,27 +135,27 @@ class CacheMetricTestCase(unittest.TestCase):
def test_cache(self):
d = dict()
metric = CacheMetric("cache", lambda: len(d))
metric = CacheMetric("cache", lambda: len(d), "cache_name")
self.assertEquals(metric.render(), [
'cache:hits 0',
'cache:total 0',
'cache:size 0',
'cache:hits{name="cache_name"} 0',
'cache:total{name="cache_name"} 0',
'cache:size{name="cache_name"} 0',
])
metric.inc_misses()
d["key"] = "value"
self.assertEquals(metric.render(), [
'cache:hits 0',
'cache:total 1',
'cache:size 1',
'cache:hits{name="cache_name"} 0',
'cache:total{name="cache_name"} 1',
'cache:size{name="cache_name"} 1',
])
metric.inc_hits()
self.assertEquals(metric.render(), [
'cache:hits 1',
'cache:total 2',
'cache:size 1',
'cache:hits{name="cache_name"} 1',
'cache:total{name="cache_name"} 2',
'cache:size{name="cache_name"} 1',
])