0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-19 08:24:25 +01:00

Merge branch 'release-v0.18.2' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-11-01 13:14:04 +00:00
commit 4a9055d446
57 changed files with 1563 additions and 804 deletions

8
.gitignore vendored
View file

@ -24,10 +24,10 @@ homeserver*.yaml
.coverage .coverage
htmlcov htmlcov
demo/*.db demo/*/*.db
demo/*.log demo/*/*.log
demo/*.log.* demo/*/*.log.*
demo/*.pid demo/*/*.pid
demo/media_store.* demo/media_store.*
demo/etc demo/etc

View file

@ -1,4 +1,80 @@
Changes in synapse v0.18.1 (2016-10-0) Changes in synapse v0.18.2 (2016-11-01)
=======================================
No changes since v0.18.2-rc5
Changes in synapse v0.18.2-rc5 (2016-10-28)
===========================================
Bug fixes:
* Fix prometheus process metrics in worker processes (PR #1184)
Changes in synapse v0.18.2-rc4 (2016-10-27)
===========================================
Bug fixes:
* Fix ``user_threepids`` schema delta, which in some instances prevented
startup after upgrade (PR #1183)
Changes in synapse v0.18.2-rc3 (2016-10-27)
===========================================
Changes:
* Allow clients to supply access tokens as headers (PR #1098)
* Clarify error codes for GET /filter/, thanks to Alexander Maznev (PR #1164)
* Make password reset email field case insensitive (PR #1170)
* Reduce redundant database work in email pusher (PR #1174)
* Allow configurable rate limiting per AS (PR #1175)
* Check whether to ratelimit sooner to avoid work (PR #1176)
* Standardise prometheus metrics (PR #1177)
Bug fixes:
* Fix incredibly slow back pagination query (PR #1178)
* Fix infinite typing bug (PR #1179)
Changes in synapse v0.18.2-rc2 (2016-10-25)
===========================================
(This release did not include the changes advertised and was identical to RC1)
Changes in synapse v0.18.2-rc1 (2016-10-17)
===========================================
Changes:
* Remove redundant event_auth index (PR #1113)
* Reduce DB hits for replication (PR #1141)
* Implement pluggable password auth (PR #1155)
* Remove rate limiting from app service senders and fix get_or_create_user
requester, thanks to Patrik Oldsberg (PR #1157)
* window.postmessage for Interactive Auth fallback (PR #1159)
* Use sys.executable instead of hardcoded python, thanks to Pedro Larroy
(PR #1162)
* Add config option for adding additional TLS fingerprints (PR #1167)
* User-interactive auth on delete device (PR #1168)
Bug fixes:
* Fix not being allowed to set your own state_key, thanks to Patrik Oldsberg
(PR #1150)
* Fix interactive auth to return 401 from for incorrect password (PR #1160,
#1166)
* Fix email push notifs being dropped (PR #1169)
Changes in synapse v0.18.1 (2016-10-05)
====================================== ======================================
No changes since v0.18.1-rc1 No changes since v0.18.1-rc1

View file

@ -15,36 +15,45 @@ How to monitor Synapse metrics using Prometheus
Restart synapse Restart synapse
3: Check out synapse-prometheus-config 3: Add a prometheus target for synapse. It needs to set the ``metrics_path``
https://github.com/matrix-org/synapse-prometheus-config to a non-default value::
4: Add ``synapse.html`` and ``synapse.rules`` - job_name: "synapse"
The ``.html`` file needs to appear in prometheus's ``consoles`` directory, metrics_path: "/_synapse/metrics"
and the ``.rules`` file needs to be invoked somewhere in the main config static_configs:
file. A symlink to each from the git checkout into the prometheus directory - targets:
might be easiest to ensure ``git pull`` keeps it updated. "my.server.here:9092"
5: Add a prometheus target for synapse Standard Metric Names
This is easiest if prometheus runs on the same machine as synapse, as it can ---------------------
then just use localhost::
global: { As of synapse version 0.18.2, the format of the process-wide metrics has been
rule_file: "synapse.rules" changed to fit prometheus standard naming conventions. Additionally the units
} have been changed to seconds, from miliseconds.
job: { ================================== =============================
name: "synapse" New name Old name
---------------------------------- -----------------------------
process_cpu_user_seconds_total process_resource_utime / 1000
process_cpu_system_seconds_total process_resource_stime / 1000
process_open_fds (no 'type' label) process_fds
================================== =============================
target_group: { The python-specific counts of garbage collector performance have been renamed.
target: "http://localhost:9092/"
}
}
6: Start prometheus:: =========================== ======================
New name Old name
--------------------------- ----------------------
python_gc_time reactor_gc_time
python_gc_unreachable_total reactor_gc_unreachable
python_gc_counts reactor_gc_counts
=========================== ======================
./prometheus -config.file=prometheus.conf The twisted-specific reactor metrics have been renamed.
7: Wait a few seconds for it to start and perform the first scrape, ==================================== =================
then visit the console: New name Old name
------------------------------------ -----------------
http://server-where-prometheus-runs:9090/consoles/synapse.html python_twisted_reactor_pending_calls reactor_tick_time
python_twisted_reactor_tick_time reactor_tick_time
==================================== =================

View file

@ -18,7 +18,9 @@
<div class="summarytext">{{ summary_text }}</div> <div class="summarytext">{{ summary_text }}</div>
</td> </td>
<td class="logo"> <td class="logo">
{% if app_name == "Vector" %} {% if app_name == "Riot" %}
<img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% else %} {% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.18.1" __version__ = "0.18.2"

View file

@ -603,10 +603,12 @@ class Auth(object):
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
user_id = yield self._get_appservice_user_id(request) user_id, app_service = yield self._get_appservice_user_id(request)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id)) defer.returnValue(
synapse.types.create_requester(user_id, app_service=app_service)
)
access_token = get_access_token_from_request( access_token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
@ -644,7 +646,8 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue(synapse.types.create_requester( defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id)) user, token_id, is_guest, device_id, app_service=app_service)
)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@ -653,20 +656,20 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
get_access_token_from_request( get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
) )
if app_service is None: if app_service is None:
defer.returnValue(None) defer.returnValue((None, None))
if "user_id" not in request.args: if "user_id" not in request.args:
defer.returnValue(app_service.sender) defer.returnValue((app_service.sender, app_service))
user_id = request.args["user_id"][0] user_id = request.args["user_id"][0]
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue(app_service.sender) defer.returnValue((app_service.sender, app_service))
if not app_service.is_interested_in_user(user_id): if not app_service.is_interested_in_user(user_id):
raise AuthError( raise AuthError(
@ -678,7 +681,7 @@ class Auth(object):
403, 403,
"Application service has not registered this user" "Application service has not registered this user"
) )
defer.returnValue(user_id) defer.returnValue((user_id, app_service))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"): def get_user_by_access_token(self, token, rights="access"):
@ -855,13 +858,12 @@ class Auth(object):
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = get_access_token_from_request( token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
service = yield self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError( raise AuthError(
@ -870,7 +872,7 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
request.authenticated_entity = service.sender request.authenticated_entity = service.sender
defer.returnValue(service) return defer.succeed(service)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
@ -1002,16 +1004,6 @@ class Auth(object):
403, 403,
"You are not allowed to set others state" "You are not allowed to set others state"
) )
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True return True
@ -1178,7 +1170,8 @@ def has_access_token(request):
bool: False if no access_token was given, True otherwise. bool: False if no access_token was given, True otherwise.
""" """
query_params = request.args.get("access_token") query_params = request.args.get("access_token")
return bool(query_params) auth_headers = request.requestHeaders.getRawHeaders("Authorization")
return bool(query_params) or bool(auth_headers)
def get_access_token_from_request(request, token_not_found_http_status=401): def get_access_token_from_request(request, token_not_found_http_status=401):
@ -1196,7 +1189,34 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
Raises: Raises:
AuthError: If there isn't an access_token in the request. AuthError: If there isn't an access_token in the request.
""" """
auth_headers = request.requestHeaders.getRawHeaders("Authorization")
query_params = request.args.get("access_token") query_params = request.args.get("access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
# Try to get the access_token from the query params. # Try to get the access_token from the query params.
if not query_params: if not query_params:
raise AuthError( raise AuthError(

View file

@ -23,7 +23,7 @@ class Ratelimiter(object):
def __init__(self): def __init__(self):
self.message_counts = collections.OrderedDict() self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count): def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
"""Can the user send a message? """Can the user send a message?
Args: Args:
user_id: The user sending a message. user_id: The user sending a message.
@ -32,12 +32,15 @@ class Ratelimiter(object):
second. second.
burst_count: How many messages the user can send before being burst_count: How many messages the user can send before being
limited. limited.
update (bool): Whether to update the message rates or not. This is
useful to check if a message would be allowed to be sent before
its ready to be actually sent.
Returns: Returns:
A pair of a bool indicating if they can send a message now and a A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message. time in seconds of when they can next send a message.
""" """
self.prune_message_counts(time_now_s) self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.pop( message_count, time_start, _ignored = self.message_counts.get(
user_id, (0., time_now_s, None), user_id, (0., time_now_s, None),
) )
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
@ -52,6 +55,7 @@ class Ratelimiter(object):
allowed = True allowed = True
message_count += 1 message_count += 1
if update:
self.message_counts[user_id] = ( self.message_counts[user_id] = (
message_count, time_start, msg_rate_hz message_count, time_start, msg_rate_hz
) )

View file

@ -197,7 +197,7 @@ class PusherServer(HomeServer):
yield start_pusher(user_id, app_id, pushkey) yield start_pusher(user_id, app_id, pushkey)
stream = results.get("events") stream = results.get("events")
if stream: if stream and stream["rows"]:
min_stream_id = stream["rows"][0][0] min_stream_id = stream["rows"][0][0]
max_stream_id = stream["position"] max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_notifications)( preserve_fn(pusher_pool.on_new_notifications)(
@ -205,7 +205,7 @@ class PusherServer(HomeServer):
) )
stream = results.get("receipts") stream = results.get("receipts")
if stream: if stream and stream["rows"]:
rows = stream["rows"] rows = stream["rows"]
affected_room_ids = set(row[1] for row in rows) affected_room_ids = set(row[1] for row in rows)
min_stream_id = rows[0][0] min_stream_id = rows[0][0]

View file

@ -24,7 +24,7 @@ import subprocess
import sys import sys
import yaml import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
RED = "\x1b[1;31m" RED = "\x1b[1;31m"

View file

@ -81,7 +81,7 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None): sender=None, id=None, protocols=None, rate_limited=True):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
@ -95,6 +95,8 @@ class ApplicationService(object):
else: else:
self.protocols = set() self.protocols = set()
self.rate_limited = rate_limited
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
# { # {
@ -234,5 +236,8 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id): def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def is_rate_limited(self):
return self.rate_limited
def __str__(self): def __str__(self):
return "ApplicationService: %s" % (self.__dict__,) return "ApplicationService: %s" % (self.__dict__,)

View file

@ -110,6 +110,11 @@ def _load_appservice(hostname, as_info, config_filename):
user = UserID(localpart, hostname) user = UserID(localpart, hostname)
user_id = user.to_string() user_id = user.to_string()
# Rate limiting for users of this AS is on by default (excludes sender)
rate_limited = True
if isinstance(as_info.get("rate_limited"), bool):
rate_limited = as_info.get("rate_limited")
# namespace checks # namespace checks
if not isinstance(as_info.get("namespaces"), dict): if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.") raise KeyError("Requires 'namespaces' object.")
@ -155,4 +160,5 @@ def _load_appservice(hostname, as_info, config_filename):
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
protocols=protocols, protocols=protocols,
rate_limited=rate_limited
) )

View file

@ -30,7 +30,7 @@ from .saml2 import SAML2Config
from .cas import CasConfig from .cas import CasConfig
from .password import PasswordConfig from .password import PasswordConfig
from .jwt import JWTConfig from .jwt import JWTConfig
from .ldap import LDAPConfig from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig from .workers import WorkerConfig
@ -39,8 +39,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig,): WorkerConfig, PasswordAuthProviderConfig,):
pass pass

View file

@ -1,100 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
MISSING_LDAP3 = (
"Missing ldap3 library. This is required for LDAP Authentication."
)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config):
def read_config(self, config):
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
# verify dependencies are available
try:
import ldap3
ldap3 # to stop unused lint
except ImportError:
raise ConfigError(MISSING_LDAP3)
self.ldap_mode = LDAPMode.SIMPLE
# verify config sanity
self.require_keys(ldap_config, [
"uri",
"base",
"attributes",
])
self.ldap_uri = ldap_config["uri"]
self.ldap_start_tls = ldap_config.get("start_tls", False)
self.ldap_base = ldap_config["base"]
self.ldap_attributes = ldap_config["attributes"]
if "bind_dn" in ldap_config:
self.ldap_mode = LDAPMode.SEARCH
self.require_keys(ldap_config, [
"bind_dn",
"bind_password",
])
self.ldap_bind_dn = ldap_config["bind_dn"]
self.ldap_bind_password = ldap_config["bind_password"]
self.ldap_filter = ldap_config.get("filter", None)
# verify attribute lookup
self.require_keys(ldap_config['attributes'], [
"uid",
"name",
"mail",
])
def require_keys(self, config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright 2016 Openmarket
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
import importlib
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
from synapse.util.ldap_auth_provider import LdapAuthProvider
parsed_config = LdapAuthProvider.parse_config(ldap_config)
self.password_providers.append((LdapAuthProvider, parsed_config))
providers = config.get("password_providers", [])
for provider in providers:
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
provider_config = provider_class.parse_config(provider["config"])
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
return """\
# password_providers:
# - module: "synapse.util.ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -19,6 +19,9 @@ from OpenSSL import crypto
import subprocess import subprocess
import os import os
from hashlib import sha256
from unpaddedbase64 import encode_base64
GENERATE_DH_PARAMS = False GENERATE_DH_PARAMS = False
@ -42,6 +45,19 @@ class TlsConfig(Config):
config.get("tls_dh_params_path"), "tls_dh_params" config.get("tls_dh_params_path"), "tls_dh_params"
) )
self.tls_fingerprints = config["tls_fingerprints"]
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
# This config option applies to non-federation HTTP clients # This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such) # (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for # It should never be used in production, and is intended for
@ -73,6 +89,28 @@ class TlsConfig(Config):
# Don't bind to the https port # Don't bind to the https port
no_tls: False no_tls: False
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
# make HTTPS requests to this server will check that the TLS
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds its the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
# will be necessary to add the fingerprints of the certificates used by
# the loadbalancers to this list if they are different to the one
# synapse is using.
#
# Homeservers are permitted to cache the list of TLS fingerprints
# returned in the key responses up to the "valid_until_ts" returned in
# key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it.
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals() """ % locals()
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):

View file

@ -55,8 +55,20 @@ class BaseHandler(object):
def ratelimit(self, requester): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
user_id = requester.user.to_string()
# The AS user itself is never rate limited.
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service is not None:
return # do not ratelimit app service senders
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
requester.user.to_string(), time_now, user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=self.hs.config.rc_message_burst_count,
) )

View file

@ -59,7 +59,7 @@ class ApplicationServicesHandler(object):
Args: Args:
current_id(int): The current maximum ID. current_id(int): The current maximum ID.
""" """
services = yield self.store.get_app_services() services = self.store.get_app_services()
if not services or not self.notify_appservices: if not services or not self.notify_appservices:
return return
@ -142,7 +142,7 @@ class ApplicationServicesHandler(object):
association can be found. association can be found.
""" """
room_alias_str = room_alias.to_string() room_alias_str = room_alias.to_string()
services = yield self.store.get_app_services() services = self.store.get_app_services()
alias_query_services = [ alias_query_services = [
s for s in services if ( s for s in services if (
s.is_interested_in_alias(room_alias_str) s.is_interested_in_alias(room_alias_str)
@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None): def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services() services = self.store.get_app_services()
protocols = {} protocols = {}
# Collect up all the individual protocol responses out of the ASes # Collect up all the individual protocol responses out of the ASes
@ -224,7 +224,7 @@ class ApplicationServicesHandler(object):
list<ApplicationService>: A list of services interested in this list<ApplicationService>: A list of services interested in this
event based on the service regex. event based on the service regex.
""" """
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
yield s.is_interested(event, self.store) yield s.is_interested(event, self.store)
@ -232,23 +232,21 @@ class ApplicationServicesHandler(object):
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_user(self, user_id): def _get_services_for_user(self, user_id):
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
s.is_interested_in_user(user_id) s.is_interested_in_user(user_id)
) )
] ]
defer.returnValue(interested_list) return defer.succeed(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol): def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if s.is_interested_in_protocol(protocol) s for s in services if s.is_interested_in_protocol(protocol)
] ]
defer.returnValue(interested_list) return defer.succeed(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
@ -264,7 +262,7 @@ class ApplicationServicesHandler(object):
return return
# user not found; could be the AS though, so check. # user not found; could be the AS though, so check.
services = yield self.store.get_app_services() services = self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id] service_list = [s for s in services if s.sender == user_id]
defer.returnValue(len(service_list) == 0) defer.returnValue(len(service_list) == 0)

View file

@ -20,7 +20,6 @@ from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -29,13 +28,6 @@ import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
try:
import ldap3
import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -59,23 +51,15 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled account_handler = _AccountHandler(
if self.ldap_enabled: hs, check_user_exists=self.check_user_exists
if not ldap3:
raise RuntimeError(
'Missing ldap3 library. This is required for LDAP Authentication.'
) )
self.ldap_mode = hs.config.ldap_mode
self.ldap_uri = hs.config.ldap_uri self.password_providers = [
self.ldap_start_tls = hs.config.ldap_start_tls module(config=config, account_handler=account_handler)
self.ldap_base = hs.config.ldap_base for module, config in hs.config.password_providers
self.ldap_attributes = hs.config.ldap_attributes ]
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@ -149,13 +133,30 @@ class AuthHandler(BaseHandler):
creds = session['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
errordict = {}
if 'type' in authdict: if 'type' in authdict:
if authdict['type'] not in self.checkers: login_type = authdict['type']
if login_type not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED) raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip) try:
result = yield self.checkers[login_type](authdict, clientip)
if result: if result:
creds[authdict['type']] = result creds[login_type] = result
self._save_session(session) self._save_session(session)
except LoginError, e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
# got a 401 with a 'flows' field.
# (https://github.com/vector-im/vector-web/issues/2447).
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
raise e
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
@ -164,6 +165,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id'])) defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -431,37 +433,40 @@ class AuthHandler(BaseHandler):
defer.Deferred: (str) canonical_user_id, or None if zero or defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches multiple matches
""" """
try:
res = yield self._find_user_id_and_pwd_hash(user_id) res = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
defer.returnValue(res[0]) defer.returnValue(res[0])
except LoginError:
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will throw if there are multiple inexact matches. insensitively, but will return None if there are multiple inexact
matches.
Returns: Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)` tuple: A 2-tuple of `(canonical_user_id, password_hash)`
None: if there is not exactly one match
""" """
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos: if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) elif len(user_infos) == 1:
# a single match (possibly not exact)
if len(user_infos) > 1: result = user_infos.popitem()
if user_id not in user_infos: elif user_id in user_infos:
# multiple matches, but one is exact
result = (user_id, user_infos[user_id])
else:
# multiple matches, none of them exact
logger.warn( logger.warn(
"Attempted to login as %s but it matches more than one user " "Attempted to login as %s but it matches more than one user "
"inexactly: %r", "inexactly: %r",
user_id, user_infos.keys() user_id, user_infos.keys()
) )
raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(result)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
@ -475,305 +480,48 @@ class AuthHandler(BaseHandler):
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id
Raises: Raises:
LoginError if the password was incorrect LoginError if login fails
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) for provider in self.password_providers:
if valid_ldap: is_valid = yield provider.check_password(user_id, password)
if is_valid:
defer.returnValue(user_id) defer.returnValue(user_id)
result = yield self._check_local_password(user_id, password) canonical_user_id = yield self._check_local_password(user_id, password)
defer.returnValue(result)
if canonical_user_id:
defer.returnValue(canonical_user_id)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
# into a 401 anyway.
raise LoginError(
403, "Invalid password",
errcode=Codes.FORBIDDEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are user_id is checked case insensitively, but will return None if there are
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (str): complete @user:id user_id (str): complete @user:id
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id, or None if unknown user / bad password
Raises:
LoginError if the password was incorrect
""" """
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash) result = self.validate_hash(password, password_hash)
if not result: if not result:
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(None)
defer.returnValue(user_id) defer.returnValue(user_id)
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting LDAP connection with %s",
self.ldap_uri
)
if self.ldap_mode == LDAPMode.SIMPLE:
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP authentication method simple bind returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
result, conn = self._ldap_authenticated_search(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
try:
logger.info(
"User authenticated against LDAP server: %s",
conn
)
except NameError:
logger.warn("Authentication method yielded no LDAP connection, aborting!")
defer.returnValue(False)
# check if user with user_id exists
if (yield self.check_user_exists(user_id)):
# exists, authentication complete
conn.unbind()
defer.returnValue(True)
else:
# does not exist, fetch metadata for account creation from
# existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug(
"ldap registration filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
registration_handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield registration_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
defer.returnValue(True)
else:
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None): def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
@ -863,6 +611,18 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.
# We've now moving towards the Home Server being the entity that
# is responsible for validating threepids used for resetting passwords
# on accounts, so in future Synapse will gain knowledge of specific
# types (mediums) of threepid. For now, we still use the existing
# infrastructure, but this is the start of synapse gaining knowledge
# of specific types of threepid (and fixes the fact that checking
# for the presenc eof an email address during password reset was
# case sensitive).
if medium == 'email':
address = address.lower()
yield self.store.user_add_threepid( yield self.store.user_add_threepid(
user_id, medium, address, validated_at, user_id, medium, address, validated_at,
self.hs.get_clock().time_msec() self.hs.get_clock().time_msec()
@ -911,3 +671,30 @@ class AuthHandler(BaseHandler):
stored_hash.encode('utf-8')) == stored_hash stored_hash.encode('utf-8')) == stored_hash
else: else:
return False return False
class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, check_user_exists):
self.hs = hs
self._check_user_exists = check_user_exists
def check_user_exists(self, user_id):
"""Check if user exissts.
Returns:
Deferred(bool)
"""
return self._check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)

View file

@ -288,13 +288,12 @@ class DirectoryHandler(BaseHandler):
result = yield as_handler.query_room_alias_exists(room_alias) result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None): def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on # Any application service "interested" in an alias they are regexing on
# can modify the alias. # can modify the alias.
# Users can only modify the alias if ALL the interested services have # Users can only modify the alias if ALL the interested services have
# non-exclusive locks on the alias (or there are no interested services) # non-exclusive locks on the alias (or there are no interested services)
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string()) s for s in services if s.is_interested_in_alias(alias.to_string())
] ]
@ -302,14 +301,12 @@ class DirectoryHandler(BaseHandler):
for service in interested_services: for service in interested_services:
if user_id == service.sender: if user_id == service.sender:
# this user IS the app service so they can do whatever they like # this user IS the app service so they can do whatever they like
defer.returnValue(True) return defer.succeed(True)
return
elif service.is_exclusive_alias(alias.to_string()): elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias. # another service has an exclusive lock on this alias.
defer.returnValue(False) return defer.succeed(False)
return
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
defer.returnValue(True) return defer.succeed(True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id): def _user_can_delete_alias(self, alias, user_id):

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -82,8 +82,8 @@ class MessageHandler(BaseHandler):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
else: else:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token( yield self.hs.get_event_sources().get_current_token_for_room(
direction='b' room_id=room_id
) )
) )
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@ -239,6 +239,21 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath" "Tried to send member event through non-member codepath"
) )
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
event.sender, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)

View file

@ -65,13 +65,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["displayname"]) defer.returnValue(result["displayname"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname): def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != requester.user: if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '': if new_displayname == '':
@ -111,13 +111,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["avatar_url"]) defer.returnValue(result["avatar_url"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url): def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != requester.user: if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(

View file

@ -19,7 +19,6 @@ import urllib
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
@ -194,7 +193,7 @@ class RegistrationHandler(BaseHandler):
def appservice_register(self, user_localpart, as_token): def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname) user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
service = yield self.store.get_app_service_by_token(as_token) service = self.store.get_app_service_by_token(as_token)
if not service: if not service:
raise AuthError(403, "Invalid application service token.") raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id): if not service.is_interested_in_user(user_id):
@ -305,11 +304,10 @@ class RegistrationHandler(BaseHandler):
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield identity_handler.bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
# valid user IDs must not clash with any user ID namespaces claimed by # valid user IDs must not clash with any user ID namespaces claimed by
# application services. # application services.
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services s for s in services
if s.is_interested_in_user(user_id) if s.is_interested_in_user(user_id)
@ -371,7 +369,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_in_ms, def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
password_hash=None): password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -418,9 +416,8 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, requester, displayname user, requester, displayname, by_admin=True,
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View file

@ -437,7 +437,7 @@ class RoomEventSource(object):
logger.warn("Stream has topological part!!!! %r", from_key) logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,) from_key = "s%s" % (from_token.stream,)
app_service = yield self.store.get_app_service_by_user_id( app_service = self.store.get_app_service_by_user_id(
user.to_string() user.to_string()
) )
if app_service: if app_service:
@ -475,8 +475,11 @@ class RoomEventSource(object):
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))
def get_current_key(self, direction='f'): def get_current_key(self):
return self.store.get_room_events_max_id(direction) return self.store.get_room_events_max_id()
def get_current_key_for_room(self, room_id):
return self.store.get_room_events_max_id(room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pagination_rows(self, user, config, key): def get_pagination_rows(self, user, config, key):

View file

@ -788,7 +788,7 @@ class SyncHandler(object):
assert since_token assert since_token
app_service = yield self.store.get_app_service_by_user_id(user_id) app_service = self.store.get_app_service_by_user_id(user_id)
if app_service: if app_service:
rooms = yield self.store.get_app_service_rooms(app_service) rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms) joined_room_ids = set(r.room_id for r in rooms)

View file

@ -88,7 +88,7 @@ class TypingHandler(object):
continue continue
until = self._member_typing_until.get(member, None) until = self._member_typing_until.get(member, None)
if not until or until < now: if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id) logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member) preserve_fn(self._stopped_typing)(member)
continue continue
@ -97,12 +97,20 @@ class TypingHandler(object):
# user. # user.
if self.hs.is_mine_id(member.user_id): if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None) last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now: if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
preserve_fn(self._push_remote)( preserve_fn(self._push_remote)(
member=member, member=member,
typing=True typing=True
) )
# Add a paranoia timer to ensure that we always have a timer for
# each person typing.
self.wheel_timer.insert(
now=now,
obj=member,
then=now + 60 * 1000,
)
def is_typing(self, member): def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, []) return member.user_id in self._room_typing.get(member.room_id, [])

View file

@ -13,14 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Because otherwise 'resource' collides with synapse.metrics.resource
from __future__ import absolute_import
import logging import logging
from resource import getrusage, RUSAGE_SELF
import functools import functools
import os
import stat
import time import time
import gc import gc
@ -30,12 +24,14 @@ from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
MemoryUsageMetric, MemoryUsageMetric,
) )
from .process_collector import register_process_collector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
all_metrics = [] all_metrics = []
all_collectors = []
class Metrics(object): class Metrics(object):
@ -46,6 +42,12 @@ class Metrics(object):
def __init__(self, name): def __init__(self, name):
self.name_prefix = name self.name_prefix = name
def make_subspace(self, name):
return Metrics("%s_%s" % (self.name_prefix, name))
def register_collector(self, func):
all_collectors.append(func)
def _register(self, metric_class, name, *args, **kwargs): def _register(self, metric_class, name, *args, **kwargs):
full_name = "%s_%s" % (self.name_prefix, name) full_name = "%s_%s" % (self.name_prefix, name)
@ -94,8 +96,8 @@ def get_metrics_for(pkg_name):
def render_all(): def render_all():
strs = [] strs = []
# TODO(paul): Internal hack for collector in all_collectors:
update_resource_metrics() collector()
for metric in all_metrics: for metric in all_metrics:
try: try:
@ -109,62 +111,6 @@ def render_all():
return "\n".join(strs) return "\n".join(strs)
# Now register some standard process-wide state metrics, to give indications of
# process resource usage
rusage = None
def update_resource_metrics():
global rusage
rusage = getrusage(RUSAGE_SELF)
resource_metrics = get_metrics_for("process.resource")
# msecs
resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# kilobytes
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
TYPES = {
stat.S_IFSOCK: "SOCK",
stat.S_IFLNK: "LNK",
stat.S_IFREG: "REG",
stat.S_IFBLK: "BLK",
stat.S_IFDIR: "DIR",
stat.S_IFCHR: "CHR",
stat.S_IFIFO: "FIFO",
}
def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
# Not every OS will have a /proc/self/fd directory
if not os.path.exists("/proc/self/fd"):
return counts
for fd in os.listdir("/proc/self/fd"):
try:
s = os.stat("/proc/self/fd/%s" % (fd))
fmt = stat.S_IFMT(s.st_mode)
if fmt in TYPES:
t = TYPES[fmt]
else:
t = "other"
counts[(t,)] += 1
except OSError:
# the dirh itself used by listdir() is usually missing by now
pass
return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor") reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time") tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls") pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
@ -176,6 +122,8 @@ reactor_metrics.register_callback(
"gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"] "gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
) )
register_process_collector(get_metrics_for("process"))
def runUntilCurrentTimer(func): def runUntilCurrentTimer(func):

View file

@ -98,9 +98,9 @@ class CallbackMetric(BaseMetric):
value = self.callback() value = self.callback()
if self.is_scalar(): if self.is_scalar():
return ["%s %d" % (self.name, value)] return ["%s %.12g" % (self.name, value)]
return ["%s%s %d" % (self.name, self._render_key(k), value[k]) return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
for k in sorted(value.keys())] for k in sorted(value.keys())]

View file

@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Because otherwise 'resource' collides with synapse.metrics.resource
from __future__ import absolute_import
import os
import stat
from resource import getrusage, RUSAGE_SELF
TICKS_PER_SEC = 100
BYTES_PER_PAGE = 4096
HAVE_PROC_STAT = os.path.exists("/proc/stat")
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits")
HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd")
TYPES = {
stat.S_IFSOCK: "SOCK",
stat.S_IFLNK: "LNK",
stat.S_IFREG: "REG",
stat.S_IFBLK: "BLK",
stat.S_IFDIR: "DIR",
stat.S_IFCHR: "CHR",
stat.S_IFIFO: "FIFO",
}
# Field indexes from /proc/self/stat, taken from the proc(5) manpage
STAT_FIELDS = {
"utime": 14,
"stime": 15,
"starttime": 22,
"vsize": 23,
"rss": 24,
}
rusage = None
stats = {}
fd_counts = None
# In order to report process_start_time_seconds we need to know the
# machine's boot time, because the value in /proc/self/stat is relative to
# this
boot_time = None
if HAVE_PROC_STAT:
with open("/proc/stat") as _procstat:
for line in _procstat:
if line.startswith("btime "):
boot_time = int(line.split()[1])
def update_resource_metrics():
global rusage
rusage = getrusage(RUSAGE_SELF)
if HAVE_PROC_SELF_STAT:
global stats
with open("/proc/self/stat") as s:
line = s.read()
# line is PID (command) more stats go here ...
raw_stats = line.split(") ", 1)[1].split(" ")
for (name, index) in STAT_FIELDS.iteritems():
# subtract 3 from the index, because proc(5) is 1-based, and
# we've lost the first two fields in PID and COMMAND above
stats[name] = int(raw_stats[index - 3])
global fd_counts
fd_counts = _process_fds()
def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
# Not every OS will have a /proc/self/fd directory
if not HAVE_PROC_SELF_FD:
return counts
for fd in os.listdir("/proc/self/fd"):
try:
s = os.stat("/proc/self/fd/%s" % (fd))
fmt = stat.S_IFMT(s.st_mode)
if fmt in TYPES:
t = TYPES[fmt]
else:
t = "other"
counts[(t,)] += 1
except OSError:
# the dirh itself used by listdir() is usually missing by now
pass
return counts
def register_process_collector(process_metrics):
# Legacy synapse-invented metric names
resource_metrics = process_metrics.make_subspace("resource")
resource_metrics.register_collector(update_resource_metrics)
# msecs
resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# kilobytes
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
process_metrics.register_callback("fds", _process_fds, labels=["type"])
# New prometheus-standard metric names
if HAVE_PROC_SELF_STAT:
process_metrics.register_callback(
"cpu_user_seconds_total",
lambda: float(stats["utime"]) / TICKS_PER_SEC
)
process_metrics.register_callback(
"cpu_system_seconds_total",
lambda: float(stats["stime"]) / TICKS_PER_SEC
)
process_metrics.register_callback(
"cpu_seconds_total",
lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC
)
process_metrics.register_callback(
"virtual_memory_bytes",
lambda: int(stats["vsize"])
)
process_metrics.register_callback(
"resident_memory_bytes",
lambda: int(stats["rss"]) * BYTES_PER_PAGE
)
process_metrics.register_callback(
"start_time_seconds",
lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC
)
if HAVE_PROC_SELF_FD:
process_metrics.register_callback(
"open_fds",
lambda: sum(fd_counts.values())
)
if HAVE_PROC_SELF_LIMITS:
def _get_max_fds():
with open("/proc/self/limits") as limits:
for line in limits:
if not line.startswith("Max open files "):
continue
# Line is Max open files $SOFT $HARD
return int(line.split()[3])
return None
process_metrics.register_callback(
"max_fds",
lambda: _get_max_fds()
)

View file

@ -150,6 +150,10 @@ class EmailPusher(object):
soonest_due_at = None soonest_due_at = None
if not unprocessed:
yield self.save_last_stream_ordering_and_success(self.max_stream_ordering)
return
for push_action in unprocessed: for push_action in unprocessed:
received_at = push_action['received_ts'] received_at = push_action['received_ts']
if received_at is None: if received_at is None:

View file

@ -328,7 +328,7 @@ class Mailer(object):
return messagevars return messagevars
@defer.inlineCallbacks @defer.inlineCallbacks
def make_summary_text(self, notifs_by_room, state_by_room, def make_summary_text(self, notifs_by_room, room_state_ids,
notif_events, user_id, reason): notif_events, user_id, reason):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
# Only one room has new stuff # Only one room has new stuff
@ -338,14 +338,18 @@ class Mailer(object):
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room" # end up with, "new message from Bob in the Bob room"
room_name = yield calculate_room_name( room_name = yield calculate_room_name(
self.store, state_by_room[room_id], user_id, fallback_to_members=False self.store, room_state_ids[room_id], user_id, fallback_to_members=False
) )
my_member_event = state_by_room[room_id][("m.room.member", user_id)] my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
if my_member_event.content["membership"] == "invite": if my_member_event.content["membership"] == "invite":
inviter_member_event = state_by_room[room_id][ inviter_member_event_id = room_state_ids[room_id][
("m.room.member", my_member_event.sender) ("m.room.member", my_member_event.sender)
] ]
inviter_member_event = yield self.store.get_event(
inviter_member_event_id
)
inviter_name = name_from_member_event(inviter_member_event) inviter_name = name_from_member_event(inviter_member_event)
if room_name is None: if room_name is None:
@ -364,8 +368,11 @@ class Mailer(object):
if len(notifs_by_room[room_id]) == 1: if len(notifs_by_room[room_id]) == 1:
# There is just the one notification, so give some detail # There is just the one notification, so give some detail
event = notif_events[notifs_by_room[room_id][0]["event_id"]] event = notif_events[notifs_by_room[room_id][0]["event_id"]]
if ("m.room.member", event.sender) in state_by_room[room_id]: if ("m.room.member", event.sender) in room_state_ids[room_id]:
state_event = state_by_room[room_id][("m.room.member", event.sender)] state_event_id = room_state_ids[room_id][
("m.room.member", event.sender)
]
state_event = yield self.store.get_event(state_event_id)
sender_name = name_from_member_event(state_event) sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None: if sender_name is not None and room_name is not None:
@ -395,11 +402,13 @@ class Mailer(object):
for n in notifs_by_room[room_id] for n in notifs_by_room[room_id]
])) ]))
defer.returnValue(MESSAGES_FROM_PERSON % { member_events = yield self.store.get_events([
"person": descriptor_from_member_events([ room_state_ids[room_id][("m.room.member", s)]
state_by_room[room_id][("m.room.member", s)]
for s in sender_ids for s in sender_ids
]), ])
defer.returnValue(MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name, "app": self.app_name,
}) })
else: else:
@ -419,11 +428,13 @@ class Mailer(object):
for n in notifs_by_room[reason['room_id']] for n in notifs_by_room[reason['room_id']]
])) ]))
defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { member_events = yield self.store.get_events([
"person": descriptor_from_member_events([ room_state_ids[room_id][("m.room.member", s)]
state_by_room[reason['room_id']][("m.room.member", s)]
for s in sender_ids for s in sender_ids
]), ])
defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name, "app": self.app_name,
}) })

View file

@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource from synapse.replication.presence_resource import PresenceResource
from synapse.api.errors import SynapseError
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -166,7 +167,8 @@ class ReplicationResource(Resource):
def replicate(): def replicate():
return self.replicate(request_streams, limit) return self.replicate(request_streams, limit)
result = yield self.notifier.wait_for_replication(replicate, timeout) writer = yield self.notifier.wait_for_replication(replicate, timeout)
result = writer.finish()
for stream_name, stream_content in result.items(): for stream_name, stream_content in result.items():
logger.info( logger.info(
@ -186,6 +188,9 @@ class ReplicationResource(Resource):
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.debug("Replicating up to %r", current_token) logger.debug("Replicating up to %r", current_token)
if limit == 0:
raise SynapseError(400, "Limit cannot be 0")
yield self.account_data(writer, current_token, limit, request_streams) yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams)
# TODO: implement limit # TODO: implement limit
@ -200,7 +205,7 @@ class ReplicationResource(Resource):
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.debug("Replicated %d rows", writer.total) logger.debug("Replicated %d rows", writer.total)
defer.returnValue(writer.finish()) defer.returnValue(writer)
def streams(self, writer, current_token, request_streams): def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams") request_token = request_streams.get("streams")
@ -237,27 +242,48 @@ class ReplicationResource(Resource):
request_events = current_token.events request_events = current_token.events
if request_backfill is None: if request_backfill is None:
request_backfill = current_token.backfill request_backfill = current_token.backfill
no_new_tokens = (
request_events == current_token.events
and request_backfill == current_token.backfill
)
if no_new_tokens:
return
res = yield self.store.get_all_new_events( res = yield self.store.get_all_new_events(
request_backfill, request_events, request_backfill, request_events,
current_token.backfill, current_token.events, current_token.backfill, current_token.events,
limit limit
) )
upto_events_token = _position_from_rows(
res.new_forward_events, current_token.events
)
upto_backfill_token = _position_from_rows(
res.new_backfill_events, current_token.backfill
)
if request_events != upto_events_token:
writer.write_header_and_rows("events", res.new_forward_events, ( writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group"
)) ), position=upto_events_token)
if request_backfill != upto_backfill_token:
writer.write_header_and_rows("backfill", res.new_backfill_events, ( writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group",
)) ), position=upto_backfill_token)
writer.write_header_and_rows( writer.write_header_and_rows(
"forward_ex_outliers", res.forward_ex_outliers, "forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group") ("position", "event_id", "state_group"),
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"backward_ex_outliers", res.backward_ex_outliers, "backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group") ("position", "event_id", "state_group"),
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",) "state_resets", res.state_resets, ("position",),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -266,15 +292,16 @@ class ReplicationResource(Resource):
request_presence = request_streams.get("presence") request_presence = request_streams.get("presence")
if request_presence is not None: if request_presence is not None and request_presence != current_position:
presence_rows = yield self.presence_handler.get_all_presence_updates( presence_rows = yield self.presence_handler.get_all_presence_updates(
request_presence, current_position request_presence, current_position
) )
upto_token = _position_from_rows(presence_rows, current_position)
writer.write_header_and_rows("presence", presence_rows, ( writer.write_header_and_rows("presence", presence_rows, (
"position", "user_id", "state", "last_active_ts", "position", "user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts", "last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active", "status_msg", "currently_active",
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token, request_streams): def typing(self, writer, current_token, request_streams):
@ -282,7 +309,7 @@ class ReplicationResource(Resource):
request_typing = request_streams.get("typing") request_typing = request_streams.get("typing")
if request_typing is not None: if request_typing is not None and request_typing != current_position:
# If they have a higher token than current max, we can assume that # If they have a higher token than current max, we can assume that
# they had been talking to a previous instance of the master. Since # they had been talking to a previous instance of the master. Since
# we reset the token on restart, the best (but hacky) thing we can # we reset the token on restart, the best (but hacky) thing we can
@ -293,9 +320,10 @@ class ReplicationResource(Resource):
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
request_typing, current_position request_typing, current_position
) )
upto_token = _position_from_rows(typing_rows, current_position)
writer.write_header_and_rows("typing", typing_rows, ( writer.write_header_and_rows("typing", typing_rows, (
"position", "room_id", "typing" "position", "room_id", "typing"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def receipts(self, writer, current_token, limit, request_streams): def receipts(self, writer, current_token, limit, request_streams):
@ -303,13 +331,14 @@ class ReplicationResource(Resource):
request_receipts = request_streams.get("receipts") request_receipts = request_streams.get("receipts")
if request_receipts is not None: if request_receipts is not None and request_receipts != current_position:
receipts_rows = yield self.store.get_all_updated_receipts( receipts_rows = yield self.store.get_all_updated_receipts(
request_receipts, current_position, limit request_receipts, current_position, limit
) )
upto_token = _position_from_rows(receipts_rows, current_position)
writer.write_header_and_rows("receipts", receipts_rows, ( writer.write_header_and_rows("receipts", receipts_rows, (
"position", "room_id", "receipt_type", "user_id", "event_id", "data" "position", "room_id", "receipt_type", "user_id", "event_id", "data"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def account_data(self, writer, current_token, limit, request_streams): def account_data(self, writer, current_token, limit, request_streams):
@ -324,23 +353,36 @@ class ReplicationResource(Resource):
user_account_data = current_position user_account_data = current_position
if room_account_data is None: if room_account_data is None:
room_account_data = current_position room_account_data = current_position
no_new_tokens = (
user_account_data == current_position
and room_account_data == current_position
)
if no_new_tokens:
return
user_rows, room_rows = yield self.store.get_all_updated_account_data( user_rows, room_rows = yield self.store.get_all_updated_account_data(
user_account_data, room_account_data, current_position, limit user_account_data, room_account_data, current_position, limit
) )
upto_users_token = _position_from_rows(user_rows, current_position)
upto_rooms_token = _position_from_rows(room_rows, current_position)
writer.write_header_and_rows("user_account_data", user_rows, ( writer.write_header_and_rows("user_account_data", user_rows, (
"position", "user_id", "type", "content" "position", "user_id", "type", "content"
)) ), position=upto_users_token)
writer.write_header_and_rows("room_account_data", room_rows, ( writer.write_header_and_rows("room_account_data", room_rows, (
"position", "user_id", "room_id", "type", "content" "position", "user_id", "room_id", "type", "content"
)) ), position=upto_rooms_token)
if tag_account_data is not None: if tag_account_data is not None:
tag_rows = yield self.store.get_all_updated_tags( tag_rows = yield self.store.get_all_updated_tags(
tag_account_data, current_position, limit tag_account_data, current_position, limit
) )
upto_tag_token = _position_from_rows(tag_rows, current_position)
writer.write_header_and_rows("tag_account_data", tag_rows, ( writer.write_header_and_rows("tag_account_data", tag_rows, (
"position", "user_id", "room_id", "tags" "position", "user_id", "room_id", "tags"
)) ), position=upto_tag_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules(self, writer, current_token, limit, request_streams): def push_rules(self, writer, current_token, limit, request_streams):
@ -348,14 +390,15 @@ class ReplicationResource(Resource):
push_rules = request_streams.get("push_rules") push_rules = request_streams.get("push_rules")
if push_rules is not None: if push_rules is not None and push_rules != current_position:
rows = yield self.store.get_all_push_rule_updates( rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit push_rules, current_position, limit
) )
upto_token = _position_from_rows(rows, current_position)
writer.write_header_and_rows("push_rules", rows, ( writer.write_header_and_rows("push_rules", rows, (
"position", "event_stream_ordering", "user_id", "rule_id", "op", "position", "event_stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions" "priority_class", "priority", "conditions", "actions"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def pushers(self, writer, current_token, limit, request_streams): def pushers(self, writer, current_token, limit, request_streams):
@ -363,18 +406,19 @@ class ReplicationResource(Resource):
pushers = request_streams.get("pushers") pushers = request_streams.get("pushers")
if pushers is not None: if pushers is not None and pushers != current_position:
updated, deleted = yield self.store.get_all_updated_pushers( updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit pushers, current_position, limit
) )
upto_token = _position_from_rows(updated, current_position)
writer.write_header_and_rows("pushers", updated, ( writer.write_header_and_rows("pushers", updated, (
"position", "user_id", "access_token", "profile_tag", "kind", "position", "user_id", "access_token", "profile_tag", "kind",
"app_id", "app_display_name", "device_display_name", "pushkey", "app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data" "ts", "lang", "data"
)) ), position=upto_token)
writer.write_header_and_rows("deleted_pushers", deleted, ( writer.write_header_and_rows("deleted_pushers", deleted, (
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams): def caches(self, writer, current_token, limit, request_streams):
@ -382,13 +426,14 @@ class ReplicationResource(Resource):
caches = request_streams.get("caches") caches = request_streams.get("caches")
if caches is not None: if caches is not None and caches != current_position:
updated_caches = yield self.store.get_all_updated_caches( updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit caches, current_position, limit
) )
upto_token = _position_from_rows(updated_caches, current_position)
writer.write_header_and_rows("caches", updated_caches, ( writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts" "position", "cache_func", "keys", "invalidation_ts"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def to_device(self, writer, current_token, limit, request_streams): def to_device(self, writer, current_token, limit, request_streams):
@ -396,13 +441,14 @@ class ReplicationResource(Resource):
to_device = request_streams.get("to_device") to_device = request_streams.get("to_device")
if to_device is not None: if to_device is not None and to_device != current_position:
to_device_rows = yield self.store.get_all_new_device_messages( to_device_rows = yield self.store.get_all_new_device_messages(
to_device, current_position, limit to_device, current_position, limit
) )
upto_token = _position_from_rows(to_device_rows, current_position)
writer.write_header_and_rows("to_device", to_device_rows, ( writer.write_header_and_rows("to_device", to_device_rows, (
"position", "user_id", "device_id", "message_json" "position", "user_id", "device_id", "message_json"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def public_rooms(self, writer, current_token, limit, request_streams): def public_rooms(self, writer, current_token, limit, request_streams):
@ -410,13 +456,14 @@ class ReplicationResource(Resource):
public_rooms = request_streams.get("public_rooms") public_rooms = request_streams.get("public_rooms")
if public_rooms is not None: if public_rooms is not None and public_rooms != current_position:
public_rooms_rows = yield self.store.get_all_new_public_rooms( public_rooms_rows = yield self.store.get_all_new_public_rooms(
public_rooms, current_position, limit public_rooms, current_position, limit
) )
upto_token = _position_from_rows(public_rooms_rows, current_position)
writer.write_header_and_rows("public_rooms", public_rooms_rows, ( writer.write_header_and_rows("public_rooms", public_rooms_rows, (
"position", "room_id", "visibility" "position", "room_id", "visibility"
)) ), position=upto_token)
class _Writer(object): class _Writer(object):
@ -426,11 +473,11 @@ class _Writer(object):
self.total = 0 self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None): def write_header_and_rows(self, name, rows, fields, position=None):
if not rows:
return
if position is None: if position is None:
if rows:
position = rows[-1][0] position = rows[-1][0]
else:
return
self.streams[name] = { self.streams[name] = {
"position": position if type(position) is int else str(position), "position": position if type(position) is int else str(position),
@ -440,6 +487,9 @@ class _Writer(object):
self.total += len(rows) self.total += len(rows)
def __nonzero__(self):
return bool(self.total)
def finish(self): def finish(self):
return self.streams return self.streams
@ -461,3 +511,20 @@ class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
def __str__(self): def __str__(self):
return "_".join(str(value) for value in self) return "_".join(str(value) for value in self)
def _position_from_rows(rows, current_position):
"""Calculates a position to return for a stream. Ideally we want to return the
position of the last row, as that will be the most correct. However, if there
are no rows we fall back to using the current position to stop us from
repeatedly hitting the storage layer unncessarily thinking there are updates.
(Not all advances of the token correspond to an actual update)
We can't just always return the current position, as we often limit the
number of rows we replicate, and so the stream may lag. The assumption is
that if the storage layer returns no new rows then we are not lagging and
we are at the `current_position`.
"""
if rows:
return rows[-1][0]
return current_position

View file

@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.types import create_requester
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -391,15 +392,16 @@ class CreateUserRestServlet(ClientV1RestServlet):
user_json = parse_json_object_from_request(request) user_json = parse_json_object_from_request(request)
access_token = get_access_token_from_request(request) access_token = get_access_token_from_request(request)
app_service = yield self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
access_token access_token
) )
if not app_service: if not app_service:
raise SynapseError(403, "Invalid application service token.") raise SynapseError(403, "Invalid application service token.")
logger.debug("creating user: %s", user_json) requester = create_requester(app_service.sender)
response = yield self._do_create(user_json) logger.debug("creating user: %s", user_json)
response = yield self._do_create(requester, user_json)
defer.returnValue((200, response)) defer.returnValue((200, response))
@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
return 403, {} return 403, {}
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_create(self, user_json): def _do_create(self, requester, user_json):
yield run_on_reactor() yield run_on_reactor()
if "localpart" not in user_json: if "localpart" not in user_json:
@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user( user_id, token = yield handler.get_or_create_user(
requester=requester,
localpart=localpart, localpart=localpart,
displayname=displayname, displayname=displayname,
duration_in_ms=(duration_seconds * 1000), duration_in_ms=(duration_seconds * 1000),

View file

@ -77,8 +77,10 @@ SUCCESS_TEMPLATE = """
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script> <script>
if (window.onAuthDone != undefined) { if (window.onAuthDone) {
window.onAuthDone(); window.onAuthDone();
} else if (window.opener && window.opener.postMessage) {
window.opener.postMessage("authDone", "*");
} }
</script> </script>
</head> </head>

View file

@ -17,6 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import constants, errors
from synapse.http import servlet from synapse.http import servlet
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -58,6 +59,7 @@ class DeviceRestServlet(servlet.RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
@ -70,11 +72,24 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, device_id): def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint. try:
# It allows the client to delete access tokens, which feels like a body = servlet.parse_json_object_from_request(request)
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one except errors.SynapseError as e:
# device at a time. if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise
authed, result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device( yield self.device_handler.delete_device(
requester.user.to_string(), requester.user.to_string(),

View file

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError, StoreError, Codes
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
@ -45,7 +45,7 @@ class GetFilterRestServlet(RestServlet):
raise AuthError(403, "Cannot get filters for other users") raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only get filters for local users") raise AuthError(403, "Can only get filters for local users")
try: try:
filter_id = int(filter_id) filter_id = int(filter_id)
@ -59,8 +59,8 @@ class GetFilterRestServlet(RestServlet):
) )
defer.returnValue((200, filter.get_filter_json())) defer.returnValue((200, filter.get_filter_json()))
except KeyError: except (KeyError, StoreError):
raise SynapseError(400, "No such filter") raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
@ -74,6 +74,7 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -81,10 +82,9 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users") raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,
user_filter=content, user_filter=content,

View file

@ -19,8 +19,6 @@ from synapse.http.server import respond_with_json_bytes
from signedjson.sign import sign_json from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from hashlib import sha256
from OpenSSL import crypto
import logging import logging
@ -48,8 +46,12 @@ class LocalKey(Resource):
"expired_ts": # integer posix timestamp when the key expired. "expired_ts": # integer posix timestamp when the key expired.
"key": # base64 encoded NACL verification key. "key": # base64 encoded NACL verification key.
} }
} },
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert. "tls_fingerprints": [ # Fingerprints of the TLS certs this server uses.
{
"sha256": # base64 encoded sha256 fingerprint of the X509 cert
},
],
"signatures": { "signatures": {
"this.server.example.com": { "this.server.example.com": {
"algorithm:version": # NACL signature for this server "algorithm:version": # NACL signature for this server
@ -90,21 +92,14 @@ class LocalKey(Resource):
u"expired_ts": key.expired, u"expired_ts": key.expired,
} }
x509_certificate_bytes = crypto.dump_certificate( tls_fingerprints = self.config.tls_fingerprints
crypto.FILETYPE_ASN1,
self.config.tls_certificate
)
sha256_fingerprint = sha256(x509_certificate_bytes).digest()
json_object = { json_object = {
u"valid_until_ts": self.valid_until_ts, u"valid_until_ts": self.valid_until_ts,
u"server_name": self.config.server_name, u"server_name": self.config.server_name,
u"verify_keys": verify_keys, u"verify_keys": verify_keys,
u"old_verify_keys": old_verify_keys, u"old_verify_keys": old_verify_keys,
u"tls_fingerprints": [{ u"tls_fingerprints": tls_fingerprints,
u"sha256": encode_base64(sha256_fingerprint),
}]
} }
for key in self.config.signing_key: for key in self.config.signing_key:
json_object = sign_json( json_object = sign_json(

View file

@ -85,7 +85,6 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql_logger.debug("[SQL] {%s} %s", self.name, sql)
sql = self.database_engine.convert_param_style(sql) sql = self.database_engine.convert_param_style(sql)
if args: if args:
try: try:
sql_logger.debug( sql_logger.debug(

View file

@ -37,7 +37,7 @@ class ApplicationServiceStore(SQLBaseStore):
) )
def get_app_services(self): def get_app_services(self):
return defer.succeed(self.services_cache) return self.services_cache
def get_app_service_by_user_id(self, user_id): def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID. """Retrieve an application service from their user ID.
@ -54,8 +54,8 @@ class ApplicationServiceStore(SQLBaseStore):
""" """
for service in self.services_cache: for service in self.services_cache:
if service.sender == user_id: if service.sender == user_id:
return defer.succeed(service) return service
return defer.succeed(None) return None
def get_app_service_by_token(self, token): def get_app_service_by_token(self, token):
"""Get the application service with the given appservice token. """Get the application service with the given appservice token.
@ -67,8 +67,8 @@ class ApplicationServiceStore(SQLBaseStore):
""" """
for service in self.services_cache: for service in self.services_cache:
if service.token == token: if service.token == token:
return defer.succeed(service) return service
return defer.succeed(None) return None
def get_app_service_rooms(self, service): def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service. """Get a list of RoomsForUser for this application service.
@ -163,7 +163,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
["as_id"] ["as_id"]
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
as_list = yield self.get_app_services() as_list = self.get_app_services()
services = [] services = []
for res in results: for res in results:

View file

@ -603,7 +603,6 @@ class EventsStore(SQLBaseStore):
"rejections", "rejections",
"redactions", "redactions",
"room_memberships", "room_memberships",
"state_events",
"topics" "topics"
): ):
txn.executemany( txn.executemany(

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 36 SCHEMA_VERSION = 37
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -320,6 +320,9 @@ class RoomStore(SQLBaseStore):
txn.execute(sql, (prev_id, current_id, limit,)) txn.execute(sql, (prev_id, current_id, limit,))
return txn.fetchall() return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.runInteraction( return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms "get_all_new_public_rooms", get_all_new_public_rooms
) )

View file

@ -0,0 +1,81 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
DROP_INDICES = """
-- We only ever query based on event_id
DROP INDEX IF EXISTS state_events_room_id;
DROP INDEX IF EXISTS state_events_type;
DROP INDEX IF EXISTS state_events_state_key;
-- room_id is indexed elsewhere
DROP INDEX IF EXISTS current_state_events_room_id;
DROP INDEX IF EXISTS current_state_events_state_key;
DROP INDEX IF EXISTS current_state_events_type;
DROP INDEX IF EXISTS transactions_have_ref;
-- (topological_ordering, stream_ordering, room_id) seems like a strange index,
-- and is used incredibly rarely.
DROP INDEX IF EXISTS events_order_topo_stream_room;
DROP INDEX IF EXISTS event_search_ev_idx;
"""
POSTGRES_DROP_CONSTRAINT = """
ALTER TABLE event_auth DROP CONSTRAINT IF EXISTS event_auth_event_id_auth_id_room_id_key;
"""
SQLITE_DROP_CONSTRAINT = """
DROP INDEX IF EXISTS evauth_edges_id;
CREATE TABLE IF NOT EXISTS event_auth_new(
event_id TEXT NOT NULL,
auth_id TEXT NOT NULL,
room_id TEXT NOT NULL
);
INSERT INTO event_auth_new
SELECT event_id, auth_id, room_id
FROM event_auth;
DROP TABLE event_auth;
ALTER TABLE event_auth_new RENAME TO event_auth;
CREATE INDEX evauth_edges_id ON event_auth(event_id);
"""
def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(DROP_INDICES.splitlines()):
cur.execute(statement)
if isinstance(database_engine, PostgresEngine):
drop_constraint = POSTGRES_DROP_CONSTRAINT
else:
drop_constraint = SQLITE_DROP_CONSTRAINT
for statement in get_statements(drop_constraint.splitlines()):
cur.execute(statement)
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -0,0 +1,52 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Update any email addresses that were stored with mixed case into all
* lowercase
*/
-- There may be "duplicate" emails (with different case) already in the table,
-- so we find them and move all but the most recently used account.
UPDATE user_threepids
SET medium = 'email_old'
WHERE medium = 'email'
AND address IN (
-- We select all the addresses that are linked to the user_id that is NOT
-- the most recently created.
SELECT u.address
FROM
user_threepids AS u,
-- `duplicate_addresses` is a table of all the email addresses that
-- appear multiple times and when the binding was created
(
SELECT lower(u1.address) AS address, max(u1.added_at) AS max_ts
FROM user_threepids AS u1
INNER JOIN user_threepids AS u2 ON u1.medium = u2.medium AND lower(u1.address) = lower(u2.address) AND u1.address != u2.address
WHERE u1.medium = 'email' AND u2.medium = 'email'
GROUP BY lower(u1.address)
) AS duplicate_addresses
WHERE
lower(u.address) = duplicate_addresses.address
AND u.added_at != max_ts -- NOT the most recently created
);
-- This update is now safe since we've removed the duplicate addresses.
UPDATE user_threepids SET address = LOWER(address) WHERE medium = 'email';
/* Add an index for the select we do on passwored reset */
CREATE INDEX user_threepids_medium_address on user_threepids (medium, address);

View file

@ -521,13 +521,20 @@ class StreamStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'): def get_room_events_max_id(self, room_id=None):
"""Returns the current token for rooms stream.
By default, it returns the current global stream token. Specifying a
`room_id` causes it to return the current room specific topological
token.
"""
token = yield self._stream_id_gen.get_current_token() token = yield self._stream_id_gen.get_current_token()
if direction != 'b': if room_id is None:
defer.returnValue("s%d" % (token,)) defer.returnValue("s%d" % (token,))
else: else:
topo = yield self.runInteraction( topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn "_get_max_topological_txn", self._get_max_topological_txn,
room_id,
) )
defer.returnValue("t%d-%d" % (topo, token)) defer.returnValue("t%d-%d" % (topo, token))
@ -579,11 +586,11 @@ class StreamStore(SQLBaseStore):
lambda r: r[0][0] if r else 0 lambda r: r[0][0] if r else 0
) )
def _get_max_topological_txn(self, txn): def _get_max_topological_txn(self, txn, room_id):
txn.execute( txn.execute(
"SELECT MAX(topological_ordering) FROM events" "SELECT MAX(topological_ordering) FROM events"
" WHERE outlier = ?", " WHERE room_id = ?",
(False,) (room_id,)
) )
rows = txn.fetchall() rows = txn.fetchall()

View file

@ -41,13 +41,39 @@ class EventSources(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_token(self, direction='f'): def get_current_token(self):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
yield self.sources["room"].get_current_key(direction) yield self.sources["room"].get_current_key()
),
presence_key=(
yield self.sources["presence"].get_current_key()
),
typing_key=(
yield self.sources["typing"].get_current_key()
),
receipt_key=(
yield self.sources["receipt"].get_current_key()
),
account_data_key=(
yield self.sources["account_data"].get_current_key()
),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
)
defer.returnValue(token)
@defer.inlineCallbacks
def get_current_token_for_room(self, room_id):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
token = StreamToken(
room_key=(
yield self.sources["room"].get_current_key_for_room(room_id)
), ),
presence_key=( presence_key=(
yield self.sources["presence"].get_current_key() yield self.sources["presence"].get_current_key()

View file

@ -18,8 +18,9 @@ from synapse.api.errors import SynapseError
from collections import namedtuple from collections import namedtuple
Requester = namedtuple("Requester", Requester = namedtuple("Requester", [
["user", "access_token_id", "is_guest", "device_id"]) "user", "access_token_id", "is_guest", "device_id", "app_service",
])
""" """
Represents the user making a request Represents the user making a request
@ -29,11 +30,12 @@ Attributes:
request, or None if it came via the appservice API or similar request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
""" """
def create_requester(user_id, access_token_id=None, is_guest=False, def create_requester(user_id, access_token_id=None, is_guest=False,
device_id=None): device_id=None, app_service=None):
""" """
Create a new ``Requester`` object Create a new ``Requester`` object
@ -43,13 +45,14 @@ def create_requester(user_id, access_token_id=None, is_guest=False,
request, or None if it came via the appservice API or similar request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
Returns: Returns:
Requester Requester
""" """
if not isinstance(user_id, UserID): if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id) user_id = UserID.from_string(user_id)
return Requester(user_id, access_token_id, is_guest, device_id) return Requester(user_id, access_token_id, is_guest, device_id, app_service)
def get_domain_from_id(string): def get_domain_from_id(string):

View file

@ -0,0 +1,368 @@
from twisted.internet import defer
from synapse.config._base import ConfigError
from synapse.types import UserID
import ldap3
import ldap3.core.exceptions
import logging
try:
import ldap3
import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
logger = logging.getLogger(__name__)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LdapAuthProvider(object):
__version__ = "0.1"
def __init__(self, config, account_handler):
self.account_handler = account_handler
if not ldap3:
raise RuntimeError(
'Missing ldap3 library. This is required for LDAP Authentication.'
)
self.ldap_mode = config.mode
self.ldap_uri = config.uri
self.ldap_start_tls = config.start_tls
self.ldap_base = config.base
self.ldap_attributes = config.attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = config.bind_dn
self.ldap_bind_password = config.bind_password
self.ldap_filter = config.filter
@defer.inlineCallbacks
def check_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting LDAP connection with %s",
self.ldap_uri
)
if self.ldap_mode == LDAPMode.SIMPLE:
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP authentication method simple bind returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
result, conn = self._ldap_authenticated_search(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
try:
logger.info(
"User authenticated against LDAP server: %s",
conn
)
except NameError:
logger.warn(
"Authentication method yielded no LDAP connection, aborting!"
)
defer.returnValue(False)
# check if user with user_id exists
if (yield self.account_handler.check_user_exists(user_id)):
# exists, authentication complete
conn.unbind()
defer.returnValue(True)
else:
# does not exist, fetch metadata for account creation from
# existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug(
"ldap registration filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
user_id, access_token = (
yield self.account_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
defer.returnValue(True)
else:
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@staticmethod
def parse_config(config):
class _LdapConfig(object):
pass
ldap_config = _LdapConfig()
ldap_config.enabled = config.get("enabled", False)
ldap_config.mode = LDAPMode.SIMPLE
# verify config sanity
_require_keys(config, [
"uri",
"base",
"attributes",
])
ldap_config.uri = config["uri"]
ldap_config.start_tls = config.get("start_tls", False)
ldap_config.base = config["base"]
ldap_config.attributes = config["attributes"]
if "bind_dn" in config:
ldap_config.mode = LDAPMode.SEARCH
_require_keys(config, [
"bind_dn",
"bind_password",
])
ldap_config.bind_dn = config["bind_dn"]
ldap_config.bind_password = config["bind_password"]
ldap_config.filter = config.get("filter", None)
# verify attribute lookup
_require_keys(config['attributes'], [
"uid",
"name",
"mail",
])
return ldap_config
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _require_keys(config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)

View file

@ -20,7 +20,7 @@ from mock import Mock
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.types import UserID from synapse.types import UserID
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver, mock_getRawHeaders
import pymacaroons import pymacaroons
@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id) self.assertEquals(requester.user.to_string(), masquerading_user_id)
@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from .. import unittest from .. import unittest
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID from synapse.types import UserID, create_requester
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "someone" local_part = "someone"
display_name = "someone" display_name = "someone"
user_id = "@someone:test" user_id = "@someone:test"
requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms) requester, local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "frank" local_part = "frank"
display_name = "Frank" display_name = "Frank"
user_id = "@frank:test" user_id = "@frank:test"
requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms) requester, local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')

View file

@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
"user_id": self.u_onion.to_string(), "user_id": self.u_onion.to_string(),
"typing": True, "typing": True,
} }
) ),
federation_auth=True,
) )
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls([

View file

@ -42,7 +42,8 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def replicate(self): def replicate(self):
streams = self.slaved_store.stream_positions() streams = self.slaved_store.stream_positions()
result = yield self.replication.replicate(streams, 100) writer = yield self.replication.replicate(streams, 100)
result = writer.finish()
yield self.slaved_store.process_replication(result) yield self.slaved_store.process_replication(result)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -120,7 +120,7 @@ class ReplicationResourceCase(unittest.TestCase):
self.hs.clock.advance_time_msec(1) self.hs.clock.advance_time_msec(1)
code, body = yield get code, body = yield get
self.assertEquals(code, 200) self.assertEquals(code, 200)
self.assertEquals(body, {}) self.assertEquals(body.get("rows", []), [])
test_timeout.__name__ = "test_timeout_%s" % (stream) test_timeout.__name__ = "test_timeout_%s" % (stream)
return test_timeout return test_timeout
@ -195,7 +195,6 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertIn("field_names", stream) self.assertIn("field_names", stream)
field_names = stream["field_names"] field_names = stream["field_names"]
self.assertIn("rows", stream) self.assertIn("rows", stream)
self.assertTrue(stream["rows"])
for row in stream["rows"]: for row in stream["rows"]:
self.assertEquals( self.assertEquals(
len(row), len(field_names), len(row), len(field_names),

View file

@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders
import json import json
@ -30,34 +31,23 @@ class CreateUserServletTestCase(unittest.TestCase):
path='/_matrix/client/api/v1/createUser' path='/_matrix/client/api/v1/createUser'
) )
self.request.args = {} self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
)
self.registration_handler = Mock() self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
# do the dance to hook it up to the hs global self.appservice = Mock(sender="@as:test")
self.handlers = Mock( self.datastore = Mock(
auth_handler=self.auth_handler, get_app_service_by_token=Mock(return_value=self.appservice)
)
# do the dance to hook things up to the hs global
handlers = Mock(
registration_handler=self.registration_handler, registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
) )
self.hs = Mock() self.hs = Mock()
self.hs.hostname = "supergbig~testing~thing.com" self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=handlers)
self.hs.config.enable_registration = True
# init the thing we're testing
self.servlet = CreateUserRestServlet(self.hs) self.servlet = CreateUserRestServlet(self.hs)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -15,78 +15,125 @@
from twisted.internet import defer from twisted.internet import defer
from . import V2AlphaRestTestCase from tests import unittest
from synapse.rest.client.v2_alpha import filter from synapse.rest.client.v2_alpha import filter
from synapse.api.errors import StoreError from synapse.api.errors import Codes
import synapse.types
from synapse.types import UserID
from ....utils import MockHttpResource, setup_test_homeserver
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(V2AlphaRestTestCase): class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test" USER_ID = "@apple:test"
EXAMPLE_FILTER = {"type": ["m.*"]}
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
def make_datastore_mock(self): @defer.inlineCallbacks
datastore = super(FilterTestCase, self).make_datastore_mock() def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self._user_filters = {} self.hs = yield setup_test_homeserver(
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
def add_user_filter(user_localpart, definition): self.auth = self.hs.get_auth()
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
def get_user_filter(user_localpart, filter_id): def get_user_by_access_token(token=None, allow_guest=False):
if user_localpart not in self._user_filters: return {
raise StoreError(404, "No user") "user": UserID.from_string(self.USER_ID),
filters = self._user_filters[user_localpart] "token_id": 1,
if filter_id >= len(filters): "is_guest": False,
raise StoreError(404, "No filter") }
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
return datastore def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None)
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.mock_resource)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}' "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
) )
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response) self.assertEquals({"filter_id": "0"}, response)
filter = yield self.store.get_user_filter(
user_localpart='apple',
filter_id=0,
)
self.assertEquals(filter, self.EXAMPLE_FILTER)
self.assertIn("apple", self._user_filters) @defer.inlineCallbacks
self.assertEquals(len(self._user_filters["apple"]), 1) def test_add_filter_for_other_user(self):
self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0]) (code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
)
self.assertEquals(403, code)
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
(code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
)
self.hs.is_mine = _is_mine
self.assertEquals(403, code)
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
self._user_filters["apple"] = [ filter_id = yield self.filtering.add_user_filter(
{"type": ["m.*"]} user_localpart='apple',
] user_filter=self.EXAMPLE_FILTER
)
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID) "/user/%s/filter/%s" % (self.USER_ID, filter_id)
) )
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"type": ["m.*"]}, response) self.assertEquals(self.EXAMPLE_FILTER, response)
@defer.inlineCallbacks
def test_get_filter_non_existant(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/12382148321" % (self.USER_ID)
)
self.assertEquals(400, code)
self.assertEquals(response['errcode'], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
# in errors.py
@defer.inlineCallbacks
def test_get_filter_invalid_id(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/foobar" % (self.USER_ID)
)
self.assertEquals(400, code)
# No ID also returns an invalid_id error
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter_no_id(self): def test_get_filter_no_id(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/2" % (self.USER_ID) "/user/%s/filter/" % (self.USER_ID)
) )
self.assertEquals(404, code) self.assertEquals(400, code)
@defer.inlineCallbacks
def test_get_filter_no_user(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID)
)
self.assertEquals(404, code)

View file

@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders
import json import json
@ -16,10 +17,11 @@ class RegisterRestServletTestCase(unittest.TestCase):
path='/_matrix/api/v2_alpha/register' path='/_matrix/api/v2_alpha/register'
) )
self.request.args = {} self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock( self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice)) side_effect=lambda x: self.appservice)
) )
self.auth_result = (False, None, None, None) self.auth_result = (False, None, None, None)

View file

@ -37,6 +37,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
config = Mock( config = Mock(
app_service_config_files=self.as_yaml_files, app_service_config_files=self.as_yaml_files,
event_cache_size=1, event_cache_size=1,
password_providers=[],
) )
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config)
@ -71,14 +72,12 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
outfile.write(yaml.dump(as_yaml)) outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token) self.as_yaml_files.append(as_token)
@defer.inlineCallbacks
def test_retrieve_unknown_service_token(self): def test_retrieve_unknown_service_token(self):
service = yield self.store.get_app_service_by_token("invalid_token") service = self.store.get_app_service_by_token("invalid_token")
self.assertEquals(service, None) self.assertEquals(service, None)
@defer.inlineCallbacks
def test_retrieval_of_service(self): def test_retrieval_of_service(self):
stored_service = yield self.store.get_app_service_by_token( stored_service = self.store.get_app_service_by_token(
self.as_token self.as_token
) )
self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.token, self.as_token)
@ -97,9 +96,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
[] []
) )
@defer.inlineCallbacks
def test_retrieval_of_all_services(self): def test_retrieval_of_all_services(self):
services = yield self.store.get_app_services() services = self.store.get_app_services()
self.assertEquals(len(services), 3) self.assertEquals(len(services), 3)
@ -112,6 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
config = Mock( config = Mock(
app_service_config_files=self.as_yaml_files, app_service_config_files=self.as_yaml_files,
event_cache_size=1, event_cache_size=1,
password_providers=[],
) )
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config)
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
@ -440,7 +439,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1") f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2") f2 = self._write_config(suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -450,7 +452,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(id="id", suffix="1") f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2") f2 = self._write_config(id="id", suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
@ -466,7 +471,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1") f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2") f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:

View file

@ -52,6 +52,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.server_name = name config.server_name = name
config.trusted_third_party_id_servers = [] config.trusted_third_party_id_servers = []
config.room_invite_state_types = [] config.room_invite_state_types = []
config.password_providers = []
config.use_frozen_dicts = True config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}
@ -115,6 +116,15 @@ def get_mock_call_args(pattern_func, mock_func):
return getcallargs(pattern_func, *invoked_args, **invoked_kargs) return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
def getRawHeaders(name, default=None):
return headers.get(name, default)
return getRawHeaders
# This is a mock /resource/ not an entire server # This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer): class MockHttpResource(HttpServer):
@ -127,7 +137,7 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
def trigger(self, http_method, path, content, mock_request): def trigger(self, http_method, path, content, mock_request, federation_auth=False):
""" Fire an HTTP event. """ Fire an HTTP event.
Args: Args:
@ -155,9 +165,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-" mock_request.getClientIP.return_value = "-"
mock_request.requestHeaders.getRawHeaders.return_value = [ headers = {}
"X-Matrix origin=test,key=,sig=" if federation_auth:
] headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it # return the right path if the event requires it
mock_request.path = path mock_request.path = path
@ -188,7 +199,7 @@ class MockHttpResource(HttpServer):
) )
defer.returnValue((code, response)) defer.returnValue((code, response))
except CodeMessageException as e: except CodeMessageException as e:
defer.returnValue((e.code, cs_error(e.msg))) defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
raise KeyError("No event can handle %s" % path) raise KeyError("No event can handle %s" % path)