0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 20:43:52 +01:00

Merge remote-tracking branch 'origin/develop' into markjh/end-to-end-key-federation

This commit is contained in:
Mark Haines 2015-08-13 17:27:53 +01:00
commit c5966b2a97
35 changed files with 616 additions and 370 deletions

View file

@ -101,25 +101,26 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux::
$ sudo pacman -S base-devel python2 python-pip \
sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3
Installing prerequisites on Mac OS X::
$ xcode-select --install
$ sudo pip install virtualenv
xcode-select --install
sudo easy_install pip
sudo pip install virtualenv
To install the synapse homeserver run::
$ virtualenv -p python2.7 ~/.synapse
$ source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
virtualenv -p python2.7 ~/.synapse
source ~/.synapse/bin/activate
pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``. Feel free to pick a different directory
@ -132,8 +133,8 @@ above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
To set up your homeserver, run (in your virtualenv, as before)::
$ cd ~/.synapse
$ python -m synapse.app.homeserver \
cd ~/.synapse
python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
@ -192,9 +193,9 @@ Running Synapse
To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and::
$ cd ~/.synapse
$ source ./bin/activate
$ synctl start
cd ~/.synapse
source ./bin/activate
synctl start
Platform Specific Instructions
==============================
@ -212,12 +213,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
$ sudo pip2.7 install --upgrade pip
sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install
request::
$ pip2.7 install --process-dependency-links \
pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
@ -225,13 +226,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if
installing under virtualenv)::
$ sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt
sudo pip2.7 uninstall py-bcrypt
sudo pip2.7 install py-bcrypt
During setup of Synapse you need to call python2.7 directly again::
$ cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \
cd ~/.synapse
python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
@ -279,22 +280,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it::
$ sudo pip install --upgrade pip
sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
$ rm -rf /tmp/pip_install_matrix
rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
$ pip install twisted
pip install twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments.
Troubleshooting Running
@ -310,10 +311,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
fix try re-installing from PyPI or directly from
(https://github.com/pyca/pynacl)::
$ # Install from PyPI
$ pip install --user --upgrade --force pynacl
$ # Install from github
$ pip install --user https://github.com/pyca/pynacl/tarball/master
# Install from PyPI
pip install --user --upgrade --force pynacl
# Install from github
pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux
~~~~~~~~~
@ -321,7 +323,7 @@ ArchLinux
If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
...or by editing synctl with the correct python executable.
@ -331,16 +333,16 @@ Synapse Development
To check out a synapse for development, clone the git repo into a working
directory of your choice::
$ git clone https://github.com/matrix-org/synapse.git
$ cd synapse
git clone https://github.com/matrix-org/synapse.git
cd synapse
Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv::
$ virtualenv env
$ source env/bin/activate
$ python synapse/python_dependencies.py | xargs -n1 pip install
$ pip install setuptools_trial mock
virtualenv env
source env/bin/activate
python synapse/python_dependencies.py | xargs -n1 pip install
pip install setuptools_trial mock
This will run a process of downloading and installing all the needed
dependencies into a virtual env.
@ -348,7 +350,7 @@ dependencies into a virtual env.
Once this is done, you may wish to run Synapse's unit tests, to
check that everything is installed as it should be::
$ python setup.py test
python setup.py test
This should end with a 'PASSED' result::
@ -389,11 +391,11 @@ IDs:
For the first form, simply pass the required hostname (of the machine) as the
--server-name parameter::
$ python -m synapse.app.homeserver \
python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml
python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process.
@ -410,11 +412,11 @@ record would then look something like::
At this point, you should then run the homeserver with the hostname of this
SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \
python -m synapse.app.homeserver \
--server-name YOURDOMAIN \
--config-path homeserver.yaml \
--generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml
python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to
@ -428,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
``localhost:8082``) which you can then access through the webclient running at
http://localhost:8080. Simply run::
$ demo/start.sh
demo/start.sh
This is mainly useful just for development purposes.
@ -502,10 +504,10 @@ Building Internal API Documentation
Before building internal API documentation install sphinx and
sphinxcontrib-napoleon::
$ pip install sphinx
$ pip install sphinxcontrib-napoleon
pip install sphinx
pip install sphinxcontrib-napoleon
Building internal API documentation::
$ python setup.py build_sphinx
python setup.py build_sphinx

View file

@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
exit 1
fi
find "$DIR" -name "*.log" -delete
find "$DIR" -name "*.db" -delete
for port in 8080 8081 8082; do
rm -rf $DIR/$port
rm -rf $DIR/media_store.$port
done
rm -rf $DIR/etc

View file

@ -8,14 +8,6 @@ cd "$DIR/.."
mkdir -p demo/etc
# Check the --no-rate-limit param
PARAMS=""
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
fi
fi
export PYTHONPATH=$(readlink -f $(pwd))
@ -31,10 +23,20 @@ for port in 8080 8081 8082; do
#rm $DIR/etc/$port.config
python -m synapse.app.homeserver \
--generate-config \
--enable_registration \
-H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \
# Check script parameters
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
# Set high limits in config file to disable rate limiting
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
fi
fi
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \
-D \

View file

@ -16,3 +16,6 @@ ignore =
docs/*
pylint.cfg
tox.ini
[flake8]
max-line-length = 90

View file

@ -48,7 +48,7 @@ setup(
description="Reference Synapse Home Server",
install_requires=dependencies['requirements'](include_conditional=True).keys(),
setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial",
"mock"
],

View file

@ -44,6 +44,11 @@ class Auth(object):
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
@ -319,7 +324,7 @@ class Auth(object):
Returns:
tuple : of UserID and device string:
User ID object of the user making the request
Client ID object of the client instance the user is using
ClientInfo object of the client instance the user is using
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@ -352,7 +357,7 @@ class Auth(object):
)
return
except KeyError:
pass # normal users won't have this query parameter set
pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_token(access_token)
user = user_info["user"]
@ -521,7 +526,6 @@ class Auth(object):
# Check state_key
if hasattr(event, "state_key"):
if not event.state_key.startswith("_"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(

View file

@ -657,6 +657,7 @@ def run(hs):
if hs.config.daemonize:
if hs.config.print_pidfile:
print hs.config.pid_file
daemon = Daemonize(

View file

@ -138,12 +138,19 @@ class Config(object):
action="store_true",
help="Generate a config file for the server name"
)
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
)
config_parser.add_argument(
"-H", "--server-name",
help="The server name to generate a config file for"
)
config_args, remaining_args = config_parser.parse_known_args(argv)
generate_keys = config_args.generate_keys
if config_args.generate_config:
if not config_args.config_path:
config_parser.error(
@ -151,34 +158,17 @@ class Config(object):
" generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
)
config_dir_path = os.path.dirname(config_args.config_path[0])
(config_path,) = config_args.config_path
if not os.path.exists(config_path):
config_dir_path = os.path.dirname(config_path)
config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
(config_path,) = config_args.config_path
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
if os.path.exists(config_path):
print "Config file %r already exists" % (config_path,)
yaml_config = cls.read_config_file(config_path)
yaml_name = yaml_config["server_name"]
if server_name != yaml_name:
print (
"Config file %r has a different server_name: "
" %r != %r" % (config_path, server_name, yaml_name)
)
sys.exit(1)
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
config.update(yaml_config)
print "Generating any missing keys for %r" % (server_name,)
obj.invoke_all("generate_files", config)
sys.exit(0)
with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config(
config_dir_path, server_name
@ -186,16 +176,22 @@ class Config(object):
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
print (
"A config file has been generated in %s for server name"
" '%s' with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to"
" your needs."
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
) % (config_path, server_name)
print (
"If this server name is incorrect, you will need to regenerate"
" the SSL certificates"
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
sys.exit(0)
else:
print (
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,)
generate_keys = True
parser = argparse.ArgumentParser(
parents=[config_parser],
@ -213,7 +209,7 @@ class Config(object):
" -c CONFIG-FILE\""
)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path)
specified_config = {}
@ -226,6 +222,10 @@ class Config(object):
config.pop("log_config")
config.update(specified_config)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)

View file

@ -24,6 +24,7 @@ class ServerConfig(Config):
self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", [])
@ -208,12 +209,18 @@ class ServerConfig(Config):
self.manhole = args.manhole
if args.daemonize is not None:
self.daemonize = args.daemonize
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
def add_arguments(self, parser):
server_group = parser.add_argument_group("server")
server_group.add_argument("-D", "--daemonize", action='store_true',
default=None,
help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"

View file

@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org']
trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds:
id_server = creds['id_server']

View file

@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
password (str) : The password to assign to this user so they can
login again.
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
Returns:
A tuple of (user_id, access_token).
Raises:

View file

@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
)
class MatrixFederationHttpAgent(_AgentBase):
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory
def __init__(self, reactor, pool=None):
_AgentBase.__init__(self, reactor, pool)
def endpointForURI(self, uri):
destination = uri.netloc
def request(self, destination, endpoint, method, path, params, query,
headers, body_producer):
outgoing_requests_counter.inc(method)
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
def _cb(response):
incoming_responses_counter.inc(method, response.code)
return response
def _eb(failure):
incoming_responses_counter.inc(method, "ERR")
return failure
d.addCallbacks(_cb, _eb)
return d
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory
)
class MatrixFederationHttpClient(object):
@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
self.server_name = hs.hostname
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.clock = hs.get_clock()
self.version_string = hs.version_string
self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
)
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "",)
url_bytes = self._create_url(
destination, path_bytes, param_bytes, query_bytes
)
txn_id = "%s-O-%s" % (method, self._next_id)
@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
# (once we have reliable transactions in place)
retries_left = 5
endpoint = preserve_context_over_fn(
self._getEndpoint, reactor, destination
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
)
log_result = None
@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object):
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
producer = body_callback(method, http_url_bytes, headers_dict)
try:
def send_request():
request_deferred = self.agent.request(
destination,
endpoint,
request_deferred = preserve_context_over_fn(
self.agent.request,
method,
path_bytes,
param_bytes,
query_bytes,
url_bytes,
Headers(headers_dict),
producer
)
@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object):
defer.returnValue((length, headers))
def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):

View file

@ -18,8 +18,12 @@ from __future__ import absolute_import
import logging
from resource import getrusage, getpagesize, RUSAGE_SELF
import functools
import os
import stat
import time
from twisted.internet import reactor
from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@ -144,3 +148,28 @@ def _process_fds():
return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
def runUntilCurrentTimer(func):
@functools.wraps(func)
def f(*args, **kwargs):
pending_calls = len(reactor.getDelayedCalls())
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(pending_calls)
return ret
return f
if hasattr(reactor, "runUntilCurrent"):
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)

View file

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.7": ["syutil>=0.0.7"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],

View file

@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_request_allow_empty
from ._base import client_v2_pattern, parse_json_dict_from_request
import logging
import hmac
@ -55,30 +55,55 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_dict_from_request(request)
body = parse_request_allow_empty(request)
# we do basic sanity checks here because the auth
# layer will store these in sessions
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
desired_password = None
if 'password' in body:
if ((not isinstance(body['password'], str) and
not isinstance(body['password'], unicode)) or
if (not isinstance(body['password'], basestring) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body:
if ((not isinstance(body['username'], str) and
not isinstance(body['username'], unicode)) or
if (not isinstance(body['username'], basestring) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username']
yield self.registration_handler.check_username(desired_username)
is_using_shared_secret = False
is_application_server = False
service = None
appservice = None
if 'access_token' in request.args:
service = yield self.auth.get_appservice_by_req(request)
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
# have completely different registration flows to normal users
# == Application Service Registration ==
if appservice:
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body["mac"]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Normal User Registration == (everyone else)
if self.hs.config.disable_registration:
raise SynapseError(403, "Registration has been disabled")
if desired_username is not None:
yield self.registration_handler.check_username(desired_username)
if self.hs.config.enable_registration_captcha:
flows = [
@ -91,39 +116,20 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
result = None
if service:
is_application_server = True
params = body
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
self._check_shared_secret_auth(
body['username'], body['mac']
)
is_using_shared_secret = True
params = body
else:
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
return
can_register = (
not self.hs.config.disable_registration
or is_application_server
or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
desired_username = params['username'] if 'username' in params else None
new_password = params['password']
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
desired_username = params.get("username", None)
new_password = params.get("password", None)
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
@ -156,18 +162,21 @@ class RegisterRestServlet(RestServlet):
else:
logger.info("bind_email not specified: not binding email")
result = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
result = self._create_registration_details(user_id, token)
defer.returnValue((200, result))
def on_OPTIONS(self, _):
return 200, {}
def _check_shared_secret_auth(self, username, mac):
@defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token):
(user_id, token) = yield self.registration_handler.appservice_register(
username, as_token
)
defer.returnValue(self._create_registration_details(user_id, token))
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@ -183,13 +192,23 @@ class RegisterRestServlet(RestServlet):
digestmod=sha1,
).hexdigest()
if compare_digest(want_mac, got_mac):
return True
else:
if not compare_digest(want_mac, got_mac):
raise SynapseError(
403, "HMAC incorrect",
)
(user_id, token) = yield self.registration_handler.register(
localpart=username, password=password
)
defer.returnValue(self._create_registration_details(user_id, token))
def _create_registration_details(self, user_id, token):
return {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)

View file

@ -244,6 +244,9 @@ class BaseMediaResource(Resource):
)
return
local_thumbnails = []
def generate_thumbnails():
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
@ -262,9 +265,10 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
)
))
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
@ -278,9 +282,14 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
)
))
yield threads.deferToThread(generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,

View file

@ -162,11 +162,12 @@ class ThumbnailResource(BaseMediaResource):
t_method = info["thumbnail_method"]
if t_method == "scale" or t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
info_list.append((
aspect_quality, size_quality, type_quality,
aspect_quality, min_quality, size_quality, type_quality,
length_quality, info
))
if info_list:

View file

@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip)
try:
last_seen = self.client_ip_last_seen.get(*key)
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None)
self.client_ip_last_seen.prefill(*key + (now,))
self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)

View file

@ -15,6 +15,7 @@
import logging
from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@ -27,6 +28,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
import inspect
import sys
import time
import threading
@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
)
_CacheSentinel = object()
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
@ -81,45 +86,44 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if keyargs in self.cache:
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
return self.cache[keyargs]
return val
cache_counter.inc_misses(self.name)
raise KeyError()
def update(self, sequence, *args):
if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args)
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
self.prefill(key, value)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[keyargs] = value
self.cache[key] = value
def invalidate(self, *keyargs):
def invalidate(self, key):
self.check_thread()
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(keyargs, None)
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
@ -130,6 +134,9 @@ class Cache(object):
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
@ -141,58 +148,92 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
cache = Cache(
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(*keyargs):
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
cached_result = cache.get(*keyargs[:self.num_args])
cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
actual_result = yield self.orig(obj, *keyargs)
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs,
self.orig.__name__, cache_key,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
sequence = self.cache.sequence
ret = yield self.orig(obj, *keyargs)
ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
def onErr(f):
self.cache.invalidate(cache_key)
return f
defer.returnValue(ret)
ret.addErrback(onErr)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=False):
def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@ -201,6 +242,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()

View file

@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
},
desc="create_room_alias_association",
)
self.get_aliases_for_room.invalidate(room_id)
self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
room_alias,
)
self.get_aliases_for_room.invalidate(room_id)
self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):

View file

@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
for room_id in events_by_room:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_backfill_events(self, room_id, event_list, limit):
@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))

View file

@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn(
@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
if not context.rejected:
txn.call_after(
self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key
(event.room_id, event.type, event.state_key,)
)
if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after(
self.get_room_name_and_aliases.invalidate,
event.room_id
(event.room_id,)
)
self._simple_upsert_txn(
@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
self._get_event_cache.invalidate(
(event_id, check_redacted, get_prev_content)
)
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events:
try:
ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content
(event_id, check_redacted, get_prev_content,)
)
if allow_rejected or not ret.rejected_reason:
@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
(ev.event_id, check_redacted, get_prev_content), ev
)
defer.returnValue(ev)
@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
(ev.event_id, check_redacted, get_prev_content), ev
)
return ev

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore, cached
from _base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
table="server_signature_keys",
@ -132,7 +131,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
self.get_all_server_verify_keys.invalidate(server_name)
self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):

View file

@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
)
self.get_presence_list_accepted.invalidate(observer_localpart)
self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid},
desc="del_presence_list",
)
self.get_presence_list_accepted.invalidate(observer_localpart)
self.get_presence_list_accepted.invalidate((observer_localpart,))

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
import logging
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list(
table=PushRuleTable.table_name,
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows)
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name,
@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule",
)
self.get_push_rules_for_user.invalidate(user_name)
self.get_push_rules_enabled_for_user.invalidate(user_name)
self.get_push_rules_for_user.invalidate((user_name,))
self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id},
)
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
@cached
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers.
"""

View file

@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
user_id
)
for r in rows:
self.get_user_by_token.invalidate(r)
self.get_user_by_token.invalidate((r,))
@cached()
def get_user_by_token(self, token):

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
import collections
import logging
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
}
)
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def f(txn):
sql = (

View file

@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
)
for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, event.state_key)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@ -78,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None
)
@cached()
@cached(max_entries=5000)
def get_users_in_room(self, room_id):
def f(txn):
@ -154,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
@cached()
@cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
return self.runInteraction(
"get_joined_hosts_for_room",

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@ -91,7 +91,6 @@ class StateStore(SQLBaseStore):
defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, key, events):
return self._get_events(
events, get_prev_content=False
@ -189,8 +188,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@cached(num_args=3)
@defer.inlineCallbacks
@cachedInlineCallbacks(num_args=3)
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (

View file

@ -178,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-",
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
@ -211,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "s%d" % (self.stream,)
# token_id is the primary key ID of the access token, not the access token itself.
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View file

@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set())
def callback(r):
self._result = (True, r)
object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
self._observers.pop().callback(r)
@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r
def errback(f):
self._result = (False, f)
object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
self._observers.pop().errback(f)
@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value):
setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)

View file

@ -0,0 +1,134 @@
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.api.errors import SynapseError
from twisted.internet import defer
from mock import Mock, MagicMock
from tests import unittest
import json
class RegisterRestServletTestCase(unittest.TestCase):
def setUp(self):
# do the dance to hook up request data to self.request_data
self.request_data = ""
self.request = Mock(
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
)
self.request.args = {}
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
# do the dance to hook it up to the hs global
self.handlers = Mock(
auth_handler=self.auth_handler,
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
)
self.hs = Mock()
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.config.disable_registration = False
# init the thing we're testing
self.servlet = RegisterRestServlet(self.hs)
@defer.inlineCallbacks
def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
self.appservice = {
"id": "1234"
}
self.registration_handler.appservice_register = Mock(
return_value=(user_id, token)
)
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (200, {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname
}))
@defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self):
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
self.appservice = None # no application service exists
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (401, None))
def test_POST_bad_password(self):
self.request_data = json.dumps({
"username": "kermit",
"password": 666
})
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)
def test_POST_bad_username(self):
self.request_data = json.dumps({
"username": 777,
"password": "monkey"
})
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)
@defer.inlineCallbacks
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
self.registration_handler.register = Mock(return_value=(user_id, token))
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (200, {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname
}))
def test_POST_disabled_registration(self):
self.hs.config.disable_registration = True
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)

View file

@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
from synapse.util.async import ObservableDeferred
from synapse.storage._base import Cache, cached
@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self):
self.cache.prefill("foo", 123)
self.cache.invalidate("foo")
self.cache.prefill(("foo",), 123)
self.cache.invalidate(("foo",))
failed = False
try:
self.cache.get("foo")
self.cache.get(("foo",))
except KeyError:
failed = True
@ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 1)
a.func.invalidate("foo")
a.func.invalidate(("foo",))
yield a.func("foo")
@ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
def func(self, key):
return key
A().func.invalidate("what")
A().func.invalidate(("what",))
@defer.inlineCallbacks
def test_max_entries(self):
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self):
callcount = [0]
d = defer.succeed(123)
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
return d
a = A()
a.func.prefill("foo", 123)
a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals((yield a.func("foo")), 123)
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)

View file

@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
yield d
self.assertTrue(d.called)
observers[0].assert_called_once("Go")
observers[1].assert_called_once("Go")
observers[0].assert_called_once_with("Go")
observers[1].assert_called_once_with("Go")
self.assertEquals(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0],