0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-02 09:28:20 +02:00

Merge branch 'release-v0.9.2'

This commit is contained in:
Erik Johnston 2015-06-12 11:53:03 +01:00
commit 405f8c4796
33 changed files with 552 additions and 286 deletions

View file

@ -1,3 +1,26 @@
Changes in synapse v0.9.2 (2015-06-12)
======================================
General:
* Use ultrajson for json (de)serialisation when a canonical encoding is not
required. Ultrajson is significantly faster than simplejson in certain
circumstances.
* Use connection pools for outgoing HTTP connections.
* Process thumbnails on separate threads.
Configuration:
* Add option, ``gzip_responses``, to disable HTTP response compression.
Federation:
* Improve resilience of backfill by ensuring we fetch any missing auth events.
* Improve performance of backfill and joining remote rooms by removing
unnecessary computations. This included handling events we'd previously
handled as well as attempting to compute the current state for outliers.
Changes in synapse v0.9.1 (2015-05-26) Changes in synapse v0.9.1 (2015-05-26)
====================================== ======================================

View file

@ -21,3 +21,5 @@ handlers:
root: root:
level: INFO level: INFO
handlers: [journal] handlers: [journal]
disable_existing_loggers: False

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.9.1" __version__ = "0.9.2"

View file

@ -54,6 +54,8 @@ from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse import events
from daemonize import Daemonize from daemonize import Daemonize
import twisted.manhole.telnet import twisted.manhole.telnet
@ -85,10 +87,16 @@ class SynapseHomeServer(HomeServer):
return MatrixFederationHttpClient(self) return MatrixFederationHttpClient(self)
def build_resource_for_client(self): def build_resource_for_client(self):
return gz_wrap(ClientV1RestResource(self)) res = ClientV1RestResource(self)
if self.config.gzip_responses:
res = gz_wrap(res)
return res
def build_resource_for_client_v2_alpha(self): def build_resource_for_client_v2_alpha(self):
return gz_wrap(ClientV2AlphaRestResource(self)) res = ClientV2AlphaRestResource(self)
if self.config.gzip_responses:
res = gz_wrap(res)
return res
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
@ -415,6 +423,8 @@ def setup(config_options):
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string) logger.info("Server version: %s", version_string)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
if re.search(":[0-9]+$", config.server_name): if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name domain_with_port = config.server_name
else: else:

View file

@ -26,6 +26,7 @@ class CaptchaConfig(Config):
config["captcha_ip_origin_is_x_forwarded"] config["captcha_ip_origin_is_x_forwarded"]
) )
self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name):
return """\ return """\
@ -48,4 +49,7 @@ class CaptchaConfig(Config):
# A secret key used to bypass the captcha test entirely. # A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE" #captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
""" """

View file

@ -39,7 +39,7 @@ class RegistrationConfig(Config):
## Registration ## ## Registration ##
# Enable registration for new users. # Enable registration for new users.
enable_registration: True enable_registration: False
# If set, allows registration by anyone who also has the shared # If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.

View file

@ -28,6 +28,8 @@ class ServerConfig(Config):
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"] self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.gzip_responses = config["gzip_responses"]
# Attempt to guess the content_addr for the v0 content repostitory # Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr") content_addr = config.get("content_addr")
@ -85,6 +87,11 @@ class ServerConfig(Config):
# Turn on the twisted telnet manhole service on localhost on the given # Turn on the twisted telnet manhole service on localhost on the given
# port. # port.
#manhole: 9000 #manhole: 9000
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
gzip_responses: True
""" % locals() """ % locals()
def read_arguments(self, args): def read_arguments(self, args):

View file

@ -16,6 +16,12 @@
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting
# a dict to frozen_dicts is expensive.
USE_FROZEN_DICTS = True
class _EventInternalMetadata(object): class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict): def __init__(self, internal_metadata_dict):
self.__dict__ = dict(internal_metadata_dict) self.__dict__ = dict(internal_metadata_dict)
@ -122,7 +128,10 @@ class FrozenEvent(EventBase):
unsigned = dict(event_dict.pop("unsigned", {})) unsigned = dict(event_dict.pop("unsigned", {}))
frozen_dict = freeze(event_dict) if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
super(FrozenEvent, self).__init__( super(FrozenEvent, self).__init__(
frozen_dict, frozen_dict,

View file

@ -18,8 +18,6 @@ from twisted.internet import defer
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -120,16 +118,15 @@ class FederationBase(object):
) )
except SynapseError: except SynapseError:
logger.warn( logger.warn(
"Signature check failed for %s redacted to %s", "Signature check failed for %s",
encode_canonical_json(pdu.get_pdu_json()), pdu.event_id,
encode_canonical_json(redacted_pdu_json),
) )
raise raise
if not check_event_content_hash(pdu): if not check_event_content_hash(pdu):
logger.warn( logger.warn(
"Event content has been tampered, redacting %s, %s", "Event content has been tampered, redacting.",
pdu.event_id, encode_canonical_json(pdu.get_dict()) pdu.event_id,
) )
defer.returnValue(redacted_event) defer.returnValue(redacted_event)

View file

@ -93,6 +93,8 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request)
logger.info("Request from %s", origin)
defer.returnValue((origin, content)) defer.returnValue((origin, content))
@log_function @log_function

View file

@ -78,7 +78,9 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder) context = yield state_handler.compute_event_context(builder)
if builder.is_state(): if builder.is_state():
builder.prev_state = context.prev_state_events builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context) yield self.auth.add_auth_events(builder, context)

View file

@ -187,8 +187,8 @@ class AuthHandler(BaseHandler):
# each request # each request
try: try:
client = SimpleHttpClient(self.hs) client = SimpleHttpClient(self.hs)
data = yield client.post_urlencoded_get_json( resp_body = yield client.post_urlencoded_get_json(
"https://www.google.com/recaptcha/api/siteverify", self.hs.config.recaptcha_siteverify_api,
args={ args={
'secret': self.hs.config.recaptcha_private_key, 'secret': self.hs.config.recaptcha_private_key,
'response': user_response, 'response': user_response,
@ -198,7 +198,8 @@ class AuthHandler(BaseHandler):
except PartialDownloadError as pde: except PartialDownloadError as pde:
# Twisted is silly # Twisted is silly
data = pde.response data = pde.response
resp_body = simplejson.loads(data) resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']: if 'success' in resp_body and resp_body['success']:
defer.returnValue(True) defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)

View file

@ -247,9 +247,15 @@ class FederationHandler(BaseHandler):
if set(e_id for e_id, _ in ev.prev_events) - event_ids if set(e_id for e_id, _ in ev.prev_events) - event_ids
] ]
logger.info(
"backfill: Got %d events with %d edges",
len(events), len(edges),
)
# For each edge get the current state. # For each edge get the current state.
auth_events = {} auth_events = {}
state_events = {}
events_to_state = {} events_to_state = {}
for e_id in edges: for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room( state, auth = yield self.replication_layer.get_state_for_room(
@ -258,12 +264,46 @@ class FederationHandler(BaseHandler):
event_id=e_id event_id=e_id
) )
auth_events.update({a.event_id: a for a in auth}) auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state events_to_state[e_id] = state
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
yield defer.gatherResults( yield defer.gatherResults(
[ [
self._handle_new_event(dest, a) self._handle_new_event(
dest, a,
auth_events={
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
},
)
for a in auth_events.values() for a in auth_events.values()
if a.event_id not in seen_events
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -274,6 +314,11 @@ class FederationHandler(BaseHandler):
dest, event_map[e_id], dest, event_map[e_id],
state=events_to_state[e_id], state=events_to_state[e_id],
backfilled=True, backfilled=True,
auth_events={
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
},
) )
for e_id in events_to_state for e_id in events_to_state
], ],
@ -900,8 +945,10 @@ class FederationHandler(BaseHandler):
event.event_id, event.signatures, event.event_id, event.signatures,
) )
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
event, old_state=state event, old_state=state, outlier=outlier,
) )
if not auth_events: if not auth_events:
@ -912,7 +959,7 @@ class FederationHandler(BaseHandler):
event.event_id, auth_events, event.event_id, auth_events,
) )
is_new_state = not event.internal_metadata.is_outlier() is_new_state = not outlier
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.

View file

@ -20,7 +20,8 @@ import synapse.metrics
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.client import ( from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError Agent, readBody, FileBodyProducer, PartialDownloadError,
HTTPConnectionPool,
) )
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
@ -55,7 +56,9 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is # The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation # BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser' # 'like a browser'
self.agent = Agent(reactor) pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = Agent(reactor, pool=pool)
self.version_string = hs.version_string self.version_string = hs.version_string
def request(self, method, *args, **kwargs): def request(self, method, *args, **kwargs):

View file

@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
@ -103,7 +103,9 @@ class MatrixFederationHttpClient(object):
self.hs = hs self.hs = hs
self.signing_key = hs.config.signing_key[0] self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.agent = MatrixFederationHttpAgent(reactor) pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string self.version_string = hs.version_string

View file

@ -19,9 +19,10 @@ from synapse.api.errors import (
) )
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
import synapse.metrics import synapse.metrics
import synapse.events
from syutil.jsonutil import ( from syutil.jsonutil import (
encode_canonical_json, encode_pretty_printed_json encode_canonical_json, encode_pretty_printed_json, encode_json
) )
from twisted.internet import defer from twisted.internet import defer
@ -168,9 +169,10 @@ class JsonResource(HttpServer, resource.Resource):
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"]) _PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
def __init__(self, hs): def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.canonical_json = canonical_json
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.path_regexs = {} self.path_regexs = {}
self.version_string = hs.version_string self.version_string = hs.version_string
@ -256,6 +258,7 @@ class JsonResource(HttpServer, resource.Resource):
response_code_message=response_code_message, response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request), pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string, version_string=self.version_string,
canonical_json=self.canonical_json,
) )
@ -277,11 +280,16 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False, def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False, response_code_message=None, pretty_print=False,
version_string=""): version_string="", canonical_json=True):
if pretty_print: if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n" json_bytes = encode_pretty_printed_json(json_object) + "\n"
else: else:
json_bytes = encode_canonical_json(json_object) if canonical_json:
json_bytes = encode_canonical_json(json_object)
else:
json_bytes = encode_json(
json_object, using_frozen_dicts=synapse.events.USE_FROZEN_DICTS
)
return respond_with_json_bytes( return respond_with_json_bytes(
request, code, json_bytes, request, code, json_bytes,

View file

@ -24,6 +24,7 @@ import baserules
import logging import logging
import simplejson as json import simplejson as json
import re import re
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -256,134 +257,154 @@ class Pusher(object):
logger.info("Pusher %s for user %s starting from token %s", logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token) self.pushkey, self.user_name, self.last_token)
wait = 0
while self.alive: while self.alive:
from_tok = StreamToken.from_string(self.last_token) try:
config = PaginationConfig(from_token=from_tok, limit='1') if wait > 0:
chunk = yield self.evStreamHandler.get_stream( yield synapse.util.async.sleep(wait)
self.user_name, config, yield self.get_and_dispatch()
timeout=100*365*24*60*60*1000, affect_presence=False wait = 0
) except:
if wait == 0:
# limiting to 1 may get 1 event plus 1 presence event, so wait = 1
# pick out the actual event else:
single_event = None wait = min(wait * 2, 1800)
for c in chunk['chunk']: logger.exception(
if 'event_id' in c: # Hmmm... "Exception in pusher loop for pushkey %s. Pausing for %ds",
single_event = c self.pushkey, wait
break
if not single_event:
self.last_token = chunk['end']
continue
if not self.alive:
continue
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
) )
@defer.inlineCallbacks
def get_and_dispatch(self):
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=timeout, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
return
if not self.alive:
return
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
)
processed = True
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True processed = True
else: for pk in rejected:
rejected = yield self.dispatch_push(single_event, tweaks) if pk != self.pushkey:
self.has_unread = True # for sanity, we only remove the pushkey if it
if isinstance(rejected, list) or isinstance(rejected, tuple): # was the one we actually sent...
processed = True logger.warn(
for pk in rejected: ("Ignoring rejected pushkey %s because we"
if pk != self.pushkey: " didn't send it"), pk
# for sanity, we only remove the pushkey if it )
# was the one we actually sent... else:
logger.warn( logger.info(
("Ignoring rejected pushkey %s because we" "Pushkey %s was rejected: removing",
" didn't send it"), pk pk
) )
else: yield self.hs.get_pusherpool().remove_pusher(
logger.info( self.app_id, pk, self.user_name
"Pushkey %s was rejected: removing", )
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_name
)
if not self.alive: if not self.alive:
continue return
if processed: if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token_and_success( self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.user_name,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_name,
self.last_token, self.failing_since)
self.clock.time_msec() else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_name,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
) )
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since)
else: else:
if not self.failing_since: logger.warn("Failed to dispatch push for user %s "
self.failing_since = self.clock.time_msec() "(failing for %dms)."
self.store.update_pusher_failing_since( "Trying again in %dms",
self.app_id, self.user_name,
self.pushkey, self.clock.time_msec() - self.failing_since,
self.user_name, self.backoff_delay)
self.failing_since yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
) self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
if (self.failing_since and self.backoff_delay = Pusher.MAX_BACKOFF
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_name,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self): def stop(self):
self.alive = False self.alive = False

View file

@ -18,7 +18,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.6": ["syutil>=0.0.6"], "syutil>=0.0.7": ["syutil>=0.0.7"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -30,6 +30,7 @@ REQUIREMENTS = {
"frozendict>=0.4": ["frozendict"], "frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {
@ -52,8 +53,8 @@ def github_link(project, version, egg):
DEPENDENCY_LINKS = [ DEPENDENCY_LINKS = [
github_link( github_link(
project="matrix-org/syutil", project="matrix-org/syutil",
version="v0.0.6", version="v0.0.7",
egg="syutil-0.0.6", egg="syutil-0.0.7",
), ),
github_link( github_link(
project="matrix-org/matrix-angular-sdk", project="matrix-org/matrix-angular-sdk",

View file

@ -25,7 +25,7 @@ class ClientV1RestResource(JsonResource):
"""A resource for version 1 of the matrix client API.""" """A resource for version 1 of the matrix client API."""
def __init__(self, hs): def __init__(self, hs):
JsonResource.__init__(self, hs) JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs) self.register_servlets(self, hs)
@staticmethod @staticmethod

View file

@ -28,7 +28,7 @@ class ClientV2AlphaRestResource(JsonResource):
"""A resource for version 2 alpha of the matrix client API.""" """A resource for version 2 alpha of the matrix client API."""
def __init__(self, hs): def __init__(self, hs):
JsonResource.__init__(self, hs) JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs) self.register_servlets(self, hs)
@staticmethod @staticmethod

View file

@ -15,13 +15,14 @@
from .thumbnailer import Thumbnailer from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import (
cs_error, Codes, SynapseError cs_error, Codes, SynapseError
) )
from twisted.internet import defer from twisted.internet import defer, threads
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
@ -52,7 +53,7 @@ class BaseMediaResource(Resource):
def __init__(self, hs, filepaths): def __init__(self, hs, filepaths):
Resource.__init__(self) Resource.__init__(self)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.client = hs.get_http_client() self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -273,57 +274,65 @@ class BaseMediaResource(Resource):
if not requirements: if not requirements:
return return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id) input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width m_width = thumbnailer.width
m_height = thumbnailer.height m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels: def generate_thumbnails():
logger.info( if m_width * m_height >= self.max_image_pixels:
"Image too large to thumbnail %r x %r > %r", logger.info(
m_width, m_height, self.max_image_pixels "Image too large to thumbnail %r x %r > %r",
) m_width, m_height, self.max_image_pixels
return )
return
scales = set() scales = set()
crops = set() crops = set()
for r_width, r_height, r_method, r_type in requirements: for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale": if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height) t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add(( scales.add((
min(m_width, t_width), min(m_height, t_height), r_type, min(m_width, t_width), min(m_height, t_height), r_type,
)) ))
elif r_method == "crop": elif r_method == "crop":
crops.add((r_width, r_height, r_type)) crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales: for t_width, t_height, t_type in scales:
t_method = "scale" t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail( t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method server_name, file_id, t_width, t_height, t_type, t_method
) )
self._makedirs(t_path) self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail( remote_thumbnails.append([
server_name, media_id, file_id, server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len t_width, t_height, t_type, t_method, t_len
) ])
for t_width, t_height, t_type in crops: for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales: if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely # If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate # scaled one then there is no point in calculating a separate
# thumbnail. # thumbnail.
continue continue
t_method = "crop" t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail( t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method server_name, file_id, t_width, t_height, t_type, t_method
) )
self._makedirs(t_path) self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail( remote_thumbnails.append([
server_name, media_id, file_id, server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len t_width, t_height, t_type, t_method, t_len
) ])
yield threads.deferToThread(generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({ defer.returnValue({
"width": m_width, "width": m_width,

View file

@ -106,7 +106,7 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None, outlier=False):
""" Fills out the context with the `current state` of the graph. The """ Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph `current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event` just before the event - i.e. it never includes `event`
@ -119,9 +119,23 @@ class StateHandler(object):
Returns: Returns:
an EventContext an EventContext
""" """
yield run_on_reactor()
context = EventContext() context = EventContext()
yield run_on_reactor() if outlier:
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
context.current_state = {
(s.type, s.state_key): s for s in old_state
}
else:
context.current_state = {}
context.prev_state_events = []
context.state_group = None
defer.returnValue(context)
if old_state: if old_state:
context.current_state = { context.current_state = {
@ -155,10 +169,6 @@ class StateHandler(object):
context.current_state = curr_state context.current_state = curr_state
context.state_group = group if not event.is_state() else None context.state_group = group if not event.is_state() else None
prev_state = yield self.store.add_event_hashes(
prev_state
)
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state: if key in context.current_state:

View file

@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 19 SCHEMA_VERSION = 20
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -348,7 +348,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file module_name, absolute_path, python_file
) )
logger.debug("Running script %s", relative_path) logger.debug("Running script %s", relative_path)
module.run_upgrade(cur) module.run_upgrade(cur, database_engine)
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path) logger.debug("Applying schema %s", relative_path)

View file

@ -127,7 +127,7 @@ class Cache(object):
self.cache.clear() self.cache.clear()
def cached(max_entries=1000, num_args=1, lru=False): class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in The function is presumed to take zero or more arguments, which are used in
@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
""" """
def wrap(orig): def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
self.orig = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
cache = Cache( cache = Cache(
name=orig.__name__, name=self.orig.__name__,
max_entries=max_entries, max_entries=self.max_entries,
keylen=num_args, keylen=self.num_args,
lru=lru, lru=self.lru,
) )
@functools.wraps(orig) @functools.wraps(self.orig)
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped(self, *keyargs): def wrapped(*keyargs):
try: try:
cached_result = cache.get(*keyargs) cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES: if DEBUG_CACHES:
actual_result = yield orig(self, *keyargs) actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result: if actual_result != cached_result:
logger.error( logger.error(
"Stale cache entry %s%r: cached: %r, actual %r", "Stale cache entry %s%r: cached: %r, actual %r",
orig.__name__, keyargs, self.orig.__name__, keyargs,
cached_result, actual_result, cached_result, actual_result,
) )
raise ValueError("Stale cache entry") raise ValueError("Stale cache entry")
@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = cache.sequence sequence = cache.sequence
ret = yield orig(self, *keyargs) ret = yield self.orig(obj, *keyargs)
cache.update(sequence, *keyargs + (ret,)) cache.update(sequence, *keyargs[:self.num_args] + (ret,))
defer.returnValue(ret) defer.returnValue(ret)
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill wrapped.prefill = cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped return wrapped
return wrap
def cached(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
class LoggingTransaction(object): class LoggingTransaction(object):

View file

@ -17,7 +17,7 @@ from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.events import FrozenEvent from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logcontext import preserve_context_over_deferred
@ -26,11 +26,11 @@ from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64 from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_json
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
import simplejson as json import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -166,8 +166,9 @@ class EventsStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
metadata_json = encode_canonical_json( metadata_json = encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict(),
using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8") ).decode("UTF-8")
# If we have already persisted this event, we don't need to do any # If we have already persisted this event, we don't need to do any
@ -235,12 +236,14 @@ class EventsStore(SQLBaseStore):
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"internal_metadata": metadata_json, "internal_metadata": metadata_json,
"json": encode_canonical_json(event_dict).decode("UTF-8"), "json": encode_json(
event_dict, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8"),
}, },
) )
content = encode_canonical_json( content = encode_json(
event.content event.content, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8") ).decode("UTF-8")
vals = { vals = {
@ -266,8 +269,8 @@ class EventsStore(SQLBaseStore):
] ]
} }
vals["unrecognized_keys"] = encode_canonical_json( vals["unrecognized_keys"] = encode_json(
unrec unrec, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8") ).decode("UTF-8")
sql = ( sql = (
@ -733,7 +736,8 @@ class EventsStore(SQLBaseStore):
because = yield self.get_event( because = yield self.get_event(
redaction_id, redaction_id,
check_redacted=False check_redacted=False,
allow_none=True,
) )
if because: if because:
@ -743,6 +747,7 @@ class EventsStore(SQLBaseStore):
prev = yield self.get_event( prev = yield self.get_event(
ev.unsigned["replaces_state"], ev.unsigned["replaces_state"],
get_prev_content=False, get_prev_content=False,
allow_none=True,
) )
if prev: if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]

View file

@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur): def run_upgrade(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex") cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall(): for row in cur.fetchall():
try: try:

View file

@ -0,0 +1,76 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main purpose of this upgrade is to change the unique key on the
pushers table again (it was missed when the v16 full schema was
made) but this also changes the pushkey and data columns to text.
When selecting a bytea column into a text column, postgres inserts
the hex encoded data, and there's no portable way of getting the
UTF-8 bytes, so we have to do it in Python.
"""
import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
access_token BIGINT DEFAULT NULL,
profile_tag VARCHAR(32) NOT NULL,
kind VARCHAR(8) NOT NULL,
app_id VARCHAR(64) NOT NULL,
app_display_name VARCHAR(64) NOT NULL,
device_display_name VARCHAR(128) NOT NULL,
pushkey TEXT NOT NULL,
ts BIGINT NOT NULL,
lang VARCHAR(8),
data TEXT,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
""")
cur.execute("""SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
""")
count = 0
for row in cur.fetchall():
row = list(row)
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
cur.execute(database_engine.convert_param_style("""
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
row
)
count += 1
cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count)

View file

@ -81,19 +81,23 @@ class StateStore(SQLBaseStore):
f, f,
) )
@defer.inlineCallbacks state_list = yield defer.gatherResults(
def c(vals):
vals[:] = yield self._get_events(vals, get_prev_content=False)
yield defer.gatherResults(
[ [
c(vals) self._fetch_events_for_group(group, vals)
for vals in states.values() for group, vals in states.items()
], ],
consumeErrors=True, consumeErrors=True,
) )
defer.returnValue(states) defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, state_group, events):
return self._get_events(
events, get_prev_content=False
).addCallback(
lambda evs: (state_group, evs)
)
def _store_state_groups_txn(self, txn, event, context): def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None: if context.current_state is None:

View file

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
class JsonEncodedObject(object): class JsonEncodedObject(object):
""" A common base class for defining protocol units that are represented """ A common base class for defining protocol units that are represented
@ -76,15 +74,7 @@ class JsonEncodedObject(object):
if k in self.valid_keys and k not in self.internal_keys if k in self.valid_keys and k not in self.internal_keys
} }
d.update(self.unrecognized_keys) d.update(self.unrecognized_keys)
return copy.deepcopy(d) return d
def get_full_dict(self):
d = {
k: _encode(v) for (k, v) in self.__dict__.items()
if k in self.valid_keys or k in self.internal_keys
}
d.update(self.unrecognized_keys)
return copy.deepcopy(d)
def __str__(self): def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))

View file

@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase):
return defer.succeed({}) return defer.succeed({})
self.datastore.have_events.side_effect = have_events self.datastore.have_events.side_effect = have_events
def annotate(ev, old_state=None): def annotate(ev, old_state=None, outlier=False):
context = Mock() context = Mock()
context.current_state = {} context.current_state = {}
context.auth_events = {} context.auth_events = {}
@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase):
) )
self.state_handler.compute_event_context.assert_called_once_with( self.state_handler.compute_event_context.assert_called_once_with(
ANY, old_state=None, ANY, old_state=None, outlier=False
) )
self.auth.check.assert_called_once_with(ANY, auth_events={}) self.auth.check.assert_called_once_with(ANY, auth_events={})

View file

@ -42,6 +42,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room", "get_room",
"store_room", "store_room",
"get_latest_events_in_room", "get_latest_events_in_room",
"add_event_hashes",
]), ]),
resource_for_federation=NonCallableMock(), resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]), http_client=NonCallableMock(spec_set=[]),
@ -88,6 +89,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1) self.datastore.persist_event.return_value = (1,1)
self.datastore.add_event_hashes.return_value = []
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite(self): def test_invite(self):

View file

@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_passthrough(self): def test_passthrough(self):
@cached() class A(object):
def func(self, key): @cached()
return key def func(self, key):
return key
self.assertEquals((yield func(self, "foo")), "foo") a = A()
self.assertEquals((yield func(self, "bar")), "bar")
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_hit(self): def test_hit(self):
callcount = [0] callcount = [0]
@cached() class A(object):
def func(self, key): @cached()
callcount[0] += 1 def func(self, key):
return key callcount[0] += 1
return key
yield func(self, "foo") a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
self.assertEquals((yield func(self, "foo")), "foo") self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invalidate(self): def test_invalidate(self):
callcount = [0] callcount = [0]
@cached() class A(object):
def func(self, key): @cached()
callcount[0] += 1 def func(self, key):
return key callcount[0] += 1
return key
yield func(self, "foo") a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
func.invalidate("foo") a.func.invalidate("foo")
yield func(self, "foo") yield a.func("foo")
self.assertEquals(callcount[0], 2) self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self): def test_invalidate_missing(self):
@cached() class A(object):
def func(self, key): @cached()
return key def func(self, key):
return key
func.invalidate("what") A().func.invalidate("what")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_max_entries(self): def test_max_entries(self):
callcount = [0] callcount = [0]
@cached(max_entries=10) class A(object):
def func(self, key): @cached(max_entries=10)
callcount[0] += 1 def func(self, key):
return key callcount[0] += 1
return key
for k in range(0,12): a = A()
yield func(self, k)
for k in range(0, 12):
yield a.func(k)
self.assertEquals(callcount[0], 12) self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate # There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times # all 12 values again, we must get called at least 2 more times
for k in range(0,12): for k in range(0,12):
yield func(self, k) yield a.func(k)
self.assertTrue(callcount[0] >= 14, self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0])) msg="Expected callcount >= 14, got %d" % (callcount[0]))
@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase):
def test_prefill(self): def test_prefill(self):
callcount = [0] callcount = [0]
@cached() class A(object):
def func(self, key): @cached()
callcount[0] += 1 def func(self, key):
return key callcount[0] += 1
return key
func.prefill("foo", 123) a = A()
self.assertEquals((yield func(self, "foo")), 123) a.func.prefill("foo", 123)
self.assertEquals((yield a.func("foo")), 123)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)

View file

@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id)) (yield self.store.get_user_by_id(self.user_id))
) )
result = yield self.store.get_user_by_token(self.tokens[1]) result = yield self.store.get_user_by_token(self.tokens[0])
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {