mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-09 11:32:01 +01:00
Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.12.1
This commit is contained in:
commit
fd142c29d9
40 changed files with 939 additions and 1034 deletions
|
@ -26,7 +26,7 @@ TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||||
if [[ ! -e .sytest-base ]]; then
|
if [[ ! -e .sytest-base ]]; then
|
||||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
||||||
else
|
else
|
||||||
(cd .sytest-base; git fetch)
|
(cd .sytest-base; git fetch -p)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
rm -rf sytest
|
rm -rf sytest
|
||||||
|
@ -52,7 +52,7 @@ RUN_POSTGRES=""
|
||||||
|
|
||||||
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
|
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
|
||||||
if psql synapse_jenkins_$port <<< ""; then
|
if psql synapse_jenkins_$port <<< ""; then
|
||||||
RUN_POSTGRES=$RUN_POSTGRES:$port
|
RUN_POSTGRES="$RUN_POSTGRES:$port"
|
||||||
cat > localhost-$port/database.yaml << EOF
|
cat > localhost-$port/database.yaml << EOF
|
||||||
name: psycopg2
|
name: psycopg2
|
||||||
args:
|
args:
|
||||||
|
@ -62,7 +62,7 @@ EOF
|
||||||
done
|
done
|
||||||
|
|
||||||
# Run if both postgresql databases exist
|
# Run if both postgresql databases exist
|
||||||
if test $RUN_POSTGRES = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
|
if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
|
||||||
echo >&2 "Running sytest with PostgreSQL";
|
echo >&2 "Running sytest with PostgreSQL";
|
||||||
$TOX_BIN/pip install psycopg2
|
$TOX_BIN/pip install psycopg2
|
||||||
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
|
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.types import UserID, RoomID
|
from synapse.types import UserID, RoomID
|
||||||
|
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
|
|
||||||
class Filtering(object):
|
class Filtering(object):
|
||||||
|
|
||||||
|
@ -149,6 +151,9 @@ class FilterCollection(object):
|
||||||
"include_leave", False
|
"include_leave", False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
||||||
|
|
||||||
def get_filter_json(self):
|
def get_filter_json(self):
|
||||||
return self._filter_json
|
return self._filter_json
|
||||||
|
|
||||||
|
|
|
@ -50,16 +50,14 @@ from twisted.cred import checkers, portal
|
||||||
|
|
||||||
from twisted.internet import reactor, task, defer
|
from twisted.internet import reactor, task, defer
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.enterprise import adbapi
|
|
||||||
from twisted.web.resource import Resource, EncodingResourceWrapper
|
from twisted.web.resource import Resource, EncodingResourceWrapper
|
||||||
from twisted.web.static import File
|
from twisted.web.static import File
|
||||||
from twisted.web.server import Site, GzipEncoderFactory, Request
|
from twisted.web.server import Site, GzipEncoderFactory, Request
|
||||||
from synapse.http.server import JsonResource, RootRedirect
|
from synapse.http.server import RootRedirect
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||||
from synapse.rest.key.v1.server_key_resource import LocalKey
|
from synapse.rest.key.v1.server_key_resource import LocalKey
|
||||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
|
||||||
from synapse.api.urls import (
|
from synapse.api.urls import (
|
||||||
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
|
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
|
||||||
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
|
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
|
||||||
|
@ -69,6 +67,7 @@ from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.crypto import context_factory
|
from synapse.crypto import context_factory
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
|
|
||||||
from synapse import events
|
from synapse import events
|
||||||
|
|
||||||
|
@ -95,19 +94,8 @@ def gz_wrap(r):
|
||||||
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
||||||
|
|
||||||
|
|
||||||
class SynapseHomeServer(HomeServer):
|
def build_resource_for_web_client(hs):
|
||||||
|
webclient_path = hs.get_config().web_client_location
|
||||||
def build_http_client(self):
|
|
||||||
return MatrixFederationHttpClient(self)
|
|
||||||
|
|
||||||
def build_client_resource(self):
|
|
||||||
return ClientRestResource(self)
|
|
||||||
|
|
||||||
def build_resource_for_federation(self):
|
|
||||||
return JsonResource(self)
|
|
||||||
|
|
||||||
def build_resource_for_web_client(self):
|
|
||||||
webclient_path = self.get_config().web_client_location
|
|
||||||
if not webclient_path:
|
if not webclient_path:
|
||||||
try:
|
try:
|
||||||
import syweb
|
import syweb
|
||||||
|
@ -135,40 +123,8 @@ class SynapseHomeServer(HomeServer):
|
||||||
# return GzipFile(webclient_path) # TODO configurable?
|
# return GzipFile(webclient_path) # TODO configurable?
|
||||||
return File(webclient_path) # TODO configurable?
|
return File(webclient_path) # TODO configurable?
|
||||||
|
|
||||||
def build_resource_for_static_content(self):
|
|
||||||
# This is old and should go away: not going to bother adding gzip
|
|
||||||
return File(
|
|
||||||
os.path.join(os.path.dirname(synapse.__file__), "static")
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_resource_for_content_repo(self):
|
|
||||||
return ContentRepoResource(
|
|
||||||
self, self.config.uploads_path, self.auth, self.content_addr
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_resource_for_media_repository(self):
|
|
||||||
return MediaRepositoryResource(self)
|
|
||||||
|
|
||||||
def build_resource_for_server_key(self):
|
|
||||||
return LocalKey(self)
|
|
||||||
|
|
||||||
def build_resource_for_server_key_v2(self):
|
|
||||||
return KeyApiV2Resource(self)
|
|
||||||
|
|
||||||
def build_resource_for_metrics(self):
|
|
||||||
if self.get_config().enable_metrics:
|
|
||||||
return MetricsResource(self)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def build_db_pool(self):
|
|
||||||
name = self.db_config["name"]
|
|
||||||
|
|
||||||
return adbapi.ConnectionPool(
|
|
||||||
name,
|
|
||||||
**self.db_config.get("args", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class SynapseHomeServer(HomeServer):
|
||||||
def _listener_http(self, config, listener_config):
|
def _listener_http(self, config, listener_config):
|
||||||
port = listener_config["port"]
|
port = listener_config["port"]
|
||||||
bind_address = listener_config.get("bind_address", "")
|
bind_address = listener_config.get("bind_address", "")
|
||||||
|
@ -178,13 +134,11 @@ class SynapseHomeServer(HomeServer):
|
||||||
if tls and config.no_tls:
|
if tls and config.no_tls:
|
||||||
return
|
return
|
||||||
|
|
||||||
metrics_resource = self.get_resource_for_metrics()
|
|
||||||
|
|
||||||
resources = {}
|
resources = {}
|
||||||
for res in listener_config["resources"]:
|
for res in listener_config["resources"]:
|
||||||
for name in res["names"]:
|
for name in res["names"]:
|
||||||
if name == "client":
|
if name == "client":
|
||||||
client_resource = self.get_client_resource()
|
client_resource = ClientRestResource(self)
|
||||||
if res["compress"]:
|
if res["compress"]:
|
||||||
client_resource = gz_wrap(client_resource)
|
client_resource = gz_wrap(client_resource)
|
||||||
|
|
||||||
|
@ -198,31 +152,35 @@ class SynapseHomeServer(HomeServer):
|
||||||
|
|
||||||
if name == "federation":
|
if name == "federation":
|
||||||
resources.update({
|
resources.update({
|
||||||
FEDERATION_PREFIX: self.get_resource_for_federation(),
|
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||||
})
|
})
|
||||||
|
|
||||||
if name in ["static", "client"]:
|
if name in ["static", "client"]:
|
||||||
resources.update({
|
resources.update({
|
||||||
STATIC_PREFIX: self.get_resource_for_static_content(),
|
STATIC_PREFIX: File(
|
||||||
|
os.path.join(os.path.dirname(synapse.__file__), "static")
|
||||||
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
if name in ["media", "federation", "client"]:
|
if name in ["media", "federation", "client"]:
|
||||||
resources.update({
|
resources.update({
|
||||||
MEDIA_PREFIX: self.get_resource_for_media_repository(),
|
MEDIA_PREFIX: MediaRepositoryResource(self),
|
||||||
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
|
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||||
|
self, self.config.uploads_path, self.auth, self.content_addr
|
||||||
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
if name in ["keys", "federation"]:
|
if name in ["keys", "federation"]:
|
||||||
resources.update({
|
resources.update({
|
||||||
SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
|
SERVER_KEY_PREFIX: LocalKey(self),
|
||||||
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
|
SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
|
||||||
})
|
})
|
||||||
|
|
||||||
if name == "webclient":
|
if name == "webclient":
|
||||||
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
|
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
|
||||||
|
|
||||||
if name == "metrics" and metrics_resource:
|
if name == "metrics" and self.get_config().enable_metrics:
|
||||||
resources[METRICS_PREFIX] = metrics_resource
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources)
|
root_resource = create_resource_tree(resources)
|
||||||
if tls:
|
if tls:
|
||||||
|
@ -296,6 +254,18 @@ class SynapseHomeServer(HomeServer):
|
||||||
except IncorrectDatabaseSetup as e:
|
except IncorrectDatabaseSetup as e:
|
||||||
quit_with_error(e.message)
|
quit_with_error(e.message)
|
||||||
|
|
||||||
|
def get_db_conn(self):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
|
||||||
def quit_with_error(error_string):
|
def quit_with_error(error_string):
|
||||||
message_lines = error_string.split("\n")
|
message_lines = error_string.split("\n")
|
||||||
|
@ -432,13 +402,7 @@ def setup(config_options):
|
||||||
logger.info("Preparing database: %s...", config.database_config['name'])
|
logger.info("Preparing database: %s...", config.database_config['name'])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_conn = database_engine.module.connect(
|
db_conn = hs.get_db_conn()
|
||||||
**{
|
|
||||||
k: v for k, v in config.database_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
database_engine.prepare_database(db_conn)
|
database_engine.prepare_database(db_conn)
|
||||||
hs.run_startup_checks(db_conn, database_engine)
|
hs.run_startup_checks(db_conn, database_engine)
|
||||||
|
|
||||||
|
@ -453,14 +417,18 @@ def setup(config_options):
|
||||||
|
|
||||||
logger.info("Database prepared in %s.", config.database_config['name'])
|
logger.info("Database prepared in %s.", config.database_config['name'])
|
||||||
|
|
||||||
|
hs.setup()
|
||||||
hs.start_listening()
|
hs.start_listening()
|
||||||
|
|
||||||
|
def start():
|
||||||
hs.get_pusherpool().start()
|
hs.get_pusherpool().start()
|
||||||
hs.get_state_handler().start_caching()
|
hs.get_state_handler().start_caching()
|
||||||
hs.get_datastore().start_profiling()
|
hs.get_datastore().start_profiling()
|
||||||
hs.get_datastore().start_doing_background_updates()
|
hs.get_datastore().start_doing_background_updates()
|
||||||
hs.get_replication_layer().start_get_pdu_cache()
|
hs.get_replication_layer().start_get_pdu_cache()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
@ -675,7 +643,7 @@ def _resource_id(resource, path_seg):
|
||||||
the mapping should looks like _resource_id(A,C) = B.
|
the mapping should looks like _resource_id(A,C) = B.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
resource (Resource): The *parent* Resource
|
resource (Resource): The *parent* Resourceb
|
||||||
path_seg (str): The name of the child Resource to be attached.
|
path_seg (str): The name of the child Resource to be attached.
|
||||||
Returns:
|
Returns:
|
||||||
str: A unique string which can be a key to the child Resource.
|
str: A unique string which can be a key to the child Resource.
|
||||||
|
|
|
@ -17,15 +17,10 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .replication import ReplicationLayer
|
from .replication import ReplicationLayer
|
||||||
from .transport import TransportLayer
|
from .transport.client import TransportLayerClient
|
||||||
|
|
||||||
|
|
||||||
def initialize_http_replication(homeserver):
|
def initialize_http_replication(homeserver):
|
||||||
transport = TransportLayer(
|
transport = TransportLayerClient(homeserver)
|
||||||
homeserver,
|
|
||||||
homeserver.hostname,
|
|
||||||
server=homeserver.get_resource_for_federation(),
|
|
||||||
client=homeserver.get_http_client()
|
|
||||||
)
|
|
||||||
|
|
||||||
return ReplicationLayer(homeserver, transport)
|
return ReplicationLayer(homeserver, transport)
|
||||||
|
|
|
@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer):
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
|
|
||||||
self.transport_layer = transport_layer
|
self.transport_layer = transport_layer
|
||||||
self.transport_layer.register_received_handler(self)
|
|
||||||
self.transport_layer.register_request_handler(self)
|
|
||||||
|
|
||||||
self.federation_client = self
|
self.federation_client = self
|
||||||
|
|
||||||
|
|
|
@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
|
||||||
support HTTPS), however individual pairings of servers may decide to
|
support HTTPS), however individual pairings of servers may decide to
|
||||||
communicate over a different (albeit still reliable) protocol.
|
communicate over a different (albeit still reliable) protocol.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .server import TransportLayerServer
|
|
||||||
from .client import TransportLayerClient
|
|
||||||
|
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
|
||||||
|
|
||||||
|
|
||||||
class TransportLayer(TransportLayerServer, TransportLayerClient):
|
|
||||||
"""This is a basic implementation of the transport layer that translates
|
|
||||||
transactions and other requests to/from HTTP.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
server_name (str): Local home server host
|
|
||||||
|
|
||||||
server (synapse.http.server.HttpServer): the http server to
|
|
||||||
register listeners on
|
|
||||||
|
|
||||||
client (synapse.http.client.HttpClient): the http client used to
|
|
||||||
send requests
|
|
||||||
|
|
||||||
request_handler (TransportRequestHandler): The handler to fire when we
|
|
||||||
receive requests for data.
|
|
||||||
|
|
||||||
received_handler (TransportReceivedHandler): The handler to fire when
|
|
||||||
we receive data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, homeserver, server_name, server, client):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
server_name (str): Local home server host
|
|
||||||
server (synapse.protocol.http.HttpServer): the http server to
|
|
||||||
register listeners on
|
|
||||||
client (synapse.protocol.http.HttpClient): the http client used to
|
|
||||||
send requests
|
|
||||||
"""
|
|
||||||
self.keyring = homeserver.get_keyring()
|
|
||||||
self.clock = homeserver.get_clock()
|
|
||||||
self.server_name = server_name
|
|
||||||
self.server = server
|
|
||||||
self.client = client
|
|
||||||
self.request_handler = None
|
|
||||||
self.received_handler = None
|
|
||||||
|
|
||||||
self.ratelimiter = FederationRateLimiter(
|
|
||||||
self.clock,
|
|
||||||
window_size=homeserver.config.federation_rc_window_size,
|
|
||||||
sleep_limit=homeserver.config.federation_rc_sleep_limit,
|
|
||||||
sleep_msec=homeserver.config.federation_rc_sleep_delay,
|
|
||||||
reject_limit=homeserver.config.federation_rc_reject_limit,
|
|
||||||
concurrent_requests=homeserver.config.federation_rc_concurrent,
|
|
||||||
)
|
|
||||||
|
|
|
@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
|
||||||
class TransportLayerClient(object):
|
class TransportLayerClient(object):
|
||||||
"""Sends federation HTTP requests to other servers"""
|
"""Sends federation HTTP requests to other servers"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
self.client = hs.get_http_client()
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_room_state(self, destination, room_id, event_id):
|
def get_room_state(self, destination, room_id, event_id):
|
||||||
""" Requests all state for a given room from the given server at the
|
""" Requests all state for a given room from the given server at the
|
||||||
|
|
|
@ -17,7 +17,8 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.util.logutils import log_function
|
from synapse.http.server import JsonResource
|
||||||
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
@ -28,9 +29,41 @@ import re
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TransportLayerServer(object):
|
class TransportLayerServer(JsonResource):
|
||||||
"""Handles incoming federation HTTP requests"""
|
"""Handles incoming federation HTTP requests"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
super(TransportLayerServer, self).__init__(hs)
|
||||||
|
|
||||||
|
self.authenticator = Authenticator(hs)
|
||||||
|
self.ratelimiter = FederationRateLimiter(
|
||||||
|
self.clock,
|
||||||
|
window_size=hs.config.federation_rc_window_size,
|
||||||
|
sleep_limit=hs.config.federation_rc_sleep_limit,
|
||||||
|
sleep_msec=hs.config.federation_rc_sleep_delay,
|
||||||
|
reject_limit=hs.config.federation_rc_reject_limit,
|
||||||
|
concurrent_requests=hs.config.federation_rc_concurrent,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_servlets()
|
||||||
|
|
||||||
|
def register_servlets(self):
|
||||||
|
register_servlets(
|
||||||
|
self.hs,
|
||||||
|
resource=self,
|
||||||
|
ratelimiter=self.ratelimiter,
|
||||||
|
authenticator=self.authenticator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Authenticator(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.keyring = hs.get_keyring()
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def authenticate_request(self, request):
|
def authenticate_request(self, request):
|
||||||
|
@ -98,37 +131,9 @@ class TransportLayerServer(object):
|
||||||
|
|
||||||
defer.returnValue((origin, content))
|
defer.returnValue((origin, content))
|
||||||
|
|
||||||
@log_function
|
|
||||||
def register_received_handler(self, handler):
|
|
||||||
""" Register a handler that will be fired when we receive data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
handler (TransportReceivedHandler)
|
|
||||||
"""
|
|
||||||
FederationSendServlet(
|
|
||||||
handler,
|
|
||||||
authenticator=self,
|
|
||||||
ratelimiter=self.ratelimiter,
|
|
||||||
server_name=self.server_name,
|
|
||||||
).register(self.server)
|
|
||||||
|
|
||||||
@log_function
|
|
||||||
def register_request_handler(self, handler):
|
|
||||||
""" Register a handler that will be fired when we get asked for data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
handler (TransportRequestHandler)
|
|
||||||
"""
|
|
||||||
for servletclass in SERVLET_CLASSES:
|
|
||||||
servletclass(
|
|
||||||
handler,
|
|
||||||
authenticator=self,
|
|
||||||
ratelimiter=self.ratelimiter,
|
|
||||||
).register(self.server)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFederationServlet(object):
|
class BaseFederationServlet(object):
|
||||||
def __init__(self, handler, authenticator, ratelimiter):
|
def __init__(self, handler, authenticator, ratelimiter, server_name):
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.authenticator = authenticator
|
self.authenticator = authenticator
|
||||||
self.ratelimiter = ratelimiter
|
self.ratelimiter = ratelimiter
|
||||||
|
@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet):
|
||||||
PATH = "/send/([^/]*)/"
|
PATH = "/send/([^/]*)/"
|
||||||
|
|
||||||
def __init__(self, handler, server_name, **kwargs):
|
def __init__(self, handler, server_name, **kwargs):
|
||||||
super(FederationSendServlet, self).__init__(handler, **kwargs)
|
super(FederationSendServlet, self).__init__(
|
||||||
|
handler, server_name=server_name, **kwargs
|
||||||
|
)
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
|
|
||||||
# This is when someone is trying to send us a bunch of data.
|
# This is when someone is trying to send us a bunch of data.
|
||||||
|
@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet):
|
||||||
|
|
||||||
|
|
||||||
SERVLET_CLASSES = (
|
SERVLET_CLASSES = (
|
||||||
|
FederationSendServlet,
|
||||||
FederationPullServlet,
|
FederationPullServlet,
|
||||||
FederationEventServlet,
|
FederationEventServlet,
|
||||||
FederationStateServlet,
|
FederationStateServlet,
|
||||||
|
@ -451,3 +459,13 @@ SERVLET_CLASSES = (
|
||||||
FederationThirdPartyInviteExchangeServlet,
|
FederationThirdPartyInviteExchangeServlet,
|
||||||
On3pidBindServlet,
|
On3pidBindServlet,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, resource, authenticator, ratelimiter):
|
||||||
|
for servletclass in SERVLET_CLASSES:
|
||||||
|
servletclass(
|
||||||
|
handler=hs.get_replication_layer(),
|
||||||
|
authenticator=authenticator,
|
||||||
|
ratelimiter=ratelimiter,
|
||||||
|
server_name=hs.hostname,
|
||||||
|
).register(resource)
|
||||||
|
|
|
@ -1186,7 +1186,13 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(e, auth_events=auth_for_e)
|
self.auth.check(e, auth_events=auth_for_e)
|
||||||
except AuthError as err:
|
except SynapseError as err:
|
||||||
|
# we may get SynapseErrors here as well as AuthErrors. For
|
||||||
|
# instance, there are a couple of (ancient) events in some
|
||||||
|
# rooms whose senders do not have the correct sigil; these
|
||||||
|
# cause SynapseErrors in auth.check. We don't want to give up
|
||||||
|
# the attempt to federate altogether in such cases.
|
||||||
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Rejecting %s because %s",
|
"Rejecting %s because %s",
|
||||||
e.event_id, err.msg
|
e.event_id, err.msg
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import SynapseError, AuthError, Codes
|
from synapse.api.errors import AuthError, Codes
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
|
@ -105,8 +105,6 @@ class MessageHandler(BaseHandler):
|
||||||
room_token = pagin_config.from_token.room_key
|
room_token = pagin_config.from_token.room_key
|
||||||
|
|
||||||
room_token = RoomStreamToken.parse(room_token)
|
room_token = RoomStreamToken.parse(room_token)
|
||||||
if room_token.topological is None:
|
|
||||||
raise SynapseError(400, "Invalid token")
|
|
||||||
|
|
||||||
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
|
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
|
||||||
"room_key", str(room_token)
|
"room_key", str(room_token)
|
||||||
|
@ -117,26 +115,30 @@ class MessageHandler(BaseHandler):
|
||||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||||
room_id, user_id
|
room_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if source_config.direction == 'b':
|
||||||
|
# if we're going backwards, we might need to backfill. This
|
||||||
|
# requires that we have a topo token.
|
||||||
|
if room_token.topological:
|
||||||
|
max_topo = room_token.topological
|
||||||
|
else:
|
||||||
|
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
|
||||||
|
room_id, room_token.stream
|
||||||
|
)
|
||||||
|
|
||||||
if membership == Membership.LEAVE:
|
if membership == Membership.LEAVE:
|
||||||
# If they have left the room then clamp the token to be before
|
# If they have left the room then clamp the token to be before
|
||||||
# they left the room.
|
# they left the room, to save the effort of loading from the
|
||||||
|
# database.
|
||||||
leave_token = yield self.store.get_topological_token_for_event(
|
leave_token = yield self.store.get_topological_token_for_event(
|
||||||
member_event_id
|
member_event_id
|
||||||
)
|
)
|
||||||
leave_token = RoomStreamToken.parse(leave_token)
|
leave_token = RoomStreamToken.parse(leave_token)
|
||||||
if leave_token.topological < room_token.topological:
|
if leave_token.topological < max_topo:
|
||||||
source_config.from_key = str(leave_token)
|
source_config.from_key = str(leave_token)
|
||||||
|
|
||||||
if source_config.direction == "f":
|
|
||||||
if source_config.to_key is None:
|
|
||||||
source_config.to_key = str(leave_token)
|
|
||||||
else:
|
|
||||||
to_token = RoomStreamToken.parse(source_config.to_key)
|
|
||||||
if leave_token.topological < to_token.topological:
|
|
||||||
source_config.to_key = str(leave_token)
|
|
||||||
|
|
||||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||||
room_id, room_token.topological
|
room_id, max_topo
|
||||||
)
|
)
|
||||||
|
|
||||||
events, next_key = yield data_source.get_pagination_rows(
|
events, next_key = yield data_source.get_pagination_rows(
|
||||||
|
|
|
@ -72,7 +72,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
|
class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
|
||||||
"room_id", # str
|
"room_id", # str
|
||||||
"timeline", # TimelineBatch
|
"timeline", # TimelineBatch
|
||||||
"state", # dict[(str, str), FrozenEvent]
|
"state", # dict[(str, str), FrozenEvent]
|
||||||
|
@ -298,46 +298,19 @@ class SyncHandler(BaseHandler):
|
||||||
room_id, sync_config, now_token, since_token=timeline_since_token
|
room_id, sync_config, now_token, since_token=timeline_since_token
|
||||||
)
|
)
|
||||||
|
|
||||||
notifs = yield self.unread_notifs_for_room_id(
|
room_sync = yield self.incremental_sync_with_gap_for_room(
|
||||||
room_id, sync_config, ephemeral_by_room
|
room_id, sync_config,
|
||||||
|
now_token=now_token,
|
||||||
|
since_token=timeline_since_token,
|
||||||
|
ephemeral_by_room=ephemeral_by_room,
|
||||||
|
tags_by_room=tags_by_room,
|
||||||
|
account_data_by_room=account_data_by_room,
|
||||||
|
all_ephemeral_by_room=ephemeral_by_room,
|
||||||
|
batch=batch,
|
||||||
|
full_state=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
unread_notifications = {}
|
defer.returnValue(room_sync)
|
||||||
if notifs is not None:
|
|
||||||
unread_notifications["notification_count"] = len(notifs)
|
|
||||||
unread_notifications["highlight_count"] = len([
|
|
||||||
1 for notif in notifs if _action_has_highlight(notif["actions"])
|
|
||||||
])
|
|
||||||
|
|
||||||
current_state = yield self.get_state_at(room_id, now_token)
|
|
||||||
|
|
||||||
current_state = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in sync_config.filter_collection.filter_room_state(
|
|
||||||
current_state.values()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
account_data = self.account_data_for_room(
|
|
||||||
room_id, tags_by_room, account_data_by_room
|
|
||||||
)
|
|
||||||
|
|
||||||
account_data = sync_config.filter_collection.filter_room_account_data(
|
|
||||||
account_data
|
|
||||||
)
|
|
||||||
|
|
||||||
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
|
|
||||||
ephemeral_by_room.get(room_id, [])
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(JoinedSyncResult(
|
|
||||||
room_id=room_id,
|
|
||||||
timeline=batch,
|
|
||||||
state=current_state,
|
|
||||||
ephemeral=ephemeral,
|
|
||||||
account_data=account_data,
|
|
||||||
unread_notifications=unread_notifications,
|
|
||||||
))
|
|
||||||
|
|
||||||
def account_data_for_user(self, account_data):
|
def account_data_for_user(self, account_data):
|
||||||
account_data_events = []
|
account_data_events = []
|
||||||
|
@ -429,44 +402,20 @@ class SyncHandler(BaseHandler):
|
||||||
|
|
||||||
defer.returnValue((now_token, ephemeral_by_room))
|
defer.returnValue((now_token, ephemeral_by_room))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def full_state_sync_for_archived_room(self, room_id, sync_config,
|
def full_state_sync_for_archived_room(self, room_id, sync_config,
|
||||||
leave_event_id, leave_token,
|
leave_event_id, leave_token,
|
||||||
timeline_since_token, tags_by_room,
|
timeline_since_token, tags_by_room,
|
||||||
account_data_by_room):
|
account_data_by_room):
|
||||||
"""Sync a room for a client which is starting without any state
|
"""Sync a room for a client which is starting without any state
|
||||||
Returns:
|
Returns:
|
||||||
A Deferred JoinedSyncResult.
|
A Deferred ArchivedSyncResult.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch = yield self.load_filtered_recents(
|
return self.incremental_sync_for_archived_room(
|
||||||
room_id, sync_config, leave_token, since_token=timeline_since_token
|
sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
|
||||||
|
account_data_by_room, full_state=True, leave_token=leave_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
leave_state = yield self.store.get_state_for_event(leave_event_id)
|
|
||||||
|
|
||||||
leave_state = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in sync_config.filter_collection.filter_room_state(
|
|
||||||
leave_state.values()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
account_data = self.account_data_for_room(
|
|
||||||
room_id, tags_by_room, account_data_by_room
|
|
||||||
)
|
|
||||||
|
|
||||||
account_data = sync_config.filter_collection.filter_room_account_data(
|
|
||||||
account_data
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(ArchivedSyncResult(
|
|
||||||
room_id=room_id,
|
|
||||||
timeline=batch,
|
|
||||||
state=leave_state,
|
|
||||||
account_data=account_data,
|
|
||||||
))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def incremental_sync_with_gap(self, sync_config, since_token):
|
def incremental_sync_with_gap(self, sync_config, since_token):
|
||||||
""" Get the incremental delta needed to bring the client up to
|
""" Get the incremental delta needed to bring the client up to
|
||||||
|
@ -512,154 +461,127 @@ class SyncHandler(BaseHandler):
|
||||||
sync_config.user
|
sync_config.user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_id = sync_config.user.to_string()
|
||||||
|
|
||||||
timeline_limit = sync_config.filter_collection.timeline_limit()
|
timeline_limit = sync_config.filter_collection.timeline_limit()
|
||||||
|
|
||||||
room_events, _ = yield self.store.get_room_events_stream(
|
|
||||||
sync_config.user.to_string(),
|
|
||||||
from_key=since_token.room_key,
|
|
||||||
to_key=now_token.room_key,
|
|
||||||
limit=timeline_limit + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
tags_by_room = yield self.store.get_updated_tags(
|
tags_by_room = yield self.store.get_updated_tags(
|
||||||
sync_config.user.to_string(),
|
user_id,
|
||||||
since_token.account_data_key,
|
since_token.account_data_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
account_data, account_data_by_room = (
|
account_data, account_data_by_room = (
|
||||||
yield self.store.get_updated_account_data_for_user(
|
yield self.store.get_updated_account_data_for_user(
|
||||||
sync_config.user.to_string(),
|
user_id,
|
||||||
since_token.account_data_key,
|
since_token.account_data_key,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
joined = []
|
# Get a list of membership change events that have happened.
|
||||||
|
rooms_changed = yield self.store.get_room_changes_for_user(
|
||||||
|
user_id, since_token.room_key, now_token.room_key
|
||||||
|
)
|
||||||
|
|
||||||
|
mem_change_events_by_room_id = {}
|
||||||
|
for event in rooms_changed:
|
||||||
|
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
|
||||||
|
|
||||||
|
newly_joined_rooms = []
|
||||||
archived = []
|
archived = []
|
||||||
if len(room_events) <= timeline_limit:
|
invited = []
|
||||||
# There is no gap in any of the rooms. Therefore we can just
|
for room_id, events in mem_change_events_by_room_id.items():
|
||||||
# partition the new events by room and return them.
|
non_joins = [e for e in events if e.membership != Membership.JOIN]
|
||||||
logger.debug("Got %i events for incremental sync - not limited",
|
has_join = len(non_joins) != len(events)
|
||||||
len(room_events))
|
|
||||||
|
|
||||||
invite_events = []
|
# We want to figure out if we joined the room at some point since
|
||||||
leave_events = []
|
# the last sync (even if we have since left). This is to make sure
|
||||||
events_by_room_id = {}
|
# we do send down the room, and with full state, where necessary
|
||||||
for event in room_events:
|
if room_id in joined_room_ids or has_join:
|
||||||
events_by_room_id.setdefault(event.room_id, []).append(event)
|
old_state = yield self.get_state_at(room_id, since_token)
|
||||||
if event.room_id not in joined_room_ids:
|
old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
|
||||||
if (event.type == EventTypes.Member
|
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
|
||||||
and event.state_key == sync_config.user.to_string()):
|
newly_joined_rooms.append(room_id)
|
||||||
if event.membership == Membership.INVITE:
|
|
||||||
invite_events.append(event)
|
|
||||||
elif event.membership in (Membership.LEAVE, Membership.BAN):
|
|
||||||
leave_events.append(event)
|
|
||||||
|
|
||||||
for room_id in joined_room_ids:
|
if room_id in joined_room_ids:
|
||||||
recents = events_by_room_id.get(room_id, [])
|
continue
|
||||||
logger.debug("Events for room %s: %r", room_id, recents)
|
|
||||||
state = {
|
|
||||||
(event.type, event.state_key): event
|
|
||||||
for event in recents if event.is_state()}
|
|
||||||
limited = False
|
|
||||||
|
|
||||||
if recents:
|
if not non_joins:
|
||||||
prev_batch = now_token.copy_and_replace(
|
continue
|
||||||
"room_key", recents[0].internal_metadata.before
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prev_batch = now_token
|
|
||||||
|
|
||||||
just_joined = yield self.check_joined_room(sync_config, state)
|
|
||||||
if just_joined:
|
|
||||||
logger.debug("User has just joined %s: needs full state",
|
|
||||||
room_id)
|
|
||||||
state = yield self.get_state_at(room_id, now_token)
|
|
||||||
# the timeline is inherently limited if we've just joined
|
|
||||||
limited = True
|
|
||||||
|
|
||||||
recents = sync_config.filter_collection.filter_room_timeline(recents)
|
|
||||||
|
|
||||||
state = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in sync_config.filter_collection.filter_room_state(
|
|
||||||
state.values()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
acc_data = self.account_data_for_room(
|
|
||||||
room_id, tags_by_room, account_data_by_room
|
|
||||||
)
|
|
||||||
|
|
||||||
acc_data = sync_config.filter_collection.filter_room_account_data(
|
|
||||||
acc_data
|
|
||||||
)
|
|
||||||
|
|
||||||
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
|
|
||||||
ephemeral_by_room.get(room_id, [])
|
|
||||||
)
|
|
||||||
|
|
||||||
room_sync = JoinedSyncResult(
|
|
||||||
room_id=room_id,
|
|
||||||
timeline=TimelineBatch(
|
|
||||||
events=recents,
|
|
||||||
prev_batch=prev_batch,
|
|
||||||
limited=limited,
|
|
||||||
),
|
|
||||||
state=state,
|
|
||||||
ephemeral=ephemeral,
|
|
||||||
account_data=acc_data,
|
|
||||||
unread_notifications={},
|
|
||||||
)
|
|
||||||
logger.debug("Result for room %s: %r", room_id, room_sync)
|
|
||||||
|
|
||||||
|
# Only bother if we're still currently invited
|
||||||
|
should_invite = non_joins[-1].membership == Membership.INVITE
|
||||||
|
if should_invite:
|
||||||
|
room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
|
||||||
if room_sync:
|
if room_sync:
|
||||||
notifs = yield self.unread_notifs_for_room_id(
|
invited.append(room_sync)
|
||||||
room_id, sync_config, all_ephemeral_by_room
|
|
||||||
)
|
|
||||||
|
|
||||||
if notifs is not None:
|
# Always include leave/ban events. Just take the last one.
|
||||||
notif_dict = room_sync.unread_notifications
|
# TODO: How do we handle ban -> leave in same batch?
|
||||||
notif_dict["notification_count"] = len(notifs)
|
leave_events = [
|
||||||
notif_dict["highlight_count"] = len([
|
e for e in non_joins
|
||||||
1 for notif in notifs
|
if e.membership in (Membership.LEAVE, Membership.BAN)
|
||||||
if _action_has_highlight(notif["actions"])
|
]
|
||||||
])
|
|
||||||
|
|
||||||
joined.append(room_sync)
|
if leave_events:
|
||||||
|
leave_event = leave_events[-1]
|
||||||
else:
|
|
||||||
logger.debug("Got %i events for incremental sync - hit limit",
|
|
||||||
len(room_events))
|
|
||||||
|
|
||||||
invite_events = yield self.store.get_invites_for_user(
|
|
||||||
sync_config.user.to_string()
|
|
||||||
)
|
|
||||||
|
|
||||||
leave_events = yield self.store.get_leave_and_ban_events_for_user(
|
|
||||||
sync_config.user.to_string()
|
|
||||||
)
|
|
||||||
|
|
||||||
for room_id in joined_room_ids:
|
|
||||||
room_sync = yield self.incremental_sync_with_gap_for_room(
|
|
||||||
room_id, sync_config, since_token, now_token,
|
|
||||||
ephemeral_by_room, tags_by_room, account_data_by_room,
|
|
||||||
all_ephemeral_by_room=all_ephemeral_by_room,
|
|
||||||
)
|
|
||||||
if room_sync:
|
|
||||||
joined.append(room_sync)
|
|
||||||
|
|
||||||
for leave_event in leave_events:
|
|
||||||
room_sync = yield self.incremental_sync_for_archived_room(
|
room_sync = yield self.incremental_sync_for_archived_room(
|
||||||
sync_config, leave_event, since_token, tags_by_room,
|
sync_config, room_id, leave_event.event_id, since_token,
|
||||||
account_data_by_room
|
tags_by_room, account_data_by_room,
|
||||||
|
full_state=room_id in newly_joined_rooms
|
||||||
)
|
)
|
||||||
if room_sync:
|
if room_sync:
|
||||||
archived.append(room_sync)
|
archived.append(room_sync)
|
||||||
|
|
||||||
invited = [
|
# Get all events for rooms we're currently joined to.
|
||||||
InvitedSyncResult(room_id=event.room_id, invite=event)
|
room_to_events = yield self.store.get_room_events_stream_for_rooms(
|
||||||
for event in invite_events
|
room_ids=joined_room_ids,
|
||||||
]
|
from_key=since_token.room_key,
|
||||||
|
to_key=now_token.room_key,
|
||||||
|
limit=timeline_limit + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
joined = []
|
||||||
|
# We loop through all room ids, even if there are no new events, in case
|
||||||
|
# there are non room events taht we need to notify about.
|
||||||
|
for room_id in joined_room_ids:
|
||||||
|
room_entry = room_to_events.get(room_id, None)
|
||||||
|
|
||||||
|
if room_entry:
|
||||||
|
events, start_key = room_entry
|
||||||
|
|
||||||
|
prev_batch_token = now_token.copy_and_replace("room_key", start_key)
|
||||||
|
|
||||||
|
newly_joined_room = room_id in newly_joined_rooms
|
||||||
|
full_state = newly_joined_room
|
||||||
|
|
||||||
|
batch = yield self.load_filtered_recents(
|
||||||
|
room_id, sync_config, prev_batch_token,
|
||||||
|
since_token=since_token,
|
||||||
|
recents=events,
|
||||||
|
newly_joined_room=newly_joined_room,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = TimelineBatch(
|
||||||
|
events=[],
|
||||||
|
prev_batch=since_token,
|
||||||
|
limited=False,
|
||||||
|
)
|
||||||
|
full_state = False
|
||||||
|
|
||||||
|
room_sync = yield self.incremental_sync_with_gap_for_room(
|
||||||
|
room_id=room_id,
|
||||||
|
sync_config=sync_config,
|
||||||
|
since_token=since_token,
|
||||||
|
now_token=now_token,
|
||||||
|
ephemeral_by_room=ephemeral_by_room,
|
||||||
|
tags_by_room=tags_by_room,
|
||||||
|
account_data_by_room=account_data_by_room,
|
||||||
|
all_ephemeral_by_room=all_ephemeral_by_room,
|
||||||
|
batch=batch,
|
||||||
|
full_state=full_state,
|
||||||
|
)
|
||||||
|
if room_sync:
|
||||||
|
joined.append(room_sync)
|
||||||
|
|
||||||
account_data_for_user = sync_config.filter_collection.filter_account_data(
|
account_data_for_user = sync_config.filter_collection.filter_account_data(
|
||||||
self.account_data_for_user(account_data)
|
self.account_data_for_user(account_data)
|
||||||
|
@ -680,28 +602,40 @@ class SyncHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def load_filtered_recents(self, room_id, sync_config, now_token,
|
def load_filtered_recents(self, room_id, sync_config, now_token,
|
||||||
since_token=None):
|
since_token=None, recents=None, newly_joined_room=False):
|
||||||
"""
|
"""
|
||||||
:returns a Deferred TimelineBatch
|
:returns a Deferred TimelineBatch
|
||||||
"""
|
"""
|
||||||
limited = True
|
|
||||||
recents = []
|
|
||||||
filtering_factor = 2
|
filtering_factor = 2
|
||||||
timeline_limit = sync_config.filter_collection.timeline_limit()
|
timeline_limit = sync_config.filter_collection.timeline_limit()
|
||||||
load_limit = max(timeline_limit * filtering_factor, 100)
|
load_limit = max(timeline_limit * filtering_factor, 10)
|
||||||
max_repeat = 3 # Only try a few times per room, otherwise
|
max_repeat = 5 # Only try a few times per room, otherwise
|
||||||
room_key = now_token.room_key
|
room_key = now_token.room_key
|
||||||
end_key = room_key
|
end_key = room_key
|
||||||
|
|
||||||
|
limited = recents is None or newly_joined_room or timeline_limit < len(recents)
|
||||||
|
|
||||||
|
if recents is not None:
|
||||||
|
recents = sync_config.filter_collection.filter_room_timeline(recents)
|
||||||
|
recents = yield self._filter_events_for_client(
|
||||||
|
sync_config.user.to_string(),
|
||||||
|
recents,
|
||||||
|
is_peeking=sync_config.is_guest,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
recents = []
|
||||||
|
|
||||||
|
since_key = None
|
||||||
|
if since_token and not newly_joined_room:
|
||||||
|
since_key = since_token.room_key
|
||||||
|
|
||||||
while limited and len(recents) < timeline_limit and max_repeat:
|
while limited and len(recents) < timeline_limit and max_repeat:
|
||||||
events, keys = yield self.store.get_recent_events_for_room(
|
events, end_key = yield self.store.get_room_events_stream_for_room(
|
||||||
room_id,
|
room_id,
|
||||||
limit=load_limit + 1,
|
limit=load_limit + 1,
|
||||||
from_token=since_token.room_key if since_token else None,
|
from_key=since_key,
|
||||||
end_token=end_key,
|
to_key=end_key,
|
||||||
)
|
)
|
||||||
room_key, _ = keys
|
|
||||||
end_key = "s" + room_key.split('-')[-1]
|
|
||||||
loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
|
loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
|
||||||
loaded_recents = yield self._filter_events_for_client(
|
loaded_recents = yield self._filter_events_for_client(
|
||||||
sync_config.user.to_string(),
|
sync_config.user.to_string(),
|
||||||
|
@ -710,8 +644,10 @@ class SyncHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
loaded_recents.extend(recents)
|
loaded_recents.extend(recents)
|
||||||
recents = loaded_recents
|
recents = loaded_recents
|
||||||
|
|
||||||
if len(events) <= load_limit:
|
if len(events) <= load_limit:
|
||||||
limited = False
|
limited = False
|
||||||
|
break
|
||||||
max_repeat -= 1
|
max_repeat -= 1
|
||||||
|
|
||||||
if len(recents) > timeline_limit:
|
if len(recents) > timeline_limit:
|
||||||
|
@ -724,7 +660,9 @@ class SyncHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(TimelineBatch(
|
defer.returnValue(TimelineBatch(
|
||||||
events=recents, prev_batch=prev_batch_token, limited=limited
|
events=recents,
|
||||||
|
prev_batch=prev_batch_token,
|
||||||
|
limited=limited or newly_joined_room
|
||||||
))
|
))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -732,25 +670,12 @@ class SyncHandler(BaseHandler):
|
||||||
since_token, now_token,
|
since_token, now_token,
|
||||||
ephemeral_by_room, tags_by_room,
|
ephemeral_by_room, tags_by_room,
|
||||||
account_data_by_room,
|
account_data_by_room,
|
||||||
all_ephemeral_by_room):
|
all_ephemeral_by_room,
|
||||||
""" Get the incremental delta needed to bring the client up to date for
|
batch, full_state=False):
|
||||||
the room. Gives the client the most recent events and the changes to
|
if full_state:
|
||||||
state.
|
state = yield self.get_state_at(room_id, now_token)
|
||||||
Returns:
|
|
||||||
A Deferred JoinedSyncResult
|
|
||||||
"""
|
|
||||||
logger.debug("Doing incremental sync for room %s between %s and %s",
|
|
||||||
room_id, since_token, now_token)
|
|
||||||
|
|
||||||
# TODO(mjark): Check for redactions we might have missed.
|
elif batch.limited:
|
||||||
|
|
||||||
batch = yield self.load_filtered_recents(
|
|
||||||
room_id, sync_config, now_token, since_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Recents %r", batch)
|
|
||||||
|
|
||||||
if batch.limited:
|
|
||||||
current_state = yield self.get_state_at(room_id, now_token)
|
current_state = yield self.get_state_at(room_id, now_token)
|
||||||
|
|
||||||
state_at_previous_sync = yield self.get_state_at(
|
state_at_previous_sync = yield self.get_state_at(
|
||||||
|
@ -772,17 +697,6 @@ class SyncHandler(BaseHandler):
|
||||||
if just_joined:
|
if just_joined:
|
||||||
state = yield self.get_state_at(room_id, now_token)
|
state = yield self.get_state_at(room_id, now_token)
|
||||||
|
|
||||||
notifs = yield self.unread_notifs_for_room_id(
|
|
||||||
room_id, sync_config, all_ephemeral_by_room
|
|
||||||
)
|
|
||||||
|
|
||||||
unread_notifications = {}
|
|
||||||
if notifs is not None:
|
|
||||||
unread_notifications["notification_count"] = len(notifs)
|
|
||||||
unread_notifications["highlight_count"] = len([
|
|
||||||
1 for notif in notifs if _action_has_highlight(notif["actions"])
|
|
||||||
])
|
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
(e.type, e.state_key): e
|
(e.type, e.state_key): e
|
||||||
for e in sync_config.filter_collection.filter_room_state(state.values())
|
for e in sync_config.filter_collection.filter_room_state(state.values())
|
||||||
|
@ -800,6 +714,7 @@ class SyncHandler(BaseHandler):
|
||||||
ephemeral_by_room.get(room_id, [])
|
ephemeral_by_room.get(room_id, [])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unread_notifications = {}
|
||||||
room_sync = JoinedSyncResult(
|
room_sync = JoinedSyncResult(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
timeline=batch,
|
timeline=batch,
|
||||||
|
@ -809,41 +724,55 @@ class SyncHandler(BaseHandler):
|
||||||
unread_notifications=unread_notifications,
|
unread_notifications=unread_notifications,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if room_sync:
|
||||||
|
notifs = yield self.unread_notifs_for_room_id(
|
||||||
|
room_id, sync_config, all_ephemeral_by_room
|
||||||
|
)
|
||||||
|
|
||||||
|
if notifs is not None:
|
||||||
|
unread_notifications["notification_count"] = len(notifs)
|
||||||
|
unread_notifications["highlight_count"] = len([
|
||||||
|
1 for notif in notifs if _action_has_highlight(notif["actions"])
|
||||||
|
])
|
||||||
|
|
||||||
logger.debug("Room sync: %r", room_sync)
|
logger.debug("Room sync: %r", room_sync)
|
||||||
|
|
||||||
defer.returnValue(room_sync)
|
defer.returnValue(room_sync)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def incremental_sync_for_archived_room(self, sync_config, leave_event,
|
def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
|
||||||
since_token, tags_by_room,
|
since_token, tags_by_room,
|
||||||
account_data_by_room):
|
account_data_by_room, full_state,
|
||||||
|
leave_token=None):
|
||||||
""" Get the incremental delta needed to bring the client up to date for
|
""" Get the incremental delta needed to bring the client up to date for
|
||||||
the archived room.
|
the archived room.
|
||||||
Returns:
|
Returns:
|
||||||
A Deferred ArchivedSyncResult
|
A Deferred ArchivedSyncResult
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not leave_token:
|
||||||
stream_token = yield self.store.get_stream_token_for_event(
|
stream_token = yield self.store.get_stream_token_for_event(
|
||||||
leave_event.event_id
|
leave_event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
leave_token = since_token.copy_and_replace("room_key", stream_token)
|
leave_token = since_token.copy_and_replace("room_key", stream_token)
|
||||||
|
|
||||||
if since_token.is_after(leave_token):
|
if since_token and since_token.is_after(leave_token):
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
batch = yield self.load_filtered_recents(
|
batch = yield self.load_filtered_recents(
|
||||||
leave_event.room_id, sync_config, leave_token, since_token,
|
room_id, sync_config, leave_token, since_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Recents %r", batch)
|
logger.debug("Recents %r", batch)
|
||||||
|
|
||||||
state_events_at_leave = yield self.store.get_state_for_event(
|
state_events_at_leave = yield self.store.get_state_for_event(
|
||||||
leave_event.event_id
|
leave_event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not full_state:
|
||||||
state_at_previous_sync = yield self.get_state_at(
|
state_at_previous_sync = yield self.get_state_at(
|
||||||
leave_event.room_id, stream_position=since_token
|
room_id, stream_position=since_token
|
||||||
)
|
)
|
||||||
|
|
||||||
state_events_delta = yield self.compute_state_delta(
|
state_events_delta = yield self.compute_state_delta(
|
||||||
|
@ -851,6 +780,8 @@ class SyncHandler(BaseHandler):
|
||||||
previous_state=state_at_previous_sync,
|
previous_state=state_at_previous_sync,
|
||||||
current_state=state_events_at_leave,
|
current_state=state_events_at_leave,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
state_events_delta = state_events_at_leave
|
||||||
|
|
||||||
state_events_delta = {
|
state_events_delta = {
|
||||||
(e.type, e.state_key): e
|
(e.type, e.state_key): e
|
||||||
|
@ -860,7 +791,7 @@ class SyncHandler(BaseHandler):
|
||||||
}
|
}
|
||||||
|
|
||||||
account_data = self.account_data_for_room(
|
account_data = self.account_data_for_room(
|
||||||
leave_event.room_id, tags_by_room, account_data_by_room
|
room_id, tags_by_room, account_data_by_room
|
||||||
)
|
)
|
||||||
|
|
||||||
account_data = sync_config.filter_collection.filter_room_account_data(
|
account_data = sync_config.filter_collection.filter_room_account_data(
|
||||||
|
@ -868,7 +799,7 @@ class SyncHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
room_sync = ArchivedSyncResult(
|
room_sync = ArchivedSyncResult(
|
||||||
room_id=leave_event.room_id,
|
room_id=room_id,
|
||||||
timeline=batch,
|
timeline=batch,
|
||||||
state=state_events_delta,
|
state=state_events_delta,
|
||||||
account_data=account_data,
|
account_data=account_data,
|
||||||
|
|
|
@ -22,7 +22,7 @@ REQUIREMENTS = {
|
||||||
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
|
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
|
||||||
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
|
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
|
||||||
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
|
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
|
||||||
"pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"],
|
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
|
||||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||||
"Twisted>=15.1.0": ["twisted>=15.1.0"],
|
"Twisted>=15.1.0": ["twisted>=15.1.0"],
|
||||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||||
|
|
|
@ -66,11 +66,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
raise SynapseError(400, e.message)
|
raise SynapseError(400, e.message)
|
||||||
|
|
||||||
before = request.args.get("before", None)
|
before = request.args.get("before", None)
|
||||||
if before and len(before):
|
if before:
|
||||||
before = before[0]
|
before = _namespaced_rule_id(spec, before[0])
|
||||||
|
|
||||||
after = request.args.get("after", None)
|
after = request.args.get("after", None)
|
||||||
if after and len(after):
|
if after:
|
||||||
after = after[0]
|
after = _namespaced_rule_id(spec, after[0])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.hs.get_datastore().add_push_rule(
|
yield self.hs.get_datastore().add_push_rule(
|
||||||
|
@ -452,11 +453,15 @@ def _strip_device_condition(rule):
|
||||||
|
|
||||||
|
|
||||||
def _namespaced_rule_id_from_spec(spec):
|
def _namespaced_rule_id_from_spec(spec):
|
||||||
|
return _namespaced_rule_id(spec, spec['rule_id'])
|
||||||
|
|
||||||
|
|
||||||
|
def _namespaced_rule_id(spec, rule_id):
|
||||||
if spec['scope'] == 'global':
|
if spec['scope'] == 'global':
|
||||||
scope = 'global'
|
scope = 'global'
|
||||||
else:
|
else:
|
||||||
scope = 'device/%s' % (spec['profile_tag'])
|
scope = 'device/%s' % (spec['profile_tag'])
|
||||||
return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
|
return "%s/%s/%s" % (scope, spec['template'], rule_id)
|
||||||
|
|
||||||
|
|
||||||
def _rule_id_from_namespaced(in_rule_id):
|
def _rule_id_from_namespaced(in_rule_id):
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
# Imports required for the default HomeServer() implementation
|
# Imports required for the default HomeServer() implementation
|
||||||
from twisted.web.client import BrowserLikePolicyForHTTPS
|
from twisted.web.client import BrowserLikePolicyForHTTPS
|
||||||
|
from twisted.enterprise import adbapi
|
||||||
|
|
||||||
from synapse.federation import initialize_http_replication
|
from synapse.federation import initialize_http_replication
|
||||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||||
from synapse.notifier import Notifier
|
from synapse.notifier import Notifier
|
||||||
|
@ -36,8 +38,15 @@ from synapse.push.pusherpool import PusherPool
|
||||||
from synapse.events.builder import EventBuilderFactory
|
from synapse.events.builder import EventBuilderFactory
|
||||||
from synapse.api.filtering import Filtering
|
from synapse.api.filtering import Filtering
|
||||||
|
|
||||||
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
|
|
||||||
class BaseHomeServer(object):
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HomeServer(object):
|
||||||
"""A basic homeserver object without lazy component builders.
|
"""A basic homeserver object without lazy component builders.
|
||||||
|
|
||||||
This will need all of the components it requires to either be passed as
|
This will need all of the components it requires to either be passed as
|
||||||
|
@ -98,39 +107,18 @@ class BaseHomeServer(object):
|
||||||
self.hostname = hostname
|
self.hostname = hostname
|
||||||
self._building = {}
|
self._building = {}
|
||||||
|
|
||||||
|
self.clock = Clock()
|
||||||
|
self.distributor = Distributor()
|
||||||
|
self.ratelimiter = Ratelimiter()
|
||||||
|
|
||||||
# Other kwargs are explicit dependencies
|
# Other kwargs are explicit dependencies
|
||||||
for depname in kwargs:
|
for depname in kwargs:
|
||||||
setattr(self, depname, kwargs[depname])
|
setattr(self, depname, kwargs[depname])
|
||||||
|
|
||||||
@classmethod
|
def setup(self):
|
||||||
def _make_dependency_method(cls, depname):
|
logger.info("Setting up.")
|
||||||
def _get(self):
|
self.datastore = DataStore(self.get_db_conn(), self)
|
||||||
if hasattr(self, depname):
|
logger.info("Finished setting up.")
|
||||||
return getattr(self, depname)
|
|
||||||
|
|
||||||
if hasattr(self, "build_%s" % (depname)):
|
|
||||||
# Prevent cyclic dependencies from deadlocking
|
|
||||||
if depname in self._building:
|
|
||||||
raise ValueError("Cyclic dependency while building %s" % (
|
|
||||||
depname,
|
|
||||||
))
|
|
||||||
self._building[depname] = 1
|
|
||||||
|
|
||||||
builder = getattr(self, "build_%s" % (depname))
|
|
||||||
dep = builder()
|
|
||||||
setattr(self, depname, dep)
|
|
||||||
|
|
||||||
del self._building[depname]
|
|
||||||
|
|
||||||
return dep
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
"%s has no %s nor a builder for it" % (
|
|
||||||
type(self).__name__, depname,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(BaseHomeServer, "get_%s" % (depname), _get)
|
|
||||||
|
|
||||||
def get_ip_from_request(self, request):
|
def get_ip_from_request(self, request):
|
||||||
# X-Forwarded-For is handled by our custom request type.
|
# X-Forwarded-For is handled by our custom request type.
|
||||||
|
@ -142,33 +130,9 @@ class BaseHomeServer(object):
|
||||||
def is_mine_id(self, string):
|
def is_mine_id(self, string):
|
||||||
return string.split(":", 1)[1] == self.hostname
|
return string.split(":", 1)[1] == self.hostname
|
||||||
|
|
||||||
# Build magic accessors for every dependency
|
|
||||||
for depname in BaseHomeServer.DEPENDENCIES:
|
|
||||||
BaseHomeServer._make_dependency_method(depname)
|
|
||||||
|
|
||||||
|
|
||||||
class HomeServer(BaseHomeServer):
|
|
||||||
"""A homeserver object that will construct most of its dependencies as
|
|
||||||
required.
|
|
||||||
|
|
||||||
It still requires the following to be specified by the caller:
|
|
||||||
resource_for_client
|
|
||||||
resource_for_web_client
|
|
||||||
resource_for_federation
|
|
||||||
resource_for_content_repo
|
|
||||||
http_client
|
|
||||||
db_pool
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build_clock(self):
|
|
||||||
return Clock()
|
|
||||||
|
|
||||||
def build_replication_layer(self):
|
def build_replication_layer(self):
|
||||||
return initialize_http_replication(self)
|
return initialize_http_replication(self)
|
||||||
|
|
||||||
def build_datastore(self):
|
|
||||||
return DataStore(self)
|
|
||||||
|
|
||||||
def build_handlers(self):
|
def build_handlers(self):
|
||||||
return Handlers(self)
|
return Handlers(self)
|
||||||
|
|
||||||
|
@ -179,10 +143,9 @@ class HomeServer(BaseHomeServer):
|
||||||
return Auth(self)
|
return Auth(self)
|
||||||
|
|
||||||
def build_http_client_context_factory(self):
|
def build_http_client_context_factory(self):
|
||||||
config = self.get_config()
|
|
||||||
return (
|
return (
|
||||||
InsecureInterceptableContextFactory()
|
InsecureInterceptableContextFactory()
|
||||||
if config.use_insecure_ssl_client_just_for_testing_do_not_use
|
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||||
else BrowserLikePolicyForHTTPS()
|
else BrowserLikePolicyForHTTPS()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -201,15 +164,9 @@ class HomeServer(BaseHomeServer):
|
||||||
def build_state_handler(self):
|
def build_state_handler(self):
|
||||||
return StateHandler(self)
|
return StateHandler(self)
|
||||||
|
|
||||||
def build_distributor(self):
|
|
||||||
return Distributor()
|
|
||||||
|
|
||||||
def build_event_sources(self):
|
def build_event_sources(self):
|
||||||
return EventSources(self)
|
return EventSources(self)
|
||||||
|
|
||||||
def build_ratelimiter(self):
|
|
||||||
return Ratelimiter()
|
|
||||||
|
|
||||||
def build_keyring(self):
|
def build_keyring(self):
|
||||||
return Keyring(self)
|
return Keyring(self)
|
||||||
|
|
||||||
|
@ -224,3 +181,55 @@ class HomeServer(BaseHomeServer):
|
||||||
|
|
||||||
def build_pusherpool(self):
|
def build_pusherpool(self):
|
||||||
return PusherPool(self)
|
return PusherPool(self)
|
||||||
|
|
||||||
|
def build_http_client(self):
|
||||||
|
return MatrixFederationHttpClient(self)
|
||||||
|
|
||||||
|
def build_db_pool(self):
|
||||||
|
name = self.db_config["name"]
|
||||||
|
|
||||||
|
return adbapi.ConnectionPool(
|
||||||
|
name,
|
||||||
|
**self.db_config.get("args", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dependency_method(depname):
|
||||||
|
def _get(hs):
|
||||||
|
try:
|
||||||
|
return getattr(hs, depname)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
builder = getattr(hs, "build_%s" % (depname))
|
||||||
|
except AttributeError:
|
||||||
|
builder = None
|
||||||
|
|
||||||
|
if builder:
|
||||||
|
# Prevent cyclic dependencies from deadlocking
|
||||||
|
if depname in hs._building:
|
||||||
|
raise ValueError("Cyclic dependency while building %s" % (
|
||||||
|
depname,
|
||||||
|
))
|
||||||
|
hs._building[depname] = 1
|
||||||
|
|
||||||
|
dep = builder()
|
||||||
|
setattr(hs, depname, dep)
|
||||||
|
|
||||||
|
del hs._building[depname]
|
||||||
|
|
||||||
|
return dep
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"%s has no %s nor a builder for it" % (
|
||||||
|
type(hs).__name__, depname,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(HomeServer, "get_%s" % (depname), _get)
|
||||||
|
|
||||||
|
|
||||||
|
# Build magic accessors for every dependency
|
||||||
|
for depname in HomeServer.DEPENDENCIES:
|
||||||
|
_make_dependency_method(depname)
|
||||||
|
|
|
@ -46,6 +46,9 @@ from .tags import TagsStore
|
||||||
from .account_data import AccountDataStore
|
from .account_data import AccountDataStore
|
||||||
|
|
||||||
|
|
||||||
|
from util.id_generators import IdGenerator, StreamIdGenerator
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,18 +82,43 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
EventPushActionsStore
|
EventPushActionsStore
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(DataStore, self).__init__(hs)
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
self.min_token_deferred = self._get_min_token()
|
cur = db_conn.cursor()
|
||||||
self.min_token = None
|
try:
|
||||||
|
cur.execute("SELECT MIN(stream_ordering) FROM events",)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
|
||||||
|
self.min_stream_token = min(self.min_stream_token, -1)
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
|
||||||
self.client_ip_last_seen = Cache(
|
self.client_ip_last_seen = Cache(
|
||||||
name="client_ip_last_seen",
|
name="client_ip_last_seen",
|
||||||
keylen=4,
|
keylen=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._stream_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "events", "stream_ordering"
|
||||||
|
)
|
||||||
|
self._receipts_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "receipts_linearized", "stream_id"
|
||||||
|
)
|
||||||
|
self._account_data_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "account_data_max_stream_id", "stream_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||||
|
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||||
|
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||||
|
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
|
||||||
|
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||||
|
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||||
|
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||||
|
|
||||||
|
super(DataStore, self).__init__(hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def insert_client_ip(self, user, access_token, ip, user_agent):
|
def insert_client_ip(self, user, access_token, ip, user_agent):
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
|
|
|
@ -15,13 +15,11 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
from synapse.util.caches.descriptors import Cache
|
from synapse.util.caches.descriptors import Cache
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from util.id_generators import IdGenerator, StreamIdGenerator
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -175,16 +173,6 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = hs.database_engine
|
||||||
|
|
||||||
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
|
|
||||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
|
||||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
|
||||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
|
||||||
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
|
|
||||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
|
||||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
|
||||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
|
||||||
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
|
|
||||||
|
|
||||||
def start_profiling(self):
|
def start_profiling(self):
|
||||||
self._previous_loop_ts = self._clock.time_msec()
|
self._previous_loop_ts = self._clock.time_msec()
|
||||||
|
|
||||||
|
@ -345,7 +333,8 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def cursor_to_dict(self, cursor):
|
@staticmethod
|
||||||
|
def cursor_to_dict(cursor):
|
||||||
"""Converts a SQL cursor into an list of dicts.
|
"""Converts a SQL cursor into an list of dicts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -402,8 +391,8 @@ class SQLBaseStore(object):
|
||||||
if not or_ignore:
|
if not or_ignore:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@log_function
|
@staticmethod
|
||||||
def _simple_insert_txn(self, txn, table, values):
|
def _simple_insert_txn(txn, table, values):
|
||||||
keys, vals = zip(*values.items())
|
keys, vals = zip(*values.items())
|
||||||
|
|
||||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||||
|
@ -414,7 +403,8 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
txn.execute(sql, vals)
|
txn.execute(sql, vals)
|
||||||
|
|
||||||
def _simple_insert_many_txn(self, txn, table, values):
|
@staticmethod
|
||||||
|
def _simple_insert_many_txn(txn, table, values):
|
||||||
if not values:
|
if not values:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -537,9 +527,10 @@ class SQLBaseStore(object):
|
||||||
table, keyvalues, retcol, allow_none=allow_none,
|
table, keyvalues, retcol, allow_none=allow_none,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
|
@classmethod
|
||||||
|
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
|
||||||
allow_none=False):
|
allow_none=False):
|
||||||
ret = self._simple_select_onecol_txn(
|
ret = cls._simple_select_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
table=table,
|
table=table,
|
||||||
keyvalues=keyvalues,
|
keyvalues=keyvalues,
|
||||||
|
@ -554,7 +545,8 @@ class SQLBaseStore(object):
|
||||||
else:
|
else:
|
||||||
raise StoreError(404, "No row found")
|
raise StoreError(404, "No row found")
|
||||||
|
|
||||||
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
|
@staticmethod
|
||||||
|
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
|
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
|
||||||
) % {
|
) % {
|
||||||
|
@ -603,7 +595,8 @@ class SQLBaseStore(object):
|
||||||
table, keyvalues, retcols
|
table, keyvalues, retcols
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
|
@classmethod
|
||||||
|
def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
|
||||||
|
@ -627,7 +620,7 @@ class SQLBaseStore(object):
|
||||||
)
|
)
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
|
|
||||||
return self.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _simple_select_many_batch(self, table, column, iterable, retcols,
|
def _simple_select_many_batch(self, table, column, iterable, retcols,
|
||||||
|
@ -662,7 +655,8 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
|
@classmethod
|
||||||
|
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
|
||||||
|
@ -699,7 +693,7 @@ class SQLBaseStore(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql, values)
|
txn.execute(sql, values)
|
||||||
return self.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||||
desc="_simple_update_one"):
|
desc="_simple_update_one"):
|
||||||
|
@ -726,7 +720,8 @@ class SQLBaseStore(object):
|
||||||
table, keyvalues, updatevalues,
|
table, keyvalues, updatevalues,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
|
@staticmethod
|
||||||
|
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||||
table,
|
table,
|
||||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||||
|
@ -743,7 +738,8 @@ class SQLBaseStore(object):
|
||||||
if txn.rowcount > 1:
|
if txn.rowcount > 1:
|
||||||
raise StoreError(500, "More than one row matched")
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
|
@staticmethod
|
||||||
|
def _simple_select_one_txn(txn, table, keyvalues, retcols,
|
||||||
allow_none=False):
|
allow_none=False):
|
||||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||||
", ".join(retcols),
|
", ".join(retcols),
|
||||||
|
@ -784,7 +780,8 @@ class SQLBaseStore(object):
|
||||||
raise StoreError(500, "more than one row matched")
|
raise StoreError(500, "more than one row matched")
|
||||||
return self.runInteraction(desc, func)
|
return self.runInteraction(desc, func)
|
||||||
|
|
||||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
@staticmethod
|
||||||
|
def _simple_delete_txn(txn, table, keyvalues):
|
||||||
sql = "DELETE FROM %s WHERE %s" % (
|
sql = "DELETE FROM %s WHERE %s" % (
|
||||||
table,
|
table,
|
||||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
@ -23,6 +24,14 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AccountDataStore(SQLBaseStore):
|
class AccountDataStore(SQLBaseStore):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(AccountDataStore, self).__init__(hs)
|
||||||
|
|
||||||
|
self._account_data_stream_cache = StreamChangeCache(
|
||||||
|
"AccountDataAndTagsChangeCache",
|
||||||
|
self._account_data_id_gen.get_max_token(None),
|
||||||
|
max_size=10000,
|
||||||
|
)
|
||||||
|
|
||||||
def get_account_data_for_user(self, user_id):
|
def get_account_data_for_user(self, user_id):
|
||||||
"""Get all the client account_data for a user.
|
"""Get all the client account_data for a user.
|
||||||
|
@ -83,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
|
||||||
"get_account_data_for_room", get_account_data_for_room_txn
|
"get_account_data_for_room", get_account_data_for_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_updated_account_data_for_user(self, user_id, stream_id):
|
def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
|
||||||
"""Get all the client account_data for a that's changed.
|
"""Get all the client account_data for a that's changed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -120,6 +129,12 @@ class AccountDataStore(SQLBaseStore):
|
||||||
|
|
||||||
return (global_account_data, account_data_by_room)
|
return (global_account_data, account_data_by_room)
|
||||||
|
|
||||||
|
changed = self._account_data_stream_cache.has_entity_changed(
|
||||||
|
user_id, int(stream_id)
|
||||||
|
)
|
||||||
|
if not changed:
|
||||||
|
return ({}, {})
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
||||||
)
|
)
|
||||||
|
@ -186,6 +201,10 @@ class AccountDataStore(SQLBaseStore):
|
||||||
"content": content_json,
|
"content": content_json,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self._account_data_stream_cache.entity_has_changed,
|
||||||
|
user_id, next_id,
|
||||||
|
)
|
||||||
self._update_max_stream_id(txn, next_id)
|
self._update_max_stream_id(txn, next_id)
|
||||||
|
|
||||||
with (yield self._account_data_id_gen.get_next(self)) as next_id:
|
with (yield self._account_data_id_gen.get_next(self)) as next_id:
|
||||||
|
|
|
@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
|
||||||
return
|
return
|
||||||
|
|
||||||
if backfilled:
|
if backfilled:
|
||||||
if not self.min_token_deferred.called:
|
start = self.min_stream_token - 1
|
||||||
yield self.min_token_deferred
|
self.min_stream_token -= len(events_and_contexts) + 1
|
||||||
start = self.min_token - 1
|
stream_orderings = range(start, self.min_stream_token, -1)
|
||||||
self.min_token -= len(events_and_contexts) + 1
|
|
||||||
stream_orderings = range(start, self.min_token, -1)
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def stream_ordering_manager():
|
def stream_ordering_manager():
|
||||||
|
@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
|
||||||
is_new_state=True, current_state=None):
|
is_new_state=True, current_state=None):
|
||||||
stream_ordering = None
|
stream_ordering = None
|
||||||
if backfilled:
|
if backfilled:
|
||||||
if not self.min_token_deferred.called:
|
self.min_stream_token -= 1
|
||||||
yield self.min_token_deferred
|
stream_ordering = self.min_stream_token
|
||||||
self.min_token -= 1
|
|
||||||
stream_ordering = self.min_token
|
|
||||||
|
|
||||||
if stream_ordering is None:
|
if stream_ordering is None:
|
||||||
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
|
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
|
||||||
|
@ -214,6 +210,12 @@ class EventsStore(SQLBaseStore):
|
||||||
for event, _ in events_and_contexts:
|
for event, _ in events_and_contexts:
|
||||||
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||||
|
|
||||||
|
if not backfilled:
|
||||||
|
txn.call_after(
|
||||||
|
self._events_stream_cache.entity_has_changed,
|
||||||
|
event.room_id, event.internal_metadata.stream_ordering,
|
||||||
|
)
|
||||||
|
|
||||||
depth_updates = {}
|
depth_updates = {}
|
||||||
for event, _ in events_and_contexts:
|
for event, _ in events_and_contexts:
|
||||||
if event.internal_metadata.is_outlier():
|
if event.internal_metadata.is_outlier():
|
||||||
|
|
|
@ -16,12 +16,13 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
|
||||||
|
|
||||||
class FilteringStore(SQLBaseStore):
|
class FilteringStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@cachedInlineCallbacks(num_args=2)
|
||||||
def get_user_filter(self, user_localpart, filter_id):
|
def get_user_filter(self, user_localpart, filter_id):
|
||||||
def_json = yield self._simple_select_one_onecol(
|
def_json = yield self._simple_select_one_onecol(
|
||||||
table="user_filters",
|
table="user_filters",
|
||||||
|
|
|
@ -130,7 +130,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
|
|
||||||
def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
|
def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
|
||||||
after = kwargs.pop("after", None)
|
after = kwargs.pop("after", None)
|
||||||
relative_to_rule = kwargs.pop("before", after)
|
before = kwargs.pop("before", None)
|
||||||
|
relative_to_rule = before or after
|
||||||
|
|
||||||
res = self._simple_select_one_txn(
|
res = self._simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
|
|
|
@ -15,11 +15,10 @@
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
|
||||||
from synapse.util.caches import cache_counter, caches_by_name
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from blist import sorteddict
|
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
|
||||||
|
@ -31,7 +30,9 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ReceiptsStore, self).__init__(hs)
|
super(ReceiptsStore, self).__init__(hs)
|
||||||
|
|
||||||
self._receipts_stream_cache = _RoomStreamChangeCache()
|
self._receipts_stream_cache = StreamChangeCache(
|
||||||
|
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
|
||||||
|
)
|
||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
def get_receipts_for_room(self, room_id, receipt_type):
|
def get_receipts_for_room(self, room_id, receipt_type):
|
||||||
|
@ -76,8 +77,8 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
room_ids = set(room_ids)
|
room_ids = set(room_ids)
|
||||||
|
|
||||||
if from_key:
|
if from_key:
|
||||||
room_ids = yield self._receipts_stream_cache.get_rooms_changed(
|
room_ids = yield self._receipts_stream_cache.get_entities_changed(
|
||||||
self, room_ids, from_key
|
room_ids, from_key
|
||||||
)
|
)
|
||||||
|
|
||||||
results = yield self._get_linearized_receipts_for_rooms(
|
results = yield self._get_linearized_receipts_for_rooms(
|
||||||
|
@ -220,6 +221,11 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
# FIXME: This shouldn't invalidate the whole cache
|
# FIXME: This shouldn't invalidate the whole cache
|
||||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
|
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self._receipts_stream_cache.entity_has_changed,
|
||||||
|
room_id, stream_id
|
||||||
|
)
|
||||||
|
|
||||||
# We don't want to clobber receipts for more recent events, so we
|
# We don't want to clobber receipts for more recent events, so we
|
||||||
# have to compare orderings of existing receipts
|
# have to compare orderings of existing receipts
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -307,9 +313,6 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
|
|
||||||
stream_id_manager = yield self._receipts_id_gen.get_next(self)
|
stream_id_manager = yield self._receipts_id_gen.get_next(self)
|
||||||
with stream_id_manager as stream_id:
|
with stream_id_manager as stream_id:
|
||||||
yield self._receipts_stream_cache.room_has_changed(
|
|
||||||
self, room_id, stream_id
|
|
||||||
)
|
|
||||||
have_persisted = yield self.runInteraction(
|
have_persisted = yield self.runInteraction(
|
||||||
"insert_linearized_receipt",
|
"insert_linearized_receipt",
|
||||||
self.insert_linearized_receipt_txn,
|
self.insert_linearized_receipt_txn,
|
||||||
|
@ -368,63 +371,3 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
"data": json.dumps(data),
|
"data": json.dumps(data),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _RoomStreamChangeCache(object):
|
|
||||||
"""Keeps track of the stream_id of the latest change in rooms.
|
|
||||||
|
|
||||||
Given a list of rooms and stream key, it will give a subset of rooms that
|
|
||||||
may have changed since that key. If the key is too old then the cache
|
|
||||||
will simply return all rooms.
|
|
||||||
"""
|
|
||||||
def __init__(self, size_of_cache=10000):
|
|
||||||
self._size_of_cache = size_of_cache
|
|
||||||
self._room_to_key = {}
|
|
||||||
self._cache = sorteddict()
|
|
||||||
self._earliest_key = None
|
|
||||||
self.name = "ReceiptsRoomChangeCache"
|
|
||||||
caches_by_name[self.name] = self._cache
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_rooms_changed(self, store, room_ids, key):
|
|
||||||
"""Returns subset of room ids that have had new receipts since the
|
|
||||||
given key. If the key is too old it will just return the given list.
|
|
||||||
"""
|
|
||||||
if key > (yield self._get_earliest_key(store)):
|
|
||||||
keys = self._cache.keys()
|
|
||||||
i = keys.bisect_right(key)
|
|
||||||
|
|
||||||
result = set(
|
|
||||||
self._cache[k] for k in keys[i:]
|
|
||||||
).intersection(room_ids)
|
|
||||||
|
|
||||||
cache_counter.inc_hits(self.name)
|
|
||||||
else:
|
|
||||||
result = room_ids
|
|
||||||
cache_counter.inc_misses(self.name)
|
|
||||||
|
|
||||||
defer.returnValue(result)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def room_has_changed(self, store, room_id, key):
|
|
||||||
"""Informs the cache that the room has been changed at the given key.
|
|
||||||
"""
|
|
||||||
if key > (yield self._get_earliest_key(store)):
|
|
||||||
old_key = self._room_to_key.get(room_id, None)
|
|
||||||
if old_key:
|
|
||||||
key = max(key, old_key)
|
|
||||||
self._cache.pop(old_key, None)
|
|
||||||
self._cache[key] = room_id
|
|
||||||
|
|
||||||
while len(self._cache) > self._size_of_cache:
|
|
||||||
k, r = self._cache.popitem()
|
|
||||||
self._earliest_key = max(k, self._earliest_key)
|
|
||||||
self._room_to_key.pop(r, None)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_earliest_key(self, store):
|
|
||||||
if self._earliest_key is None:
|
|
||||||
self._earliest_key = yield store.get_max_receipt_stream_id()
|
|
||||||
self._earliest_key = int(self._earliest_key)
|
|
||||||
|
|
||||||
defer.returnValue(self._earliest_key)
|
|
||||||
|
|
|
@ -241,7 +241,7 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
@cached()
|
@cached(max_entries=5000)
|
||||||
def get_rooms_for_user(self, user_id):
|
def get_rooms_for_user(self, user_id):
|
||||||
return self.get_rooms_for_user_where_membership_is(
|
return self.get_rooms_for_user_where_membership_is(
|
||||||
user_id, membership_list=[Membership.JOIN],
|
user_id, membership_list=[Membership.JOIN],
|
||||||
|
|
16
synapse/storage/schema/delta/28/events_room_stream.sql
Normal file
16
synapse/storage/schema/delta/28/events_room_stream.sql
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE INDEX events_room_stream on events(room_id, stream_ordering);
|
|
@ -37,6 +37,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
@ -77,6 +78,12 @@ def upper_bound(token):
|
||||||
|
|
||||||
|
|
||||||
class StreamStore(SQLBaseStore):
|
class StreamStore(SQLBaseStore):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(StreamStore, self).__init__(hs)
|
||||||
|
|
||||||
|
self._events_stream_cache = StreamChangeCache(
|
||||||
|
"EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
|
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
|
||||||
|
@ -157,6 +164,135 @@ class StreamStore(SQLBaseStore):
|
||||||
results = yield self.runInteraction("get_appservice_room_stream", f)
|
results = yield self.runInteraction("get_appservice_room_stream", f)
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0):
|
||||||
|
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||||
|
|
||||||
|
room_ids = yield self._events_stream_cache.get_entities_changed(
|
||||||
|
room_ids, from_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not room_ids:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
room_ids = list(room_ids)
|
||||||
|
for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
|
||||||
|
res = yield defer.gatherResults([
|
||||||
|
self.get_room_events_stream_for_room(
|
||||||
|
room_id, from_key, to_key, limit
|
||||||
|
).addCallback(lambda r, rm: (rm, r), room_id)
|
||||||
|
for room_id in room_ids
|
||||||
|
])
|
||||||
|
results.update(dict(res))
|
||||||
|
|
||||||
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0):
|
||||||
|
if from_key is not None:
|
||||||
|
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||||
|
else:
|
||||||
|
from_id = None
|
||||||
|
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
||||||
|
|
||||||
|
if from_key == to_key:
|
||||||
|
defer.returnValue(([], from_key))
|
||||||
|
|
||||||
|
if from_id:
|
||||||
|
has_changed = yield self._events_stream_cache.has_entity_changed(
|
||||||
|
room_id, from_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_changed:
|
||||||
|
defer.returnValue(([], from_key))
|
||||||
|
|
||||||
|
def f(txn):
|
||||||
|
if from_id is not None:
|
||||||
|
sql = (
|
||||||
|
"SELECT event_id, stream_ordering FROM events WHERE"
|
||||||
|
" room_id = ?"
|
||||||
|
" AND not outlier"
|
||||||
|
" AND stream_ordering > ? AND stream_ordering <= ?"
|
||||||
|
" ORDER BY stream_ordering DESC LIMIT ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (room_id, from_id, to_id, limit))
|
||||||
|
else:
|
||||||
|
sql = (
|
||||||
|
"SELECT event_id, stream_ordering FROM events WHERE"
|
||||||
|
" room_id = ?"
|
||||||
|
" AND not outlier"
|
||||||
|
" AND stream_ordering <= ?"
|
||||||
|
" ORDER BY stream_ordering DESC LIMIT ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (room_id, to_id, limit))
|
||||||
|
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
ret = self._get_events_txn(
|
||||||
|
txn,
|
||||||
|
[r["event_id"] for r in rows],
|
||||||
|
get_prev_content=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self._set_before_and_after(ret, rows, topo_order=False)
|
||||||
|
|
||||||
|
ret.reverse()
|
||||||
|
|
||||||
|
if rows:
|
||||||
|
key = "s%d" % min(r["stream_ordering"] for r in rows)
|
||||||
|
else:
|
||||||
|
# Assume we didn't get anything because there was nothing to
|
||||||
|
# get.
|
||||||
|
key = from_key
|
||||||
|
|
||||||
|
return ret, key
|
||||||
|
res = yield self.runInteraction("get_room_events_stream_for_room", f)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
def get_room_changes_for_user(self, user_id, from_key, to_key):
|
||||||
|
if from_key is not None:
|
||||||
|
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||||
|
else:
|
||||||
|
from_id = None
|
||||||
|
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
||||||
|
|
||||||
|
if from_key == to_key:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def f(txn):
|
||||||
|
if from_id is not None:
|
||||||
|
sql = (
|
||||||
|
"SELECT m.event_id, stream_ordering FROM events AS e,"
|
||||||
|
" room_memberships AS m"
|
||||||
|
" WHERE e.event_id = m.event_id"
|
||||||
|
" AND m.user_id = ?"
|
||||||
|
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
|
||||||
|
" ORDER BY e.stream_ordering ASC"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, from_id, to_id,))
|
||||||
|
else:
|
||||||
|
sql = (
|
||||||
|
"SELECT m.event_id, stream_ordering FROM events AS e,"
|
||||||
|
" room_memberships AS m"
|
||||||
|
" WHERE e.event_id = m.event_id"
|
||||||
|
" AND m.user_id = ?"
|
||||||
|
" AND stream_ordering <= ?"
|
||||||
|
" ORDER BY stream_ordering ASC"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, to_id,))
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
ret = self._get_events_txn(
|
||||||
|
txn,
|
||||||
|
[r["event_id"] for r in rows],
|
||||||
|
get_prev_content=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return self.runInteraction("get_room_changes_for_user", f)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_room_events_stream(
|
def get_room_events_stream(
|
||||||
self,
|
self,
|
||||||
|
@ -174,7 +310,8 @@ class StreamStore(SQLBaseStore):
|
||||||
"SELECT c.room_id FROM history_visibility AS h"
|
"SELECT c.room_id FROM history_visibility AS h"
|
||||||
" INNER JOIN current_state_events AS c"
|
" INNER JOIN current_state_events AS c"
|
||||||
" ON h.event_id = c.event_id"
|
" ON h.event_id = c.event_id"
|
||||||
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
|
" WHERE c.room_id IN (%s)"
|
||||||
|
" AND h.history_visibility = 'world_readable'" % (
|
||||||
",".join(map(lambda _: "?", room_ids))
|
",".join(map(lambda _: "?", room_ids))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -434,6 +571,18 @@ class StreamStore(SQLBaseStore):
|
||||||
row["topological_ordering"], row["stream_ordering"],)
|
row["topological_ordering"], row["stream_ordering"],)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
|
||||||
|
sql = (
|
||||||
|
"SELECT max(topological_ordering) FROM events"
|
||||||
|
" WHERE room_id = ? AND stream_ordering < ?"
|
||||||
|
)
|
||||||
|
return self._execute(
|
||||||
|
"get_max_topological_token_for_stream_and_room", None,
|
||||||
|
sql, room_id, stream_key,
|
||||||
|
).addCallback(
|
||||||
|
lambda r: r[0][0] if r else 0
|
||||||
|
)
|
||||||
|
|
||||||
def _get_max_topological_txn(self, txn):
|
def _get_max_topological_txn(self, txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT MAX(topological_ordering) FROM events"
|
"SELECT MAX(topological_ordering) FROM events"
|
||||||
|
@ -444,24 +593,14 @@ class StreamStore(SQLBaseStore):
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
return rows[0][0] if rows else 0
|
return rows[0][0] if rows else 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_min_token(self):
|
|
||||||
row = yield self._execute(
|
|
||||||
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
|
|
||||||
self.min_token = min(self.min_token, -1)
|
|
||||||
|
|
||||||
logger.debug("min_token is: %s", self.min_token)
|
|
||||||
|
|
||||||
defer.returnValue(self.min_token)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_before_and_after(events, rows):
|
def _set_before_and_after(events, rows, topo_order=True):
|
||||||
for event, row in zip(events, rows):
|
for event, row in zip(events, rows):
|
||||||
stream = row["stream_ordering"]
|
stream = row["stream_ordering"]
|
||||||
|
if topo_order:
|
||||||
topo = event.depth
|
topo = event.depth
|
||||||
|
else:
|
||||||
|
topo = None
|
||||||
internal = event.internal_metadata
|
internal = event.internal_metadata
|
||||||
internal.before = str(RoomStreamToken(topo, stream - 1))
|
internal.before = str(RoomStreamToken(topo, stream - 1))
|
||||||
internal.after = str(RoomStreamToken(topo, stream))
|
internal.after = str(RoomStreamToken(topo, stream))
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from .util.id_generators import StreamIdGenerator
|
|
||||||
|
|
||||||
import ujson as json
|
import ujson as json
|
||||||
import logging
|
import logging
|
||||||
|
@ -25,13 +24,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TagsStore(SQLBaseStore):
|
class TagsStore(SQLBaseStore):
|
||||||
def __init__(self, hs):
|
|
||||||
super(TagsStore, self).__init__(hs)
|
|
||||||
|
|
||||||
self._account_data_id_gen = StreamIdGenerator(
|
|
||||||
"account_data_max_stream_id", "stream_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_max_account_data_stream_id(self):
|
def get_max_account_data_stream_id(self):
|
||||||
"""Get the current max stream id for the private user data stream
|
"""Get the current max stream id for the private user data stream
|
||||||
|
|
||||||
|
@ -87,6 +79,12 @@ class TagsStore(SQLBaseStore):
|
||||||
room_ids = [row[0] for row in txn.fetchall()]
|
room_ids = [row[0] for row in txn.fetchall()]
|
||||||
return room_ids
|
return room_ids
|
||||||
|
|
||||||
|
changed = self._account_data_stream_cache.has_entity_changed(
|
||||||
|
user_id, int(stream_id)
|
||||||
|
)
|
||||||
|
if not changed:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
room_ids = yield self.runInteraction(
|
room_ids = yield self.runInteraction(
|
||||||
"get_updated_tags", get_updated_tags_txn
|
"get_updated_tags", get_updated_tags_txn
|
||||||
)
|
)
|
||||||
|
@ -184,6 +182,11 @@ class TagsStore(SQLBaseStore):
|
||||||
next_id(int): The the revision to advance to.
|
next_id(int): The the revision to advance to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self._account_data_stream_cache.entity_has_changed,
|
||||||
|
user_id, next_id
|
||||||
|
)
|
||||||
|
|
||||||
update_max_id_sql = (
|
update_max_id_sql = (
|
||||||
"UPDATE account_data_max_stream_id"
|
"UPDATE account_data_max_stream_id"
|
||||||
" SET stream_id = ?"
|
" SET stream_id = ?"
|
||||||
|
|
|
@ -72,28 +72,24 @@ class StreamIdGenerator(object):
|
||||||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
with stream_id_gen.get_next_txn(txn) as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
def __init__(self, table, column):
|
def __init__(self, db_conn, table, column):
|
||||||
self.table = table
|
self.table = table
|
||||||
self.column = column
|
self.column = column
|
||||||
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
self._current_max = None
|
cur = db_conn.cursor()
|
||||||
|
self._current_max = self._get_or_compute_current_max(cur)
|
||||||
|
cur.close()
|
||||||
|
|
||||||
self._unfinished_ids = deque()
|
self._unfinished_ids = deque()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_next(self, store):
|
def get_next(self, store):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with yield stream_id_gen.get_next as stream_id:
|
with yield stream_id_gen.get_next as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
if not self._current_max:
|
|
||||||
yield store.runInteraction(
|
|
||||||
"_compute_current_max",
|
|
||||||
self._get_or_compute_current_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._current_max += 1
|
self._current_max += 1
|
||||||
next_id = self._current_max
|
next_id = self._current_max
|
||||||
|
@ -108,21 +104,14 @@ class StreamIdGenerator(object):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.remove(next_id)
|
||||||
|
|
||||||
defer.returnValue(manager())
|
return manager()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_next_mult(self, store, n):
|
def get_next_mult(self, store, n):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with yield stream_id_gen.get_next(store, n) as stream_ids:
|
with yield stream_id_gen.get_next(store, n) as stream_ids:
|
||||||
# ... persist events ...
|
# ... persist events ...
|
||||||
"""
|
"""
|
||||||
if not self._current_max:
|
|
||||||
yield store.runInteraction(
|
|
||||||
"_compute_current_max",
|
|
||||||
self._get_or_compute_current_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
next_ids = range(self._current_max + 1, self._current_max + n + 1)
|
next_ids = range(self._current_max + 1, self._current_max + n + 1)
|
||||||
self._current_max += n
|
self._current_max += n
|
||||||
|
@ -139,24 +128,17 @@ class StreamIdGenerator(object):
|
||||||
for next_id in next_ids:
|
for next_id in next_ids:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.remove(next_id)
|
||||||
|
|
||||||
defer.returnValue(manager())
|
return manager()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_max_token(self, store):
|
def get_max_token(self, store):
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
"""Returns the maximum stream id such that all stream ids less than or
|
||||||
equal to it have been successfully persisted.
|
equal to it have been successfully persisted.
|
||||||
"""
|
"""
|
||||||
if not self._current_max:
|
|
||||||
yield store.runInteraction(
|
|
||||||
"_compute_current_max",
|
|
||||||
self._get_or_compute_current_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._unfinished_ids:
|
if self._unfinished_ids:
|
||||||
defer.returnValue(self._unfinished_ids[0] - 1)
|
return self._unfinished_ids[0] - 1
|
||||||
|
|
||||||
defer.returnValue(self._current_max)
|
return self._current_max
|
||||||
|
|
||||||
def _get_or_compute_current_max(self, txn):
|
def _get_or_compute_current_max(self, txn):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|
|
@ -37,7 +37,7 @@ class LruCache(object):
|
||||||
"""
|
"""
|
||||||
def __init__(self, max_size, keylen=1, cache_type=dict):
|
def __init__(self, max_size, keylen=1, cache_type=dict):
|
||||||
cache = cache_type()
|
cache = cache_type()
|
||||||
self.size = 0
|
self.cache = cache # Used for introspection.
|
||||||
list_root = []
|
list_root = []
|
||||||
list_root[:] = [list_root, list_root, None, None]
|
list_root[:] = [list_root, list_root, None, None]
|
||||||
|
|
||||||
|
@ -60,7 +60,6 @@ class LruCache(object):
|
||||||
prev_node[NEXT] = node
|
prev_node[NEXT] = node
|
||||||
next_node[PREV] = node
|
next_node[PREV] = node
|
||||||
cache[key] = node
|
cache[key] = node
|
||||||
self.size += 1
|
|
||||||
|
|
||||||
def move_node_to_front(node):
|
def move_node_to_front(node):
|
||||||
prev_node = node[PREV]
|
prev_node = node[PREV]
|
||||||
|
@ -79,7 +78,6 @@ class LruCache(object):
|
||||||
next_node = node[NEXT]
|
next_node = node[NEXT]
|
||||||
prev_node[NEXT] = next_node
|
prev_node[NEXT] = next_node
|
||||||
next_node[PREV] = prev_node
|
next_node[PREV] = prev_node
|
||||||
self.size -= 1
|
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_get(key, default=None):
|
def cache_get(key, default=None):
|
||||||
|
@ -98,7 +96,7 @@ class LruCache(object):
|
||||||
node[VALUE] = value
|
node[VALUE] = value
|
||||||
else:
|
else:
|
||||||
add_node(key, value)
|
add_node(key, value)
|
||||||
if self.size > max_size:
|
if len(cache) > max_size:
|
||||||
todelete = list_root[PREV]
|
todelete = list_root[PREV]
|
||||||
delete_node(todelete)
|
delete_node(todelete)
|
||||||
cache.pop(todelete[KEY], None)
|
cache.pop(todelete[KEY], None)
|
||||||
|
@ -110,7 +108,7 @@ class LruCache(object):
|
||||||
return node[VALUE]
|
return node[VALUE]
|
||||||
else:
|
else:
|
||||||
add_node(key, value)
|
add_node(key, value)
|
||||||
if self.size > max_size:
|
if len(cache) > max_size:
|
||||||
todelete = list_root[PREV]
|
todelete = list_root[PREV]
|
||||||
delete_node(todelete)
|
delete_node(todelete)
|
||||||
cache.pop(todelete[KEY], None)
|
cache.pop(todelete[KEY], None)
|
||||||
|
@ -145,7 +143,7 @@ class LruCache(object):
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_len():
|
def cache_len():
|
||||||
return self.size
|
return len(cache)
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_contains(key):
|
def cache_contains(key):
|
||||||
|
|
107
synapse/util/caches/stream_change_cache.py
Normal file
107
synapse/util/caches/stream_change_cache.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.util.caches import cache_counter, caches_by_name
|
||||||
|
|
||||||
|
|
||||||
|
from blist import sorteddict
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamChangeCache(object):
|
||||||
|
"""Keeps track of the stream positions of the latest change in a set of entities.
|
||||||
|
|
||||||
|
Typically the entity will be a room or user id.
|
||||||
|
|
||||||
|
Given a list of entities and a stream position, it will give a subset of
|
||||||
|
entities that may have changed since that position. If position key is too
|
||||||
|
old then the cache will simply return all given entities.
|
||||||
|
"""
|
||||||
|
def __init__(self, name, current_stream_pos, max_size=10000):
|
||||||
|
self._max_size = max_size
|
||||||
|
self._entity_to_key = {}
|
||||||
|
self._cache = sorteddict()
|
||||||
|
self._earliest_known_stream_pos = current_stream_pos
|
||||||
|
self.name = name
|
||||||
|
caches_by_name[self.name] = self._cache
|
||||||
|
|
||||||
|
def has_entity_changed(self, entity, stream_pos):
|
||||||
|
"""Returns True if the entity may have been updated since stream_pos
|
||||||
|
"""
|
||||||
|
assert type(stream_pos) is int
|
||||||
|
|
||||||
|
if stream_pos < self._earliest_known_stream_pos:
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if stream_pos == self._earliest_known_stream_pos:
|
||||||
|
# If the same as the earliest key, assume nothing has changed.
|
||||||
|
cache_counter.inc_hits(self.name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
latest_entity_change_pos = self._entity_to_key.get(entity, None)
|
||||||
|
if latest_entity_change_pos is None:
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if stream_pos < latest_entity_change_pos:
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
cache_counter.inc_hits(self.name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_entities_changed(self, entities, stream_pos):
|
||||||
|
"""Returns subset of entities that have had new things since the
|
||||||
|
given position. If the position is too old it will just return the given list.
|
||||||
|
"""
|
||||||
|
assert type(stream_pos) is int
|
||||||
|
|
||||||
|
if stream_pos >= self._earliest_known_stream_pos:
|
||||||
|
keys = self._cache.keys()
|
||||||
|
i = keys.bisect_right(stream_pos)
|
||||||
|
|
||||||
|
result = set(
|
||||||
|
self._cache[k] for k in keys[i:]
|
||||||
|
).intersection(entities)
|
||||||
|
|
||||||
|
cache_counter.inc_hits(self.name)
|
||||||
|
else:
|
||||||
|
result = entities
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def entity_has_changed(self, entity, stream_pos):
|
||||||
|
"""Informs the cache that the entity has been changed at the given
|
||||||
|
position.
|
||||||
|
"""
|
||||||
|
assert type(stream_pos) is int
|
||||||
|
|
||||||
|
if stream_pos > self._earliest_known_stream_pos:
|
||||||
|
old_pos = self._entity_to_key.get(entity, None)
|
||||||
|
if old_pos:
|
||||||
|
stream_pos = max(stream_pos, old_pos)
|
||||||
|
self._cache.pop(old_pos, None)
|
||||||
|
self._cache[stream_pos] = entity
|
||||||
|
self._entity_to_key[entity] = stream_pos
|
||||||
|
|
||||||
|
while len(self._cache) > self._max_size:
|
||||||
|
k, r = self._cache.popitem()
|
||||||
|
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
|
||||||
|
self._entity_to_key.pop(r, None)
|
|
@ -8,6 +8,7 @@ class TreeCache(object):
|
||||||
Keys must be tuples.
|
Keys must be tuples.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.size = 0
|
||||||
self.root = {}
|
self.root = {}
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
|
@ -20,7 +21,8 @@ class TreeCache(object):
|
||||||
node = self.root
|
node = self.root
|
||||||
for k in key[:-1]:
|
for k in key[:-1]:
|
||||||
node = node.setdefault(k, {})
|
node = node.setdefault(k, {})
|
||||||
node[key[-1]] = value
|
node[key[-1]] = _Entry(value)
|
||||||
|
self.size += 1
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
node = self.root
|
node = self.root
|
||||||
|
@ -28,9 +30,10 @@ class TreeCache(object):
|
||||||
node = node.get(k, None)
|
node = node.get(k, None)
|
||||||
if node is None:
|
if node is None:
|
||||||
return default
|
return default
|
||||||
return node.get(key[-1], default)
|
return node.get(key[-1], _Entry(default)).value
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
self.size = 0
|
||||||
self.root = {}
|
self.root = {}
|
||||||
|
|
||||||
def pop(self, key, default=None):
|
def pop(self, key, default=None):
|
||||||
|
@ -57,4 +60,33 @@ class TreeCache(object):
|
||||||
break
|
break
|
||||||
node_and_keys[i+1][0].pop(k)
|
node_and_keys[i+1][0].pop(k)
|
||||||
|
|
||||||
|
popped, cnt = _strip_and_count_entires(popped)
|
||||||
|
self.size -= cnt
|
||||||
return popped
|
return popped
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
|
||||||
|
class _Entry(object):
|
||||||
|
__slots__ = ["value"]
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_and_count_entires(d):
|
||||||
|
"""Takes an _Entry or dict with leaves of _Entry's, and either returns the
|
||||||
|
value or a dictionary with _Entry's replaced by their values.
|
||||||
|
|
||||||
|
Also returns the count of _Entry's
|
||||||
|
"""
|
||||||
|
if isinstance(d, dict):
|
||||||
|
cnt = 0
|
||||||
|
for key, value in d.items():
|
||||||
|
v, n = _strip_and_count_entires(value)
|
||||||
|
d[key] = v
|
||||||
|
cnt += n
|
||||||
|
return d, cnt
|
||||||
|
else:
|
||||||
|
return d.value, 1
|
||||||
|
|
|
@ -382,19 +382,20 @@ class FilteringTestCase(unittest.TestCase):
|
||||||
"types": ["m.*"]
|
"types": ["m.*"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user = UserID.from_string("@" + user_localpart + ":test")
|
|
||||||
filter_id = yield self.datastore.add_user_filter(
|
filter_id = yield self.datastore.add_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart + "2",
|
||||||
user_filter=user_filter_json,
|
user_filter=user_filter_json,
|
||||||
)
|
)
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
|
event_id="$asdasd:localhost",
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
type="custom.avatar.3d.crazy",
|
type="custom.avatar.3d.crazy",
|
||||||
)
|
)
|
||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield self.filtering.get_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart + "2",
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,303 +0,0 @@
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# trial imports
|
|
||||||
from twisted.internet import defer
|
|
||||||
from tests import unittest
|
|
||||||
|
|
||||||
# python imports
|
|
||||||
from mock import Mock, ANY
|
|
||||||
|
|
||||||
from ..utils import MockHttpResource, MockClock, setup_test_homeserver
|
|
||||||
|
|
||||||
from synapse.federation import initialize_http_replication
|
|
||||||
from synapse.events import FrozenEvent
|
|
||||||
|
|
||||||
|
|
||||||
def make_pdu(prev_pdus=[], **kwargs):
|
|
||||||
"""Provide some default fields for making a PduTuple."""
|
|
||||||
pdu_fields = {
|
|
||||||
"state_key": None,
|
|
||||||
"prev_events": prev_pdus,
|
|
||||||
}
|
|
||||||
pdu_fields.update(kwargs)
|
|
||||||
|
|
||||||
return FrozenEvent(pdu_fields)
|
|
||||||
|
|
||||||
|
|
||||||
class FederationTestCase(unittest.TestCase):
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def setUp(self):
|
|
||||||
self.mock_resource = MockHttpResource()
|
|
||||||
self.mock_http_client = Mock(spec=[
|
|
||||||
"get_json",
|
|
||||||
"put_json",
|
|
||||||
])
|
|
||||||
self.mock_persistence = Mock(spec=[
|
|
||||||
"prep_send_transaction",
|
|
||||||
"delivered_txn",
|
|
||||||
"get_received_txn_response",
|
|
||||||
"set_received_txn_response",
|
|
||||||
"get_destination_retry_timings",
|
|
||||||
"get_auth_chain",
|
|
||||||
])
|
|
||||||
self.mock_persistence.get_received_txn_response.return_value = (
|
|
||||||
defer.succeed(None)
|
|
||||||
)
|
|
||||||
|
|
||||||
retry_timings_res = {
|
|
||||||
"destination": "",
|
|
||||||
"retry_last_ts": 0,
|
|
||||||
"retry_interval": 0,
|
|
||||||
}
|
|
||||||
self.mock_persistence.get_destination_retry_timings.return_value = (
|
|
||||||
defer.succeed(retry_timings_res)
|
|
||||||
)
|
|
||||||
self.mock_persistence.get_auth_chain.return_value = []
|
|
||||||
self.clock = MockClock()
|
|
||||||
hs = yield setup_test_homeserver(
|
|
||||||
resource_for_federation=self.mock_resource,
|
|
||||||
http_client=self.mock_http_client,
|
|
||||||
datastore=self.mock_persistence,
|
|
||||||
clock=self.clock,
|
|
||||||
keyring=Mock(),
|
|
||||||
)
|
|
||||||
self.federation = initialize_http_replication(hs)
|
|
||||||
self.distributor = hs.get_distributor()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_get_state(self):
|
|
||||||
mock_handler = Mock(spec=[
|
|
||||||
"get_state_for_pdu",
|
|
||||||
])
|
|
||||||
|
|
||||||
self.federation.set_handler(mock_handler)
|
|
||||||
|
|
||||||
mock_handler.get_state_for_pdu.return_value = defer.succeed([])
|
|
||||||
|
|
||||||
# Empty context initially
|
|
||||||
(code, response) = yield self.mock_resource.trigger(
|
|
||||||
"GET",
|
|
||||||
"/_matrix/federation/v1/state/my-context/",
|
|
||||||
None
|
|
||||||
)
|
|
||||||
self.assertEquals(200, code)
|
|
||||||
self.assertFalse(response["pdus"])
|
|
||||||
|
|
||||||
# Now lets give the context some state
|
|
||||||
mock_handler.get_state_for_pdu.return_value = (
|
|
||||||
defer.succeed([
|
|
||||||
make_pdu(
|
|
||||||
event_id="the-pdu-id",
|
|
||||||
origin="red",
|
|
||||||
user_id="@a:red",
|
|
||||||
room_id="my-context",
|
|
||||||
type="m.topic",
|
|
||||||
origin_server_ts=123456789000,
|
|
||||||
depth=1,
|
|
||||||
content={"topic": "The topic"},
|
|
||||||
state_key="",
|
|
||||||
power_level=1000,
|
|
||||||
prev_state="last-pdu-id",
|
|
||||||
),
|
|
||||||
])
|
|
||||||
)
|
|
||||||
|
|
||||||
(code, response) = yield self.mock_resource.trigger(
|
|
||||||
"GET",
|
|
||||||
"/_matrix/federation/v1/state/my-context/",
|
|
||||||
None
|
|
||||||
)
|
|
||||||
self.assertEquals(200, code)
|
|
||||||
self.assertEquals(1, len(response["pdus"]))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_get_pdu(self):
|
|
||||||
mock_handler = Mock(spec=[
|
|
||||||
"get_persisted_pdu",
|
|
||||||
])
|
|
||||||
|
|
||||||
self.federation.set_handler(mock_handler)
|
|
||||||
|
|
||||||
mock_handler.get_persisted_pdu.return_value = (
|
|
||||||
defer.succeed(None)
|
|
||||||
)
|
|
||||||
|
|
||||||
(code, response) = yield self.mock_resource.trigger(
|
|
||||||
"GET",
|
|
||||||
"/_matrix/federation/v1/event/abc123def456/",
|
|
||||||
None
|
|
||||||
)
|
|
||||||
self.assertEquals(404, code)
|
|
||||||
|
|
||||||
# Now insert such a PDU
|
|
||||||
mock_handler.get_persisted_pdu.return_value = (
|
|
||||||
defer.succeed(
|
|
||||||
make_pdu(
|
|
||||||
event_id="abc123def456",
|
|
||||||
origin="red",
|
|
||||||
user_id="@a:red",
|
|
||||||
room_id="my-context",
|
|
||||||
type="m.text",
|
|
||||||
origin_server_ts=123456789001,
|
|
||||||
depth=1,
|
|
||||||
content={"text": "Here is the message"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
(code, response) = yield self.mock_resource.trigger(
|
|
||||||
"GET",
|
|
||||||
"/_matrix/federation/v1/event/abc123def456/",
|
|
||||||
None
|
|
||||||
)
|
|
||||||
self.assertEquals(200, code)
|
|
||||||
self.assertEquals(1, len(response["pdus"]))
|
|
||||||
self.assertEquals("m.text", response["pdus"][0]["type"])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_send_pdu(self):
|
|
||||||
self.mock_http_client.put_json.return_value = defer.succeed(
|
|
||||||
(200, "OK")
|
|
||||||
)
|
|
||||||
|
|
||||||
pdu = make_pdu(
|
|
||||||
event_id="abc123def456",
|
|
||||||
origin="red",
|
|
||||||
user_id="@a:red",
|
|
||||||
room_id="my-context",
|
|
||||||
type="m.text",
|
|
||||||
origin_server_ts=123456789001,
|
|
||||||
depth=1,
|
|
||||||
content={"text": "Here is the message"},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.federation.send_pdu(pdu, ["remote"])
|
|
||||||
|
|
||||||
self.mock_http_client.put_json.assert_called_with(
|
|
||||||
"remote",
|
|
||||||
path="/_matrix/federation/v1/send/1000000/",
|
|
||||||
data={
|
|
||||||
"origin_server_ts": 1000000,
|
|
||||||
"origin": "test",
|
|
||||||
"pdus": [
|
|
||||||
pdu.get_pdu_json(),
|
|
||||||
],
|
|
||||||
'pdu_failures': [],
|
|
||||||
},
|
|
||||||
json_data_callback=ANY,
|
|
||||||
long_retries=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_send_edu(self):
|
|
||||||
self.mock_http_client.put_json.return_value = defer.succeed(
|
|
||||||
(200, "OK")
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.federation.send_edu(
|
|
||||||
destination="remote",
|
|
||||||
edu_type="m.test",
|
|
||||||
content={"testing": "content here"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# MockClock ensures we can guess these timestamps
|
|
||||||
self.mock_http_client.put_json.assert_called_with(
|
|
||||||
"remote",
|
|
||||||
path="/_matrix/federation/v1/send/1000000/",
|
|
||||||
data={
|
|
||||||
"origin": "test",
|
|
||||||
"origin_server_ts": 1000000,
|
|
||||||
"pdus": [],
|
|
||||||
"edus": [
|
|
||||||
{
|
|
||||||
"edu_type": "m.test",
|
|
||||||
"content": {"testing": "content here"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
'pdu_failures': [],
|
|
||||||
},
|
|
||||||
json_data_callback=ANY,
|
|
||||||
long_retries=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_recv_edu(self):
|
|
||||||
recv_observer = Mock()
|
|
||||||
recv_observer.return_value = defer.succeed(())
|
|
||||||
|
|
||||||
self.federation.register_edu_handler("m.test", recv_observer)
|
|
||||||
|
|
||||||
yield self.mock_resource.trigger(
|
|
||||||
"PUT",
|
|
||||||
"/_matrix/federation/v1/send/1001000/",
|
|
||||||
"""{
|
|
||||||
"origin": "remote",
|
|
||||||
"origin_server_ts": 1001000,
|
|
||||||
"pdus": [],
|
|
||||||
"edus": [
|
|
||||||
{
|
|
||||||
"origin": "remote",
|
|
||||||
"destination": "test",
|
|
||||||
"edu_type": "m.test",
|
|
||||||
"content": {"testing": "reply here"}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}"""
|
|
||||||
)
|
|
||||||
|
|
||||||
recv_observer.assert_called_with(
|
|
||||||
"remote", {"testing": "reply here"}
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_send_query(self):
|
|
||||||
self.mock_http_client.get_json.return_value = defer.succeed(
|
|
||||||
{"your": "response"}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = yield self.federation.make_query(
|
|
||||||
destination="remote",
|
|
||||||
query_type="a-question",
|
|
||||||
args={"one": "1", "two": "2"},
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEquals({"your": "response"}, response)
|
|
||||||
|
|
||||||
self.mock_http_client.get_json.assert_called_with(
|
|
||||||
destination="remote",
|
|
||||||
path="/_matrix/federation/v1/query/a-question",
|
|
||||||
args={"one": "1", "two": "2"},
|
|
||||||
retry_on_dns_fail=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_recv_query(self):
|
|
||||||
recv_handler = Mock()
|
|
||||||
recv_handler.return_value = defer.succeed({"another": "response"})
|
|
||||||
|
|
||||||
self.federation.register_query_handler("a-question", recv_handler)
|
|
||||||
|
|
||||||
code, response = yield self.mock_resource.trigger(
|
|
||||||
"GET",
|
|
||||||
"/_matrix/federation/v1/query/a-question?three=3&four=4",
|
|
||||||
None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEquals(200, code)
|
|
||||||
self.assertEquals({"another": "response"}, response)
|
|
||||||
|
|
||||||
recv_handler.assert_called_with(
|
|
||||||
{"three": "3", "four": "4"}
|
|
||||||
)
|
|
|
@ -280,6 +280,15 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
||||||
}
|
}
|
||||||
EventSources.SOURCE_TYPES["presence"] = PresenceEventSource
|
EventSources.SOURCE_TYPES["presence"] = PresenceEventSource
|
||||||
|
|
||||||
|
clock = Mock(spec=[
|
||||||
|
"call_later",
|
||||||
|
"cancel_call_later",
|
||||||
|
"time_msec",
|
||||||
|
"looping_call",
|
||||||
|
])
|
||||||
|
|
||||||
|
clock.time_msec.return_value = 1000000
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
http_client=None,
|
http_client=None,
|
||||||
resource_for_client=self.mock_resource,
|
resource_for_client=self.mock_resource,
|
||||||
|
@ -289,16 +298,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
||||||
"get_presence_list",
|
"get_presence_list",
|
||||||
"get_rooms_for_user",
|
"get_rooms_for_user",
|
||||||
]),
|
]),
|
||||||
clock=Mock(spec=[
|
clock=clock,
|
||||||
"call_later",
|
|
||||||
"cancel_call_later",
|
|
||||||
"time_msec",
|
|
||||||
"looping_call",
|
|
||||||
]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hs.get_clock().time_msec.return_value = 1000000
|
|
||||||
|
|
||||||
def _get_user_by_req(req=None, allow_guest=False):
|
def _get_user_by_req(req=None, allow_guest=False):
|
||||||
return Requester(UserID.from_string(myid), "", False)
|
return Requester(UserID.from_string(myid), "", False)
|
||||||
|
|
||||||
|
|
|
@ -1045,8 +1045,13 @@ class RoomMessageListTestCase(RestTestCase):
|
||||||
self.assertTrue("end" in response)
|
self.assertTrue("end" in response)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_stream_token_is_rejected(self):
|
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
||||||
|
token = "s0_0_0_0_0"
|
||||||
(code, response) = yield self.mock_resource.trigger_get(
|
(code, response) = yield self.mock_resource.trigger_get(
|
||||||
"/rooms/%s/messages?access_token=x&from=s0_0_0_0" %
|
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||||
self.room_id)
|
(self.room_id, token))
|
||||||
self.assertEquals(400, code)
|
self.assertEquals(200, code)
|
||||||
|
self.assertTrue("start" in response)
|
||||||
|
self.assertEquals(token, response['start'])
|
||||||
|
self.assertTrue("chunk" in response)
|
||||||
|
self.assertTrue("end" in response)
|
||||||
|
|
|
@ -439,7 +439,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||||
f2 = self._write_config(suffix="2")
|
f2 = self._write_config(suffix="2")
|
||||||
|
|
||||||
config = Mock(app_service_config_files=[f1, f2])
|
config = Mock(app_service_config_files=[f1, f2])
|
||||||
hs = yield setup_test_homeserver(config=config)
|
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||||
|
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
|
|
||||||
|
@ -449,7 +449,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||||
f2 = self._write_config(id="id", suffix="2")
|
f2 = self._write_config(id="id", suffix="2")
|
||||||
|
|
||||||
config = Mock(app_service_config_files=[f1, f2])
|
config = Mock(app_service_config_files=[f1, f2])
|
||||||
hs = yield setup_test_homeserver(config=config)
|
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
|
@ -465,7 +465,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||||
f2 = self._write_config(as_token="as_token", suffix="2")
|
f2 = self._write_config(as_token="as_token", suffix="2")
|
||||||
|
|
||||||
config = Mock(app_service_config_files=[f1, f2])
|
config = Mock(app_service_config_files=[f1, f2])
|
||||||
hs = yield setup_test_homeserver(config=config)
|
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
|
|
|
@ -18,7 +18,6 @@ from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.storage.registration import RegistrationStore
|
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
|
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
|
@ -31,7 +30,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||||
hs = yield setup_test_homeserver()
|
hs = yield setup_test_homeserver()
|
||||||
self.db_pool = hs.get_db_pool()
|
self.db_pool = hs.get_db_pool()
|
||||||
|
|
||||||
self.store = RegistrationStore(hs)
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
self.user_id = "@my-user:test"
|
self.user_id = "@my-user:test"
|
||||||
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
|
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
|
||||||
|
|
|
@ -16,10 +16,10 @@
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.server import BaseHomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID, RoomAlias
|
from synapse.types import UserID, RoomAlias
|
||||||
|
|
||||||
mock_homeserver = BaseHomeServer(hostname="my.domain")
|
mock_homeserver = HomeServer(hostname="my.domain")
|
||||||
|
|
||||||
|
|
||||||
class UserIDTestCase(unittest.TestCase):
|
class UserIDTestCase(unittest.TestCase):
|
||||||
|
@ -34,7 +34,6 @@ class UserIDTestCase(unittest.TestCase):
|
||||||
with self.assertRaises(SynapseError):
|
with self.assertRaises(SynapseError):
|
||||||
UserID.from_string("")
|
UserID.from_string("")
|
||||||
|
|
||||||
|
|
||||||
def test_build(self):
|
def test_build(self):
|
||||||
user = UserID("5678efgh", "my.domain")
|
user = UserID("5678efgh", "my.domain")
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from .. import unittest
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache
|
from synapse.util.caches.treecache import TreeCache
|
||||||
|
|
||||||
|
|
||||||
class LruCacheTestCase(unittest.TestCase):
|
class LruCacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def test_get_set(self):
|
def test_get_set(self):
|
||||||
|
@ -72,3 +73,9 @@ class LruCacheTestCase(unittest.TestCase):
|
||||||
self.assertEquals(cache.get(("vehicles", "car")), "vroom")
|
self.assertEquals(cache.get(("vehicles", "car")), "vroom")
|
||||||
self.assertEquals(cache.get(("vehicles", "train")), "chuff")
|
self.assertEquals(cache.get(("vehicles", "train")), "chuff")
|
||||||
# Man from del_multi say "Yes".
|
# Man from del_multi say "Yes".
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
cache = LruCache(1)
|
||||||
|
cache["key"] = 1
|
||||||
|
cache.clear()
|
||||||
|
self.assertEquals(len(cache), 0)
|
||||||
|
|
|
@ -25,6 +25,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||||
cache[("b",)] = "B"
|
cache[("b",)] = "B"
|
||||||
self.assertEquals(cache.get(("a",)), "A")
|
self.assertEquals(cache.get(("a",)), "A")
|
||||||
self.assertEquals(cache.get(("b",)), "B")
|
self.assertEquals(cache.get(("b",)), "B")
|
||||||
|
self.assertEquals(len(cache), 2)
|
||||||
|
|
||||||
def test_pop_onelevel(self):
|
def test_pop_onelevel(self):
|
||||||
cache = TreeCache()
|
cache = TreeCache()
|
||||||
|
@ -33,6 +34,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||||
self.assertEquals(cache.pop(("a",)), "A")
|
self.assertEquals(cache.pop(("a",)), "A")
|
||||||
self.assertEquals(cache.pop(("a",)), None)
|
self.assertEquals(cache.pop(("a",)), None)
|
||||||
self.assertEquals(cache.get(("b",)), "B")
|
self.assertEquals(cache.get(("b",)), "B")
|
||||||
|
self.assertEquals(len(cache), 1)
|
||||||
|
|
||||||
def test_get_set_twolevel(self):
|
def test_get_set_twolevel(self):
|
||||||
cache = TreeCache()
|
cache = TreeCache()
|
||||||
|
@ -42,6 +44,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||||
self.assertEquals(cache.get(("a", "a")), "AA")
|
self.assertEquals(cache.get(("a", "a")), "AA")
|
||||||
self.assertEquals(cache.get(("a", "b")), "AB")
|
self.assertEquals(cache.get(("a", "b")), "AB")
|
||||||
self.assertEquals(cache.get(("b", "a")), "BA")
|
self.assertEquals(cache.get(("b", "a")), "BA")
|
||||||
|
self.assertEquals(len(cache), 3)
|
||||||
|
|
||||||
def test_pop_twolevel(self):
|
def test_pop_twolevel(self):
|
||||||
cache = TreeCache()
|
cache = TreeCache()
|
||||||
|
@ -53,6 +56,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||||
self.assertEquals(cache.get(("a", "b")), "AB")
|
self.assertEquals(cache.get(("a", "b")), "AB")
|
||||||
self.assertEquals(cache.pop(("b", "a")), "BA")
|
self.assertEquals(cache.pop(("b", "a")), "BA")
|
||||||
self.assertEquals(cache.pop(("b", "a")), None)
|
self.assertEquals(cache.pop(("b", "a")), None)
|
||||||
|
self.assertEquals(len(cache), 1)
|
||||||
|
|
||||||
def test_pop_mixedlevel(self):
|
def test_pop_mixedlevel(self):
|
||||||
cache = TreeCache()
|
cache = TreeCache()
|
||||||
|
@ -64,3 +68,11 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||||
self.assertEquals(cache.get(("a", "a")), None)
|
self.assertEquals(cache.get(("a", "a")), None)
|
||||||
self.assertEquals(cache.get(("a", "b")), None)
|
self.assertEquals(cache.get(("a", "b")), None)
|
||||||
self.assertEquals(cache.get(("b", "a")), "BA")
|
self.assertEquals(cache.get(("b", "a")), "BA")
|
||||||
|
self.assertEquals(len(cache), 1)
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
cache = TreeCache()
|
||||||
|
cache[("a",)] = "A"
|
||||||
|
cache[("b",)] = "B"
|
||||||
|
cache.clear()
|
||||||
|
self.assertEquals(len(cache), 0)
|
||||||
|
|
|
@ -19,6 +19,8 @@ from synapse.api.constants import EventTypes
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.federation.transport import server
|
||||||
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
|
@ -58,8 +60,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
name, db_pool=db_pool, config=config,
|
name, db_pool=db_pool, config=config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=create_engine("sqlite3"),
|
database_engine=create_engine("sqlite3"),
|
||||||
|
get_db_conn=db_pool.get_db_conn,
|
||||||
**kargs
|
**kargs
|
||||||
)
|
)
|
||||||
|
hs.setup()
|
||||||
else:
|
else:
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
name, db_pool=None, datastore=datastore, config=config,
|
name, db_pool=None, datastore=datastore, config=config,
|
||||||
|
@ -80,6 +84,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
|
|
||||||
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
|
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
|
||||||
|
|
||||||
|
fed = kargs.get("resource_for_federation", None)
|
||||||
|
if fed:
|
||||||
|
server.register_servlets(
|
||||||
|
hs,
|
||||||
|
resource=fed,
|
||||||
|
authenticator=server.Authenticator(hs),
|
||||||
|
ratelimiter=FederationRateLimiter(
|
||||||
|
hs.get_clock(),
|
||||||
|
window_size=hs.config.federation_rc_window_size,
|
||||||
|
sleep_limit=hs.config.federation_rc_sleep_limit,
|
||||||
|
sleep_msec=hs.config.federation_rc_sleep_delay,
|
||||||
|
reject_limit=hs.config.federation_rc_reject_limit,
|
||||||
|
concurrent_requests=hs.config.federation_rc_concurrent
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue(hs)
|
defer.returnValue(hs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -262,6 +282,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
|
||||||
lambda conn: prepare_database(conn, engine)
|
lambda conn: prepare_database(conn, engine)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_db_conn(self):
|
||||||
|
conn = self.connect()
|
||||||
|
engine = create_engine("sqlite3")
|
||||||
|
prepare_database(conn, engine)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
class MemoryDataStore(object):
|
class MemoryDataStore(object):
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue