Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
This commit is contained in:
commit
d351be1567
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -12,6 +12,7 @@ dbs/
|
|||
dist/
|
||||
docs/build/
|
||||
*.egg-info
|
||||
pip-wheel-metadata/
|
||||
|
||||
cmdclient_config.json
|
||||
homeserver*.db
|
||||
|
|
|
@ -15,6 +15,7 @@ recursive-include docs *
|
|||
recursive-include scripts *
|
||||
recursive-include scripts-dev *
|
||||
recursive-include synapse *.pyi
|
||||
recursive-include tests *.pem
|
||||
recursive-include tests *.py
|
||||
|
||||
recursive-include synapse/res *
|
||||
|
|
1
changelog.d/4408.feature
Normal file
1
changelog.d/4408.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -1 +0,0 @@
|
|||
Refactor 'sign_request' as 'build_auth_headers'
|
1
changelog.d/4409.feature
Normal file
1
changelog.d/4409.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -1 +0,0 @@
|
|||
Remove redundant federation connection wrapping code
|
1
changelog.d/4426.feature
Normal file
1
changelog.d/4426.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -1 +0,0 @@
|
|||
Remove redundant SynapseKeyClientProtocol magic
|
1
changelog.d/4427.feature
Normal file
1
changelog.d/4427.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -1 +0,0 @@
|
|||
Refactor and cleanup for SRV record lookup
|
1
changelog.d/4428.feature
Normal file
1
changelog.d/4428.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -1 +0,0 @@
|
|||
Move SRV logic into the Agent layer
|
1
changelog.d/4464.feature
Normal file
1
changelog.d/4464.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -1 +0,0 @@
|
|||
Move SRV logic into the Agent layer
|
1
changelog.d/4468.feature
Normal file
1
changelog.d/4468.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -1 +0,0 @@
|
|||
Move SRV logic into the Agent layer
|
1
changelog.d/4481.misc
Normal file
1
changelog.d/4481.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add infrastructure to support different event formats
|
1
changelog.d/4483.feature
Normal file
1
changelog.d/4483.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add support for room version 3
|
1
changelog.d/4486.bugfix
Normal file
1
changelog.d/4486.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Workaround for login error when using both LDAP and internal authentication.
|
1
changelog.d/4487.feature
Normal file
1
changelog.d/4487.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -1 +0,0 @@
|
|||
Fix idna and ipv6 literal handling in MatrixFederationAgent
|
1
changelog.d/4489.feature
Normal file
1
changelog.d/4489.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
1
changelog.d/4494.misc
Normal file
1
changelog.d/4494.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add infrastructure to support different event formats
|
1
changelog.d/4495.feature
Normal file
1
changelog.d/4495.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Synapse will now reload TLS certificates from disk upon SIGHUP.
|
1
changelog.d/4496.misc
Normal file
1
changelog.d/4496.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add infrastructure to support different event formats
|
1
changelog.d/4498.misc
Normal file
1
changelog.d/4498.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Clarify documentation for the `public_baseurl` config param
|
1
changelog.d/4499.feature
Normal file
1
changelog.d/4499.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add support for room version 3
|
1
changelog.d/4506.misc
Normal file
1
changelog.d/4506.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Make it possible to set the log level for tests via an environment variable
|
1
changelog.d/4507.misc
Normal file
1
changelog.d/4507.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Reduce the log level of linearizer lock acquirement to DEBUG.
|
1
changelog.d/4509.removal
Normal file
1
changelog.d/4509.removal
Normal file
|
@ -0,0 +1 @@
|
|||
Synapse no longer generates self-signed TLS certificates when generating a configuration file.
|
1
changelog.d/4510.misc
Normal file
1
changelog.d/4510.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add infrastructure to support different event formats
|
1
changelog.d/4511.feature
Normal file
1
changelog.d/4511.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
1
changelog.d/4512.bugfix
Normal file
1
changelog.d/4512.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a bug where setting a relative consent directory path would cause a crash.
|
1
changelog.d/4514.misc
Normal file
1
changelog.d/4514.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add infrastructure to support different event formats
|
1
changelog.d/4515.feature
Normal file
1
changelog.d/4515.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add support for room version 3
|
1
changelog.d/4516.feature
Normal file
1
changelog.d/4516.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
1
changelog.d/4519.misc
Normal file
1
changelog.d/4519.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix code to comply with linting in PyFlakes 3.7.1.
|
1
changelog.d/4520.feature
Normal file
1
changelog.d/4520.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
1
changelog.d/4521.feature
Normal file
1
changelog.d/4521.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement MSC1708 (.well-known routing for server-server federation)
|
|
@ -46,7 +46,7 @@ def request_registration(
|
|||
# Get the nonce
|
||||
r = requests.get(url, verify=False)
|
||||
|
||||
if r.status_code is not 200:
|
||||
if r.status_code != 200:
|
||||
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
|
||||
if 400 <= r.status_code < 500:
|
||||
try:
|
||||
|
@ -84,7 +84,7 @@ def request_registration(
|
|||
_print("Sending registration request...")
|
||||
r = requests.post(url, json=data, verify=False)
|
||||
|
||||
if r.status_code is not 200:
|
||||
if r.status_code != 200:
|
||||
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
|
||||
if 400 <= r.status_code < 500:
|
||||
try:
|
||||
|
|
|
@ -550,17 +550,6 @@ class Auth(object):
|
|||
"""
|
||||
return self.store.is_server_admin(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_auth_events(self, builder, context):
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
|
||||
|
||||
auth_events_entries = yield self.store.add_event_hashes(
|
||||
auth_ids
|
||||
)
|
||||
|
||||
builder.auth_events = auth_events_entries
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_auth_events(self, event, current_state_ids, for_verification=False):
|
||||
if event.type == EventTypes.Create:
|
||||
|
@ -577,7 +566,7 @@ class Auth(object):
|
|||
key = (EventTypes.JoinRules, "", )
|
||||
join_rule_event_id = current_state_ids.get(key)
|
||||
|
||||
key = (EventTypes.Member, event.user_id, )
|
||||
key = (EventTypes.Member, event.sender, )
|
||||
member_event_id = current_state_ids.get(key)
|
||||
|
||||
key = (EventTypes.Create, "", )
|
||||
|
@ -627,7 +616,7 @@ class Auth(object):
|
|||
|
||||
defer.returnValue(auth_ids)
|
||||
|
||||
def check_redaction(self, event, auth_events):
|
||||
def check_redaction(self, room_version, event, auth_events):
|
||||
"""Check whether the event sender is allowed to redact the target event.
|
||||
|
||||
Returns:
|
||||
|
@ -640,7 +629,7 @@ class Auth(object):
|
|||
AuthError if the event sender is definitely not allowed to redact
|
||||
the target event.
|
||||
"""
|
||||
return event_auth.check_redaction(event, auth_events)
|
||||
return event_auth.check_redaction(room_version, event, auth_events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_can_change_room_list(self, room_id, user):
|
||||
|
|
|
@ -104,7 +104,7 @@ class ThirdPartyEntityKind(object):
|
|||
class RoomVersions(object):
|
||||
V1 = "1"
|
||||
V2 = "2"
|
||||
VDH_TEST = "vdh-test-version"
|
||||
V3 = "3"
|
||||
STATE_V2_TEST = "state-v2-test"
|
||||
|
||||
|
||||
|
@ -116,7 +116,7 @@ DEFAULT_ROOM_VERSION = RoomVersions.V1
|
|||
KNOWN_ROOM_VERSIONS = {
|
||||
RoomVersions.V1,
|
||||
RoomVersions.V2,
|
||||
RoomVersions.VDH_TEST,
|
||||
RoomVersions.V3,
|
||||
RoomVersions.STATE_V2_TEST,
|
||||
}
|
||||
|
||||
|
@ -126,10 +126,12 @@ class EventFormatVersions(object):
|
|||
independently from the room version.
|
||||
"""
|
||||
V1 = 1
|
||||
V2 = 2
|
||||
|
||||
|
||||
KNOWN_EVENT_FORMAT_VERSIONS = {
|
||||
EventFormatVersions.V1,
|
||||
EventFormatVersions.V2,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -143,6 +143,9 @@ def listen_metrics(bind_addresses, port):
|
|||
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
|
||||
"""
|
||||
Create a TCP socket for a port and several addresses
|
||||
|
||||
Returns:
|
||||
list (empty)
|
||||
"""
|
||||
for address in bind_addresses:
|
||||
try:
|
||||
|
@ -155,25 +158,37 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
|
|||
except error.CannotListenError as e:
|
||||
check_bind_error(e, address, bind_addresses)
|
||||
|
||||
logger.info("Synapse now listening on TCP port %d", port)
|
||||
return []
|
||||
|
||||
|
||||
def listen_ssl(
|
||||
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
|
||||
):
|
||||
"""
|
||||
Create an SSL socket for a port and several addresses
|
||||
Create an TLS-over-TCP socket for a port and several addresses
|
||||
|
||||
Returns:
|
||||
list of twisted.internet.tcp.Port listening for TLS connections
|
||||
"""
|
||||
r = []
|
||||
for address in bind_addresses:
|
||||
try:
|
||||
reactor.listenSSL(
|
||||
port,
|
||||
factory,
|
||||
context_factory,
|
||||
backlog,
|
||||
address
|
||||
r.append(
|
||||
reactor.listenSSL(
|
||||
port,
|
||||
factory,
|
||||
context_factory,
|
||||
backlog,
|
||||
address
|
||||
)
|
||||
)
|
||||
except error.CannotListenError as e:
|
||||
check_bind_error(e, address, bind_addresses)
|
||||
|
||||
logger.info("Synapse now listening on port %d (TLS)", port)
|
||||
return r
|
||||
|
||||
|
||||
def check_bind_error(e, address, bind_addresses):
|
||||
"""
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
import gc
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
|
@ -27,6 +28,7 @@ from prometheus_client import Gauge
|
|||
|
||||
from twisted.application import service
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
||||
from twisted.web.server import GzipEncoderFactory
|
||||
from twisted.web.static import File
|
||||
|
@ -84,6 +86,7 @@ def gz_wrap(r):
|
|||
|
||||
class SynapseHomeServer(HomeServer):
|
||||
DATASTORE_CLASS = DataStore
|
||||
_listening_services = []
|
||||
|
||||
def _listener_http(self, config, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
@ -121,7 +124,7 @@ class SynapseHomeServer(HomeServer):
|
|||
root_resource = create_resource_tree(resources, root_resource)
|
||||
|
||||
if tls:
|
||||
listen_ssl(
|
||||
return listen_ssl(
|
||||
bind_addresses,
|
||||
port,
|
||||
SynapseSite(
|
||||
|
@ -135,7 +138,7 @@ class SynapseHomeServer(HomeServer):
|
|||
)
|
||||
|
||||
else:
|
||||
listen_tcp(
|
||||
return listen_tcp(
|
||||
bind_addresses,
|
||||
port,
|
||||
SynapseSite(
|
||||
|
@ -146,7 +149,6 @@ class SynapseHomeServer(HomeServer):
|
|||
self.version_string,
|
||||
)
|
||||
)
|
||||
logger.info("Synapse now listening on port %d", port)
|
||||
|
||||
def _configure_named_resource(self, name, compress=False):
|
||||
"""Build a resource map for a named resource
|
||||
|
@ -242,7 +244,9 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
for listener in config.listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listener_http(config, listener)
|
||||
self._listening_services.extend(
|
||||
self._listener_http(config, listener)
|
||||
)
|
||||
elif listener["type"] == "manhole":
|
||||
listen_tcp(
|
||||
listener["bind_addresses"],
|
||||
|
@ -322,7 +326,19 @@ def setup(config_options):
|
|||
# generating config files and shouldn't try to continue.
|
||||
sys.exit(0)
|
||||
|
||||
synapse.config.logger.setup_logging(config, use_worker_options=False)
|
||||
sighup_callbacks = []
|
||||
synapse.config.logger.setup_logging(
|
||||
config,
|
||||
use_worker_options=False,
|
||||
register_sighup=sighup_callbacks.append
|
||||
)
|
||||
|
||||
def handle_sighup(*args, **kwargs):
|
||||
for i in sighup_callbacks:
|
||||
i(*args, **kwargs)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, handle_sighup)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -359,6 +375,31 @@ def setup(config_options):
|
|||
|
||||
hs.setup()
|
||||
|
||||
def refresh_certificate(*args):
|
||||
"""
|
||||
Refresh the TLS certificates that Synapse is using by re-reading them
|
||||
from disk and updating the TLS context factories to use them.
|
||||
"""
|
||||
logging.info("Reloading certificate from disk...")
|
||||
hs.config.read_certificate_from_disk()
|
||||
hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
|
||||
config
|
||||
)
|
||||
logging.info("Certificate reloaded.")
|
||||
|
||||
logging.info("Updating context factories...")
|
||||
for i in hs._listening_services:
|
||||
if isinstance(i.factory, TLSMemoryBIOFactory):
|
||||
i.factory = TLSMemoryBIOFactory(
|
||||
hs.tls_server_context_factory,
|
||||
False,
|
||||
i.factory.wrappedFactory
|
||||
)
|
||||
logging.info("Context factories updated.")
|
||||
|
||||
sighup_callbacks.append(refresh_certificate)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start():
|
||||
try:
|
||||
|
|
|
@ -13,6 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from os import path
|
||||
|
||||
from synapse.config import ConfigError
|
||||
|
||||
from ._base import Config
|
||||
|
||||
DEFAULT_CONFIG = """\
|
||||
|
@ -85,7 +89,15 @@ class ConsentConfig(Config):
|
|||
if consent_config is None:
|
||||
return
|
||||
self.user_consent_version = str(consent_config["version"])
|
||||
self.user_consent_template_dir = consent_config["template_dir"]
|
||||
self.user_consent_template_dir = self.abspath(
|
||||
consent_config["template_dir"]
|
||||
)
|
||||
if not path.isdir(self.user_consent_template_dir):
|
||||
raise ConfigError(
|
||||
"Could not find template directory '%s'" % (
|
||||
self.user_consent_template_dir,
|
||||
),
|
||||
)
|
||||
self.user_consent_server_notice_content = consent_config.get(
|
||||
"server_notice_content",
|
||||
)
|
||||
|
|
|
@ -127,7 +127,7 @@ class LoggingConfig(Config):
|
|||
)
|
||||
|
||||
|
||||
def setup_logging(config, use_worker_options=False):
|
||||
def setup_logging(config, use_worker_options=False, register_sighup=None):
|
||||
""" Set up python logging
|
||||
|
||||
Args:
|
||||
|
@ -136,7 +136,16 @@ def setup_logging(config, use_worker_options=False):
|
|||
|
||||
use_worker_options (bool): True to use 'worker_log_config' and
|
||||
'worker_log_file' options instead of 'log_config' and 'log_file'.
|
||||
|
||||
register_sighup (func | None): Function to call to register a
|
||||
sighup handler.
|
||||
"""
|
||||
if not register_sighup:
|
||||
if getattr(signal, "SIGHUP"):
|
||||
register_sighup = lambda x: signal.signal(signal.SIGHUP, x)
|
||||
else:
|
||||
register_sighup = lambda x: None
|
||||
|
||||
log_config = (config.worker_log_config if use_worker_options
|
||||
else config.log_config)
|
||||
log_file = (config.worker_log_file if use_worker_options
|
||||
|
@ -198,13 +207,7 @@ def setup_logging(config, use_worker_options=False):
|
|||
|
||||
load_log_config()
|
||||
|
||||
# TODO(paul): obviously this is a terrible mechanism for
|
||||
# stealing SIGHUP, because it means no other part of synapse
|
||||
# can use it instead. If we want to catch SIGHUP anywhere
|
||||
# else as well, I'd suggest we find a nicer way to broadcast
|
||||
# it around.
|
||||
if getattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, sighup)
|
||||
register_sighup(sighup)
|
||||
|
||||
# make sure that the first thing we log is a thing we can grep backwards
|
||||
# for
|
||||
|
|
|
@ -261,7 +261,7 @@ class ServerConfig(Config):
|
|||
# enter into the 'custom HS URL' field on their client. If you
|
||||
# use synapse with a reverse proxy, this should be the URL to reach
|
||||
# synapse via the proxy.
|
||||
# public_baseurl: https://example.com:8448/
|
||||
# public_baseurl: https://example.com/
|
||||
|
||||
# Set the soft limit on the number of file descriptors synapse can use
|
||||
# Zero is used to indicate synapse should set the soft limit to the
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
|
||||
|
@ -39,8 +40,8 @@ class TlsConfig(Config):
|
|||
self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
|
||||
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
|
||||
|
||||
self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
|
||||
self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
|
||||
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
|
||||
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
|
||||
self._original_tls_fingerprints = config["tls_fingerprints"]
|
||||
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||
self.no_tls = config.get("no_tls", False)
|
||||
|
@ -94,6 +95,16 @@ class TlsConfig(Config):
|
|||
"""
|
||||
self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
|
||||
|
||||
# Check if it is self-signed, and issue a warning if so.
|
||||
if self.tls_certificate.get_issuer() == self.tls_certificate.get_subject():
|
||||
warnings.warn(
|
||||
(
|
||||
"Self-signed TLS certificates will not be accepted by Synapse 1.0. "
|
||||
"Please either provide a valid certificate, or use Synapse's ACME "
|
||||
"support to provision one."
|
||||
)
|
||||
)
|
||||
|
||||
if not self.no_tls:
|
||||
self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
|
||||
|
||||
|
@ -118,10 +129,11 @@ class TlsConfig(Config):
|
|||
return (
|
||||
"""\
|
||||
# PEM encoded X509 certificate for TLS.
|
||||
# You can replace the self-signed certificate that synapse
|
||||
# autogenerates on launch with your own SSL certificate + key pair
|
||||
# if you like. Any required intermediary certificates can be
|
||||
# appended after the primary certificate in hierarchical order.
|
||||
# This certificate, as of Synapse 1.0, will need to be a valid
|
||||
# and verifiable certificate, with a root that is available in
|
||||
# the root store of other servers you wish to federate to. Any
|
||||
# required intermediary certificates can be appended after the
|
||||
# primary certificate in hierarchical order.
|
||||
tls_certificate_path: "%(tls_certificate_path)s"
|
||||
|
||||
# PEM encoded private key for TLS
|
||||
|
@ -183,40 +195,3 @@ class TlsConfig(Config):
|
|||
def read_tls_private_key(self, private_key_path):
|
||||
private_key_pem = self.read_file(private_key_path, "tls_private_key")
|
||||
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
|
||||
|
||||
def generate_files(self, config):
|
||||
tls_certificate_path = config["tls_certificate_path"]
|
||||
tls_private_key_path = config["tls_private_key_path"]
|
||||
|
||||
if not self.path_exists(tls_private_key_path):
|
||||
with open(tls_private_key_path, "wb") as private_key_file:
|
||||
tls_private_key = crypto.PKey()
|
||||
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
|
||||
private_key_pem = crypto.dump_privatekey(
|
||||
crypto.FILETYPE_PEM, tls_private_key
|
||||
)
|
||||
private_key_file.write(private_key_pem)
|
||||
else:
|
||||
with open(tls_private_key_path) as private_key_file:
|
||||
private_key_pem = private_key_file.read()
|
||||
tls_private_key = crypto.load_privatekey(
|
||||
crypto.FILETYPE_PEM, private_key_pem
|
||||
)
|
||||
|
||||
if not self.path_exists(tls_certificate_path):
|
||||
with open(tls_certificate_path, "wb") as certificate_file:
|
||||
cert = crypto.X509()
|
||||
subject = cert.get_subject()
|
||||
subject.CN = config["server_name"]
|
||||
|
||||
cert.set_serial_number(1000)
|
||||
cert.gmtime_adj_notBefore(0)
|
||||
cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
|
||||
cert.set_issuer(cert.get_subject())
|
||||
cert.set_pubkey(tls_private_key)
|
||||
|
||||
cert.sign(tls_private_key, 'sha256')
|
||||
|
||||
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
|
||||
|
||||
certificate_file.write(cert_pem)
|
||||
|
|
|
@ -131,12 +131,12 @@ def compute_event_signature(event_dict, signature_name, signing_key):
|
|||
return redact_json["signatures"]
|
||||
|
||||
|
||||
def add_hashes_and_signatures(event, signature_name, signing_key,
|
||||
def add_hashes_and_signatures(event_dict, signature_name, signing_key,
|
||||
hash_algorithm=hashlib.sha256):
|
||||
"""Add content hash and sign the event
|
||||
|
||||
Args:
|
||||
event_dict (EventBuilder): The event to add hashes to and sign
|
||||
event_dict (dict): The event to add hashes to and sign
|
||||
signature_name (str): The name of the entity signing the event
|
||||
(typically the server's hostname).
|
||||
signing_key (syutil.crypto.SigningKey): The key to sign with
|
||||
|
@ -144,16 +144,12 @@ def add_hashes_and_signatures(event, signature_name, signing_key,
|
|||
to hash the event
|
||||
"""
|
||||
|
||||
name, digest = compute_content_hash(
|
||||
event.get_pdu_json(), hash_algorithm=hash_algorithm,
|
||||
)
|
||||
name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm)
|
||||
|
||||
if not hasattr(event, "hashes"):
|
||||
event.hashes = {}
|
||||
event.hashes[name] = encode_base64(digest)
|
||||
event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
|
||||
|
||||
event.signatures = compute_event_signature(
|
||||
event.get_pdu_json(),
|
||||
event_dict["signatures"] = compute_event_signature(
|
||||
event_dict,
|
||||
signature_name=signature_name,
|
||||
signing_key=signing_key,
|
||||
)
|
||||
|
|
|
@ -20,7 +20,14 @@ from signedjson.key import decode_verify_key_bytes
|
|||
from signedjson.sign import SignatureVerifyException, verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, JoinRules, Membership
|
||||
from synapse.api.constants import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventFormatVersions,
|
||||
EventTypes,
|
||||
JoinRules,
|
||||
Membership,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.api.errors import AuthError, EventSizeError, SynapseError
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
|
||||
|
@ -49,7 +56,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
|
|||
|
||||
if do_sig_check:
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
event_id_domain = get_domain_from_id(event.event_id)
|
||||
|
||||
is_invite_via_3pid = (
|
||||
event.type == EventTypes.Member
|
||||
|
@ -66,9 +72,13 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
|
|||
if not is_invite_via_3pid:
|
||||
raise AuthError(403, "Event not signed by sender's server")
|
||||
|
||||
# Check the event_id's domain has signed the event
|
||||
if not event.signatures.get(event_id_domain):
|
||||
raise AuthError(403, "Event not signed by sending server")
|
||||
if event.format_version in (EventFormatVersions.V1,):
|
||||
# Only older room versions have event IDs to check.
|
||||
event_id_domain = get_domain_from_id(event.event_id)
|
||||
|
||||
# Check the origin domain has signed the event
|
||||
if not event.signatures.get(event_id_domain):
|
||||
raise AuthError(403, "Event not signed by sending server")
|
||||
|
||||
if auth_events is None:
|
||||
# Oh, we don't know what the state of the room was, so we
|
||||
|
@ -168,7 +178,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
|
|||
_check_power_levels(event, auth_events)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
check_redaction(event, auth_events)
|
||||
check_redaction(room_version, event, auth_events)
|
||||
|
||||
logger.debug("Allowing! %s", event)
|
||||
|
||||
|
@ -422,7 +432,7 @@ def _can_send_event(event, auth_events):
|
|||
return True
|
||||
|
||||
|
||||
def check_redaction(event, auth_events):
|
||||
def check_redaction(room_version, event, auth_events):
|
||||
"""Check whether the event sender is allowed to redact the target event.
|
||||
|
||||
Returns:
|
||||
|
@ -442,10 +452,16 @@ def check_redaction(event, auth_events):
|
|||
if user_level >= redact_level:
|
||||
return False
|
||||
|
||||
redacter_domain = get_domain_from_id(event.event_id)
|
||||
redactee_domain = get_domain_from_id(event.redacts)
|
||||
if redacter_domain == redactee_domain:
|
||||
if room_version in (RoomVersions.V1, RoomVersions.V2,):
|
||||
redacter_domain = get_domain_from_id(event.event_id)
|
||||
redactee_domain = get_domain_from_id(event.redacts)
|
||||
if redacter_domain == redactee_domain:
|
||||
return True
|
||||
elif room_version == RoomVersions.V3:
|
||||
event.internal_metadata.recheck_redaction = True
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError("Unrecognized room version %r" % (room_version,))
|
||||
|
||||
raise AuthError(
|
||||
403,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,11 +19,9 @@ from distutils.util import strtobool
|
|||
|
||||
import six
|
||||
|
||||
from synapse.api.constants import (
|
||||
KNOWN_EVENT_FORMAT_VERSIONS,
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventFormatVersions,
|
||||
)
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
|
@ -63,6 +62,21 @@ class _EventInternalMetadata(object):
|
|||
"""
|
||||
return getattr(self, "send_on_behalf_of", None)
|
||||
|
||||
def need_to_check_redaction(self):
|
||||
"""Whether the redaction event needs to be rechecked when fetching
|
||||
from the database.
|
||||
|
||||
Starting in room v3 redaction events are accepted up front, and later
|
||||
checked to see if the redacter and redactee's domains match.
|
||||
|
||||
If the sender of the redaction event is allowed to redact any event
|
||||
due to auth rules, then this will always return false.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
return getattr(self, "recheck_redaction", False)
|
||||
|
||||
|
||||
def _event_dict_property(key):
|
||||
# We want to be able to use hasattr with the event dict properties.
|
||||
|
@ -225,16 +239,6 @@ class FrozenEvent(EventBase):
|
|||
rejected_reason=rejected_reason,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_event(event):
|
||||
e = FrozenEvent(
|
||||
event.get_pdu_json()
|
||||
)
|
||||
|
||||
e.internal_metadata = event.internal_metadata
|
||||
|
||||
return e
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
@ -246,6 +250,85 @@ class FrozenEvent(EventBase):
|
|||
)
|
||||
|
||||
|
||||
class FrozenEventV2(EventBase):
|
||||
format_version = EventFormatVersions.V2 # All events of this type are V2
|
||||
|
||||
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
||||
event_dict = dict(event_dict)
|
||||
|
||||
# Signatures is a dict of dicts, and this is faster than doing a
|
||||
# copy.deepcopy
|
||||
signatures = {
|
||||
name: {sig_id: sig for sig_id, sig in sigs.items()}
|
||||
for name, sigs in event_dict.pop("signatures", {}).items()
|
||||
}
|
||||
|
||||
assert "event_id" not in event_dict
|
||||
|
||||
unsigned = dict(event_dict.pop("unsigned", {}))
|
||||
|
||||
# We intern these strings because they turn up a lot (especially when
|
||||
# caching).
|
||||
event_dict = intern_dict(event_dict)
|
||||
|
||||
if USE_FROZEN_DICTS:
|
||||
frozen_dict = freeze(event_dict)
|
||||
else:
|
||||
frozen_dict = event_dict
|
||||
|
||||
self._event_id = None
|
||||
self.type = event_dict["type"]
|
||||
if "state_key" in event_dict:
|
||||
self.state_key = event_dict["state_key"]
|
||||
|
||||
super(FrozenEventV2, self).__init__(
|
||||
frozen_dict,
|
||||
signatures=signatures,
|
||||
unsigned=unsigned,
|
||||
internal_metadata_dict=internal_metadata_dict,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
|
||||
@property
|
||||
def event_id(self):
|
||||
# We have to import this here as otherwise we get an import loop which
|
||||
# is hard to break.
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
if self._event_id:
|
||||
return self._event_id
|
||||
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
|
||||
return self._event_id
|
||||
|
||||
def prev_event_ids(self):
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of event IDs of this event's prev_events
|
||||
"""
|
||||
return self.prev_events
|
||||
|
||||
def auth_event_ids(self):
|
||||
"""Returns the list of auth event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of event IDs of this event's auth_events
|
||||
"""
|
||||
return self.auth_events
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
|
||||
self.event_id,
|
||||
self.get("type", None),
|
||||
self.get("state_key", None),
|
||||
)
|
||||
|
||||
|
||||
def room_version_to_event_format(room_version):
|
||||
"""Converts a room version string to the event format
|
||||
|
||||
|
@ -259,7 +342,14 @@ def room_version_to_event_format(room_version):
|
|||
# We should have already checked version, so this should not happen
|
||||
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||
|
||||
return EventFormatVersions.V1
|
||||
if room_version in (
|
||||
RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
|
||||
):
|
||||
return EventFormatVersions.V1
|
||||
elif room_version in (RoomVersions.V3,):
|
||||
return EventFormatVersions.V2
|
||||
else:
|
||||
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||
|
||||
|
||||
def event_type_from_format_version(format_version):
|
||||
|
@ -273,8 +363,12 @@ def event_type_from_format_version(format_version):
|
|||
type: A type that can be initialized as per the initializer of
|
||||
`FrozenEvent`
|
||||
"""
|
||||
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
|
||||
|
||||
if format_version == EventFormatVersions.V1:
|
||||
return FrozenEvent
|
||||
elif format_version == EventFormatVersions.V2:
|
||||
return FrozenEventV2
|
||||
else:
|
||||
raise Exception(
|
||||
"No event format %r" % (format_version,)
|
||||
)
|
||||
return FrozenEvent
|
||||
|
|
|
@ -13,79 +13,161 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import RoomVersions
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import (
|
||||
KNOWN_EVENT_FORMAT_VERSIONS,
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
MAX_DEPTH,
|
||||
EventFormatVersions,
|
||||
)
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.types import EventID
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from . import EventBase, FrozenEvent, _event_dict_property
|
||||
from . import (
|
||||
_EventInternalMetadata,
|
||||
event_type_from_format_version,
|
||||
room_version_to_event_format,
|
||||
)
|
||||
|
||||
|
||||
def get_event_builder(room_version, key_values={}, internal_metadata_dict={}):
|
||||
"""Generate an event builder appropriate for the given room version
|
||||
@attr.s(slots=True, cmp=False, frozen=True)
|
||||
class EventBuilder(object):
|
||||
"""A format independent event builder used to build up the event content
|
||||
before signing the event.
|
||||
|
||||
Args:
|
||||
room_version (str): Version of the room that we're creating an
|
||||
event builder for
|
||||
key_values (dict): Fields used as the basis of the new event
|
||||
internal_metadata_dict (dict): Used to create the `_EventInternalMetadata`
|
||||
object.
|
||||
(Note that while objects of this class are frozen, the
|
||||
content/unsigned/internal_metadata fields are still mutable)
|
||||
|
||||
Returns:
|
||||
EventBuilder
|
||||
Attributes:
|
||||
format_version (int): Event format version
|
||||
room_id (str)
|
||||
type (str)
|
||||
sender (str)
|
||||
content (dict)
|
||||
unsigned (dict)
|
||||
internal_metadata (_EventInternalMetadata)
|
||||
|
||||
_state (StateHandler)
|
||||
_auth (synapse.api.Auth)
|
||||
_store (DataStore)
|
||||
_clock (Clock)
|
||||
_hostname (str): The hostname of the server creating the event
|
||||
_signing_key: The signing key to use to sign the event as the server
|
||||
"""
|
||||
if room_version in {
|
||||
RoomVersions.V1,
|
||||
RoomVersions.V2,
|
||||
RoomVersions.VDH_TEST,
|
||||
RoomVersions.STATE_V2_TEST,
|
||||
}:
|
||||
return EventBuilder(key_values, internal_metadata_dict)
|
||||
else:
|
||||
raise Exception(
|
||||
"No event format defined for version %r" % (room_version,)
|
||||
|
||||
_state = attr.ib()
|
||||
_auth = attr.ib()
|
||||
_store = attr.ib()
|
||||
_clock = attr.ib()
|
||||
_hostname = attr.ib()
|
||||
_signing_key = attr.ib()
|
||||
|
||||
format_version = attr.ib()
|
||||
|
||||
room_id = attr.ib()
|
||||
type = attr.ib()
|
||||
sender = attr.ib()
|
||||
|
||||
content = attr.ib(default=attr.Factory(dict))
|
||||
unsigned = attr.ib(default=attr.Factory(dict))
|
||||
|
||||
# These only exist on a subset of events, so they raise AttributeError if
|
||||
# someone tries to get them when they don't exist.
|
||||
_state_key = attr.ib(default=None)
|
||||
_redacts = attr.ib(default=None)
|
||||
|
||||
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
|
||||
|
||||
@property
|
||||
def state_key(self):
|
||||
if self._state_key is not None:
|
||||
return self._state_key
|
||||
|
||||
raise AttributeError("state_key")
|
||||
|
||||
def is_state(self):
|
||||
return self._state_key is not None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def build(self, prev_event_ids):
|
||||
"""Transform into a fully signed and hashed event
|
||||
|
||||
Args:
|
||||
prev_event_ids (list[str]): The event IDs to use as the prev events
|
||||
|
||||
Returns:
|
||||
Deferred[FrozenEvent]
|
||||
"""
|
||||
|
||||
state_ids = yield self._state.get_current_state_ids(
|
||||
self.room_id, prev_event_ids,
|
||||
)
|
||||
auth_ids = yield self._auth.compute_auth_events(
|
||||
self, state_ids,
|
||||
)
|
||||
|
||||
if self.format_version == EventFormatVersions.V1:
|
||||
auth_events = yield self._store.add_event_hashes(auth_ids)
|
||||
prev_events = yield self._store.add_event_hashes(prev_event_ids)
|
||||
else:
|
||||
auth_events = auth_ids
|
||||
prev_events = prev_event_ids
|
||||
|
||||
class EventBuilder(EventBase):
|
||||
def __init__(self, key_values={}, internal_metadata_dict={}):
|
||||
signatures = copy.deepcopy(key_values.pop("signatures", {}))
|
||||
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
|
||||
|
||||
super(EventBuilder, self).__init__(
|
||||
key_values,
|
||||
signatures=signatures,
|
||||
unsigned=unsigned,
|
||||
internal_metadata_dict=internal_metadata_dict,
|
||||
old_depth = yield self._store.get_max_depth_of(
|
||||
prev_event_ids,
|
||||
)
|
||||
depth = old_depth + 1
|
||||
|
||||
event_id = _event_dict_property("event_id")
|
||||
state_key = _event_dict_property("state_key")
|
||||
type = _event_dict_property("type")
|
||||
# we cap depth of generated events, to ensure that they are not
|
||||
# rejected by other servers (and so that they can be persisted in
|
||||
# the db)
|
||||
depth = min(depth, MAX_DEPTH)
|
||||
|
||||
def build(self):
|
||||
return FrozenEvent.from_event(self)
|
||||
event_dict = {
|
||||
"auth_events": auth_events,
|
||||
"prev_events": prev_events,
|
||||
"type": self.type,
|
||||
"room_id": self.room_id,
|
||||
"sender": self.sender,
|
||||
"content": self.content,
|
||||
"unsigned": self.unsigned,
|
||||
"depth": depth,
|
||||
"prev_state": [],
|
||||
}
|
||||
|
||||
if self.is_state():
|
||||
event_dict["state_key"] = self._state_key
|
||||
|
||||
if self._redacts is not None:
|
||||
event_dict["redacts"] = self._redacts
|
||||
|
||||
defer.returnValue(
|
||||
create_local_event_from_event_dict(
|
||||
clock=self._clock,
|
||||
hostname=self._hostname,
|
||||
signing_key=self._signing_key,
|
||||
format_version=self.format_version,
|
||||
event_dict=event_dict,
|
||||
internal_metadata_dict=self.internal_metadata.get_dict(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class EventBuilderFactory(object):
|
||||
def __init__(self, clock, hostname):
|
||||
self.clock = clock
|
||||
self.hostname = hostname
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.hostname = hs.hostname
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
|
||||
self.event_id_count = 0
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
def create_event_id(self):
|
||||
i = str(self.event_id_count)
|
||||
self.event_id_count += 1
|
||||
|
||||
local_part = str(int(self.clock.time())) + i + random_string(5)
|
||||
|
||||
e_id = EventID(local_part, self.hostname)
|
||||
|
||||
return e_id.to_string()
|
||||
|
||||
def new(self, room_version, key_values={}):
|
||||
def new(self, room_version, key_values):
|
||||
"""Generate an event builder appropriate for the given room version
|
||||
|
||||
Args:
|
||||
|
@ -98,27 +180,103 @@ class EventBuilderFactory(object):
|
|||
"""
|
||||
|
||||
# There's currently only the one event version defined
|
||||
if room_version not in {
|
||||
RoomVersions.V1,
|
||||
RoomVersions.V2,
|
||||
RoomVersions.VDH_TEST,
|
||||
RoomVersions.STATE_V2_TEST,
|
||||
}:
|
||||
if room_version not in KNOWN_ROOM_VERSIONS:
|
||||
raise Exception(
|
||||
"No event format defined for version %r" % (room_version,)
|
||||
)
|
||||
|
||||
key_values["event_id"] = self.create_event_id()
|
||||
return EventBuilder(
|
||||
store=self.store,
|
||||
state=self.state,
|
||||
auth=self.auth,
|
||||
clock=self.clock,
|
||||
hostname=self.hostname,
|
||||
signing_key=self.signing_key,
|
||||
format_version=room_version_to_event_format(room_version),
|
||||
type=key_values["type"],
|
||||
state_key=key_values.get("state_key"),
|
||||
room_id=key_values["room_id"],
|
||||
sender=key_values["sender"],
|
||||
content=key_values.get("content", {}),
|
||||
unsigned=key_values.get("unsigned", {}),
|
||||
redacts=key_values.get("redacts", None),
|
||||
)
|
||||
|
||||
time_now = int(self.clock.time_msec())
|
||||
|
||||
key_values.setdefault("origin", self.hostname)
|
||||
key_values.setdefault("origin_server_ts", time_now)
|
||||
def create_local_event_from_event_dict(clock, hostname, signing_key,
|
||||
format_version, event_dict,
|
||||
internal_metadata_dict=None):
|
||||
"""Takes a fully formed event dict, ensuring that fields like `origin`
|
||||
and `origin_server_ts` have correct values for a locally produced event,
|
||||
then signs and hashes it.
|
||||
|
||||
key_values.setdefault("unsigned", {})
|
||||
age = key_values["unsigned"].pop("age", 0)
|
||||
key_values["unsigned"].setdefault("age_ts", time_now - age)
|
||||
Args:
|
||||
clock (Clock)
|
||||
hostname (str)
|
||||
signing_key
|
||||
format_version (int)
|
||||
event_dict (dict)
|
||||
internal_metadata_dict (dict|None)
|
||||
|
||||
key_values["signatures"] = {}
|
||||
Returns:
|
||||
FrozenEvent
|
||||
"""
|
||||
|
||||
return EventBuilder(key_values=key_values,)
|
||||
# There's currently only the one event version defined
|
||||
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
|
||||
raise Exception(
|
||||
"No event format defined for version %r" % (format_version,)
|
||||
)
|
||||
|
||||
if internal_metadata_dict is None:
|
||||
internal_metadata_dict = {}
|
||||
|
||||
time_now = int(clock.time_msec())
|
||||
|
||||
if format_version == EventFormatVersions.V1:
|
||||
event_dict["event_id"] = _create_event_id(clock, hostname)
|
||||
|
||||
event_dict["origin"] = hostname
|
||||
event_dict["origin_server_ts"] = time_now
|
||||
|
||||
event_dict.setdefault("unsigned", {})
|
||||
age = event_dict["unsigned"].pop("age", 0)
|
||||
event_dict["unsigned"].setdefault("age_ts", time_now - age)
|
||||
|
||||
event_dict.setdefault("signatures", {})
|
||||
|
||||
add_hashes_and_signatures(
|
||||
event_dict,
|
||||
hostname,
|
||||
signing_key,
|
||||
)
|
||||
return event_type_from_format_version(format_version)(
|
||||
event_dict, internal_metadata_dict=internal_metadata_dict,
|
||||
)
|
||||
|
||||
|
||||
# A counter used when generating new event IDs
|
||||
_event_id_counter = 0
|
||||
|
||||
|
||||
def _create_event_id(clock, hostname):
|
||||
"""Create a new event ID
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
hostname (str): The server name for the event ID
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
|
||||
global _event_id_counter
|
||||
|
||||
i = str(_event_id_counter)
|
||||
_event_id_counter += 1
|
||||
|
||||
local_part = str(int(clock.time())) + i + random_string(5)
|
||||
|
||||
e_id = EventID(local_part, hostname)
|
||||
|
||||
return e_id.to_string()
|
||||
|
|
|
@ -267,6 +267,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
|
|||
Returns:
|
||||
dict
|
||||
"""
|
||||
|
||||
# FIXME(erikj): To handle the case of presence events and the like
|
||||
if not isinstance(e, EventBase):
|
||||
return e
|
||||
|
@ -276,6 +277,8 @@ def serialize_event(e, time_now_ms, as_client_event=True,
|
|||
# Should this strip out None's?
|
||||
d = {k: v for k, v in e.get_dict().items()}
|
||||
|
||||
d["event_id"] = e.event_id
|
||||
|
||||
if "age_ts" in d["unsigned"]:
|
||||
d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
|
||||
del d["unsigned"]["age_ts"]
|
||||
|
|
|
@ -15,23 +15,29 @@
|
|||
|
||||
from six import string_types
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.constants import EventFormatVersions, EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import EventID, RoomID, UserID
|
||||
|
||||
|
||||
class EventValidator(object):
|
||||
def validate_new(self, event):
|
||||
"""Validates the event has roughly the right format
|
||||
|
||||
def validate(self, event):
|
||||
EventID.from_string(event.event_id)
|
||||
RoomID.from_string(event.room_id)
|
||||
Args:
|
||||
event (FrozenEvent)
|
||||
"""
|
||||
self.validate_builder(event)
|
||||
|
||||
if event.format_version == EventFormatVersions.V1:
|
||||
EventID.from_string(event.event_id)
|
||||
|
||||
required = [
|
||||
# "auth_events",
|
||||
"auth_events",
|
||||
"content",
|
||||
# "hashes",
|
||||
"hashes",
|
||||
"origin",
|
||||
# "prev_events",
|
||||
"prev_events",
|
||||
"sender",
|
||||
"type",
|
||||
]
|
||||
|
@ -41,8 +47,25 @@ class EventValidator(object):
|
|||
raise SynapseError(400, "Event does not have key %s" % (k,))
|
||||
|
||||
# Check that the following keys have string values
|
||||
strings = [
|
||||
event_strings = [
|
||||
"origin",
|
||||
]
|
||||
|
||||
for s in event_strings:
|
||||
if not isinstance(getattr(event, s), string_types):
|
||||
raise SynapseError(400, "'%s' not a string type" % (s,))
|
||||
|
||||
def validate_builder(self, event):
|
||||
"""Validates that the builder/event has roughly the right format. Only
|
||||
checks values that we expect a proto event to have, rather than all the
|
||||
fields an event would have
|
||||
|
||||
Args:
|
||||
event (EventBuilder|FrozenEvent)
|
||||
"""
|
||||
|
||||
strings = [
|
||||
"room_id",
|
||||
"sender",
|
||||
"type",
|
||||
]
|
||||
|
@ -54,22 +77,7 @@ class EventValidator(object):
|
|||
if not isinstance(getattr(event, s), string_types):
|
||||
raise SynapseError(400, "Not '%s' a string type" % (s,))
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if "membership" not in event.content:
|
||||
raise SynapseError(400, "Content has not membership key")
|
||||
|
||||
if event.content["membership"] not in Membership.LIST:
|
||||
raise SynapseError(400, "Invalid membership key")
|
||||
|
||||
# Check that the following keys have dictionary values
|
||||
# TODO
|
||||
|
||||
# Check that the following keys have the correct format for DAGs
|
||||
# TODO
|
||||
|
||||
def validate_new(self, event):
|
||||
self.validate(event)
|
||||
|
||||
RoomID.from_string(event.room_id)
|
||||
UserID.from_string(event.sender)
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
|
@ -86,9 +94,16 @@ class EventValidator(object):
|
|||
elif event.type == EventTypes.Name:
|
||||
self._ensure_strings(event.content, ["name"])
|
||||
|
||||
elif event.type == EventTypes.Member:
|
||||
if "membership" not in event.content:
|
||||
raise SynapseError(400, "Content has not membership key")
|
||||
|
||||
if event.content["membership"] not in Membership.LIST:
|
||||
raise SynapseError(400, "Invalid membership key")
|
||||
|
||||
def _ensure_strings(self, d, keys):
|
||||
for s in keys:
|
||||
if s not in d:
|
||||
raise SynapseError(400, "'%s' not in content" % (s,))
|
||||
if not isinstance(d[s], string_types):
|
||||
raise SynapseError(400, "Not '%s' a string type" % (s,))
|
||||
raise SynapseError(400, "'%s' not a string type" % (s,))
|
||||
|
|
|
@ -20,7 +20,7 @@ import six
|
|||
from twisted.internet import defer
|
||||
from twisted.internet.defer import DeferredList
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.events import event_type_from_format_version
|
||||
|
@ -66,7 +66,7 @@ class FederationBase(object):
|
|||
Returns:
|
||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||
"""
|
||||
deferreds = self._check_sigs_and_hashes(pdus)
|
||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_check_result(pdu, deferred):
|
||||
|
@ -121,16 +121,17 @@ class FederationBase(object):
|
|||
else:
|
||||
defer.returnValue([p for p in valid_pdus if p])
|
||||
|
||||
def _check_sigs_and_hash(self, pdu):
|
||||
def _check_sigs_and_hash(self, room_version, pdu):
|
||||
return logcontext.make_deferred_yieldable(
|
||||
self._check_sigs_and_hashes([pdu])[0],
|
||||
self._check_sigs_and_hashes(room_version, [pdu])[0],
|
||||
)
|
||||
|
||||
def _check_sigs_and_hashes(self, pdus):
|
||||
def _check_sigs_and_hashes(self, room_version, pdus):
|
||||
"""Checks that each of the received events is correctly signed by the
|
||||
sending server.
|
||||
|
||||
Args:
|
||||
room_version (str): The room version of the PDUs
|
||||
pdus (list[FrozenEvent]): the events to be checked
|
||||
|
||||
Returns:
|
||||
|
@ -141,7 +142,7 @@ class FederationBase(object):
|
|||
* throws a SynapseError if the signature check failed.
|
||||
The deferreds run their callbacks in the sentinel logcontext.
|
||||
"""
|
||||
deferreds = _check_sigs_on_pdus(self.keyring, pdus)
|
||||
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
|
||||
|
||||
ctx = logcontext.LoggingContext.current_context()
|
||||
|
||||
|
@ -203,16 +204,17 @@ class FederationBase(object):
|
|||
|
||||
|
||||
class PduToCheckSig(namedtuple("PduToCheckSig", [
|
||||
"pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
|
||||
"pdu", "redacted_pdu_json", "sender_domain", "deferreds",
|
||||
])):
|
||||
pass
|
||||
|
||||
|
||||
def _check_sigs_on_pdus(keyring, pdus):
|
||||
def _check_sigs_on_pdus(keyring, room_version, pdus):
|
||||
"""Check that the given events are correctly signed
|
||||
|
||||
Args:
|
||||
keyring (synapse.crypto.Keyring): keyring object to do the checks
|
||||
room_version (str): the room version of the PDUs
|
||||
pdus (Collection[EventBase]): the events to be checked
|
||||
|
||||
Returns:
|
||||
|
@ -225,9 +227,7 @@ def _check_sigs_on_pdus(keyring, pdus):
|
|||
|
||||
# we want to check that the event is signed by:
|
||||
#
|
||||
# (a) the server which created the event_id
|
||||
#
|
||||
# (b) the sender's server.
|
||||
# (a) the sender's server
|
||||
#
|
||||
# - except in the case of invites created from a 3pid invite, which are exempt
|
||||
# from this check, because the sender has to match that of the original 3pid
|
||||
|
@ -241,34 +241,26 @@ def _check_sigs_on_pdus(keyring, pdus):
|
|||
# and signatures are *supposed* to be valid whether or not an event has been
|
||||
# redacted. But this isn't the worst of the ways that 3pid invites are broken.
|
||||
#
|
||||
# (b) for V1 and V2 rooms, the server which created the event_id
|
||||
#
|
||||
# let's start by getting the domain for each pdu, and flattening the event back
|
||||
# to JSON.
|
||||
|
||||
pdus_to_check = [
|
||||
PduToCheckSig(
|
||||
pdu=p,
|
||||
redacted_pdu_json=prune_event(p).get_pdu_json(),
|
||||
event_id_domain=get_domain_from_id(p.event_id),
|
||||
sender_domain=get_domain_from_id(p.sender),
|
||||
deferreds=[],
|
||||
)
|
||||
for p in pdus
|
||||
]
|
||||
|
||||
# first make sure that the event is signed by the event_id's domain
|
||||
deferreds = keyring.verify_json_objects_for_server([
|
||||
(p.event_id_domain, p.redacted_pdu_json)
|
||||
for p in pdus_to_check
|
||||
])
|
||||
|
||||
for p, d in zip(pdus_to_check, deferreds):
|
||||
p.deferreds.append(d)
|
||||
|
||||
# now let's look for events where the sender's domain is different to the
|
||||
# event id's domain (normally only the case for joins/leaves), and add additional
|
||||
# checks.
|
||||
# First we check that the sender event is signed by the sender's domain
|
||||
# (except if its a 3pid invite, in which case it may be sent by any server)
|
||||
pdus_to_check_sender = [
|
||||
p for p in pdus_to_check
|
||||
if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
|
||||
if not _is_invite_via_3pid(p.pdu)
|
||||
]
|
||||
|
||||
more_deferreds = keyring.verify_json_objects_for_server([
|
||||
|
@ -279,19 +271,43 @@ def _check_sigs_on_pdus(keyring, pdus):
|
|||
for p, d in zip(pdus_to_check_sender, more_deferreds):
|
||||
p.deferreds.append(d)
|
||||
|
||||
# now let's look for events where the sender's domain is different to the
|
||||
# event id's domain (normally only the case for joins/leaves), and add additional
|
||||
# checks. Only do this if the room version has a concept of event ID domain
|
||||
if room_version in (
|
||||
RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
|
||||
):
|
||||
pdus_to_check_event_id = [
|
||||
p for p in pdus_to_check
|
||||
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
|
||||
]
|
||||
|
||||
more_deferreds = keyring.verify_json_objects_for_server([
|
||||
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
|
||||
for p in pdus_to_check_event_id
|
||||
])
|
||||
|
||||
for p, d in zip(pdus_to_check_event_id, more_deferreds):
|
||||
p.deferreds.append(d)
|
||||
elif room_version in (RoomVersions.V3,):
|
||||
pass # No further checks needed, as event IDs are hashes here
|
||||
else:
|
||||
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||
|
||||
# replace lists of deferreds with single Deferreds
|
||||
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
|
||||
|
||||
|
||||
def _flatten_deferred_list(deferreds):
|
||||
"""Given a list of one or more deferreds, either return the single deferred, or
|
||||
combine into a DeferredList.
|
||||
"""Given a list of deferreds, either return the single deferred,
|
||||
combine into a DeferredList, or return an already resolved deferred.
|
||||
"""
|
||||
if len(deferreds) > 1:
|
||||
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
|
||||
else:
|
||||
assert len(deferreds) == 1
|
||||
elif len(deferreds) == 1:
|
||||
return deferreds[0]
|
||||
else:
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
def _is_invite_via_3pid(event):
|
||||
|
@ -319,7 +335,7 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
|
|||
"""
|
||||
# we could probably enforce a bunch of other fields here (room_id, sender,
|
||||
# origin, etc etc)
|
||||
assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
|
||||
assert_params_in_dict(pdu_json, ('type', 'depth'))
|
||||
|
||||
depth = pdu_json['depth']
|
||||
if not isinstance(depth, six.integer_types):
|
||||
|
|
|
@ -37,8 +37,7 @@ from synapse.api.errors import (
|
|||
HttpResponseException,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.events import room_version_to_event_format
|
||||
from synapse.events import builder, room_version_to_event_format
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -72,7 +71,8 @@ class FederationClient(FederationBase):
|
|||
self.state = hs.get_state_handler()
|
||||
self.transport_layer = hs.get_federation_transport_client()
|
||||
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.hostname = hs.hostname
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
|
||||
self._get_pdu_cache = ExpiringCache(
|
||||
cache_name="get_pdu_cache",
|
||||
|
@ -205,7 +205,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
self._check_sigs_and_hashes(pdus),
|
||||
self._check_sigs_and_hashes(room_version, pdus),
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError))
|
||||
|
||||
|
@ -268,7 +268,7 @@ class FederationClient(FederationBase):
|
|||
pdu = pdu_list[0]
|
||||
|
||||
# Check signatures are correct.
|
||||
signed_pdu = yield self._check_sigs_and_hash(pdu)
|
||||
signed_pdu = yield self._check_sigs_and_hash(room_version, pdu)
|
||||
|
||||
break
|
||||
|
||||
|
@ -608,18 +608,10 @@ class FederationClient(FederationBase):
|
|||
if "prev_state" not in pdu_dict:
|
||||
pdu_dict["prev_state"] = []
|
||||
|
||||
# Strip off the fields that we want to clobber.
|
||||
pdu_dict.pop("origin", None)
|
||||
pdu_dict.pop("origin_server_ts", None)
|
||||
pdu_dict.pop("unsigned", None)
|
||||
|
||||
builder = self.event_builder_factory.new(room_version, pdu_dict)
|
||||
add_hashes_and_signatures(
|
||||
builder,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
ev = builder.create_local_event_from_event_dict(
|
||||
self._clock, self.hostname, self.signing_key,
|
||||
format_version=event_format, event_dict=pdu_dict,
|
||||
)
|
||||
ev = builder.build()
|
||||
|
||||
defer.returnValue(
|
||||
(destination, ev, event_format)
|
||||
|
@ -751,18 +743,9 @@ class FederationClient(FederationBase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, destination, room_id, event_id, pdu):
|
||||
time_now = self._clock.time_msec()
|
||||
try:
|
||||
code, content = yield self.transport_layer.send_invite(
|
||||
destination=destination,
|
||||
room_id=room_id,
|
||||
event_id=event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
if e.code == 403:
|
||||
raise e.to_synapse_error()
|
||||
raise
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
|
||||
content = yield self._do_send_invite(destination, pdu, room_version)
|
||||
|
||||
pdu_dict = content["event"]
|
||||
|
||||
|
@ -774,12 +757,61 @@ class FederationClient(FederationBase):
|
|||
pdu = event_from_pdu_json(pdu_dict, format_ver)
|
||||
|
||||
# Check signatures are correct.
|
||||
pdu = yield self._check_sigs_and_hash(pdu)
|
||||
pdu = yield self._check_sigs_and_hash(room_version, pdu)
|
||||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
|
||||
defer.returnValue(pdu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_send_invite(self, destination, pdu, room_version):
|
||||
"""Actually sends the invite, first trying v2 API and falling back to
|
||||
v1 API if necessary.
|
||||
|
||||
Args:
|
||||
destination (str): Target server
|
||||
pdu (FrozenEvent)
|
||||
room_version (str)
|
||||
|
||||
Returns:
|
||||
dict: The event as a dict as returned by the remote server
|
||||
"""
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
content = yield self.transport_layer.send_invite_v2(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content={
|
||||
"event": pdu.get_pdu_json(time_now),
|
||||
"room_version": room_version,
|
||||
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
|
||||
},
|
||||
)
|
||||
defer.returnValue(content)
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
if room_version in (RoomVersions.V1, RoomVersions.V2):
|
||||
pass # We'll fall through
|
||||
else:
|
||||
raise Exception("Remote server is too old")
|
||||
elif e.code == 403:
|
||||
raise e.to_synapse_error()
|
||||
else:
|
||||
raise
|
||||
|
||||
# Didn't work, try v1 API.
|
||||
# Note the v1 API returns a tuple of `(200, content)`
|
||||
|
||||
_, content = yield self.transport_layer.send_invite_v1(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
defer.returnValue(content)
|
||||
|
||||
def send_leave(self, destinations, pdu):
|
||||
"""Sends a leave event to one of a list of homeservers.
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
|||
from twisted.internet.abstract import isIPAddress
|
||||
from twisted.python import failure
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
FederationError,
|
||||
|
@ -322,7 +322,7 @@ class FederationServer(FederationBase):
|
|||
if self.hs.is_mine_id(event.event_id):
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
event.get_pdu_json(),
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
|
@ -620,16 +620,19 @@ class FederationServer(FederationBase):
|
|||
"""
|
||||
# check that it's actually being sent from a valid destination to
|
||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||
if origin != get_domain_from_id(pdu.event_id):
|
||||
if origin != get_domain_from_id(pdu.sender):
|
||||
# We continue to accept join events from any server; this is
|
||||
# necessary for the federation join dance to work correctly.
|
||||
# (When we join over federation, the "helper" server is
|
||||
# responsible for sending out the join event, rather than the
|
||||
# origin. See bug #1893).
|
||||
# origin. See bug #1893. This is also true for some third party
|
||||
# invites).
|
||||
if not (
|
||||
pdu.type == 'm.room.member' and
|
||||
pdu.content and
|
||||
pdu.content.get("membership", None) == 'join'
|
||||
pdu.content.get("membership", None) in (
|
||||
Membership.JOIN, Membership.INVITE,
|
||||
)
|
||||
):
|
||||
logger.info(
|
||||
"Discarding PDU %s from invalid origin %s",
|
||||
|
@ -642,9 +645,12 @@ class FederationServer(FederationBase):
|
|||
pdu.event_id, origin
|
||||
)
|
||||
|
||||
# We've already checked that we know the room version by this point
|
||||
room_version = yield self.store.get_room_version(pdu.room_id)
|
||||
|
||||
# Check signature.
|
||||
try:
|
||||
pdu = yield self._check_sigs_and_hash(pdu)
|
||||
pdu = yield self._check_sigs_and_hash(room_version, pdu)
|
||||
except SynapseError as e:
|
||||
raise FederationError(
|
||||
"ERROR",
|
||||
|
|
|
@ -175,7 +175,7 @@ class TransactionQueue(object):
|
|||
def handle_event(event):
|
||||
# Only send events for this server.
|
||||
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
|
||||
is_mine = self.is_mine_id(event.event_id)
|
||||
is_mine = self.is_mine_id(event.sender)
|
||||
if not is_mine and send_on_behalf_of is None:
|
||||
return
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from six.moves import urllib
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.urls import FEDERATION_V1_PREFIX
|
||||
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -289,7 +289,7 @@ class TransportLayerClient(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_invite(self, destination, room_id, event_id, content):
|
||||
def send_invite_v1(self, destination, room_id, event_id, content):
|
||||
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
|
@ -301,6 +301,20 @@ class TransportLayerClient(object):
|
|||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_invite_v2(self, destination, room_id, event_id, content):
|
||||
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_public_rooms(self, remote_server, limit, since_token,
|
||||
|
@ -958,3 +972,24 @@ def _create_v1_path(path, *args):
|
|||
FEDERATION_V1_PREFIX
|
||||
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
|
||||
)
|
||||
|
||||
|
||||
def _create_v2_path(path, *args):
|
||||
"""Creates a path against V2 federation API from the path template and
|
||||
args. Ensures that all args are url encoded.
|
||||
|
||||
Example:
|
||||
|
||||
_create_v2_path("/event/%s/", event_id)
|
||||
|
||||
Args:
|
||||
path (str): String template for the path
|
||||
args: ([str]): Args to insert into path. Each arg will be url encoded
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
return (
|
||||
FEDERATION_V2_PREFIX
|
||||
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
|
||||
)
|
||||
|
|
|
@ -57,8 +57,8 @@ class DirectoryHandler(BaseHandler):
|
|||
# general association creation for both human users and app services
|
||||
|
||||
for wchar in string.whitespace:
|
||||
if wchar in room_alias.localpart:
|
||||
raise SynapseError(400, "Invalid characters in room alias")
|
||||
if wchar in room_alias.localpart:
|
||||
raise SynapseError(400, "Invalid characters in room alias")
|
||||
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room alias must be local")
|
||||
|
|
|
@ -102,7 +102,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self.hs = hs
|
||||
|
||||
self.store = hs.get_datastore() # type: synapse.storage.DataStore
|
||||
self.store = hs.get_datastore()
|
||||
self.federation_client = hs.get_federation_client()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.server_name = hs.hostname
|
||||
|
@ -1300,7 +1300,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
event.get_pdu_json(),
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
|
@ -2282,7 +2282,7 @@ class FederationHandler(BaseHandler):
|
|||
room_version = yield self.store.get_room_version(room_id)
|
||||
builder = self.event_builder_factory.new(room_version, event_dict)
|
||||
|
||||
EventValidator().validate_new(builder)
|
||||
EventValidator().validate_builder(builder)
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder=builder
|
||||
)
|
||||
|
@ -2291,6 +2291,12 @@ class FederationHandler(BaseHandler):
|
|||
room_version, event_dict, event, context
|
||||
)
|
||||
|
||||
EventValidator().validate_new(event)
|
||||
|
||||
# We need to tell the transaction queue to send this out, even
|
||||
# though the sender isn't a local user.
|
||||
event.internal_metadata.send_on_behalf_of = self.hs.hostname
|
||||
|
||||
try:
|
||||
yield self.auth.check_from_context(room_version, event, context)
|
||||
except AuthError as e:
|
||||
|
@ -2340,6 +2346,10 @@ class FederationHandler(BaseHandler):
|
|||
raise e
|
||||
yield self._check_signature(event, context)
|
||||
|
||||
# We need to tell the transaction queue to send this out, even
|
||||
# though the sender isn't a local user.
|
||||
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
|
||||
|
||||
# XXX we send the invite here, but send_membership_event also sends it,
|
||||
# so we end up making two requests. I think this is redundant.
|
||||
returned_invite = yield self.send_invite(origin, event)
|
||||
|
@ -2376,10 +2386,11 @@ class FederationHandler(BaseHandler):
|
|||
# auth check code will explode appropriately.
|
||||
|
||||
builder = self.event_builder_factory.new(room_version, event_dict)
|
||||
EventValidator().validate_new(builder)
|
||||
EventValidator().validate_builder(builder)
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
EventValidator().validate_new(event)
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
|
|||
from twisted.internet import defer
|
||||
from twisted.internet.defer import succeed
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
|
||||
from synapse.api.constants import EventTypes, Membership, RoomVersions
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
|
@ -31,7 +31,6 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.api.urls import ConsentURIBuilder
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
|
@ -288,7 +287,7 @@ class EventCreationHandler(object):
|
|||
|
||||
builder = self.event_builder_factory.new(room_version, event_dict)
|
||||
|
||||
self.validator.validate_new(builder)
|
||||
self.validator.validate_builder(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
|
@ -326,6 +325,8 @@ class EventCreationHandler(object):
|
|||
prev_events_and_hashes=prev_events_and_hashes,
|
||||
)
|
||||
|
||||
self.validator.validate_new(event)
|
||||
|
||||
defer.returnValue((event, context))
|
||||
|
||||
def _is_exempt_from_privacy_policy(self, builder, requester):
|
||||
|
@ -543,40 +544,19 @@ class EventCreationHandler(object):
|
|||
prev_events_and_hashes = \
|
||||
yield self.store.get_prev_events_for_room(builder.room_id)
|
||||
|
||||
if prev_events_and_hashes:
|
||||
depth = max([d for _, _, d in prev_events_and_hashes]) + 1
|
||||
# we cap depth of generated events, to ensure that they are not
|
||||
# rejected by other servers (and so that they can be persisted in
|
||||
# the db)
|
||||
depth = min(depth, MAX_DEPTH)
|
||||
else:
|
||||
depth = 1
|
||||
|
||||
prev_events = [
|
||||
(event_id, prev_hashes)
|
||||
for event_id, prev_hashes, _ in prev_events_and_hashes
|
||||
]
|
||||
|
||||
builder.prev_events = prev_events
|
||||
builder.depth = depth
|
||||
|
||||
context = yield self.state.compute_event_context(builder)
|
||||
event = yield builder.build(
|
||||
prev_event_ids=[p for p, _ in prev_events],
|
||||
)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
if requester:
|
||||
context.app_service = requester.app_service
|
||||
|
||||
if builder.is_state():
|
||||
builder.prev_state = yield self.store.add_event_hashes(
|
||||
context.prev_state_events
|
||||
)
|
||||
|
||||
yield self.auth.add_auth_events(builder, context)
|
||||
|
||||
signing_key = self.hs.config.signing_key[0]
|
||||
add_hashes_and_signatures(
|
||||
builder, self.server_name, signing_key
|
||||
)
|
||||
|
||||
event = builder.build()
|
||||
self.validator.validate_new(event)
|
||||
|
||||
logger.debug(
|
||||
"Created event %s",
|
||||
|
@ -765,7 +745,8 @@ class EventCreationHandler(object):
|
|||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.values()
|
||||
}
|
||||
if self.auth.check_redaction(event, auth_events=auth_events):
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
if self.auth.check_redaction(room_version, event, auth_events=auth_events):
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
check_redacted=False,
|
||||
|
@ -779,6 +760,9 @@ class EventCreationHandler(object):
|
|||
"You don't have permission to redact events"
|
||||
)
|
||||
|
||||
# We've already checked.
|
||||
event.internal_metadata.recheck_redaction = False
|
||||
|
||||
if event.type == EventTypes.Create:
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
if prev_state_ids:
|
||||
|
|
|
@ -960,7 +960,8 @@ class RoomMemberHandler(object):
|
|||
# first member event?
|
||||
create_event_id = current_state_ids.get(("m.room.create", ""))
|
||||
if len(current_state_ids) == 1 and create_event_id:
|
||||
defer.returnValue(self.hs.is_mine_id(create_event_id))
|
||||
# We can only get here if we're in the process of creating the room
|
||||
defer.returnValue(True)
|
||||
|
||||
for etype, state_key in current_state_ids:
|
||||
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
|
||||
|
|
|
@ -12,7 +12,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
import attr
|
||||
from netaddr import IPAddress
|
||||
|
@ -20,14 +23,30 @@ from zope.interface import implementer
|
|||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.web.client import URI, Agent, HTTPConnectionPool
|
||||
from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
|
||||
from twisted.web.http import stringToDatetime
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IAgent
|
||||
|
||||
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
|
||||
from synapse.util.caches.ttlcache import TTLCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
# period to cache .well-known results for by default
|
||||
WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
|
||||
|
||||
# jitter to add to the .well-known default cache ttl
|
||||
WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
|
||||
|
||||
# period to cache failure to fetch .well-known for
|
||||
WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
|
||||
|
||||
# cap for .well-known cache period
|
||||
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
well_known_cache = TTLCache('well-known')
|
||||
|
||||
|
||||
@implementer(IAgent)
|
||||
|
@ -43,13 +62,20 @@ class MatrixFederationAgent(object):
|
|||
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||
factory to use for fetching client tls options, or none to disable TLS.
|
||||
|
||||
_well_known_tls_policy (IPolicyForHTTPS|None):
|
||||
TLS policy to use for fetching .well-known files. None to use a default
|
||||
(browser-like) implementation.
|
||||
|
||||
srv_resolver (SrvResolver|None):
|
||||
SRVResolver impl to use for looking up SRV records. None to use a default
|
||||
implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, reactor, tls_client_options_factory, _srv_resolver=None,
|
||||
self, reactor, tls_client_options_factory,
|
||||
_well_known_tls_policy=None,
|
||||
_srv_resolver=None,
|
||||
_well_known_cache=well_known_cache,
|
||||
):
|
||||
self._reactor = reactor
|
||||
self._tls_client_options_factory = tls_client_options_factory
|
||||
|
@ -62,6 +88,18 @@ class MatrixFederationAgent(object):
|
|||
self._pool.maxPersistentPerHost = 5
|
||||
self._pool.cachedConnectionTimeout = 2 * 60
|
||||
|
||||
agent_args = {}
|
||||
if _well_known_tls_policy is not None:
|
||||
# the param is called 'contextFactory', but actually passing a
|
||||
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
|
||||
agent_args['contextFactory'] = _well_known_tls_policy
|
||||
_well_known_agent = RedirectAgent(
|
||||
Agent(self._reactor, pool=self._pool, **agent_args),
|
||||
)
|
||||
self._well_known_agent = _well_known_agent
|
||||
|
||||
self._well_known_cache = _well_known_cache
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||
"""
|
||||
|
@ -114,7 +152,11 @@ class MatrixFederationAgent(object):
|
|||
class EndpointFactory(object):
|
||||
@staticmethod
|
||||
def endpointForURI(_uri):
|
||||
logger.info("Connecting to %s:%s", res.target_host, res.target_port)
|
||||
logger.info(
|
||||
"Connecting to %s:%i",
|
||||
res.target_host.decode("ascii"),
|
||||
res.target_port,
|
||||
)
|
||||
ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
|
||||
if tls_options is not None:
|
||||
ep = wrapClientTLS(tls_options, ep)
|
||||
|
@ -127,7 +169,7 @@ class MatrixFederationAgent(object):
|
|||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _route_matrix_uri(self, parsed_uri):
|
||||
def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
|
||||
"""Helper for `request`: determine the routing for a Matrix URI
|
||||
|
||||
Args:
|
||||
|
@ -135,6 +177,9 @@ class MatrixFederationAgent(object):
|
|||
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
|
||||
if there is no explicit port given.
|
||||
|
||||
lookup_well_known (bool): True if we should look up the .well-known file if
|
||||
there is no SRV record.
|
||||
|
||||
Returns:
|
||||
Deferred[_RoutingResult]
|
||||
"""
|
||||
|
@ -169,6 +214,42 @@ class MatrixFederationAgent(object):
|
|||
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
|
||||
server_list = yield self._srv_resolver.resolve_service(service_name)
|
||||
|
||||
if not server_list and lookup_well_known:
|
||||
# try a .well-known lookup
|
||||
well_known_server = yield self._get_well_known(parsed_uri.host)
|
||||
|
||||
if well_known_server:
|
||||
# if we found a .well-known, start again, but don't do another
|
||||
# .well-known lookup.
|
||||
|
||||
# parse the server name in the .well-known response into host/port.
|
||||
# (This code is lifted from twisted.web.client.URI.fromBytes).
|
||||
if b':' in well_known_server:
|
||||
well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
|
||||
try:
|
||||
well_known_port = int(well_known_port)
|
||||
except ValueError:
|
||||
# the part after the colon could not be parsed as an int
|
||||
# - we assume it is an IPv6 literal with no port (the closing
|
||||
# ']' stops it being parsed as an int)
|
||||
well_known_host, well_known_port = well_known_server, -1
|
||||
else:
|
||||
well_known_host, well_known_port = well_known_server, -1
|
||||
|
||||
new_uri = URI(
|
||||
scheme=parsed_uri.scheme,
|
||||
netloc=well_known_server,
|
||||
host=well_known_host,
|
||||
port=well_known_port,
|
||||
path=parsed_uri.path,
|
||||
params=parsed_uri.params,
|
||||
query=parsed_uri.query,
|
||||
fragment=parsed_uri.fragment,
|
||||
)
|
||||
|
||||
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
|
||||
defer.returnValue(res)
|
||||
|
||||
if not server_list:
|
||||
target_host = parsed_uri.host
|
||||
port = 8448
|
||||
|
@ -190,6 +271,114 @@ class MatrixFederationAgent(object):
|
|||
target_port=port,
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_well_known(self, server_name):
|
||||
"""Attempt to fetch and parse a .well-known file for the given server
|
||||
|
||||
Args:
|
||||
server_name (bytes): name of the server, from the requested url
|
||||
|
||||
Returns:
|
||||
Deferred[bytes|None]: either the new server name, from the .well-known, or
|
||||
None if there was no .well-known file.
|
||||
"""
|
||||
try:
|
||||
cached = self._well_known_cache[server_name]
|
||||
defer.returnValue(cached)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# TODO: should we linearise so that we don't end up doing two .well-known requests
|
||||
# for the same server in parallel?
|
||||
|
||||
uri = b"https://%s/.well-known/matrix/server" % (server_name, )
|
||||
uri_str = uri.decode("ascii")
|
||||
logger.info("Fetching %s", uri_str)
|
||||
try:
|
||||
response = yield make_deferred_yieldable(
|
||||
self._well_known_agent.request(b"GET", uri),
|
||||
)
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
if response.code != 200:
|
||||
raise Exception("Non-200 response %s" % (response.code, ))
|
||||
except Exception as e:
|
||||
logger.info("Error fetching %s: %s", uri_str, e)
|
||||
|
||||
# add some randomness to the TTL to avoid a stampeding herd every hour
|
||||
# after startup
|
||||
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
|
||||
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
|
||||
|
||||
self._well_known_cache.set(server_name, None, cache_period)
|
||||
defer.returnValue(None)
|
||||
|
||||
try:
|
||||
parsed_body = json.loads(body.decode('utf-8'))
|
||||
logger.info("Response from .well-known: %s", parsed_body)
|
||||
if not isinstance(parsed_body, dict):
|
||||
raise Exception("not a dict")
|
||||
if "m.server" not in parsed_body:
|
||||
raise Exception("Missing key 'm.server'")
|
||||
except Exception as e:
|
||||
raise Exception("invalid .well-known response from %s: %s" % (uri_str, e,))
|
||||
|
||||
result = parsed_body["m.server"].encode("ascii")
|
||||
|
||||
cache_period = _cache_period_from_headers(
|
||||
response.headers,
|
||||
time_now=self._reactor.seconds,
|
||||
)
|
||||
if cache_period is None:
|
||||
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
|
||||
# add some randomness to the TTL to avoid a stampeding herd every 24 hours
|
||||
# after startup
|
||||
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
|
||||
else:
|
||||
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
|
||||
|
||||
if cache_period > 0:
|
||||
self._well_known_cache.set(server_name, result, cache_period)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
|
||||
def _cache_period_from_headers(headers, time_now=time.time):
|
||||
cache_controls = _parse_cache_control(headers)
|
||||
|
||||
if b'no-store' in cache_controls:
|
||||
return 0
|
||||
|
||||
if b'max-age' in cache_controls:
|
||||
try:
|
||||
max_age = int(cache_controls[b'max-age'])
|
||||
return max_age
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
expires = headers.getRawHeaders(b'expires')
|
||||
if expires is not None:
|
||||
try:
|
||||
expires_date = stringToDatetime(expires[-1])
|
||||
return expires_date - time_now()
|
||||
except ValueError:
|
||||
# RFC7234 says 'A cache recipient MUST interpret invalid date formats,
|
||||
# especially the value "0", as representing a time in the past (i.e.,
|
||||
# "already expired").
|
||||
return 0
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _parse_cache_control(headers):
|
||||
cache_controls = {}
|
||||
for hdr in headers.getRawHeaders(b'cache-control', []):
|
||||
for directive in hdr.split(b','):
|
||||
splits = [x.strip() for x in directive.split(b'=', 1)]
|
||||
k = splits[0].lower()
|
||||
v = splits[1] if len(splits) > 1 else None
|
||||
cache_controls[k] = v
|
||||
return cache_controls
|
||||
|
||||
|
||||
@attr.s
|
||||
class _RoutingResult(object):
|
||||
|
|
|
@ -84,7 +84,7 @@ def _rule_to_template(rule):
|
|||
templaterule["pattern"] = thecond["pattern"]
|
||||
|
||||
if unscoped_rule_id:
|
||||
templaterule['rule_id'] = unscoped_rule_id
|
||||
templaterule['rule_id'] = unscoped_rule_id
|
||||
if 'default' in rule:
|
||||
templaterule['default'] = rule['default']
|
||||
return templaterule
|
||||
|
|
|
@ -95,7 +95,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
event_and_contexts = []
|
||||
for event_payload in event_payloads:
|
||||
event_dict = event_payload["event"]
|
||||
format_ver = content["event_format_version"]
|
||||
format_ver = event_payload["event_format_version"]
|
||||
internal_metadata = event_payload["internal_metadata"]
|
||||
rejected_reason = event_payload["rejected_reason"]
|
||||
|
||||
|
|
|
@ -101,16 +101,7 @@ class ConsentResource(Resource):
|
|||
"missing in config file.",
|
||||
)
|
||||
|
||||
# daemonize changes the cwd to /, so make the path absolute now.
|
||||
consent_template_directory = path.abspath(
|
||||
hs.config.user_consent_template_dir,
|
||||
)
|
||||
if not path.isdir(consent_template_directory):
|
||||
raise ConfigError(
|
||||
"Could not find template directory '%s'" % (
|
||||
consent_template_directory,
|
||||
),
|
||||
)
|
||||
consent_template_directory = hs.config.user_consent_template_dir
|
||||
|
||||
loader = jinja2.FileSystemLoader(consent_template_directory)
|
||||
self._jinja_env = jinja2.Environment(
|
||||
|
|
|
@ -355,10 +355,7 @@ class HomeServer(object):
|
|||
return Keyring(self)
|
||||
|
||||
def build_event_builder_factory(self):
|
||||
return EventBuilderFactory(
|
||||
clock=self.get_clock(),
|
||||
hostname=self.hostname,
|
||||
)
|
||||
return EventBuilderFactory(self)
|
||||
|
||||
def build_filtering(self):
|
||||
return Filtering(self)
|
||||
|
|
|
@ -608,7 +608,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
|
|||
state_sets, event_map, state_res_store.get_events,
|
||||
)
|
||||
elif room_version in (
|
||||
RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2,
|
||||
RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.V3,
|
||||
):
|
||||
return v2.resolve_events_with_store(
|
||||
room_version, state_sets, event_map, state_res_store,
|
||||
|
|
|
@ -317,7 +317,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
thirty_days_ago_in_secs))
|
||||
|
||||
for row in txn:
|
||||
if row[0] is 'unknown':
|
||||
if row[0] == 'unknown':
|
||||
pass
|
||||
results[row[0]] = row[1]
|
||||
|
||||
|
|
|
@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
|||
|
||||
return dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_max_depth_of(self, event_ids):
|
||||
"""Returns the max depth of a set of event IDs
|
||||
|
||||
Args:
|
||||
event_ids (list[str])
|
||||
|
||||
Returns
|
||||
Deferred[int]
|
||||
"""
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="events",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
retcols=("depth",),
|
||||
desc="get_max_depth_of",
|
||||
)
|
||||
|
||||
if not rows:
|
||||
defer.returnValue(0)
|
||||
else:
|
||||
defer.returnValue(max(row["depth"] for row in rows))
|
||||
|
||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||
return self._simple_select_onecol_txn(
|
||||
txn,
|
||||
|
|
|
@ -904,105 +904,105 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
|
|||
|
||||
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
|
||||
for room_id, current_state_tuple in iteritems(state_delta_by_room):
|
||||
to_delete, to_insert = current_state_tuple
|
||||
to_delete, to_insert = current_state_tuple
|
||||
|
||||
# First we add entries to the current_state_delta_stream. We
|
||||
# do this before updating the current_state_events table so
|
||||
# that we can use it to calculate the `prev_event_id`. (This
|
||||
# allows us to not have to pull out the existing state
|
||||
# unnecessarily).
|
||||
sql = """
|
||||
INSERT INTO current_state_delta_stream
|
||||
(stream_id, room_id, type, state_key, event_id, prev_event_id)
|
||||
SELECT ?, ?, ?, ?, ?, (
|
||||
SELECT event_id FROM current_state_events
|
||||
WHERE room_id = ? AND type = ? AND state_key = ?
|
||||
)
|
||||
"""
|
||||
txn.executemany(sql, (
|
||||
(
|
||||
max_stream_order, room_id, etype, state_key, None,
|
||||
room_id, etype, state_key,
|
||||
)
|
||||
for etype, state_key in to_delete
|
||||
# We sanity check that we're deleting rather than updating
|
||||
if (etype, state_key) not in to_insert
|
||||
))
|
||||
txn.executemany(sql, (
|
||||
(
|
||||
max_stream_order, room_id, etype, state_key, ev_id,
|
||||
room_id, etype, state_key,
|
||||
)
|
||||
for (etype, state_key), ev_id in iteritems(to_insert)
|
||||
))
|
||||
|
||||
# Now we actually update the current_state_events table
|
||||
|
||||
txn.executemany(
|
||||
"DELETE FROM current_state_events"
|
||||
" WHERE room_id = ? AND type = ? AND state_key = ?",
|
||||
(
|
||||
(room_id, etype, state_key)
|
||||
for etype, state_key in itertools.chain(to_delete, to_insert)
|
||||
),
|
||||
# First we add entries to the current_state_delta_stream. We
|
||||
# do this before updating the current_state_events table so
|
||||
# that we can use it to calculate the `prev_event_id`. (This
|
||||
# allows us to not have to pull out the existing state
|
||||
# unnecessarily).
|
||||
sql = """
|
||||
INSERT INTO current_state_delta_stream
|
||||
(stream_id, room_id, type, state_key, event_id, prev_event_id)
|
||||
SELECT ?, ?, ?, ?, ?, (
|
||||
SELECT event_id FROM current_state_events
|
||||
WHERE room_id = ? AND type = ? AND state_key = ?
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="current_state_events",
|
||||
values=[
|
||||
{
|
||||
"event_id": ev_id,
|
||||
"room_id": room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
}
|
||||
for key, ev_id in iteritems(to_insert)
|
||||
],
|
||||
"""
|
||||
txn.executemany(sql, (
|
||||
(
|
||||
max_stream_order, room_id, etype, state_key, None,
|
||||
room_id, etype, state_key,
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self._curr_state_delta_stream_cache.entity_has_changed,
|
||||
room_id, max_stream_order,
|
||||
for etype, state_key in to_delete
|
||||
# We sanity check that we're deleting rather than updating
|
||||
if (etype, state_key) not in to_insert
|
||||
))
|
||||
txn.executemany(sql, (
|
||||
(
|
||||
max_stream_order, room_id, etype, state_key, ev_id,
|
||||
room_id, etype, state_key,
|
||||
)
|
||||
for (etype, state_key), ev_id in iteritems(to_insert)
|
||||
))
|
||||
|
||||
# Invalidate the various caches
|
||||
# Now we actually update the current_state_events table
|
||||
|
||||
# Figure out the changes of membership to invalidate the
|
||||
# `get_rooms_for_user` cache.
|
||||
# We find out which membership events we may have deleted
|
||||
# and which we have added, then we invlidate the caches for all
|
||||
# those users.
|
||||
members_changed = set(
|
||||
state_key
|
||||
for ev_type, state_key in itertools.chain(to_delete, to_insert)
|
||||
if ev_type == EventTypes.Member
|
||||
)
|
||||
txn.executemany(
|
||||
"DELETE FROM current_state_events"
|
||||
" WHERE room_id = ? AND type = ? AND state_key = ?",
|
||||
(
|
||||
(room_id, etype, state_key)
|
||||
for etype, state_key in itertools.chain(to_delete, to_insert)
|
||||
),
|
||||
)
|
||||
|
||||
for member in members_changed:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_rooms_for_user_with_stream_ordering, (member,)
|
||||
)
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="current_state_events",
|
||||
values=[
|
||||
{
|
||||
"event_id": ev_id,
|
||||
"room_id": room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
}
|
||||
for key, ev_id in iteritems(to_insert)
|
||||
],
|
||||
)
|
||||
|
||||
for host in set(get_domain_from_id(u) for u in members_changed):
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.is_host_joined, (room_id, host)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.was_host_joined, (room_id, host)
|
||||
)
|
||||
txn.call_after(
|
||||
self._curr_state_delta_stream_cache.entity_has_changed,
|
||||
room_id, max_stream_order,
|
||||
)
|
||||
|
||||
# Invalidate the various caches
|
||||
|
||||
# Figure out the changes of membership to invalidate the
|
||||
# `get_rooms_for_user` cache.
|
||||
# We find out which membership events we may have deleted
|
||||
# and which we have added, then we invlidate the caches for all
|
||||
# those users.
|
||||
members_changed = set(
|
||||
state_key
|
||||
for ev_type, state_key in itertools.chain(to_delete, to_insert)
|
||||
if ev_type == EventTypes.Member
|
||||
)
|
||||
|
||||
for member in members_changed:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_users_in_room, (room_id,)
|
||||
txn, self.get_rooms_for_user_with_stream_ordering, (member,)
|
||||
)
|
||||
|
||||
for host in set(get_domain_from_id(u) for u in members_changed):
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_room_summary, (room_id,)
|
||||
txn, self.is_host_joined, (room_id, host)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.was_host_joined, (room_id, host)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_current_state_ids, (room_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_users_in_room, (room_id,)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_room_summary, (room_id,)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_current_state_ids, (room_id,)
|
||||
)
|
||||
|
||||
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
|
||||
max_stream_order):
|
||||
|
|
|
@ -21,13 +21,14 @@ from canonicaljson import json
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventFormatVersions
|
||||
from synapse.api.constants import EventFormatVersions, EventTypes
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
|
||||
# these are only included to make the type annotations work
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import (
|
||||
LoggingContext,
|
||||
PreserveLoggingContext,
|
||||
|
@ -162,7 +163,6 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
missing_events = yield self._enqueue_events(
|
||||
missing_events_ids,
|
||||
check_redacted=check_redacted,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
|
@ -174,6 +174,29 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
if not entry:
|
||||
continue
|
||||
|
||||
# Starting in room version v3, some redactions need to be rechecked if we
|
||||
# didn't have the redacted event at the time, so we recheck on read
|
||||
# instead.
|
||||
if not allow_rejected and entry.event.type == EventTypes.Redaction:
|
||||
if entry.event.internal_metadata.need_to_check_redaction():
|
||||
orig = yield self.get_event(
|
||||
entry.event.redacts,
|
||||
allow_none=True,
|
||||
allow_rejected=True,
|
||||
get_prev_content=False,
|
||||
)
|
||||
expected_domain = get_domain_from_id(entry.event.sender)
|
||||
if orig and get_domain_from_id(orig.sender) == expected_domain:
|
||||
# This redaction event is allowed. Mark as not needing a
|
||||
# recheck.
|
||||
entry.event.internal_metadata.recheck_redaction = False
|
||||
else:
|
||||
# We don't have the event that is being redacted, so we
|
||||
# assume that the event isn't authorized for now. (If we
|
||||
# later receive the event, then we will always redact
|
||||
# it anyway, since we have this redaction)
|
||||
continue
|
||||
|
||||
if allow_rejected or not entry.event.rejected_reason:
|
||||
if check_redacted and entry.redacted_event:
|
||||
event = entry.redacted_event
|
||||
|
@ -197,7 +220,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
defer.returnValue(events)
|
||||
|
||||
def _invalidate_get_event_cache(self, event_id):
|
||||
self._get_event_cache.invalidate((event_id,))
|
||||
self._get_event_cache.invalidate((event_id,))
|
||||
|
||||
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
|
||||
"""Fetch events from the caches
|
||||
|
@ -310,7 +333,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
||||
def _enqueue_events(self, events, allow_rejected=False):
|
||||
"""Fetches events from the database using the _event_fetch_list. This
|
||||
allows batch and bulk fetching of events - it allows us to fetch events
|
||||
without having to create a new transaction for each request for events.
|
||||
|
@ -443,6 +466,19 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# will serialise this field correctly
|
||||
redacted_event.unsigned["redacted_because"] = because
|
||||
|
||||
# Starting in room version v3, some redactions need to be
|
||||
# rechecked if we didn't have the redacted event at the
|
||||
# time, so we recheck on read instead.
|
||||
if because.internal_metadata.need_to_check_redaction():
|
||||
expected_domain = get_domain_from_id(original_ev.sender)
|
||||
if get_domain_from_id(because.sender) == expected_domain:
|
||||
# This redaction event is allowed. Mark as not needing a
|
||||
# recheck.
|
||||
because.internal_metadata.recheck_redaction = False
|
||||
else:
|
||||
# Senders don't match, so the event isn't actually redacted
|
||||
redacted_event = None
|
||||
|
||||
cache_entry = _EventCacheEntry(
|
||||
event=original_ev,
|
||||
redacted_event=redacted_event,
|
||||
|
|
|
@ -201,7 +201,7 @@ class Linearizer(object):
|
|||
if entry[0] >= self.max_count:
|
||||
res = self._await_lock(key)
|
||||
else:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
|
||||
)
|
||||
entry[0] += 1
|
||||
|
@ -215,7 +215,7 @@ class Linearizer(object):
|
|||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
|
||||
logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
|
||||
|
||||
# We've finished executing so check if there are any things
|
||||
# blocked waiting to execute and start one of them
|
||||
|
@ -247,7 +247,7 @@ class Linearizer(object):
|
|||
"""
|
||||
entry = self.key_to_defer[key]
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
|
||||
)
|
||||
|
||||
|
@ -255,7 +255,7 @@ class Linearizer(object):
|
|||
entry[1][new_defer] = 1
|
||||
|
||||
def cb(_r):
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
||||
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
|
||||
entry[0] += 1
|
||||
|
||||
# if the code holding the lock completes synchronously, then it
|
||||
|
@ -273,7 +273,7 @@ class Linearizer(object):
|
|||
def eb(e):
|
||||
logger.info("defer %r got err %r", new_defer, e)
|
||||
if isinstance(e, CancelledError):
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"Cancelling wait for linearizer lock %r for key %r",
|
||||
self.name, key,
|
||||
)
|
||||
|
|
161
synapse/util/caches/ttlcache.py
Normal file
161
synapse/util/caches/ttlcache.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import attr
|
||||
from sortedcontainers import SortedList
|
||||
|
||||
from synapse.util.caches import register_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class TTLCache(object):
|
||||
"""A key/value cache implementation where each entry has its own TTL"""
|
||||
|
||||
def __init__(self, cache_name, timer=time.time):
|
||||
# map from key to _CacheEntry
|
||||
self._data = {}
|
||||
|
||||
# the _CacheEntries, sorted by expiry time
|
||||
self._expiry_list = SortedList()
|
||||
|
||||
self._timer = timer
|
||||
|
||||
self._metrics = register_cache("ttl", cache_name, self)
|
||||
|
||||
def set(self, key, value, ttl):
|
||||
"""Add/update an entry in the cache
|
||||
|
||||
Args:
|
||||
key: key for this entry
|
||||
value: value for this entry
|
||||
ttl (float): TTL for this entry, in seconds
|
||||
"""
|
||||
expiry = self._timer() + ttl
|
||||
|
||||
self.expire()
|
||||
e = self._data.pop(key, SENTINEL)
|
||||
if e != SENTINEL:
|
||||
self._expiry_list.remove(e)
|
||||
|
||||
entry = _CacheEntry(expiry_time=expiry, key=key, value=value)
|
||||
self._data[key] = entry
|
||||
self._expiry_list.add(entry)
|
||||
|
||||
def get(self, key, default=SENTINEL):
|
||||
"""Get a value from the cache
|
||||
|
||||
Args:
|
||||
key: key to look up
|
||||
default: default value to return, if key is not found. If not set, and the
|
||||
key is not found, a KeyError will be raised
|
||||
|
||||
Returns:
|
||||
value from the cache, or the default
|
||||
"""
|
||||
self.expire()
|
||||
e = self._data.get(key, SENTINEL)
|
||||
if e == SENTINEL:
|
||||
self._metrics.inc_misses()
|
||||
if default == SENTINEL:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
self._metrics.inc_hits()
|
||||
return e.value
|
||||
|
||||
def get_with_expiry(self, key):
|
||||
"""Get a value, and its expiry time, from the cache
|
||||
|
||||
Args:
|
||||
key: key to look up
|
||||
|
||||
Returns:
|
||||
Tuple[Any, float]: the value from the cache, and the expiry time
|
||||
|
||||
Raises:
|
||||
KeyError if the entry is not found
|
||||
"""
|
||||
self.expire()
|
||||
try:
|
||||
e = self._data[key]
|
||||
except KeyError:
|
||||
self._metrics.inc_misses()
|
||||
raise
|
||||
self._metrics.inc_hits()
|
||||
return e.value, e.expiry_time
|
||||
|
||||
def pop(self, key, default=SENTINEL):
|
||||
"""Remove a value from the cache
|
||||
|
||||
If key is in the cache, remove it and return its value, else return default.
|
||||
If default is not given and key is not in the cache, a KeyError is raised.
|
||||
|
||||
Args:
|
||||
key: key to look up
|
||||
default: default value to return, if key is not found. If not set, and the
|
||||
key is not found, a KeyError will be raised
|
||||
|
||||
Returns:
|
||||
value from the cache, or the default
|
||||
"""
|
||||
self.expire()
|
||||
e = self._data.pop(key, SENTINEL)
|
||||
if e == SENTINEL:
|
||||
self._metrics.inc_misses()
|
||||
if default == SENTINEL:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
self._expiry_list.remove(e)
|
||||
self._metrics.inc_hits()
|
||||
return e.value
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.get(key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.pop(key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._data
|
||||
|
||||
def __len__(self):
|
||||
self.expire()
|
||||
return len(self._data)
|
||||
|
||||
def expire(self):
|
||||
"""Run the expiry on the cache. Any entries whose expiry times are due will
|
||||
be removed
|
||||
"""
|
||||
now = self._timer()
|
||||
while self._expiry_list:
|
||||
first_entry = self._expiry_list[0]
|
||||
if first_entry.expiry_time - now > 0.0:
|
||||
break
|
||||
del self._data[first_entry.key]
|
||||
del self._expiry_list[0]
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class _CacheEntry(object):
|
||||
"""TTLCache entry"""
|
||||
# expiry_time is the first attribute, so that entries are sorted by expiry.
|
||||
expiry_time = attr.ib()
|
||||
key = attr.ib()
|
||||
value = attr.ib()
|
|
@ -285,7 +285,10 @@ class LoggingContext(object):
|
|||
self.alive = False
|
||||
|
||||
# if we have a parent, pass our CPU usage stats on
|
||||
if self.parent_context is not None:
|
||||
if (
|
||||
self.parent_context is not None
|
||||
and hasattr(self.parent_context, '_resource_usage')
|
||||
):
|
||||
self.parent_context._resource_usage += self._resource_usage
|
||||
|
||||
# reset them in case we get entered again
|
||||
|
|
|
@ -50,8 +50,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
|||
"homeserver.yaml",
|
||||
"lemurs.win.log.config",
|
||||
"lemurs.win.signing.key",
|
||||
"lemurs.win.tls.crt",
|
||||
"lemurs.win.tls.key",
|
||||
]
|
||||
),
|
||||
set(os.listdir(self.dir)),
|
||||
|
|
75
tests/config/test_tls.py
Normal file
75
tests/config/test_tls.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from synapse.config.tls import TlsConfig
|
||||
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class TLSConfigTests(TestCase):
|
||||
|
||||
def test_warn_self_signed(self):
|
||||
"""
|
||||
Synapse will give a warning when it loads a self-signed certificate.
|
||||
"""
|
||||
config_dir = self.mktemp()
|
||||
os.mkdir(config_dir)
|
||||
with open(os.path.join(config_dir, "cert.pem"), 'w') as f:
|
||||
f.write("""-----BEGIN CERTIFICATE-----
|
||||
MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
|
||||
BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv
|
||||
Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb
|
||||
QXV0b21hdGVkIFRlc3RpbmcgQXV0aG9yaXR5MSkwJwYJKoZIhvcNAQkBFhpzZWN1
|
||||
cml0eUB0d2lzdGVkbWF0cml4LmNvbTAgFw0xNzA3MTIxNDAxNTNaGA8yMTE3MDYx
|
||||
ODE0MDE1M1owgbcxCzAJBgNVBAYTAlRSMQ8wDQYDVQQIDAbDh29ydW0xFDASBgNV
|
||||
BAcMC0JhxZ9tYWvDp8SxMRIwEAYDVQQDDAlsb2NhbGhvc3QxHDAaBgNVBAoME1R3
|
||||
aXN0ZWQgTWF0cml4IExhYnMxJDAiBgNVBAsMG0F1dG9tYXRlZCBUZXN0aW5nIEF1
|
||||
dGhvcml0eTEpMCcGCSqGSIb3DQEJARYac2VjdXJpdHlAdHdpc3RlZG1hdHJpeC5j
|
||||
b20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDwT6kbqtMUI0sMkx4h
|
||||
I+L780dA59KfksZCqJGmOsMD6hte9EguasfkZzvCF3dk3NhwCjFSOvKx6rCwiteo
|
||||
WtYkVfo+rSuVNmt7bEsOUDtuTcaxTzIFB+yHOYwAaoz3zQkyVW0c4pzioiLCGCmf
|
||||
FLdiDBQGGp74tb+7a0V6kC3vMLFoM3L6QWq5uYRB5+xLzlPJ734ltyvfZHL3Us6p
|
||||
cUbK+3WTWvb4ER0W2RqArAj6Bc/ERQKIAPFEiZi9bIYTwvBH27OKHRz+KoY/G8zY
|
||||
+l+WZoJqDhupRAQAuh7O7V/y6bSP+KNxJRie9QkZvw1PSaGSXtGJI3WWdO12/Ulg
|
||||
epJpAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAJXEq5P9xwvP9aDkXIqzcD0L8sf8
|
||||
ewlhlxTQdeqt2Nace0Yk18lIo2oj1t86Y8jNbpAnZJeI813Rr5M7FbHCXoRc/SZG
|
||||
I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj
|
||||
iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2
|
||||
SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz
|
||||
s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
|
||||
-----END CERTIFICATE-----""")
|
||||
|
||||
config = {
|
||||
"tls_certificate_path": os.path.join(config_dir, "cert.pem"),
|
||||
"no_tls": True,
|
||||
"tls_fingerprints": []
|
||||
}
|
||||
|
||||
t = TlsConfig()
|
||||
t.read_config(config)
|
||||
t.read_certificate_from_disk()
|
||||
|
||||
warnings = self.flushWarnings()
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(
|
||||
warnings[0]["message"],
|
||||
(
|
||||
"Self-signed TLS certificates will not be accepted by "
|
||||
"Synapse 1.0. Please either provide a valid certificate, "
|
||||
"or use Synapse's ACME support to provision one."
|
||||
)
|
||||
)
|
|
@ -18,7 +18,7 @@ import nacl.signing
|
|||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.events import FrozenEvent
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
@ -40,20 +40,18 @@ class EventSigningTestCase(unittest.TestCase):
|
|||
self.signing_key.version = KEY_VER
|
||||
|
||||
def test_sign_minimal(self):
|
||||
builder = EventBuilder(
|
||||
{
|
||||
'event_id': "$0:domain",
|
||||
'origin': "domain",
|
||||
'origin_server_ts': 1000000,
|
||||
'signatures': {},
|
||||
'type': "X",
|
||||
'unsigned': {'age_ts': 1000000},
|
||||
}
|
||||
)
|
||||
event_dict = {
|
||||
'event_id': "$0:domain",
|
||||
'origin': "domain",
|
||||
'origin_server_ts': 1000000,
|
||||
'signatures': {},
|
||||
'type': "X",
|
||||
'unsigned': {'age_ts': 1000000},
|
||||
}
|
||||
|
||||
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
|
||||
add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
|
||||
|
||||
event = builder.build()
|
||||
event = FrozenEvent(event_dict)
|
||||
|
||||
self.assertTrue(hasattr(event, 'hashes'))
|
||||
self.assertIn('sha256', event.hashes)
|
||||
|
@ -71,23 +69,21 @@ class EventSigningTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
def test_sign_message(self):
|
||||
builder = EventBuilder(
|
||||
{
|
||||
'content': {'body': "Here is the message content"},
|
||||
'event_id': "$0:domain",
|
||||
'origin': "domain",
|
||||
'origin_server_ts': 1000000,
|
||||
'type': "m.room.message",
|
||||
'room_id': "!r:domain",
|
||||
'sender': "@u:domain",
|
||||
'signatures': {},
|
||||
'unsigned': {'age_ts': 1000000},
|
||||
}
|
||||
)
|
||||
event_dict = {
|
||||
'content': {'body': "Here is the message content"},
|
||||
'event_id': "$0:domain",
|
||||
'origin': "domain",
|
||||
'origin_server_ts': 1000000,
|
||||
'type': "m.room.message",
|
||||
'room_id': "!r:domain",
|
||||
'sender': "@u:domain",
|
||||
'signatures': {},
|
||||
'unsigned': {'age_ts': 1000000},
|
||||
}
|
||||
|
||||
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
|
||||
add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
|
||||
|
||||
event = builder.build()
|
||||
event = FrozenEvent(event_dict)
|
||||
|
||||
self.assertTrue(hasattr(event, 'hashes'))
|
||||
self.assertIn('sha256', event.hashes)
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os.path
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
|
||||
def get_test_cert_file():
|
||||
"""get the path to the test cert"""
|
||||
|
||||
# the cert file itself is made with:
|
||||
#
|
||||
# openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \
|
||||
# -nodes -subj '/CN=testserv'
|
||||
return os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'server.pem',
|
||||
)
|
||||
|
||||
|
||||
class ServerTLSContext(object):
|
||||
"""A TLS Context which presents our test cert."""
|
||||
def __init__(self):
|
||||
self.filename = get_test_cert_file()
|
||||
|
||||
def getContext(self):
|
||||
ctx = SSL.Context(SSL.TLSv1_METHOD)
|
||||
ctx.use_certificate_file(self.filename)
|
||||
ctx.use_privatekey_file(self.filename)
|
||||
return ctx
|
|
@ -17,18 +17,26 @@ import logging
|
|||
from mock import Mock
|
||||
|
||||
import treq
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.test.ssl_helpers import ServerTLSContext
|
||||
from twisted.web.http import HTTPChannel
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IPolicyForHTTPS
|
||||
|
||||
from synapse.crypto.context_factory import ClientTLSOptionsFactory
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.http.federation.matrix_federation_agent import (
|
||||
MatrixFederationAgent,
|
||||
_cache_period_from_headers,
|
||||
)
|
||||
from synapse.http.federation.srv_resolver import Server
|
||||
from synapse.util.caches.ttlcache import TTLCache
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests.http import ServerTLSContext
|
||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
@ -41,10 +49,14 @@ class MatrixFederationAgentTests(TestCase):
|
|||
|
||||
self.mock_resolver = Mock()
|
||||
|
||||
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
|
||||
|
||||
self.agent = MatrixFederationAgent(
|
||||
reactor=self.reactor,
|
||||
tls_client_options_factory=ClientTLSOptionsFactory(None),
|
||||
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
|
||||
_srv_resolver=self.mock_resolver,
|
||||
_well_known_cache=self.well_known_cache,
|
||||
)
|
||||
|
||||
def _make_connection(self, client_factory, expected_sni):
|
||||
|
@ -65,10 +77,14 @@ class MatrixFederationAgentTests(TestCase):
|
|||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||
# stubbing that out here.
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
|
||||
client_protocol.makeConnection(
|
||||
FakeTransport(server_tls_protocol, self.reactor, client_protocol),
|
||||
)
|
||||
|
||||
# tell the server tls protocol to send its stuff back to the client, too
|
||||
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
|
||||
server_tls_protocol.makeConnection(
|
||||
FakeTransport(client_protocol, self.reactor, server_tls_protocol),
|
||||
)
|
||||
|
||||
# give the reactor a pump to get the TLS juices flowing.
|
||||
self.reactor.pump((0.1,))
|
||||
|
@ -101,9 +117,55 @@ class MatrixFederationAgentTests(TestCase):
|
|||
try:
|
||||
fetch_res = yield fetch_d
|
||||
defer.returnValue(fetch_res)
|
||||
except Exception as e:
|
||||
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
|
||||
raise
|
||||
finally:
|
||||
_check_logcontext(context)
|
||||
|
||||
def _handle_well_known_connection(
|
||||
self, client_factory, expected_sni, target_server, response_headers={},
|
||||
):
|
||||
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
|
||||
request is for a .well-known, and send the response.
|
||||
|
||||
Args:
|
||||
client_factory (IProtocolFactory): outgoing connection
|
||||
expected_sni (bytes): SNI that we expect the outgoing connection to send
|
||||
target_server (bytes): target server that we should redirect to in the
|
||||
.well-known response.
|
||||
Returns:
|
||||
HTTPChannel: server impl
|
||||
"""
|
||||
# make the connection for .well-known
|
||||
well_known_server = self._make_connection(
|
||||
client_factory,
|
||||
expected_sni=expected_sni,
|
||||
)
|
||||
# check the .well-known request and send a response
|
||||
self.assertEqual(len(well_known_server.requests), 1)
|
||||
request = well_known_server.requests[0]
|
||||
self._send_well_known_response(request, target_server, headers=response_headers)
|
||||
return well_known_server
|
||||
|
||||
def _send_well_known_response(self, request, target_server, headers={}):
|
||||
"""Check that an incoming request looks like a valid .well-known request, and
|
||||
send back the response.
|
||||
"""
|
||||
self.assertEqual(request.method, b'GET')
|
||||
self.assertEqual(request.path, b'/.well-known/matrix/server')
|
||||
self.assertEqual(
|
||||
request.requestHeaders.getRawHeaders(b'host'),
|
||||
[b'testserv'],
|
||||
)
|
||||
# send back a response
|
||||
for k, v in headers.items():
|
||||
request.setHeader(k, v)
|
||||
request.write(b'{ "m.server": "%s" }' % (target_server,))
|
||||
request.finish()
|
||||
|
||||
self.reactor.pump((0.1, ))
|
||||
|
||||
def test_get(self):
|
||||
"""
|
||||
happy-path test of a GET request with an explicit port
|
||||
|
@ -283,9 +345,9 @@ class MatrixFederationAgentTests(TestCase):
|
|||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
def test_get_hostname_no_srv(self):
|
||||
def test_get_no_srv_no_well_known(self):
|
||||
"""
|
||||
Test the behaviour when the server name has no port, and no SRV record
|
||||
Test the behaviour when the server name has no port, no SRV, and no well-known
|
||||
"""
|
||||
|
||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
||||
|
@ -300,11 +362,24 @@ class MatrixFederationAgentTests(TestCase):
|
|||
b"_matrix._tcp.testserv",
|
||||
)
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
# there should be an attempt to connect on port 443 for the .well-known
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
# fonx the connection
|
||||
client_factory.clientConnectionFailed(None, Exception("nope"))
|
||||
|
||||
# attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
|
||||
# .well-known request fails.
|
||||
self.reactor.pump((0.4,))
|
||||
|
||||
# we should fall back to a direct connection
|
||||
self.assertEqual(len(clients), 2)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8448)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
|
@ -327,6 +402,171 @@ class MatrixFederationAgentTests(TestCase):
|
|||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
def test_get_well_known(self):
|
||||
"""Test the behaviour when the server name has no port and no SRV record, but
|
||||
the .well-known redirects elsewhere
|
||||
"""
|
||||
|
||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
self.reactor.lookups["target-server"] = "1::f"
|
||||
|
||||
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||
b"_matrix._tcp.testserv",
|
||||
)
|
||||
self.mock_resolver.resolve_service.reset_mock()
|
||||
|
||||
# there should be an attempt to connect on port 443 for the .well-known
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
self._handle_well_known_connection(
|
||||
client_factory, expected_sni=b"testserv", target_server=b"target-server",
|
||||
)
|
||||
|
||||
# there should be another SRV lookup
|
||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||
b"_matrix._tcp.target-server",
|
||||
)
|
||||
|
||||
# now we should get a connection to the target server
|
||||
self.assertEqual(len(clients), 2)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
|
||||
self.assertEqual(host, '1::f')
|
||||
self.assertEqual(port, 8448)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory,
|
||||
expected_sni=b'target-server',
|
||||
)
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b'GET')
|
||||
self.assertEqual(request.path, b'/foo/bar')
|
||||
self.assertEqual(
|
||||
request.requestHeaders.getRawHeaders(b'host'),
|
||||
[b'target-server'],
|
||||
)
|
||||
|
||||
# finish the request
|
||||
request.finish()
|
||||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
|
||||
|
||||
# check the cache expires
|
||||
self.reactor.pump((25 * 3600,))
|
||||
self.well_known_cache.expire()
|
||||
self.assertNotIn(b"testserv", self.well_known_cache)
|
||||
|
||||
def test_get_well_known_redirect(self):
|
||||
"""Test the behaviour when the server name has no port and no SRV record, but
|
||||
the .well-known has a 300 redirect
|
||||
"""
|
||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
self.reactor.lookups["target-server"] = "1::f"
|
||||
|
||||
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||
b"_matrix._tcp.testserv",
|
||||
)
|
||||
self.mock_resolver.resolve_service.reset_mock()
|
||||
|
||||
# there should be an attempt to connect on port 443 for the .well-known
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
redirect_server = self._make_connection(
|
||||
client_factory,
|
||||
expected_sni=b"testserv",
|
||||
)
|
||||
|
||||
# send a 302 redirect
|
||||
self.assertEqual(len(redirect_server.requests), 1)
|
||||
request = redirect_server.requests[0]
|
||||
request.redirect(b'https://testserv/even_better_known')
|
||||
request.finish()
|
||||
|
||||
self.reactor.pump((0.1, ))
|
||||
|
||||
# now there should be another connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
well_known_server = self._make_connection(
|
||||
client_factory,
|
||||
expected_sni=b"testserv",
|
||||
)
|
||||
|
||||
self.assertEqual(len(well_known_server.requests), 1, "No request after 302")
|
||||
request = well_known_server.requests[0]
|
||||
self.assertEqual(request.method, b'GET')
|
||||
self.assertEqual(request.path, b'/even_better_known')
|
||||
request.write(b'{ "m.server": "target-server" }')
|
||||
request.finish()
|
||||
|
||||
self.reactor.pump((0.1, ))
|
||||
|
||||
# there should be another SRV lookup
|
||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||
b"_matrix._tcp.target-server",
|
||||
)
|
||||
|
||||
# now we should get a connection to the target server
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1::f')
|
||||
self.assertEqual(port, 8448)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory,
|
||||
expected_sni=b'target-server',
|
||||
)
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b'GET')
|
||||
self.assertEqual(request.path, b'/foo/bar')
|
||||
self.assertEqual(
|
||||
request.requestHeaders.getRawHeaders(b'host'),
|
||||
[b'target-server'],
|
||||
)
|
||||
|
||||
# finish the request
|
||||
request.finish()
|
||||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
|
||||
|
||||
# check the cache expires
|
||||
self.reactor.pump((25 * 3600,))
|
||||
self.well_known_cache.expire()
|
||||
self.assertNotIn(b"testserv", self.well_known_cache)
|
||||
|
||||
def test_get_hostname_srv(self):
|
||||
"""
|
||||
Test the behaviour when there is a single SRV record
|
||||
|
@ -372,6 +612,71 @@ class MatrixFederationAgentTests(TestCase):
|
|||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
def test_get_well_known_srv(self):
|
||||
"""Test the behaviour when the server name has no port and no SRV record, but
|
||||
the .well-known redirects to a place where there is a SRV.
|
||||
"""
|
||||
|
||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
self.reactor.lookups["srvtarget"] = "5.6.7.8"
|
||||
|
||||
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||
b"_matrix._tcp.testserv",
|
||||
)
|
||||
self.mock_resolver.resolve_service.reset_mock()
|
||||
|
||||
# there should be an attempt to connect on port 443 for the .well-known
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
||||
Server(host=b"srvtarget", port=8443),
|
||||
]
|
||||
|
||||
self._handle_well_known_connection(
|
||||
client_factory, expected_sni=b"testserv", target_server=b"target-server",
|
||||
)
|
||||
|
||||
# there should be another SRV lookup
|
||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||
b"_matrix._tcp.target-server",
|
||||
)
|
||||
|
||||
# now we should get a connection to the target of the SRV record
|
||||
self.assertEqual(len(clients), 2)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
|
||||
self.assertEqual(host, '5.6.7.8')
|
||||
self.assertEqual(port, 8443)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory,
|
||||
expected_sni=b'target-server',
|
||||
)
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b'GET')
|
||||
self.assertEqual(request.path, b'/foo/bar')
|
||||
self.assertEqual(
|
||||
request.requestHeaders.getRawHeaders(b'host'),
|
||||
[b'target-server'],
|
||||
)
|
||||
|
||||
# finish the request
|
||||
request.finish()
|
||||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
def test_idna_servername(self):
|
||||
"""test the behaviour when the server name has idna chars in"""
|
||||
|
||||
|
@ -390,11 +695,25 @@ class MatrixFederationAgentTests(TestCase):
|
|||
b"_matrix._tcp.xn--bcher-kva.com",
|
||||
)
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
# there should be an attempt to connect on port 443 for the .well-known
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
# fonx the connection
|
||||
client_factory.clientConnectionFailed(None, Exception("nope"))
|
||||
|
||||
# attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
|
||||
# .well-known request fails.
|
||||
self.reactor.pump((0.4,))
|
||||
|
||||
# We should fall back to port 8448
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 2)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8448)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
|
@ -461,6 +780,126 @@ class MatrixFederationAgentTests(TestCase):
|
|||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_get_well_known(self, serv):
|
||||
try:
|
||||
result = yield self.agent._get_well_known(serv)
|
||||
logger.info("Result from well-known fetch: %s", result)
|
||||
except Exception as e:
|
||||
logger.warning("Error fetching well-known: %s", e)
|
||||
raise
|
||||
defer.returnValue(result)
|
||||
|
||||
def test_well_known_cache(self):
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
fetch_d = self.do_get_well_known(b'testserv')
|
||||
|
||||
# there should be an attempt to connect on port 443 for the .well-known
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
well_known_server = self._handle_well_known_connection(
|
||||
client_factory,
|
||||
expected_sni=b"testserv",
|
||||
response_headers={b'Cache-Control': b'max-age=10'},
|
||||
target_server=b"target-server",
|
||||
)
|
||||
|
||||
r = self.successResultOf(fetch_d)
|
||||
self.assertEqual(r, b'target-server')
|
||||
|
||||
# close the tcp connection
|
||||
well_known_server.loseConnection()
|
||||
|
||||
# repeat the request: it should hit the cache
|
||||
fetch_d = self.do_get_well_known(b'testserv')
|
||||
r = self.successResultOf(fetch_d)
|
||||
self.assertEqual(r, b'target-server')
|
||||
|
||||
# expire the cache
|
||||
self.reactor.pump((10.0,))
|
||||
|
||||
# now it should connect again
|
||||
fetch_d = self.do_get_well_known(b'testserv')
|
||||
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
self._handle_well_known_connection(
|
||||
client_factory,
|
||||
expected_sni=b"testserv",
|
||||
target_server=b"other-server",
|
||||
)
|
||||
|
||||
r = self.successResultOf(fetch_d)
|
||||
self.assertEqual(r, b'other-server')
|
||||
|
||||
|
||||
class TestCachePeriodFromHeaders(TestCase):
|
||||
def test_cache_control(self):
|
||||
# uppercase
|
||||
self.assertEqual(
|
||||
_cache_period_from_headers(
|
||||
Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
|
||||
), 100,
|
||||
)
|
||||
|
||||
# missing value
|
||||
self.assertIsNone(_cache_period_from_headers(
|
||||
Headers({b'Cache-Control': [b'max-age=, bar']}),
|
||||
))
|
||||
|
||||
# hackernews: bogus due to semicolon
|
||||
self.assertIsNone(_cache_period_from_headers(
|
||||
Headers({b'Cache-Control': [b'private; max-age=0']}),
|
||||
))
|
||||
|
||||
# github
|
||||
self.assertEqual(
|
||||
_cache_period_from_headers(
|
||||
Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
|
||||
), 0,
|
||||
)
|
||||
|
||||
# google
|
||||
self.assertEqual(
|
||||
_cache_period_from_headers(
|
||||
Headers({b'cache-control': [b'private, max-age=0']}),
|
||||
), 0,
|
||||
)
|
||||
|
||||
def test_expires(self):
|
||||
self.assertEqual(
|
||||
_cache_period_from_headers(
|
||||
Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
|
||||
time_now=lambda: 1548833700
|
||||
), 33,
|
||||
)
|
||||
|
||||
# cache-control overrides expires
|
||||
self.assertEqual(
|
||||
_cache_period_from_headers(
|
||||
Headers({
|
||||
b'cache-control': [b'max-age=10'],
|
||||
b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
|
||||
}),
|
||||
time_now=lambda: 1548833700
|
||||
), 10,
|
||||
)
|
||||
|
||||
# invalid expires means immediate expiry
|
||||
self.assertEqual(
|
||||
_cache_period_from_headers(
|
||||
Headers({b'Expires': [b'0']}),
|
||||
), 0,
|
||||
)
|
||||
|
||||
|
||||
def _check_logcontext(context):
|
||||
current = LoggingContext.current_context()
|
||||
|
@ -492,3 +931,11 @@ def _build_test_server():
|
|||
def _log_request(request):
|
||||
"""Implements Factory.log, which is expected by Request.finish"""
|
||||
logger.info("Completed request %s", request)
|
||||
|
||||
|
||||
@implementer(IPolicyForHTTPS)
|
||||
class TrustingTLSPolicyForHTTPS(object):
|
||||
"""An IPolicyForHTTPS which doesn't do any certificate verification"""
|
||||
def creatorForNetloc(self, hostname, port):
|
||||
certificateOptions = OpenSSLCertificateOptions()
|
||||
return ClientTLSOptions(hostname, certificateOptions.getContext())
|
||||
|
|
81
tests/http/server.pem
Normal file
81
tests/http/server.pem
Normal file
|
@ -0,0 +1,81 @@
|
|||
-----BEGIN PRIVATE KEY-----
|
||||
MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCgF43/3lAgJ+p0
|
||||
x7Rn8UcL8a4fctvdkikvZrCngw96LkB34Evfq8YGWlOVjU+f9naUJLAKMatmAfEN
|
||||
r+rMX4VOXmpTwuu6iLtqwreUrRFMESyrmvQxa15p+y85gkY0CFmXMblv6ORbxHTG
|
||||
ncBGwST4WK4Poewcgt6jcISFCESTUKu1zc3cw1ANIDRyDLB5K44KwIe36dcKckyN
|
||||
Kdtv4BJ+3fcIZIkPJH62zqCypgFF1oiFt40uJzClxgHdJZlKYpgkfnDTckw4Y/Mx
|
||||
9k8BbE310KAzUNMV9H7I1eEolzrNr66FQj1eN64X/dqO8lTbwCqAd4diCT4sIUk0
|
||||
0SVsAUjNd3g8j651hx+Qb1t8fuOjrny8dmeMxtUgIBHoQcpcj76R55Fs7KZ9uar0
|
||||
8OFTyGIze51W1jG2K/7/5M1zxIqrA+7lsXu5OR81s7I+Ng/UUAhiHA/z+42/aiNa
|
||||
qEuk6tqj3rHfLctnCbtZ+JrRNqSSwEi8F0lMA021ivEd2eJV+284OyJjhXOmKHrX
|
||||
QADHrmS7Sh4syTZvRNm9n+qWID0KdDr2Sji/KnS3Enp44HDQ4xriT6/xhwEGsyuX
|
||||
oH5aAkdLznulbWkHBbyx1SUQSTLpOqzaioF9m1vRrLsFvrkrY3D253mPJ5eU9HM/
|
||||
dilduFcUgj4rz+6cdXUAh+KK/v95zwIDAQABAoICAFG5tJPaOa0ws0/KYx5s3YgL
|
||||
aIhFalhCNSQtmCDrlwsYcXDA3/rfBchYdDL0YKGYgBBAal3J3WXFt/j0xThvyu2m
|
||||
5UC9UPl4s7RckrsjXqEmY1d3UxGnbhtMT19cUdpeKN42VCP9EBaIw9Rg07dLAkSF
|
||||
gNYaIx6q8F0fI4eGIPvTQtUcqur4CfWpaxyNvckdovV6M85/YXfDwbCOnacPDGIX
|
||||
jfSK3i0MxGMuOHr6o8uzKR6aBUh6WStHWcw7VXXTvzdiFNbckmx3Gb93rf1b/LBw
|
||||
QFfx+tBKcC62gKroCOzXso/0sL9YTVeSD/DJZOiJwSiz3Dj/3u1IUMbVvfTU8wSi
|
||||
CYS7Z+jHxwSOCSSNTXm1wO/MtDsNKbI1+R0cohr/J9pOMQvrVh1+2zSDOFvXAQ1S
|
||||
yvjn+uqdmijRoV2VEGVHd+34C+ci7eJGAhL/f92PohuuFR2shUETgGWzpACZSJwg
|
||||
j1d90Hs81hj07vWRb+xCeDh00vimQngz9AD8vYvv/S4mqRGQ6TZdfjLoUwSTg0JD
|
||||
6sQgRXX026gQhLhn687vLKZfHwzQPZkpQdxOR0dTZ/ho/RyGGRJXH4kN4cA2tPr+
|
||||
AKYQ29YXGlEzGG7OqikaZcprNWG6UFgEpuXyBxCgp9r4ladZo3J+1Rhgus8ZYatd
|
||||
uO98q3WEBmP6CZ2n32mBAoIBAQDS/c/ybFTos0YpGHakwdmSfj5OOQJto2y8ywfG
|
||||
qDHwO0ebcpNnS1+MA+7XbKUQb/3Iq7iJljkkzJG2DIJ6rpKynYts1ViYpM7M/t0T
|
||||
W3V1gvUcUL62iqkgws4pnpWmubFkqV31cPSHcfIIclnzeQ1aOEGsGHNAvhty0ciC
|
||||
DnkJACbqApvopFLOR5f6UFTtKExE+hDH0WqgpsCAKJ1L4g6pBzZatI32/CN9JEVU
|
||||
tDbxLV75hHlFFjUrG7nT1rPyr/gI8Ceh9/2xeXPfjJUR0PrG3U1nwLqUCZkvFzO6
|
||||
XpN2+A+/v4v5xqMjKDKDFy1oq6SCMomwv/viw6wl/84TMbolAoIBAQDCPiMecnR8
|
||||
REik6tqVzQO/uSe9ZHjz6J15t5xdwaI6HpSwLlIkQPkLTjyXtFpemK5DOYRxrJvQ
|
||||
remfrZrN2qtLlb/DKpuGPWRsPOvWCrSuNEp48ivUehtclljrzxAFfy0sM+fWeJ48
|
||||
nTnR+td9KNhjNtZixzWdAy/mE+jdaMsXVnk66L73Uz+2WsnvVMW2R6cpCR0F2eP/
|
||||
B4zDWRqlT2w47sePAB81mFYSQLvPC6Xcgg1OqMubfiizJI49c8DO6Jt+FFYdsxhd
|
||||
kG52Eqa/Net6rN3ueiS6yXL5TU3Y6g96bPA2KyNCypucGcddcBfqaiVx/o4AH6yT
|
||||
NrdsrYtyvk/jAoIBAQDHUwKVeeRJJbvdbQAArCV4MI155n+1xhMe1AuXkCQFWGtQ
|
||||
nlBE4D72jmyf1UKnIbW2Uwv15xY6/ouVWYIWlj9+QDmMaozVP7Uiko+WDuwLRNl8
|
||||
k4dn+dzHV2HejbPBG2JLv3lFOx23q1zEwArcaXrExaq9Ayg2fKJ/uVHcFAIiD6Oz
|
||||
pR1XDY4w1A/uaN+iYFSVQUyDCQLbnEz1hej73CaPZoHh9Pq83vxD5/UbjVjuRTeZ
|
||||
L55FNzKpc/r89rNvTPBcuUwnxplDhYKDKVNWzn9rSXwrzTY2Tk8J3rh+k4RqevSd
|
||||
6D47jH1n5Dy7/TRn0ueKHGZZtTUnyEUkbOJo3ayFAoIBAHKDyZaQqaX9Z8p6fwWj
|
||||
yVsFoK0ih8BcWkLBAdmwZ6DWGJjJpjmjaG/G3ygc9s4gO1R8m12dAnuDnGE8KzDD
|
||||
gwtbrKM2Alyg4wyA2hTlWOH/CAzH0RlCJ9Fs/d1/xJVJBeuyajLiB3/6vXTS6qnq
|
||||
I7BSSxAPG8eGcn21LSsjNeB7ZZtaTgNnu/8ZBUYo9yrgkWc67TZe3/ChldYxOOlO
|
||||
qqHh/BqNWtjxB4VZTp/g4RbgQVInZ2ozdXEv0v/dt0UEk29ANAjsZif7F3RayJ2f
|
||||
/0TilzCaJ/9K9pKNhaClVRy7Dt8QjYg6BIWCGSw4ApF7pLnQ9gySn95mersCkVzD
|
||||
YDsCggEAb0E/TORjQhKfNQvahyLfQFm151e+HIoqBqa4WFyfFxe/IJUaLH/JSSFw
|
||||
VohbQqPdCmaAeuQ8ERL564DdkcY5BgKcax79fLLCOYP5bT11aQx6uFpfl2Dcm6Z9
|
||||
QdCRI4jzPftsd5fxLNH1XtGyC4t6vTic4Pji2O71WgWzx0j5v4aeDY4sZQeFxqCV
|
||||
/q7Ee8hem1Rn5RFHu14FV45RS4LAWl6wvf5pQtneSKzx8YL0GZIRRytOzdEfnGKr
|
||||
FeUlAj5uL+5/p0ZEgM7gPsEBwdm8scF79qSUn8UWSoXNeIauF9D4BDg8RZcFFxka
|
||||
KILVFsq3cQC+bEnoM4eVbjEQkGs1RQ==
|
||||
-----END PRIVATE KEY-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIE/jCCAuagAwIBAgIJANFtVaGvJWZlMA0GCSqGSIb3DQEBCwUAMBMxETAPBgNV
|
||||
BAMMCHRlc3RzZXJ2MCAXDTE5MDEyNzIyMDIzNloYDzIxMTkwMTAzMjIwMjM2WjAT
|
||||
MREwDwYDVQQDDAh0ZXN0c2VydjCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC
|
||||
ggIBAKAXjf/eUCAn6nTHtGfxRwvxrh9y292SKS9msKeDD3ouQHfgS9+rxgZaU5WN
|
||||
T5/2dpQksAoxq2YB8Q2v6sxfhU5ealPC67qIu2rCt5StEUwRLKua9DFrXmn7LzmC
|
||||
RjQIWZcxuW/o5FvEdMadwEbBJPhYrg+h7ByC3qNwhIUIRJNQq7XNzdzDUA0gNHIM
|
||||
sHkrjgrAh7fp1wpyTI0p22/gEn7d9whkiQ8kfrbOoLKmAUXWiIW3jS4nMKXGAd0l
|
||||
mUpimCR+cNNyTDhj8zH2TwFsTfXQoDNQ0xX0fsjV4SiXOs2vroVCPV43rhf92o7y
|
||||
VNvAKoB3h2IJPiwhSTTRJWwBSM13eDyPrnWHH5BvW3x+46OufLx2Z4zG1SAgEehB
|
||||
ylyPvpHnkWzspn25qvTw4VPIYjN7nVbWMbYr/v/kzXPEiqsD7uWxe7k5HzWzsj42
|
||||
D9RQCGIcD/P7jb9qI1qoS6Tq2qPesd8ty2cJu1n4mtE2pJLASLwXSUwDTbWK8R3Z
|
||||
4lX7bzg7ImOFc6YoetdAAMeuZLtKHizJNm9E2b2f6pYgPQp0OvZKOL8qdLcSenjg
|
||||
cNDjGuJPr/GHAQazK5egfloCR0vOe6VtaQcFvLHVJRBJMuk6rNqKgX2bW9GsuwW+
|
||||
uStjcPbneY8nl5T0cz92KV24VxSCPivP7px1dQCH4or+/3nPAgMBAAGjUzBRMB0G
|
||||
A1UdDgQWBBQcQZpzLzTk5KdS/Iz7sGCV7gTd/zAfBgNVHSMEGDAWgBQcQZpzLzTk
|
||||
5KdS/Iz7sGCV7gTd/zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC
|
||||
AQAr/Pgha57jqYsDDX1LyRrVdqoVBpLBeB7x/p9dKYm7S6tBTDFNMZ0SZyQP8VEG
|
||||
7UoC9/OQ9nCdEMoR7ZKpQsmipwcIqpXHS6l4YOkf5EEq5jpMgvlEesHmBJJeJew/
|
||||
FEPDl1bl8d0tSrmWaL3qepmwzA+2lwAAouWk2n+rLiP8CZ3jZeoTXFqYYrUlEqO9
|
||||
fHMvuWqTV4KCSyNY+GWCrnHetulgKHlg+W2J1mZnrCKcBhWf9C2DesTJO+JldIeM
|
||||
ornTFquSt21hZi+k3aySuMn2N3MWiNL8XsZVsAnPSs0zA+2fxjJkShls8Gc7cCvd
|
||||
a6XrNC+PY6pONguo7rEU4HiwbvnawSTngFFglmH/ImdA/HkaAekW6o82aI8/UxFx
|
||||
V9fFMO3iKDQdOrg77hI1bx9RlzKNZZinE2/Pu26fWd5d2zqDWCjl8ykGQRAfXgYN
|
||||
H3BjgyXLl+ao5/pOUYYtzm3ruTXTgRcy5hhL6hVTYhSrf9vYh4LNIeXNKnZ78tyG
|
||||
TX77/kU2qXhBGCFEUUMqUNV/+ITir2lmoxVjknt19M07aGr8C7SgYt6Rs+qDpMiy
|
||||
JurgvRh8LpVq4pHx1efxzxCFmo58DMrG40I0+CF3y/niNpOb1gp2wAqByRiORkds
|
||||
f0ytW6qZ0TpHbD6gOtQLYDnhx3ISuX+QYSekVwQUpffeWQ==
|
||||
-----END CERTIFICATE-----
|
|
@ -354,7 +354,13 @@ class FakeTransport(object):
|
|||
:type: twisted.internet.interfaces.IReactorTime
|
||||
"""
|
||||
|
||||
_protocol = attr.ib(default=None)
|
||||
"""The Protocol which is producing data for this transport. Optional, but if set
|
||||
will get called back for connectionLost() notifications etc.
|
||||
"""
|
||||
|
||||
disconnecting = False
|
||||
disconnected = False
|
||||
buffer = attr.ib(default=b'')
|
||||
producer = attr.ib(default=None)
|
||||
|
||||
|
@ -364,11 +370,17 @@ class FakeTransport(object):
|
|||
def getHost(self):
|
||||
return None
|
||||
|
||||
def loseConnection(self):
|
||||
self.disconnecting = True
|
||||
def loseConnection(self, reason=None):
|
||||
if not self.disconnecting:
|
||||
logger.info("FakeTransport: loseConnection(%s)", reason)
|
||||
self.disconnecting = True
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(reason)
|
||||
self.disconnected = True
|
||||
|
||||
def abortConnection(self):
|
||||
self.disconnecting = True
|
||||
logger.info("FakeTransport: abortConnection()")
|
||||
self.loseConnection()
|
||||
|
||||
def pauseProducing(self):
|
||||
if not self.producer:
|
||||
|
@ -407,9 +419,16 @@ class FakeTransport(object):
|
|||
# TLSMemoryBIOProtocol
|
||||
return
|
||||
|
||||
if self.disconnected:
|
||||
return
|
||||
logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
|
||||
|
||||
if getattr(self.other, "transport") is not None:
|
||||
self.other.dataReceived(self.buffer)
|
||||
self.buffer = b""
|
||||
try:
|
||||
self.other.dataReceived(self.buffer)
|
||||
self.buffer = b""
|
||||
except Exception as e:
|
||||
logger.warning("Exception writing to protocol: %s", e)
|
||||
return
|
||||
|
||||
self._reactor.callLater(0.0, _write)
|
||||
|
|
|
@ -11,7 +11,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup
|
||||
) # type: synapse.server.HomeServer
|
||||
)
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
|
|
|
@ -20,9 +20,6 @@ import tests.utils
|
|||
|
||||
|
||||
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
|
|
|
@ -22,9 +22,6 @@ import tests.utils
|
|||
|
||||
|
||||
class KeyStoreTestCase(tests.unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(KeyStoreTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.keys.KeyStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
|
|
|
@ -28,9 +28,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class StateStoreTestCase(tests.unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(StateStoreTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
|
|
18
tests/test_utils/__init__.py
Normal file
18
tests/test_utils/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utilities for running the unit tests
|
||||
"""
|
54
tests/test_utils/logging_setup.py
Normal file
54
tests/test_utils/logging_setup.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
|
||||
import twisted.logger
|
||||
|
||||
from synapse.util.logcontext import LoggingContextFilter
|
||||
|
||||
|
||||
class ToTwistedHandler(logging.Handler):
|
||||
"""logging handler which sends the logs to the twisted log"""
|
||||
tx_log = twisted.logger.Logger()
|
||||
|
||||
def emit(self, record):
|
||||
log_entry = self.format(record)
|
||||
log_level = record.levelname.lower().replace('warning', 'warn')
|
||||
self.tx_log.emit(
|
||||
twisted.logger.LogLevel.levelWithName(log_level),
|
||||
log_entry.replace("{", r"(").replace("}", r")"),
|
||||
)
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Configure the python logging appropriately for the tests.
|
||||
|
||||
(Logs will end up in _trial_temp.)
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
|
||||
)
|
||||
|
||||
handler = ToTwistedHandler()
|
||||
formatter = logging.Formatter(log_format)
|
||||
handler.setFormatter(formatter)
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
|
||||
root_logger.setLevel(log_level)
|
|
@ -166,7 +166,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def inject_message(self, user_id, content=None):
|
||||
if content is None:
|
||||
content = {"body": "testytest"}
|
||||
content = {"body": "testytest", "msgtype": "m.text"}
|
||||
builder = self.event_builder_factory.new(
|
||||
RoomVersions.V1,
|
||||
{
|
||||
|
|
|
@ -31,38 +31,14 @@ from synapse.http.server import JsonResource
|
|||
from synapse.http.site import SynapseRequest
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util.logcontext import LoggingContext, LoggingContextFilter
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests.server import get_clock, make_request, render, setup_test_homeserver
|
||||
from tests.test_utils.logging_setup import setup_logging
|
||||
from tests.utils import default_config, setupdb
|
||||
|
||||
setupdb()
|
||||
|
||||
# Set up putting Synapse's logs into Trial's.
|
||||
rootLogger = logging.getLogger()
|
||||
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class ToTwistedHandler(logging.Handler):
|
||||
tx_log = twisted.logger.Logger()
|
||||
|
||||
def emit(self, record):
|
||||
log_entry = self.format(record)
|
||||
log_level = record.levelname.lower().replace('warning', 'warn')
|
||||
self.tx_log.emit(
|
||||
twisted.logger.LogLevel.levelWithName(log_level),
|
||||
log_entry.replace("{", r"(").replace("}", r")"),
|
||||
)
|
||||
|
||||
|
||||
handler = ToTwistedHandler()
|
||||
formatter = logging.Formatter(log_format)
|
||||
handler.setFormatter(formatter)
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
rootLogger.addHandler(handler)
|
||||
setup_logging()
|
||||
|
||||
|
||||
def around(target):
|
||||
|
@ -96,7 +72,7 @@ class TestCase(unittest.TestCase):
|
|||
|
||||
method = getattr(self, methodName)
|
||||
|
||||
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
|
||||
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
|
||||
|
||||
@around(self)
|
||||
def setUp(orig):
|
||||
|
@ -114,7 +90,7 @@ class TestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
old_level = logging.getLogger().level
|
||||
if old_level != level:
|
||||
if level is not None and old_level != level:
|
||||
|
||||
@around(self)
|
||||
def tearDown(orig):
|
||||
|
@ -122,7 +98,8 @@ class TestCase(unittest.TestCase):
|
|||
logging.getLogger().setLevel(old_level)
|
||||
return ret
|
||||
|
||||
logging.getLogger().setLevel(level)
|
||||
logging.getLogger().setLevel(level)
|
||||
|
||||
return orig()
|
||||
|
||||
@around(self)
|
||||
|
|
83
tests/util/caches/test_ttlcache.py
Normal file
83
tests/util/caches/test_ttlcache.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.util.caches.ttlcache import TTLCache
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class CacheTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock_timer = Mock(side_effect=lambda: 100.0)
|
||||
self.cache = TTLCache("test_cache", self.mock_timer)
|
||||
|
||||
def test_get(self):
|
||||
"""simple set/get tests"""
|
||||
self.cache.set('one', '1', 10)
|
||||
self.cache.set('two', '2', 20)
|
||||
self.cache.set('three', '3', 30)
|
||||
|
||||
self.assertEqual(len(self.cache), 3)
|
||||
|
||||
self.assertTrue('one' in self.cache)
|
||||
self.assertEqual(self.cache.get('one'), '1')
|
||||
self.assertEqual(self.cache['one'], '1')
|
||||
self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110))
|
||||
self.assertEqual(self.cache._metrics.hits, 3)
|
||||
self.assertEqual(self.cache._metrics.misses, 0)
|
||||
|
||||
self.cache.set('two', '2.5', 20)
|
||||
self.assertEqual(self.cache['two'], '2.5')
|
||||
self.assertEqual(self.cache._metrics.hits, 4)
|
||||
|
||||
# non-existent-item tests
|
||||
self.assertEqual(self.cache.get('four', '4'), '4')
|
||||
self.assertIs(self.cache.get('four', None), None)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
self.cache['four']
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
self.cache.get('four')
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
self.cache.get_with_expiry('four')
|
||||
|
||||
self.assertEqual(self.cache._metrics.hits, 4)
|
||||
self.assertEqual(self.cache._metrics.misses, 5)
|
||||
|
||||
def test_expiry(self):
|
||||
self.cache.set('one', '1', 10)
|
||||
self.cache.set('two', '2', 20)
|
||||
self.cache.set('three', '3', 30)
|
||||
|
||||
self.assertEqual(len(self.cache), 3)
|
||||
self.assertEqual(self.cache['one'], '1')
|
||||
self.assertEqual(self.cache['two'], '2')
|
||||
|
||||
# enough for the first entry to expire, but not the rest
|
||||
self.mock_timer.side_effect = lambda: 110.0
|
||||
|
||||
self.assertEqual(len(self.cache), 2)
|
||||
self.assertFalse('one' in self.cache)
|
||||
self.assertEqual(self.cache['two'], '2')
|
||||
self.assertEqual(self.cache['three'], '3')
|
||||
|
||||
self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120))
|
||||
|
||||
self.assertEqual(self.cache._metrics.hits, 5)
|
||||
self.assertEqual(self.cache._metrics.misses, 0)
|
Loading…
Reference in a new issue