0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-09 11:32:01 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.12.1

This commit is contained in:
Erik Johnston 2016-01-29 13:52:12 +00:00
commit fd142c29d9
40 changed files with 939 additions and 1034 deletions

View file

@ -26,7 +26,7 @@ TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else else
(cd .sytest-base; git fetch) (cd .sytest-base; git fetch -p)
fi fi
rm -rf sytest rm -rf sytest
@ -52,7 +52,7 @@ RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES=$RUN_POSTGRES:$port RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF cat > localhost-$port/database.yaml << EOF
name: psycopg2 name: psycopg2
args: args:
@ -62,7 +62,7 @@ EOF
done done
# Run if both postgresql databases exist # Run if both postgresql databases exist
if test $RUN_POSTGRES = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL"; echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2 $TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \ ./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \

View file

@ -15,6 +15,8 @@
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
import ujson as json
class Filtering(object): class Filtering(object):
@ -149,6 +151,9 @@ class FilterCollection(object):
"include_leave", False "include_leave", False
) )
def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
def get_filter_json(self): def get_filter_json(self):
return self._filter_json return self._filter_json

View file

@ -50,16 +50,14 @@ from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer from twisted.internet import reactor, task, defer
from twisted.application import service from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
@ -69,6 +67,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse import events from synapse import events
@ -95,19 +94,8 @@ def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
class SynapseHomeServer(HomeServer): def build_resource_for_web_client(hs):
webclient_path = hs.get_config().web_client_location
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_client_resource(self):
return ClientRestResource(self)
def build_resource_for_federation(self):
return JsonResource(self)
def build_resource_for_web_client(self):
webclient_path = self.get_config().web_client_location
if not webclient_path: if not webclient_path:
try: try:
import syweb import syweb
@ -135,40 +123,8 @@ class SynapseHomeServer(HomeServer):
# return GzipFile(webclient_path) # TODO configurable? # return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
def build_resource_for_content_repo(self):
return ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
)
def build_resource_for_media_repository(self):
return MediaRepositoryResource(self)
def build_resource_for_server_key(self):
return LocalKey(self)
def build_resource_for_server_key_v2(self):
return KeyApiV2Resource(self)
def build_resource_for_metrics(self):
if self.get_config().enable_metrics:
return MetricsResource(self)
else:
return None
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config): def _listener_http(self, config, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_address = listener_config.get("bind_address", "")
@ -178,13 +134,11 @@ class SynapseHomeServer(HomeServer):
if tls and config.no_tls: if tls and config.no_tls:
return return
metrics_resource = self.get_resource_for_metrics()
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
for name in res["names"]: for name in res["names"]:
if name == "client": if name == "client":
client_resource = self.get_client_resource() client_resource = ClientRestResource(self)
if res["compress"]: if res["compress"]:
client_resource = gz_wrap(client_resource) client_resource = gz_wrap(client_resource)
@ -198,31 +152,35 @@ class SynapseHomeServer(HomeServer):
if name == "federation": if name == "federation":
resources.update({ resources.update({
FEDERATION_PREFIX: self.get_resource_for_federation(), FEDERATION_PREFIX: TransportLayerServer(self),
}) })
if name in ["static", "client"]: if name in ["static", "client"]:
resources.update({ resources.update({
STATIC_PREFIX: self.get_resource_for_static_content(), STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
),
}) })
if name in ["media", "federation", "client"]: if name in ["media", "federation", "client"]:
resources.update({ resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(), MEDIA_PREFIX: MediaRepositoryResource(self),
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(), CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
),
}) })
if name in ["keys", "federation"]: if name in ["keys", "federation"]:
resources.update({ resources.update({
SERVER_KEY_PREFIX: self.get_resource_for_server_key(), SERVER_KEY_PREFIX: LocalKey(self),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(), SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
}) })
if name == "webclient": if name == "webclient":
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client() resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and metrics_resource: if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = metrics_resource resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources) root_resource = create_resource_tree(resources)
if tls: if tls:
@ -296,6 +254,18 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn)
return db_conn
def quit_with_error(error_string): def quit_with_error(error_string):
message_lines = error_string.split("\n") message_lines = error_string.split("\n")
@ -432,13 +402,7 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name']) logger.info("Preparing database: %s...", config.database_config['name'])
try: try:
db_conn = database_engine.module.connect( db_conn = hs.get_db_conn()
**{
k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn) database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine) hs.run_startup_checks(db_conn, database_engine)
@ -453,14 +417,18 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name']) logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
hs.start_listening() hs.start_listening()
def start():
hs.get_pusherpool().start() hs.get_pusherpool().start()
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
reactor.callWhenRunning(start)
return hs return hs
@ -675,7 +643,7 @@ def _resource_id(resource, path_seg):
the mapping should looks like _resource_id(A,C) = B. the mapping should looks like _resource_id(A,C) = B.
Args: Args:
resource (Resource): The *parent* Resource resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached. path_seg (str): The name of the child Resource to be attached.
Returns: Returns:
str: A unique string which can be a key to the child Resource. str: A unique string which can be a key to the child Resource.

View file

@ -17,15 +17,10 @@
""" """
from .replication import ReplicationLayer from .replication import ReplicationLayer
from .transport import TransportLayer from .transport.client import TransportLayerClient
def initialize_http_replication(homeserver): def initialize_http_replication(homeserver):
transport = TransportLayer( transport = TransportLayerClient(homeserver)
homeserver,
homeserver.hostname,
server=homeserver.get_resource_for_federation(),
client=homeserver.get_http_client()
)
return ReplicationLayer(homeserver, transport) return ReplicationLayer(homeserver, transport)

View file

@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
self.federation_client = self self.federation_client = self

View file

@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
support HTTPS), however individual pairings of servers may decide to support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol. communicate over a different (albeit still reliable) protocol.
""" """
from .server import TransportLayerServer
from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates
transactions and other requests to/from HTTP.
Attributes:
server_name (str): Local home server host
server (synapse.http.server.HttpServer): the http server to
register listeners on
client (synapse.http.client.HttpClient): the http client used to
send requests
request_handler (TransportRequestHandler): The handler to fire when we
receive requests for data.
received_handler (TransportReceivedHandler): The handler to fire when
we receive data.
"""
def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
server (synapse.protocol.http.HttpServer): the http server to
register listeners on
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name
self.server = server
self.client = client
self.request_handler = None
self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View file

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class TransportLayerClient(object): class TransportLayerClient(object):
"""Sends federation HTTP requests to other servers""" """Sends federation HTTP requests to other servers"""
def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_http_client()
@log_function @log_function
def get_room_state(self, destination, room_id, event_id): def get_room_state(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the """ Requests all state for a given room from the given server at the

View file

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function from synapse.http.server import JsonResource
from synapse.util.ratelimitutils import FederationRateLimiter
import functools import functools
import logging import logging
@ -28,9 +29,41 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TransportLayerServer(object): class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests""" """Handles incoming federation HTTP requests"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
super(TransportLayerServer, self).__init__(hs)
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent,
)
self.register_servlets()
def register_servlets(self):
register_servlets(
self.hs,
resource=self,
ratelimiter=self.ratelimiter,
authenticator=self.authenticator,
)
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request): def authenticate_request(self, request):
@ -98,37 +131,9 @@ class TransportLayerServer(object):
defer.returnValue((origin, content)) defer.returnValue((origin, content))
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
Args:
handler (TransportReceivedHandler)
"""
FederationSendServlet(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
).register(self.server)
@log_function
def register_request_handler(self, handler):
""" Register a handler that will be fired when we get asked for data.
Args:
handler (TransportRequestHandler)
"""
for servletclass in SERVLET_CLASSES:
servletclass(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
class BaseFederationServlet(object): class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter): def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/" PATH = "/send/([^/]*)/"
def __init__(self, handler, server_name, **kwargs): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(handler, **kwargs) super(FederationSendServlet, self).__init__(
handler, server_name=server_name, **kwargs
)
self.server_name = server_name self.server_name = server_name
# This is when someone is trying to send us a bunch of data. # This is when someone is trying to send us a bunch of data.
@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet):
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
@ -451,3 +459,13 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
) )
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in SERVLET_CLASSES:
servletclass(
handler=hs.get_replication_layer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)

View file

@ -1186,7 +1186,13 @@ class FederationHandler(BaseHandler):
try: try:
self.auth.check(e, auth_events=auth_for_e) self.auth.check(e, auth_events=auth_for_e)
except AuthError as err: except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
# rooms whose senders do not have the correct sigil; these
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
logger.warn( logger.warn(
"Rejecting %s because %s", "Rejecting %s because %s",
e.event_id, err.msg e.event_id, err.msg

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError, AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -105,8 +105,6 @@ class MessageHandler(BaseHandler):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token) room_token = RoomStreamToken.parse(room_token)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace( pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token) "room_key", str(room_token)
@ -117,26 +115,30 @@ class MessageHandler(BaseHandler):
membership, member_event_id = yield self._check_in_room_or_world_readable( membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id room_id, user_id
) )
if source_config.direction == 'b':
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
room_id, room_token.stream
)
if membership == Membership.LEAVE: if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room. # they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event( leave_token = yield self.store.get_topological_token_for_event(
member_event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological: if leave_token.topological < max_topo:
source_config.from_key = str(leave_token) source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological room_id, max_topo
) )
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(

View file

@ -72,7 +72,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
) )
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
"room_id", # str "room_id", # str
"timeline", # TimelineBatch "timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
@ -298,46 +298,19 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token room_id, sync_config, now_token, since_token=timeline_since_token
) )
notifs = yield self.unread_notifs_for_room_id( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, ephemeral_by_room room_id, sync_config,
now_token=now_token,
since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
all_ephemeral_by_room=ephemeral_by_room,
batch=batch,
full_state=True,
) )
unread_notifications = {} defer.returnValue(room_sync)
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
current_state = yield self.get_state_at(room_id, now_token)
current_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
current_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
defer.returnValue(JoinedSyncResult(
room_id=room_id,
timeline=batch,
state=current_state,
ephemeral=ephemeral,
account_data=account_data,
unread_notifications=unread_notifications,
))
def account_data_for_user(self, account_data): def account_data_for_user(self, account_data):
account_data_events = [] account_data_events = []
@ -429,44 +402,20 @@ class SyncHandler(BaseHandler):
defer.returnValue((now_token, ephemeral_by_room)) defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config, def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token, leave_event_id, leave_token,
timeline_since_token, tags_by_room, timeline_since_token, tags_by_room,
account_data_by_room): account_data_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred ArchivedSyncResult.
""" """
batch = yield self.load_filtered_recents( return self.incremental_sync_for_archived_room(
room_id, sync_config, leave_token, since_token=timeline_since_token sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
account_data_by_room, full_state=True, leave_token=leave_token,
) )
leave_state = yield self.store.get_state_for_event(leave_event_id)
leave_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
leave_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state,
account_data=account_data,
))
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token): def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to """ Get the incremental delta needed to bring the client up to
@ -512,154 +461,127 @@ class SyncHandler(BaseHandler):
sync_config.user sync_config.user
) )
user_id = sync_config.user.to_string()
timeline_limit = sync_config.filter_collection.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(),
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
)
tags_by_room = yield self.store.get_updated_tags( tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(), user_id,
since_token.account_data_key, since_token.account_data_key,
) )
account_data, account_data_by_room = ( account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user( yield self.store.get_updated_account_data_for_user(
sync_config.user.to_string(), user_id,
since_token.account_data_key, since_token.account_data_key,
) )
) )
joined = [] # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_room_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id = {}
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
archived = [] archived = []
if len(room_events) <= timeline_limit: invited = []
# There is no gap in any of the rooms. Therefore we can just for room_id, events in mem_change_events_by_room_id.items():
# partition the new events by room and return them. non_joins = [e for e in events if e.membership != Membership.JOIN]
logger.debug("Got %i events for incremental sync - not limited", has_join = len(non_joins) != len(events)
len(room_events))
invite_events = [] # We want to figure out if we joined the room at some point since
leave_events = [] # the last sync (even if we have since left). This is to make sure
events_by_room_id = {} # we do send down the room, and with full state, where necessary
for event in room_events: if room_id in joined_room_ids or has_join:
events_by_room_id.setdefault(event.room_id, []).append(event) old_state = yield self.get_state_at(room_id, since_token)
if event.room_id not in joined_room_ids: old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
if (event.type == EventTypes.Member if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
and event.state_key == sync_config.user.to_string()): newly_joined_rooms.append(room_id)
if event.membership == Membership.INVITE:
invite_events.append(event)
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_events.append(event)
for room_id in joined_room_ids: if room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) continue
logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
if recents: if not non_joins:
prev_batch = now_token.copy_and_replace( continue
"room_key", recents[0].internal_metadata.before
)
else:
prev_batch = now_token
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
recents = sync_config.filter_collection.filter_room_timeline(recents)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state.values()
)
}
acc_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
acc_data = sync_config.filter_collection.filter_room_account_data(
acc_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=TimelineBatch(
events=recents,
prev_batch=prev_batch,
limited=limited,
),
state=state,
ephemeral=ephemeral,
account_data=acc_data,
unread_notifications={},
)
logger.debug("Result for room %s: %r", room_id, room_sync)
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if room_sync: if room_sync:
notifs = yield self.unread_notifs_for_room_id( invited.append(room_sync)
room_id, sync_config, all_ephemeral_by_room
)
if notifs is not None: # Always include leave/ban events. Just take the last one.
notif_dict = room_sync.unread_notifications # TODO: How do we handle ban -> leave in same batch?
notif_dict["notification_count"] = len(notifs) leave_events = [
notif_dict["highlight_count"] = len([ e for e in non_joins
1 for notif in notifs if e.membership in (Membership.LEAVE, Membership.BAN)
if _action_has_highlight(notif["actions"]) ]
])
joined.append(room_sync) if leave_events:
leave_event = leave_events[-1]
else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string()
)
leave_events = yield self.store.get_leave_and_ban_events_for_user(
sync_config.user.to_string()
)
for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room, account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
)
if room_sync:
joined.append(room_sync)
for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room( room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token, tags_by_room, sync_config, room_id, leave_event.event_id, since_token,
account_data_by_room tags_by_room, account_data_by_room,
full_state=room_id in newly_joined_rooms
) )
if room_sync: if room_sync:
archived.append(room_sync) archived.append(room_sync)
invited = [ # Get all events for rooms we're currently joined to.
InvitedSyncResult(room_id=event.room_id, invite=event) room_to_events = yield self.store.get_room_events_stream_for_rooms(
for event in invite_events room_ids=joined_room_ids,
] from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
)
joined = []
# We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about.
for room_id in joined_room_ids:
room_entry = room_to_events.get(room_id, None)
if room_entry:
events, start_key = room_entry
prev_batch_token = now_token.copy_and_replace("room_key", start_key)
newly_joined_room = room_id in newly_joined_rooms
full_state = newly_joined_room
batch = yield self.load_filtered_recents(
room_id, sync_config, prev_batch_token,
since_token=since_token,
recents=events,
newly_joined_room=newly_joined_room,
)
else:
batch = TimelineBatch(
events=[],
prev_batch=since_token,
limited=False,
)
full_state = False
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id=room_id,
sync_config=sync_config,
since_token=since_token,
now_token=now_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
batch=batch,
full_state=full_state,
)
if room_sync:
joined.append(room_sync)
account_data_for_user = sync_config.filter_collection.filter_account_data( account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data) self.account_data_for_user(account_data)
@ -680,28 +602,40 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None, recents=None, newly_joined_room=False):
""" """
:returns a Deferred TimelineBatch :returns a Deferred TimelineBatch
""" """
limited = True
recents = []
filtering_factor = 2 filtering_factor = 2
timeline_limit = sync_config.filter_collection.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 100) load_limit = max(timeline_limit * filtering_factor, 10)
max_repeat = 3 # Only try a few times per room, otherwise max_repeat = 5 # Only try a few times per room, otherwise
room_key = now_token.room_key room_key = now_token.room_key
end_key = room_key end_key = room_key
limited = recents is None or newly_joined_room or timeline_limit < len(recents)
if recents is not None:
recents = sync_config.filter_collection.filter_room_timeline(recents)
recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
recents,
is_peeking=sync_config.is_guest,
)
else:
recents = []
since_key = None
if since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat: while limited and len(recents) < timeline_limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room( events, end_key = yield self.store.get_room_events_stream_for_room(
room_id, room_id,
limit=load_limit + 1, limit=load_limit + 1,
from_token=since_token.room_key if since_token else None, from_key=since_key,
end_token=end_key, to_key=end_key,
) )
room_key, _ = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter_collection.filter_room_timeline(events) loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
@ -710,8 +644,10 @@ class SyncHandler(BaseHandler):
) )
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents
if len(events) <= load_limit: if len(events) <= load_limit:
limited = False limited = False
break
max_repeat -= 1 max_repeat -= 1
if len(recents) > timeline_limit: if len(recents) > timeline_limit:
@ -724,7 +660,9 @@ class SyncHandler(BaseHandler):
) )
defer.returnValue(TimelineBatch( defer.returnValue(TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=limited events=recents,
prev_batch=prev_batch_token,
limited=limited or newly_joined_room
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -732,25 +670,12 @@ class SyncHandler(BaseHandler):
since_token, now_token, since_token, now_token,
ephemeral_by_room, tags_by_room, ephemeral_by_room, tags_by_room,
account_data_by_room, account_data_by_room,
all_ephemeral_by_room): all_ephemeral_by_room,
""" Get the incremental delta needed to bring the client up to date for batch, full_state=False):
the room. Gives the client the most recent events and the changes to if full_state:
state. state = yield self.get_state_at(room_id, now_token)
Returns:
A Deferred JoinedSyncResult
"""
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed. elif batch.limited:
batch = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logger.debug("Recents %r", batch)
if batch.limited:
current_state = yield self.get_state_at(room_id, now_token) current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
@ -772,17 +697,6 @@ class SyncHandler(BaseHandler):
if just_joined: if just_joined:
state = yield self.get_state_at(room_id, now_token) state = yield self.get_state_at(room_id, now_token)
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
)
unread_notifications = {}
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
state = { state = {
(e.type, e.state_key): e (e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values()) for e in sync_config.filter_collection.filter_room_state(state.values())
@ -800,6 +714,7 @@ class SyncHandler(BaseHandler):
ephemeral_by_room.get(room_id, []) ephemeral_by_room.get(room_id, [])
) )
unread_notifications = {}
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
@ -809,41 +724,55 @@ class SyncHandler(BaseHandler):
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
) )
if room_sync:
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
)
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
logger.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event, def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
since_token, tags_by_room, since_token, tags_by_room,
account_data_by_room): account_data_by_room, full_state,
leave_token=None):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the archived room. the archived room.
Returns: Returns:
A Deferred ArchivedSyncResult A Deferred ArchivedSyncResult
""" """
if not leave_token:
stream_token = yield self.store.get_stream_token_for_event( stream_token = yield self.store.get_stream_token_for_event(
leave_event.event_id leave_event_id
) )
leave_token = since_token.copy_and_replace("room_key", stream_token) leave_token = since_token.copy_and_replace("room_key", stream_token)
if since_token.is_after(leave_token): if since_token and since_token.is_after(leave_token):
defer.returnValue(None) defer.returnValue(None)
batch = yield self.load_filtered_recents( batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token, room_id, sync_config, leave_token, since_token,
) )
logger.debug("Recents %r", batch) logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event( state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id leave_event_id
) )
if not full_state:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token room_id, stream_position=since_token
) )
state_events_delta = yield self.compute_state_delta( state_events_delta = yield self.compute_state_delta(
@ -851,6 +780,8 @@ class SyncHandler(BaseHandler):
previous_state=state_at_previous_sync, previous_state=state_at_previous_sync,
current_state=state_events_at_leave, current_state=state_events_at_leave,
) )
else:
state_events_delta = state_events_at_leave
state_events_delta = { state_events_delta = {
(e.type, e.state_key): e (e.type, e.state_key): e
@ -860,7 +791,7 @@ class SyncHandler(BaseHandler):
} }
account_data = self.account_data_for_room( account_data = self.account_data_for_room(
leave_event.room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
) )
account_data = sync_config.filter_collection.filter_room_account_data( account_data = sync_config.filter_collection.filter_room_account_data(
@ -868,7 +799,7 @@ class SyncHandler(BaseHandler):
) )
room_sync = ArchivedSyncResult( room_sync = ArchivedSyncResult(
room_id=leave_event.room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
account_data=account_data, account_data=account_data,

View file

@ -22,7 +22,7 @@ REQUIREMENTS = {
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"], "Twisted>=15.1.0": ["twisted>=15.1.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],

View file

@ -66,11 +66,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
before = request.args.get("before", None) before = request.args.get("before", None)
if before and len(before): if before:
before = before[0] before = _namespaced_rule_id(spec, before[0])
after = request.args.get("after", None) after = request.args.get("after", None)
if after and len(after): if after:
after = after[0] after = _namespaced_rule_id(spec, after[0])
try: try:
yield self.hs.get_datastore().add_push_rule( yield self.hs.get_datastore().add_push_rule(
@ -452,11 +453,15 @@ def _strip_device_condition(rule):
def _namespaced_rule_id_from_spec(spec): def _namespaced_rule_id_from_spec(spec):
return _namespaced_rule_id(spec, spec['rule_id'])
def _namespaced_rule_id(spec, rule_id):
if spec['scope'] == 'global': if spec['scope'] == 'global':
scope = 'global' scope = 'global'
else: else:
scope = 'device/%s' % (spec['profile_tag']) scope = 'device/%s' % (spec['profile_tag'])
return "%s/%s/%s" % (scope, spec['template'], spec['rule_id']) return "%s/%s/%s" % (scope, spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id): def _rule_id_from_namespaced(in_rule_id):

View file

@ -20,6 +20,8 @@
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -36,8 +38,15 @@ from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering from synapse.api.filtering import Filtering
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
class BaseHomeServer(object): import logging
logger = logging.getLogger(__name__)
class HomeServer(object):
"""A basic homeserver object without lazy component builders. """A basic homeserver object without lazy component builders.
This will need all of the components it requires to either be passed as This will need all of the components it requires to either be passed as
@ -98,39 +107,18 @@ class BaseHomeServer(object):
self.hostname = hostname self.hostname = hostname
self._building = {} self._building = {}
self.clock = Clock()
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
# Other kwargs are explicit dependencies # Other kwargs are explicit dependencies
for depname in kwargs: for depname in kwargs:
setattr(self, depname, kwargs[depname]) setattr(self, depname, kwargs[depname])
@classmethod def setup(self):
def _make_dependency_method(cls, depname): logger.info("Setting up.")
def _get(self): self.datastore = DataStore(self.get_db_conn(), self)
if hasattr(self, depname): logger.info("Finished setting up.")
return getattr(self, depname)
if hasattr(self, "build_%s" % (depname)):
# Prevent cyclic dependencies from deadlocking
if depname in self._building:
raise ValueError("Cyclic dependency while building %s" % (
depname,
))
self._building[depname] = 1
builder = getattr(self, "build_%s" % (depname))
dep = builder()
setattr(self, depname, dep)
del self._building[depname]
return dep
raise NotImplementedError(
"%s has no %s nor a builder for it" % (
type(self).__name__, depname,
)
)
setattr(BaseHomeServer, "get_%s" % (depname), _get)
def get_ip_from_request(self, request): def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type. # X-Forwarded-For is handled by our custom request type.
@ -142,33 +130,9 @@ class BaseHomeServer(object):
def is_mine_id(self, string): def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname return string.split(":", 1)[1] == self.hostname
# Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname)
class HomeServer(BaseHomeServer):
"""A homeserver object that will construct most of its dependencies as
required.
It still requires the following to be specified by the caller:
resource_for_client
resource_for_web_client
resource_for_federation
resource_for_content_repo
http_client
db_pool
"""
def build_clock(self):
return Clock()
def build_replication_layer(self): def build_replication_layer(self):
return initialize_http_replication(self) return initialize_http_replication(self)
def build_datastore(self):
return DataStore(self)
def build_handlers(self): def build_handlers(self):
return Handlers(self) return Handlers(self)
@ -179,10 +143,9 @@ class HomeServer(BaseHomeServer):
return Auth(self) return Auth(self)
def build_http_client_context_factory(self): def build_http_client_context_factory(self):
config = self.get_config()
return ( return (
InsecureInterceptableContextFactory() InsecureInterceptableContextFactory()
if config.use_insecure_ssl_client_just_for_testing_do_not_use if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS() else BrowserLikePolicyForHTTPS()
) )
@ -201,15 +164,9 @@ class HomeServer(BaseHomeServer):
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)
def build_distributor(self):
return Distributor()
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)
def build_ratelimiter(self):
return Ratelimiter()
def build_keyring(self): def build_keyring(self):
return Keyring(self) return Keyring(self)
@ -224,3 +181,55 @@ class HomeServer(BaseHomeServer):
def build_pusherpool(self): def build_pusherpool(self):
return PusherPool(self) return PusherPool(self)
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
def _make_dependency_method(depname):
def _get(hs):
try:
return getattr(hs, depname)
except AttributeError:
pass
try:
builder = getattr(hs, "build_%s" % (depname))
except AttributeError:
builder = None
if builder:
# Prevent cyclic dependencies from deadlocking
if depname in hs._building:
raise ValueError("Cyclic dependency while building %s" % (
depname,
))
hs._building[depname] = 1
dep = builder()
setattr(hs, depname, dep)
del hs._building[depname]
return dep
raise NotImplementedError(
"%s has no %s nor a builder for it" % (
type(hs).__name__, depname,
)
)
setattr(HomeServer, "get_%s" % (depname), _get)
# Build magic accessors for every dependency
for depname in HomeServer.DEPENDENCIES:
_make_dependency_method(depname)

View file

@ -46,6 +46,9 @@ from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
import logging import logging
@ -79,18 +82,43 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore EventPushActionsStore
): ):
def __init__(self, hs): def __init__(self, db_conn, hs):
super(DataStore, self).__init__(hs)
self.hs = hs self.hs = hs
self.min_token_deferred = self._get_min_token() cur = db_conn.cursor()
self.min_token = None try:
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
self.min_stream_token = min(self.min_stream_token, -1)
finally:
cur.close()
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
) )
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering"
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
super(DataStore, self).__init__(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent): def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View file

@ -15,13 +15,11 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
import synapse.metrics import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
@ -175,16 +173,6 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()
@ -345,7 +333,8 @@ class SQLBaseStore(object):
defer.returnValue(result) defer.returnValue(result)
def cursor_to_dict(self, cursor): @staticmethod
def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.
Args: Args:
@ -402,8 +391,8 @@ class SQLBaseStore(object):
if not or_ignore: if not or_ignore:
raise raise
@log_function @staticmethod
def _simple_insert_txn(self, txn, table, values): def _simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items()) keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % ( sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@ -414,7 +403,8 @@ class SQLBaseStore(object):
txn.execute(sql, vals) txn.execute(sql, vals)
def _simple_insert_many_txn(self, txn, table, values): @staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values: if not values:
return return
@ -537,9 +527,10 @@ class SQLBaseStore(object):
table, keyvalues, retcol, allow_none=allow_none, table, keyvalues, retcol, allow_none=allow_none,
) )
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol, @classmethod
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False): allow_none=False):
ret = self._simple_select_onecol_txn( ret = cls._simple_select_onecol_txn(
txn, txn,
table=table, table=table,
keyvalues=keyvalues, keyvalues=keyvalues,
@ -554,7 +545,8 @@ class SQLBaseStore(object):
else: else:
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): @staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = ( sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s" "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % { ) % {
@ -603,7 +595,8 @@ class SQLBaseStore(object):
table, keyvalues, retcols table, keyvalues, retcols
) )
def _simple_select_list_txn(self, txn, table, keyvalues, retcols): @classmethod
def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -627,7 +620,7 @@ class SQLBaseStore(object):
) )
txn.execute(sql) txn.execute(sql)
return self.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@defer.inlineCallbacks @defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols, def _simple_select_many_batch(self, table, column, iterable, retcols,
@ -662,7 +655,8 @@ class SQLBaseStore(object):
defer.returnValue(results) defer.returnValue(results)
def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols): @classmethod
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -699,7 +693,7 @@ class SQLBaseStore(object):
) )
txn.execute(sql, values) txn.execute(sql, values)
return self.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"): desc="_simple_update_one"):
@ -726,7 +720,8 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues, table, keyvalues, updatevalues,
) )
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues): @staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % ( update_sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
", ".join("%s = ?" % (k,) for k in updatevalues), ", ".join("%s = ?" % (k,) for k in updatevalues),
@ -743,7 +738,8 @@ class SQLBaseStore(object):
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
def _simple_select_one_txn(self, txn, table, keyvalues, retcols, @staticmethod
def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False): allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % ( select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
@ -784,7 +780,8 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func) return self.runInteraction(desc, func)
def _simple_delete_txn(self, txn, table, keyvalues): @staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer from twisted.internet import defer
import ujson as json import ujson as json
@ -23,6 +24,14 @@ logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore): class AccountDataStore(SQLBaseStore):
def __init__(self, hs):
super(AccountDataStore, self).__init__(hs)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_max_token(None),
max_size=10000,
)
def get_account_data_for_user(self, user_id): def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user. """Get all the client account_data for a user.
@ -83,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn "get_account_data_for_room", get_account_data_for_room_txn
) )
def get_updated_account_data_for_user(self, user_id, stream_id): def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
"""Get all the client account_data for a that's changed. """Get all the client account_data for a that's changed.
Args: Args:
@ -120,6 +129,12 @@ class AccountDataStore(SQLBaseStore):
return (global_account_data, account_data_by_room) return (global_account_data, account_data_by_room)
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
return ({}, {})
return self.runInteraction( return self.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
) )
@ -186,6 +201,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json, "content": content_json,
} }
) )
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
self._update_max_stream_id(txn, next_id) self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with (yield self._account_data_id_gen.get_next(self)) as next_id:

View file

@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
return return
if backfilled: if backfilled:
if not self.min_token_deferred.called: start = self.min_stream_token - 1
yield self.min_token_deferred self.min_stream_token -= len(events_and_contexts) + 1
start = self.min_token - 1 stream_orderings = range(start, self.min_stream_token, -1)
self.min_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_token, -1)
@contextmanager @contextmanager
def stream_ordering_manager(): def stream_ordering_manager():
@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
is_new_state=True, current_state=None): is_new_state=True, current_state=None):
stream_ordering = None stream_ordering = None
if backfilled: if backfilled:
if not self.min_token_deferred.called: self.min_stream_token -= 1
yield self.min_token_deferred stream_ordering = self.min_stream_token
self.min_token -= 1
stream_ordering = self.min_token
if stream_ordering is None: if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self) stream_ordering_manager = yield self._stream_id_gen.get_next(self)
@ -214,6 +210,12 @@ class EventsStore(SQLBaseStore):
for event, _ in events_and_contexts: for event, _ in events_and_contexts:
txn.call_after(self._invalidate_get_event_cache, event.event_id) txn.call_after(self._invalidate_get_event_cache, event.event_id)
if not backfilled:
txn.call_after(
self._events_stream_cache.entity_has_changed,
event.room_id, event.internal_metadata.stream_ordering,
)
depth_updates = {} depth_updates = {}
for event, _ in events_and_contexts: for event, _ in events_and_contexts:
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():

View file

@ -16,12 +16,13 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
import simplejson as json import simplejson as json
class FilteringStore(SQLBaseStore): class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks @cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol( def_json = yield self._simple_select_one_onecol(
table="user_filters", table="user_filters",

View file

@ -130,7 +130,8 @@ class PushRuleStore(SQLBaseStore):
def _add_push_rule_relative_txn(self, txn, user_id, **kwargs): def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
after = kwargs.pop("after", None) after = kwargs.pop("after", None)
relative_to_rule = kwargs.pop("before", after) before = kwargs.pop("before", None)
relative_to_rule = before or after
res = self._simple_select_one_txn( res = self._simple_select_one_txn(
txn, txn,

View file

@ -15,11 +15,10 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches import cache_counter, caches_by_name from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer from twisted.internet import defer
from blist import sorteddict
import logging import logging
import ujson as json import ujson as json
@ -31,7 +30,9 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = _RoomStreamChangeCache() self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
)
@cached(num_args=2) @cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type): def get_receipts_for_room(self, room_id, receipt_type):
@ -76,8 +77,8 @@ class ReceiptsStore(SQLBaseStore):
room_ids = set(room_ids) room_ids = set(room_ids)
if from_key: if from_key:
room_ids = yield self._receipts_stream_cache.get_rooms_changed( room_ids = yield self._receipts_stream_cache.get_entities_changed(
self, room_ids, from_key room_ids, from_key
) )
results = yield self._get_linearized_receipts_for_rooms( results = yield self._get_linearized_receipts_for_rooms(
@ -220,6 +221,11 @@ class ReceiptsStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
txn.call_after(
self._receipts_stream_cache.entity_has_changed,
room_id, stream_id
)
# We don't want to clobber receipts for more recent events, so we # We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts # have to compare orderings of existing receipts
sql = ( sql = (
@ -307,9 +313,6 @@ class ReceiptsStore(SQLBaseStore):
stream_id_manager = yield self._receipts_id_gen.get_next(self) stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id: with stream_id_manager as stream_id:
yield self._receipts_stream_cache.room_has_changed(
self, room_id, stream_id
)
have_persisted = yield self.runInteraction( have_persisted = yield self.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,
@ -368,63 +371,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data), "data": json.dumps(data),
} }
) )
class _RoomStreamChangeCache(object):
"""Keeps track of the stream_id of the latest change in rooms.
Given a list of rooms and stream key, it will give a subset of rooms that
may have changed since that key. If the key is too old then the cache
will simply return all rooms.
"""
def __init__(self, size_of_cache=10000):
self._size_of_cache = size_of_cache
self._room_to_key = {}
self._cache = sorteddict()
self._earliest_key = None
self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache
@defer.inlineCallbacks
def get_rooms_changed(self, store, room_ids, key):
"""Returns subset of room ids that have had new receipts since the
given key. If the key is too old it will just return the given list.
"""
if key > (yield self._get_earliest_key(store)):
keys = self._cache.keys()
i = keys.bisect_right(key)
result = set(
self._cache[k] for k in keys[i:]
).intersection(room_ids)
cache_counter.inc_hits(self.name)
else:
result = room_ids
cache_counter.inc_misses(self.name)
defer.returnValue(result)
@defer.inlineCallbacks
def room_has_changed(self, store, room_id, key):
"""Informs the cache that the room has been changed at the given key.
"""
if key > (yield self._get_earliest_key(store)):
old_key = self._room_to_key.get(room_id, None)
if old_key:
key = max(key, old_key)
self._cache.pop(old_key, None)
self._cache[key] = room_id
while len(self._cache) > self._size_of_cache:
k, r = self._cache.popitem()
self._earliest_key = max(k, self._earliest_key)
self._room_to_key.pop(r, None)
@defer.inlineCallbacks
def _get_earliest_key(self, store):
if self._earliest_key is None:
self._earliest_key = yield store.get_max_receipt_stream_id()
self._earliest_key = int(self._earliest_key)
defer.returnValue(self._earliest_key)

View file

@ -241,7 +241,7 @@ class RoomMemberStore(SQLBaseStore):
return rows return rows
@cached() @cached(max_entries=5000)
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is( return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],

View file

@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX events_room_stream on events(room_id, stream_ordering);

View file

@ -37,6 +37,7 @@ from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -77,6 +78,12 @@ def upper_bound(token):
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
def __init__(self, hs):
super(StreamStore, self).__init__(hs)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0): def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@ -157,6 +164,135 @@ class StreamStore(SQLBaseStore):
results = yield self.runInteraction("get_appservice_room_stream", f) results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = yield self._events_stream_cache.get_entities_changed(
room_ids, from_id
)
if not room_ids:
defer.returnValue({})
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([
self.get_room_events_stream_for_room(
room_id, from_key, to_key, limit
).addCallback(lambda r, rm: (rm, r), room_id)
for room_id in room_ids
])
results.update(dict(res))
defer.returnValue(results)
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0):
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
from_id = None
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
defer.returnValue(([], from_key))
if from_id:
has_changed = yield self._events_stream_cache.has_entity_changed(
room_id, from_id
)
if not has_changed:
defer.returnValue(([], from_key))
def f(txn):
if from_id is not None:
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering > ? AND stream_ordering <= ?"
" ORDER BY stream_ordering DESC LIMIT ?"
)
txn.execute(sql, (room_id, from_id, to_id, limit))
else:
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering <= ?"
" ORDER BY stream_ordering DESC LIMIT ?"
)
txn.execute(sql, (room_id, to_id, limit))
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False)
ret.reverse()
if rows:
key = "s%d" % min(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = from_key
return ret, key
res = yield self.runInteraction("get_room_events_stream_for_room", f)
defer.returnValue(res)
def get_room_changes_for_user(self, user_id, from_key, to_key):
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
from_id = None
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
return defer.succeed([])
def f(txn):
if from_id is not None:
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
)
txn.execute(sql, (user_id, from_id, to_id,))
else:
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
)
txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
return ret
return self.runInteraction("get_room_changes_for_user", f)
@log_function @log_function
def get_room_events_stream( def get_room_events_stream(
self, self,
@ -174,7 +310,8 @@ class StreamStore(SQLBaseStore):
"SELECT c.room_id FROM history_visibility AS h" "SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c" " INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id" " ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % ( " WHERE c.room_id IN (%s)"
" AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids)) ",".join(map(lambda _: "?", room_ids))
) )
) )
@ -434,6 +571,18 @@ class StreamStore(SQLBaseStore):
row["topological_ordering"], row["stream_ordering"],) row["topological_ordering"], row["stream_ordering"],)
) )
def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
sql = (
"SELECT max(topological_ordering) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
"get_max_topological_token_for_stream_and_room", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
)
def _get_max_topological_txn(self, txn): def _get_max_topological_txn(self, txn):
txn.execute( txn.execute(
"SELECT MAX(topological_ordering) FROM events" "SELECT MAX(topological_ordering) FROM events"
@ -444,24 +593,14 @@ class StreamStore(SQLBaseStore):
rows = txn.fetchall() rows = txn.fetchall()
return rows[0][0] if rows else 0 return rows[0][0] if rows else 0
@defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
)
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
self.min_token = min(self.min_token, -1)
logger.debug("min_token is: %s", self.min_token)
defer.returnValue(self.min_token)
@staticmethod @staticmethod
def _set_before_and_after(events, rows): def _set_before_and_after(events, rows, topo_order=True):
for event, row in zip(events, rows): for event, row in zip(events, rows):
stream = row["stream_ordering"] stream = row["stream_ordering"]
if topo_order:
topo = event.depth topo = event.depth
else:
topo = None
internal = event.internal_metadata internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1)) internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream)) internal.after = str(RoomStreamToken(topo, stream))

View file

@ -16,7 +16,6 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from twisted.internet import defer from twisted.internet import defer
from .util.id_generators import StreamIdGenerator
import ujson as json import ujson as json
import logging import logging
@ -25,13 +24,6 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore): class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
self._account_data_id_gen = StreamIdGenerator(
"account_data_max_stream_id", "stream_id"
)
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream """Get the current max stream id for the private user data stream
@ -87,6 +79,12 @@ class TagsStore(SQLBaseStore):
room_ids = [row[0] for row in txn.fetchall()] room_ids = [row[0] for row in txn.fetchall()]
return room_ids return room_ids
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
defer.returnValue({})
room_ids = yield self.runInteraction( room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn "get_updated_tags", get_updated_tags_txn
) )
@ -184,6 +182,11 @@ class TagsStore(SQLBaseStore):
next_id(int): The the revision to advance to. next_id(int): The the revision to advance to.
""" """
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id
)
update_max_id_sql = ( update_max_id_sql = (
"UPDATE account_data_max_stream_id" "UPDATE account_data_max_stream_id"
" SET stream_id = ?" " SET stream_id = ?"

View file

@ -72,28 +72,24 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id: with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, table, column): def __init__(self, db_conn, table, column):
self.table = table self.table = table
self.column = column self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = None cur = db_conn.cursor()
self._current_max = self._get_or_compute_current_max(cur)
cur.close()
self._unfinished_ids = deque() self._unfinished_ids = deque()
@defer.inlineCallbacks
def get_next(self, store): def get_next(self, store):
""" """
Usage: Usage:
with yield stream_id_gen.get_next as stream_id: with yield stream_id_gen.get_next as stream_id:
# ... persist event ... # ... persist event ...
""" """
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock: with self._lock:
self._current_max += 1 self._current_max += 1
next_id = self._current_max next_id = self._current_max
@ -108,21 +104,14 @@ class StreamIdGenerator(object):
with self._lock: with self._lock:
self._unfinished_ids.remove(next_id) self._unfinished_ids.remove(next_id)
defer.returnValue(manager()) return manager()
@defer.inlineCallbacks
def get_next_mult(self, store, n): def get_next_mult(self, store, n):
""" """
Usage: Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids: with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ... # ... persist events ...
""" """
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock: with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1) next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n self._current_max += n
@ -139,24 +128,17 @@ class StreamIdGenerator(object):
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.remove(next_id) self._unfinished_ids.remove(next_id)
defer.returnValue(manager()) return manager()
@defer.inlineCallbacks
def get_max_token(self, store): def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
defer.returnValue(self._unfinished_ids[0] - 1) return self._unfinished_ids[0] - 1
defer.returnValue(self._current_max) return self._current_max
def _get_or_compute_current_max(self, txn): def _get_or_compute_current_max(self, txn):
with self._lock: with self._lock:

View file

@ -37,7 +37,7 @@ class LruCache(object):
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type() cache = cache_type()
self.size = 0 self.cache = cache # Used for introspection.
list_root = [] list_root = []
list_root[:] = [list_root, list_root, None, None] list_root[:] = [list_root, list_root, None, None]
@ -60,7 +60,6 @@ class LruCache(object):
prev_node[NEXT] = node prev_node[NEXT] = node
next_node[PREV] = node next_node[PREV] = node
cache[key] = node cache[key] = node
self.size += 1
def move_node_to_front(node): def move_node_to_front(node):
prev_node = node[PREV] prev_node = node[PREV]
@ -79,7 +78,6 @@ class LruCache(object):
next_node = node[NEXT] next_node = node[NEXT]
prev_node[NEXT] = next_node prev_node[NEXT] = next_node
next_node[PREV] = prev_node next_node[PREV] = prev_node
self.size -= 1
@synchronized @synchronized
def cache_get(key, default=None): def cache_get(key, default=None):
@ -98,7 +96,7 @@ class LruCache(object):
node[VALUE] = value node[VALUE] = value
else: else:
add_node(key, value) add_node(key, value)
if self.size > max_size: if len(cache) > max_size:
todelete = list_root[PREV] todelete = list_root[PREV]
delete_node(todelete) delete_node(todelete)
cache.pop(todelete[KEY], None) cache.pop(todelete[KEY], None)
@ -110,7 +108,7 @@ class LruCache(object):
return node[VALUE] return node[VALUE]
else: else:
add_node(key, value) add_node(key, value)
if self.size > max_size: if len(cache) > max_size:
todelete = list_root[PREV] todelete = list_root[PREV]
delete_node(todelete) delete_node(todelete)
cache.pop(todelete[KEY], None) cache.pop(todelete[KEY], None)
@ -145,7 +143,7 @@ class LruCache(object):
@synchronized @synchronized
def cache_len(): def cache_len():
return self.size return len(cache)
@synchronized @synchronized
def cache_contains(key): def cache_contains(key):

View file

@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches import cache_counter, caches_by_name
from blist import sorteddict
import logging
logger = logging.getLogger(__name__)
class StreamChangeCache(object):
"""Keeps track of the stream positions of the latest change in a set of entities.
Typically the entity will be a room or user id.
Given a list of entities and a stream position, it will give a subset of
entities that may have changed since that position. If position key is too
old then the cache will simply return all given entities.
"""
def __init__(self, name, current_stream_pos, max_size=10000):
self._max_size = max_size
self._entity_to_key = {}
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
caches_by_name[self.name] = self._cache
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
cache_counter.inc_misses(self.name)
return True
if stream_pos == self._earliest_known_stream_pos:
# If the same as the earliest key, assume nothing has changed.
cache_counter.inc_hits(self.name)
return False
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
cache_counter.inc_misses(self.name)
return True
if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name)
return True
cache_counter.inc_hits(self.name)
return False
def get_entities_changed(self, entities, stream_pos):
"""Returns subset of entities that have had new things since the
given position. If the position is too old it will just return the given list.
"""
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
keys = self._cache.keys()
i = keys.bisect_right(stream_pos)
result = set(
self._cache[k] for k in keys[i:]
).intersection(entities)
cache_counter.inc_hits(self.name)
else:
result = entities
cache_counter.inc_misses(self.name)
return result
def entity_has_changed(self, entity, stream_pos):
"""Informs the cache that the entity has been changed at the given
position.
"""
assert type(stream_pos) is int
if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None)
if old_pos:
stream_pos = max(stream_pos, old_pos)
self._cache.pop(old_pos, None)
self._cache[stream_pos] = entity
self._entity_to_key[entity] = stream_pos
while len(self._cache) > self._max_size:
k, r = self._cache.popitem()
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
self._entity_to_key.pop(r, None)

View file

@ -8,6 +8,7 @@ class TreeCache(object):
Keys must be tuples. Keys must be tuples.
""" """
def __init__(self): def __init__(self):
self.size = 0
self.root = {} self.root = {}
def __setitem__(self, key, value): def __setitem__(self, key, value):
@ -20,7 +21,8 @@ class TreeCache(object):
node = self.root node = self.root
for k in key[:-1]: for k in key[:-1]:
node = node.setdefault(k, {}) node = node.setdefault(k, {})
node[key[-1]] = value node[key[-1]] = _Entry(value)
self.size += 1
def get(self, key, default=None): def get(self, key, default=None):
node = self.root node = self.root
@ -28,9 +30,10 @@ class TreeCache(object):
node = node.get(k, None) node = node.get(k, None)
if node is None: if node is None:
return default return default
return node.get(key[-1], default) return node.get(key[-1], _Entry(default)).value
def clear(self): def clear(self):
self.size = 0
self.root = {} self.root = {}
def pop(self, key, default=None): def pop(self, key, default=None):
@ -57,4 +60,33 @@ class TreeCache(object):
break break
node_and_keys[i+1][0].pop(k) node_and_keys[i+1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt
return popped return popped
def __len__(self):
return self.size
class _Entry(object):
__slots__ = ["value"]
def __init__(self, value):
self.value = value
def _strip_and_count_entires(d):
"""Takes an _Entry or dict with leaves of _Entry's, and either returns the
value or a dictionary with _Entry's replaced by their values.
Also returns the count of _Entry's
"""
if isinstance(d, dict):
cnt = 0
for key, value in d.items():
v, n = _strip_and_count_entires(value)
d[key] = v
cnt += n
return d, cnt
else:
return d.value, 1

View file

@ -382,19 +382,20 @@ class FilteringTestCase(unittest.TestCase):
"types": ["m.*"] "types": ["m.*"]
} }
} }
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart + "2",
user_filter=user_filter_json, user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
event_id="$asdasd:localhost",
sender="@foo:bar", sender="@foo:bar",
type="custom.avatar.3d.crazy", type="custom.avatar.3d.crazy",
) )
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart + "2",
filter_id=filter_id, filter_id=filter_id,
) )

View file

@ -1,303 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# trial imports
from twisted.internet import defer
from tests import unittest
# python imports
from mock import Mock, ANY
from ..utils import MockHttpResource, MockClock, setup_test_homeserver
from synapse.federation import initialize_http_replication
from synapse.events import FrozenEvent
def make_pdu(prev_pdus=[], **kwargs):
"""Provide some default fields for making a PduTuple."""
pdu_fields = {
"state_key": None,
"prev_events": prev_pdus,
}
pdu_fields.update(kwargs)
return FrozenEvent(pdu_fields)
class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource()
self.mock_http_client = Mock(spec=[
"get_json",
"put_json",
])
self.mock_persistence = Mock(spec=[
"prep_send_transaction",
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_auth_chain",
])
self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None)
)
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.mock_persistence.get_destination_retry_timings.return_value = (
defer.succeed(retry_timings_res)
)
self.mock_persistence.get_auth_chain.return_value = []
self.clock = MockClock()
hs = yield setup_test_homeserver(
resource_for_federation=self.mock_resource,
http_client=self.mock_http_client,
datastore=self.mock_persistence,
clock=self.clock,
keyring=Mock(),
)
self.federation = initialize_http_replication(hs)
self.distributor = hs.get_distributor()
@defer.inlineCallbacks
def test_get_state(self):
mock_handler = Mock(spec=[
"get_state_for_pdu",
])
self.federation.set_handler(mock_handler)
mock_handler.get_state_for_pdu.return_value = defer.succeed([])
# Empty context initially
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/state/my-context/",
None
)
self.assertEquals(200, code)
self.assertFalse(response["pdus"])
# Now lets give the context some state
mock_handler.get_state_for_pdu.return_value = (
defer.succeed([
make_pdu(
event_id="the-pdu-id",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.topic",
origin_server_ts=123456789000,
depth=1,
content={"topic": "The topic"},
state_key="",
power_level=1000,
prev_state="last-pdu-id",
),
])
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/state/my-context/",
None
)
self.assertEquals(200, code)
self.assertEquals(1, len(response["pdus"]))
@defer.inlineCallbacks
def test_get_pdu(self):
mock_handler = Mock(spec=[
"get_persisted_pdu",
])
self.federation.set_handler(mock_handler)
mock_handler.get_persisted_pdu.return_value = (
defer.succeed(None)
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/event/abc123def456/",
None
)
self.assertEquals(404, code)
# Now insert such a PDU
mock_handler.get_persisted_pdu.return_value = (
defer.succeed(
make_pdu(
event_id="abc123def456",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.text",
origin_server_ts=123456789001,
depth=1,
content={"text": "Here is the message"},
)
)
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/event/abc123def456/",
None
)
self.assertEquals(200, code)
self.assertEquals(1, len(response["pdus"]))
self.assertEquals("m.text", response["pdus"][0]["type"])
@defer.inlineCallbacks
def test_send_pdu(self):
self.mock_http_client.put_json.return_value = defer.succeed(
(200, "OK")
)
pdu = make_pdu(
event_id="abc123def456",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.text",
origin_server_ts=123456789001,
depth=1,
content={"text": "Here is the message"},
)
yield self.federation.send_pdu(pdu, ["remote"])
self.mock_http_client.put_json.assert_called_with(
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin_server_ts": 1000000,
"origin": "test",
"pdus": [
pdu.get_pdu_json(),
],
'pdu_failures': [],
},
json_data_callback=ANY,
long_retries=True,
)
@defer.inlineCallbacks
def test_send_edu(self):
self.mock_http_client.put_json.return_value = defer.succeed(
(200, "OK")
)
yield self.federation.send_edu(
destination="remote",
edu_type="m.test",
content={"testing": "content here"},
)
# MockClock ensures we can guess these timestamps
self.mock_http_client.put_json.assert_called_with(
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin": "test",
"origin_server_ts": 1000000,
"pdus": [],
"edus": [
{
"edu_type": "m.test",
"content": {"testing": "content here"},
}
],
'pdu_failures': [],
},
json_data_callback=ANY,
long_retries=True,
)
@defer.inlineCallbacks
def test_recv_edu(self):
recv_observer = Mock()
recv_observer.return_value = defer.succeed(())
self.federation.register_edu_handler("m.test", recv_observer)
yield self.mock_resource.trigger(
"PUT",
"/_matrix/federation/v1/send/1001000/",
"""{
"origin": "remote",
"origin_server_ts": 1001000,
"pdus": [],
"edus": [
{
"origin": "remote",
"destination": "test",
"edu_type": "m.test",
"content": {"testing": "reply here"}
}
]
}"""
)
recv_observer.assert_called_with(
"remote", {"testing": "reply here"}
)
@defer.inlineCallbacks
def test_send_query(self):
self.mock_http_client.get_json.return_value = defer.succeed(
{"your": "response"}
)
response = yield self.federation.make_query(
destination="remote",
query_type="a-question",
args={"one": "1", "two": "2"},
)
self.assertEquals({"your": "response"}, response)
self.mock_http_client.get_json.assert_called_with(
destination="remote",
path="/_matrix/federation/v1/query/a-question",
args={"one": "1", "two": "2"},
retry_on_dns_fail=True,
)
@defer.inlineCallbacks
def test_recv_query(self):
recv_handler = Mock()
recv_handler.return_value = defer.succeed({"another": "response"})
self.federation.register_query_handler("a-question", recv_handler)
code, response = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/query/a-question?three=3&four=4",
None
)
self.assertEquals(200, code)
self.assertEquals({"another": "response"}, response)
recv_handler.assert_called_with(
{"three": "3", "four": "4"}
)

View file

@ -280,6 +280,15 @@ class PresenceEventStreamTestCase(unittest.TestCase):
} }
EventSources.SOURCE_TYPES["presence"] = PresenceEventSource EventSources.SOURCE_TYPES["presence"] = PresenceEventSource
clock = Mock(spec=[
"call_later",
"cancel_call_later",
"time_msec",
"looping_call",
])
clock.time_msec.return_value = 1000000
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
@ -289,16 +298,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"get_presence_list", "get_presence_list",
"get_rooms_for_user", "get_rooms_for_user",
]), ]),
clock=Mock(spec=[ clock=clock,
"call_later",
"cancel_call_later",
"time_msec",
"looping_call",
]),
) )
hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None, allow_guest=False): def _get_user_by_req(req=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False) return Requester(UserID.from_string(myid), "", False)

View file

@ -1045,8 +1045,13 @@ class RoomMessageListTestCase(RestTestCase):
self.assertTrue("end" in response) self.assertTrue("end" in response)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_rejected(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=s0_0_0_0" % "/rooms/%s/messages?access_token=x&from=%s" %
self.room_id) (self.room_id, token))
self.assertEquals(400, code) self.assertEquals(200, code)
self.assertTrue("start" in response)
self.assertEquals(token, response['start'])
self.assertTrue("chunk" in response)
self.assertTrue("end" in response)

View file

@ -439,7 +439,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(suffix="2") f2 = self._write_config(suffix="2")
config = Mock(app_service_config_files=[f1, f2]) config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config, datastore=Mock())
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -449,7 +449,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(id="id", suffix="2") f2 = self._write_config(id="id", suffix="2")
config = Mock(app_service_config_files=[f1, f2]) config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -465,7 +465,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(as_token="as_token", suffix="2") f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(app_service_config_files=[f1, f2]) config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs) ApplicationServiceStore(hs)

View file

@ -18,7 +18,6 @@ from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage.registration import RegistrationStore
from synapse.util import stringutils from synapse.util import stringutils
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -31,7 +30,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
self.store = RegistrationStore(hs) self.store = hs.get_datastore()
self.user_id = "@my-user:test" self.user_id = "@my-user:test"
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",

View file

@ -16,10 +16,10 @@
from tests import unittest from tests import unittest
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.server import BaseHomeServer from synapse.server import HomeServer
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias
mock_homeserver = BaseHomeServer(hostname="my.domain") mock_homeserver = HomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase): class UserIDTestCase(unittest.TestCase):
@ -34,7 +34,6 @@ class UserIDTestCase(unittest.TestCase):
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
UserID.from_string("") UserID.from_string("")
def test_build(self): def test_build(self):
user = UserID("5678efgh", "my.domain") user = UserID("5678efgh", "my.domain")

View file

@ -19,6 +19,7 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
def test_get_set(self): def test_get_set(self):
@ -72,3 +73,9 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("vehicles", "car")), "vroom") self.assertEquals(cache.get(("vehicles", "car")), "vroom")
self.assertEquals(cache.get(("vehicles", "train")), "chuff") self.assertEquals(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes". # Man from del_multi say "Yes".
def test_clear(self):
cache = LruCache(1)
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)

View file

@ -25,6 +25,7 @@ class TreeCacheTestCase(unittest.TestCase):
cache[("b",)] = "B" cache[("b",)] = "B"
self.assertEquals(cache.get(("a",)), "A") self.assertEquals(cache.get(("a",)), "A")
self.assertEquals(cache.get(("b",)), "B") self.assertEquals(cache.get(("b",)), "B")
self.assertEquals(len(cache), 2)
def test_pop_onelevel(self): def test_pop_onelevel(self):
cache = TreeCache() cache = TreeCache()
@ -33,6 +34,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.pop(("a",)), "A") self.assertEquals(cache.pop(("a",)), "A")
self.assertEquals(cache.pop(("a",)), None) self.assertEquals(cache.pop(("a",)), None)
self.assertEquals(cache.get(("b",)), "B") self.assertEquals(cache.get(("b",)), "B")
self.assertEquals(len(cache), 1)
def test_get_set_twolevel(self): def test_get_set_twolevel(self):
cache = TreeCache() cache = TreeCache()
@ -42,6 +44,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "a")), "AA") self.assertEquals(cache.get(("a", "a")), "AA")
self.assertEquals(cache.get(("a", "b")), "AB") self.assertEquals(cache.get(("a", "b")), "AB")
self.assertEquals(cache.get(("b", "a")), "BA") self.assertEquals(cache.get(("b", "a")), "BA")
self.assertEquals(len(cache), 3)
def test_pop_twolevel(self): def test_pop_twolevel(self):
cache = TreeCache() cache = TreeCache()
@ -53,6 +56,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "b")), "AB") self.assertEquals(cache.get(("a", "b")), "AB")
self.assertEquals(cache.pop(("b", "a")), "BA") self.assertEquals(cache.pop(("b", "a")), "BA")
self.assertEquals(cache.pop(("b", "a")), None) self.assertEquals(cache.pop(("b", "a")), None)
self.assertEquals(len(cache), 1)
def test_pop_mixedlevel(self): def test_pop_mixedlevel(self):
cache = TreeCache() cache = TreeCache()
@ -64,3 +68,11 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "a")), None) self.assertEquals(cache.get(("a", "a")), None)
self.assertEquals(cache.get(("a", "b")), None) self.assertEquals(cache.get(("a", "b")), None)
self.assertEquals(cache.get(("b", "a")), "BA") self.assertEquals(cache.get(("b", "a")), "BA")
self.assertEquals(len(cache), 1)
def test_clear(self):
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
cache.clear()
self.assertEquals(len(cache), 0)

View file

@ -19,6 +19,8 @@ from synapse.api.constants import EventTypes
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.federation.transport import server
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -58,8 +60,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
name, db_pool=db_pool, config=config, name, db_pool=db_pool, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine("sqlite3"), database_engine=create_engine("sqlite3"),
get_db_conn=db_pool.get_db_conn,
**kargs **kargs
) )
hs.setup()
else: else:
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, name, db_pool=None, datastore=datastore, config=config,
@ -80,6 +84,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
fed = kargs.get("resource_for_federation", None)
if fed:
server.register_servlets(
hs,
resource=fed,
authenticator=server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(),
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent
),
)
defer.returnValue(hs) defer.returnValue(hs)
@ -262,6 +282,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
lambda conn: prepare_database(conn, engine) lambda conn: prepare_database(conn, engine)
) )
def get_db_conn(self):
conn = self.connect()
engine = create_engine("sqlite3")
prepare_database(conn, engine)
return conn
class MemoryDataStore(object): class MemoryDataStore(object):