0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 20:33:53 +01:00

Merge branch 'develop' into email_login

This commit is contained in:
David Baker 2015-08-20 10:16:01 +01:00
commit c50ad14bae
60 changed files with 1759 additions and 714 deletions

View file

@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
exit 1 exit 1
fi fi
find "$DIR" -name "*.log" -delete for port in 8080 8081 8082; do
find "$DIR" -name "*.db" -delete rm -rf $DIR/$port
rm -rf $DIR/media_store.$port
done
rm -rf $DIR/etc rm -rf $DIR/etc

View file

@ -8,14 +8,6 @@ cd "$DIR/.."
mkdir -p demo/etc mkdir -p demo/etc
# Check the --no-rate-limit param
PARAMS=""
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
fi
fi
export PYTHONPATH=$(readlink -f $(pwd)) export PYTHONPATH=$(readlink -f $(pwd))
@ -31,10 +23,20 @@ for port in 8080 8081 8082; do
#rm $DIR/etc/$port.config #rm $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--generate-config \ --generate-config \
--enable_registration \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
# Check script parameters
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
# Set high limits in config file to disable rate limiting
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
fi
fi
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
-D \ -D \

View file

@ -16,3 +16,6 @@ ignore =
docs/* docs/*
pylint.cfg pylint.cfg
tox.ini tox.ini
[flake8]
max-line-length = 90

View file

@ -48,7 +48,7 @@ setup(
description="Reference Synapse Home Server", description="Reference Synapse Home Server",
install_requires=dependencies['requirements'](include_conditional=True).keys(), install_requires=dependencies['requirements'](include_conditional=True).keys(),
setup_requires=[ setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 "Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial", "setuptools_trial",
"mock" "mock"
], ],

View file

@ -44,6 +44,11 @@ class Auth(object):
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
@ -319,7 +324,7 @@ class Auth(object):
Returns: Returns:
tuple : of UserID and device string: tuple : of UserID and device string:
User ID object of the user making the request User ID object of the user making the request
Client ID object of the client instance the user is using ClientInfo object of the client instance the user is using
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -347,12 +352,14 @@ class Auth(object):
if not user_id: if not user_id:
raise KeyError raise KeyError
request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(
(UserID.from_string(user_id), ClientInfo("", "")) (UserID.from_string(user_id), ClientInfo("", ""))
) )
return return
except KeyError: except KeyError:
pass # normal users won't have this query parameter set pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_token(access_token)
user = user_info["user"] user = user_info["user"]
@ -420,6 +427,7 @@ class Auth(object):
"Unrecognised access token.", "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
request.authenticated_entity = service.sender
defer.returnValue(service) defer.returnValue(service)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
@ -521,7 +529,6 @@ class Auth(object):
# Check state_key # Check state_key
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
if not event.state_key.startswith("_"):
if event.state_key.startswith("@"): if event.state_key.startswith("@"):
if event.state_key != event.user_id: if event.state_key != event.user_id:
raise AuthError( raise AuthError(

View file

@ -657,6 +657,7 @@ def run(hs):
if hs.config.daemonize: if hs.config.daemonize:
if hs.config.print_pidfile:
print hs.config.pid_file print hs.config.pid_file
daemon = Daemonize( daemon = Daemonize(

View file

@ -138,12 +138,19 @@ class Config(object):
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name"
) )
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
)
config_parser.add_argument( config_parser.add_argument(
"-H", "--server-name", "-H", "--server-name",
help="The server name to generate a config file for" help="The server name to generate a config file for"
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
generate_keys = config_args.generate_keys
if config_args.generate_config: if config_args.generate_config:
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
@ -151,34 +158,17 @@ class Config(object):
" generated using \"--generate-config -H SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
(config_path,) = config_args.config_path
config_dir_path = os.path.dirname(config_args.config_path[0]) if not os.path.exists(config_path):
config_dir_path = os.path.dirname(config_path)
config_dir_path = os.path.abspath(config_dir_path) config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name server_name = config_args.server_name
if not server_name: if not server_name:
print "Must specify a server_name to a generate config for." print "Must specify a server_name to a generate config for."
sys.exit(1) sys.exit(1)
(config_path,) = config_args.config_path
if not os.path.exists(config_dir_path): if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
if os.path.exists(config_path):
print "Config file %r already exists" % (config_path,)
yaml_config = cls.read_config_file(config_path)
yaml_name = yaml_config["server_name"]
if server_name != yaml_name:
print (
"Config file %r has a different server_name: "
" %r != %r" % (config_path, server_name, yaml_name)
)
sys.exit(1)
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
config.update(yaml_config)
print "Generating any missing keys for %r" % (server_name,)
obj.invoke_all("generate_files", config)
sys.exit(0)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config( config_bytes, config = obj.generate_config(
config_dir_path, server_name config_dir_path, server_name
@ -186,16 +176,22 @@ class Config(object):
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_bytes)
print ( print (
"A config file has been generated in %s for server name" "A config file has been generated in %r for server name"
" '%s' with corresponding SSL keys and self-signed" " %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to" " certificates. Please review this file and customise it"
" your needs." " to your needs."
) % (config_path, server_name) ) % (config_path, server_name)
print ( print (
"If this server name is incorrect, you will need to regenerate" "If this server name is incorrect, you will need to"
" the SSL certificates" " regenerate the SSL certificates"
) )
sys.exit(0) sys.exit(0)
else:
print (
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,)
generate_keys = True
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
parents=[config_parser], parents=[config_parser],
@ -213,7 +209,7 @@ class Config(object):
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
config_dir_path = os.path.dirname(config_args.config_path[0]) config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path) config_dir_path = os.path.abspath(config_dir_path)
specified_config = {} specified_config = {}
@ -226,6 +222,10 @@ class Config(object):
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
obj.invoke_all("read_config", config) obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args) obj.invoke_all("read_arguments", args)

View file

@ -24,6 +24,7 @@ class ServerConfig(Config):
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"] self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])
@ -208,12 +209,18 @@ class ServerConfig(Config):
self.manhole = args.manhole self.manhole = args.manhole
if args.daemonize is not None: if args.daemonize is not None:
self.daemonize = args.daemonize self.daemonize = args.daemonize
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
def add_arguments(self, parser): def add_arguments(self, parser):
server_group = parser.add_argument_group("server") server_group = parser.add_argument_group("server")
server_group.add_argument("-D", "--daemonize", action='store_true', server_group.add_argument("-D", "--daemonize", action='store_true',
default=None, default=None,
help="Daemonize the home server") help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole", server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int, type=int,
help="Turn on the twisted telnet manhole" help="Turn on the twisted telnet manhole"

View file

@ -23,7 +23,7 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -134,6 +134,36 @@ class FederationClient(FederationBase):
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
) )
@log_function
def query_client_keys(self, destination, content):
"""Query device keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(destination, content)
@log_function
def claim_client_keys(self, destination, content):
"""Claims one-time keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(destination, content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def backfill(self, dest, context, limit, extremities): def backfill(self, dest, context, limit, extremities):

View file

@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
import simplejson as json
import logging import logging
@ -312,6 +313,48 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
results = yield self.store.claim_e2e_one_time_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_get_missing_events(self, origin, room_id, earliest_events, def on_get_missing_events(self, origin, room_id, earliest_events,

View file

@ -222,6 +222,76 @@ class TransportLayerClient(object):
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_client_keys(self, destination, query_content):
"""Query the device keys for a list of user ids hosted on a remote
server.
Request:
{
"device_keys": {
"<user_id>": ["<device_id>"]
} }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
} } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/keys/query"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content):
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
} } }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the one-time keys.
"""
path = PREFIX + "/user/keys/claim"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_missing_events(self, destination, room_id, earliest_events, def get_missing_events(self, destination, room_id, earliest_events,

View file

@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_claim_client_keys(origin, content)
defer.returnValue((200, response))
class FederationQueryAuthServlet(BaseFederationServlet): class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)" PATH = "/query_auth/([^/]*)/([^/]*)"
@ -373,4 +391,6 @@ SERVLET_CLASSES = (
FederationQueryAuthServlet, FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet,
) )

View file

@ -22,7 +22,6 @@ from .room import (
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .login import LoginHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .presence import PresenceHandler from .presence import PresenceHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@ -54,7 +53,6 @@ class Handlers(object):
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs) self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs) self.room_list_handler = RoomListHandler(hs)
self.login_handler = LoginHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)

View file

@ -47,17 +47,24 @@ class AuthHandler(BaseHandler):
self.sessions = {} self.sessions = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip=None): def check_auth(self, flows, clientdict, clientip):
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow. protocol and handles the login flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
Args: Args:
flows: list of list of stages flows (list): A list of login flows. Each flow is an ordered list of
authdict: The dictionary from the client root level, not the strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent. 'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns: Returns:
A tuple of authed, dict, dict where authed is true if the client A tuple of (authed, dict, dict) where authed is true if the client
has successfully completed an auth flow. If it is true, the first has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage. dict contains the authenticated credentials of each stage.
@ -75,7 +82,7 @@ class AuthHandler(BaseHandler):
del clientdict['auth'] del clientdict['auth']
if 'session' in authdict: if 'session' in authdict:
sid = authdict['session'] sid = authdict['session']
sess = self._get_session_info(sid) session = self._get_session_info(sid)
if len(clientdict) > 0: if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters # This was designed to allow the client to omit the parameters
@ -87,20 +94,19 @@ class AuthHandler(BaseHandler):
# on a home server. # on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data # Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary. # isn't arbintrary.
sess['clientdict'] = clientdict session['clientdict'] = clientdict
self._save_session(sess) self._save_session(session)
pass elif 'clientdict' in session:
elif 'clientdict' in sess: clientdict = session['clientdict']
clientdict = sess['clientdict']
if not authdict: if not authdict:
defer.returnValue( defer.returnValue(
(False, self._auth_dict_for_flows(flows, sess), clientdict) (False, self._auth_dict_for_flows(flows, session), clientdict)
) )
if 'creds' not in sess: if 'creds' not in session:
sess['creds'] = {} session['creds'] = {}
creds = sess['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
if 'type' in authdict: if 'type' in authdict:
@ -109,15 +115,15 @@ class AuthHandler(BaseHandler):
result = yield self.checkers[authdict['type']](authdict, clientip) result = yield self.checkers[authdict['type']](authdict, clientip)
if result: if result:
creds[authdict['type']] = result creds[authdict['type']] = result
self._save_session(sess) self._save_session(session)
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds) logger.info("Auth completed with creds: %r", creds)
self._remove_session(sess) self._remove_session(session)
defer.returnValue((True, creds, clientdict)) defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, sess) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict)) defer.returnValue((False, ret, clientdict))
@ -151,22 +157,13 @@ class AuthHandler(BaseHandler):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
user = authdict["user"] user_id = authdict["user"]
password = authdict["password"] password = authdict["password"]
if not user.startswith('@'): if not user_id.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_info = yield self.store.get_user_by_id(user_id=user) self._check_password(user_id, password)
if not user_info: defer.returnValue(user_id)
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@ -270,6 +267,59 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks
def login_with_password(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): User ID
password (str): Password
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
yield self._check_password(user_id, password)
reg_handler = self.hs.get_handlers().registration_handler
access_token = reg_handler.generate_token(user_id)
logger.info("Logging in user %s", user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token)
@defer.inlineCallbacks
def _check_password(self, user_id, password):
"""Checks that user_id has passed password, raises LoginError if not."""
user_info = yield self.store.get_user_by_id(user_id=user_id)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"]
if not bcrypt.checkpw(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id)
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)
def _save_session(self, session): def _save_session(self, session):
# TODO: Persistent storage # TODO: Persistent storage
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)

View file

@ -70,6 +70,14 @@ class EventStreamHandler(BaseHandler):
self._streams_per_user[auth_user] += 1 self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
auth_user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user) room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout: if timeout:

View file

@ -229,15 +229,15 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
states = yield self.store.get_state_for_events( event_to_state = yield self.store.get_state_for_events(
room_id, [e.event_id for e in events], room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
) )
events_and_states = zip(events, states) def redact_disallowed(event, state):
def redact_disallowed(event_and_state):
event, state = event_and_state
if not state: if not state:
return event return event
@ -271,11 +271,10 @@ class FederationHandler(BaseHandler):
return event return event
res = map(redact_disallowed, events_and_states) defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
logger.info("_filter_events_for_server %r", res) for e in events
])
defer.returnValue(res)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
@ -503,7 +502,7 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
states = yield defer.gatherResults([ states = yield defer.gatherResults([
self.state_handler.resolve_state_groups([e]) self.state_handler.resolve_state_groups(room_id, [e])
for e in event_ids for e in event_ids
]) ])
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))

View file

@ -1,83 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes
import bcrypt
import logging
logger = logging.getLogger(__name__)
class LoginHandler(BaseHandler):
def __init__(self, hs):
super(LoginHandler, self).__init__(hs)
self.hs = hs
@defer.inlineCallbacks
def login(self, user, password):
"""Login as the specified user with the specified password.
Args:
user (str): The user ID.
password (str): The password.
Returns:
The newly allocated access token.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
# TODO do this better, it can't go in __init__ else it cyclic loops
if not hasattr(self, "reg_handler"):
self.reg_handler = self.hs.get_handlers().registration_handler
# pull out the hash for this user if they exist
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it.
token = self.reg_handler._generate_token(user)
logger.info("Adding token %s for user %s", token, user)
yield self.store.add_access_token_to_user(user, token)
defer.returnValue(token)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, token_id=None):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
user_id, token_id
)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)

View file

@ -137,15 +137,15 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events): def _filter_events_for_client(self, user_id, room_id, events):
states = yield self.store.get_state_for_events( event_id_to_state = yield self.store.get_state_for_events(
room_id, [e.event_id for e in events], room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
) )
events_and_states = zip(events, states) def allowed(event, state):
def allowed(event_and_state):
event, state = event_and_state
if event.type == EventTypes.RoomHistoryVisibility: if event.type == EventTypes.RoomHistoryVisibility:
return True return True
@ -175,10 +175,10 @@ class MessageHandler(BaseHandler):
return True return True
events_and_states = filter(allowed, events_and_states)
defer.returnValue([ defer.returnValue([
ev event
for ev, _ in events_and_states for event in events
if allowed(event, event_id_to_state[event.event_id])
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
@ -401,8 +401,12 @@ class MessageHandler(BaseHandler):
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
# Only do N rooms at once
n = 5
d_list = [handle_room(e) for e in room_list]
for i in range(0, len(d_list), n):
yield defer.gatherResults( yield defer.gatherResults(
[handle_room(e) for e in room_list], d_list[i:i + n],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -456,20 +460,14 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
presence_defs = yield defer.DeferredList( states = yield presence_handler.get_states(
[ target_users=[UserID.from_string(m.user_id) for m in room_members],
presence_handler.get_state(
target_user=UserID.from_string(m.user_id),
auth_user=auth_user, auth_user=auth_user,
as_event=True, as_event=True,
check_auth=False, check_auth=False,
) )
for m in room_members
],
consumeErrors=True,
)
defer.returnValue([p for success, p in presence_defs if success]) defer.returnValue(states.values())
receipts_handler = self.hs.get_handlers().receipts_handler receipts_handler = self.hs.get_handlers().receipts_handler

View file

@ -192,6 +192,20 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False, check_auth=True): def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
"""Get the current presence state of the given user.
Args:
target_user (UserID): The user whose presence we want
auth_user (UserID): The user requesting the presence, used for
checking if said user is allowed to see the persence of the
`target_user`
as_event (bool): Format the return as an event or not?
check_auth (bool): Perform the auth checks or not?
Returns:
dict: The presence state of the `target_user`, whose format depends
on the `as_event` argument.
"""
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
if check_auth: if check_auth:
visible = yield self.is_presence_visible( visible = yield self.is_presence_visible(
@ -232,6 +246,81 @@ class PresenceHandler(BaseHandler):
else: else:
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks
def get_states(self, target_users, auth_user, as_event=False, check_auth=True):
"""A batched version of the `get_state` method that accepts a list of
`target_users`
Args:
target_users (list): The list of UserID's whose presence we want
auth_user (UserID): The user requesting the presence, used for
checking if said user is allowed to see the persence of the
`target_users`
as_event (bool): Format the return as an event or not?
check_auth (bool): Perform the auth checks or not?
Returns:
dict: A mapping from user -> presence_state
"""
local_users, remote_users = partitionbool(
target_users,
lambda u: self.hs.is_mine(u)
)
if check_auth:
for user in local_users:
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=user
)
if not visible:
raise SynapseError(404, "Presence information not visible")
results = {}
if local_users:
for user in local_users:
if user in self._user_cachemap:
results[user] = self._user_cachemap[user].get_state()
local_to_user = {u.localpart: u for u in local_users}
states = yield self.store.get_presence_states(
[u.localpart for u in local_users if u not in results]
)
for local_part, state in states.items():
if state is None:
continue
res = {"presence": state["state"]}
if "status_msg" in state and state["status_msg"]:
res["status_msg"] = state["status_msg"]
results[local_to_user[local_part]] = res
for user in remote_users:
# TODO(paul): Have remote server send us permissions set
results[user] = self._get_or_offline_usercache(user).get_state()
for state in results.values():
if "last_active" in state:
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
if as_event:
for user, state in results.items():
content = state
content["user_id"] = user.to_string()
if "last_active" in content:
content["last_active_ago"] = int(
self._clock.time_msec() - content.pop("last_active")
)
results[user] = {"type": "m.presence", "content": content}
defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def set_state(self, target_user, auth_user, state): def set_state(self, target_user, auth_user, state):

View file

@ -171,7 +171,6 @@ class ReceiptEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
defer.returnValue(([], from_key))
from_key = int(from_key) from_key = int(from_key)
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
@ -194,7 +193,6 @@ class ReceiptEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pagination_rows(self, user, config, key): def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key) to_key = int(config.from_key)
defer.returnValue(([], to_key))
if config.to_key: if config.to_key:
from_key = int(config.to_key) from_key = int(config.to_key)

View file

@ -91,7 +91,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -111,7 +111,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -161,7 +161,7 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.", 400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -208,7 +208,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self.generate_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -273,7 +273,7 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
def _generate_token(self, user_id): def generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace # urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as # all =s with .s so http clients don't quote =s when it is used as
# query params. # query params.

View file

@ -557,12 +557,6 @@ class RoomMemberHandler(BaseHandler):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given
membership states in.""" membership states in."""
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user( rooms = yield self.store.get_rooms_for_user(
user.to_string(), user.to_string(),
) )

View file

@ -96,9 +96,18 @@ class SyncHandler(BaseHandler):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user( room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user sync_config.user
) )
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user, room_ids, sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback sync_config.filter, timeout, current_sync_callback
@ -229,7 +238,16 @@ class SyncHandler(BaseHandler):
logger.debug("Typing %r", typing_by_room) logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user) app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
# TODO (mjark): Does public mean "published"? # TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True) published_rooms = yield self.store.get_rooms(is_public=True)
@ -294,15 +312,15 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events): def _filter_events_for_client(self, user_id, room_id, events):
states = yield self.store.get_state_for_events( event_id_to_state = yield self.store.get_state_for_events(
room_id, [e.event_id for e in events], room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
) )
events_and_states = zip(events, states) def allowed(event, state):
def allowed(event_and_state):
event, state = event_and_state
if event.type == EventTypes.RoomHistoryVisibility: if event.type == EventTypes.RoomHistoryVisibility:
return True return True
@ -331,10 +349,11 @@ class SyncHandler(BaseHandler):
return membership == Membership.INVITE return membership == Membership.INVITE
return True return True
events_and_states = filter(allowed, events_and_states)
defer.returnValue([ defer.returnValue([
ev event
for ev, _ in events_and_states for event in events
if allowed(event, event_id_to_state[event.event_id])
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
) )
class MatrixFederationHttpAgent(_AgentBase): class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory
def __init__(self, reactor, pool=None): def endpointForURI(self, uri):
_AgentBase.__init__(self, reactor, pool) destination = uri.netloc
def request(self, destination, endpoint, method, path, params, query, return matrix_federation_endpoint(
headers, body_producer): reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory
outgoing_requests_counter.inc(method) )
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
def _cb(response):
incoming_responses_counter.inc(method, response.code)
return response
def _eb(failure):
incoming_responses_counter.inc(method, "ERR")
return failure
d.addCallbacks(_cb, _eb)
return d
class MatrixFederationHttpClient(object): class MatrixFederationHttpClient(object):
@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
self.server_name = hs.hostname self.server_name = hs.hostname
pool = HTTPConnectionPool(reactor) pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10 pool.maxPersistentPerHost = 10
self.agent = MatrixFederationHttpAgent(reactor, pool=pool) self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1 self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination] headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse( url_bytes = self._create_url(
("", "", path_bytes, param_bytes, query_bytes, "",) destination, path_bytes, param_bytes, query_bytes
) )
txn_id = "%s-O-%s" % (method, self._next_id) txn_id = "%s-O-%s" % (method, self._next_id)
@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
# (once we have reliable transactions in place) # (once we have reliable transactions in place)
retries_left = 5 retries_left = 5
endpoint = preserve_context_over_fn( http_url_bytes = urlparse.urlunparse(
self._getEndpoint, reactor, destination ("", "", path_bytes, param_bytes, query_bytes, "")
) )
log_result = None log_result = None
@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object):
while True: while True:
producer = None producer = None
if body_callback: if body_callback:
producer = body_callback(method, url_bytes, headers_dict) producer = body_callback(method, http_url_bytes, headers_dict)
try: try:
def send_request(): def send_request():
request_deferred = self.agent.request( request_deferred = preserve_context_over_fn(
destination, self.agent.request,
endpoint,
method, method,
path_bytes, url_bytes,
param_bytes,
query_bytes,
Headers(headers_dict), Headers(headers_dict),
producer producer
) )
@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object):
defer.returnValue((length, headers)) defer.returnValue((length, headers))
def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
class _ReadBodyToFileProtocol(protocol.Protocol): class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size): def __init__(self, stream, deferred, max_size):

View file

@ -18,8 +18,12 @@ from __future__ import absolute_import
import logging import logging
from resource import getrusage, getpagesize, RUSAGE_SELF from resource import getrusage, getpagesize, RUSAGE_SELF
import functools
import os import os
import stat import stat
import time
from twisted.internet import reactor
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@ -144,3 +148,50 @@ def _process_fds():
return counts return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"]) get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
def runUntilCurrentTimer(func):
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
num_pending = 0
# _newTimedCalls is one long list of *all* pending calls. Below loop
# is based off of impl of reactor.runUntilCurrent
for delayed_call in reactor._newTimedCalls:
if delayed_call.time > now:
break
if delayed_call.delayed_time > 0:
continue
num_pending += 1
num_pending += len(reactor.threadCallQueue)
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending)
return ret
return f
try:
# Ensure the reactor has all the attributes we expect
reactor.runUntilCurrent
reactor._newTimedCalls
reactor.threadCallQueue
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
except AttributeError:
pass

View file

@ -94,17 +94,14 @@ class PusherPool:
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id): def remove_pushers_by_user(self, user_id):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access token %s", "Removing all pushers for user %s",
user_id, not_access_token_id user_id,
) )
for p in all: for p in all:
if ( if p['user_name'] == user_id:
p['user_name'] == user_id and
p['access_token'] != not_access_token_id
):
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']

View file

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.7": ["syutil>=0.0.7"], "syutil>=0.0.7": ["syutil>=0.0.7"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted>=15.1.0": ["twisted>=15.1.0"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],

View file

@ -85,9 +85,8 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create( user_id = UserID.create(
user_id, self.hs.hostname).to_string() user_id, self.hs.hostname).to_string()
handler = self.handlers.login_handler token = yield self.handlers.auth_handler.login_with_password(
token = yield handler.login( user_id=user_id
user=user_id,
password=login_submission["password"]) password=login_submission["password"])
result = { result = {

View file

@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
authed, result, params = yield self.auth_handler.check_auth([ authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
], body) ], body, self.hs.get_ip_from_request(request))
if not authed: if not authed:
defer.returnValue((401, result)) defer.returnValue((401, result))
@ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password'] new_password = params['new_password']
yield self.login_handler.set_password( yield self.auth_handler.set_password(
user_id, new_password, None user_id, new_password, None
) )
@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
logger.warn("Couldn't add 3pid: invalid response from ID sevrer") logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server") raise SynapseError(500, "Invalid response from ID Server")
yield self.login_handler.add_threepid( yield self.auth_handler.add_threepid(
auth_user.to_string(), auth_user.to_string(),
threepid['medium'], threepid['medium'],
threepid['address'], threepid['address'],

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import UserID
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from ._base import client_v2_pattern from ._base import client_v2_pattern
@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
super(KeyQueryServlet, self).__init__() super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
logger.debug("onPOST")
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
try: try:
body = json.loads(request.content.read()) body = json.loads(request.content.read())
except: except:
raise SynapseError(400, "Invalid key JSON") raise SynapseError(400, "Invalid key JSON")
query = [] result = yield self.handle_request(body)
for user_id, device_ids in body.get("device_keys", {}).items(): defer.returnValue(result)
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
defer.returnValue(self.json_result(request, results))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request) auth_user, client_info = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string() auth_user_id = auth_user.to_string()
if not user_id: user_id = user_id if user_id else auth_user_id
user_id = auth_user_id device_ids = [device_id] if device_id else []
if not device_id: result = yield self.handle_request(
device_id = None {"device_keys": {user_id: device_ids}}
# Returns a map of user_id->device_id->json_bytes. )
results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) defer.returnValue(result)
defer.returnValue(self.json_result(request, results))
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.setdefault(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)
def json_result(self, request, results):
json_result = {} json_result = {}
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items(): for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads( json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes json_bytes
) )
return (200, {"device_keys": json_result})
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm): def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
results = yield self.store.claim_e2e_one_time_keys( result = yield self.handle_request(
[(user_id, device_id, algorithm)] {"one_time_keys": {user_id: {device_id: algorithm}}}
) )
defer.returnValue(self.json_result(request, results)) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm): def on_POST(self, request, user_id, device_id, algorithm):
@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
body = json.loads(request.content.read()) body = json.loads(request.content.read())
except: except:
raise SynapseError(400, "Invalid key JSON") raise SynapseError(400, "Invalid key JSON")
query = [] result = yield self.handle_request(body)
for user_id, device_keys in body.get("one_time_keys", {}).items(): defer.returnValue(result)
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm)) @defer.inlineCallbacks
results = yield self.store.claim_e2e_one_time_keys(query) def handle_request(self, body):
defer.returnValue(self.json_result(request, results)) local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
remote_queries.setdefault(user.domain, {})[user_id] = (
device_keys
)
results = yield self.store.claim_e2e_one_time_keys(local_query)
def json_result(self, request, results):
json_result = {} json_result = {}
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
json_result.setdefault(user_id, {})[device_id] = { json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes) key_id: json.loads(json_bytes)
} }
return (200, {"one_time_keys": json_result})
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"one_time_keys": json_result}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View file

@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/register*") PATTERN = client_v2_pattern("/register")
def __init__(self, hs): def __init__(self, hs):
super(RegisterRestServlet, self).__init__() super(RegisterRestServlet, self).__init__()
@ -50,7 +50,6 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -148,7 +147,7 @@ class RegisterRestServlet(RestServlet):
if reqd not in threepid: if reqd not in threepid:
logger.info("Can't add incomplete 3pid") logger.info("Can't add incomplete 3pid")
else: else:
yield self.login_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,
threepid['medium'], threepid['medium'],
threepid['address'], threepid['address'],
@ -224,6 +223,9 @@ class RegisterRestServlet(RestServlet):
if k not in body: if k not in body:
absent.append(k) absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -231,9 +233,6 @@ class RegisterRestServlet(RestServlet):
if existingUid is not None: if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
ret = yield self.identity_handler.requestEmailToken(**body) ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret)) defer.returnValue((200, ret))

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes from synapse.api.auth import AuthEventTypes
@ -96,7 +96,7 @@ class StateHandler(object):
cache.ts = self.clock.time_msec() cache.ts = self.clock.time_msec()
state = cache.state state = cache.state
else: else:
res = yield self.resolve_state_groups(event_ids) res = yield self.resolve_state_groups(room_id, event_ids)
state = res[1] state = res[1]
if event_type: if event_type:
@ -155,13 +155,13 @@ class StateHandler(object):
if event.is_state(): if event.is_state():
ret = yield self.resolve_state_groups( ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
event_type=event.type, event_type=event.type,
state_key=event.state_key, state_key=event.state_key,
) )
else: else:
ret = yield self.resolve_state_groups( ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
group, curr_state, prev_state = ret group, curr_state, prev_state = ret
@ -180,7 +180,7 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups(self, event_ids, event_type=None, state_key=""): def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -205,7 +205,7 @@ class StateHandler(object):
) )
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
event_ids room_id, event_ids
) )
logger.debug( logger.debug(

View file

@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip) key = (user.to_string(), access_token, device_id, ip)
try: try:
last_seen = self.client_ip_last_seen.get(*key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None) defer.returnValue(None)
self.client_ip_last_seen.prefill(*key + (now,)) self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint, # It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
) )
logger.debug("Running script %s", relative_path) logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine) module.run_upgrade(cur, database_engine)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path) logger.debug("Applying schema %s", relative_path)

View file

@ -17,21 +17,20 @@ import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple
import functools
import sys import sys
import time import time
import threading import threading
DEBUG_CACHES = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,159 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
caches_by_name = {}
cache_counter = metrics.register_cache(
"cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
labels=["name"],
)
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if keyargs in self.cache:
cache_counter.inc_hits(self.name)
return self.cache[keyargs]
cache_counter.inc_misses(self.name)
raise KeyError()
def update(self, sequence, *args):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args)
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[keyargs] = value
def invalidate(self, *keyargs):
self.check_thread()
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(keyargs, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
self.orig = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
@functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(*keyargs):
try:
cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES:
actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = yield self.orig(obj, *keyargs)
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
defer.returnValue(ret)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
@ -321,6 +167,8 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000)
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0

View file

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -104,7 +105,7 @@ class DirectoryStore(SQLBaseStore):
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
@ -114,7 +115,7 @@ class DirectoryStore(SQLBaseStore):
room_alias, room_alias,
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id) defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):

View file

@ -15,7 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from syutil.base64util import encode_base64 from syutil.base64util import encode_base64
import logging import logging
@ -362,7 +363,7 @@ class EventFederationStore(SQLBaseStore):
for room_id in events_by_room: for room_id in events_by_room:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
@ -505,4 +506,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?" query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))

View file

@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
if current_state: if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id) txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn( self._simple_delete_txn(
@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
if not context.rejected: if not context.rejected:
txn.call_after( txn.call_after(
self.get_current_state_for_key.invalidate, self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key (event.room_id, event.type, event.state_key,)
) )
if event.type in [EventTypes.Name, EventTypes.Aliases]: if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after( txn.call_after(
self.get_room_name_and_aliases.invalidate, self.get_room_name_and_aliases.invalidate,
event.room_id (event.room_id,)
) )
self._simple_upsert_txn( self._simple_upsert_txn(
@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True): for check_redacted in (False, True):
for get_prev_content in (False, True): for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted, self._get_event_cache.invalidate(
get_prev_content) (event_id, check_redacted, get_prev_content)
)
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events: for event_id in events:
try: try:
ret = self._get_event_cache.get( ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content (event_id, check_redacted, get_prev_content,)
) )
if allow_rejected or not ret.rejected_reason: if allow_rejected or not ret.rejected_reason:
@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
defer.returnValue(ev) defer.returnValue(ev)
@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
return ev return ev

View file

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore, cached from _base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -71,8 +72,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate", desc="store_server_certificate",
) )
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_all_server_verify_keys(self, server_name): def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="server_signature_keys", table="server_signature_keys",
@ -132,7 +132,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key", desc="store_server_verify_key",
) )
self.get_all_server_verify_keys.invalidate(server_name) self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):

View file

@ -13,19 +13,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
from twisted.internet import defer from twisted.internet import defer
class PresenceStore(SQLBaseStore): class PresenceStore(SQLBaseStore):
def create_presence(self, user_localpart): def create_presence(self, user_localpart):
return self._simple_insert( res = self._simple_insert(
table="presence", table="presence",
values={"user_id": user_localpart}, values={"user_id": user_localpart},
desc="create_presence", desc="create_presence",
) )
self.get_presence_state.invalidate((user_localpart,))
return res
def has_presence_state(self, user_localpart): def has_presence_state(self, user_localpart):
return self._simple_select_one( return self._simple_select_one(
table="presence", table="presence",
@ -35,6 +39,7 @@ class PresenceStore(SQLBaseStore):
desc="has_presence_state", desc="has_presence_state",
) )
@cached(max_entries=2000)
def get_presence_state(self, user_localpart): def get_presence_state(self, user_localpart):
return self._simple_select_one( return self._simple_select_one(
table="presence", table="presence",
@ -43,8 +48,27 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_state", desc="get_presence_state",
) )
@cachedList(get_presence_state.cache, list_name="user_localparts")
def get_presence_states(self, user_localparts):
def f(txn):
results = {}
for user_localpart in user_localparts:
res = self._simple_select_one_txn(
txn,
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"],
allow_none=True,
)
if res:
results[user_localpart] = res
return results
return self.runInteraction("get_presence_states", f)
def set_presence_state(self, user_localpart, new_state): def set_presence_state(self, user_localpart, new_state):
return self._simple_update_one( res = self._simple_update_one(
table="presence", table="presence",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"], updatevalues={"state": new_state["state"],
@ -53,6 +77,9 @@ class PresenceStore(SQLBaseStore):
desc="set_presence_state", desc="set_presence_state",
) )
self.get_presence_state.invalidate((user_localpart,))
return res
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(
table="presence_allow_inbound", table="presence_allow_inbound",
@ -98,7 +125,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted", desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result) defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
@ -133,4 +160,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
desc="del_presence_list", desc="del_presence_list",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))

View file

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -23,8 +24,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_for_user(self, user_name): def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table=PushRuleTable.table_name, table=PushRuleTable.table_name,
@ -41,8 +41,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name): def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name, table=PushRuleEnableTable.table_name,
@ -153,11 +152,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -189,10 +188,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -218,8 +217,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule", desc="delete_push_rule",
) )
self.get_push_rules_for_user.invalidate(user_name) self.get_push_rules_for_user.invalidate((user_name,))
self.get_push_rules_enabled_for_user.invalidate(user_name) self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled): def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -240,10 +239,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id}, {'id': new_id},
) )
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )

View file

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches import cache_counter, caches_by_name
from twisted.internet import defer from twisted.internet import defer
from synapse.util import unwrapFirstError
from blist import sorteddict from blist import sorteddict
import logging import logging
import ujson as json import ujson as json
@ -53,19 +53,13 @@ class ReceiptsStore(SQLBaseStore):
self, room_ids, from_key self, room_ids, from_key
) )
results = yield defer.gatherResults( results = yield self._get_linearized_receipts_for_rooms(
[ room_ids, to_key, from_key=from_key
self.get_linearized_receipts_for_room(
room_id, to_key, from_key=from_key
) )
for room_id in room_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue([ev for res in results for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@defer.inlineCallbacks @cachedInlineCallbacks(num_args=3, max_entries=5000)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.
@ -125,11 +119,70 @@ class ReceiptsStore(SQLBaseStore):
"content": content, "content": content,
}]) }])
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
defer.returnValue({})
def f(txn):
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
) % (
",".join(["?"] * len(room_ids))
)
args = list(room_ids)
args.extend([from_key, to_key])
txn.execute(sql, args)
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id <= ?"
) % (
",".join(["?"] * len(room_ids))
)
args = list(room_ids)
args.append(to_key)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
txn_results = yield self.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(row["room_id"], {
"type": "m.receipt",
"room_id": row["room_id"],
"content": {},
})
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = json.loads(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
for room_id in room_ids
}
defer.returnValue(results)
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self) return self._receipts_id_gen.get_max_token(self)
@cached @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_graph_receipts_for_room(self, room_id): def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers. """Get receipts for sending to remote servers.
""" """
@ -305,6 +358,8 @@ class _RoomStreamChangeCache(object):
self._room_to_key = {} self._room_to_key = {}
self._cache = sorteddict() self._cache = sorteddict()
self._earliest_key = None self._earliest_key = None
self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache
@defer.inlineCallbacks @defer.inlineCallbacks
def get_rooms_changed(self, store, room_ids, key): def get_rooms_changed(self, store, room_ids, key):
@ -318,8 +373,11 @@ class _RoomStreamChangeCache(object):
result = set( result = set(
self._cache[k] for k in keys[i:] self._cache[k] for k in keys[i:]
).intersection(room_ids) ).intersection(room_ids)
cache_counter.inc_hits(self.name)
else: else:
result = room_ids result = room_ids
cache_counter.inc_misses(self.name)
defer.returnValue(result) defer.returnValue(result)

View file

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
class RegistrationStore(SQLBaseStore): class RegistrationStore(SQLBaseStore):
@ -111,16 +112,16 @@ class RegistrationStore(SQLBaseStore):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens_apart_from(self, user_id, token_id): def user_delete_access_tokens(self, user_id):
yield self.runInteraction( yield self.runInteraction(
"user_delete_access_tokens_apart_from", "user_delete_access_tokens",
self._user_delete_access_tokens_apart_from, user_id, token_id self._user_delete_access_tokens, user_id
) )
def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id): def _user_delete_access_tokens(self, txn, user_id):
txn.execute( txn.execute(
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?", "DELETE FROM access_tokens WHERE user_id = ?",
(user_id, token_id) (user_id, )
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -131,7 +132,7 @@ class RegistrationStore(SQLBaseStore):
user_id user_id
) )
for r in rows: for r in rows:
self.get_user_by_token.invalidate(r) self.get_user_by_token.invalidate((r,))
@cached() @cached()
def get_user_by_token(self, token): def get_user_by_token(self, token):

View file

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
import collections import collections
import logging import logging
@ -186,8 +187,7 @@ class RoomStore(SQLBaseStore):
} }
) )
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):
sql = ( sql = (

View file

@ -17,7 +17,8 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import UserID from synapse.types import UserID
@ -54,9 +55,9 @@ class RoomMemberStore(SQLBaseStore):
) )
for event in events: for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -78,7 +79,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None lambda events: events[0] if events else None
) )
@cached() @cached(max_entries=5000)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
@ -154,7 +155,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
@cached() @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_joined_hosts_for_room", "get_joined_hosts_for_room",

View file

@ -0,0 +1,18 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized(
room_id, stream_id
);

View file

@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import (
cached, cachedInlineCallbacks, cachedList
)
from twisted.internet import defer from twisted.internet import defer
@ -44,60 +47,25 @@ class StateStore(SQLBaseStore):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups(self, event_ids): def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids
The return value is a dict mapping group names to lists of events. The return value is a dict mapping group names to lists of events.
""" """
if not event_ids:
defer.returnValue({})
def f(txn): event_to_groups = yield self._get_state_group_for_events(
groups = set() room_id, event_ids,
for event_id in event_ids:
group = self._simple_select_one_onecol_txn(
txn,
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
)
if group:
groups.add(group)
res = {}
for group in groups:
state_ids = self._simple_select_onecol_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": group},
retcol="event_id",
) )
res[group] = state_ids groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
return res defer.returnValue({
group: state_map.values()
states = yield self.runInteraction( for group, state_map in group_to_state.items()
"get_state_groups", })
f,
)
state_list = yield defer.gatherResults(
[
self._fetch_events_for_group(group, vals)
for group, vals in states.items()
],
consumeErrors=True,
)
defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, key, events):
return self._get_events(
events, get_prev_content=False
).addCallback(
lambda evs: (key, evs)
)
def _store_state_groups_txn(self, txn, event, context): def _store_state_groups_txn(self, txn, event, context):
return self._store_mult_state_groups_txn(txn, [(event, context)]) return self._store_mult_state_groups_txn(txn, [(event, context)])
@ -189,8 +157,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@cached(num_args=3) @cachedInlineCallbacks(num_args=3)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key): def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn): def f(txn):
sql = ( sql = (
@ -206,64 +173,254 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@defer.inlineCallbacks def _get_state_groups_from_groups(self, groups_and_types):
def get_state_for_events(self, room_id, event_ids): """Returns dictionary state_group -> state event ids
Args:
groups_and_types (list): list of 2-tuple (`group`, `types`)
"""
def f(txn): def f(txn):
groups = set() results = {}
event_to_group = {} for group, types in groups_and_types:
for event_id in event_ids: if types is not None:
# TODO: Remove this loop. where_clause = "AND (%s)" % (
group = self._simple_select_one_onecol_txn( " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
txn,
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
) )
if group: else:
event_to_group[event_id] = group where_clause = ""
groups.add(group)
group_to_state_ids = {} sql = (
for group in groups: "SELECT event_id FROM state_groups_state WHERE"
state_ids = self._simple_select_onecol_txn( " state_group = ? %s"
txn, ) % (where_clause,)
table="state_groups_state",
keyvalues={"state_group": group},
retcol="event_id",
)
group_to_state_ids[group] = state_ids args = [group]
if types is not None:
args.extend([i for typ in types for i in typ])
return event_to_group, group_to_state_ids txn.execute(sql, args)
res = yield self.runInteraction( results[group] = [r[0] for r in txn.fetchall()]
"annotate_events_with_state_groups",
return results
return self.runInteraction(
"_get_state_groups_from_groups",
f, f,
) )
event_to_group, group_to_state_ids = res @defer.inlineCallbacks
def get_state_for_events(self, room_id, event_ids, types):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event. The state dicts will only have the type/state_keys
that are in the `types` list.
state_list = yield defer.gatherResults( Args:
[ room_id (str)
self._fetch_events_for_group(group, vals) event_ids (list)
for group, vals in group_to_state_ids.items() types (list): List of (type, state_key) tuples which are used to
], filter the state fetched. `state_key` may be None, which matches
consumeErrors=True, any `state_key`
Returns:
deferred: A list of dicts corresponding to the event_ids given.
The dicts are mappings from (type, state_key) -> state_events
"""
event_to_groups = yield self._get_state_group_for_events(
room_id, event_ids,
) )
state_dict = { groups = set(event_to_groups.values())
group: { group_to_state = yield self._get_state_for_groups(groups, types)
(ev.type, ev.state_key): ev
for ev in state event_to_state = {
} event_id: group_to_state[group]
for group, state in state_list for event_id, group in event_to_groups.items()
} }
defer.returnValue([ defer.returnValue({event: event_to_state[event] for event in event_ids})
state_dict.get(event_to_group.get(event, None), None)
for event in event_ids @cached(num_args=2, lru=True, max_entries=10000)
]) def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={
"event_id": event_id,
},
retcol="state_group",
allow_none=True,
desc="_get_state_group_for_event",
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
num_args=2)
def _get_state_group_for_events(self, room_id, event_ids):
"""Returns mapping event_id -> state_group
"""
def f(txn):
results = {}
for event_id in event_ids:
results[event_id] = self._simple_select_one_onecol_txn(
txn,
table="event_to_state_groups",
keyvalues={
"event_id": event_id,
},
retcol="state_group",
allow_none=True,
)
return results
return self.runInteraction("_get_state_group_for_events", f)
def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
`missing_types` is the list of types that aren't in the cache for that
group. `got_all` is a bool indicating if we successfully retrieved all
requests state from the cache, if False we need to query the DB for the
missing state.
Args:
group: The state group to lookup
types (list): List of 2-tuples of the form (`type`, `state_key`),
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
is_all, state_dict = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
for typ, state_key in types:
if state_key is None:
type_to_key[typ] = None
missing_types.add((typ, state_key))
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
if (typ, state_key) not in state_dict:
missing_types.add((typ, state_key))
sentinel = object()
def include(typ, state_key):
valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel:
return False
if valid_state_keys is None:
return True
if state_key in valid_state_keys:
return True
return False
got_all = not (missing_types or types is None)
return {
k: v for k, v in state_dict.items()
if include(k[0], k[1])
}, missing_types, got_all
def _get_all_state_from_cache(self, group):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
indicating if we successfully retrieved all requests state from the
cache, if False we need to query the DB for the missing state.
Args:
group: The state group to lookup
"""
is_all, state_dict = self._state_group_cache.get(group)
return state_dict, is_all
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
"""Given list of groups returns dict of group -> list of state events
with matching types. `types` is a list of `(type, state_key)`, where
a `state_key` of None matches all state_keys. If `types` is None then
all events are returned.
"""
results = {}
missing_groups_and_types = []
if types is not None:
for group in set(groups):
state_dict, missing_types, got_all = self._get_some_state_from_cache(
group, types
)
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((group, missing_types))
else:
for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache(
group
)
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((group, None))
if not missing_groups_and_types:
defer.returnValue({
group: {
type_tuple: event
for type_tuple, event in state.items()
if event
}
for group, state in results.items()
})
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
group_state_dict = yield self._get_state_groups_from_groups(
missing_groups_and_types
)
state_events = yield self._get_events(
[e_id for l in group_state_dict.values() for e_id in l],
get_prev_content=False
)
state_events = {e.event_id: e for e in state_events}
# Now we want to update the cache with all the things we fetched
# from the database.
for group, state_ids in group_state_dict.items():
if types:
# We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've
# explicitly asked for some types then we will probably ask
# for them again.
state_dict = {key: None for key in types}
state_dict.update(results[group])
results[group] = state_dict
else:
state_dict = results[group]
for event_id in state_ids:
state_event = state_events[event_id]
state_dict[(state_event.type, state_event.state_key)] = state_event
self._state_group_cache.update(
cache_seq_num,
key=group,
value=state_dict,
full=(types is None),
)
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.items():
results[group] = {
key: event for key, event in state_dict.items() if event
}
defer.returnValue(results)
def _make_group_id(clock): def _make_group_id(clock):

View file

@ -36,6 +36,7 @@ what sort order was used:
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -299,9 +300,8 @@ class StreamStore(SQLBaseStore):
defer.returnValue((events, token)) defer.returnValue((events, token))
@defer.inlineCallbacks @cachedInlineCallbacks(num_args=4)
def get_recent_events_for_room(self, room_id, limit, end_token, def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback # TODO (erikj): Handle compressed feedback
end_token = RoomStreamToken.parse_stream_token(end_token) end_token = RoomStreamToken.parse_stream_token(end_token)

View file

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from collections import namedtuple from collections import namedtuple

View file

@ -178,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
Live tokens start with an "s" followed by the "stream_ordering" id of the Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-", "topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after. followed by the "stream_ordering" id of the event it comes after.
""" """
__slots__ = [] __slots__ = []
@ -211,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# token_id is the primary key ID of the access token, not the access token itself.
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View file

@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set()) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) object.__setattr__(self, "_result", (True, r))
while self._observers: while self._observers:
try: try:
self._observers.pop().callback(r) self._observers.pop().callback(r)
@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r return r
def errback(f): def errback(f):
self._result = (False, f) object.__setattr__(self, "_result", (False, f))
while self._observers: while self._observers:
try: try:
self._observers.pop().errback(f) self._observers.pop().errback(f)
@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self._deferred, name, value) setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)

View file

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.metrics
DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
cache_counter = metrics.register_cache(
"cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
labels=["name"],
)

View file

@ -0,0 +1,377 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from . import caches_by_name, DEBUG_CACHES, cache_counter
from twisted.internet import defer
from collections import OrderedDict
import functools
import inspect
import threading
logger = logging.getLogger(__name__)
_CacheSentinel = object()
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
return val
cache_counter.inc_misses(self.name)
if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def invalidate(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, cache_key,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = self.cache.sequence
ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
def onErr(f):
self.cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
class CacheListDescriptor(object):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion.
"""
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
"""
Args:
orig (function)
cache (Cache)
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
% (self.list_name, cache.name,)
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
cached = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
res = self.cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
except KeyError:
missing.append(arg)
if missing:
sequence = self.cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
self.function_to_call,
**args_to_call
)
ret_d = ObservableDeferred(ret_d)
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
key = list(keyargs)
key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
self.cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
return defer.gatherResults(
cached.values(),
consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
is specified as a list that is iterated through to lookup keys in the
original cache. A new list consisting of the keys that weren't in the cache
get passed to the original function, the result of which is stored in the
cache.
Args:
cache (Cache): The underlying cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache.
inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`?
Example:
class Example(object):
@cached(num_args=2)
def do_something(self, first_arg):
...
@cachedList(do_something.cache, list_name="second_args", num_args=2)
def batch_do_something(self, first_arg, second_args):
...
"""
return lambda orig: CacheListDescriptor(
orig,
cache=cache,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
)

View file

@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
from . import caches_by_name, cache_counter
import threading
import logging
logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
class DictionaryCache(object):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries)
self.name = name
self.sequence = 0
self.thread = None
# caches_by_name[name] = self.cache
class Sentinel(object):
__slots__ = []
self.sentinel = Sentinel()
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
cache_counter.inc_hits(self.name)
if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value))
else:
return DictionaryEntry(entry.full, {
k: entry.value[k]
for k in dict_keys
if k in entry.value
})
cache_counter.inc_misses(self.name)
return DictionaryEntry(False, {})
def invalidate(self, key):
self.check_thread()
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
def update(self, sequence, key, value, full=False):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
if full:
self._insert(key, value)
else:
self._update_or_insert(key, value)
def _update_or_insert(self, key, value):
entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
entry.value.update(value)
def _insert(self, key, value):
self.cache[key] = DictionaryEntry(True, value)

View file

@ -17,7 +17,9 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import Cache, cached from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
class CacheTestCase(unittest.TestCase): class CacheTestCase(unittest.TestCase):
@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
self.assertEquals(self.cache.get("foo"), 123) self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self): def test_invalidate(self):
self.cache.prefill("foo", 123) self.cache.prefill(("foo",), 123)
self.cache.invalidate("foo") self.cache.invalidate(("foo",))
failed = False failed = False
try: try:
self.cache.get("foo") self.cache.get(("foo",))
except KeyError: except KeyError:
failed = True failed = True
@ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
a.func.invalidate("foo") a.func.invalidate(("foo",))
yield a.func("foo") yield a.func("foo")
@ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
def func(self, key): def func(self, key):
return key return key
A().func.invalidate("what") A().func.invalidate(("what",))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_max_entries(self): def test_max_entries(self):
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertTrue(callcount[0] >= 14, self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0])) msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self): def test_prefill(self):
callcount = [0] callcount = [0]
d = defer.succeed(123)
class A(object): class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return d
a = A() a = A()
a.func.prefill("foo", 123) a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals((yield a.func("foo")), 123) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)

View file

@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
yield d yield d
self.assertTrue(d.called) self.assertTrue(d.called)
observers[0].assert_called_once("Go") observers[0].assert_called_once_with("Go")
observers[1].assert_called_once("Go") observers[1].assert_called_once_with("Go")
self.assertEquals(mock_logger.warning.call_count, 1) self.assertEquals(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], self.assertIsInstance(mock_logger.warning.call_args[0][0],

View file

@ -69,7 +69,7 @@ class StateGroupStore(object):
self._next_group = 1 self._next_group = 1
def get_state_groups(self, event_ids): def get_state_groups(self, room_id, event_ids):
groups = {} groups = {}
for event_id in event_ids: for event_id in event_ids:
group = self._event_to_state_group.get(event_id) group = self._event_to_state_group.get(event_id)

View file

@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.util.caches.dictionary_cache import DictionaryCache
class DictCacheTestCase(unittest.TestCase):
def setUp(self):
self.cache = DictionaryCache("foobar")
def test_simple_cache_hit_full(self):
key = "test_simple_cache_hit_full"
v = self.cache.get(key)
self.assertEqual((False, {}), v)
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
self.cache.update(seq, key, test_value, full=True)
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
def test_simple_cache_hit_partial(self):
key = "test_simple_cache_hit_partial"
seq = self.cache.sequence
test_value = {
"test": "test_simple_cache_hit_partial"
}
self.cache.update(seq, key, test_value, full=True)
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
def test_simple_cache_miss_partial(self):
key = "test_simple_cache_miss_partial"
seq = self.cache.sequence
test_value = {
"test": "test_simple_cache_miss_partial"
}
self.cache.update(seq, key, test_value, full=True)
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
def test_simple_cache_hit_miss_partial(self):
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
test_value = {
"test": "test_simple_cache_hit_miss_partial",
"test2": "test_simple_cache_hit_miss_partial2",
"test3": "test_simple_cache_hit_miss_partial3",
}
self.cache.update(seq, key, test_value, full=True)
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
def test_multi_insert(self):
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
test_value_1 = {
"test": "test_simple_cache_hit_miss_partial",
}
self.cache.update(seq, key, test_value_1, full=False)
seq = self.cache.sequence
test_value_2 = {
"test2": "test_simple_cache_hit_miss_partial2",
}
self.cache.update(seq, key, test_value_2, full=False)
c = self.cache.get(key)
self.assertEqual(
{
"test": "test_simple_cache_hit_miss_partial",
"test2": "test_simple_cache_hit_miss_partial2",
},
c.value
)

View file

@ -16,7 +16,7 @@
from .. import unittest from .. import unittest
from synapse.util.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1 cache["key"] = 1
self.assertEquals(cache.pop("key"), 1) self.assertEquals(cache.pop("key"), 1)
self.assertEquals(cache.pop("key"), None) self.assertEquals(cache.pop("key"), None)