From b5f55a1d85386834b18828b196e1859a4410d98a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Jun 2015 09:52:24 +0100 Subject: [PATCH 01/23] Implement bulk verify_signed_json API --- synapse/crypto/keyring.py | 426 ++++++++++++++++-------- synapse/federation/federation_base.py | 121 ++++--- synapse/federation/federation_client.py | 57 +++- synapse/storage/keys.py | 50 +-- 4 files changed, 439 insertions(+), 215 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index aff69c5f8..873c9b40f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64 from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter - -from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError from OpenSSL import crypto +from collections import namedtuple import urllib import hashlib import logging @@ -38,6 +38,9 @@ import logging logger = logging.getLogger(__name__) +KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) + + class Keyring(object): def __init__(self, hs): self.store = hs.get_datastore() @@ -49,141 +52,274 @@ class Keyring(object): self.key_downloads = {} - @defer.inlineCallbacks def verify_json_for_server(self, server_name, json_object): - logger.debug("Verifying for %s", server_name) - key_ids = signature_ids(json_object, server_name) - if not key_ids: - raise SynapseError( - 400, - "Not signed with a supported algorithm", - Codes.UNAUTHORIZED, - ) - try: - verify_key = yield self.get_server_verify_key(server_name, key_ids) - except IOError as e: - logger.warn( - "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 502, - "Error downloading keys for %s" % (server_name,), - Codes.UNAUTHORIZED, - ) - except Exception as e: - logger.warn( - "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 401, - "No key for %s with id %s" % (server_name, key_ids), - Codes.UNAUTHORIZED, - ) + return self.verify_json_objects_for_server( + [(server_name, json_object)] + )[0] - try: - verify_signed_json(json_object, server_name, verify_key) - except: - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s" % ( - server_name, verify_key.alg, verify_key.version - ), - Codes.UNAUTHORIZED, - ) + def verify_json_objects_for_server(self, server_and_json): + """Bulk verfies signatures of json objects, bulk fetching keys as + necessary. - @defer.inlineCallbacks - def get_server_verify_key(self, server_name, key_ids): - """Finds a verification key for the server with one of the key ids. - Trys to fetch the key from a trusted perspective server first. Args: - server_name(str): The name of the server to fetch a key for. - keys_ids (list of str): The key_ids to check for. + server_and_json (list): List of pairs of (server_name, json_object) + + Returns: + list of deferreds indicating success or failure to verify each + json object's signature for the given server_name. """ - cached = yield self.store.get_server_verify_keys(server_name, key_ids) + group_id_to_json = {} + group_id_to_group = {} + group_ids = [] - if cached: - defer.returnValue(cached[0]) - return + next_group_id = 0 + deferreds = {} - download = self.key_downloads.get(server_name) + for server_name, json_object in server_and_json: + logger.debug("Verifying for %s", server_name) + group_id = next_group_id + next_group_id += 1 + group_ids.append(group_id) - if download is None: - download = self._get_server_verify_key_impl(server_name, key_ids) - download = ObservableDeferred( - download, - consumeErrors=True + key_ids = signature_ids(json_object, server_name) + if not key_ids: + deferreds[group_id] = defer.fail(SynapseError( + 400, + "Not signed with a supported algorithm", + Codes.UNAUTHORIZED, + )) + + group = KeyGroup(server_name, group_id, key_ids) + + group_id_to_group[group_id] = group + group_id_to_json[group_id] = json_object + + @defer.inlineCallbacks + def handle_key_deferred(group, deferred): + server_name = group.server_name + try: + _, _, key_id, verify_key = yield deferred + except IOError as e: + logger.warn( + "Got IOError when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) + raise SynapseError( + 502, + "Error downloading keys for %s" % (server_name,), + Codes.UNAUTHORIZED, + ) + except Exception as e: + logger.exception( + "Got Exception when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) + raise SynapseError( + 401, + "No key for %s with id %s" % (server_name, key_ids), + Codes.UNAUTHORIZED, + ) + + json_object = group_id_to_json[group.group_id] + + try: + verify_signed_json(json_object, server_name, verify_key) + except: + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s" % ( + server_name, verify_key.alg, verify_key.version + ), + Codes.UNAUTHORIZED, + ) + + deferreds.update(self.get_server_verify_keys( + group_id_to_group + )) + + return [ + handle_key_deferred( + group_id_to_group[g_id], + deferreds[g_id], ) - self.key_downloads[server_name] = download + for g_id in group_ids + ] - @download.addBoth - def callback(ret): - del self.key_downloads[server_name] - return ret + def get_server_verify_keys(self, group_id_to_group): + """Takes a dict of KeyGroups and tries to find at least one key for + each group. + """ - r = yield download.observe() - defer.returnValue(r) + # These are functions that produce keys given a list of key ids + key_fetch_fns = ( + self.get_keys_from_store, # First try the local store + self.get_keys_from_perspectives, # Then try via perspectives + self.get_keys_from_server, # Then try directly + ) + + group_deferreds = { + group_id: defer.Deferred() + for group_id in group_id_to_group + } + + @defer.inlineCallbacks + def do_iterations(): + merged_results = {} + + missing_keys = { + group.server_name: key_id + for group in group_id_to_group.values() + for key_id in group.key_ids + } + + for fn in key_fetch_fns: + results = yield fn(missing_keys.items()) + merged_results.update(results) + + # We now need to figure out which groups we have keys for + # and which we don't + missing_groups = {} + for group in group_id_to_group.values(): + for key_id in group.key_ids: + if key_id in merged_results[group.server_name]: + group_deferreds.pop(group.group_id).callback(( + group.group_id, + group.server_name, + key_id, + merged_results[group.server_name][key_id], + )) + break + else: + missing_groups.setdefault( + group.server_name, [] + ).append(group) + + if not missing_groups: + break + + missing_keys = { + server_name: set( + key_id for group in groups for key_id in group.key_ids + ) + for server_name, groups in missing_groups.items() + } + + for group in missing_groups.values(): + group_deferreds.pop(group.group_id).errback(SynapseError( + 401, + "No key for %s with id %s" % ( + group.server_name, group.key_ids, + ), + Codes.UNAUTHORIZED, + )) + + def on_err(err): + for deferred in group_deferreds.values(): + deferred.errback(err) + group_deferreds.clear() + + do_iterations().addErrback(on_err) + + return group_deferreds @defer.inlineCallbacks - def _get_server_verify_key_impl(self, server_name, key_ids): - keys = None + def get_keys_from_store(self, server_name_and_key_ids): + res = yield defer.gatherResults( + [ + self.store.get_server_verify_keys(server_name, key_ids) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + defer.returnValue(dict(zip( + [server_name for server_name, _ in server_name_and_key_ids], + res + ))) + + @defer.inlineCallbacks + def get_keys_from_perspectives(self, server_name_and_key_ids): @defer.inlineCallbacks def get_key(perspective_name, perspective_keys): try: result = yield self.get_server_verify_key_v2_indirect( - server_name, key_ids, perspective_name, perspective_keys + server_name_and_key_ids, perspective_name, perspective_keys ) defer.returnValue(result) except Exception as e: - logging.info( - "Unable to getting key %r for %r from %r: %s %s", - key_ids, server_name, perspective_name, + logger.exception( + "Unable to get key from %r: %s %s", + perspective_name, type(e).__name__, str(e.message), ) + defer.returnValue({}) - perspective_results = yield defer.gatherResults([ - get_key(p_name, p_keys) - for p_name, p_keys in self.perspective_servers.items() - ]) + results = yield defer.gatherResults( + [ + get_key(p_name, p_keys) + for p_name, p_keys in self.perspective_servers.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) - for results in perspective_results: - if results is not None: - keys = results + union_of_keys = {} + for result in results: + for server_name, keys in result.items(): + union_of_keys.setdefault(server_name, {}).update(keys) - limiter = yield get_retry_limiter( - server_name, - self.clock, - self.store, - ) + defer.returnValue(union_of_keys) - with limiter: - if not keys: + @defer.inlineCallbacks + def get_keys_from_server(self, server_name_and_key_ids): + @defer.inlineCallbacks + def get_key(server_name, key_ids): + limiter = yield get_retry_limiter( + server_name, + self.clock, + self.store, + ) + with limiter: + keys = None try: keys = yield self.get_server_verify_key_v2_direct( server_name, key_ids ) except Exception as e: - logging.info( + logger.info( "Unable to getting key %r for %r directly: %s %s", key_ids, server_name, type(e).__name__, str(e.message), ) - if not keys: - keys = yield self.get_server_verify_key_v1_direct( - server_name, key_ids - ) + if not keys: + keys = yield self.get_server_verify_key_v1_direct( + server_name, key_ids + ) - for key_id in key_ids: - if key_id in keys: - defer.returnValue(keys[key_id]) - return - raise ValueError("No verification key found for given key ids") + keys = {server_name: keys} + + defer.returnValue(keys) + + results = yield defer.gatherResults( + [ + get_key(server_name, key_ids) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + merged = {} + for result in results: + merged.update(result) + + defer.returnValue({ + server_name: keys + for server_name, keys in merged.items() + if keys + }) @defer.inlineCallbacks - def get_server_verify_key_v2_indirect(self, server_name, key_ids, + def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, perspective_name, perspective_keys): limiter = yield get_retry_limiter( @@ -204,6 +340,7 @@ class Keyring(object): u"minimum_valid_until_ts": 0 } for key_id in key_ids } + for server_name, key_ids in server_names_and_key_ids } }, ) @@ -243,23 +380,29 @@ class Keyring(object): " server %r" % (perspective_name,) ) - response_keys = yield self.process_v2_response( - server_name, perspective_name, response + processed_response = yield self.process_v2_response( + perspective_name, response ) - keys.update(response_keys) + for server_name, response_keys in processed_response.items(): + keys.setdefault(server_name, {}).update(response_keys) - yield self.store_keys( - server_name=server_name, - from_server=perspective_name, - verify_keys=keys, - ) + yield defer.gatherResults( + [ + self.store_keys( + server_name=server_name, + from_server=perspective_name, + verify_keys=response_keys, + ) + for server_name, response_keys in keys.items() + ], + consumeErrors=True + ).addErrback(unwrapFirstError) defer.returnValue(keys) @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} for requested_key_id in key_ids: @@ -295,25 +438,30 @@ class Keyring(object): raise ValueError("TLS certificate not allowed by fingerprints") response_keys = yield self.process_v2_response( - server_name=server_name, from_server=server_name, - requested_id=requested_key_id, + requested_ids=[requested_key_id], response_json=response, ) keys.update(response_keys) - yield self.store_keys( - server_name=server_name, - from_server=server_name, - verify_keys=keys, - ) + yield defer.gatherResults( + [ + self.store_keys( + server_name=key_server_name, + from_server=server_name, + verify_keys=verify_keys, + ) + for key_server_name, verify_keys in keys.items() + ], + consumeErrors=True + ).addErrback(unwrapFirstError) defer.returnValue(keys) @defer.inlineCallbacks - def process_v2_response(self, server_name, from_server, response_json, - requested_id=None): + def process_v2_response(self, from_server, response_json, + requested_ids=[]): time_now_ms = self.clock.time_msec() response_keys = {} verify_keys = {} @@ -335,6 +483,8 @@ class Keyring(object): verify_key.time_added = time_now_ms old_verify_keys[key_id] = verify_key + results = {} + server_name = response_json["server_name"] for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: raise ValueError( @@ -357,28 +507,31 @@ class Keyring(object): signed_key_json_bytes = encode_canonical_json(signed_key_json) ts_valid_until_ms = signed_key_json[u"valid_until_ts"] - updated_key_ids = set() - if requested_id is not None: - updated_key_ids.add(requested_id) + updated_key_ids = set(requested_ids) updated_key_ids.update(verify_keys) updated_key_ids.update(old_verify_keys) response_keys.update(verify_keys) response_keys.update(old_verify_keys) - for key_id in updated_key_ids: - yield self.store.store_server_keys_json( - server_name=server_name, - key_id=key_id, - from_server=server_name, - ts_now_ms=time_now_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=signed_key_json_bytes, - ) + yield defer.gatherResults( + [ + self.store.store_server_keys_json( + server_name=server_name, + key_id=key_id, + from_server=server_name, + ts_now_ms=time_now_ms, + ts_expires_ms=ts_valid_until_ms, + key_json_bytes=signed_key_json_bytes, + ) + for key_id in updated_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) - defer.returnValue(response_keys) + results[server_name] = response_keys - raise ValueError("No verification key found for given key ids") + defer.returnValue(results) @defer.inlineCallbacks def get_server_verify_key_v1_direct(self, server_name, key_ids): @@ -462,8 +615,13 @@ class Keyring(object): Returns: A deferred that completes when the keys are stored. """ - for key_id, key in verify_keys.items(): - # TODO(markjh): Store whether the keys have expired. - yield self.store.store_server_verify_key( - server_name, server_name, key.time_added, key - ) + # TODO(markjh): Store whether the keys have expired. + yield defer.gatherResults( + [ + self.store.store_server_verify_key( + server_name, server_name, key.time_added, key + ) + for key_id, key in verify_keys.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 299493af9..bdfa24760 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -32,7 +32,8 @@ logger = logging.getLogger(__name__) class FederationBase(object): @defer.inlineCallbacks - def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False): + def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, + include_none=False): """Takes a list of PDUs and checks the signatures and hashs of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of @@ -50,84 +51,108 @@ class FederationBase(object): Returns: Deferred : A list of PDUs that have valid signatures and hashes. """ + deferreds = self._check_sigs_and_hashes(pdus) - signed_pdus = [] + def callback(pdu): + return pdu - @defer.inlineCallbacks - def do(pdu): - try: - new_pdu = yield self._check_sigs_and_hash(pdu) - signed_pdus.append(new_pdu) - except SynapseError: - # FIXME: We should handle signature failures more gracefully. + def errback(failure, pdu): + failure.trap(SynapseError) + return None + def try_local_db(res, pdu): + if not res: # Check local db. - new_pdu = yield self.store.get_event( + return self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True, ) - if new_pdu: - signed_pdus.append(new_pdu) - return + return res - # Check pdu.origin - if pdu.origin != origin: - try: - new_pdu = yield self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - outlier=outlier, - timeout=10000, - ) - - if new_pdu: - signed_pdus.append(new_pdu) - return - except: - pass + def try_remote(res, pdu): + if not res and pdu.origin != origin: + return self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + outlier=outlier, + timeout=10000, + ).addErrback(lambda e: None) + return res + def warn(res, pdu): + if not res: logger.warn( "Failed to find copy of %s with valid signature", pdu.event_id, ) + return res - yield defer.gatherResults( - [do(pdu) for pdu in pdus], + for pdu, deferred in zip(pdus, deferreds): + deferred.addCallbacks( + callback, errback, errbackArgs=[pdu] + ).addCallback( + try_local_db, pdu + ).addCallback( + try_remote, pdu + ).addCallback( + warn, pdu + ) + + valid_pdus = yield defer.gatherResults( + deferreds, consumeErrors=True ).addErrback(unwrapFirstError) - defer.returnValue(signed_pdus) + if include_none: + defer.returnValue(valid_pdus) + else: + defer.returnValue([p for p in valid_pdus if p]) - @defer.inlineCallbacks def _check_sigs_and_hash(self, pdu): - """Throws a SynapseError if the PDU does not have the correct + return self._check_sigs_and_hashes([pdu])[0] + + def _check_sigs_and_hashes(self, pdus): + """Throws a SynapseError if a PDU does not have the correct signatures. Returns: FrozenEvent: Either the given event or it redacted if it failed the content hash check. """ - # Check signatures are correct. - redacted_event = prune_event(pdu) - redacted_pdu_json = redacted_event.get_pdu_json() - try: - yield self.keyring.verify_json_for_server( - pdu.origin, redacted_pdu_json - ) - except SynapseError: + redacted_pdus = [ + prune_event(pdu) + for pdu in pdus + ] + + deferreds = self.keyring.verify_json_objects_for_server([ + (p.origin, p.get_pdu_json()) + for p in redacted_pdus + ]) + + def callback(_, pdu, redacted): + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + return pdu + + def errback(failure, pdu): + failure.trap(SynapseError) logger.warn( "Signature check failed for %s", pdu.event_id, ) - raise + return failure - if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting.", - pdu.event_id, + for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): + deferred.addCallbacks( + callback, errback, + callbackArgs=[pdu, redacted], + errbackArgs=[pdu], ) - defer.returnValue(redacted_event) - defer.returnValue(pdu) + return deferreds diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d3b46b24c..7736d14fb 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -30,6 +30,7 @@ import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination +import copy import itertools import logging import random @@ -167,7 +168,7 @@ class FederationClient(FederationBase): # FIXME: We should handle signature failures more gracefully. pdus[:] = yield defer.gatherResults( - [self._check_sigs_and_hash(pdu) for pdu in pdus], + self._check_sigs_and_hashes(pdus), consumeErrors=True, ).addErrback(unwrapFirstError) @@ -230,7 +231,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - pdu = yield self._check_sigs_and_hash(pdu) + pdu = yield self._check_sigs_and_hashes([pdu])[0] break @@ -327,6 +328,9 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def make_join(self, destinations, room_id, user_id): for destination in destinations: + if destination == self.server_name: + continue + try: ret = yield self.transport_layer.make_join( destination, room_id, user_id @@ -353,6 +357,9 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_join(self, destinations, pdu): for destination in destinations: + if destination == self.server_name: + continue + try: time_now = self._clock.time_msec() _, content = yield self.transport_layer.send_join( @@ -374,17 +381,39 @@ class FederationClient(FederationBase): for p in content.get("auth_chain", []) ] - signed_state, signed_auth = yield defer.gatherResults( - [ - self._check_sigs_and_hash_and_fetch( - destination, state, outlier=True - ), - self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True - ) - ], - consumeErrors=True - ).addErrback(unwrapFirstError) + pdus = { + p.event_id: p + for p in itertools.chain(state, auth_chain) + } + + valid_pdus = yield self._check_sigs_and_hash_and_fetch( + destination, pdus.values(), + outlier=True, + ) + + valid_pdus_map = { + p.event_id: p + for p in valid_pdus + } + + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + signed_state = [ + copy.copy(valid_pdus_map[p.event_id]) + for p in state + if p.event_id in valid_pdus_map + ] + + signed_auth = [ + valid_pdus_map[p.event_id] + for p in auth_chain + if p.event_id in valid_pdus_map + ] + + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + for s in signed_state: + s.internal_metadata = copy.deepcopy(s.internal_metadata) auth_chain.sort(key=lambda e: e.depth) @@ -396,7 +425,7 @@ class FederationClient(FederationBase): except CodeMessageException: raise except Exception as e: - logger.warn( + logger.exception( "Failed to send_join via %s: %s", destination, e.message ) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 5bdf497b9..940a5f7e0 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from _base import SQLBaseStore, cached from twisted.internet import defer @@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) + @cached() + @defer.inlineCallbacks + def get_all_server_verify_keys(self, server_name): + rows = yield self._simple_select_list( + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + }, + retcols=["key_id", "verify_key"], + desc="get_all_server_verify_keys", + ) + + defer.returnValue({ + row["key_id"]: decode_verify_key_bytes( + row["key_id"], str(row["verify_key"]) + ) + for row in rows + }) + @defer.inlineCallbacks def get_server_verify_keys(self, server_name, key_ids): """Retrieve the NACL verification key for a given server for the given @@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore): Returns: (list of VerifyKey): The verification keys. """ - sql = ( - "SELECT key_id, verify_key FROM server_signature_keys" - " WHERE server_name = ?" - " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" - ) - - rows = yield self._execute_and_decode( - "get_server_verify_keys", sql, server_name, *key_ids - ) - - keys = [] - for row in rows: - key_id = row["key_id"] - key_bytes = row["verify_key"] - key = decode_verify_key_bytes(key_id, str(key_bytes)) - keys.append(key) - defer.returnValue(keys) + keys = yield self.get_all_server_verify_keys(server_name) + defer.returnValue({ + k: keys[k] + for k in key_ids + if k in keys and keys[k] + }) + @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. @@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore): ts_now_ms (int): The time now in milliseconds verification_key (VerifyKey): The NACL verify key. """ - return self._simple_upsert( + yield self._simple_upsert( table="server_signature_keys", keyvalues={ "server_name": server_name, @@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) + self.get_all_server_verify_keys.invalidate(server_name) + def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): """Stores the JSON bytes for a set of keys from a server @@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore): "ts_valid_until_ms": ts_expires_ms, "key_json": buffer(key_json_bytes), }, + desc="store_server_keys_json", ) def get_server_keys_json(self, server_keys): From f0dd568e1681b878d9e85a12c21bfa0d5540ff73 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Jun 2015 11:25:00 +0100 Subject: [PATCH 02/23] Wait for previous attempts at fetching keys for a given server before trying to fetch more --- synapse/crypto/keyring.py | 83 ++++++++++++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 873c9b40f..aa74d4d0c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter from synapse.util import unwrapFirstError +from synapse.util.async import ObservableDeferred + from OpenSSL import crypto from collections import namedtuple @@ -88,6 +90,8 @@ class Keyring(object): "Not signed with a supported algorithm", Codes.UNAUTHORIZED, )) + else: + deferreds[group_id] = defer.Deferred() group = KeyGroup(server_name, group_id, key_ids) @@ -133,10 +137,41 @@ class Keyring(object): Codes.UNAUTHORIZED, ) - deferreds.update(self.get_server_verify_keys( - group_id_to_group - )) + server_to_deferred = { + server_name: defer.Deferred() + for server_name, _ in server_and_json + } + # We want to wait for any previous lookups to complete before + # proceeding. + wait_on_deferred = self.wait_for_previous_lookups( + [server_name for server_name, _ in server_and_json], + server_to_deferred, + ) + + # Actually start fetching keys. + wait_on_deferred.addBoth( + lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) + ) + + # When we've finished fetching all the keys for a given server_name, + # resolve the deferred passed to `wait_for_previous_lookups` so that + # any lookups waiting will proceed. + server_to_gids = {} + + def remove_deferreds(res, server_name, group_id): + server_to_gids[server_name].discard(group_id) + if not server_to_gids[server_name]: + server_to_deferred.pop(server_name).callback(None) + return res + + for g_id, deferred in deferreds.items(): + server_name = group_id_to_group[g_id].server_name + server_to_gids.setdefault(server_name, set()).add(g_id) + deferred.addBoth(remove_deferreds, server_name, g_id) + + # Pass those keys to handle_key_deferred so that the json object + # signatures can be verified return [ handle_key_deferred( group_id_to_group[g_id], @@ -145,7 +180,30 @@ class Keyring(object): for g_id in group_ids ] - def get_server_verify_keys(self, group_id_to_group): + @defer.inlineCallbacks + def wait_for_previous_lookups(self, server_names, server_to_deferred): + """Waits for any previous key lookups for the given servers to finish. + + Args: + server_names (list): list of server_names we want to lookup + server_to_deferred (dict): server_name to deferred which gets + resolved once we've finished looking up keys for that server + """ + while True: + wait_on = [ + self.key_downloads[server_name] + for server_name in server_names + if server_name in self.key_downloads + ] + if wait_on: + yield defer.DeferredList(wait_on) + else: + break + + for server_name, deferred in server_to_deferred: + self.key_downloads[server_name] = ObservableDeferred(deferred) + + def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): """Takes a dict of KeyGroups and tries to find at least one key for each group. """ @@ -157,11 +215,6 @@ class Keyring(object): self.get_keys_from_server, # Then try directly ) - group_deferreds = { - group_id: defer.Deferred() - for group_id in group_id_to_group - } - @defer.inlineCallbacks def do_iterations(): merged_results = {} @@ -182,7 +235,7 @@ class Keyring(object): for group in group_id_to_group.values(): for key_id in group.key_ids: if key_id in merged_results[group.server_name]: - group_deferreds.pop(group.group_id).callback(( + group_id_to_deferred[group.group_id].callback(( group.group_id, group.server_name, key_id, @@ -205,7 +258,7 @@ class Keyring(object): } for group in missing_groups.values(): - group_deferreds.pop(group.group_id).errback(SynapseError( + group_id_to_deferred[group.group_id].errback(SynapseError( 401, "No key for %s with id %s" % ( group.server_name, group.key_ids, @@ -214,13 +267,13 @@ class Keyring(object): )) def on_err(err): - for deferred in group_deferreds.values(): - deferred.errback(err) - group_deferreds.clear() + for deferred in group_id_to_deferred.values(): + if not deferred.called: + deferred.errback(err) do_iterations().addErrback(on_err) - return group_deferreds + return group_id_to_deferred @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): From 1a605456260bfb46d8bb9cff2d40d19aec03daa4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Jul 2015 16:20:10 +0100 Subject: [PATCH 03/23] Add basic impl for room history ACL on GET /messages client API --- synapse/api/constants.py | 2 ++ synapse/handlers/message.py | 33 ++++++++++++++++++- synapse/storage/state.py | 63 +++++++++++++++++++++++++++++++++++-- 3 files changed, 95 insertions(+), 3 deletions(-) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index d8a18ee87..3e15e8a9d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -75,6 +75,8 @@ class EventTypes(object): Redaction = "m.room.redaction" Feedback = "m.room.message.feedback" + RoomHistoryVisibility = "m.room.history_visibility" + # These are used for validation Message = "m.room.message" Topic = "m.room.topic" diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e324662f1..17c75f33c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -113,11 +113,42 @@ class MessageHandler(BaseHandler): "room_key", next_key ) + if not events: + defer.returnValue({ + "chunk": [], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + }) + + states = yield self.store.get_state_for_events( + room_id, [e.event_id for e in events], + ) + + events_and_states = zip(events, states) + + def allowed(event_and_state): + _, state = event_and_state + + membership = state.get((EventTypes.Member, user_id), None) + if membership and membership.membership == Membership.JOIN: + return True + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history and history.content["visibility"] == "after_join": + return False + + events_and_states = filter(allowed, events_and_states) + events = [ + ev + for ev, _ in events_and_states + ] + time_now = self.clock.time_msec() chunk = { "chunk": [ - serialize_event(e, time_now, as_client_event) for e in events + serialize_event(e, time_now, as_client_event) + for e in events ], "start": pagin_config.from_token.to_string(), "end": next_token.to_string(), diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f2b17f29e..d7844edee 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -92,11 +92,11 @@ class StateStore(SQLBaseStore): defer.returnValue(dict(state_list)) @cached(num_args=1) - def _fetch_events_for_group(self, state_group, events): + def _fetch_events_for_group(self, key, events): return self._get_events( events, get_prev_content=False ).addCallback( - lambda evs: (state_group, evs) + lambda evs: (key, evs) ) def _store_state_groups_txn(self, txn, event, context): @@ -194,6 +194,65 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids): + def f(txn): + groups = set() + event_to_group = {} + for event_id in event_ids: + # TODO: Remove this loop. + group = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + event_to_group[event_id] = group + groups.add(group) + + group_to_state_ids = {} + for group in groups: + state_ids = self._simple_select_onecol_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + + group_to_state_ids[group] = state_ids + + return event_to_group, group_to_state_ids + + res = yield self.runInteraction( + "annotate_events_with_state_groups", + f, + ) + + event_to_group, group_to_state_ids = res + + state_list = yield defer.gatherResults( + [ + self._fetch_events_for_group(group, vals) + for group, vals in group_to_state_ids.items() + ], + consumeErrors=True, + ) + + state_dict = { + group: { + (ev.type, ev.state_key): ev + for ev in state + } + for group, state in state_list + } + + defer.returnValue([ + state_dict.get(event_to_group.get(event, None), None) + for event in event_ids + ]) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) From 41938afed884361451ad6e91eb44e805ebbdaeb0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Jul 2015 17:02:10 +0100 Subject: [PATCH 04/23] Make v1 initial syncs respect room history ACL --- synapse/handlers/message.py | 61 ++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 17c75f33c..00c7dbec8 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -120,28 +120,7 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) - states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], - ) - - events_and_states = zip(events, states) - - def allowed(event_and_state): - _, state = event_and_state - - membership = state.get((EventTypes.Member, user_id), None) - if membership and membership.membership == Membership.JOIN: - return True - - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history and history.content["visibility"] == "after_join": - return False - - events_and_states = filter(allowed, events_and_states) - events = [ - ev - for ev, _ in events_and_states - ] + events = yield self._filter_events_for_client(user_id, room_id, events) time_now = self.clock.time_msec() @@ -156,6 +135,36 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) + @defer.inlineCallbacks + def _filter_events_for_client(self, user_id, room_id, events): + states = yield self.store.get_state_for_events( + room_id, [e.event_id for e in events], + ) + + events_and_states = zip(events, states) + + def allowed(event_and_state): + event, state = event_and_state + + if event.type == EventTypes.RoomHistoryVisibility: + return True + + membership = state.get((EventTypes.Member, user_id), None) + if membership and membership.membership == Membership.JOIN: + return True + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history and history.content.get("visibility", None) == "after_join": + return False + + return True + + events_and_states = filter(allowed, events_and_states) + defer.returnValue([ + ev + for ev, _ in events_and_states + ]) + @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, client=None, txn_id=None): @@ -347,6 +356,10 @@ class MessageHandler(BaseHandler): ] ).addErrback(unwrapFirstError) + messages = yield self._filter_events_for_client( + user_id, event.room_id, messages + ) + start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) time_now = self.clock.time_msec() @@ -448,6 +461,10 @@ class MessageHandler(BaseHandler): consumeErrors=True, ).addErrback(unwrapFirstError) + messages = yield self._filter_events_for_client( + user_id, room_id, messages + ) + start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) From 00ab882ed66d1ecdbd9de15c7ac591f4ee07b9f8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Jul 2015 10:31:17 +0100 Subject: [PATCH 05/23] Add m.room.history_visibility to list of auth events --- synapse/api/auth.py | 2 +- synapse/events/utils.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 4da62e5d8..deca747f7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) AuthEventTypes = ( EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, - EventTypes.JoinRules, + EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 1aa952150..4c82780f4 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -74,6 +74,8 @@ def prune_event(event): ) elif event_type == EventTypes.Aliases: add_fields("aliases") + elif event_type == EventTypes.RoomHistoryVisibility: + add_fields("visibility") allowed_fields = { k: v From 400894616d15a01c168b2356d950972b6e746496 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Jul 2015 14:51:01 +0100 Subject: [PATCH 06/23] Respect m.room.history_visibility in v2_alpha sync API --- synapse/handlers/sync.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bd8c60368..5078c4e45 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -292,6 +292,36 @@ class SyncHandler(BaseHandler): next_batch=now_token, )) + @defer.inlineCallbacks + def _filter_events_for_client(self, user_id, room_id, events): + states = yield self.store.get_state_for_events( + room_id, [e.event_id for e in events], + ) + + events_and_states = zip(events, states) + + def allowed(event_and_state): + event, state = event_and_state + + if event.type == EventTypes.RoomHistoryVisibility: + return True + + membership = state.get((EventTypes.Member, user_id), None) + if membership and membership.membership == Membership.JOIN: + return True + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history and history.content.get("visibility", None) == "after_join": + return False + + return True + + events_and_states = filter(allowed, events_and_states) + defer.returnValue([ + ev + for ev, _ in events_and_states + ]) + @defer.inlineCallbacks def load_filtered_recents(self, room_id, sync_config, now_token, since_token=None): @@ -313,6 +343,9 @@ class SyncHandler(BaseHandler): (room_key, _) = keys end_key = "s" + room_key.split('-')[-1] loaded_recents = sync_config.filter.filter_room_events(events) + loaded_recents = yield self._filter_events_for_client( + sync_config.user.to_string(), room_id, loaded_recents, + ) loaded_recents.extend(recents) recents = loaded_recents if len(events) <= load_limit: From c3e2600c6727534d4ebf20dcd8219e248ca31461 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Jul 2015 17:52:57 +0100 Subject: [PATCH 07/23] Filter and redact events that the other server doesn't have permission to see during backfill --- synapse/handlers/federation.py | 44 ++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b5d882fd6..663d05c63 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -31,6 +31,8 @@ from synapse.crypto.event_signing import ( ) from synapse.types import UserID +from synapse.events.utils import prune_event + from synapse.util.retryutils import NotRetryingDestination from twisted.internet import defer @@ -222,6 +224,46 @@ class FederationHandler(BaseHandler): "user_joined_room", user=user, room_id=event.room_id ) + @defer.inlineCallbacks + def _filter_events_for_server(self, server_name, room_id, events): + states = yield self.store.get_state_for_events( + room_id, [e.event_id for e in events], + ) + + events_and_states = zip(events, states) + + def redact_disallowed(event_and_state): + event, state = event_and_state + + if not state: + return event + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history and history.content.get("visibility", None) == "after_join": + for ev in state.values(): + if ev.type != EventTypes.Member: + continue + try: + domain = UserID.from_string(ev.state_key).domain + except: + continue + + if domain != server_name: + continue + + if ev.membership == Membership.JOIN: + return event + else: + return prune_event(event) + + return event + + res = map(redact_disallowed, events_and_states) + + logger.info("_filter_events_for_server %r", res) + + defer.returnValue(res) + @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit, extremities=[]): @@ -882,6 +924,8 @@ class FederationHandler(BaseHandler): limit ) + events = yield self._filter_events_for_server(origin, room_id, events) + defer.returnValue(events) @defer.inlineCallbacks From fb47c3cfbe213c01b25e5605b81c998b764e2bf8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jul 2015 13:05:52 +0100 Subject: [PATCH 08/23] Rename key and values for m.room.history_visibility. Support 'invited' value --- synapse/events/utils.py | 2 +- synapse/handlers/federation.py | 34 ++++++++++++++++++++-------------- synapse/handlers/message.py | 24 ++++++++++++++++++++---- synapse/handlers/sync.py | 25 ++++++++++++++++++++----- 4 files changed, 61 insertions(+), 24 deletions(-) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 4c82780f4..7bd78343f 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -75,7 +75,7 @@ def prune_event(event): elif event_type == EventTypes.Aliases: add_fields("aliases") elif event_type == EventTypes.RoomHistoryVisibility: - add_fields("visibility") + add_fields("history_visibility") allowed_fields = { k: v diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 663d05c63..cd3867ed9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -239,22 +239,28 @@ class FederationHandler(BaseHandler): return event history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history and history.content.get("visibility", None) == "after_join": - for ev in state.values(): - if ev.type != EventTypes.Member: - continue - try: - domain = UserID.from_string(ev.state_key).domain - except: - continue + if history: + visibility = history.content.get("history_visibility", "shared") + if visibility in ["invited", "joined"]: + for ev in state.values(): + if ev.type != EventTypes.Member: + continue + try: + domain = UserID.from_string(ev.state_key).domain + except: + continue - if domain != server_name: - continue + if domain != server_name: + continue - if ev.membership == Membership.JOIN: - return event - else: - return prune_event(event) + memtype = ev.membership + if memtype == Membership.JOIN: + return event + elif memtype == Membership.INVITE: + if visibility == "invited": + return event + else: + return prune_event(event) return event diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 00c7dbec8..d8b117612 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -149,13 +149,29 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.RoomHistoryVisibility: return True - membership = state.get((EventTypes.Member, user_id), None) - if membership and membership.membership == Membership.JOIN: + membership_ev = state.get((EventTypes.Member, user_id), None) + if membership_ev: + membership = membership_ev.membership + else: + membership = Membership.LEAVE + + if membership == Membership.JOIN: return True history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history and history.content.get("visibility", None) == "after_join": - return False + if history: + visibility = history.content.get("history_visibility", "shared") + else: + visibility = "shared" + + if visibility == "public": + return True + elif visibility == "shared": + return True + elif visibility == "joined": + return membership == Membership.JOIN + elif visibility == "invited": + return membership == Membership.INVITE return True diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5078c4e45..6cff6230c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -306,16 +306,31 @@ class SyncHandler(BaseHandler): if event.type == EventTypes.RoomHistoryVisibility: return True - membership = state.get((EventTypes.Member, user_id), None) - if membership and membership.membership == Membership.JOIN: + membership_ev = state.get((EventTypes.Member, user_id), None) + if membership_ev: + membership = membership_ev.membership + else: + membership = Membership.LEAVE + + if membership == Membership.JOIN: return True history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history and history.content.get("visibility", None) == "after_join": - return False + if history: + visibility = history.content.get("history_visibility", "shared") + else: + visibility = "shared" + + if visibility == "public": + return True + elif visibility == "shared": + return True + elif visibility == "joined": + return membership == Membership.JOIN + elif visibility == "invited": + return membership == Membership.INVITE return True - events_and_states = filter(allowed, events_and_states) defer.returnValue([ ev From 1a3255b507550c76f11251c890a43947b1f4e272 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jul 2015 13:09:16 +0100 Subject: [PATCH 09/23] Add m.room.history_visibility to newly created rooms' m.room.power_levels --- synapse/api/auth.py | 1 + synapse/handlers/room.py | 1 + 2 files changed, 2 insertions(+) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index deca747f7..1a25bf108 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -575,6 +575,7 @@ class Auth(object): levels_to_check = [ ("users_default", []), ("events_default", []), + ("state_default", []), ("ban", []), ("redact", []), ("kick", []), diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 4bd027d9b..891707df4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -213,6 +213,7 @@ class RoomCreationHandler(BaseHandler): "events": { EventTypes.Name: 100, EventTypes.PowerLevels: 100, + EventTypes.RoomHistoryVisibility: 100, }, "events_default": 0, "state_default": 50, From b5770f89478509b5c4c1a610753989fbe29e35e7 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 6 Jul 2015 18:46:47 +0100 Subject: [PATCH 10/23] Add store for client end to end keys --- synapse/storage/__init__.py | 4 +- synapse/storage/end_to_end_keys.py | 133 ++++++++++++++++++ .../schema/delta/21/end_to_end_keys.sql | 35 +++++ 3 files changed, 171 insertions(+), 1 deletion(-) create mode 100644 synapse/storage/end_to_end_keys.py create mode 100644 synapse/storage/schema/delta/21/end_to_end_keys.sql diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c137f4782..e089d8167 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -37,6 +37,7 @@ from .rejections import RejectionsStore from .state import StateStore from .signatures import SignatureStore from .filtering import FilteringStore +from .end_to_end_keys import EndToEndKeyStore import fnmatch @@ -51,7 +52,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 20 +SCHEMA_VERSION = 21 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -74,6 +75,7 @@ class DataStore(RoomMemberStore, RoomStore, PushRuleStore, ApplicationServiceTransactionStore, EventsStore, + EndToEndKeyStore, ): def __init__(self, hs): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py new file mode 100644 index 000000000..936a64669 --- /dev/null +++ b/synapse/storage/end_to_end_keys.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _base import SQLBaseStore + + +class EndToEndKeyStore(SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): + return self._simple_upsert( + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": json_bytes, + } + ) + + def get_e2e_device_keys(self, query_list): + """Fetch a list of device keys. + Args: + query_list(list): List of pairs of user_ids and device_ids. + Returns: + Dict mapping from user-id to dict mapping from device_id to + key json byte strings. + """ + def _get_e2e_device_keys(txn): + result = {} + for user_id, device_id in query_list: + user_result = result.setdefault(user_id, {}) + keyvalues = {"user_id": user_id} + if device_id: + keyvalues["device_id"] = device_id + rows = self._simple_select_list_txn( + txn, table="e2e_device_keys_json", + keyvalues=keyvalues, + retcols=["device_id", "key_json"] + ) + for row in rows: + user_result[row["device_id"]] = row["key_json"] + return result + return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + + def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until, + key_list): + def _add_e2e_one_time_keys(txn): + for (algorithm, key_id, json_bytes) in key_list: + self._simple_upsert_txn( + txn, table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + values={ + "ts_added_ms": time_now, + "valid_until_ms": valid_until, + "key_json": json_bytes, + } + ) + return self.runInteraction( + "add_e2e_one_time_keys", _add_e2e_one_time_keys + ) + + def count_e2e_one_time_keys(self, user_id, device_id, time_now): + """ Count the number of one time keys the server has for a device + Returns: + Dict mapping from algorithm to number of keys for that algorithm. + """ + def _count_e2e_one_time_keys(txn): + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?" + ) + txn.execute(sql, (user_id, device_id, time_now)) + sql = ( + "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ?" + " GROUP BY algorithm" + ) + txn.execute(sql, (user_id, device_id)) + result = {} + for algorithm, key_count in txn.fetchall(): + result[algorithm] = key_count + return result + return self.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) + + def take_e2e_one_time_keys(self, query_list, time_now): + """Take a list of one time keys out of the database""" + def _take_e2e_one_time_keys(txn): + sql = ( + "SELECT key_id, key_json FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND valid_until_ms > ?" + " LIMIT 1" + ) + result = {} + delete = [] + for user_id, device_id, algorithm in query_list: + user_result = result.setdefault(user_id, {}) + device_result = user_result.setdefault(device_id, {}) + txn.execute(sql, (user_id, device_id, algorithm, time_now)) + for key_id, key_json in txn.fetchall(): + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND key_id = ?" + ) + for user_id, device_id, algorithm, key_id in delete: + txn.execute(sql, (user_id, device_id, algorithm, key_id)) + return result + return self.runInteraction( + "take_e2e_one_time_keys", _take_e2e_one_time_keys + ) diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql new file mode 100644 index 000000000..107d2e67c --- /dev/null +++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql @@ -0,0 +1,35 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS e2e_device_keys_json ( + user_id TEXT NOT NULL, -- The user these keys are for. + device_id TEXT NOT NULL, -- Which of the user's devices these keys are for. + ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded. + key_json TEXT NOT NULL, -- The keys for the device as a JSON blob. + CONSTRAINT uniqueness UNIQUE (user_id, device_id) +); + + +CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( + user_id TEXT NOT NULL, -- The user this one-time key is for. + device_id TEXT NOT NULL, -- The device this one-time key is for. + algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. + valid_until_ms BIGINT NOT NULL, -- When this key is valid until. + key_json TEXT NOT NULL, -- The key as a JSON blob. + CONSTRAINT uniqueness UNIQUE (user_id, device_id, algorithm, key_id) +); From 2ef182ee9358bce24cdef7c09ae7289925d076ef Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 6 Jul 2015 18:47:57 +0100 Subject: [PATCH 11/23] Add client API for uploading and querying keys for end to end encryption --- synapse/rest/client/v2_alpha/__init__.py | 4 +- synapse/rest/client/v2_alpha/keys.py | 287 +++++++++++++++++++++++ 2 files changed, 290 insertions(+), 1 deletion(-) create mode 100644 synapse/rest/client/v2_alpha/keys.py diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index 7d1aff430..c3323d2a8 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -18,7 +18,8 @@ from . import ( filter, account, register, - auth + auth, + keys, ) from synapse.http.server import JsonResource @@ -38,3 +39,4 @@ class ClientV2AlphaRestResource(JsonResource): account.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource) + keys.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py new file mode 100644 index 000000000..3bb4ad64f --- /dev/null +++ b/synapse/rest/client/v2_alpha/keys.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet +from syutil.jsonutil import encode_canonical_json + +from ._base import client_v2_pattern + +import simplejson as json +import logging + +logger = logging.getLogger(__name__) + + +class KeyUploadServlet(RestServlet): + """ + POST /keys/upload/ HTTP/1.1 + Content-Type: application/json + + { + "device_keys": { + "user_id": "", + "device_id": "", + "valid_until_ts": , + "algorithms": [ + "m.olm.curve25519-aes-sha256", + ] + "keys": { + ":": "", + }, + "signatures:" { + "/" { + ":": "" + } } }, + "one_time_keys": { + ":": "" + }, + "one_time_keys_valid_for": , + } + """ + PATTERN = client_v2_pattern("/keys/upload/(?P[^/]*)") + + def __init__(self, hs): + super(KeyUploadServlet, self).__init__() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_POST(self, request, device_id): + auth_user, client_info = yield self.auth.get_user_by_req(request) + user_id = auth_user.to_string() + # TODO: Check that the device_id matches that in the authentication + # or derive the device_id from the authentication instead. + try: + body = json.loads(request.content.read()) + except: + raise SynapseError(400, "Invalid key JSON") + time_now = self.clock.time_msec() + + # TODO: Validate the JSON to make sure it has the right keys. + device_keys = body.get("device_keys", None) + if device_keys: + logger.info( + "Updating device_keys for device %r for user %r at %d", + device_id, auth_user, time_now + ) + # TODO: Sign the JSON with the server key + yield self.store.set_e2e_device_keys( + user_id, device_id, time_now, + encode_canonical_json(device_keys) + ) + + one_time_keys = body.get("one_time_keys", None) + one_time_keys_valid_for = body.get("one_time_keys_valid_for", None) + if one_time_keys: + valid_until = int(one_time_keys_valid_for) + time_now + logger.info( + "Adding %d one_time_keys for device %r for user %r at %d" + " valid_until %d", + len(one_time_keys), device_id, user_id, time_now, valid_until + ) + key_list = [] + for key_id, key_json in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, encode_canonical_json(key_json) + )) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, valid_until, key_list + ) + + result = yield self.store.count_e2e_one_time_keys( + user_id, device_id, time_now + ) + defer.returnValue((200, {"one_time_key_counts": result})) + + @defer.inlineCallbacks + def on_GET(self, request, device_id): + auth_user, client_info = yield self.auth.get_user_by_req(request) + user_id = auth_user.to_string() + time_now = self.clock.time_msec() + + result = yield self.store.count_e2e_one_time_keys( + user_id, device_id, time_now + ) + defer.returnValue((200, {"one_time_key_counts": result})) + + +class KeyQueryServlet(RestServlet): + """ + GET /keys/query/ HTTP/1.1 + + GET /keys/query// HTTP/1.1 + + POST /keys/query HTTP/1.1 + Content-Type: application/json + { + "device_keys": { + "": [""] + } } + + HTTP/1.1 200 OK + { + "device_keys": { + "": { + "": { + "user_id": "", // Duplicated to be signed + "device_id": "", // Duplicated to be signed + "valid_until_ts": , + "algorithms": [ // List of supported algorithms + "m.olm.curve25519-aes-sha256", + ], + "keys": { // Must include a ed25519 signing key + ":": "", + }, + "signatures:" { + // Must be signed with device's ed25519 key + "/": { + ":": "" + } + // Must be signed by this server. + "": { + ":": "" + } } } } } } + """ + + PATTERN = client_v2_pattern( + "/keys/query(?:" + "/(?P[^/]*)(?:" + "/(?P[^/]*)" + ")?" + ")?" + ) + + def __init__(self, hs): + super(KeyQueryServlet, self).__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_POST(self, request, user_id, device_id): + logger.debug("onPOST") + yield self.auth.get_user_by_req(request) + try: + body = json.loads(request.content.read()) + except: + raise SynapseError(400, "Invalid key JSON") + query = [] + for user_id, device_ids in body.get("device_keys", {}).items(): + if not device_ids: + query.append((user_id, None)) + else: + for device_id in device_ids: + query.append((user_id, device_id)) + results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) + defer.returnValue(self.json_result(request, results)) + + @defer.inlineCallbacks + def on_GET(self, request, user_id, device_id): + auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user_id = auth_user.to_string() + if not user_id: + user_id = auth_user_id + if not device_id: + device_id = None + # Returns a map of user_id->device_id->json_bytes. + results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) + defer.returnValue(self.json_result(request, results)) + + def json_result(self, request, results): + json_result = {} + for user_id, device_keys in results.items(): + for device_id, json_bytes in device_keys.items(): + json_result.setdefault(user_id, {})[device_id] = json.loads( + json_bytes + ) + return (200, {"device_keys": json_result}) + + +class OneTimeKeyServlet(RestServlet): + """ + GET /keys/take/// HTTP/1.1 + + POST /keys/take HTTP/1.1 + { + "one_time_keys": { + "": { + "": "" + } } } + + HTTP/1.1 200 OK + { + "one_time_keys": { + "": { + "": { + ":": "" + } } } } + + """ + PATTERN = client_v2_pattern( + "/keys/take(?:/?|(?:/" + "(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" + ")?)" + ) + + def __init__(self, hs): + super(OneTimeKeyServlet, self).__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_GET(self, request, user_id, device_id, algorithm): + yield self.auth.get_user_by_req(request) + time_now = self.clock.time_msec() + results = yield self.store.take_e2e_one_time_keys( + [(user_id, device_id, algorithm)], time_now + ) + defer.returnValue(self.json_result(request, results)) + + @defer.inlineCallbacks + def on_POST(self, request, user_id, device_id, algorithm): + yield self.auth.get_user_by_req(request) + try: + body = json.loads(request.content.read()) + except: + raise SynapseError(400, "Invalid key JSON") + query = [] + for user_id, device_keys in body.get("one_time_keys", {}).items(): + for device_id, algorithm in device_keys.items(): + query.append((user_id, device_id, algorithm)) + time_now = self.clock.time_msec() + results = yield self.store.take_e2e_one_time_keys(query, time_now) + defer.returnValue(self.json_result(request, results)) + + def json_result(self, request, results): + json_result = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + return (200, {"one_time_keys": json_result}) + + +def register_servlets(hs, http_server): + KeyUploadServlet(hs).register(http_server) + KeyQueryServlet(hs).register(http_server) + OneTimeKeyServlet(hs).register(http_server) From 81682d0f820a6209535267a45ee28b8f66ff7794 Mon Sep 17 00:00:00 2001 From: Muthu Subramanian Date: Tue, 7 Jul 2015 17:40:30 +0530 Subject: [PATCH 12/23] Integrate SAML2 basic authentication - uses pysaml2 --- synapse/config/homeserver.py | 6 ++-- synapse/config/saml2.py | 27 ++++++++++++++ synapse/handlers/register.py | 30 ++++++++++++++++ synapse/python_dependencies.py | 1 + synapse/rest/client/v1/login.py | 62 ++++++++++++++++++++++++++++++++- 5 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 synapse/config/saml2.py diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index fe0ccb6eb..5c655c537 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -25,12 +25,12 @@ from .registration import RegistrationConfig from .metrics import MetricsConfig from .appservice import AppServiceConfig from .key import KeyConfig - +from .saml2 import SAML2Config class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, - VoipConfig, RegistrationConfig, - MetricsConfig, AppServiceConfig, KeyConfig,): + VoipConfig, RegistrationConfig, MetricsConfig, + AppServiceConfig, KeyConfig, SAML2Config, ): pass diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py new file mode 100644 index 000000000..4f3a724e2 --- /dev/null +++ b/synapse/config/saml2.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 Ericsson +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + +class SAML2Config(Config): + def read_config(self, config): + self.saml2_config = config["saml2_config"] + + def default_config(self, config_dir_path, server_name): + return """ + saml2_config: + config_path: "%s/sp_conf.py" + idp_redirect_url: "http://%s/idp" + """%(config_dir_path, server_name) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7b68585a1..4c6c5e297 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -192,6 +192,36 @@ class RegistrationHandler(BaseHandler): else: logger.info("Valid captcha entered from %s", ip) + @defer.inlineCallbacks + def register_saml2(self, localpart): + """ + Registers email_id as SAML2 Based Auth. + """ + if urllib.quote(localpart) != localpart: + raise SynapseError( + 400, + "User ID must only contain characters which do not" + " require URL encoding." + ) + user = UserID(localpart, self.hs.hostname) + user_id = user.to_string() + + yield self.check_user_id_is_valid(user_id) + token = self._generate_token(user_id) + try: + yield self.store.register( + user_id=user_id, + token=token, + password_hash=None + ) + yield self.distributor.fire("registered_user", user) + except Exception, e: + yield self.store.add_access_token_to_user(user_id, token) + # Ignore Registration errors + logger.exception(e) + defer.returnValue((user_id, token)) + + @defer.inlineCallbacks def register_email(self, threepidCreds): """ diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index f9e59dd91..17587170c 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -31,6 +31,7 @@ REQUIREMENTS = { "pillow": ["PIL"], "pydenticon": ["pydenticon"], "ujson": ["ujson"], + "pysaml2": ["saml2"], } CONDITIONAL_REQUIREMENTS = { "web_client": { diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index b2257b749..dc7615c6f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -20,14 +20,32 @@ from synapse.types import UserID from base import ClientV1RestServlet, client_path_pattern import simplejson as json +import cgi +import urllib + +import logging +from saml2 import BINDING_HTTP_REDIRECT +from saml2 import BINDING_HTTP_POST +from saml2.metadata import create_metadata_string +from saml2 import config +from saml2.client import Saml2Client +from saml2.httputil import ServiceError +from saml2.samlp import Extensions +from saml2.extension.pefim import SPCertEnc +from saml2.s_utils import rndstr class LoginRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login$") PASS_TYPE = "m.login.password" + SAML2_TYPE = "m.login.saml2" + + def __init__(self, hs): + super(LoginRestServlet, self).__init__(hs) + self.idp_redirect_url = hs.config.saml2_config['idp_redirect_url'] def on_GET(self, request): - return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]}) + return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}, {"type": LoginRestServlet.SAML2_TYPE}]}) def on_OPTIONS(self, request): return (200, {}) @@ -39,6 +57,14 @@ class LoginRestServlet(ClientV1RestServlet): if login_submission["type"] == LoginRestServlet.PASS_TYPE: result = yield self.do_password_login(login_submission) defer.returnValue(result) + elif login_submission["type"] == LoginRestServlet.SAML2_TYPE: + relay_state = "" + if "relay_state" in login_submission: + relay_state = "&RelayState="+urllib.quote(login_submission["relay_state"]) + result = { + "uri": "%s%s"%(self.idp_redirect_url, relay_state) + } + defer.returnValue((200, result)) else: raise SynapseError(400, "Bad login type.") except KeyError: @@ -93,6 +119,39 @@ class PasswordResetRestServlet(ClientV1RestServlet): "Missing keys. Requires 'email' and 'user_id'." ) +class SAML2RestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/login/saml2") + + def __init__(self, hs): + super(SAML2RestServlet, self).__init__(hs) + self.sp_config = hs.config.saml2_config['config_path'] + + @defer.inlineCallbacks + def on_POST(self, request): + saml2_auth = None + try: + conf = config.SPConfig() + conf.load_file(self.sp_config) + SP = Saml2Client(conf) + saml2_auth = SP.parse_authn_request_response(request.args['SAMLResponse'][0], BINDING_HTTP_POST) + except Exception, e: # Not authenticated + logger = logging.getLogger(__name__) + logger.exception(e) + if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed: + username = saml2_auth.name_id.text + handler = self.handlers.registration_handler + (user_id, token) = yield handler.register_saml2(username) + # Forward to the RelayState callback along with ava + if 'RelayState' in request.args: + request.redirect(urllib.unquote(request.args['RelayState'][0])+'?status=authenticated&access_token='+token+'&user_id='+user_id+'&ava='+urllib.quote(json.dumps(saml2_auth.ava))) + request.finish() + defer.returnValue(None) + defer.returnValue((200, {"status":"authenticated", "user_id": user_id, "token": token, "ava":saml2_auth.ava})) + elif 'RelayState' in request.args: + request.redirect(urllib.unquote(request.args['RelayState'][0])+'?status=not_authenticated') + request.finish() + defer.returnValue(None) + defer.returnValue((200, {"status":"not_authenticated"})) def _parse_json(request): try: @@ -106,4 +165,5 @@ def _parse_json(request): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + SAML2RestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server) From 77c5db5977c3fb61d9d2906c6692ee502d477e18 Mon Sep 17 00:00:00 2001 From: Muthu Subramanian Date: Wed, 8 Jul 2015 16:05:20 +0530 Subject: [PATCH 13/23] code beautify --- synapse/rest/client/v1/login.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index dc7615c6f..b4c74c4c2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -45,7 +45,8 @@ class LoginRestServlet(ClientV1RestServlet): self.idp_redirect_url = hs.config.saml2_config['idp_redirect_url'] def on_GET(self, request): - return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}, {"type": LoginRestServlet.SAML2_TYPE}]}) + return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}, + {"type": LoginRestServlet.SAML2_TYPE}]}) def on_OPTIONS(self, request): return (200, {}) @@ -60,9 +61,10 @@ class LoginRestServlet(ClientV1RestServlet): elif login_submission["type"] == LoginRestServlet.SAML2_TYPE: relay_state = "" if "relay_state" in login_submission: - relay_state = "&RelayState="+urllib.quote(login_submission["relay_state"]) + relay_state = "&RelayState="+urllib.quote( + login_submission["relay_state"]) result = { - "uri": "%s%s"%(self.idp_redirect_url, relay_state) + "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) else: @@ -119,6 +121,7 @@ class PasswordResetRestServlet(ClientV1RestServlet): "Missing keys. Requires 'email' and 'user_id'." ) + class SAML2RestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/saml2") @@ -133,25 +136,35 @@ class SAML2RestServlet(ClientV1RestServlet): conf = config.SPConfig() conf.load_file(self.sp_config) SP = Saml2Client(conf) - saml2_auth = SP.parse_authn_request_response(request.args['SAMLResponse'][0], BINDING_HTTP_POST) - except Exception, e: # Not authenticated + saml2_auth = SP.parse_authn_request_response( + request.args['SAMLResponse'][0], BINDING_HTTP_POST) + except Exception, e: # Not authenticated logger = logging.getLogger(__name__) logger.exception(e) - if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed: + if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed: username = saml2_auth.name_id.text handler = self.handlers.registration_handler (user_id, token) = yield handler.register_saml2(username) # Forward to the RelayState callback along with ava if 'RelayState' in request.args: - request.redirect(urllib.unquote(request.args['RelayState'][0])+'?status=authenticated&access_token='+token+'&user_id='+user_id+'&ava='+urllib.quote(json.dumps(saml2_auth.ava))) + request.redirect(urllib.unquote( + request.args['RelayState'][0]) + + '?status=authenticated&access_token=' + + token + '&user_id=' + user_id + '&ava=' + + urllib.quote(json.dumps(saml2_auth.ava))) request.finish() defer.returnValue(None) - defer.returnValue((200, {"status":"authenticated", "user_id": user_id, "token": token, "ava":saml2_auth.ava})) + defer.returnValue((200, {"status": "authenticated", + "user_id": user_id, "token": token, + "ava": saml2_auth.ava})) elif 'RelayState' in request.args: - request.redirect(urllib.unquote(request.args['RelayState'][0])+'?status=not_authenticated') + request.redirect(urllib.unquote( + request.args['RelayState'][0]) + + '?status=not_authenticated') request.finish() defer.returnValue(None) - defer.returnValue((200, {"status":"not_authenticated"})) + defer.returnValue((200, {"status": "not_authenticated"})) + def _parse_json(request): try: From f53bae0c1948a8c0a229e0b20f237f7ff4b1d84c Mon Sep 17 00:00:00 2001 From: Muthu Subramanian Date: Wed, 8 Jul 2015 16:05:46 +0530 Subject: [PATCH 14/23] code beautify --- synapse/config/homeserver.py | 1 + synapse/config/saml2.py | 3 ++- synapse/handlers/register.py | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 5c655c537..d77f04540 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -27,6 +27,7 @@ from .appservice import AppServiceConfig from .key import KeyConfig from .saml2 import SAML2Config + class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 4f3a724e2..d18d076a8 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -15,6 +15,7 @@ from ._base import Config + class SAML2Config(Config): def read_config(self, config): self.saml2_config = config["saml2_config"] @@ -24,4 +25,4 @@ class SAML2Config(Config): saml2_config: config_path: "%s/sp_conf.py" idp_redirect_url: "http://%s/idp" - """%(config_dir_path, server_name) + """ % (config_dir_path, server_name) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4c6c5e297..a1288b425 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -220,7 +220,6 @@ class RegistrationHandler(BaseHandler): # Ignore Registration errors logger.exception(e) defer.returnValue((user_id, token)) - @defer.inlineCallbacks def register_email(self, threepidCreds): From 8fb79eeea4b1c388771785024b79e84b4206fc24 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 8 Jul 2015 17:04:29 +0100 Subject: [PATCH 15/23] Only remove one time keys when new one time keys are added --- synapse/storage/end_to_end_keys.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 936a64669..b3cede37e 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -58,6 +58,11 @@ class EndToEndKeyStore(SQLBaseStore): def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until, key_list): def _add_e2e_one_time_keys(txn): + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?" + ) + txn.execute(sql, (user_id, device_id, time_now)) for (algorithm, key_id, json_bytes) in key_list: self._simple_upsert_txn( txn, table="e2e_one_time_keys_json", @@ -83,17 +88,12 @@ class EndToEndKeyStore(SQLBaseStore): Dict mapping from algorithm to number of keys for that algorithm. """ def _count_e2e_one_time_keys(txn): - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?" - ) - txn.execute(sql, (user_id, device_id, time_now)) sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ?" + " WHERE user_id = ? AND device_id = ? AND valid_until_ms >= ?" " GROUP BY algorithm" ) - txn.execute(sql, (user_id, device_id)) + txn.execute(sql, (user_id, device_id, time_now)) result = {} for algorithm, key_count in txn.fetchall(): result[algorithm] = key_count From d2caa5351aece72b274f78fe81348f715389d421 Mon Sep 17 00:00:00 2001 From: Muthu Subramanian Date: Thu, 9 Jul 2015 12:58:15 +0530 Subject: [PATCH 16/23] code beautify --- synapse/rest/client/v1/login.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index b4c74c4c2..b4894497b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -20,19 +20,15 @@ from synapse.types import UserID from base import ClientV1RestServlet, client_path_pattern import simplejson as json -import cgi import urllib import logging -from saml2 import BINDING_HTTP_REDIRECT from saml2 import BINDING_HTTP_POST -from saml2.metadata import create_metadata_string from saml2 import config from saml2.client import Saml2Client -from saml2.httputil import ServiceError -from saml2.samlp import Extensions -from saml2.extension.pefim import SPCertEnc -from saml2.s_utils import rndstr + + +logger = logging.getLogger(__name__) class LoginRestServlet(ClientV1RestServlet): @@ -137,9 +133,8 @@ class SAML2RestServlet(ClientV1RestServlet): conf.load_file(self.sp_config) SP = Saml2Client(conf) saml2_auth = SP.parse_authn_request_response( - request.args['SAMLResponse'][0], BINDING_HTTP_POST) + request.args['SAMLResponse'][0], BINDING_HTTP_POST) except Exception, e: # Not authenticated - logger = logging.getLogger(__name__) logger.exception(e) if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed: username = saml2_auth.name_id.text From 8cd34dfe955841d7ff3306b84a686e7138aec526 Mon Sep 17 00:00:00 2001 From: Muthu Subramanian Date: Thu, 9 Jul 2015 13:34:47 +0530 Subject: [PATCH 17/23] Make SAML2 optional and add some references/comments --- synapse/config/saml2.py | 14 ++++++++++++++ synapse/rest/client/v1/login.py | 13 +++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index d18d076a8..be5176db5 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -16,6 +16,19 @@ from ._base import Config +# +# SAML2 Configuration +# Synapse uses pysaml2 libraries for providing SAML2 support +# +# config_path: Path to the sp_conf.py configuration file +# idp_redirect_url: Identity provider URL which will redirect +# the user back to /login/saml2 with proper info. +# +# sp_conf.py file is something like: +# https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example +# +# More information: https://pythonhosted.org/pysaml2/howto/config.html +# class SAML2Config(Config): def read_config(self, config): self.saml2_config = config["saml2_config"] @@ -23,6 +36,7 @@ class SAML2Config(Config): def default_config(self, config_dir_path, server_name): return """ saml2_config: + enabled: false config_path: "%s/sp_conf.py" idp_redirect_url: "http://%s/idp" """ % (config_dir_path, server_name) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index b4894497b..f64f5e990 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -39,10 +39,13 @@ class LoginRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_config['idp_redirect_url'] + self.saml2_enabled = hs.config.saml2_config['enabled'] def on_GET(self, request): - return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}, - {"type": LoginRestServlet.SAML2_TYPE}]}) + flows = [{"type": LoginRestServlet.PASS_TYPE}] + if self.saml2_enabled: + flows.append({"type": LoginRestServlet.SAML2_TYPE}) + return (200, {"flows": flows}) def on_OPTIONS(self, request): return (200, {}) @@ -54,7 +57,8 @@ class LoginRestServlet(ClientV1RestServlet): if login_submission["type"] == LoginRestServlet.PASS_TYPE: result = yield self.do_password_login(login_submission) defer.returnValue(result) - elif login_submission["type"] == LoginRestServlet.SAML2_TYPE: + elif self.saml2_enabled and (login_submission["type"] == + LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: relay_state = "&RelayState="+urllib.quote( @@ -173,5 +177,6 @@ def _parse_json(request): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) - SAML2RestServlet(hs).register(http_server) + if hs.config.saml2_config['enabled']: + SAML2RestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server) From bf0d59ed30b63c6a355e7b3f2a74a26181fd6893 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 9 Jul 2015 14:04:03 +0100 Subject: [PATCH 18/23] Don't bother with a timeout for one time keys on the server. --- synapse/rest/client/v2_alpha/keys.py | 25 ++++++------------- synapse/storage/end_to_end_keys.py | 20 +++++---------- .../schema/delta/21/end_to_end_keys.sql | 1 - 3 files changed, 13 insertions(+), 33 deletions(-) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 3bb4ad64f..4b617c251 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -50,7 +50,6 @@ class KeyUploadServlet(RestServlet): "one_time_keys": { ":": "" }, - "one_time_keys_valid_for": , } """ PATTERN = client_v2_pattern("/keys/upload/(?P[^/]*)") @@ -87,13 +86,10 @@ class KeyUploadServlet(RestServlet): ) one_time_keys = body.get("one_time_keys", None) - one_time_keys_valid_for = body.get("one_time_keys_valid_for", None) if one_time_keys: - valid_until = int(one_time_keys_valid_for) + time_now logger.info( - "Adding %d one_time_keys for device %r for user %r at %d" - " valid_until %d", - len(one_time_keys), device_id, user_id, time_now, valid_until + "Adding %d one_time_keys for device %r for user %r at %d", + len(one_time_keys), device_id, user_id, time_now ) key_list = [] for key_id, key_json in one_time_keys.items(): @@ -103,23 +99,18 @@ class KeyUploadServlet(RestServlet): )) yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, valid_until, key_list + user_id, device_id, time_now, key_list ) - result = yield self.store.count_e2e_one_time_keys( - user_id, device_id, time_now - ) + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) @defer.inlineCallbacks def on_GET(self, request, device_id): auth_user, client_info = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() - time_now = self.clock.time_msec() - result = yield self.store.count_e2e_one_time_keys( - user_id, device_id, time_now - ) + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) @@ -249,9 +240,8 @@ class OneTimeKeyServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - time_now = self.clock.time_msec() results = yield self.store.take_e2e_one_time_keys( - [(user_id, device_id, algorithm)], time_now + [(user_id, device_id, algorithm)] ) defer.returnValue(self.json_result(request, results)) @@ -266,8 +256,7 @@ class OneTimeKeyServlet(RestServlet): for user_id, device_keys in body.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) - time_now = self.clock.time_msec() - results = yield self.store.take_e2e_one_time_keys(query, time_now) + results = yield self.store.take_e2e_one_time_keys(query) defer.returnValue(self.json_result(request, results)) def json_result(self, request, results): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index b3cede37e..99dc864e4 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -55,14 +55,8 @@ class EndToEndKeyStore(SQLBaseStore): return result return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) - def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until, - key_list): + def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def _add_e2e_one_time_keys(txn): - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?" - ) - txn.execute(sql, (user_id, device_id, time_now)) for (algorithm, key_id, json_bytes) in key_list: self._simple_upsert_txn( txn, table="e2e_one_time_keys_json", @@ -74,7 +68,6 @@ class EndToEndKeyStore(SQLBaseStore): }, values={ "ts_added_ms": time_now, - "valid_until_ms": valid_until, "key_json": json_bytes, } ) @@ -82,7 +75,7 @@ class EndToEndKeyStore(SQLBaseStore): "add_e2e_one_time_keys", _add_e2e_one_time_keys ) - def count_e2e_one_time_keys(self, user_id, device_id, time_now): + def count_e2e_one_time_keys(self, user_id, device_id): """ Count the number of one time keys the server has for a device Returns: Dict mapping from algorithm to number of keys for that algorithm. @@ -90,10 +83,10 @@ class EndToEndKeyStore(SQLBaseStore): def _count_e2e_one_time_keys(txn): sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND valid_until_ms >= ?" + " WHERE user_id = ? AND device_id = ?" " GROUP BY algorithm" ) - txn.execute(sql, (user_id, device_id, time_now)) + txn.execute(sql, (user_id, device_id)) result = {} for algorithm, key_count in txn.fetchall(): result[algorithm] = key_count @@ -102,13 +95,12 @@ class EndToEndKeyStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - def take_e2e_one_time_keys(self, query_list, time_now): + def take_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" def _take_e2e_one_time_keys(txn): sql = ( "SELECT key_id, key_json FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND valid_until_ms > ?" " LIMIT 1" ) result = {} @@ -116,7 +108,7 @@ class EndToEndKeyStore(SQLBaseStore): for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm, time_now)) + txn.execute(sql, (user_id, device_id, algorithm)) for key_id, key_json in txn.fetchall(): device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql index 107d2e67c..95e27eb7e 100644 --- a/synapse/storage/schema/delta/21/end_to_end_keys.sql +++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql @@ -29,7 +29,6 @@ CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. - valid_until_ms BIGINT NOT NULL, -- When this key is valid until. key_json TEXT NOT NULL, -- The key as a JSON blob. CONSTRAINT uniqueness UNIQUE (user_id, device_id, algorithm, key_id) ); From b5f0d73ea3f6611c0980a03a0dfe57058071013e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 9 Jul 2015 17:09:26 +0100 Subject: [PATCH 19/23] Add comment --- synapse/handlers/federation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index cd3867ed9..d7f197f24 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -242,6 +242,10 @@ class FederationHandler(BaseHandler): if history: visibility = history.content.get("history_visibility", "shared") if visibility in ["invited", "joined"]: + # We now loop through all state events looking for + # membership states for the requesting server to determine + # if the server is either in the room or has been invited + # into the room. for ev in state.values(): if ev.type != EventTypes.Member: continue From a887efa07aa9757538cb36d0dfa2b79925eba4aa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 10 Jul 2015 10:37:24 +0100 Subject: [PATCH 20/23] Add Muthu Subramanian to AUTHORS --- AUTHORS.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/AUTHORS.rst b/AUTHORS.rst index d7224ff5d..54ced6700 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -42,3 +42,6 @@ Ivan Shapovalov Eric Myhre * Fix bug where ``media_store_path`` config option was ignored by v0 content repository API. + +Muthu Subramanian + * Add SAML2 support for registration and logins. From f3049d0b81ad626de7ca80330608b374e0ec8b5b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 10 Jul 2015 10:50:03 +0100 Subject: [PATCH 21/23] Small tweaks to SAML2 configuration. - Add saml2 config docs to default config. - Use existence of saml2 config to indicate if saml2 should be enabled. --- synapse/config/saml2.py | 48 ++++++++++++++++++++------------- synapse/rest/client/v1/login.py | 8 +++--- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index be5176db5..153203687 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -16,27 +16,39 @@ from ._base import Config -# -# SAML2 Configuration -# Synapse uses pysaml2 libraries for providing SAML2 support -# -# config_path: Path to the sp_conf.py configuration file -# idp_redirect_url: Identity provider URL which will redirect -# the user back to /login/saml2 with proper info. -# -# sp_conf.py file is something like: -# https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example -# -# More information: https://pythonhosted.org/pysaml2/howto/config.html -# class SAML2Config(Config): + """SAML2 Configuration + Synapse uses pysaml2 libraries for providing SAML2 support + + config_path: Path to the sp_conf.py configuration file + idp_redirect_url: Identity provider URL which will redirect + the user back to /login/saml2 with proper info. + + sp_conf.py file is something like: + https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example + + More information: https://pythonhosted.org/pysaml2/howto/config.html + """ + def read_config(self, config): - self.saml2_config = config["saml2_config"] + saml2_config = config.get("saml2_config", None) + if saml2_config: + self.saml2_enabled = True + self.saml2_config_path = saml2_config["config_path"] + self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"] + else: + self.saml2_enabled = False + self.saml2_config_path = None + self.saml2_idp_redirect_url = None def default_config(self, config_dir_path, server_name): return """ - saml2_config: - enabled: false - config_path: "%s/sp_conf.py" - idp_redirect_url: "http://%s/idp" + # Enable SAML2 for registration and login. Uses pysaml2 + # config_path: Path to the sp_conf.py configuration file + # idp_redirect_url: Identity provider URL which will redirect + # the user back to /login/saml2 with proper info. + # See pysaml2 docs for format of config. + #saml2_config: + # config_path: "%s/sp_conf.py" + # idp_redirect_url: "http://%s/idp" """ % (config_dir_path, server_name) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f64f5e990..998d4d44c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -38,8 +38,8 @@ class LoginRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) - self.idp_redirect_url = hs.config.saml2_config['idp_redirect_url'] - self.saml2_enabled = hs.config.saml2_config['enabled'] + self.idp_redirect_url = hs.config.saml2_idp_redirect_url + self.saml2_enabled = hs.config.saml2_enabled def on_GET(self, request): flows = [{"type": LoginRestServlet.PASS_TYPE}] @@ -127,7 +127,7 @@ class SAML2RestServlet(ClientV1RestServlet): def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) - self.sp_config = hs.config.saml2_config['config_path'] + self.sp_config = hs.config.saml2_config_path @defer.inlineCallbacks def on_POST(self, request): @@ -177,6 +177,6 @@ def _parse_json(request): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) - if hs.config.saml2_config['enabled']: + if hs.config.saml2_enabled: SAML2RestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server) From a01097d60b9c711d71cd5d1c63cb4fb5b95a8a63 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 10 Jul 2015 13:26:18 +0100 Subject: [PATCH 22/23] Assume that each device for a user has only one of each type of key --- synapse/rest/client/v2_alpha/keys.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 4b617c251..f03126775 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -41,11 +41,11 @@ class KeyUploadServlet(RestServlet): "m.olm.curve25519-aes-sha256", ] "keys": { - ":": "", + ":": "", }, "signatures:" { - "/" { - ":": "" + "" { + ":": "" } } }, "one_time_keys": { ":": "" From 0d7f0febf45575666d182cb2ad0cf451b11fbf07 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 10 Jul 2015 13:43:03 +0100 Subject: [PATCH 23/23] Uniquely name unique constraint --- synapse/storage/schema/delta/21/end_to_end_keys.sql | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql index 95e27eb7e..8b4a380d1 100644 --- a/synapse/storage/schema/delta/21/end_to_end_keys.sql +++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql @@ -19,7 +19,7 @@ CREATE TABLE IF NOT EXISTS e2e_device_keys_json ( device_id TEXT NOT NULL, -- Which of the user's devices these keys are for. ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded. key_json TEXT NOT NULL, -- The keys for the device as a JSON blob. - CONSTRAINT uniqueness UNIQUE (user_id, device_id) + CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) ); @@ -30,5 +30,5 @@ CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. key_json TEXT NOT NULL, -- The key as a JSON blob. - CONSTRAINT uniqueness UNIQUE (user_id, device_id, algorithm, key_id) + CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) );