0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 08:33:48 +01:00

Merge branch 'develop' into rav/refactor_device_query

This commit is contained in:
Mark Haines 2016-08-03 11:12:47 +01:00
commit 921f17f938
22 changed files with 507 additions and 139 deletions

View file

@ -24,5 +24,7 @@ recursive-include synapse/static *.js
exclude jenkins.sh
exclude jenkins*.sh
exclude jenkins*
recursive-exclude jenkins *.sh
prune demo/etc

View file

@ -22,24 +22,10 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
./jenkins/prepare_synapse.sh
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .dendron-base ]]; then
git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
else
(cd .dendron-base; git fetch -p)
fi
rm -rf dendron
git clone .dendron-base dendron --shared
cd dendron
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
: ${GOPATH:=${WORKSPACE}/.gopath}
if [[ "${GOPATH}" != *:* ]]; then
@ -48,40 +34,32 @@ if [[ "${GOPATH}" != *:* ]]; then
fi
export GOPATH
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
cd dendron
go get github.com/constabulary/gb/...
gb generate
gb build
cd ..
cd ../sytest
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
: ${PORT_BASE:=20000}
: ${PORT_COUNT=100}
export PORT_BASE
export PORT_COUNT
./jenkins/prep_sytest_for_postgres.sh
mkdir -p var
echo >&2 "Running sytest with PostgreSQL";
TOX_BIN=$WORKSPACE/.tox/py27/bin
./jenkins/install_and_run.sh --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
--federation-reader \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd ..

View file

@ -22,37 +22,26 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
./jenkins/prepare_synapse.sh
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
: ${PORT_BASE:=20000}
: ${PORT_COUNT=100}
export PORT_BASE
export PORT_COUNT
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
TOX_BIN=$WORKSPACE/.tox/py27/bin
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd ..
cp sytest/.coverage.* .

View file

@ -4,6 +4,7 @@ set -eux
: ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
@ -22,27 +23,18 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
./jenkins/prepare_synapse.sh
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
: ${PORT_BASE:=20000}
: ${PORT_COUNT=100}
export PORT_BASE
export PORT_COUNT
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_COUNT=20}
: ${PORT_BASE:=8000}
TOX_BIN=$WORKSPACE/.tox/py27/bin
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \

24
jenkins/clone.sh Executable file
View file

@ -0,0 +1,24 @@
#! /bin/bash
NAME=$1
PROJECT=$2
BASE=".$NAME-base"
# update our clone
if [ ! -d .$NAME-base ]; then
git clone $PROJECT $BASE --mirror
else
(cd $BASE; git fetch -p)
fi
rm -rf $NAME
git clone $BASE $NAME --shared
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
cd $NAME
# check out the relevant branch
git checkout "${GIT_BRANCH}" || (
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git checkout "origin/develop"
)
git clean -df .

19
jenkins/prepare_synapse.sh Executable file
View file

@ -0,0 +1,19 @@
#! /bin/bash
cd "`dirname $0`/.."
TOX_DIR=$WORKSPACE/.tox
mkdir -p $TOX_DIR
if ! [ $TOX_DIR -ef .tox ]; then
ln -s "$TOX_DIR" .tox
fi
# set up the virtualenv
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

View file

@ -116,11 +116,12 @@ def get_json(origin_name, origin_key, destination, path):
authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items():
authorization_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
origin_name, key, sig,
)
))
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
origin_name, key, sig,
)
authorization_headers.append(bytes(header))
sys.stderr.write(header)
sys.stderr.write("\n")
result = requests.get(
lookup(destination, path),

View file

@ -0,0 +1,206 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import FEDERATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
TransactionStore,
BaseSlavedStore,
):
pass
class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse federation reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string=get_version_string("Synapse", synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-federation-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -21,10 +21,11 @@ from .units import Transaction, Edu
from synapse.util.async import Linearizer
from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
import synapse.metrics
from synapse.api.errors import FederationError, SynapseError
from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
@ -48,9 +49,15 @@ class FederationServer(FederationBase):
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer()
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
@ -188,33 +195,50 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
with (yield self._server_linearizer.queue((origin, room_id))):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id,
if not event_id:
raise NotImplementedError("Specify an event")
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
result = self._state_resp_cache.get((room_id, event_id))
if not result:
with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set(
(room_id, event_id),
self._on_context_state_request_compute(room_id, event_id)
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
else:
resp = yield result
defer.returnValue((200, resp))
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, {
defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
})
@defer.inlineCallbacks
@log_function

View file

@ -991,14 +991,9 @@ class FederationHandler(BaseHandler):
defer.returnValue(None)
@defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
def get_state_for_pdu(self, room_id, event_id):
yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
)

View file

@ -345,8 +345,8 @@ class RoomCreationHandler(BaseHandler):
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache()
self.remote_list_request_cache = ResponseCache()
self.response_cache = ResponseCache(hs)
self.remote_list_request_cache = ResponseCache(hs)
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL

View file

@ -138,7 +138,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache()
self.response_cache = ResponseCache(hs)
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):

View file

@ -14,6 +14,7 @@
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
@ -92,7 +93,11 @@ class EmailPusher(object):
def on_stop(self):
if self.timed_call:
self.timed_call.cancel()
try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@ -189,7 +194,10 @@ class EmailPusher(object):
soonest_due_at = should_notify_at
if self.timed_call is not None:
self.timed_call.cancel()
try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
if soonest_due_at is not None:

View file

@ -16,6 +16,7 @@
from synapse.push import PusherConfigException
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
import push_rule_evaluator
@ -109,7 +110,11 @@ class HttpPusher(object):
def on_stop(self):
if self.timed_call:
self.timed_call.cancel()
try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks
def _process(self):

View file

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
].orig

View file

@ -145,6 +145,15 @@ class SlavedEventStore(BaseSlavedStore):
_get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__
get_missing_events = DataStore.get_missing_events.__func__
_get_missing_events = DataStore._get_missing_events.__func__
get_auth_chain = DataStore.get_auth_chain.__func__
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()

View file

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.keys import KeyStore
class SlavedKeyStore(BaseSlavedStore):
_get_server_verify_key = KeyStore.__dict__[
"_get_server_verify_key"
]
get_server_verify_keys = DataStore.get_server_verify_keys.__func__
store_server_verify_key = DataStore.store_server_verify_key.__func__
get_server_certificate = DataStore.get_server_certificate.__func__
store_server_certificate = DataStore.store_server_certificate.__func__
get_server_keys_json = DataStore.get_server_keys_json.__func__
store_server_keys_json = DataStore.store_server_keys_json.__func__

View file

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
class RoomStore(BaseSlavedStore):
get_public_room_ids = DataStore.get_public_room_ids.__func__

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
class TransactionStore(BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[
"get_destination_retry_timings"
].orig
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
# For now, don't record the destination rety timings
def set_destination_retry_timings(*args, **kwargs):
return defer.succeed(None)

View file

@ -196,12 +196,12 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
authed, result, params, session_id = yield self.auth_handler.check_auth(
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
defer.returnValue((401, auth_result))
return
if registered_user_id is not None:
@ -236,18 +236,18 @@ class RegisterRestServlet(RestServlet):
add_email = True
result = yield self._create_registration_details(
return_dict = yield self._create_registration_details(
registered_user_id, params
)
if add_email and result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
yield self._register_email_threepid(
registered_user_id, threepid, result["access_token"],
registered_user_id, threepid, return_dict["access_token"],
params.get("bind_email")
)
defer.returnValue((200, result))
defer.returnValue((200, return_dict))
def on_OPTIONS(self, _):
return 200, {}
@ -356,8 +356,6 @@ class RegisterRestServlet(RestServlet):
else:
logger.info("bind_email not specified: not binding email")
defer.returnValue()
@defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user

View file

@ -22,6 +22,10 @@ import OpenSSL
from signedjson.key import decode_verify_key_bytes
import hashlib
import logging
logger = logging.getLogger(__name__)
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates
@ -74,22 +78,22 @@ class KeyStore(SQLBaseStore):
)
@cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
def _get_server_verify_key(self, server_name, key_id):
verify_key_bytes = yield self._simple_select_one_onecol(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
"key_id": key_id,
},
retcols=["key_id", "verify_key"],
desc="get_all_server_verify_keys",
retcol="verify_key",
desc="_get_server_verify_key",
allow_none=True,
)
defer.returnValue({
row["key_id"]: decode_verify_key_bytes(
row["key_id"], str(row["verify_key"])
)
for row in rows
})
if verify_key_bytes:
defer.returnValue(decode_verify_key_bytes(
key_id, str(verify_key_bytes)
))
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
@ -101,12 +105,12 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
keys = yield self.get_all_server_verify_keys(server_name)
defer.returnValue({
k: keys[k]
for k in key_ids
if k in keys and keys[k]
})
keys = {}
for key_id in key_ids:
key = yield self._get_server_verify_key(server_name, key_id)
if key:
keys[key_id] = key
defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
@ -133,8 +137,6 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server

View file

@ -24,9 +24,12 @@ class ResponseCache(object):
used rather than trying to compute a new response.
"""
def __init__(self):
def __init__(self, hs, timeout_ms=0):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.
def get(self, key):
result = self.pending_result_cache.get(key)
if result is not None:
@ -39,7 +42,13 @@ class ResponseCache(object):
self.pending_result_cache[key] = result
def remove(r):
self.pending_result_cache.pop(key, None)
if self.timeout_sec:
self.clock.call_later(
self.timeout_sec,
self.pending_result_cache.pop, key, None,
)
else:
self.pending_result_cache.pop(key, None)
return r
result.addBoth(remove)