0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 18:13:59 +01:00

Merge branch 'release-v0.9.3' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2015-07-01 15:12:57 +01:00
commit 67362a9a03
25 changed files with 790 additions and 484 deletions

View file

@ -38,3 +38,7 @@ Brabo <brabo at riseup.net>
Ivan Shapovalov <intelfx100 at gmail.com> Ivan Shapovalov <intelfx100 at gmail.com>
* contrib/systemd: a sample systemd unit file and a logger configuration * contrib/systemd: a sample systemd unit file and a logger configuration
Eric Myhre <hash at exultant.us>
* Fix bug where ``media_store_path`` config option was ignored by v0 content
repository API.

View file

@ -1,3 +1,26 @@
Changes in synapse v0.9.3 (2015-07-01)
======================================
No changes from v0.9.3 Release Candidate 1.
Changes in synapse v0.9.3-rc1 (2015-06-23)
==========================================
General:
* Fix a memory leak in the notifier. (SYN-412)
* Improve performance of room initial sync. (SYN-418)
* General improvements to logging.
* Remove ``access_token`` query params from ``INFO`` level logging.
Configuration:
* Add support for specifying and configuring multiple listeners. (SYN-389)
Application services:
* Fix bug where synapse failed to send user queries to application services.
Changes in synapse v0.9.2-r2 (2015-06-15) Changes in synapse v0.9.2-r2 (2015-06-15)
========================================= =========================================

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.9.2-r2" __version__ = "0.9.3"

View file

@ -370,6 +370,8 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
request.authenticated_entity = user.to_string()
defer.returnValue((user, ClientInfo(device_id, token_id))) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError( raise AuthError(

View file

@ -34,8 +34,7 @@ from twisted.application import service
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory from twisted.web.server import Site, GzipEncoderFactory, Request
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@ -61,11 +60,13 @@ import twisted.manhole.telnet
import synapse import synapse
import contextlib
import logging import logging
import os import os
import re import re
import resource import resource
import subprocess import subprocess
import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
@ -87,16 +88,10 @@ class SynapseHomeServer(HomeServer):
return MatrixFederationHttpClient(self) return MatrixFederationHttpClient(self)
def build_resource_for_client(self): def build_resource_for_client(self):
res = ClientV1RestResource(self) return ClientV1RestResource(self)
if self.config.gzip_responses:
res = gz_wrap(res)
return res
def build_resource_for_client_v2_alpha(self): def build_resource_for_client_v2_alpha(self):
res = ClientV2AlphaRestResource(self) return ClientV2AlphaRestResource(self)
if self.config.gzip_responses:
res = gz_wrap(res)
return res
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
@ -119,7 +114,7 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
return ContentRepoResource( return ContentRepoResource(
self, self.upload_dir, self.auth, self.content_addr self, self.config.uploads_path, self.auth, self.content_addr
) )
def build_resource_for_media_repository(self): def build_resource_for_media_repository(self):
@ -145,152 +140,105 @@ class SynapseHomeServer(HomeServer):
**self.db_config.get("args", {}) **self.db_config.get("args", {})
) )
def create_resource_tree(self, redirect_root_to_web_client): def _listener_http(self, config, listener_config):
"""Create the resource tree for this Home Server. port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
This in unduly complicated because Twisted does not support putting if tls and config.no_tls:
child resources more than 1 level deep at a time. return
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
config = self.get_config()
web_client = config.web_client
# list containing (path_str, Resource) e.g:
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
desired_tree = [
(CLIENT_PREFIX, self.get_resource_for_client()),
(CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
(FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
(STATIC_PREFIX, self.get_resource_for_static_content()),
]
if web_client:
logger.info("Adding the web client.")
desired_tree.append((WEB_CLIENT_PREFIX,
self.get_resource_for_web_client()))
if web_client and redirect_root_to_web_client:
self.root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
self.root_resource = Resource()
metrics_resource = self.get_resource_for_metrics() metrics_resource = self.get_resource_for_metrics()
if config.metrics_port is None and metrics_resource is not None:
desired_tree.append((METRICS_PREFIX, metrics_resource))
# ideally we'd just use getChild and putChild but getChild doesn't work resources = {}
# unless you give it a Request object IN ADDITION to the name :/ So for res in listener_config["resources"]:
# instead, we'll store a copy of this mapping so we can actually add for name in res["names"]:
# extra resources to existing nodes. See self._resource_id for the key. if name == "client":
resource_mappings = {} if res["compress"]:
for full_path, res in desired_tree: client_v1 = gz_wrap(self.get_resource_for_client())
logger.info("Attaching %s to path %s", res, full_path) client_v2 = gz_wrap(self.get_resource_for_client_v2_alpha())
last_resource = self.root_resource else:
for path_seg in full_path.split('/')[1:-1]: client_v1 = self.get_resource_for_client()
if path_seg not in last_resource.listNames(): client_v2 = self.get_resource_for_client_v2_alpha()
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = self._resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = self._resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# =========================== resources.update({
# now attach the actual desired resource CLIENT_PREFIX: client_v1,
last_path_seg = full_path.split('/')[-1] CLIENT_V2_ALPHA_PREFIX: client_v2,
})
# if there is already a resource here, thieve its children and if name == "federation":
# replace it resources.update({
res_id = self._resource_id(last_resource, last_path_seg) FEDERATION_PREFIX: self.get_resource_for_federation(),
if res_id in resource_mappings: })
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = self._resource_id(existing_dummy_resource,
child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place if name in ["static", "client"]:
last_resource.putChild(last_path_seg, res) resources.update({
res_id = self._resource_id(last_resource, last_path_seg) STATIC_PREFIX: self.get_resource_for_static_content(),
resource_mappings[res_id] = res })
return self.root_resource if name in ["media", "federation", "client"]:
resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(),
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
})
def _resource_id(self, resource, path_seg): if name in ["keys", "federation"]:
"""Construct an arbitrary resource ID so you can retrieve the mapping resources.update({
later. SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
})
If you want to represent resource A putChild resource B with path C, if name == "webclient":
the mapping should looks like _resource_id(A,C) = B. resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
Args: if name == "metrics" and metrics_resource:
resource (Resource): The *parent* Resource resources[METRICS_PREFIX] = metrics_resource
path_seg (str): The name of the child Resource to be attached.
Returns: root_resource = create_resource_tree(resources)
str: A unique string which can be a key to the child Resource. if tls:
""" reactor.listenSSL(
return "%s-%s" % (resource, path_seg) port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_context_factory,
interface=bind_address
)
else:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse now listening on port %d", port)
def start_listening(self): def start_listening(self):
config = self.get_config() config = self.get_config()
if not config.no_tls and config.bind_port is not None: for listener in config.listeners:
reactor.listenSSL( if listener["type"] == "http":
config.bind_port, self._listener_http(config, listener)
SynapseSite( elif listener["type"] == "manhole":
"synapse.access.https", f = twisted.manhole.telnet.ShellFactory()
config, f.username = "matrix"
self.root_resource, f.password = "rabbithole"
), f.namespace['hs'] = self
self.tls_context_factory, reactor.listenTCP(
interface=config.bind_host listener["port"],
) f,
logger.info("Synapse now listening on port %d", config.bind_port) interface=listener.get("bind_address", '127.0.0.1')
)
if config.unsecure_port is not None: else:
reactor.listenTCP( logger.warn("Unrecognized listener type: %s", listener["type"])
config.unsecure_port,
SynapseSite(
"synapse.access.http",
config,
self.root_resource,
),
interface=config.bind_host
)
logger.info("Synapse now listening on port %d", config.unsecure_port)
metrics_resource = self.get_resource_for_metrics()
if metrics_resource and config.metrics_port is not None:
reactor.listenTCP(
config.metrics_port,
SynapseSite(
"synapse.access.metrics",
config,
metrics_resource,
),
interface=config.metrics_bind_host,
)
logger.info(
"Metrics now running on %s port %d",
config.metrics_bind_host, config.metrics_port,
)
def run_startup_checks(self, db_conn, database_engine): def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain( all_users_native = are_all_users_on_domain(
@ -425,11 +373,6 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name
else:
domain_with_port = "%s:%s" % (config.server_name, config.bind_port)
tls_context_factory = context_factory.ServerContextFactory(config) tls_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"]) database_engine = create_engine(config.database_config["name"])
@ -437,8 +380,6 @@ def setup(config_options):
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"),
db_config=config.database_config, db_config=config.database_config,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
@ -447,10 +388,6 @@ def setup(config_options):
database_engine=database_engine, database_engine=database_engine,
) )
hs.create_resource_tree(
redirect_root_to_web_client=True,
)
logger.info("Preparing database: %r...", config.database_config) logger.info("Preparing database: %r...", config.database_config)
try: try:
@ -475,13 +412,6 @@ def setup(config_options):
logger.info("Database prepared in %r.", config.database_config) logger.info("Database prepared in %r.", config.database_config)
if config.manhole:
f = twisted.manhole.telnet.ShellFactory()
f.username = "matrix"
f.password = "rabbithole"
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
hs.start_listening() hs.start_listening()
hs.get_pusherpool().start() hs.get_pusherpool().start()
@ -507,22 +437,194 @@ class SynapseService(service.Service):
return self._port.stopListening() return self._port.stopListening()
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return re.sub(
r'(\?.*access_token=)[^&]*(.*)$',
r'\1<redacted>\2',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site): class SynapseSite(Site):
""" """
Subclass of a twisted http Site that does access logging with python's Subclass of a twisted http Site that does access logging with python's
standard logging standard logging
""" """
def __init__(self, logger_name, config, resource, *args, **kwargs): def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs) Site.__init__(self, resource, *args, **kwargs)
if config.captcha_ip_origin_is_x_forwarded:
self._log_formatter = proxiedLogFormatter self.site_tag = site_tag
else:
self._log_formatter = combinedLogFormatter proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
def log(self, request): def log(self, request):
line = self._log_formatter(self._logDateTime, request) pass
self.access_logger.info(line)
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
"""Create the resource tree for this Home Server.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = _resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
# if there is already a resource here, thieve its children and
# replace it
res_id = _resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = _resource_id(
existing_dummy_resource, child_name
)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = _resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
return root_resource
def _resource_id(resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resource
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
def run(hs): def run(hs):

View file

@ -148,7 +148,7 @@ class Config(object):
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
@ -209,7 +209,7 @@ class Config(object):
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )

View file

@ -21,10 +21,6 @@ class CaptchaConfig(Config):
self.recaptcha_private_key = config["recaptcha_private_key"] self.recaptcha_private_key = config["recaptcha_private_key"]
self.recaptcha_public_key = config["recaptcha_public_key"] self.recaptcha_public_key = config["recaptcha_public_key"]
self.enable_registration_captcha = config["enable_registration_captcha"] self.enable_registration_captcha = config["enable_registration_captcha"]
# XXX: This is used for more than just captcha
self.captcha_ip_origin_is_x_forwarded = (
config["captcha_ip_origin_is_x_forwarded"]
)
self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
@ -43,10 +39,6 @@ class CaptchaConfig(Config):
# public/private key. # public/private key.
enable_registration_captcha: False enable_registration_captcha: False
# When checking captchas, use the X-Forwarded-For (XFF) header
# as the client IP and not the actual client IP.
captcha_ip_origin_is_x_forwarded: False
# A secret key used to bypass the captcha test entirely. # A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE" #captcha_bypass_secret: "YOUR_SECRET_HERE"

View file

@ -28,10 +28,4 @@ class MetricsConfig(Config):
# Enable collection and rendering of performance metrics # Enable collection and rendering of performance metrics
enable_metrics: False enable_metrics: False
# Separate port to accept metrics requests on
# metrics_port: 8081
# Which host to bind the metric listener to
# metrics_bind_host: 127.0.0.1
""" """

View file

@ -21,13 +21,18 @@ class ContentRepositoryConfig(Config):
self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"])
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads")
return """ return """
# Directory where uploaded images and attachments are stored. # Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s" media_store_path: "%(media_store)s"
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
# The largest allowed upload size in bytes # The largest allowed upload size in bytes
max_upload_size: "10M" max_upload_size: "10M"

View file

@ -20,26 +20,97 @@ class ServerConfig(Config):
def read_config(self, config): def read_config(self, config):
self.server_name = config["server_name"] self.server_name = config["server_name"]
self.bind_port = config["bind_port"]
self.bind_host = config["bind_host"]
self.unsecure_port = config["unsecure_port"]
self.manhole = config.get("manhole")
self.pid_file = self.abspath(config.get("pid_file")) self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"] self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.gzip_responses = config["gzip_responses"]
self.listeners = config.get("listeners", [])
bind_port = config.get("bind_port")
if bind_port:
self.listeners = []
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
names = ["client", "webclient"] if self.web_client else ["client"]
self.listeners.append({
"port": bind_port,
"bind_address": bind_host,
"tls": True,
"type": "http",
"resources": [
{
"names": names,
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
unsecure_port = config.get("unsecure_port", bind_port - 400)
if unsecure_port:
self.listeners.append({
"port": unsecure_port,
"bind_address": bind_host,
"tls": False,
"type": "http",
"resources": [
{
"names": names,
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
manhole = config.get("manhole")
if manhole:
self.listeners.append({
"port": manhole,
"bind_address": "127.0.0.1",
"type": "manhole",
})
metrics_port = config.get("metrics_port")
if metrics_port:
self.listeners.append({
"port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"),
"tls": False,
"type": "http",
"resources": [
{
"names": ["metrics"],
"compress": False,
},
]
})
# Attempt to guess the content_addr for the v0 content repostitory # Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr") content_addr = config.get("content_addr")
if not content_addr: if not content_addr:
for listener in self.listeners:
if listener["type"] == "http" and not listener.get("tls", False):
unsecure_port = listener["port"]
break
else:
raise RuntimeError("Could not determine 'content_addr'")
host = self.server_name host = self.server_name
if ':' not in host: if ':' not in host:
host = "%s:%d" % (host, self.unsecure_port) host = "%s:%d" % (host, unsecure_port)
else: else:
host = host.split(':')[0] host = host.split(':')[0]
host = "%s:%d" % (host, self.unsecure_port) host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,) content_addr = "http://%s" % (host,)
self.content_addr = content_addr self.content_addr = content_addr
@ -61,18 +132,6 @@ class ServerConfig(Config):
# e.g. matrix.org, localhost:8080, etc. # e.g. matrix.org, localhost:8080, etc.
server_name: "%(server_name)s" server_name: "%(server_name)s"
# The port to listen for HTTPS requests on.
# For when matrix traffic is sent directly to synapse.
bind_port: %(bind_port)s
# The port to listen for HTTP requests on.
# For when matrix traffic passes through loadbalancer that unwraps TLS.
unsecure_port: %(unsecure_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_host: ""
# When running as a daemon, the file to store the pid in # When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s pid_file: %(pid_file)s
@ -84,14 +143,64 @@ class ServerConfig(Config):
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# Turn on the twisted telnet manhole service on localhost on the given # List of ports that Synapse should listen on, their purpose and their
# port. # configuration.
#manhole: 9000 listeners:
# Main HTTPS listener
# For when matrix traffic is sent directly to synapse.
-
# The port to listen for HTTPS requests on.
port: %(bind_port)s
# Should synapse compress HTTP responses to clients that support it? # Local interface to listen on.
# This should be disabled if running synapse behind a load balancer # The empty string will cause synapse to listen on all interfaces.
# that can do automatic compression. bind_address: ''
gzip_responses: True
# This is a 'http' listener, allows us to specify 'resources'.
type: http
tls: true
# Use the X-Forwarded-For (XFF) header as the client IP and not the
# actual client IP.
x_forwarded: false
# List of HTTP resources to serve on this listener.
resources:
-
# List of resources to host on this listener.
names:
- client # The client-server APIs, both v1 and v2
- webclient # The bundled webclient.
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
compress: true
- names: [federation] # Federation APIs
compress: false
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
bind_address: ''
type: http
x_forwarded: false
resources:
- names: [client, webclient]
compress: true
- names: [federation]
compress: false
# Turn on the twisted telnet manhole service on localhost on the given
# port.
# - port: 9000
# bind_address: 127.0.0.1
# type: manhole
""" % locals() """ % locals()
def read_arguments(self, args): def read_arguments(self, args):

View file

@ -94,6 +94,7 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request)
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin
defer.returnValue((origin, content)) defer.returnValue((origin, content))

View file

@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
return return
user_info = yield self.store.get_user_by_id(user_id) user_info = yield self.store.get_user_by_id(user_id)
if not user_info: if user_info:
defer.returnValue(False) defer.returnValue(False)
return return

View file

@ -380,15 +380,6 @@ class MessageHandler(BaseHandler):
if limit is None: if limit is None:
limit = 10 limit = 10
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
room_members = [ room_members = [
m for m in current_state.values() m for m in current_state.values()
if m.type == EventTypes.Member if m.type == EventTypes.Member
@ -396,19 +387,38 @@ class MessageHandler(BaseHandler):
] ]
presence_handler = self.hs.get_handlers().presence_handler presence_handler = self.hs.get_handlers().presence_handler
presence = []
for m in room_members: @defer.inlineCallbacks
try: def get_presence():
member_presence = yield presence_handler.get_state( presence_defs = yield defer.DeferredList(
target_user=UserID.from_string(m.user_id), [
auth_user=auth_user, presence_handler.get_state(
as_event=True, target_user=UserID.from_string(m.user_id),
) auth_user=auth_user,
presence.append(member_presence) as_event=True,
except SynapseError: check_auth=False,
logger.exception( )
"Failed to get member presence of %r", m.user_id for m in room_members
],
consumeErrors=True,
)
defer.returnValue([p for success, p in presence_defs if success])
presence, (messages, token) = yield defer.gatherResults(
[
get_presence(),
self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
) )
],
consumeErrors=True,
).addErrback(unwrapFirstError)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -191,24 +191,24 @@ class PresenceHandler(BaseHandler):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False): def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
visible = yield self.is_presence_visible( if check_auth:
observer_user=auth_user, visible = yield self.is_presence_visible(
observed_user=target_user observer_user=auth_user,
) observed_user=target_user
)
if not visible: if not visible:
raise SynapseError(404, "Presence information not visible") raise SynapseError(404, "Presence information not visible")
state = yield self.store.get_presence_state(target_user.localpart)
if "mtime" in state:
del state["mtime"]
state["presence"] = state.pop("state")
if target_user in self._user_cachemap: if target_user in self._user_cachemap:
cached_state = self._user_cachemap[target_user].get_state() state = self._user_cachemap[target_user].get_state()
if "last_active" in cached_state: else:
state["last_active"] = cached_state["last_active"] state = yield self.store.get_presence_state(target_user.localpart)
if "mtime" in state:
del state["mtime"]
state["presence"] = state.pop("state")
else: else:
# TODO(paul): Have remote server send us permissions set # TODO(paul): Have remote server send us permissions set
state = self._get_or_offline_usercache(target_user).get_state() state = self._get_or_offline_usercache(target_user).get_state()

View file

@ -61,21 +61,31 @@ class SimpleHttpClient(object):
self.agent = Agent(reactor, pool=pool) self.agent = Agent(reactor, pool=pool)
self.version_string = hs.version_string self.version_string = hs.version_string
def request(self, method, *args, **kwargs): def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
# counters to it # counters to it
outgoing_requests_counter.inc(method) outgoing_requests_counter.inc(method)
d = preserve_context_over_fn( d = preserve_context_over_fn(
self.agent.request, self.agent.request,
method, *args, **kwargs method, uri, *args, **kwargs
) )
logger.info("Sending request %s %s", method, uri)
def _cb(response): def _cb(response):
incoming_responses_counter.inc(method, response.code) incoming_responses_counter.inc(method, response.code)
logger.info(
"Received response to %s %s: %s",
method, uri, response.code
)
return response return response
def _eb(failure): def _eb(failure):
incoming_responses_counter.inc(method, "ERR") incoming_responses_counter.inc(method, "ERR")
logger.info(
"Error sending request to %s %s: %s %s",
method, uri, failure.type, failure.getErrorMessage()
)
return failure return failure
d.addCallbacks(_cb, _eb) d.addCallbacks(_cb, _eb)
@ -84,7 +94,9 @@ class SimpleHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}): def post_urlencoded_get_json(self, uri, args={}):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
response = yield self.request( response = yield self.request(
@ -97,7 +109,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(query_bytes)) bodyProducer=FileBodyProducer(StringIO(query_bytes))
) )
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -105,7 +117,7 @@ class SimpleHttpClient(object):
def post_json_get_json(self, uri, post_json): def post_json_get_json(self, uri, post_json):
json_str = encode_canonical_json(post_json) json_str = encode_canonical_json(post_json)
logger.info("HTTP POST %s -> %s", json_str, uri) logger.debug("HTTP POST %s -> %s", json_str, uri)
response = yield self.request( response = yield self.request(
"POST", "POST",
@ -116,7 +128,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -149,7 +161,7 @@ class SimpleHttpClient(object):
}) })
) )
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -192,7 +204,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -226,7 +238,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
) )
try: try:
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(body) defer.returnValue(body)
except PartialDownloadError as e: except PartialDownloadError as e:
# twisted dislikes google's response, no content length. # twisted dislikes google's response, no content length.

View file

@ -35,11 +35,13 @@ from syutil.crypto.jsonsign import sign_json
import simplejson as json import simplejson as json
import logging import logging
import sys
import urllib import urllib
import urlparse import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound")
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
@ -109,6 +111,8 @@ class MatrixFederationHttpClient(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
@ -123,88 +127,98 @@ class MatrixFederationHttpClient(object):
("", "", path_bytes, param_bytes, query_bytes, "",) ("", "", path_bytes, param_bytes, query_bytes, "",)
) )
logger.info("Sending request to %s: %s %s", txn_id = "%s-O-%s" % (method, self._next_id)
destination, method, url_bytes) self._next_id = (self._next_id + 1) % (sys.maxint - 1)
logger.debug( outbound_logger.info(
"Types: %s", "{%s} [%s] Sending request: %s %s",
[ txn_id, destination, method, url_bytes
type(destination), type(method), type(path_bytes),
type(param_bytes),
type(query_bytes)
]
) )
# XXX: Would be much nicer to retry only at the transaction-layer # XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place) # (once we have reliable transactions in place)
retries_left = 5 retries_left = 5
endpoint = self._getEndpoint(reactor, destination) endpoint = preserve_context_over_fn(
self._getEndpoint, reactor, destination
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
try:
request_deferred = preserve_context_over_fn(
self.agent.request,
destination,
endpoint,
method,
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
response = yield self.clock.time_bound_deferred(
request_deferred,
time_out=timeout/1000. if timeout else 60,
)
logger.debug("Got response to %s", method)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
raise
logger.warn(
"Sending request failed to %s: %s %s: %s - %s",
destination,
method,
url_bytes,
type(e).__name__,
_flatten_response_never_received(e),
)
if retries_left and not timeout:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
logger.info(
"Received response %d %s for %s: %s %s",
response.code,
response.phrase,
destination,
method,
url_bytes
) )
log_result = None
try:
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
try:
def send_request():
request_deferred = self.agent.request(
destination,
endpoint,
method,
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
return self.clock.time_bound_deferred(
request_deferred,
time_out=timeout/1000. if timeout else 60,
)
response = yield preserve_context_over_fn(
send_request,
)
log_result = "%d %s" % (response.code, response.phrase,)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise
logger.warn(
"{%s} Sending request failed to %s: %s %s: %s - %s",
txn_id,
destination,
method,
url_bytes,
type(e).__name__,
_flatten_response_never_received(e),
)
log_result = "%s - %s" % (
type(e).__name__, _flatten_response_never_received(e),
)
if retries_left and not timeout:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
finally:
outbound_logger.info(
"{%s} [%s] Result: %s",
txn_id,
destination,
log_result,
)
if 200 <= response.code < 300: if 200 <= response.code < 300:
pass pass
else: else:
# :'( # :'(
# Update transactions table? # Update transactions table?
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase, body response.code, response.phrase, body
) )
@ -284,10 +298,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json" "Content-Type not application/json"
) )
logger.debug("Getting resp body") body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -330,9 +341,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json" "Content-Type not application/json"
) )
logger.debug("Getting resp body") body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -390,9 +399,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json" "Content-Type not application/json"
) )
logger.debug("Getting resp body") body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -435,7 +442,10 @@ class MatrixFederationHttpClient(object):
headers = dict(response.headers.getAllRawHeaders()) headers = dict(response.headers.getAllRawHeaders())
try: try:
length = yield _readBodyToFile(response, output_stream, max_size) length = yield preserve_context_over_fn(
_readBodyToFile,
response, output_stream, max_size
)
except: except:
logger.exception("Failed to download body") logger.exception("Failed to download body")
raise raise

View file

@ -79,53 +79,39 @@ def request_handler(request_handler):
_next_request_id += 1 _next_request_id += 1
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
request_context.request = request_id request_context.request = request_id
code = None with request.processing():
start = self.clock.time_msec() try:
try: d = request_handler(self, request)
logger.info( with PreserveLoggingContext():
"Received request: %s %s", yield d
request.method, request.path except CodeMessageException as e:
) code = e.code
d = request_handler(self, request) if isinstance(e, SynapseError):
with PreserveLoggingContext(): logger.info(
yield d "%s SynapseError: %s - %s", request, code, e.msg
code = request.code )
except CodeMessageException as e: else:
code = e.code logger.exception(e)
if isinstance(e, SynapseError): outgoing_responses_counter.inc(request.method, str(code))
logger.info( respond_with_json(
"%s SynapseError: %s - %s", request, code, e.msg request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
) )
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
code = 500
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)
finally:
code = str(code) if code else "-"
end = self.clock.time_msec()
logger.info(
"Processed request: %dms %s %s %s",
end-start, code, request.method, request.path
)
return wrapped_request_handler return wrapped_request_handler

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor, ObservableDeferred
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -45,21 +45,11 @@ class _NotificationListener(object):
The events stream handler will have yielded to the deferred, so to The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred. notify the handler it is sufficient to resolve the deferred.
""" """
__slots__ = ["deferred"]
def __init__(self, deferred): def __init__(self, deferred):
self.deferred = deferred self.deferred = deferred
def notified(self):
return self.deferred.called
def notify(self, token):
""" Inform whoever is listening about the new events.
"""
try:
self.deferred.callback(token)
except defer.AlreadyCalledError:
pass
class _NotifierUserStream(object): class _NotifierUserStream(object):
"""This represents a user connected to the event stream. """This represents a user connected to the event stream.
@ -75,11 +65,12 @@ class _NotifierUserStream(object):
appservice=None): appservice=None):
self.user = str(user) self.user = str(user)
self.appservice = appservice self.appservice = appservice
self.listeners = set()
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
@ -91,12 +82,10 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance( self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id stream_key, stream_id
) )
if self.listeners: self.last_notified_ms = time_now_ms
self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred
listeners = self.listeners self.notify_deferred = ObservableDeferred(defer.Deferred())
self.listeners = set() noify_deferred.callback(self.current_token)
for listener in listeners:
listener.notify(self.current_token)
def remove(self, notifier): def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier """ Remove this listener from all the indexes in the Notifier
@ -114,6 +103,18 @@ class _NotifierUserStream(object):
self.appservice, set() self.appservice, set()
).discard(self) ).discard(self)
def count_listeners(self):
return len(self.notify_deferred.observers())
def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
"""
if self.current_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
@ -158,7 +159,7 @@ class Notifier(object):
for x in self.appservice_to_user_streams.values(): for x in self.appservice_to_user_streams.values():
all_user_streams |= x all_user_streams |= x
return sum(len(stream.listeners) for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
metrics.register_callback( metrics.register_callback(
@ -286,10 +287,6 @@ class Notifier(object):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
deferred = defer.Deferred()
time_now_ms = self.clock.time_msec()
user = str(user) user = str(user)
user_stream = self.user_to_user_stream.get(user) user_stream = self.user_to_user_stream.get(user)
if user_stream is None: if user_stream is None:
@ -302,55 +299,44 @@ class Notifier(object):
rooms=rooms, rooms=rooms,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=time_now_ms, time_now_ms=self.clock.time_msec(),
) )
self._register_with_keys(user_stream) self._register_with_keys(user_stream)
result = None
if timeout:
# Will be set to a _NotificationListener that we'll be waiting on.
# Allows us to cancel it.
listener = None
def timed_out():
if listener:
listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out)
prev_token = from_token
while not result:
try:
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
if result:
break
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
# We need to supply the token we supplied to callback so
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token)
yield listener.deferred
except defer.CancelledError:
break
self.clock.cancel_call_later(timer, ignore_errs=True)
else: else:
current_token = user_stream.current_token current_token = user_stream.current_token
listener = [_NotificationListener(deferred)]
if timeout and not current_token.is_after(from_token):
user_stream.listeners.add(listener[0])
if current_token.is_after(from_token):
result = yield callback(from_token, current_token) result = yield callback(from_token, current_token)
else:
result = None
timer = [None]
if result:
user_stream.listeners.discard(listener[0])
defer.returnValue(result)
return
if timeout:
timed_out = [False]
def _timeout_listener():
timed_out[0] = True
timer[0] = None
user_stream.listeners.discard(listener[0])
listener[0].notify(current_token)
# We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
new_token = yield deferred
deferred = defer.Deferred()
listener[0] = _NotificationListener(deferred)
user_stream.listeners.add(listener[0])
result = yield callback(current_token, new_token)
current_token = new_token
if timer[0] is not None:
try:
self.clock.cancel_call_later(timer[0])
except:
logger.exception("Failed to cancel notifer timer")
defer.returnValue(result) defer.returnValue(result)
@ -368,6 +354,9 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
defer.returnValue(None)
events = [] events = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
@ -402,7 +391,7 @@ class Notifier(object):
expired_streams = [] expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values(): for stream in self.user_to_user_stream.values():
if stream.listeners: if stream.count_listeners():
continue continue
if stream.last_notified_ms < expire_before_ts: if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream) expired_streams.append(stream)

View file

@ -39,10 +39,10 @@ class HttpTransactionStore(object):
A tuple of (HTTP response code, response content) or None. A tuple of (HTTP response code, response content) or None.
""" """
try: try:
logger.debug("get_response Key: %s TxnId: %s", key, txn_id) logger.debug("get_response TxnId: %s", txn_id)
(last_txn_id, response) = self.transactions[key] (last_txn_id, response) = self.transactions[key]
if txn_id == last_txn_id: if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", key) logger.info("get_response: Returning a response for %s", txn_id)
return response return response
except KeyError: except KeyError:
pass pass
@ -58,7 +58,7 @@ class HttpTransactionStore(object):
txn_id (str): The transaction ID for this request. txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content) response (tuple): A tuple of (HTTP response code, response content)
""" """
logger.debug("store_response Key: %s TxnId: %s", key, txn_id) logger.debug("store_response TxnId: %s", txn_id)
self.transactions[key] = (txn_id, response) self.transactions[key] = (txn_id, response)
def store_client_transaction(self, request, txn_id, response): def store_client_transaction(self, request, txn_id, response):

View file

@ -132,16 +132,8 @@ class BaseHomeServer(object):
setattr(BaseHomeServer, "get_%s" % (depname), _get) setattr(BaseHomeServer, "get_%s" % (depname), _get)
def get_ip_from_request(self, request): def get_ip_from_request(self, request):
# May be an X-Forwarding-For header depending on config # X-Forwarded-For is handled by our custom request type.
ip_addr = request.getClientIP() return request.getClientIP()
if self.config.captcha_ip_origin_is_x_forwarded:
# use the header
if request.requestHeaders.hasHeader("X-Forwarded-For"):
ip_addr = request.requestHeaders.getRawHeaders(
"X-Forwarded-For"
)[0]
return ip_addr
def is_mine(self, domain_specific_string): def is_mine(self, domain_specific_string):
return domain_specific_string.domain == self.hostname return domain_specific_string.domain == self.hostname

View file

@ -91,8 +91,12 @@ class Clock(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs) return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer): def cancel_call_later(self, timer, ignore_errs=False):
timer.cancel() try:
timer.cancel()
except:
if not ignore_errs:
raise
def time_bound_deferred(self, given_deferred, time_out): def time_bound_deferred(self, given_deferred, time_out):
if given_deferred.called: if given_deferred.called:

View file

@ -38,6 +38,9 @@ class ObservableDeferred(object):
deferred. deferred.
If consumeErrors is true errors will be captured from the origin deferred. If consumeErrors is true errors will be captured from the origin deferred.
Cancelling or otherwise resolving an observer will not affect the original
ObservableDeferred.
""" """
__slots__ = ["_deferred", "_observers", "_result"] __slots__ = ["_deferred", "_observers", "_result"]
@ -45,7 +48,7 @@ class ObservableDeferred(object):
def __init__(self, deferred, consumeErrors=False): def __init__(self, deferred, consumeErrors=False):
object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None) object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", []) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) self._result = (True, r)
@ -74,12 +77,21 @@ class ObservableDeferred(object):
def observe(self): def observe(self):
if not self._result: if not self._result:
d = defer.Deferred() d = defer.Deferred()
self._observers.append(d)
def remove(r):
self._observers.discard(d)
return r
d.addBoth(remove)
self._observers.add(d)
return d return d
else: else:
success, res = self._result success, res = self._result
return defer.succeed(res) if success else defer.fail(res) return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._deferred, name) return getattr(self._deferred, name)

View file

@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
) )
class _PreservingContextDeferred(defer.Deferred):
"""A deferred that ensures that all callbacks and errbacks are called with
the given logging context.
"""
def __init__(self, context):
self._log_context = context
defer.Deferred.__init__(self)
def addCallbacks(self, callback, errback=None,
callbackArgs=None, callbackKeywords=None,
errbackArgs=None, errbackKeywords=None):
callback = self._wrap_callback(callback)
errback = self._wrap_callback(errback)
return defer.Deferred.addCallbacks(
self, callback,
errback=errback,
callbackArgs=callbackArgs,
callbackKeywords=callbackKeywords,
errbackArgs=errbackArgs,
errbackKeywords=errbackKeywords,
)
def _wrap_callback(self, f):
def g(res, *args, **kwargs):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = self._log_context
res = f(res, *args, **kwargs)
return res
return g
def preserve_context_over_fn(fn, *args, **kwargs): def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes """Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so. and restores the current logging context while doing so.
@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
"""Given a deferred wrap it such that any callbacks added later to it will """Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context. be invoked with the current context.
""" """
d = defer.Deferred()
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
d = _PreservingContextDeferred(current_context)
def cb(res): deferred.chainDeferred(d)
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.callback(res)
return res
def eb(failure):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.errback(failure)
return res
if deferred.called:
return deferred
deferred.addCallbacks(cb, eb)
return d return d

View file

@ -57,6 +57,49 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, event interested_service, event
) )
@defer.inlineCallbacks
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user = Mock(return_value=True)
self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value=None)
event = Mock(
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id
)
@defer.inlineCallbacks
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user = Mock(return_value=True)
self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value={
"name": user_id
})
event = Mock(
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been."
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_query_room_alias_exists(self): def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar" room_alias_str = "#foo:bar"

View file

@ -114,6 +114,8 @@ class MockHttpResource(HttpServer):
mock_request.method = http_method mock_request.method = http_method
mock_request.uri = path mock_request.uri = path
mock_request.getClientIP.return_value = "-"
mock_request.requestHeaders.getRawHeaders.return_value=[ mock_request.requestHeaders.getRawHeaders.return_value=[
"X-Matrix origin=test,key=,sig=" "X-Matrix origin=test,key=,sig="
] ]