forked from MirrorHub/synapse
Merge remote-tracking branch 'origin/develop' into markjh/end-to-end-key-federation
This commit is contained in:
commit
c5966b2a97
35 changed files with 616 additions and 370 deletions
92
README.rst
92
README.rst
|
@ -101,25 +101,26 @@ header files for python C extensions.
|
|||
|
||||
Installing prerequisites on Ubuntu or Debian::
|
||||
|
||||
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
|
||||
python-pip python-setuptools sqlite3 \
|
||||
libssl-dev python-virtualenv libjpeg-dev
|
||||
sudo apt-get install build-essential python2.7-dev libffi-dev \
|
||||
python-pip python-setuptools sqlite3 \
|
||||
libssl-dev python-virtualenv libjpeg-dev
|
||||
|
||||
Installing prerequisites on ArchLinux::
|
||||
|
||||
$ sudo pacman -S base-devel python2 python-pip \
|
||||
python-setuptools python-virtualenv sqlite3
|
||||
sudo pacman -S base-devel python2 python-pip \
|
||||
python-setuptools python-virtualenv sqlite3
|
||||
|
||||
Installing prerequisites on Mac OS X::
|
||||
|
||||
$ xcode-select --install
|
||||
$ sudo pip install virtualenv
|
||||
xcode-select --install
|
||||
sudo easy_install pip
|
||||
sudo pip install virtualenv
|
||||
|
||||
To install the synapse homeserver run::
|
||||
|
||||
$ virtualenv -p python2.7 ~/.synapse
|
||||
$ source ~/.synapse/bin/activate
|
||||
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||
virtualenv -p python2.7 ~/.synapse
|
||||
source ~/.synapse/bin/activate
|
||||
pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||
|
||||
This installs synapse, along with the libraries it uses, into a virtual
|
||||
environment under ``~/.synapse``. Feel free to pick a different directory
|
||||
|
@ -132,8 +133,8 @@ above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
|
|||
|
||||
To set up your homeserver, run (in your virtualenv, as before)::
|
||||
|
||||
$ cd ~/.synapse
|
||||
$ python -m synapse.app.homeserver \
|
||||
cd ~/.synapse
|
||||
python -m synapse.app.homeserver \
|
||||
--server-name machine.my.domain.name \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
|
@ -192,9 +193,9 @@ Running Synapse
|
|||
To actually run your new homeserver, pick a working directory for Synapse to run
|
||||
(e.g. ``~/.synapse``), and::
|
||||
|
||||
$ cd ~/.synapse
|
||||
$ source ./bin/activate
|
||||
$ synctl start
|
||||
cd ~/.synapse
|
||||
source ./bin/activate
|
||||
synctl start
|
||||
|
||||
Platform Specific Instructions
|
||||
==============================
|
||||
|
@ -212,12 +213,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
|
|||
|
||||
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
|
||||
|
||||
$ sudo pip2.7 install --upgrade pip
|
||||
sudo pip2.7 install --upgrade pip
|
||||
|
||||
You also may need to explicitly specify python 2.7 again during the install
|
||||
request::
|
||||
|
||||
$ pip2.7 install --process-dependency-links \
|
||||
pip2.7 install --process-dependency-links \
|
||||
https://github.com/matrix-org/synapse/tarball/master
|
||||
|
||||
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
|
||||
|
@ -225,13 +226,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
|
|||
compile it under the right architecture. (This should not be needed if
|
||||
installing under virtualenv)::
|
||||
|
||||
$ sudo pip2.7 uninstall py-bcrypt
|
||||
$ sudo pip2.7 install py-bcrypt
|
||||
sudo pip2.7 uninstall py-bcrypt
|
||||
sudo pip2.7 install py-bcrypt
|
||||
|
||||
During setup of Synapse you need to call python2.7 directly again::
|
||||
|
||||
$ cd ~/.synapse
|
||||
$ python2.7 -m synapse.app.homeserver \
|
||||
cd ~/.synapse
|
||||
python2.7 -m synapse.app.homeserver \
|
||||
--server-name machine.my.domain.name \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
|
@ -279,22 +280,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
|
|||
you get errors about ``error: no such option: --process-dependency-links`` you
|
||||
may need to manually upgrade it::
|
||||
|
||||
$ sudo pip install --upgrade pip
|
||||
sudo pip install --upgrade pip
|
||||
|
||||
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
||||
refuse to run until you remove the temporary installation directory it
|
||||
created. To reset the installation::
|
||||
|
||||
$ rm -rf /tmp/pip_install_matrix
|
||||
rm -rf /tmp/pip_install_matrix
|
||||
|
||||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.::
|
||||
|
||||
$ pip install twisted
|
||||
pip install twisted
|
||||
|
||||
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||
On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||
will need to export CFLAGS=-Qunused-arguments.
|
||||
|
||||
Troubleshooting Running
|
||||
|
@ -310,10 +311,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
|
|||
fix try re-installing from PyPI or directly from
|
||||
(https://github.com/pyca/pynacl)::
|
||||
|
||||
$ # Install from PyPI
|
||||
$ pip install --user --upgrade --force pynacl
|
||||
$ # Install from github
|
||||
$ pip install --user https://github.com/pyca/pynacl/tarball/master
|
||||
# Install from PyPI
|
||||
pip install --user --upgrade --force pynacl
|
||||
|
||||
# Install from github
|
||||
pip install --user https://github.com/pyca/pynacl/tarball/master
|
||||
|
||||
ArchLinux
|
||||
~~~~~~~~~
|
||||
|
@ -321,7 +323,7 @@ ArchLinux
|
|||
If running `$ synctl start` fails with 'returned non-zero exit status 1',
|
||||
you will need to explicitly call Python2.7 - either running as::
|
||||
|
||||
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
||||
python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
||||
|
||||
...or by editing synctl with the correct python executable.
|
||||
|
||||
|
@ -331,16 +333,16 @@ Synapse Development
|
|||
To check out a synapse for development, clone the git repo into a working
|
||||
directory of your choice::
|
||||
|
||||
$ git clone https://github.com/matrix-org/synapse.git
|
||||
$ cd synapse
|
||||
git clone https://github.com/matrix-org/synapse.git
|
||||
cd synapse
|
||||
|
||||
Synapse has a number of external dependencies, that are easiest
|
||||
to install using pip and a virtualenv::
|
||||
|
||||
$ virtualenv env
|
||||
$ source env/bin/activate
|
||||
$ python synapse/python_dependencies.py | xargs -n1 pip install
|
||||
$ pip install setuptools_trial mock
|
||||
virtualenv env
|
||||
source env/bin/activate
|
||||
python synapse/python_dependencies.py | xargs -n1 pip install
|
||||
pip install setuptools_trial mock
|
||||
|
||||
This will run a process of downloading and installing all the needed
|
||||
dependencies into a virtual env.
|
||||
|
@ -348,7 +350,7 @@ dependencies into a virtual env.
|
|||
Once this is done, you may wish to run Synapse's unit tests, to
|
||||
check that everything is installed as it should be::
|
||||
|
||||
$ python setup.py test
|
||||
python setup.py test
|
||||
|
||||
This should end with a 'PASSED' result::
|
||||
|
||||
|
@ -389,11 +391,11 @@ IDs:
|
|||
For the first form, simply pass the required hostname (of the machine) as the
|
||||
--server-name parameter::
|
||||
|
||||
$ python -m synapse.app.homeserver \
|
||||
python -m synapse.app.homeserver \
|
||||
--server-name machine.my.domain.name \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
|
||||
Alternatively, you can run ``synctl start`` to guide you through the process.
|
||||
|
||||
|
@ -410,11 +412,11 @@ record would then look something like::
|
|||
At this point, you should then run the homeserver with the hostname of this
|
||||
SRV record, as that is the name other machines will expect it to have::
|
||||
|
||||
$ python -m synapse.app.homeserver \
|
||||
python -m synapse.app.homeserver \
|
||||
--server-name YOURDOMAIN \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
|
||||
|
||||
You may additionally want to pass one or more "-v" options, in order to
|
||||
|
@ -428,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
|
|||
``localhost:8082``) which you can then access through the webclient running at
|
||||
http://localhost:8080. Simply run::
|
||||
|
||||
$ demo/start.sh
|
||||
demo/start.sh
|
||||
|
||||
This is mainly useful just for development purposes.
|
||||
|
||||
|
@ -502,10 +504,10 @@ Building Internal API Documentation
|
|||
Before building internal API documentation install sphinx and
|
||||
sphinxcontrib-napoleon::
|
||||
|
||||
$ pip install sphinx
|
||||
$ pip install sphinxcontrib-napoleon
|
||||
pip install sphinx
|
||||
pip install sphinxcontrib-napoleon
|
||||
|
||||
Building internal API documentation::
|
||||
|
||||
$ python setup.py build_sphinx
|
||||
python setup.py build_sphinx
|
||||
|
||||
|
|
|
@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
find "$DIR" -name "*.log" -delete
|
||||
find "$DIR" -name "*.db" -delete
|
||||
for port in 8080 8081 8082; do
|
||||
rm -rf $DIR/$port
|
||||
rm -rf $DIR/media_store.$port
|
||||
done
|
||||
|
||||
rm -rf $DIR/etc
|
||||
|
|
|
@ -8,14 +8,6 @@ cd "$DIR/.."
|
|||
|
||||
mkdir -p demo/etc
|
||||
|
||||
# Check the --no-rate-limit param
|
||||
PARAMS=""
|
||||
if [ $# -eq 1 ]; then
|
||||
if [ $1 = "--no-rate-limit" ]; then
|
||||
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
|
||||
fi
|
||||
fi
|
||||
|
||||
export PYTHONPATH=$(readlink -f $(pwd))
|
||||
|
||||
|
||||
|
@ -31,10 +23,20 @@ for port in 8080 8081 8082; do
|
|||
#rm $DIR/etc/$port.config
|
||||
python -m synapse.app.homeserver \
|
||||
--generate-config \
|
||||
--enable_registration \
|
||||
-H "localhost:$https_port" \
|
||||
--config-path "$DIR/etc/$port.config" \
|
||||
|
||||
# Check script parameters
|
||||
if [ $# -eq 1 ]; then
|
||||
if [ $1 = "--no-rate-limit" ]; then
|
||||
# Set high limits in config file to disable rate limiting
|
||||
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
|
||||
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
|
||||
fi
|
||||
fi
|
||||
|
||||
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
|
||||
|
||||
python -m synapse.app.homeserver \
|
||||
--config-path "$DIR/etc/$port.config" \
|
||||
-D \
|
||||
|
|
|
@ -16,3 +16,6 @@ ignore =
|
|||
docs/*
|
||||
pylint.cfg
|
||||
tox.ini
|
||||
|
||||
[flake8]
|
||||
max-line-length = 90
|
||||
|
|
2
setup.py
2
setup.py
|
@ -48,7 +48,7 @@ setup(
|
|||
description="Reference Synapse Home Server",
|
||||
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
||||
setup_requires=[
|
||||
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||
"Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||
"setuptools_trial",
|
||||
"mock"
|
||||
],
|
||||
|
|
|
@ -44,6 +44,11 @@ class Auth(object):
|
|||
def check(self, event, auth_events):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
Args:
|
||||
event: the event being checked.
|
||||
auth_events (dict: event-key -> event): the existing room state.
|
||||
|
||||
|
||||
Returns:
|
||||
True if the auth checks pass.
|
||||
"""
|
||||
|
@ -319,7 +324,7 @@ class Auth(object):
|
|||
Returns:
|
||||
tuple : of UserID and device string:
|
||||
User ID object of the user making the request
|
||||
Client ID object of the client instance the user is using
|
||||
ClientInfo object of the client instance the user is using
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
|
@ -352,7 +357,7 @@ class Auth(object):
|
|||
)
|
||||
return
|
||||
except KeyError:
|
||||
pass # normal users won't have this query parameter set
|
||||
pass # normal users won't have the user_id query parameter set.
|
||||
|
||||
user_info = yield self.get_user_by_token(access_token)
|
||||
user = user_info["user"]
|
||||
|
@ -521,23 +526,22 @@ class Auth(object):
|
|||
|
||||
# Check state_key
|
||||
if hasattr(event, "state_key"):
|
||||
if not event.state_key.startswith("_"):
|
||||
if event.state_key.startswith("@"):
|
||||
if event.state_key != event.user_id:
|
||||
if event.state_key.startswith("@"):
|
||||
if event.state_key != event.user_id:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You are not allowed to set others state"
|
||||
)
|
||||
else:
|
||||
sender_domain = UserID.from_string(
|
||||
event.user_id
|
||||
).domain
|
||||
|
||||
if sender_domain != event.state_key:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You are not allowed to set others state"
|
||||
)
|
||||
else:
|
||||
sender_domain = UserID.from_string(
|
||||
event.user_id
|
||||
).domain
|
||||
|
||||
if sender_domain != event.state_key:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You are not allowed to set others state"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -657,7 +657,8 @@ def run(hs):
|
|||
|
||||
if hs.config.daemonize:
|
||||
|
||||
print hs.config.pid_file
|
||||
if hs.config.print_pidfile:
|
||||
print hs.config.pid_file
|
||||
|
||||
daemon = Daemonize(
|
||||
app="synapse-homeserver",
|
||||
|
|
|
@ -138,12 +138,19 @@ class Config(object):
|
|||
action="store_true",
|
||||
help="Generate a config file for the server name"
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"--generate-keys",
|
||||
action="store_true",
|
||||
help="Generate any missing key files then exit"
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"-H", "--server-name",
|
||||
help="The server name to generate a config file for"
|
||||
)
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
|
||||
generate_keys = config_args.generate_keys
|
||||
|
||||
if config_args.generate_config:
|
||||
if not config_args.config_path:
|
||||
config_parser.error(
|
||||
|
@ -151,51 +158,40 @@ class Config(object):
|
|||
" generated using \"--generate-config -H SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
server_name = config_args.server_name
|
||||
if not server_name:
|
||||
print "Must specify a server_name to a generate config for."
|
||||
sys.exit(1)
|
||||
(config_path,) = config_args.config_path
|
||||
if not os.path.exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
if os.path.exists(config_path):
|
||||
print "Config file %r already exists" % (config_path,)
|
||||
yaml_config = cls.read_config_file(config_path)
|
||||
yaml_name = yaml_config["server_name"]
|
||||
if server_name != yaml_name:
|
||||
print (
|
||||
"Config file %r has a different server_name: "
|
||||
" %r != %r" % (config_path, server_name, yaml_name)
|
||||
)
|
||||
if not os.path.exists(config_path):
|
||||
config_dir_path = os.path.dirname(config_path)
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
server_name = config_args.server_name
|
||||
if not server_name:
|
||||
print "Must specify a server_name to a generate config for."
|
||||
sys.exit(1)
|
||||
config_bytes, config = obj.generate_config(
|
||||
config_dir_path, server_name
|
||||
)
|
||||
config.update(yaml_config)
|
||||
print "Generating any missing keys for %r" % (server_name,)
|
||||
obj.invoke_all("generate_files", config)
|
||||
sys.exit(0)
|
||||
with open(config_path, "wb") as config_file:
|
||||
config_bytes, config = obj.generate_config(
|
||||
config_dir_path, server_name
|
||||
)
|
||||
obj.invoke_all("generate_files", config)
|
||||
config_file.write(config_bytes)
|
||||
if not os.path.exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
with open(config_path, "wb") as config_file:
|
||||
config_bytes, config = obj.generate_config(
|
||||
config_dir_path, server_name
|
||||
)
|
||||
obj.invoke_all("generate_files", config)
|
||||
config_file.write(config_bytes)
|
||||
print (
|
||||
"A config file has been generated in %s for server name"
|
||||
" '%s' with corresponding SSL keys and self-signed"
|
||||
" certificates. Please review this file and customise it to"
|
||||
" your needs."
|
||||
"A config file has been generated in %r for server name"
|
||||
" %r with corresponding SSL keys and self-signed"
|
||||
" certificates. Please review this file and customise it"
|
||||
" to your needs."
|
||||
) % (config_path, server_name)
|
||||
print (
|
||||
"If this server name is incorrect, you will need to regenerate"
|
||||
" the SSL certificates"
|
||||
)
|
||||
sys.exit(0)
|
||||
print (
|
||||
"If this server name is incorrect, you will need to"
|
||||
" regenerate the SSL certificates"
|
||||
)
|
||||
sys.exit(0)
|
||||
else:
|
||||
print (
|
||||
"Config file %r already exists. Generating any missing key"
|
||||
" files."
|
||||
) % (config_path,)
|
||||
generate_keys = True
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[config_parser],
|
||||
|
@ -213,7 +209,7 @@ class Config(object):
|
|||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
config_dir_path = os.path.dirname(config_args.config_path[-1])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
specified_config = {}
|
||||
|
@ -226,6 +222,10 @@ class Config(object):
|
|||
config.pop("log_config")
|
||||
config.update(specified_config)
|
||||
|
||||
if generate_keys:
|
||||
obj.invoke_all("generate_files", config)
|
||||
sys.exit(0)
|
||||
|
||||
obj.invoke_all("read_config", config)
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
|
|
@ -24,6 +24,7 @@ class ServerConfig(Config):
|
|||
self.web_client = config["web_client"]
|
||||
self.soft_file_limit = config["soft_file_limit"]
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.print_pidfile = config.get("print_pidfile")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||
|
||||
self.listeners = config.get("listeners", [])
|
||||
|
@ -208,12 +209,18 @@ class ServerConfig(Config):
|
|||
self.manhole = args.manhole
|
||||
if args.daemonize is not None:
|
||||
self.daemonize = args.daemonize
|
||||
if args.print_pidfile is not None:
|
||||
self.print_pidfile = args.print_pidfile
|
||||
|
||||
def add_arguments(self, parser):
|
||||
server_group = parser.add_argument_group("server")
|
||||
server_group.add_argument("-D", "--daemonize", action='store_true',
|
||||
default=None,
|
||||
help="Daemonize the home server")
|
||||
server_group.add_argument("--print-pidfile", action='store_true',
|
||||
default=None,
|
||||
help="Print the path to the pidfile just"
|
||||
" before daemonizing")
|
||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
|
|
|
@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
|
|||
http_client = SimpleHttpClient(self.hs)
|
||||
# XXX: make this configurable!
|
||||
# trustedIdServers = ['matrix.org', 'localhost:8090']
|
||||
trustedIdServers = ['matrix.org']
|
||||
trustedIdServers = ['matrix.org', 'vector.im']
|
||||
|
||||
if 'id_server' in creds:
|
||||
id_server = creds['id_server']
|
||||
|
|
|
@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
|
|||
localpart : The local part of the user ID to register. If None,
|
||||
one will be randomly generated.
|
||||
password (str) : The password to assign to this user so they can
|
||||
login again.
|
||||
login again. This can be None which means they cannot login again
|
||||
via a password (e.g. the user is an application service user).
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
from twisted.internet import defer, reactor, protocol
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
|
||||
from twisted.web.client import readBody, HTTPConnectionPool, Agent
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
|
@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
|
|||
)
|
||||
|
||||
|
||||
class MatrixFederationHttpAgent(_AgentBase):
|
||||
class MatrixFederationEndpointFactory(object):
|
||||
def __init__(self, hs):
|
||||
self.tls_context_factory = hs.tls_context_factory
|
||||
|
||||
def __init__(self, reactor, pool=None):
|
||||
_AgentBase.__init__(self, reactor, pool)
|
||||
def endpointForURI(self, uri):
|
||||
destination = uri.netloc
|
||||
|
||||
def request(self, destination, endpoint, method, path, params, query,
|
||||
headers, body_producer):
|
||||
|
||||
outgoing_requests_counter.inc(method)
|
||||
|
||||
host = b""
|
||||
port = 0
|
||||
fragment = b""
|
||||
|
||||
parsed_URI = _URI(b"http", destination, host, port, path, params,
|
||||
query, fragment)
|
||||
|
||||
# Set the connection pool key to be the destination.
|
||||
key = destination
|
||||
|
||||
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
|
||||
headers, body_producer,
|
||||
parsed_URI.originForm)
|
||||
|
||||
def _cb(response):
|
||||
incoming_responses_counter.inc(method, response.code)
|
||||
return response
|
||||
|
||||
def _eb(failure):
|
||||
incoming_responses_counter.inc(method, "ERR")
|
||||
return failure
|
||||
|
||||
d.addCallbacks(_cb, _eb)
|
||||
|
||||
return d
|
||||
return matrix_federation_endpoint(
|
||||
reactor, destination, timeout=10,
|
||||
ssl_context_factory=self.tls_context_factory
|
||||
)
|
||||
|
||||
|
||||
class MatrixFederationHttpClient(object):
|
||||
|
@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
|
|||
self.server_name = hs.hostname
|
||||
pool = HTTPConnectionPool(reactor)
|
||||
pool.maxPersistentPerHost = 10
|
||||
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
|
||||
self.agent = Agent.usingEndpointFactory(
|
||||
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
||||
)
|
||||
self.clock = hs.get_clock()
|
||||
self.version_string = hs.version_string
|
||||
|
||||
self._next_id = 1
|
||||
|
||||
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
|
||||
return urlparse.urlunparse(
|
||||
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_request(self, destination, method, path_bytes,
|
||||
body_callback, headers_dict={}, param_bytes=b"",
|
||||
|
@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
|
|||
headers_dict[b"User-Agent"] = [self.version_string]
|
||||
headers_dict[b"Host"] = [destination]
|
||||
|
||||
url_bytes = urlparse.urlunparse(
|
||||
("", "", path_bytes, param_bytes, query_bytes, "",)
|
||||
url_bytes = self._create_url(
|
||||
destination, path_bytes, param_bytes, query_bytes
|
||||
)
|
||||
|
||||
txn_id = "%s-O-%s" % (method, self._next_id)
|
||||
|
@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
|
|||
# (once we have reliable transactions in place)
|
||||
retries_left = 5
|
||||
|
||||
endpoint = preserve_context_over_fn(
|
||||
self._getEndpoint, reactor, destination
|
||||
http_url_bytes = urlparse.urlunparse(
|
||||
("", "", path_bytes, param_bytes, query_bytes, "")
|
||||
)
|
||||
|
||||
log_result = None
|
||||
|
@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object):
|
|||
while True:
|
||||
producer = None
|
||||
if body_callback:
|
||||
producer = body_callback(method, url_bytes, headers_dict)
|
||||
producer = body_callback(method, http_url_bytes, headers_dict)
|
||||
|
||||
try:
|
||||
def send_request():
|
||||
request_deferred = self.agent.request(
|
||||
destination,
|
||||
endpoint,
|
||||
request_deferred = preserve_context_over_fn(
|
||||
self.agent.request,
|
||||
method,
|
||||
path_bytes,
|
||||
param_bytes,
|
||||
query_bytes,
|
||||
url_bytes,
|
||||
Headers(headers_dict),
|
||||
producer
|
||||
)
|
||||
|
@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
defer.returnValue((length, headers))
|
||||
|
||||
def _getEndpoint(self, reactor, destination):
|
||||
return matrix_federation_endpoint(
|
||||
reactor, destination, timeout=10,
|
||||
ssl_context_factory=self.hs.tls_context_factory
|
||||
)
|
||||
|
||||
|
||||
class _ReadBodyToFileProtocol(protocol.Protocol):
|
||||
def __init__(self, stream, deferred, max_size):
|
||||
|
|
|
@ -18,8 +18,12 @@ from __future__ import absolute_import
|
|||
|
||||
import logging
|
||||
from resource import getrusage, getpagesize, RUSAGE_SELF
|
||||
import functools
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
from .metric import (
|
||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
||||
|
@ -144,3 +148,28 @@ def _process_fds():
|
|||
return counts
|
||||
|
||||
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
|
||||
|
||||
reactor_metrics = get_metrics_for("reactor")
|
||||
tick_time = reactor_metrics.register_distribution("tick_time")
|
||||
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
|
||||
|
||||
|
||||
def runUntilCurrentTimer(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def f(*args, **kwargs):
|
||||
pending_calls = len(reactor.getDelayedCalls())
|
||||
start = time.time() * 1000
|
||||
ret = func(*args, **kwargs)
|
||||
end = time.time() * 1000
|
||||
tick_time.inc_by(end - start)
|
||||
pending_calls_metric.inc_by(pending_calls)
|
||||
return ret
|
||||
|
||||
return f
|
||||
|
||||
|
||||
if hasattr(reactor, "runUntilCurrent"):
|
||||
# runUntilCurrent is called when we have pending calls. It is called once
|
||||
# per iteratation after fd polling.
|
||||
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
|
||||
|
|
|
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
REQUIREMENTS = {
|
||||
"syutil>=0.0.7": ["syutil>=0.0.7"],
|
||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
||||
"Twisted>=15.1.0": ["twisted>=15.1.0"],
|
||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||
"pyyaml": ["yaml"],
|
||||
|
|
|
@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
|
|||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_pattern, parse_request_allow_empty
|
||||
from ._base import client_v2_pattern, parse_json_dict_from_request
|
||||
|
||||
import logging
|
||||
import hmac
|
||||
|
@ -55,30 +55,55 @@ class RegisterRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
body = parse_json_dict_from_request(request)
|
||||
|
||||
body = parse_request_allow_empty(request)
|
||||
# we do basic sanity checks here because the auth
|
||||
# layer will store these in sessions
|
||||
# we do basic sanity checks here because the auth layer will store these
|
||||
# in sessions. Pull out the username/password provided to us.
|
||||
desired_password = None
|
||||
if 'password' in body:
|
||||
if ((not isinstance(body['password'], str) and
|
||||
not isinstance(body['password'], unicode)) or
|
||||
if (not isinstance(body['password'], basestring) or
|
||||
len(body['password']) > 512):
|
||||
raise SynapseError(400, "Invalid password")
|
||||
desired_password = body["password"]
|
||||
|
||||
desired_username = None
|
||||
if 'username' in body:
|
||||
if ((not isinstance(body['username'], str) and
|
||||
not isinstance(body['username'], unicode)) or
|
||||
if (not isinstance(body['username'], basestring) or
|
||||
len(body['username']) > 512):
|
||||
raise SynapseError(400, "Invalid username")
|
||||
desired_username = body['username']
|
||||
yield self.registration_handler.check_username(desired_username)
|
||||
|
||||
is_using_shared_secret = False
|
||||
is_application_server = False
|
||||
|
||||
service = None
|
||||
appservice = None
|
||||
if 'access_token' in request.args:
|
||||
service = yield self.auth.get_appservice_by_req(request)
|
||||
appservice = yield self.auth.get_appservice_by_req(request)
|
||||
|
||||
# fork off as soon as possible for ASes and shared secret auth which
|
||||
# have completely different registration flows to normal users
|
||||
|
||||
# == Application Service Registration ==
|
||||
if appservice:
|
||||
result = yield self._do_appservice_registration(
|
||||
desired_username, request.args["access_token"][0]
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
||||
# == Shared Secret Registration == (e.g. create new user scripts)
|
||||
if 'mac' in body:
|
||||
# FIXME: Should we really be determining if this is shared secret
|
||||
# auth based purely on the 'mac' key?
|
||||
result = yield self._do_shared_secret_registration(
|
||||
desired_username, desired_password, body["mac"]
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
||||
# == Normal User Registration == (everyone else)
|
||||
if self.hs.config.disable_registration:
|
||||
raise SynapseError(403, "Registration has been disabled")
|
||||
|
||||
if desired_username is not None:
|
||||
yield self.registration_handler.check_username(desired_username)
|
||||
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
flows = [
|
||||
|
@ -91,39 +116,20 @@ class RegisterRestServlet(RestServlet):
|
|||
[LoginType.EMAIL_IDENTITY]
|
||||
]
|
||||
|
||||
result = None
|
||||
if service:
|
||||
is_application_server = True
|
||||
params = body
|
||||
elif 'mac' in body:
|
||||
# Check registration-specific shared secret auth
|
||||
if 'username' not in body:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
self._check_shared_secret_auth(
|
||||
body['username'], body['mac']
|
||||
)
|
||||
is_using_shared_secret = True
|
||||
params = body
|
||||
else:
|
||||
authed, result, params = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
can_register = (
|
||||
not self.hs.config.disable_registration
|
||||
or is_application_server
|
||||
or is_using_shared_secret
|
||||
authed, result, params = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
if not can_register:
|
||||
raise SynapseError(403, "Registration has been disabled")
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
return
|
||||
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
desired_username = params['username'] if 'username' in params else None
|
||||
new_password = params['password']
|
||||
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
new_password = params.get("password", None)
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
|
@ -156,18 +162,21 @@ class RegisterRestServlet(RestServlet):
|
|||
else:
|
||||
logger.info("bind_email not specified: not binding email")
|
||||
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
result = self._create_registration_details(user_id, token)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
def _check_shared_secret_auth(self, username, mac):
|
||||
@defer.inlineCallbacks
|
||||
def _do_appservice_registration(self, username, as_token):
|
||||
(user_id, token) = yield self.registration_handler.appservice_register(
|
||||
username, as_token
|
||||
)
|
||||
defer.returnValue(self._create_registration_details(user_id, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_shared_secret_registration(self, username, password, mac):
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
|
@ -183,13 +192,23 @@ class RegisterRestServlet(RestServlet):
|
|||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
if compare_digest(want_mac, got_mac):
|
||||
return True
|
||||
else:
|
||||
if not compare_digest(want_mac, got_mac):
|
||||
raise SynapseError(
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=username, password=password
|
||||
)
|
||||
defer.returnValue(self._create_registration_details(user_id, token))
|
||||
|
||||
def _create_registration_details(self, user_id, token):
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -244,43 +244,52 @@ class BaseMediaResource(Resource):
|
|||
)
|
||||
return
|
||||
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
local_thumbnails = []
|
||||
|
||||
def generate_thumbnails():
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
|
||||
local_thumbnails.append((
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
local_thumbnails.append((
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
))
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
yield threads.deferToThread(generate_thumbnails)
|
||||
|
||||
for l in local_thumbnails:
|
||||
yield self.store.store_local_thumbnail(*l)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
|
|
|
@ -162,11 +162,12 @@ class ThumbnailResource(BaseMediaResource):
|
|||
t_method = info["thumbnail_method"]
|
||||
if t_method == "scale" or t_method == "crop":
|
||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
info_list.append((
|
||||
aspect_quality, size_quality, type_quality,
|
||||
aspect_quality, min_quality, size_quality, type_quality,
|
||||
length_quality, info
|
||||
))
|
||||
if info_list:
|
||||
|
|
|
@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
key = (user.to_string(), access_token, device_id, ip)
|
||||
|
||||
try:
|
||||
last_seen = self.client_ip_last_seen.get(*key)
|
||||
last_seen = self.client_ip_last_seen.get(key)
|
||||
except KeyError:
|
||||
last_seen = None
|
||||
|
||||
|
@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||
defer.returnValue(None)
|
||||
|
||||
self.client_ip_last_seen.prefill(*key + (now,))
|
||||
self.client_ip_last_seen.prefill(key, now)
|
||||
|
||||
# It's safe not to lock here: a) no unique constraint,
|
||||
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
||||
|
@ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
|||
)
|
||||
logger.debug("Running script %s", relative_path)
|
||||
module.run_upgrade(cur, database_engine)
|
||||
elif ext == ".pyc":
|
||||
# Sometimes .pyc files turn up anyway even though we've
|
||||
# disabled their generation; e.g. from distribution package
|
||||
# installers. Silently skip it
|
||||
pass
|
||||
elif ext == ".sql":
|
||||
# A plain old .sql file, just read and execute it
|
||||
logger.debug("Applying schema %s", relative_path)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import logging
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||
from synapse.util.lrucache import LruCache
|
||||
|
@ -27,6 +28,7 @@ from twisted.internet import defer
|
|||
from collections import namedtuple, OrderedDict
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
|
@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
|
|||
)
|
||||
|
||||
|
||||
_CacheSentinel = object()
|
||||
|
||||
|
||||
class Cache(object):
|
||||
|
||||
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
|
||||
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
|
||||
if lru:
|
||||
self.cache = LruCache(max_size=max_entries)
|
||||
self.max_entries = None
|
||||
|
@ -81,45 +86,44 @@ class Cache(object):
|
|||
"Cache objects can only be accessed from the main thread"
|
||||
)
|
||||
|
||||
def get(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
if keyargs in self.cache:
|
||||
def get(self, key, default=_CacheSentinel):
|
||||
val = self.cache.get(key, _CacheSentinel)
|
||||
if val is not _CacheSentinel:
|
||||
cache_counter.inc_hits(self.name)
|
||||
return self.cache[keyargs]
|
||||
return val
|
||||
|
||||
cache_counter.inc_misses(self.name)
|
||||
raise KeyError()
|
||||
|
||||
def update(self, sequence, *args):
|
||||
if default is _CacheSentinel:
|
||||
raise KeyError()
|
||||
else:
|
||||
return default
|
||||
|
||||
def update(self, sequence, key, value):
|
||||
self.check_thread()
|
||||
if self.sequence == sequence:
|
||||
# Only update the cache if the caches sequence number matches the
|
||||
# number that the cache had before the SELECT was started (SYN-369)
|
||||
self.prefill(*args)
|
||||
|
||||
def prefill(self, *args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
self.prefill(key, value)
|
||||
|
||||
def prefill(self, key, value):
|
||||
if self.max_entries is not None:
|
||||
while len(self.cache) >= self.max_entries:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
self.cache[keyargs] = value
|
||||
self.cache[key] = value
|
||||
|
||||
def invalidate(self, *keyargs):
|
||||
def invalidate(self, key):
|
||||
self.check_thread()
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError(
|
||||
"The cache key must be a tuple not %r" % (type(key),)
|
||||
)
|
||||
|
||||
# Increment the sequence number so that any SELECT statements that
|
||||
# raced with the INSERT don't update the cache (SYN-369)
|
||||
self.sequence += 1
|
||||
self.cache.pop(keyargs, None)
|
||||
self.cache.pop(key, None)
|
||||
|
||||
def invalidate_all(self):
|
||||
self.check_thread()
|
||||
|
@ -130,6 +134,9 @@ class Cache(object):
|
|||
class CacheDescriptor(object):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
This caches deferreds, rather than the results themselves. Deferreds that
|
||||
fail are removed from the cache.
|
||||
|
||||
The function is presumed to take zero or more arguments, which are used in
|
||||
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||
misses use the function body to generate the value.
|
||||
|
@ -141,58 +148,92 @@ class CacheDescriptor(object):
|
|||
which can be used to insert values into the cache specifically, without
|
||||
calling the calculation function.
|
||||
"""
|
||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
|
||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
|
||||
inlineCallbacks=False):
|
||||
self.orig = orig
|
||||
|
||||
if inlineCallbacks:
|
||||
self.function_to_call = defer.inlineCallbacks(orig)
|
||||
else:
|
||||
self.function_to_call = orig
|
||||
|
||||
self.max_entries = max_entries
|
||||
self.num_args = num_args
|
||||
self.lru = lru
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
cache = Cache(
|
||||
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
|
||||
|
||||
if len(self.arg_names) < self.num_args:
|
||||
raise Exception(
|
||||
"Not enough explicit positional arguments to key off of for %r."
|
||||
" (@cached cannot key off of *args or **kwars)"
|
||||
% (orig.__name__,)
|
||||
)
|
||||
|
||||
self.cache = Cache(
|
||||
name=self.orig.__name__,
|
||||
max_entries=self.max_entries,
|
||||
keylen=self.num_args,
|
||||
lru=self.lru,
|
||||
)
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(*keyargs):
|
||||
def wrapped(*args, **kwargs):
|
||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||
try:
|
||||
cached_result = cache.get(*keyargs[:self.num_args])
|
||||
cached_result_d = self.cache.get(cache_key)
|
||||
|
||||
observer = cached_result_d.observe()
|
||||
if DEBUG_CACHES:
|
||||
actual_result = yield self.orig(obj, *keyargs)
|
||||
if actual_result != cached_result:
|
||||
logger.error(
|
||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||
self.orig.__name__, keyargs,
|
||||
cached_result, actual_result,
|
||||
)
|
||||
raise ValueError("Stale cache entry")
|
||||
defer.returnValue(cached_result)
|
||||
@defer.inlineCallbacks
|
||||
def check_result(cached_result):
|
||||
actual_result = yield self.function_to_call(obj, *args, **kwargs)
|
||||
if actual_result != cached_result:
|
||||
logger.error(
|
||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||
self.orig.__name__, cache_key,
|
||||
cached_result, actual_result,
|
||||
)
|
||||
raise ValueError("Stale cache entry")
|
||||
defer.returnValue(cached_result)
|
||||
observer.addCallback(check_result)
|
||||
|
||||
return observer
|
||||
except KeyError:
|
||||
# Get the sequence number of the cache before reading from the
|
||||
# database so that we can tell if the cache is invalidated
|
||||
# while the SELECT is executing (SYN-369)
|
||||
sequence = cache.sequence
|
||||
sequence = self.cache.sequence
|
||||
|
||||
ret = yield self.orig(obj, *keyargs)
|
||||
ret = defer.maybeDeferred(
|
||||
self.function_to_call,
|
||||
obj, *args, **kwargs
|
||||
)
|
||||
|
||||
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
|
||||
def onErr(f):
|
||||
self.cache.invalidate(cache_key)
|
||||
return f
|
||||
|
||||
defer.returnValue(ret)
|
||||
ret.addErrback(onErr)
|
||||
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.invalidate_all = cache.invalidate_all
|
||||
wrapped.prefill = cache.prefill
|
||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||
self.cache.update(sequence, cache_key, ret)
|
||||
|
||||
return ret.observe()
|
||||
|
||||
wrapped.invalidate = self.cache.invalidate
|
||||
wrapped.invalidate_all = self.cache.invalidate_all
|
||||
wrapped.prefill = self.cache.prefill
|
||||
|
||||
obj.__dict__[self.orig.__name__] = wrapped
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def cached(max_entries=1000, num_args=1, lru=False):
|
||||
def cached(max_entries=1000, num_args=1, lru=True):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
|
@ -201,6 +242,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
|
|||
)
|
||||
|
||||
|
||||
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
lru=lru,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
|
||||
|
||||
class LoggingTransaction(object):
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
|
|
|
@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
|
|||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
self.get_aliases_for_room.invalidate(room_id)
|
||||
self.get_aliases_for_room.invalidate((room_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_room_alias(self, room_alias):
|
||||
|
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
|
|||
room_alias,
|
||||
)
|
||||
|
||||
self.get_aliases_for_room.invalidate(room_id)
|
||||
self.get_aliases_for_room.invalidate((room_id,))
|
||||
defer.returnValue(room_id)
|
||||
|
||||
def _delete_room_alias_txn(self, txn, room_alias):
|
||||
|
|
|
@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
|
|||
|
||||
for room_id in events_by_room:
|
||||
txn.call_after(
|
||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
||||
self.get_latest_event_ids_in_room.invalidate, (room_id,)
|
||||
)
|
||||
|
||||
def get_backfill_events(self, room_id, event_list, limit):
|
||||
|
@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
|
|||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||
|
|
|
@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
|
|||
if current_state:
|
||||
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
||||
|
||||
self._simple_delete_txn(
|
||||
|
@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
|
|||
if not context.rejected:
|
||||
txn.call_after(
|
||||
self.get_current_state_for_key.invalidate,
|
||||
event.room_id, event.type, event.state_key
|
||||
)
|
||||
(event.room_id, event.type, event.state_key,)
|
||||
)
|
||||
|
||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||
txn.call_after(
|
||||
self.get_room_name_and_aliases.invalidate,
|
||||
event.room_id
|
||||
(event.room_id,)
|
||||
)
|
||||
|
||||
self._simple_upsert_txn(
|
||||
|
@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
|
|||
def _invalidate_get_event_cache(self, event_id):
|
||||
for check_redacted in (False, True):
|
||||
for get_prev_content in (False, True):
|
||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
||||
get_prev_content)
|
||||
self._get_event_cache.invalidate(
|
||||
(event_id, check_redacted, get_prev_content)
|
||||
)
|
||||
|
||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
|
@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
|
|||
for event_id in events:
|
||||
try:
|
||||
ret = self._get_event_cache.get(
|
||||
event_id, check_redacted, get_prev_content
|
||||
(event_id, check_redacted, get_prev_content,)
|
||||
)
|
||||
|
||||
if allow_rejected or not ret.rejected_reason:
|
||||
|
@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
|
|||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||
|
||||
self._get_event_cache.prefill(
|
||||
ev.event_id, check_redacted, get_prev_content, ev
|
||||
(ev.event_id, check_redacted, get_prev_content), ev
|
||||
)
|
||||
|
||||
defer.returnValue(ev)
|
||||
|
@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
|
|||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||
|
||||
self._get_event_cache.prefill(
|
||||
ev.event_id, check_redacted, get_prev_content, ev
|
||||
(ev.event_id, check_redacted, get_prev_content), ev
|
||||
)
|
||||
|
||||
return ev
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from _base import SQLBaseStore, cached
|
||||
from _base import SQLBaseStore, cachedInlineCallbacks
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
|
|||
desc="store_server_certificate",
|
||||
)
|
||||
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks()
|
||||
def get_all_server_verify_keys(self, server_name):
|
||||
rows = yield self._simple_select_list(
|
||||
table="server_signature_keys",
|
||||
|
@ -132,7 +131,7 @@ class KeyStore(SQLBaseStore):
|
|||
desc="store_server_verify_key",
|
||||
)
|
||||
|
||||
self.get_all_server_verify_keys.invalidate(server_name)
|
||||
self.get_all_server_verify_keys.invalidate((server_name,))
|
||||
|
||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||
|
|
|
@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
|
|||
updatevalues={"accepted": True},
|
||||
desc="set_presence_list_accepted",
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
||||
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||
defer.returnValue(result)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
|
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
|
|||
"observed_user_id": observed_userid},
|
||||
desc="del_presence_list",
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
||||
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class PushRuleStore(SQLBaseStore):
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks()
|
||||
def get_push_rules_for_user(self, user_name):
|
||||
rows = yield self._simple_select_list(
|
||||
table=PushRuleTable.table_name,
|
||||
|
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue(rows)
|
||||
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks()
|
||||
def get_push_rules_enabled_for_user(self, user_name):
|
||||
results = yield self._simple_select_list(
|
||||
table=PushRuleEnableTable.table_name,
|
||||
|
@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
|
|||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
|
@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
|
|||
new_rule['priority'] = new_prio
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
|
@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
|
|||
desc="delete_push_rule",
|
||||
)
|
||||
|
||||
self.get_push_rules_for_user.invalidate(user_name)
|
||||
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
||||
self.get_push_rules_for_user.invalidate((user_name,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((user_name,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||
|
@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
|
|||
{'id': new_id},
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_max_token(self)
|
||||
|
||||
@cached
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks()
|
||||
def get_graph_receipts_for_room(self, room_id):
|
||||
"""Get receipts for sending to remote servers.
|
||||
"""
|
||||
|
|
|
@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
|
|||
user_id
|
||||
)
|
||||
for r in rows:
|
||||
self.get_user_by_token.invalidate(r)
|
||||
self.get_user_by_token.invalidate((r,))
|
||||
|
||||
@cached()
|
||||
def get_user_by_token(self, token):
|
||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
|
|||
}
|
||||
)
|
||||
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks()
|
||||
def get_room_name_and_aliases(self, room_id):
|
||||
def f(txn):
|
||||
sql = (
|
||||
|
|
|
@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
for event in events:
|
||||
txn.call_after(self.get_rooms_for_user.invalidate, event.state_key)
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||
|
||||
def get_room_member(self, user_id, room_id):
|
||||
"""Retrieve the current state of a room member.
|
||||
|
@ -78,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
lambda events: events[0] if events else None
|
||||
)
|
||||
|
||||
@cached()
|
||||
@cached(max_entries=5000)
|
||||
def get_users_in_room(self, room_id):
|
||||
def f(txn):
|
||||
|
||||
|
@ -154,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
||||
]
|
||||
|
||||
@cached()
|
||||
@cached(max_entries=5000)
|
||||
def get_joined_hosts_for_room(self, room_id):
|
||||
return self.runInteraction(
|
||||
"get_joined_hosts_for_room",
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -91,7 +91,6 @@ class StateStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue(dict(state_list))
|
||||
|
||||
@cached(num_args=1)
|
||||
def _fetch_events_for_group(self, key, events):
|
||||
return self._get_events(
|
||||
events, get_prev_content=False
|
||||
|
@ -189,8 +188,7 @@ class StateStore(SQLBaseStore):
|
|||
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||
defer.returnValue(events)
|
||||
|
||||
@cached(num_args=3)
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks(num_args=3)
|
||||
def get_current_state_for_key(self, room_id, event_type, state_key):
|
||||
def f(txn):
|
||||
sql = (
|
||||
|
|
|
@ -178,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
|
||||
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
||||
event it comes after. Historic tokens start with a "t" followed by the
|
||||
"topological_ordering" id of the event it comes after, follewed by "-",
|
||||
"topological_ordering" id of the event it comes after, followed by "-",
|
||||
followed by the "stream_ordering" id of the event it comes after.
|
||||
"""
|
||||
__slots__ = []
|
||||
|
@ -211,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
return "s%d" % (self.stream,)
|
||||
|
||||
|
||||
# token_id is the primary key ID of the access token, not the access token itself.
|
||||
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
||||
|
|
|
@ -51,7 +51,7 @@ class ObservableDeferred(object):
|
|||
object.__setattr__(self, "_observers", set())
|
||||
|
||||
def callback(r):
|
||||
self._result = (True, r)
|
||||
object.__setattr__(self, "_result", (True, r))
|
||||
while self._observers:
|
||||
try:
|
||||
self._observers.pop().callback(r)
|
||||
|
@ -60,7 +60,7 @@ class ObservableDeferred(object):
|
|||
return r
|
||||
|
||||
def errback(f):
|
||||
self._result = (False, f)
|
||||
object.__setattr__(self, "_result", (False, f))
|
||||
while self._observers:
|
||||
try:
|
||||
self._observers.pop().errback(f)
|
||||
|
@ -97,3 +97,8 @@ class ObservableDeferred(object):
|
|||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(self._deferred, name, value)
|
||||
|
||||
def __repr__(self):
|
||||
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
|
||||
id(self), self._result, self._deferred,
|
||||
)
|
||||
|
|
134
tests/rest/client/v2_alpha/test_register.py
Normal file
134
tests/rest/client/v2_alpha/test_register.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.api.errors import SynapseError
|
||||
from twisted.internet import defer
|
||||
from mock import Mock, MagicMock
|
||||
from tests import unittest
|
||||
import json
|
||||
|
||||
|
||||
class RegisterRestServletTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# do the dance to hook up request data to self.request_data
|
||||
self.request_data = ""
|
||||
self.request = Mock(
|
||||
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||
)
|
||||
self.request.args = {}
|
||||
|
||||
self.appservice = None
|
||||
self.auth = Mock(get_appservice_by_req=Mock(
|
||||
side_effect=lambda x: defer.succeed(self.appservice))
|
||||
)
|
||||
|
||||
self.auth_result = (False, None, None)
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
|
||||
)
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
auth_handler=self.auth_handler,
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler
|
||||
)
|
||||
self.hs = Mock()
|
||||
self.hs.hostname = "superbig~testing~thing.com"
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.config.disable_registration = False
|
||||
|
||||
# init the thing we're testing
|
||||
self.servlet = RegisterRestServlet(self.hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
self.appservice = {
|
||||
"id": "1234"
|
||||
}
|
||||
self.registration_handler.appservice_register = Mock(
|
||||
return_value=(user_id, token)
|
||||
)
|
||||
result = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(result, (200, {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_invalid(self):
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
self.appservice = None # no application service exists
|
||||
result = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(result, (401, None))
|
||||
|
||||
def test_POST_bad_password(self):
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": 666
|
||||
})
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
|
||||
def test_POST_bad_username(self):
|
||||
self.request_data = json.dumps({
|
||||
"username": 777,
|
||||
"password": "monkey"
|
||||
})
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_user_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (True, None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.register = Mock(return_value=(user_id, token))
|
||||
|
||||
result = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(result, (200, {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
}))
|
||||
|
||||
def test_POST_disabled_registration(self):
|
||||
self.hs.config.disable_registration = True
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (True, None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
|
@ -17,6 +17,8 @@
|
|||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
|
||||
from synapse.storage._base import Cache, cached
|
||||
|
||||
|
||||
|
@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
|
|||
self.assertEquals(self.cache.get("foo"), 123)
|
||||
|
||||
def test_invalidate(self):
|
||||
self.cache.prefill("foo", 123)
|
||||
self.cache.invalidate("foo")
|
||||
self.cache.prefill(("foo",), 123)
|
||||
self.cache.invalidate(("foo",))
|
||||
|
||||
failed = False
|
||||
try:
|
||||
self.cache.get("foo")
|
||||
self.cache.get(("foo",))
|
||||
except KeyError:
|
||||
failed = True
|
||||
|
||||
|
@ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEquals(callcount[0], 1)
|
||||
|
||||
a.func.invalidate("foo")
|
||||
a.func.invalidate(("foo",))
|
||||
|
||||
yield a.func("foo")
|
||||
|
||||
|
@ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||
def func(self, key):
|
||||
return key
|
||||
|
||||
A().func.invalidate("what")
|
||||
A().func.invalidate(("what",))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_max_entries(self):
|
||||
|
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||
self.assertTrue(callcount[0] >= 14,
|
||||
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_prefill(self):
|
||||
callcount = [0]
|
||||
|
||||
d = defer.succeed(123)
|
||||
|
||||
class A(object):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
return d
|
||||
|
||||
a = A()
|
||||
|
||||
a.func.prefill("foo", 123)
|
||||
a.func.prefill(("foo",), ObservableDeferred(d))
|
||||
|
||||
self.assertEquals((yield a.func("foo")), 123)
|
||||
self.assertEquals(a.func("foo").result, d.result)
|
||||
self.assertEquals(callcount[0], 0)
|
||||
|
|
|
@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
|
|||
yield d
|
||||
self.assertTrue(d.called)
|
||||
|
||||
observers[0].assert_called_once("Go")
|
||||
observers[1].assert_called_once("Go")
|
||||
observers[0].assert_called_once_with("Go")
|
||||
observers[1].assert_called_once_with("Go")
|
||||
|
||||
self.assertEquals(mock_logger.warning.call_count, 1)
|
||||
self.assertIsInstance(mock_logger.warning.call_args[0][0],
|
||||
|
|
Loading…
Reference in a new issue