forked from MirrorHub/synapse
Merge remote-tracking branch 'origin/develop' into dbkr/email_notifs
This commit is contained in:
commit
acded821c4
42 changed files with 1372 additions and 455 deletions
|
@ -25,7 +25,9 @@ rm .coverage* || echo "No coverage files to remove"
|
||||||
tox --notest -e py27
|
tox --notest -e py27
|
||||||
|
|
||||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||||
|
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||||
$TOX_BIN/pip install psycopg2
|
$TOX_BIN/pip install psycopg2
|
||||||
|
$TOX_BIN/pip install lxml
|
||||||
|
|
||||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,8 @@ rm .coverage* || echo "No coverage files to remove"
|
||||||
|
|
||||||
tox --notest -e py27
|
tox --notest -e py27
|
||||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||||
|
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||||
|
$TOX_BIN/pip install lxml
|
||||||
|
|
||||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||||
|
|
||||||
|
|
86
jenkins.sh
86
jenkins.sh
|
@ -1,86 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -eux
|
|
||||||
|
|
||||||
: ${WORKSPACE:="$(pwd)"}
|
|
||||||
|
|
||||||
export PYTHONDONTWRITEBYTECODE=yep
|
|
||||||
export SYNAPSE_CACHE_FACTOR=1
|
|
||||||
|
|
||||||
# Output test results as junit xml
|
|
||||||
export TRIAL_FLAGS="--reporter=subunit"
|
|
||||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
|
||||||
# Write coverage reports to a separate file for each process
|
|
||||||
export COVERAGE_OPTS="-p"
|
|
||||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
|
||||||
|
|
||||||
# Output flake8 violations to violations.flake8.log
|
|
||||||
# Don't exit with non-0 status code on Jenkins,
|
|
||||||
# so that the build steps continue and a later step can decided whether to
|
|
||||||
# UNSTABLE or FAILURE this build.
|
|
||||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
|
||||||
|
|
||||||
rm .coverage* || echo "No coverage files to remove"
|
|
||||||
|
|
||||||
tox
|
|
||||||
|
|
||||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
|
||||||
|
|
||||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
|
||||||
|
|
||||||
if [[ ! -e .sytest-base ]]; then
|
|
||||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
|
||||||
else
|
|
||||||
(cd .sytest-base; git fetch -p)
|
|
||||||
fi
|
|
||||||
|
|
||||||
rm -rf sytest
|
|
||||||
git clone .sytest-base sytest --shared
|
|
||||||
cd sytest
|
|
||||||
|
|
||||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
|
||||||
|
|
||||||
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
|
|
||||||
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
|
|
||||||
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
|
|
||||||
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
|
|
||||||
|
|
||||||
./install-deps.pl
|
|
||||||
|
|
||||||
: ${PORT_BASE:=8000}
|
|
||||||
|
|
||||||
echo >&2 "Running sytest with SQLite3";
|
|
||||||
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
|
|
||||||
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap
|
|
||||||
|
|
||||||
RUN_POSTGRES=""
|
|
||||||
|
|
||||||
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
|
|
||||||
if psql synapse_jenkins_$port <<< ""; then
|
|
||||||
RUN_POSTGRES="$RUN_POSTGRES:$port"
|
|
||||||
cat > localhost-$port/database.yaml << EOF
|
|
||||||
name: psycopg2
|
|
||||||
args:
|
|
||||||
database: synapse_jenkins_$port
|
|
||||||
EOF
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Run if both postgresql databases exist
|
|
||||||
if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
|
|
||||||
echo >&2 "Running sytest with PostgreSQL";
|
|
||||||
$TOX_BIN/pip install psycopg2
|
|
||||||
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
|
|
||||||
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap
|
|
||||||
else
|
|
||||||
echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES"
|
|
||||||
fi
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
cp sytest/.coverage.* .
|
|
||||||
|
|
||||||
# Combine the coverage reports
|
|
||||||
echo "Combining:" .coverage.*
|
|
||||||
$TOX_BIN/python -m coverage combine
|
|
||||||
# Output coverage to coverage.xml
|
|
||||||
$TOX_BIN/coverage xml -o coverage.xml
|
|
|
@ -214,6 +214,10 @@ class Porter(object):
|
||||||
|
|
||||||
self.progress.add_table(table, postgres_size, table_size)
|
self.progress.add_table(table, postgres_size, table_size)
|
||||||
|
|
||||||
|
if table == "event_search":
|
||||||
|
yield self.handle_search_table(postgres_size, table_size, next_chunk)
|
||||||
|
return
|
||||||
|
|
||||||
select = (
|
select = (
|
||||||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
||||||
% (table,)
|
% (table,)
|
||||||
|
@ -232,38 +236,6 @@ class Porter(object):
|
||||||
if rows:
|
if rows:
|
||||||
next_chunk = rows[-1][0] + 1
|
next_chunk = rows[-1][0] + 1
|
||||||
|
|
||||||
if table == "event_search":
|
|
||||||
# We have to treat event_search differently since it has a
|
|
||||||
# different structure in the two different databases.
|
|
||||||
def insert(txn):
|
|
||||||
sql = (
|
|
||||||
"INSERT INTO event_search (event_id, room_id, key, sender, vector)"
|
|
||||||
" VALUES (?,?,?,?,to_tsvector('english', ?))"
|
|
||||||
)
|
|
||||||
|
|
||||||
rows_dict = [
|
|
||||||
dict(zip(headers, row))
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
txn.executemany(sql, [
|
|
||||||
(
|
|
||||||
row["event_id"],
|
|
||||||
row["room_id"],
|
|
||||||
row["key"],
|
|
||||||
row["sender"],
|
|
||||||
row["value"],
|
|
||||||
)
|
|
||||||
for row in rows_dict
|
|
||||||
])
|
|
||||||
|
|
||||||
self.postgres_store._simple_update_one_txn(
|
|
||||||
txn,
|
|
||||||
table="port_from_sqlite3",
|
|
||||||
keyvalues={"table_name": table},
|
|
||||||
updatevalues={"rowid": next_chunk},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._convert_rows(table, headers, rows)
|
self._convert_rows(table, headers, rows)
|
||||||
|
|
||||||
def insert(txn):
|
def insert(txn):
|
||||||
|
@ -286,6 +258,73 @@ class Porter(object):
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_search_table(self, postgres_size, table_size, next_chunk):
|
||||||
|
select = (
|
||||||
|
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
||||||
|
" FROM event_search as es"
|
||||||
|
" INNER JOIN events AS e USING (event_id, room_id)"
|
||||||
|
" WHERE es.rowid >= ?"
|
||||||
|
" ORDER BY es.rowid LIMIT ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
def r(txn):
|
||||||
|
txn.execute(select, (next_chunk, self.batch_size,))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
headers = [column[0] for column in txn.description]
|
||||||
|
|
||||||
|
return headers, rows
|
||||||
|
|
||||||
|
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||||
|
|
||||||
|
if rows:
|
||||||
|
next_chunk = rows[-1][0] + 1
|
||||||
|
|
||||||
|
# We have to treat event_search differently since it has a
|
||||||
|
# different structure in the two different databases.
|
||||||
|
def insert(txn):
|
||||||
|
sql = (
|
||||||
|
"INSERT INTO event_search (event_id, room_id, key,"
|
||||||
|
" sender, vector, origin_server_ts, stream_ordering)"
|
||||||
|
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
|
||||||
|
)
|
||||||
|
|
||||||
|
rows_dict = [
|
||||||
|
dict(zip(headers, row))
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
txn.executemany(sql, [
|
||||||
|
(
|
||||||
|
row["event_id"],
|
||||||
|
row["room_id"],
|
||||||
|
row["key"],
|
||||||
|
row["sender"],
|
||||||
|
row["value"],
|
||||||
|
row["origin_server_ts"],
|
||||||
|
row["stream_ordering"],
|
||||||
|
)
|
||||||
|
for row in rows_dict
|
||||||
|
])
|
||||||
|
|
||||||
|
self.postgres_store._simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
table="port_from_sqlite3",
|
||||||
|
keyvalues={"table_name": "event_search"},
|
||||||
|
updatevalues={"rowid": next_chunk},
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.postgres_store.execute(insert)
|
||||||
|
|
||||||
|
postgres_size += len(rows)
|
||||||
|
|
||||||
|
self.progress.update("event_search", postgres_size)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def setup_db(self, db_config, database_engine):
|
def setup_db(self, db_config, database_engine):
|
||||||
db_conn = database_engine.module.connect(
|
db_conn = database_engine.module.connect(
|
||||||
**{
|
**{
|
||||||
|
|
|
@ -16,12 +16,9 @@
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
|
|
||||||
from synapse.python_dependencies import (
|
from synapse.python_dependencies import (
|
||||||
|
@ -35,18 +32,11 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
|
||||||
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
from twisted.conch.manhole import ColoredManhole
|
|
||||||
from twisted.conch.insults import insults
|
|
||||||
from twisted.conch import manhole_ssh
|
|
||||||
from twisted.cred import checkers, portal
|
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import reactor, task, defer
|
from twisted.internet import reactor, task, defer
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.web.resource import Resource, EncodingResourceWrapper
|
from twisted.web.resource import Resource, EncodingResourceWrapper
|
||||||
from twisted.web.static import File
|
from twisted.web.static import File
|
||||||
from twisted.web.server import Site, GzipEncoderFactory, Request
|
from twisted.web.server import GzipEncoderFactory
|
||||||
from synapse.http.server import RootRedirect
|
from synapse.http.server import RootRedirect
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||||
|
@ -66,6 +56,10 @@ from synapse.federation.transport.server import TransportLayerServer
|
||||||
|
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
from synapse.util.manhole import manhole
|
||||||
|
|
||||||
|
from synapse.http.site import SynapseSite
|
||||||
|
|
||||||
from synapse import events
|
from synapse import events
|
||||||
|
|
||||||
|
@ -74,9 +68,6 @@ from daemonize import Daemonize
|
||||||
logger = logging.getLogger("synapse.app.homeserver")
|
logger = logging.getLogger("synapse.app.homeserver")
|
||||||
|
|
||||||
|
|
||||||
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
|
|
||||||
|
|
||||||
|
|
||||||
def gz_wrap(r):
|
def gz_wrap(r):
|
||||||
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
||||||
|
|
||||||
|
@ -174,7 +165,12 @@ class SynapseHomeServer(HomeServer):
|
||||||
if name == "replication":
|
if name == "replication":
|
||||||
resources[REPLICATION_PREFIX] = ReplicationResource(self)
|
resources[REPLICATION_PREFIX] = ReplicationResource(self)
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources)
|
if WEB_CLIENT_PREFIX in resources:
|
||||||
|
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
||||||
|
else:
|
||||||
|
root_resource = Resource()
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, root_resource)
|
||||||
if tls:
|
if tls:
|
||||||
reactor.listenSSL(
|
reactor.listenSSL(
|
||||||
port,
|
port,
|
||||||
|
@ -207,24 +203,13 @@ class SynapseHomeServer(HomeServer):
|
||||||
if listener["type"] == "http":
|
if listener["type"] == "http":
|
||||||
self._listener_http(config, listener)
|
self._listener_http(config, listener)
|
||||||
elif listener["type"] == "manhole":
|
elif listener["type"] == "manhole":
|
||||||
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
|
|
||||||
matrix="rabbithole"
|
|
||||||
)
|
|
||||||
|
|
||||||
rlm = manhole_ssh.TerminalRealm()
|
|
||||||
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
|
|
||||||
ColoredManhole,
|
|
||||||
{
|
|
||||||
"__name__": "__console__",
|
|
||||||
"hs": self,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
f = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
|
|
||||||
|
|
||||||
reactor.listenTCP(
|
reactor.listenTCP(
|
||||||
listener["port"],
|
listener["port"],
|
||||||
f,
|
manhole(
|
||||||
|
username="matrix",
|
||||||
|
password="rabbithole",
|
||||||
|
globals={"hs": self},
|
||||||
|
),
|
||||||
interface=listener.get("bind_address", '127.0.0.1')
|
interface=listener.get("bind_address", '127.0.0.1')
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -371,210 +356,6 @@ class SynapseService(service.Service):
|
||||||
return self._port.stopListening()
|
return self._port.stopListening()
|
||||||
|
|
||||||
|
|
||||||
class SynapseRequest(Request):
|
|
||||||
def __init__(self, site, *args, **kw):
|
|
||||||
Request.__init__(self, *args, **kw)
|
|
||||||
self.site = site
|
|
||||||
self.authenticated_entity = None
|
|
||||||
self.start_time = 0
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
# We overwrite this so that we don't log ``access_token``
|
|
||||||
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
|
|
||||||
self.__class__.__name__,
|
|
||||||
id(self),
|
|
||||||
self.method,
|
|
||||||
self.get_redacted_uri(),
|
|
||||||
self.clientproto,
|
|
||||||
self.site.site_tag,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_redacted_uri(self):
|
|
||||||
return ACCESS_TOKEN_RE.sub(
|
|
||||||
r'\1<redacted>\3',
|
|
||||||
self.uri
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_user_agent(self):
|
|
||||||
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
|
|
||||||
|
|
||||||
def started_processing(self):
|
|
||||||
self.site.access_logger.info(
|
|
||||||
"%s - %s - Received request: %s %s",
|
|
||||||
self.getClientIP(),
|
|
||||||
self.site.site_tag,
|
|
||||||
self.method,
|
|
||||||
self.get_redacted_uri()
|
|
||||||
)
|
|
||||||
self.start_time = int(time.time() * 1000)
|
|
||||||
|
|
||||||
def finished_processing(self):
|
|
||||||
|
|
||||||
try:
|
|
||||||
context = LoggingContext.current_context()
|
|
||||||
ru_utime, ru_stime = context.get_resource_usage()
|
|
||||||
db_txn_count = context.db_txn_count
|
|
||||||
db_txn_duration = context.db_txn_duration
|
|
||||||
except:
|
|
||||||
ru_utime, ru_stime = (0, 0)
|
|
||||||
db_txn_count, db_txn_duration = (0, 0)
|
|
||||||
|
|
||||||
self.site.access_logger.info(
|
|
||||||
"%s - %s - {%s}"
|
|
||||||
" Processed request: %dms (%dms, %dms) (%dms/%d)"
|
|
||||||
" %sB %s \"%s %s %s\" \"%s\"",
|
|
||||||
self.getClientIP(),
|
|
||||||
self.site.site_tag,
|
|
||||||
self.authenticated_entity,
|
|
||||||
int(time.time() * 1000) - self.start_time,
|
|
||||||
int(ru_utime * 1000),
|
|
||||||
int(ru_stime * 1000),
|
|
||||||
int(db_txn_duration * 1000),
|
|
||||||
int(db_txn_count),
|
|
||||||
self.sentLength,
|
|
||||||
self.code,
|
|
||||||
self.method,
|
|
||||||
self.get_redacted_uri(),
|
|
||||||
self.clientproto,
|
|
||||||
self.get_user_agent(),
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def processing(self):
|
|
||||||
self.started_processing()
|
|
||||||
yield
|
|
||||||
self.finished_processing()
|
|
||||||
|
|
||||||
|
|
||||||
class XForwardedForRequest(SynapseRequest):
|
|
||||||
def __init__(self, *args, **kw):
|
|
||||||
SynapseRequest.__init__(self, *args, **kw)
|
|
||||||
|
|
||||||
"""
|
|
||||||
Add a layer on top of another request that only uses the value of an
|
|
||||||
X-Forwarded-For header as the result of C{getClientIP}.
|
|
||||||
"""
|
|
||||||
def getClientIP(self):
|
|
||||||
"""
|
|
||||||
@return: The client address (the first address) in the value of the
|
|
||||||
I{X-Forwarded-For header}. If the header is not present, return
|
|
||||||
C{b"-"}.
|
|
||||||
"""
|
|
||||||
return self.requestHeaders.getRawHeaders(
|
|
||||||
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
|
|
||||||
|
|
||||||
|
|
||||||
class SynapseRequestFactory(object):
|
|
||||||
def __init__(self, site, x_forwarded_for):
|
|
||||||
self.site = site
|
|
||||||
self.x_forwarded_for = x_forwarded_for
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
if self.x_forwarded_for:
|
|
||||||
return XForwardedForRequest(self.site, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return SynapseRequest(self.site, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class SynapseSite(Site):
|
|
||||||
"""
|
|
||||||
Subclass of a twisted http Site that does access logging with python's
|
|
||||||
standard logging
|
|
||||||
"""
|
|
||||||
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
|
|
||||||
Site.__init__(self, resource, *args, **kwargs)
|
|
||||||
|
|
||||||
self.site_tag = site_tag
|
|
||||||
|
|
||||||
proxied = config.get("x_forwarded", False)
|
|
||||||
self.requestFactory = SynapseRequestFactory(self, proxied)
|
|
||||||
self.access_logger = logging.getLogger(logger_name)
|
|
||||||
|
|
||||||
def log(self, request):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
|
|
||||||
"""Create the resource tree for this Home Server.
|
|
||||||
|
|
||||||
This in unduly complicated because Twisted does not support putting
|
|
||||||
child resources more than 1 level deep at a time.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
web_client (bool): True to enable the web client.
|
|
||||||
redirect_root_to_web_client (bool): True to redirect '/' to the
|
|
||||||
location of the web client. This does nothing if web_client is not
|
|
||||||
True.
|
|
||||||
"""
|
|
||||||
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
|
|
||||||
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
|
||||||
else:
|
|
||||||
root_resource = Resource()
|
|
||||||
|
|
||||||
# ideally we'd just use getChild and putChild but getChild doesn't work
|
|
||||||
# unless you give it a Request object IN ADDITION to the name :/ So
|
|
||||||
# instead, we'll store a copy of this mapping so we can actually add
|
|
||||||
# extra resources to existing nodes. See self._resource_id for the key.
|
|
||||||
resource_mappings = {}
|
|
||||||
for full_path, res in desired_tree.items():
|
|
||||||
logger.info("Attaching %s to path %s", res, full_path)
|
|
||||||
last_resource = root_resource
|
|
||||||
for path_seg in full_path.split('/')[1:-1]:
|
|
||||||
if path_seg not in last_resource.listNames():
|
|
||||||
# resource doesn't exist, so make a "dummy resource"
|
|
||||||
child_resource = Resource()
|
|
||||||
last_resource.putChild(path_seg, child_resource)
|
|
||||||
res_id = _resource_id(last_resource, path_seg)
|
|
||||||
resource_mappings[res_id] = child_resource
|
|
||||||
last_resource = child_resource
|
|
||||||
else:
|
|
||||||
# we have an existing Resource, use that instead.
|
|
||||||
res_id = _resource_id(last_resource, path_seg)
|
|
||||||
last_resource = resource_mappings[res_id]
|
|
||||||
|
|
||||||
# ===========================
|
|
||||||
# now attach the actual desired resource
|
|
||||||
last_path_seg = full_path.split('/')[-1]
|
|
||||||
|
|
||||||
# if there is already a resource here, thieve its children and
|
|
||||||
# replace it
|
|
||||||
res_id = _resource_id(last_resource, last_path_seg)
|
|
||||||
if res_id in resource_mappings:
|
|
||||||
# there is a dummy resource at this path already, which needs
|
|
||||||
# to be replaced with the desired resource.
|
|
||||||
existing_dummy_resource = resource_mappings[res_id]
|
|
||||||
for child_name in existing_dummy_resource.listNames():
|
|
||||||
child_res_id = _resource_id(
|
|
||||||
existing_dummy_resource, child_name
|
|
||||||
)
|
|
||||||
child_resource = resource_mappings[child_res_id]
|
|
||||||
# steal the children
|
|
||||||
res.putChild(child_name, child_resource)
|
|
||||||
|
|
||||||
# finally, insert the desired resource in the right place
|
|
||||||
last_resource.putChild(last_path_seg, res)
|
|
||||||
res_id = _resource_id(last_resource, last_path_seg)
|
|
||||||
resource_mappings[res_id] = res
|
|
||||||
|
|
||||||
return root_resource
|
|
||||||
|
|
||||||
|
|
||||||
def _resource_id(resource, path_seg):
|
|
||||||
"""Construct an arbitrary resource ID so you can retrieve the mapping
|
|
||||||
later.
|
|
||||||
|
|
||||||
If you want to represent resource A putChild resource B with path C,
|
|
||||||
the mapping should looks like _resource_id(A,C) = B.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resource (Resource): The *parent* Resourceb
|
|
||||||
path_seg (str): The name of the child Resource to be attached.
|
|
||||||
Returns:
|
|
||||||
str: A unique string which can be a key to the child Resource.
|
|
||||||
"""
|
|
||||||
return "%s-%s" % (resource, path_seg)
|
|
||||||
|
|
||||||
|
|
||||||
def run(hs):
|
def run(hs):
|
||||||
PROFILE_SYNAPSE = False
|
PROFILE_SYNAPSE = False
|
||||||
if PROFILE_SYNAPSE:
|
if PROFILE_SYNAPSE:
|
||||||
|
|
315
synapse/app/pusher.py
Normal file
315
synapse/app/pusher.py
Normal file
|
@ -0,0 +1,315 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
from synapse.config.database import DatabaseConfig
|
||||||
|
from synapse.config.logger import LoggingConfig
|
||||||
|
from synapse.http.site import SynapseSite
|
||||||
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
|
from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
||||||
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||||
|
from synapse.util.manhole import manhole
|
||||||
|
from synapse.util.rlimit import change_resource_limit
|
||||||
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
|
from twisted.internet import reactor, defer
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from daemonize import Daemonize
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("synapse.app.pusher")
|
||||||
|
|
||||||
|
|
||||||
|
class SlaveConfig(DatabaseConfig):
|
||||||
|
def read_config(self, config):
|
||||||
|
self.replication_url = config["replication_url"]
|
||||||
|
self.server_name = config["server_name"]
|
||||||
|
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
|
||||||
|
"use_insecure_ssl_client_just_for_testing_do_not_use", False
|
||||||
|
)
|
||||||
|
self.user_agent_suffix = None
|
||||||
|
self.start_pushers = True
|
||||||
|
self.listeners = config["listeners"]
|
||||||
|
self.soft_file_limit = config.get("soft_file_limit")
|
||||||
|
self.daemonize = config.get("daemonize")
|
||||||
|
self.pid_file = self.abspath(config.get("pid_file"))
|
||||||
|
|
||||||
|
def default_config(self, server_name, **kwargs):
|
||||||
|
pid_file = self.abspath("pusher.pid")
|
||||||
|
return """\
|
||||||
|
# Slave configuration
|
||||||
|
|
||||||
|
# The replication listener on the synapse to talk to.
|
||||||
|
#replication_url: https://localhost:{replication_port}/_synapse/replication
|
||||||
|
|
||||||
|
server_name: "%(server_name)s"
|
||||||
|
|
||||||
|
listeners: []
|
||||||
|
# Enable a ssh manhole listener on the pusher.
|
||||||
|
# - type: manhole
|
||||||
|
# port: {manhole_port}
|
||||||
|
# bind_address: 127.0.0.1
|
||||||
|
# Enable a metric listener on the pusher.
|
||||||
|
# - type: http
|
||||||
|
# port: {metrics_port}
|
||||||
|
# bind_address: 127.0.0.1
|
||||||
|
# resources:
|
||||||
|
# - names: ["metrics"]
|
||||||
|
# compress: False
|
||||||
|
|
||||||
|
report_stats: False
|
||||||
|
|
||||||
|
daemonize: False
|
||||||
|
|
||||||
|
pid_file: %(pid_file)s
|
||||||
|
|
||||||
|
""" % locals()
|
||||||
|
|
||||||
|
|
||||||
|
class PusherSlaveConfig(SlaveConfig, LoggingConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PusherSlaveStore(
|
||||||
|
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore
|
||||||
|
):
|
||||||
|
update_pusher_last_stream_ordering_and_success = (
|
||||||
|
DataStore.update_pusher_last_stream_ordering_and_success.__func__
|
||||||
|
)
|
||||||
|
|
||||||
|
update_pusher_failing_since = (
|
||||||
|
DataStore.update_pusher_failing_since.__func__
|
||||||
|
)
|
||||||
|
|
||||||
|
update_pusher_last_stream_ordering = (
|
||||||
|
DataStore.update_pusher_last_stream_ordering.__func__
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PusherServer(HomeServer):
|
||||||
|
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
logger.info("Setting up.")
|
||||||
|
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
|
||||||
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
|
http_client = self.get_simple_http_client()
|
||||||
|
replication_url = self.config.replication_url
|
||||||
|
url = replication_url + "/remove_pushers"
|
||||||
|
return http_client.post_json_get_json(url, {
|
||||||
|
"remove": [{
|
||||||
|
"app_id": app_id,
|
||||||
|
"push_key": push_key,
|
||||||
|
"user_id": user_id,
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
|
||||||
|
def _listen_http(self, listener_config):
|
||||||
|
port = listener_config["port"]
|
||||||
|
bind_address = listener_config.get("bind_address", "")
|
||||||
|
site_tag = listener_config.get("tag", port)
|
||||||
|
resources = {}
|
||||||
|
for res in listener_config["resources"]:
|
||||||
|
for name in res["names"]:
|
||||||
|
if name == "metrics":
|
||||||
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, Resource())
|
||||||
|
reactor.listenTCP(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.http.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
interface=bind_address
|
||||||
|
)
|
||||||
|
logger.info("Synapse pusher now listening on port %d", port)
|
||||||
|
|
||||||
|
def start_listening(self):
|
||||||
|
for listener in self.config.listeners:
|
||||||
|
if listener["type"] == "http":
|
||||||
|
self._listen_http(listener)
|
||||||
|
elif listener["type"] == "manhole":
|
||||||
|
reactor.listenTCP(
|
||||||
|
listener["port"],
|
||||||
|
manhole(
|
||||||
|
username="matrix",
|
||||||
|
password="rabbithole",
|
||||||
|
globals={"hs": self},
|
||||||
|
),
|
||||||
|
interface=listener.get("bind_address", '127.0.0.1')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def replicate(self):
|
||||||
|
http_client = self.get_simple_http_client()
|
||||||
|
store = self.get_datastore()
|
||||||
|
replication_url = self.config.replication_url
|
||||||
|
pusher_pool = self.get_pusherpool()
|
||||||
|
|
||||||
|
def stop_pusher(user_id, app_id, pushkey):
|
||||||
|
key = "%s:%s" % (app_id, pushkey)
|
||||||
|
pushers_for_user = pusher_pool.pushers.get(user_id, {})
|
||||||
|
pusher = pushers_for_user.pop(key, None)
|
||||||
|
if pusher is None:
|
||||||
|
return
|
||||||
|
logger.info("Stopping pusher %r / %r", user_id, key)
|
||||||
|
pusher.on_stop()
|
||||||
|
|
||||||
|
def start_pusher(user_id, app_id, pushkey):
|
||||||
|
key = "%s:%s" % (app_id, pushkey)
|
||||||
|
logger.info("Starting pusher %r / %r", user_id, key)
|
||||||
|
return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def poke_pushers(results):
|
||||||
|
pushers_rows = set(
|
||||||
|
map(tuple, results.get("pushers", {}).get("rows", []))
|
||||||
|
)
|
||||||
|
deleted_pushers_rows = set(
|
||||||
|
map(tuple, results.get("deleted_pushers", {}).get("rows", []))
|
||||||
|
)
|
||||||
|
for row in sorted(pushers_rows | deleted_pushers_rows):
|
||||||
|
if row in deleted_pushers_rows:
|
||||||
|
user_id, app_id, pushkey = row[1:4]
|
||||||
|
stop_pusher(user_id, app_id, pushkey)
|
||||||
|
elif row in pushers_rows:
|
||||||
|
user_id = row[1]
|
||||||
|
app_id = row[5]
|
||||||
|
pushkey = row[8]
|
||||||
|
yield start_pusher(user_id, app_id, pushkey)
|
||||||
|
|
||||||
|
stream = results.get("events")
|
||||||
|
if stream:
|
||||||
|
min_stream_id = stream["rows"][0][0]
|
||||||
|
max_stream_id = stream["position"]
|
||||||
|
preserve_fn(pusher_pool.on_new_notifications)(
|
||||||
|
min_stream_id, max_stream_id
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = results.get("receipts")
|
||||||
|
if stream:
|
||||||
|
rows = stream["rows"]
|
||||||
|
affected_room_ids = set(row[1] for row in rows)
|
||||||
|
min_stream_id = rows[0][0]
|
||||||
|
max_stream_id = stream["position"]
|
||||||
|
preserve_fn(pusher_pool.on_new_receipts)(
|
||||||
|
min_stream_id, max_stream_id, affected_room_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
args = store.stream_positions()
|
||||||
|
args["timeout"] = 30000
|
||||||
|
result = yield http_client.get_json(replication_url, args=args)
|
||||||
|
yield store.process_replication(result)
|
||||||
|
poke_pushers(result)
|
||||||
|
except:
|
||||||
|
logger.exception("Error replicating from %r", replication_url)
|
||||||
|
sleep(30)
|
||||||
|
|
||||||
|
|
||||||
|
def setup(config_options):
|
||||||
|
try:
|
||||||
|
config = PusherSlaveConfig.load_config(
|
||||||
|
"Synapse pusher", config_options
|
||||||
|
)
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
config.setup_logging()
|
||||||
|
|
||||||
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
ps = PusherServer(
|
||||||
|
config.server_name,
|
||||||
|
db_config=config.database_config,
|
||||||
|
config=config,
|
||||||
|
version_string=get_version_string("Synapse", synapse),
|
||||||
|
database_engine=database_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
ps.setup()
|
||||||
|
ps.start_listening()
|
||||||
|
|
||||||
|
change_resource_limit(ps.config.soft_file_limit)
|
||||||
|
|
||||||
|
def start():
|
||||||
|
ps.replicate()
|
||||||
|
ps.get_pusherpool().start()
|
||||||
|
ps.get_datastore().start_profiling()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
return ps
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with LoggingContext("main"):
|
||||||
|
ps = setup(sys.argv[1:])
|
||||||
|
|
||||||
|
if ps.config.daemonize:
|
||||||
|
def run():
|
||||||
|
with LoggingContext("run"):
|
||||||
|
change_resource_limit(ps.config.soft_file_limit)
|
||||||
|
reactor.run()
|
||||||
|
|
||||||
|
daemon = Daemonize(
|
||||||
|
app="synapse-pusher",
|
||||||
|
pid=ps.config.pid_file,
|
||||||
|
action=run,
|
||||||
|
auto_close_fds=False,
|
||||||
|
verbose=True,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
daemon.start()
|
||||||
|
else:
|
||||||
|
reactor.run()
|
|
@ -13,7 +13,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
|
|
||||||
|
MISSING_JWT = (
|
||||||
|
"""Missing jwt library. This is required for jwt login.
|
||||||
|
|
||||||
|
Install by running:
|
||||||
|
pip install pyjwt
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JWTConfig(Config):
|
class JWTConfig(Config):
|
||||||
|
@ -23,6 +32,12 @@ class JWTConfig(Config):
|
||||||
self.jwt_enabled = jwt_config.get("enabled", False)
|
self.jwt_enabled = jwt_config.get("enabled", False)
|
||||||
self.jwt_secret = jwt_config["secret"]
|
self.jwt_secret = jwt_config["secret"]
|
||||||
self.jwt_algorithm = jwt_config["algorithm"]
|
self.jwt_algorithm = jwt_config["algorithm"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jwt
|
||||||
|
jwt # To stop unused lint.
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigError(MISSING_JWT)
|
||||||
else:
|
else:
|
||||||
self.jwt_enabled = False
|
self.jwt_enabled = False
|
||||||
self.jwt_secret = None
|
self.jwt_secret = None
|
||||||
|
@ -30,6 +45,8 @@ class JWTConfig(Config):
|
||||||
|
|
||||||
def default_config(self, **kwargs):
|
def default_config(self, **kwargs):
|
||||||
return """\
|
return """\
|
||||||
|
# The JWT needs to contain a globally unique "sub" (subject) claim.
|
||||||
|
#
|
||||||
# jwt_config:
|
# jwt_config:
|
||||||
# enabled: true
|
# enabled: true
|
||||||
# secret: "a secret"
|
# secret: "a secret"
|
||||||
|
|
|
@ -33,6 +33,7 @@ class ServerConfig(Config):
|
||||||
if self.public_baseurl is not None:
|
if self.public_baseurl is not None:
|
||||||
if self.public_baseurl[-1] != '/':
|
if self.public_baseurl[-1] != '/':
|
||||||
self.public_baseurl += '/'
|
self.public_baseurl += '/'
|
||||||
|
self.start_pushers = config.get("start_pushers", True)
|
||||||
|
|
||||||
self.listeners = config.get("listeners", [])
|
self.listeners = config.get("listeners", [])
|
||||||
|
|
||||||
|
|
|
@ -232,7 +232,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
if old_membership == "ban" and action != "unban":
|
if old_membership == "ban" and action != "unban":
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
"Cannot %s user who was is banned" % (action,),
|
"Cannot %s user who was banned" % (action,),
|
||||||
errcode=Codes.BAD_STATE
|
errcode=Codes.BAD_STATE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -462,5 +462,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
||||||
self._context = SSL.Context(SSL.SSLv23_METHOD)
|
self._context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
self._context.set_verify(VERIFY_NONE, lambda *_: None)
|
self._context.set_verify(VERIFY_NONE, lambda *_: None)
|
||||||
|
|
||||||
def getContext(self, hostname, port):
|
def getContext(self, hostname=None, port=None):
|
||||||
return self._context
|
return self._context
|
||||||
|
|
||||||
|
def creatorForNetloc(self, hostname, port):
|
||||||
|
return self
|
||||||
|
|
|
@ -74,7 +74,12 @@ response_db_txn_duration = metrics.register_distribution(
|
||||||
_next_request_id = 0
|
_next_request_id = 0
|
||||||
|
|
||||||
|
|
||||||
def request_handler(request_handler):
|
def request_handler(report_metrics=True):
|
||||||
|
"""Decorator for ``wrap_request_handler``"""
|
||||||
|
return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_request_handler(request_handler, report_metrics):
|
||||||
"""Wraps a method that acts as a request handler with the necessary logging
|
"""Wraps a method that acts as a request handler with the necessary logging
|
||||||
and exception handling.
|
and exception handling.
|
||||||
|
|
||||||
|
@ -96,7 +101,12 @@ def request_handler(request_handler):
|
||||||
global _next_request_id
|
global _next_request_id
|
||||||
request_id = "%s-%s" % (request.method, _next_request_id)
|
request_id = "%s-%s" % (request.method, _next_request_id)
|
||||||
_next_request_id += 1
|
_next_request_id += 1
|
||||||
|
|
||||||
with LoggingContext(request_id) as request_context:
|
with LoggingContext(request_id) as request_context:
|
||||||
|
if report_metrics:
|
||||||
|
request_metrics = RequestMetrics()
|
||||||
|
request_metrics.start(self.clock)
|
||||||
|
|
||||||
request_context.request = request_id
|
request_context.request = request_id
|
||||||
with request.processing():
|
with request.processing():
|
||||||
try:
|
try:
|
||||||
|
@ -133,6 +143,14 @@ def request_handler(request_handler):
|
||||||
},
|
},
|
||||||
send_cors=True
|
send_cors=True
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if report_metrics:
|
||||||
|
request_metrics.stop(
|
||||||
|
self.clock, request, self.__class__.__name__
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
return wrapped_request_handler
|
return wrapped_request_handler
|
||||||
|
|
||||||
|
|
||||||
|
@ -197,19 +215,23 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
self._async_render(request)
|
self._async_render(request)
|
||||||
return server.NOT_DONE_YET
|
return server.NOT_DONE_YET
|
||||||
|
|
||||||
@request_handler
|
# Disable metric reporting because _async_render does its own metrics.
|
||||||
|
# It does its own metric reporting because _async_render dispatches to
|
||||||
|
# a callback and it's the class name of that callback we want to report
|
||||||
|
# against rather than the JsonResource itself.
|
||||||
|
@request_handler(report_metrics=False)
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render(self, request):
|
def _async_render(self, request):
|
||||||
""" This gets called from render() every time someone sends us a request.
|
""" This gets called from render() every time someone sends us a request.
|
||||||
This checks if anyone has registered a callback for that method and
|
This checks if anyone has registered a callback for that method and
|
||||||
path.
|
path.
|
||||||
"""
|
"""
|
||||||
start = self.clock.time_msec()
|
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
self._send_response(request, 200, {})
|
self._send_response(request, 200, {})
|
||||||
return
|
return
|
||||||
|
|
||||||
start_context = LoggingContext.current_context()
|
request_metrics = RequestMetrics()
|
||||||
|
request_metrics.start(self.clock)
|
||||||
|
|
||||||
# Loop through all the registered callbacks to check if the method
|
# Loop through all the registered callbacks to check if the method
|
||||||
# and path regex match
|
# and path regex match
|
||||||
|
@ -241,40 +263,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
self._send_response(request, code, response)
|
self._send_response(request, code, response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = LoggingContext.current_context()
|
request_metrics.stop(self.clock, request, servlet_classname)
|
||||||
|
|
||||||
tag = ""
|
|
||||||
if context:
|
|
||||||
tag = context.tag
|
|
||||||
|
|
||||||
if context != start_context:
|
|
||||||
logger.warn(
|
|
||||||
"Context have unexpectedly changed %r, %r",
|
|
||||||
context, self.start_context
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
incoming_requests_counter.inc(request.method, servlet_classname, tag)
|
|
||||||
|
|
||||||
response_timer.inc_by(
|
|
||||||
self.clock.time_msec() - start, request.method,
|
|
||||||
servlet_classname, tag
|
|
||||||
)
|
|
||||||
|
|
||||||
ru_utime, ru_stime = context.get_resource_usage()
|
|
||||||
|
|
||||||
response_ru_utime.inc_by(
|
|
||||||
ru_utime, request.method, servlet_classname, tag
|
|
||||||
)
|
|
||||||
response_ru_stime.inc_by(
|
|
||||||
ru_stime, request.method, servlet_classname, tag
|
|
||||||
)
|
|
||||||
response_db_txn_count.inc_by(
|
|
||||||
context.db_txn_count, request.method, servlet_classname, tag
|
|
||||||
)
|
|
||||||
response_db_txn_duration.inc_by(
|
|
||||||
context.db_txn_duration, request.method, servlet_classname, tag
|
|
||||||
)
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -307,6 +296,48 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestMetrics(object):
|
||||||
|
def start(self, clock):
|
||||||
|
self.start = clock.time_msec()
|
||||||
|
self.start_context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
def stop(self, clock, request, servlet_classname):
|
||||||
|
context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
tag = ""
|
||||||
|
if context:
|
||||||
|
tag = context.tag
|
||||||
|
|
||||||
|
if context != self.start_context:
|
||||||
|
logger.warn(
|
||||||
|
"Context have unexpectedly changed %r, %r",
|
||||||
|
context, self.start_context
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
incoming_requests_counter.inc(request.method, servlet_classname, tag)
|
||||||
|
|
||||||
|
response_timer.inc_by(
|
||||||
|
clock.time_msec() - self.start, request.method,
|
||||||
|
servlet_classname, tag
|
||||||
|
)
|
||||||
|
|
||||||
|
ru_utime, ru_stime = context.get_resource_usage()
|
||||||
|
|
||||||
|
response_ru_utime.inc_by(
|
||||||
|
ru_utime, request.method, servlet_classname, tag
|
||||||
|
)
|
||||||
|
response_ru_stime.inc_by(
|
||||||
|
ru_stime, request.method, servlet_classname, tag
|
||||||
|
)
|
||||||
|
response_db_txn_count.inc_by(
|
||||||
|
context.db_txn_count, request.method, servlet_classname, tag
|
||||||
|
)
|
||||||
|
response_db_txn_duration.inc_by(
|
||||||
|
context.db_txn_duration, request.method, servlet_classname, tag
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RootRedirect(resource.Resource):
|
class RootRedirect(resource.Resource):
|
||||||
"""Redirects the root '/' path to another path."""
|
"""Redirects the root '/' path to another path."""
|
||||||
|
|
||||||
|
|
146
synapse/http/site.py
Normal file
146
synapse/http/site.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from twisted.web.server import Site, Request
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
|
||||||
|
|
||||||
|
|
||||||
|
class SynapseRequest(Request):
|
||||||
|
def __init__(self, site, *args, **kw):
|
||||||
|
Request.__init__(self, *args, **kw)
|
||||||
|
self.site = site
|
||||||
|
self.authenticated_entity = None
|
||||||
|
self.start_time = 0
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
# We overwrite this so that we don't log ``access_token``
|
||||||
|
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
id(self),
|
||||||
|
self.method,
|
||||||
|
self.get_redacted_uri(),
|
||||||
|
self.clientproto,
|
||||||
|
self.site.site_tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_redacted_uri(self):
|
||||||
|
return ACCESS_TOKEN_RE.sub(
|
||||||
|
r'\1<redacted>\3',
|
||||||
|
self.uri
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_user_agent(self):
|
||||||
|
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
|
||||||
|
|
||||||
|
def started_processing(self):
|
||||||
|
self.site.access_logger.info(
|
||||||
|
"%s - %s - Received request: %s %s",
|
||||||
|
self.getClientIP(),
|
||||||
|
self.site.site_tag,
|
||||||
|
self.method,
|
||||||
|
self.get_redacted_uri()
|
||||||
|
)
|
||||||
|
self.start_time = int(time.time() * 1000)
|
||||||
|
|
||||||
|
def finished_processing(self):
|
||||||
|
|
||||||
|
try:
|
||||||
|
context = LoggingContext.current_context()
|
||||||
|
ru_utime, ru_stime = context.get_resource_usage()
|
||||||
|
db_txn_count = context.db_txn_count
|
||||||
|
db_txn_duration = context.db_txn_duration
|
||||||
|
except:
|
||||||
|
ru_utime, ru_stime = (0, 0)
|
||||||
|
db_txn_count, db_txn_duration = (0, 0)
|
||||||
|
|
||||||
|
self.site.access_logger.info(
|
||||||
|
"%s - %s - {%s}"
|
||||||
|
" Processed request: %dms (%dms, %dms) (%dms/%d)"
|
||||||
|
" %sB %s \"%s %s %s\" \"%s\"",
|
||||||
|
self.getClientIP(),
|
||||||
|
self.site.site_tag,
|
||||||
|
self.authenticated_entity,
|
||||||
|
int(time.time() * 1000) - self.start_time,
|
||||||
|
int(ru_utime * 1000),
|
||||||
|
int(ru_stime * 1000),
|
||||||
|
int(db_txn_duration * 1000),
|
||||||
|
int(db_txn_count),
|
||||||
|
self.sentLength,
|
||||||
|
self.code,
|
||||||
|
self.method,
|
||||||
|
self.get_redacted_uri(),
|
||||||
|
self.clientproto,
|
||||||
|
self.get_user_agent(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def processing(self):
|
||||||
|
self.started_processing()
|
||||||
|
yield
|
||||||
|
self.finished_processing()
|
||||||
|
|
||||||
|
|
||||||
|
class XForwardedForRequest(SynapseRequest):
|
||||||
|
def __init__(self, *args, **kw):
|
||||||
|
SynapseRequest.__init__(self, *args, **kw)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Add a layer on top of another request that only uses the value of an
|
||||||
|
X-Forwarded-For header as the result of C{getClientIP}.
|
||||||
|
"""
|
||||||
|
def getClientIP(self):
|
||||||
|
"""
|
||||||
|
@return: The client address (the first address) in the value of the
|
||||||
|
I{X-Forwarded-For header}. If the header is not present, return
|
||||||
|
C{b"-"}.
|
||||||
|
"""
|
||||||
|
return self.requestHeaders.getRawHeaders(
|
||||||
|
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
|
||||||
|
|
||||||
|
|
||||||
|
class SynapseRequestFactory(object):
|
||||||
|
def __init__(self, site, x_forwarded_for):
|
||||||
|
self.site = site
|
||||||
|
self.x_forwarded_for = x_forwarded_for
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if self.x_forwarded_for:
|
||||||
|
return XForwardedForRequest(self.site, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return SynapseRequest(self.site, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class SynapseSite(Site):
|
||||||
|
"""
|
||||||
|
Subclass of a twisted http Site that does access logging with python's
|
||||||
|
standard logging
|
||||||
|
"""
|
||||||
|
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
|
||||||
|
Site.__init__(self, resource, *args, **kwargs)
|
||||||
|
|
||||||
|
self.site_tag = site_tag
|
||||||
|
|
||||||
|
proxied = config.get("x_forwarded", False)
|
||||||
|
self.requestFactory = SynapseRequestFactory(self, proxied)
|
||||||
|
self.access_logger = logging.getLogger(logger_name)
|
||||||
|
|
||||||
|
def log(self, request):
|
||||||
|
pass
|
|
@ -230,7 +230,7 @@ class HttpPusher(object):
|
||||||
"Pushkey %s was rejected: removing",
|
"Pushkey %s was rejected: removing",
|
||||||
pk
|
pk
|
||||||
)
|
)
|
||||||
yield self.hs.get_pusherpool().remove_pusher(
|
yield self.hs.remove_pusher(
|
||||||
self.app_id, pk, self.user_id
|
self.app_id, pk, self.user_id
|
||||||
)
|
)
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
|
@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
|
||||||
class PusherPool:
|
class PusherPool:
|
||||||
def __init__(self, _hs):
|
def __init__(self, _hs):
|
||||||
self.hs = _hs
|
self.hs = _hs
|
||||||
|
self.start_pushers = _hs.config.start_pushers
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
self.pushers = {}
|
self.pushers = {}
|
||||||
|
@ -177,6 +178,9 @@ class PusherPool:
|
||||||
self._start_pushers([p])
|
self._start_pushers([p])
|
||||||
|
|
||||||
def _start_pushers(self, pushers):
|
def _start_pushers(self, pushers):
|
||||||
|
if not self.start_pushers:
|
||||||
|
logger.info("Not starting pushers because they are disabled in the config")
|
||||||
|
return
|
||||||
logger.info("Starting %d pushers", len(pushers))
|
logger.info("Starting %d pushers", len(pushers))
|
||||||
for pusherdict in pushers:
|
for pusherdict in pushers:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -36,7 +36,6 @@ REQUIREMENTS = {
|
||||||
"blist": ["blist"],
|
"blist": ["blist"],
|
||||||
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
|
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
|
||||||
"pymacaroons-pynacl": ["pymacaroons"],
|
"pymacaroons-pynacl": ["pymacaroons"],
|
||||||
"pyjwt": ["jwt"],
|
|
||||||
}
|
}
|
||||||
CONDITIONAL_REQUIREMENTS = {
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
"web_client": {
|
"web_client": {
|
||||||
|
|
54
synapse/replication/pusher_resource.py
Normal file
54
synapse/replication/pusher_resource.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.http.server import respond_with_json_bytes, request_handler
|
||||||
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
||||||
|
class PusherResource(Resource):
|
||||||
|
"""
|
||||||
|
HTTP endpoint for deleting rejected pushers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
Resource.__init__(self) # Resource is old-style, so no super()
|
||||||
|
|
||||||
|
self.version_string = hs.version_string
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
def render_POST(self, request):
|
||||||
|
self._async_render_POST(request)
|
||||||
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
@request_handler()
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _async_render_POST(self, request):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
for remove in content["remove"]:
|
||||||
|
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||||
|
remove["app_id"],
|
||||||
|
remove["push_key"],
|
||||||
|
remove["user_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
|
respond_with_json_bytes(request, 200, "{}")
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
from synapse.http.server import request_handler, finish_request
|
from synapse.http.server import request_handler, finish_request
|
||||||
|
from synapse.replication.pusher_resource import PusherResource
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
@ -102,8 +103,6 @@ class ReplicationResource(Resource):
|
||||||
long-polling this replication API for new data on those streams.
|
long-polling this replication API for new data on those streams.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
isLeaf = True
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
Resource.__init__(self) # Resource is old-style, so no super()
|
Resource.__init__(self) # Resource is old-style, so no super()
|
||||||
|
|
||||||
|
@ -113,6 +112,9 @@ class ReplicationResource(Resource):
|
||||||
self.presence_handler = hs.get_handlers().presence_handler
|
self.presence_handler = hs.get_handlers().presence_handler
|
||||||
self.typing_handler = hs.get_handlers().typing_notification_handler
|
self.typing_handler = hs.get_handlers().typing_notification_handler
|
||||||
self.notifier = hs.notifier
|
self.notifier = hs.notifier
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self.putChild("remove_pushers", PusherResource(hs))
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self._async_render_GET(request)
|
self._async_render_GET(request)
|
||||||
|
@ -138,7 +140,7 @@ class ReplicationResource(Resource):
|
||||||
state_token,
|
state_token,
|
||||||
))
|
))
|
||||||
|
|
||||||
@request_handler
|
@request_handler()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_GET(self, request):
|
def _async_render_GET(self, request):
|
||||||
limit = parse_integer(request, "limit", 100)
|
limit = parse_integer(request, "limit", 100)
|
||||||
|
@ -343,7 +345,7 @@ class ReplicationResource(Resource):
|
||||||
"app_id", "app_display_name", "device_display_name", "pushkey",
|
"app_id", "app_display_name", "device_display_name", "pushkey",
|
||||||
"ts", "lang", "data"
|
"ts", "lang", "data"
|
||||||
))
|
))
|
||||||
writer.write_header_and_rows("deleted", deleted, (
|
writer.write_header_and_rows("deleted_pushers", deleted, (
|
||||||
"position", "user_id", "app_id", "pushkey"
|
"position", "user_id", "app_id", "pushkey"
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@ -381,7 +383,7 @@ class _Writer(object):
|
||||||
position = rows[-1][0]
|
position = rows[-1][0]
|
||||||
|
|
||||||
self.streams[name] = {
|
self.streams[name] = {
|
||||||
"position": str(position),
|
"position": position if type(position) is int else str(position),
|
||||||
"field_names": fields,
|
"field_names": fields,
|
||||||
"rows": rows,
|
"rows": rows,
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ from synapse.storage import DataStore
|
||||||
from synapse.storage.room import RoomStore
|
from synapse.storage.room import RoomStore
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
from synapse.storage.roommember import RoomMemberStore
|
||||||
from synapse.storage.event_federation import EventFederationStore
|
from synapse.storage.event_federation import EventFederationStore
|
||||||
|
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||||
from synapse.storage.state import StateStore
|
from synapse.storage.state import StateStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
@ -68,7 +69,19 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
_get_current_state_for_key = StateStore.__dict__[
|
_get_current_state_for_key = StateStore.__dict__[
|
||||||
"_get_current_state_for_key"
|
"_get_current_state_for_key"
|
||||||
]
|
]
|
||||||
|
get_invited_rooms_for_user = RoomMemberStore.__dict__[
|
||||||
|
"get_invited_rooms_for_user"
|
||||||
|
]
|
||||||
|
get_unread_event_push_actions_by_room_for_user = (
|
||||||
|
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
|
||||||
|
)
|
||||||
|
|
||||||
|
get_unread_push_actions_for_user_in_range = (
|
||||||
|
DataStore.get_unread_push_actions_for_user_in_range.__func__
|
||||||
|
)
|
||||||
|
get_push_action_users_in_range = (
|
||||||
|
DataStore.get_push_action_users_in_range.__func__
|
||||||
|
)
|
||||||
get_event = DataStore.get_event.__func__
|
get_event = DataStore.get_event.__func__
|
||||||
get_current_state = DataStore.get_current_state.__func__
|
get_current_state = DataStore.get_current_state.__func__
|
||||||
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
|
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
|
||||||
|
@ -82,6 +95,7 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
get_room_events_stream_for_room = (
|
get_room_events_stream_for_room = (
|
||||||
DataStore.get_room_events_stream_for_room.__func__
|
DataStore.get_room_events_stream_for_room.__func__
|
||||||
)
|
)
|
||||||
|
|
||||||
_set_before_and_after = DataStore._set_before_and_after
|
_set_before_and_after = DataStore._set_before_and_after
|
||||||
|
|
||||||
_get_events = DataStore._get_events.__func__
|
_get_events = DataStore._get_events.__func__
|
||||||
|
@ -104,7 +118,7 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedEventStore, self).stream_positions()
|
result = super(SlavedEventStore, self).stream_positions()
|
||||||
result["events"] = self._stream_id_gen.get_current_token()
|
result["events"] = self._stream_id_gen.get_current_token()
|
||||||
result["backfill"] = self._backfill_id_gen.get_current_token()
|
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication(self, result):
|
||||||
|
@ -122,7 +136,7 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
|
|
||||||
stream = result.get("backfill")
|
stream = result.get("backfill")
|
||||||
if stream:
|
if stream:
|
||||||
self._backfill_id_gen.advance(stream["position"])
|
self._backfill_id_gen.advance(-stream["position"])
|
||||||
for row in stream["rows"]:
|
for row in stream["rows"]:
|
||||||
self._process_replication_row(
|
self._process_replication_row(
|
||||||
row, backfilled=True, state_resets=state_resets
|
row, backfilled=True, state_resets=state_resets
|
||||||
|
@ -147,11 +161,11 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
internal = json.loads(row[1])
|
internal = json.loads(row[1])
|
||||||
event_json = json.loads(row[2])
|
event_json = json.loads(row[2])
|
||||||
event = FrozenEvent(event_json, internal_metadata_dict=internal)
|
event = FrozenEvent(event_json, internal_metadata_dict=internal)
|
||||||
self._invalidate_caches_for_event(
|
self.invalidate_caches_for_event(
|
||||||
event, backfilled, reset_state=position in state_resets
|
event, backfilled, reset_state=position in state_resets
|
||||||
)
|
)
|
||||||
|
|
||||||
def _invalidate_caches_for_event(self, event, backfilled, reset_state):
|
def invalidate_caches_for_event(self, event, backfilled, reset_state):
|
||||||
if reset_state:
|
if reset_state:
|
||||||
self._get_current_state_for_key.invalidate_all()
|
self._get_current_state_for_key.invalidate_all()
|
||||||
self.get_rooms_for_user.invalidate_all()
|
self.get_rooms_for_user.invalidate_all()
|
||||||
|
@ -163,6 +177,10 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
|
|
||||||
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
|
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
|
||||||
|
|
||||||
|
self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
|
||||||
|
(event.room_id,)
|
||||||
|
)
|
||||||
|
|
||||||
if not backfilled:
|
if not backfilled:
|
||||||
self._events_stream_cache.entity_has_changed(
|
self._events_stream_cache.entity_has_changed(
|
||||||
event.room_id, event.internal_metadata.stream_ordering
|
event.room_id, event.internal_metadata.stream_ordering
|
||||||
|
@ -182,6 +200,7 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
# self._membership_stream_cache.entity_has_changed(
|
# self._membership_stream_cache.entity_has_changed(
|
||||||
# event.state_key, event.internal_metadata.stream_ordering
|
# event.state_key, event.internal_metadata.stream_ordering
|
||||||
# )
|
# )
|
||||||
|
self.get_invited_rooms_for_user.invalidate((event.state_key,))
|
||||||
|
|
||||||
if not event.is_state():
|
if not event.is_state():
|
||||||
return
|
return
|
||||||
|
|
52
synapse/replication/slave/storage/pushers.py
Normal file
52
synapse/replication/slave/storage/pushers.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedPusherStore(BaseSlavedStore):
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(SlavedPusherStore, self).__init__(db_conn, hs)
|
||||||
|
self._pushers_id_gen = SlavedIdTracker(
|
||||||
|
db_conn, "pushers", "id",
|
||||||
|
extra_tables=[("deleted_pushers", "stream_id")],
|
||||||
|
)
|
||||||
|
|
||||||
|
get_all_pushers = DataStore.get_all_pushers.__func__
|
||||||
|
get_pushers_by = DataStore.get_pushers_by.__func__
|
||||||
|
get_pushers_by_app_id_and_pushkey = (
|
||||||
|
DataStore.get_pushers_by_app_id_and_pushkey.__func__
|
||||||
|
)
|
||||||
|
_decode_pushers_rows = DataStore._decode_pushers_rows.__func__
|
||||||
|
|
||||||
|
def stream_positions(self):
|
||||||
|
result = super(SlavedPusherStore, self).stream_positions()
|
||||||
|
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process_replication(self, result):
|
||||||
|
stream = result.get("pushers")
|
||||||
|
if stream:
|
||||||
|
self._pushers_id_gen.advance(stream["position"])
|
||||||
|
|
||||||
|
stream = result.get("deleted_pushers")
|
||||||
|
if stream:
|
||||||
|
self._pushers_id_gen.advance(stream["position"])
|
||||||
|
|
||||||
|
return super(SlavedPusherStore, self).process_replication(result)
|
61
synapse/replication/slave/storage/receipts.py
Normal file
61
synapse/replication/slave/storage/receipts.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
from synapse.storage.receipts import ReceiptsStore
|
||||||
|
|
||||||
|
# So, um, we want to borrow a load of functions intended for reading from
|
||||||
|
# a DataStore, but we don't want to take functions that either write to the
|
||||||
|
# DataStore or are cached and don't have cache invalidation logic.
|
||||||
|
#
|
||||||
|
# Rather than write duplicate versions of those functions, or lift them to
|
||||||
|
# a common base class, we going to grab the underlying __func__ object from
|
||||||
|
# the method descriptor on the DataStore and chuck them into our class.
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedReceiptsStore(BaseSlavedStore):
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self._receipts_id_gen = SlavedIdTracker(
|
||||||
|
db_conn, "receipts_linearized", "stream_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
|
||||||
|
|
||||||
|
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
|
||||||
|
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
|
||||||
|
|
||||||
|
def stream_positions(self):
|
||||||
|
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||||
|
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process_replication(self, result):
|
||||||
|
stream = result.get("receipts")
|
||||||
|
if stream:
|
||||||
|
self._receipts_id_gen.advance(stream["position"])
|
||||||
|
for row in stream["rows"]:
|
||||||
|
room_id, receipt_type, user_id = row[1:4]
|
||||||
|
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
|
||||||
|
|
||||||
|
return super(SlavedReceiptsStore, self).process_replication(result)
|
||||||
|
|
||||||
|
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||||
|
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
|
@ -33,9 +33,6 @@ from saml2.client import Saml2Client
|
||||||
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import jwt
|
|
||||||
from jwt.exceptions import InvalidTokenError
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -224,16 +221,24 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_jwt_login(self, login_submission):
|
def do_jwt_login(self, login_submission):
|
||||||
token = login_submission['token']
|
token = login_submission.get("token", None)
|
||||||
if token is None:
|
if token is None:
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(
|
||||||
|
401, "Token field for JWT is missing",
|
||||||
|
errcode=Codes.UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from jwt.exceptions import InvalidTokenError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
|
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
|
||||||
except InvalidTokenError:
|
except InvalidTokenError:
|
||||||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user = payload['user']
|
user = payload.get("sub", None)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,6 @@ class LocalKey(Resource):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
self.response_body = encode_canonical_json(
|
self.response_body = encode_canonical_json(
|
||||||
self.response_json_object(hs.config)
|
self.response_json_object(hs.config)
|
||||||
|
|
|
@ -97,7 +97,7 @@ class RemoteKey(Resource):
|
||||||
self.async_render_GET(request)
|
self.async_render_GET(request)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@request_handler
|
@request_handler()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def async_render_GET(self, request):
|
def async_render_GET(self, request):
|
||||||
if len(request.postpath) == 1:
|
if len(request.postpath) == 1:
|
||||||
|
@ -122,7 +122,7 @@ class RemoteKey(Resource):
|
||||||
self.async_render_POST(request)
|
self.async_render_POST(request)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@request_handler
|
@request_handler()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def async_render_POST(self, request):
|
def async_render_POST(self, request):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
|
@ -36,12 +36,13 @@ class DownloadResource(Resource):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self._async_render_GET(request)
|
self._async_render_GET(request)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@request_handler
|
@request_handler()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_GET(self, request):
|
def _async_render_GET(self, request):
|
||||||
server_name, media_id, name = parse_media_id(request)
|
server_name, media_id, name = parse_media_id(request)
|
||||||
|
|
|
@ -45,7 +45,17 @@ class PreviewUrlResource(Resource):
|
||||||
|
|
||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs, media_repo):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
|
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.version_string = hs.version_string
|
||||||
|
self.filepaths = media_repo.filepaths
|
||||||
|
self.max_spider_size = hs.config.max_spider_size
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
self.store = hs.get_datastore()
|
||||||
self.client = SpiderHttpClient(hs)
|
self.client = SpiderHttpClient(hs)
|
||||||
|
self.media_repo = media_repo
|
||||||
|
|
||||||
if hasattr(hs.config, "url_preview_url_blacklist"):
|
if hasattr(hs.config, "url_preview_url_blacklist"):
|
||||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||||
|
|
||||||
|
@ -60,18 +70,11 @@ class PreviewUrlResource(Resource):
|
||||||
|
|
||||||
self.downloads = {}
|
self.downloads = {}
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
|
||||||
self.clock = hs.get_clock()
|
|
||||||
self.version_string = hs.version_string
|
|
||||||
self.filepaths = media_repo.filepaths
|
|
||||||
self.max_spider_size = hs.config.max_spider_size
|
|
||||||
self.server_name = hs.hostname
|
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self._async_render_GET(request)
|
self._async_render_GET(request)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@request_handler
|
@request_handler()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_GET(self, request):
|
def _async_render_GET(self, request):
|
||||||
|
|
||||||
|
@ -368,7 +371,7 @@ class PreviewUrlResource(Resource):
|
||||||
file_id = random_string(24)
|
file_id = random_string(24)
|
||||||
|
|
||||||
fname = self.filepaths.local_media_filepath(file_id)
|
fname = self.filepaths.local_media_filepath(file_id)
|
||||||
self._makedirs(fname)
|
self.media_repo._makedirs(fname)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(fname, "wb") as f:
|
with open(fname, "wb") as f:
|
||||||
|
|
|
@ -39,12 +39,13 @@ class ThumbnailResource(Resource):
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self._async_render_GET(request)
|
self._async_render_GET(request)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@request_handler
|
@request_handler()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_GET(self, request):
|
def _async_render_GET(self, request):
|
||||||
server_name, media_id, _ = parse_media_id(request)
|
server_name, media_id, _ = parse_media_id(request)
|
||||||
|
|
|
@ -41,6 +41,7 @@ class UploadResource(Resource):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.max_upload_size = hs.config.max_upload_size
|
self.max_upload_size = hs.config.max_upload_size
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
def render_POST(self, request):
|
def render_POST(self, request):
|
||||||
self._async_render_POST(request)
|
self._async_render_POST(request)
|
||||||
|
@ -50,7 +51,7 @@ class UploadResource(Resource):
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@request_handler
|
@request_handler()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_POST(self, request):
|
def _async_render_POST(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
|
@ -193,6 +193,9 @@ class HomeServer(object):
|
||||||
**self.db_config.get("args", {})
|
**self.db_config.get("args", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
|
||||||
def _make_dependency_method(depname):
|
def _make_dependency_method(depname):
|
||||||
def _get(hs):
|
def _get(hs):
|
||||||
|
|
|
@ -214,7 +214,7 @@ class StateHandler(object):
|
||||||
|
|
||||||
if self._state_cache is not None:
|
if self._state_cache is not None:
|
||||||
cache = self._state_cache.get(group_names, None)
|
cache = self._state_cache.get(group_names, None)
|
||||||
if cache and cache.state_group:
|
if cache:
|
||||||
cache.ts = self.clock.time_msec()
|
cache.ts = self.clock.time_msec()
|
||||||
|
|
||||||
event_dict = yield self.store.get_events(cache.state.values())
|
event_dict = yield self.store.get_events(cache.state.values())
|
||||||
|
@ -230,22 +230,34 @@ class StateHandler(object):
|
||||||
(cache.state_group, state, prev_states)
|
(cache.state_group, state, prev_states)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
|
||||||
|
|
||||||
new_state, prev_states = self._resolve_events(
|
new_state, prev_states = self._resolve_events(
|
||||||
state_groups.values(), event_type, state_key
|
state_groups.values(), event_type, state_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
state_group = None
|
||||||
|
new_state_event_ids = frozenset(e.event_id for e in new_state.values())
|
||||||
|
for sg, events in state_groups.items():
|
||||||
|
if new_state_event_ids == frozenset(e.event_id for e in events):
|
||||||
|
state_group = sg
|
||||||
|
break
|
||||||
|
|
||||||
if self._state_cache is not None:
|
if self._state_cache is not None:
|
||||||
cache = _StateCacheEntry(
|
cache = _StateCacheEntry(
|
||||||
state={key: event.event_id for key, event in new_state.items()},
|
state={key: event.event_id for key, event in new_state.items()},
|
||||||
state_group=None,
|
state_group=state_group,
|
||||||
ts=self.clock.time_msec()
|
ts=self.clock.time_msec()
|
||||||
)
|
)
|
||||||
|
|
||||||
self._state_cache[group_names] = cache
|
self._state_cache[group_names] = cache
|
||||||
|
|
||||||
defer.returnValue((None, new_state, prev_states))
|
defer.returnValue((state_group, new_state, prev_states))
|
||||||
|
|
||||||
def resolve_events(self, state_sets, event):
|
def resolve_events(self, state_sets, event):
|
||||||
|
logger.info(
|
||||||
|
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||||
|
)
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
return self._resolve_events(
|
return self._resolve_events(
|
||||||
state_sets, event.type, event.state_key
|
state_sets, event.type, event.state_key
|
||||||
|
|
|
@ -173,11 +173,12 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Updating %r. Updated %r items in %rms."
|
"Updating %r. Updated %r items in %rms."
|
||||||
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)",
|
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
|
||||||
update_name, items_updated, duration_ms,
|
update_name, items_updated, duration_ms,
|
||||||
performance.total_items_per_ms(),
|
performance.total_items_per_ms(),
|
||||||
performance.average_items_per_ms(),
|
performance.average_items_per_ms(),
|
||||||
performance.total_item_count,
|
performance.total_item_count,
|
||||||
|
batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
performance.update(items_updated, duration_ms)
|
performance.update(items_updated, duration_ms)
|
||||||
|
|
|
@ -1145,6 +1145,12 @@ class EventsStore(SQLBaseStore):
|
||||||
current_backfill_id, current_forward_id, limit):
|
current_backfill_id, current_forward_id, limit):
|
||||||
"""Get all the new events that have arrived at the server either as
|
"""Get all the new events that have arrived at the server either as
|
||||||
new events or as backfilled events"""
|
new events or as backfilled events"""
|
||||||
|
have_backfill_events = last_backfill_id != current_backfill_id
|
||||||
|
have_forward_events = last_forward_id != current_forward_id
|
||||||
|
|
||||||
|
if not have_backfill_events and not have_forward_events:
|
||||||
|
return defer.succeed(AllNewEventsResult([], [], [], [], []))
|
||||||
|
|
||||||
def get_all_new_events_txn(txn):
|
def get_all_new_events_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
|
"SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
|
||||||
|
@ -1157,7 +1163,7 @@ class EventsStore(SQLBaseStore):
|
||||||
" ORDER BY e.stream_ordering ASC"
|
" ORDER BY e.stream_ordering ASC"
|
||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
if last_forward_id != current_forward_id:
|
if have_forward_events:
|
||||||
txn.execute(sql, (last_forward_id, current_forward_id, limit))
|
txn.execute(sql, (last_forward_id, current_forward_id, limit))
|
||||||
new_forward_events = txn.fetchall()
|
new_forward_events = txn.fetchall()
|
||||||
|
|
||||||
|
@ -1201,7 +1207,7 @@ class EventsStore(SQLBaseStore):
|
||||||
" ORDER BY e.stream_ordering DESC"
|
" ORDER BY e.stream_ordering DESC"
|
||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
if last_backfill_id != current_backfill_id:
|
if have_backfill_events:
|
||||||
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
|
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
|
||||||
new_backfill_events = txn.fetchall()
|
new_backfill_events = txn.fetchall()
|
||||||
|
|
||||||
|
|
|
@ -106,6 +106,9 @@ class PusherStore(SQLBaseStore):
|
||||||
return self._pushers_id_gen.get_current_token()
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
def get_all_updated_pushers(self, last_id, current_id, limit):
|
def get_all_updated_pushers(self, last_id, current_id, limit):
|
||||||
|
if last_id == current_id:
|
||||||
|
return defer.succeed(([], []))
|
||||||
|
|
||||||
def get_all_updated_pushers_txn(txn):
|
def get_all_updated_pushers_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT id, user_name, access_token, profile_tag, kind,"
|
"SELECT id, user_name, access_token, profile_tag, kind,"
|
||||||
|
|
|
@ -391,6 +391,9 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
||||||
|
if last_id == current_id:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
def get_all_updated_receipts_txn(txn):
|
def get_all_updated_receipts_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
||||||
|
|
|
@ -169,20 +169,28 @@ class RoomStore(SQLBaseStore):
|
||||||
def _store_event_search_txn(self, txn, event, key, value):
|
def _store_event_search_txn(self, txn, event, key, value):
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
sql = (
|
sql = (
|
||||||
"INSERT INTO event_search (event_id, room_id, key, vector)"
|
"INSERT INTO event_search"
|
||||||
" VALUES (?,?,?,to_tsvector('english', ?))"
|
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
|
||||||
|
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
|
||||||
|
)
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
event.event_id, event.room_id, key, value,
|
||||||
|
event.internal_metadata.stream_ordering,
|
||||||
|
event.origin_server_ts,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
sql = (
|
sql = (
|
||||||
"INSERT INTO event_search (event_id, room_id, key, value)"
|
"INSERT INTO event_search (event_id, room_id, key, value)"
|
||||||
" VALUES (?,?,?,?)"
|
" VALUES (?,?,?,?)"
|
||||||
)
|
)
|
||||||
|
txn.execute(sql, (event.event_id, event.room_id, key, value,))
|
||||||
else:
|
else:
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
raise Exception("Unrecognized database engine")
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
txn.execute(sql, (event.event_id, event.room_id, key, value,))
|
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks()
|
||||||
def get_room_name_and_aliases(self, room_id):
|
def get_room_name_and_aliases(self, room_id):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
|
65
synapse/storage/schema/delta/31/search_update.py
Normal file
65
synapse/storage/schema/delta/31/search_update.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
from synapse.storage.prepare_database import get_statements
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import ujson
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ALTER_TABLE = """
|
||||||
|
ALTER TABLE event_search ADD COLUMN origin_server_ts BIGINT;
|
||||||
|
ALTER TABLE event_search ADD COLUMN stream_ordering BIGINT;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def run_create(cur, database_engine, *args, **kwargs):
|
||||||
|
if not isinstance(database_engine, PostgresEngine):
|
||||||
|
return
|
||||||
|
|
||||||
|
for statement in get_statements(ALTER_TABLE.splitlines()):
|
||||||
|
cur.execute(statement)
|
||||||
|
|
||||||
|
cur.execute("SELECT MIN(stream_ordering) FROM events")
|
||||||
|
rows = cur.fetchall()
|
||||||
|
min_stream_id = rows[0][0]
|
||||||
|
|
||||||
|
cur.execute("SELECT MAX(stream_ordering) FROM events")
|
||||||
|
rows = cur.fetchall()
|
||||||
|
max_stream_id = rows[0][0]
|
||||||
|
|
||||||
|
if min_stream_id is not None and max_stream_id is not None:
|
||||||
|
progress = {
|
||||||
|
"target_min_stream_id_inclusive": min_stream_id,
|
||||||
|
"max_stream_id_exclusive": max_stream_id + 1,
|
||||||
|
"rows_inserted": 0,
|
||||||
|
"have_added_indexes": False,
|
||||||
|
}
|
||||||
|
progress_json = ujson.dumps(progress)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"INSERT into background_updates (update_name, progress_json)"
|
||||||
|
" VALUES (?, ?)"
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = database_engine.convert_param_style(sql)
|
||||||
|
|
||||||
|
cur.execute(sql, ("event_search_order", progress_json))
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||||
|
pass
|
|
@ -29,12 +29,17 @@ logger = logging.getLogger(__name__)
|
||||||
class SearchStore(BackgroundUpdateStore):
|
class SearchStore(BackgroundUpdateStore):
|
||||||
|
|
||||||
EVENT_SEARCH_UPDATE_NAME = "event_search"
|
EVENT_SEARCH_UPDATE_NAME = "event_search"
|
||||||
|
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(SearchStore, self).__init__(hs)
|
super(SearchStore, self).__init__(hs)
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
|
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
|
||||||
)
|
)
|
||||||
|
self.register_background_update_handler(
|
||||||
|
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
||||||
|
self._background_reindex_search_order
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _background_reindex_search(self, progress, batch_size):
|
def _background_reindex_search(self, progress, batch_size):
|
||||||
|
@ -131,6 +136,82 @@ class SearchStore(BackgroundUpdateStore):
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _background_reindex_search_order(self, progress, batch_size):
|
||||||
|
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||||
|
max_stream_id = progress["max_stream_id_exclusive"]
|
||||||
|
rows_inserted = progress.get("rows_inserted", 0)
|
||||||
|
have_added_index = progress['have_added_indexes']
|
||||||
|
|
||||||
|
if not have_added_index:
|
||||||
|
def create_index(conn):
|
||||||
|
conn.rollback()
|
||||||
|
conn.set_session(autocommit=True)
|
||||||
|
c = conn.cursor()
|
||||||
|
|
||||||
|
# We create with NULLS FIRST so that when we search *backwards*
|
||||||
|
# we get the ones with non null origin_server_ts *first*
|
||||||
|
c.execute(
|
||||||
|
"CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
|
||||||
|
"room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
|
||||||
|
)
|
||||||
|
c.execute(
|
||||||
|
"CREATE INDEX CONCURRENTLY event_search_order ON event_search("
|
||||||
|
"origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
|
||||||
|
)
|
||||||
|
conn.set_session(autocommit=False)
|
||||||
|
|
||||||
|
yield self.runWithConnection(create_index)
|
||||||
|
|
||||||
|
pg = dict(progress)
|
||||||
|
pg["have_added_indexes"] = True
|
||||||
|
|
||||||
|
yield self.runInteraction(
|
||||||
|
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
||||||
|
self._background_update_progress_txn,
|
||||||
|
self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reindex_search_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
|
||||||
|
" origin_server_ts = e.origin_server_ts"
|
||||||
|
" FROM events AS e"
|
||||||
|
" WHERE e.event_id = es.event_id"
|
||||||
|
" AND ? <= e.stream_ordering AND e.stream_ordering < ?"
|
||||||
|
" RETURNING es.stream_ordering"
|
||||||
|
)
|
||||||
|
|
||||||
|
min_stream_id = max_stream_id - batch_size
|
||||||
|
txn.execute(sql, (min_stream_id, max_stream_id))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
|
||||||
|
if min_stream_id < target_min_stream_id:
|
||||||
|
# We've recached the end.
|
||||||
|
return len(rows), False
|
||||||
|
|
||||||
|
progress = {
|
||||||
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||||
|
"max_stream_id_exclusive": min_stream_id,
|
||||||
|
"rows_inserted": rows_inserted + len(rows),
|
||||||
|
"have_added_indexes": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._background_update_progress_txn(
|
||||||
|
txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(rows), True
|
||||||
|
|
||||||
|
num_rows, finished = yield self.runInteraction(
|
||||||
|
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
if not finished:
|
||||||
|
yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
|
||||||
|
|
||||||
|
defer.returnValue(num_rows)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def search_msgs(self, room_ids, search_term, keys):
|
def search_msgs(self, room_ids, search_term, keys):
|
||||||
"""Performs a full text search over events with given keys.
|
"""Performs a full text search over events with given keys.
|
||||||
|
@ -310,7 +391,6 @@ class SearchStore(BackgroundUpdateStore):
|
||||||
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
|
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
|
||||||
" origin_server_ts, stream_ordering, room_id, event_id"
|
" origin_server_ts, stream_ordering, room_id, event_id"
|
||||||
" FROM event_search"
|
" FROM event_search"
|
||||||
" NATURAL JOIN events"
|
|
||||||
" WHERE vector @@ to_tsquery('english', ?) AND "
|
" WHERE vector @@ to_tsquery('english', ?) AND "
|
||||||
)
|
)
|
||||||
args = [search_query, search_query] + args
|
args = [search_query, search_query] + args
|
||||||
|
@ -355,7 +435,15 @@ class SearchStore(BackgroundUpdateStore):
|
||||||
|
|
||||||
# We add an arbitrary limit here to ensure we don't try to pull the
|
# We add an arbitrary limit here to ensure we don't try to pull the
|
||||||
# entire table from the database.
|
# entire table from the database.
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
sql += (
|
||||||
|
" ORDER BY origin_server_ts DESC NULLS LAST,"
|
||||||
|
" stream_ordering DESC NULLS LAST LIMIT ?"
|
||||||
|
)
|
||||||
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
|
sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
|
||||||
|
else:
|
||||||
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
args.append(limit)
|
args.append(limit)
|
||||||
|
|
||||||
|
|
|
@ -174,6 +174,12 @@ class StateStore(SQLBaseStore):
|
||||||
return [r[0] for r in results]
|
return [r[0] for r in results]
|
||||||
return self.runInteraction("get_current_state_for_key", f)
|
return self.runInteraction("get_current_state_for_key", f)
|
||||||
|
|
||||||
|
@cached(num_args=2, lru=True, max_entries=1000)
|
||||||
|
def _get_state_group_from_group(self, group, types):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@cachedList(cached_method_name="_get_state_group_from_group",
|
||||||
|
list_name="groups", num_args=2, inlineCallbacks=True)
|
||||||
def _get_state_groups_from_groups(self, groups, types):
|
def _get_state_groups_from_groups(self, groups, types):
|
||||||
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
|
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
|
||||||
"""
|
"""
|
||||||
|
@ -201,18 +207,23 @@ class StateStore(SQLBaseStore):
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
rows = self.cursor_to_dict(txn)
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
results = {}
|
results = {group: {} for group in groups}
|
||||||
for row in rows:
|
for row in rows:
|
||||||
key = (row["type"], row["state_key"])
|
key = (row["type"], row["state_key"])
|
||||||
results.setdefault(row["state_group"], {})[key] = row["event_id"]
|
results[row["state_group"]][key] = row["event_id"]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
|
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
return self.runInteraction(
|
res = yield self.runInteraction(
|
||||||
"_get_state_groups_from_groups",
|
"_get_state_groups_from_groups",
|
||||||
f, chunk
|
f, chunk
|
||||||
)
|
)
|
||||||
|
results.update(res)
|
||||||
|
|
||||||
|
defer.returnValue(results)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_for_events(self, event_ids, types):
|
def get_state_for_events(self, event_ids, types):
|
||||||
|
@ -359,6 +370,8 @@ class StateStore(SQLBaseStore):
|
||||||
a `state_key` of None matches all state_keys. If `types` is None then
|
a `state_key` of None matches all state_keys. If `types` is None then
|
||||||
all events are returned.
|
all events are returned.
|
||||||
"""
|
"""
|
||||||
|
if types:
|
||||||
|
types = frozenset(types)
|
||||||
results = {}
|
results = {}
|
||||||
missing_groups = []
|
missing_groups = []
|
||||||
if types is not None:
|
if types is not None:
|
||||||
|
|
98
synapse/util/httpresourcetree.py
Normal file
98
synapse/util/httpresourcetree.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_resource_tree(desired_tree, root_resource):
|
||||||
|
"""Create the resource tree for this Home Server.
|
||||||
|
|
||||||
|
This in unduly complicated because Twisted does not support putting
|
||||||
|
child resources more than 1 level deep at a time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
web_client (bool): True to enable the web client.
|
||||||
|
root_resource (twisted.web.resource.Resource): The root
|
||||||
|
resource to add the tree to.
|
||||||
|
Returns:
|
||||||
|
twisted.web.resource.Resource: the ``root_resource`` with a tree of
|
||||||
|
child resources added to it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ideally we'd just use getChild and putChild but getChild doesn't work
|
||||||
|
# unless you give it a Request object IN ADDITION to the name :/ So
|
||||||
|
# instead, we'll store a copy of this mapping so we can actually add
|
||||||
|
# extra resources to existing nodes. See self._resource_id for the key.
|
||||||
|
resource_mappings = {}
|
||||||
|
for full_path, res in desired_tree.items():
|
||||||
|
logger.info("Attaching %s to path %s", res, full_path)
|
||||||
|
last_resource = root_resource
|
||||||
|
for path_seg in full_path.split('/')[1:-1]:
|
||||||
|
if path_seg not in last_resource.listNames():
|
||||||
|
# resource doesn't exist, so make a "dummy resource"
|
||||||
|
child_resource = Resource()
|
||||||
|
last_resource.putChild(path_seg, child_resource)
|
||||||
|
res_id = _resource_id(last_resource, path_seg)
|
||||||
|
resource_mappings[res_id] = child_resource
|
||||||
|
last_resource = child_resource
|
||||||
|
else:
|
||||||
|
# we have an existing Resource, use that instead.
|
||||||
|
res_id = _resource_id(last_resource, path_seg)
|
||||||
|
last_resource = resource_mappings[res_id]
|
||||||
|
|
||||||
|
# ===========================
|
||||||
|
# now attach the actual desired resource
|
||||||
|
last_path_seg = full_path.split('/')[-1]
|
||||||
|
|
||||||
|
# if there is already a resource here, thieve its children and
|
||||||
|
# replace it
|
||||||
|
res_id = _resource_id(last_resource, last_path_seg)
|
||||||
|
if res_id in resource_mappings:
|
||||||
|
# there is a dummy resource at this path already, which needs
|
||||||
|
# to be replaced with the desired resource.
|
||||||
|
existing_dummy_resource = resource_mappings[res_id]
|
||||||
|
for child_name in existing_dummy_resource.listNames():
|
||||||
|
child_res_id = _resource_id(
|
||||||
|
existing_dummy_resource, child_name
|
||||||
|
)
|
||||||
|
child_resource = resource_mappings[child_res_id]
|
||||||
|
# steal the children
|
||||||
|
res.putChild(child_name, child_resource)
|
||||||
|
|
||||||
|
# finally, insert the desired resource in the right place
|
||||||
|
last_resource.putChild(last_path_seg, res)
|
||||||
|
res_id = _resource_id(last_resource, last_path_seg)
|
||||||
|
resource_mappings[res_id] = res
|
||||||
|
|
||||||
|
return root_resource
|
||||||
|
|
||||||
|
|
||||||
|
def _resource_id(resource, path_seg):
|
||||||
|
"""Construct an arbitrary resource ID so you can retrieve the mapping
|
||||||
|
later.
|
||||||
|
|
||||||
|
If you want to represent resource A putChild resource B with path C,
|
||||||
|
the mapping should looks like _resource_id(A,C) = B.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource (Resource): The *parent* Resourceb
|
||||||
|
path_seg (str): The name of the child Resource to be attached.
|
||||||
|
Returns:
|
||||||
|
str: A unique string which can be a key to the child Resource.
|
||||||
|
"""
|
||||||
|
return "%s-%s" % (resource, path_seg)
|
70
synapse/util/manhole.py
Normal file
70
synapse/util/manhole.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.conch.manhole import ColoredManhole
|
||||||
|
from twisted.conch.insults import insults
|
||||||
|
from twisted.conch import manhole_ssh
|
||||||
|
from twisted.cred import checkers, portal
|
||||||
|
from twisted.conch.ssh.keys import Key
|
||||||
|
|
||||||
|
PUBLIC_KEY = (
|
||||||
|
"ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az"
|
||||||
|
"64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS"
|
||||||
|
"kbh/C+BR3utDS555mV"
|
||||||
|
)
|
||||||
|
|
||||||
|
PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY-----
|
||||||
|
MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
|
||||||
|
4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
|
||||||
|
vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
|
||||||
|
Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
|
||||||
|
xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
|
||||||
|
PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
|
||||||
|
gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
|
||||||
|
DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
|
||||||
|
pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
|
||||||
|
EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
|
||||||
|
-----END RSA PRIVATE KEY-----"""
|
||||||
|
|
||||||
|
|
||||||
|
def manhole(username, password, globals):
|
||||||
|
"""Starts a ssh listener with password authentication using
|
||||||
|
the given username and password. Clients connecting to the ssh
|
||||||
|
listener will find themselves in a colored python shell with
|
||||||
|
the supplied globals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username(str): The username ssh clients should auth with.
|
||||||
|
password(str): The password ssh clients should auth with.
|
||||||
|
globals(dict): The variables to expose in the shell.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
|
||||||
|
"""
|
||||||
|
|
||||||
|
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
|
||||||
|
**{username: password}
|
||||||
|
)
|
||||||
|
|
||||||
|
rlm = manhole_ssh.TerminalRealm()
|
||||||
|
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
|
||||||
|
ColoredManhole,
|
||||||
|
dict(globals, __name__="__console__")
|
||||||
|
)
|
||||||
|
|
||||||
|
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
|
||||||
|
factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY)
|
||||||
|
factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY)
|
||||||
|
|
||||||
|
return factory
|
|
@ -15,8 +15,6 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from synapse.replication.resource import ReplicationResource
|
from synapse.replication.resource import ReplicationResource
|
||||||
|
@ -38,7 +36,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||||
self.replication = ReplicationResource(self.hs)
|
self.replication = ReplicationResource(self.hs)
|
||||||
|
|
||||||
self.master_store = self.hs.get_datastore()
|
self.master_store = self.hs.get_datastore()
|
||||||
self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs)
|
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -16,6 +16,7 @@ from ._base import BaseSlavedStoreTestCase
|
||||||
|
|
||||||
from synapse.events import FrozenEvent, _EventInternalMetadata
|
from synapse.events import FrozenEvent, _EventInternalMetadata
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
from synapse.storage.roommember import RoomsForUser
|
from synapse.storage.roommember import RoomsForUser
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -43,6 +44,8 @@ def patch__eq__(cls):
|
||||||
|
|
||||||
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
|
STORE_TYPE = SlavedEventStore
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Patch up the equality operator for events so that we can check
|
# Patch up the equality operator for events so that we can check
|
||||||
# whether lists of events match using assertEquals
|
# whether lists of events match using assertEquals
|
||||||
|
@ -251,6 +254,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
|
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
|
||||||
yield self.check("get_event", [msg.event_id], redacted)
|
yield self.check("get_event", [msg.event_id], redacted)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invites(self):
|
||||||
|
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
|
||||||
|
event = yield self.persist(
|
||||||
|
type="m.room.member", key=USER_ID_2, membership="invite"
|
||||||
|
)
|
||||||
|
yield self.replicate()
|
||||||
|
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser(
|
||||||
|
ROOM_ID, USER_ID, "invite", event.event_id,
|
||||||
|
event.internal_metadata.stream_ordering
|
||||||
|
)])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_push_actions_for_user(self):
|
||||||
|
yield self.persist(type="m.room.create", creator=USER_ID)
|
||||||
|
yield self.persist(type="m.room.join", key=USER_ID, membership="join")
|
||||||
|
yield self.persist(
|
||||||
|
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
|
||||||
|
)
|
||||||
|
event1 = yield self.persist(
|
||||||
|
type="m.room.message", msgtype="m.text", body="hello"
|
||||||
|
)
|
||||||
|
yield self.replicate()
|
||||||
|
yield self.check(
|
||||||
|
"get_unread_event_push_actions_by_room_for_user",
|
||||||
|
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||||
|
{"highlight_count": 0, "notify_count": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.persist(
|
||||||
|
type="m.room.message", msgtype="m.text", body="world",
|
||||||
|
push_actions=[(USER_ID_2, ["notify"])],
|
||||||
|
)
|
||||||
|
yield self.replicate()
|
||||||
|
yield self.check(
|
||||||
|
"get_unread_event_push_actions_by_room_for_user",
|
||||||
|
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||||
|
{"highlight_count": 0, "notify_count": 1}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.persist(
|
||||||
|
type="m.room.message", msgtype="m.text", body="world",
|
||||||
|
push_actions=[(USER_ID_2, [
|
||||||
|
"notify", {"set_tweak": "highlight", "value": True}
|
||||||
|
])],
|
||||||
|
)
|
||||||
|
yield self.replicate()
|
||||||
|
yield self.check(
|
||||||
|
"get_unread_event_push_actions_by_room_for_user",
|
||||||
|
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||||
|
{"highlight_count": 1, "notify_count": 2}
|
||||||
|
)
|
||||||
|
|
||||||
event_id = 0
|
event_id = 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -258,6 +314,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
|
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
|
||||||
state=None, reset_state=False, backfill=False,
|
state=None, reset_state=False, backfill=False,
|
||||||
depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
|
depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
|
||||||
|
push_actions=[],
|
||||||
**content
|
**content
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -290,6 +347,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
self.event_id += 1
|
self.event_id += 1
|
||||||
|
|
||||||
context = EventContext(current_state=state)
|
context = EventContext(current_state=state)
|
||||||
|
context.push_actions = push_actions
|
||||||
|
|
||||||
ordering = None
|
ordering = None
|
||||||
if backfill:
|
if backfill:
|
||||||
|
|
39
tests/replication/slave/storage/test_receipts.py
Normal file
39
tests/replication/slave/storage/test_receipts.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStoreTestCase
|
||||||
|
|
||||||
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
USER_ID = "@feeling:blue"
|
||||||
|
ROOM_ID = "!room:blue"
|
||||||
|
EVENT_ID = "$event:blue"
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
|
STORE_TYPE = SlavedReceiptsStore
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_receipt(self):
|
||||||
|
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
|
||||||
|
yield self.master_store.insert_receipt(
|
||||||
|
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
|
||||||
|
)
|
||||||
|
yield self.replicate()
|
||||||
|
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {
|
||||||
|
ROOM_ID: EVENT_ID
|
||||||
|
})
|
Loading…
Reference in a new issue