0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-17 23:42:33 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into erikj/fix_email_notifs

This commit is contained in:
Erik Johnston 2016-10-13 14:23:48 +01:00
commit 6f7540ada4
18 changed files with 733 additions and 507 deletions

8
.gitignore vendored
View file

@ -24,10 +24,10 @@ homeserver*.yaml
.coverage .coverage
htmlcov htmlcov
demo/*.db demo/*/*.db
demo/*.log demo/*/*.log
demo/*.log.* demo/*/*.log.*
demo/*.pid demo/*/*.pid
demo/media_store.* demo/media_store.*
demo/etc demo/etc

View file

@ -1001,16 +1001,6 @@ class Auth(object):
403, 403,
"You are not allowed to set others state" "You are not allowed to set others state"
) )
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True return True

View file

@ -197,7 +197,7 @@ class PusherServer(HomeServer):
yield start_pusher(user_id, app_id, pushkey) yield start_pusher(user_id, app_id, pushkey)
stream = results.get("events") stream = results.get("events")
if stream: if stream and stream["rows"]:
min_stream_id = stream["rows"][0][0] min_stream_id = stream["rows"][0][0]
max_stream_id = stream["position"] max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_notifications)( preserve_fn(pusher_pool.on_new_notifications)(
@ -205,7 +205,7 @@ class PusherServer(HomeServer):
) )
stream = results.get("receipts") stream = results.get("receipts")
if stream: if stream and stream["rows"]:
rows = stream["rows"] rows = stream["rows"]
affected_room_ids = set(row[1] for row in rows) affected_room_ids = set(row[1] for row in rows)
min_stream_id = rows[0][0] min_stream_id = rows[0][0]

View file

@ -30,7 +30,7 @@ from .saml2 import SAML2Config
from .cas import CasConfig from .cas import CasConfig
from .password import PasswordConfig from .password import PasswordConfig
from .jwt import JWTConfig from .jwt import JWTConfig
from .ldap import LDAPConfig from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig from .workers import WorkerConfig
@ -39,8 +39,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig,): WorkerConfig, PasswordAuthProviderConfig,):
pass pass

View file

@ -1,100 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
MISSING_LDAP3 = (
"Missing ldap3 library. This is required for LDAP Authentication."
)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config):
def read_config(self, config):
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
# verify dependencies are available
try:
import ldap3
ldap3 # to stop unused lint
except ImportError:
raise ConfigError(MISSING_LDAP3)
self.ldap_mode = LDAPMode.SIMPLE
# verify config sanity
self.require_keys(ldap_config, [
"uri",
"base",
"attributes",
])
self.ldap_uri = ldap_config["uri"]
self.ldap_start_tls = ldap_config.get("start_tls", False)
self.ldap_base = ldap_config["base"]
self.ldap_attributes = ldap_config["attributes"]
if "bind_dn" in ldap_config:
self.ldap_mode = LDAPMode.SEARCH
self.require_keys(ldap_config, [
"bind_dn",
"bind_password",
])
self.ldap_bind_dn = ldap_config["bind_dn"]
self.ldap_bind_password = ldap_config["bind_password"]
self.ldap_filter = ldap_config.get("filter", None)
# verify attribute lookup
self.require_keys(ldap_config['attributes'], [
"uid",
"name",
"mail",
])
def require_keys(self, config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright 2016 Openmarket
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
import importlib
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
from synapse.util.ldap_auth_provider import LdapAuthProvider
parsed_config = LdapAuthProvider.parse_config(ldap_config)
self.password_providers.append((LdapAuthProvider, parsed_config))
providers = config.get("password_providers", [])
for provider in providers:
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
provider_config = provider_class.parse_config(provider["config"])
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
return """\
# password_providers:
# - module: "synapse.util.ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -19,6 +19,9 @@ from OpenSSL import crypto
import subprocess import subprocess
import os import os
from hashlib import sha256
from unpaddedbase64 import encode_base64
GENERATE_DH_PARAMS = False GENERATE_DH_PARAMS = False
@ -42,6 +45,19 @@ class TlsConfig(Config):
config.get("tls_dh_params_path"), "tls_dh_params" config.get("tls_dh_params_path"), "tls_dh_params"
) )
self.tls_fingerprints = config["tls_fingerprints"]
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
# This config option applies to non-federation HTTP clients # This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such) # (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for # It should never be used in production, and is intended for
@ -73,6 +89,28 @@ class TlsConfig(Config):
# Don't bind to the https port # Don't bind to the https port
no_tls: False no_tls: False
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
# make HTTPS requests to this server will check that the TLS
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds its the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
# will be necessary to add the fingerprints of the certificates used by
# the loadbalancers to this list if they are different to the one
# synapse is using.
#
# Homeservers are permitted to cache the list of TLS fingerprints
# returned in the key responses up to the "valid_until_ts" returned in
# key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it.
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals() """ % locals()
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):

View file

@ -20,7 +20,6 @@ from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -29,13 +28,6 @@ import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
try:
import ldap3
import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -59,23 +51,15 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled account_handler = _AccountHandler(
if self.ldap_enabled: hs, check_user_exists=self.check_user_exists
if not ldap3:
raise RuntimeError(
'Missing ldap3 library. This is required for LDAP Authentication.'
) )
self.ldap_mode = hs.config.ldap_mode
self.ldap_uri = hs.config.ldap_uri self.password_providers = [
self.ldap_start_tls = hs.config.ldap_start_tls module(config=config, account_handler=account_handler)
self.ldap_base = hs.config.ldap_base for module, config in hs.config.password_providers
self.ldap_attributes = hs.config.ldap_attributes ]
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@ -149,13 +133,30 @@ class AuthHandler(BaseHandler):
creds = session['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
errordict = {}
if 'type' in authdict: if 'type' in authdict:
if authdict['type'] not in self.checkers: login_type = authdict['type']
if login_type not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED) raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip) try:
result = yield self.checkers[login_type](authdict, clientip)
if result: if result:
creds[authdict['type']] = result creds[login_type] = result
self._save_session(session) self._save_session(session)
except LoginError, e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
# got a 401 with a 'flows' field.
# (https://github.com/vector-im/vector-web/issues/2447).
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
raise e
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
@ -164,6 +165,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id'])) defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -431,37 +433,40 @@ class AuthHandler(BaseHandler):
defer.Deferred: (str) canonical_user_id, or None if zero or defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches multiple matches
""" """
try:
res = yield self._find_user_id_and_pwd_hash(user_id) res = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
defer.returnValue(res[0]) defer.returnValue(res[0])
except LoginError:
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will throw if there are multiple inexact matches. insensitively, but will return None if there are multiple inexact
matches.
Returns: Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)` tuple: A 2-tuple of `(canonical_user_id, password_hash)`
None: if there is not exactly one match
""" """
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos: if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) elif len(user_infos) == 1:
# a single match (possibly not exact)
if len(user_infos) > 1: result = user_infos.popitem()
if user_id not in user_infos: elif user_id in user_infos:
# multiple matches, but one is exact
result = (user_id, user_infos[user_id])
else:
# multiple matches, none of them exact
logger.warn( logger.warn(
"Attempted to login as %s but it matches more than one user " "Attempted to login as %s but it matches more than one user "
"inexactly: %r", "inexactly: %r",
user_id, user_infos.keys() user_id, user_infos.keys()
) )
raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(result)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
@ -475,305 +480,48 @@ class AuthHandler(BaseHandler):
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id
Raises: Raises:
LoginError if the password was incorrect LoginError if login fails
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) for provider in self.password_providers:
if valid_ldap: is_valid = yield provider.check_password(user_id, password)
if is_valid:
defer.returnValue(user_id) defer.returnValue(user_id)
result = yield self._check_local_password(user_id, password) canonical_user_id = yield self._check_local_password(user_id, password)
defer.returnValue(result)
if canonical_user_id:
defer.returnValue(canonical_user_id)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
# into a 401 anyway.
raise LoginError(
403, "Invalid password",
errcode=Codes.FORBIDDEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are user_id is checked case insensitively, but will return None if there are
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (str): complete @user:id user_id (str): complete @user:id
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id, or None if unknown user / bad password
Raises:
LoginError if the password was incorrect
""" """
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash) result = self.validate_hash(password, password_hash)
if not result: if not result:
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(None)
defer.returnValue(user_id) defer.returnValue(user_id)
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting LDAP connection with %s",
self.ldap_uri
)
if self.ldap_mode == LDAPMode.SIMPLE:
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP authentication method simple bind returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
result, conn = self._ldap_authenticated_search(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
try:
logger.info(
"User authenticated against LDAP server: %s",
conn
)
except NameError:
logger.warn("Authentication method yielded no LDAP connection, aborting!")
defer.returnValue(False)
# check if user with user_id exists
if (yield self.check_user_exists(user_id)):
# exists, authentication complete
conn.unbind()
defer.returnValue(True)
else:
# does not exist, fetch metadata for account creation from
# existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug(
"ldap registration filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
registration_handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield registration_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
defer.returnValue(True)
else:
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None): def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
@ -911,3 +659,30 @@ class AuthHandler(BaseHandler):
stored_hash.encode('utf-8')) == stored_hash stored_hash.encode('utf-8')) == stored_hash
else: else:
return False return False
class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, check_user_exists):
self.hs = hs
self._check_user_exists = check_user_exists
def check_user_exists(self, user_id):
"""Check if user exissts.
Returns:
Deferred(bool)
"""
return self._check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)

View file

@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource from synapse.replication.presence_resource import PresenceResource
from synapse.api.errors import SynapseError
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -166,7 +167,8 @@ class ReplicationResource(Resource):
def replicate(): def replicate():
return self.replicate(request_streams, limit) return self.replicate(request_streams, limit)
result = yield self.notifier.wait_for_replication(replicate, timeout) writer = yield self.notifier.wait_for_replication(replicate, timeout)
result = writer.finish()
for stream_name, stream_content in result.items(): for stream_name, stream_content in result.items():
logger.info( logger.info(
@ -186,6 +188,9 @@ class ReplicationResource(Resource):
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.debug("Replicating up to %r", current_token) logger.debug("Replicating up to %r", current_token)
if limit == 0:
raise SynapseError(400, "Limit cannot be 0")
yield self.account_data(writer, current_token, limit, request_streams) yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams)
# TODO: implement limit # TODO: implement limit
@ -200,7 +205,7 @@ class ReplicationResource(Resource):
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.debug("Replicated %d rows", writer.total) logger.debug("Replicated %d rows", writer.total)
defer.returnValue(writer.finish()) defer.returnValue(writer)
def streams(self, writer, current_token, request_streams): def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams") request_token = request_streams.get("streams")
@ -237,27 +242,48 @@ class ReplicationResource(Resource):
request_events = current_token.events request_events = current_token.events
if request_backfill is None: if request_backfill is None:
request_backfill = current_token.backfill request_backfill = current_token.backfill
no_new_tokens = (
request_events == current_token.events
and request_backfill == current_token.backfill
)
if no_new_tokens:
return
res = yield self.store.get_all_new_events( res = yield self.store.get_all_new_events(
request_backfill, request_events, request_backfill, request_events,
current_token.backfill, current_token.events, current_token.backfill, current_token.events,
limit limit
) )
upto_events_token = _position_from_rows(
res.new_forward_events, current_token.events
)
upto_backfill_token = _position_from_rows(
res.new_backfill_events, current_token.backfill
)
if request_events != upto_events_token:
writer.write_header_and_rows("events", res.new_forward_events, ( writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group"
)) ), position=upto_events_token)
if request_backfill != upto_backfill_token:
writer.write_header_and_rows("backfill", res.new_backfill_events, ( writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group",
)) ), position=upto_backfill_token)
writer.write_header_and_rows( writer.write_header_and_rows(
"forward_ex_outliers", res.forward_ex_outliers, "forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group") ("position", "event_id", "state_group"),
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"backward_ex_outliers", res.backward_ex_outliers, "backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group") ("position", "event_id", "state_group"),
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",) "state_resets", res.state_resets, ("position",),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -266,15 +292,16 @@ class ReplicationResource(Resource):
request_presence = request_streams.get("presence") request_presence = request_streams.get("presence")
if request_presence is not None: if request_presence is not None and request_presence != current_position:
presence_rows = yield self.presence_handler.get_all_presence_updates( presence_rows = yield self.presence_handler.get_all_presence_updates(
request_presence, current_position request_presence, current_position
) )
upto_token = _position_from_rows(presence_rows, current_position)
writer.write_header_and_rows("presence", presence_rows, ( writer.write_header_and_rows("presence", presence_rows, (
"position", "user_id", "state", "last_active_ts", "position", "user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts", "last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active", "status_msg", "currently_active",
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token, request_streams): def typing(self, writer, current_token, request_streams):
@ -282,7 +309,7 @@ class ReplicationResource(Resource):
request_typing = request_streams.get("typing") request_typing = request_streams.get("typing")
if request_typing is not None: if request_typing is not None and request_typing != current_position:
# If they have a higher token than current max, we can assume that # If they have a higher token than current max, we can assume that
# they had been talking to a previous instance of the master. Since # they had been talking to a previous instance of the master. Since
# we reset the token on restart, the best (but hacky) thing we can # we reset the token on restart, the best (but hacky) thing we can
@ -293,9 +320,10 @@ class ReplicationResource(Resource):
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
request_typing, current_position request_typing, current_position
) )
upto_token = _position_from_rows(typing_rows, current_position)
writer.write_header_and_rows("typing", typing_rows, ( writer.write_header_and_rows("typing", typing_rows, (
"position", "room_id", "typing" "position", "room_id", "typing"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def receipts(self, writer, current_token, limit, request_streams): def receipts(self, writer, current_token, limit, request_streams):
@ -303,13 +331,14 @@ class ReplicationResource(Resource):
request_receipts = request_streams.get("receipts") request_receipts = request_streams.get("receipts")
if request_receipts is not None: if request_receipts is not None and request_receipts != current_position:
receipts_rows = yield self.store.get_all_updated_receipts( receipts_rows = yield self.store.get_all_updated_receipts(
request_receipts, current_position, limit request_receipts, current_position, limit
) )
upto_token = _position_from_rows(receipts_rows, current_position)
writer.write_header_and_rows("receipts", receipts_rows, ( writer.write_header_and_rows("receipts", receipts_rows, (
"position", "room_id", "receipt_type", "user_id", "event_id", "data" "position", "room_id", "receipt_type", "user_id", "event_id", "data"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def account_data(self, writer, current_token, limit, request_streams): def account_data(self, writer, current_token, limit, request_streams):
@ -324,23 +353,36 @@ class ReplicationResource(Resource):
user_account_data = current_position user_account_data = current_position
if room_account_data is None: if room_account_data is None:
room_account_data = current_position room_account_data = current_position
no_new_tokens = (
user_account_data == current_position
and room_account_data == current_position
)
if no_new_tokens:
return
user_rows, room_rows = yield self.store.get_all_updated_account_data( user_rows, room_rows = yield self.store.get_all_updated_account_data(
user_account_data, room_account_data, current_position, limit user_account_data, room_account_data, current_position, limit
) )
upto_users_token = _position_from_rows(user_rows, current_position)
upto_rooms_token = _position_from_rows(room_rows, current_position)
writer.write_header_and_rows("user_account_data", user_rows, ( writer.write_header_and_rows("user_account_data", user_rows, (
"position", "user_id", "type", "content" "position", "user_id", "type", "content"
)) ), position=upto_users_token)
writer.write_header_and_rows("room_account_data", room_rows, ( writer.write_header_and_rows("room_account_data", room_rows, (
"position", "user_id", "room_id", "type", "content" "position", "user_id", "room_id", "type", "content"
)) ), position=upto_rooms_token)
if tag_account_data is not None: if tag_account_data is not None:
tag_rows = yield self.store.get_all_updated_tags( tag_rows = yield self.store.get_all_updated_tags(
tag_account_data, current_position, limit tag_account_data, current_position, limit
) )
upto_tag_token = _position_from_rows(tag_rows, current_position)
writer.write_header_and_rows("tag_account_data", tag_rows, ( writer.write_header_and_rows("tag_account_data", tag_rows, (
"position", "user_id", "room_id", "tags" "position", "user_id", "room_id", "tags"
)) ), position=upto_tag_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules(self, writer, current_token, limit, request_streams): def push_rules(self, writer, current_token, limit, request_streams):
@ -348,14 +390,15 @@ class ReplicationResource(Resource):
push_rules = request_streams.get("push_rules") push_rules = request_streams.get("push_rules")
if push_rules is not None: if push_rules is not None and push_rules != current_position:
rows = yield self.store.get_all_push_rule_updates( rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit push_rules, current_position, limit
) )
upto_token = _position_from_rows(rows, current_position)
writer.write_header_and_rows("push_rules", rows, ( writer.write_header_and_rows("push_rules", rows, (
"position", "event_stream_ordering", "user_id", "rule_id", "op", "position", "event_stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions" "priority_class", "priority", "conditions", "actions"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def pushers(self, writer, current_token, limit, request_streams): def pushers(self, writer, current_token, limit, request_streams):
@ -363,18 +406,19 @@ class ReplicationResource(Resource):
pushers = request_streams.get("pushers") pushers = request_streams.get("pushers")
if pushers is not None: if pushers is not None and pushers != current_position:
updated, deleted = yield self.store.get_all_updated_pushers( updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit pushers, current_position, limit
) )
upto_token = _position_from_rows(updated, current_position)
writer.write_header_and_rows("pushers", updated, ( writer.write_header_and_rows("pushers", updated, (
"position", "user_id", "access_token", "profile_tag", "kind", "position", "user_id", "access_token", "profile_tag", "kind",
"app_id", "app_display_name", "device_display_name", "pushkey", "app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data" "ts", "lang", "data"
)) ), position=upto_token)
writer.write_header_and_rows("deleted_pushers", deleted, ( writer.write_header_and_rows("deleted_pushers", deleted, (
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams): def caches(self, writer, current_token, limit, request_streams):
@ -382,13 +426,14 @@ class ReplicationResource(Resource):
caches = request_streams.get("caches") caches = request_streams.get("caches")
if caches is not None: if caches is not None and caches != current_position:
updated_caches = yield self.store.get_all_updated_caches( updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit caches, current_position, limit
) )
upto_token = _position_from_rows(updated_caches, current_position)
writer.write_header_and_rows("caches", updated_caches, ( writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts" "position", "cache_func", "keys", "invalidation_ts"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def to_device(self, writer, current_token, limit, request_streams): def to_device(self, writer, current_token, limit, request_streams):
@ -396,13 +441,14 @@ class ReplicationResource(Resource):
to_device = request_streams.get("to_device") to_device = request_streams.get("to_device")
if to_device is not None: if to_device is not None and to_device != current_position:
to_device_rows = yield self.store.get_all_new_device_messages( to_device_rows = yield self.store.get_all_new_device_messages(
to_device, current_position, limit to_device, current_position, limit
) )
upto_token = _position_from_rows(to_device_rows, current_position)
writer.write_header_and_rows("to_device", to_device_rows, ( writer.write_header_and_rows("to_device", to_device_rows, (
"position", "user_id", "device_id", "message_json" "position", "user_id", "device_id", "message_json"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def public_rooms(self, writer, current_token, limit, request_streams): def public_rooms(self, writer, current_token, limit, request_streams):
@ -410,13 +456,14 @@ class ReplicationResource(Resource):
public_rooms = request_streams.get("public_rooms") public_rooms = request_streams.get("public_rooms")
if public_rooms is not None: if public_rooms is not None and public_rooms != current_position:
public_rooms_rows = yield self.store.get_all_new_public_rooms( public_rooms_rows = yield self.store.get_all_new_public_rooms(
public_rooms, current_position, limit public_rooms, current_position, limit
) )
upto_token = _position_from_rows(public_rooms_rows, current_position)
writer.write_header_and_rows("public_rooms", public_rooms_rows, ( writer.write_header_and_rows("public_rooms", public_rooms_rows, (
"position", "room_id", "visibility" "position", "room_id", "visibility"
)) ), position=upto_token)
class _Writer(object): class _Writer(object):
@ -426,11 +473,11 @@ class _Writer(object):
self.total = 0 self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None): def write_header_and_rows(self, name, rows, fields, position=None):
if not rows:
return
if position is None: if position is None:
if rows:
position = rows[-1][0] position = rows[-1][0]
else:
return
self.streams[name] = { self.streams[name] = {
"position": position if type(position) is int else str(position), "position": position if type(position) is int else str(position),
@ -440,6 +487,9 @@ class _Writer(object):
self.total += len(rows) self.total += len(rows)
def __nonzero__(self):
return bool(self.total)
def finish(self): def finish(self):
return self.streams return self.streams
@ -461,3 +511,20 @@ class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
def __str__(self): def __str__(self):
return "_".join(str(value) for value in self) return "_".join(str(value) for value in self)
def _position_from_rows(rows, current_position):
"""Calculates a position to return for a stream. Ideally we want to return the
position of the last row, as that will be the most correct. However, if there
are no rows we fall back to using the current position to stop us from
repeatedly hitting the storage layer unncessarily thinking there are updates.
(Not all advances of the token correspond to an actual update)
We can't just always return the current position, as we often limit the
number of rows we replicate, and so the stream may lag. The assumption is
that if the storage layer returns no new rows then we are not lagging and
we are at the `current_position`.
"""
if rows:
return rows[-1][0]
return current_position

View file

@ -77,8 +77,10 @@ SUCCESS_TEMPLATE = """
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script> <script>
if (window.onAuthDone != undefined) { if (window.onAuthDone) {
window.onAuthDone(); window.onAuthDone();
} else if (window.opener && window.opener.postMessage) {
window.opener.postMessage("authDone", "*");
} }
</script> </script>
</head> </head>

View file

@ -17,6 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import constants, errors
from synapse.http import servlet from synapse.http import servlet
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -58,6 +59,7 @@ class DeviceRestServlet(servlet.RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
@ -70,11 +72,24 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, device_id): def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint. try:
# It allows the client to delete access tokens, which feels like a body = servlet.parse_json_object_from_request(request)
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one except errors.SynapseError as e:
# device at a time. if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise
authed, result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device( yield self.device_handler.delete_device(
requester.user.to_string(), requester.user.to_string(),

View file

@ -19,8 +19,6 @@ from synapse.http.server import respond_with_json_bytes
from signedjson.sign import sign_json from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from hashlib import sha256
from OpenSSL import crypto
import logging import logging
@ -48,8 +46,12 @@ class LocalKey(Resource):
"expired_ts": # integer posix timestamp when the key expired. "expired_ts": # integer posix timestamp when the key expired.
"key": # base64 encoded NACL verification key. "key": # base64 encoded NACL verification key.
} }
} },
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert. "tls_fingerprints": [ # Fingerprints of the TLS certs this server uses.
{
"sha256": # base64 encoded sha256 fingerprint of the X509 cert
},
],
"signatures": { "signatures": {
"this.server.example.com": { "this.server.example.com": {
"algorithm:version": # NACL signature for this server "algorithm:version": # NACL signature for this server
@ -90,21 +92,14 @@ class LocalKey(Resource):
u"expired_ts": key.expired, u"expired_ts": key.expired,
} }
x509_certificate_bytes = crypto.dump_certificate( tls_fingerprints = self.config.tls_fingerprints
crypto.FILETYPE_ASN1,
self.config.tls_certificate
)
sha256_fingerprint = sha256(x509_certificate_bytes).digest()
json_object = { json_object = {
u"valid_until_ts": self.valid_until_ts, u"valid_until_ts": self.valid_until_ts,
u"server_name": self.config.server_name, u"server_name": self.config.server_name,
u"verify_keys": verify_keys, u"verify_keys": verify_keys,
u"old_verify_keys": old_verify_keys, u"old_verify_keys": old_verify_keys,
u"tls_fingerprints": [{ u"tls_fingerprints": tls_fingerprints,
u"sha256": encode_base64(sha256_fingerprint),
}]
} }
for key in self.config.signing_key: for key in self.config.signing_key:
json_object = sign_json( json_object = sign_json(

View file

@ -320,6 +320,9 @@ class RoomStore(SQLBaseStore):
txn.execute(sql, (prev_id, current_id, limit,)) txn.execute(sql, (prev_id, current_id, limit,))
return txn.fetchall() return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.runInteraction( return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms "get_all_new_public_rooms", get_all_new_public_rooms
) )

View file

@ -0,0 +1,368 @@
from twisted.internet import defer
from synapse.config._base import ConfigError
from synapse.types import UserID
import ldap3
import ldap3.core.exceptions
import logging
try:
import ldap3
import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
logger = logging.getLogger(__name__)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LdapAuthProvider(object):
__version__ = "0.1"
def __init__(self, config, account_handler):
self.account_handler = account_handler
if not ldap3:
raise RuntimeError(
'Missing ldap3 library. This is required for LDAP Authentication.'
)
self.ldap_mode = config.mode
self.ldap_uri = config.uri
self.ldap_start_tls = config.start_tls
self.ldap_base = config.base
self.ldap_attributes = config.attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = config.bind_dn
self.ldap_bind_password = config.bind_password
self.ldap_filter = config.filter
@defer.inlineCallbacks
def check_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting LDAP connection with %s",
self.ldap_uri
)
if self.ldap_mode == LDAPMode.SIMPLE:
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP authentication method simple bind returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
result, conn = self._ldap_authenticated_search(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
try:
logger.info(
"User authenticated against LDAP server: %s",
conn
)
except NameError:
logger.warn(
"Authentication method yielded no LDAP connection, aborting!"
)
defer.returnValue(False)
# check if user with user_id exists
if (yield self.account_handler.check_user_exists(user_id)):
# exists, authentication complete
conn.unbind()
defer.returnValue(True)
else:
# does not exist, fetch metadata for account creation from
# existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug(
"ldap registration filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
user_id, access_token = (
yield self.account_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
defer.returnValue(True)
else:
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@staticmethod
def parse_config(config):
class _LdapConfig(object):
pass
ldap_config = _LdapConfig()
ldap_config.enabled = config.get("enabled", False)
ldap_config.mode = LDAPMode.SIMPLE
# verify config sanity
_require_keys(config, [
"uri",
"base",
"attributes",
])
ldap_config.uri = config["uri"]
ldap_config.start_tls = config.get("start_tls", False)
ldap_config.base = config["base"]
ldap_config.attributes = config["attributes"]
if "bind_dn" in config:
ldap_config.mode = LDAPMode.SEARCH
_require_keys(config, [
"bind_dn",
"bind_password",
])
ldap_config.bind_dn = config["bind_dn"]
ldap_config.bind_password = config["bind_password"]
ldap_config.filter = config.get("filter", None)
# verify attribute lookup
_require_keys(config['attributes'], [
"uid",
"name",
"mail",
])
return ldap_config
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _require_keys(config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)

View file

@ -42,7 +42,8 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def replicate(self): def replicate(self):
streams = self.slaved_store.stream_positions() streams = self.slaved_store.stream_positions()
result = yield self.replication.replicate(streams, 100) writer = yield self.replication.replicate(streams, 100)
result = writer.finish()
yield self.slaved_store.process_replication(result) yield self.slaved_store.process_replication(result)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -120,7 +120,7 @@ class ReplicationResourceCase(unittest.TestCase):
self.hs.clock.advance_time_msec(1) self.hs.clock.advance_time_msec(1)
code, body = yield get code, body = yield get
self.assertEquals(code, 200) self.assertEquals(code, 200)
self.assertEquals(body, {}) self.assertEquals(body.get("rows", []), [])
test_timeout.__name__ = "test_timeout_%s" % (stream) test_timeout.__name__ = "test_timeout_%s" % (stream)
return test_timeout return test_timeout
@ -195,7 +195,6 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertIn("field_names", stream) self.assertIn("field_names", stream)
field_names = stream["field_names"] field_names = stream["field_names"]
self.assertIn("rows", stream) self.assertIn("rows", stream)
self.assertTrue(stream["rows"])
for row in stream["rows"]: for row in stream["rows"]:
self.assertEquals( self.assertEquals(
len(row), len(field_names), len(row), len(field_names),

View file

@ -37,6 +37,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
config = Mock( config = Mock(
app_service_config_files=self.as_yaml_files, app_service_config_files=self.as_yaml_files,
event_cache_size=1, event_cache_size=1,
password_providers=[],
) )
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config)
@ -109,6 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
config = Mock( config = Mock(
app_service_config_files=self.as_yaml_files, app_service_config_files=self.as_yaml_files,
event_cache_size=1, event_cache_size=1,
password_providers=[],
) )
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config)
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
@ -437,7 +439,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1") f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2") f2 = self._write_config(suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -447,7 +452,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(id="id", suffix="1") f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2") f2 = self._write_config(id="id", suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
@ -463,7 +471,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1") f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2") f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:

View file

@ -52,6 +52,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.server_name = name config.server_name = name
config.trusted_third_party_id_servers = [] config.trusted_third_party_id_servers = []
config.room_invite_state_types = [] config.room_invite_state_types = []
config.password_providers = []
config.use_frozen_dicts = True config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}