mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 12:43:50 +01:00
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/receipts
This commit is contained in:
commit
5989637f37
23 changed files with 1330 additions and 222 deletions
|
@ -42,3 +42,6 @@ Ivan Shapovalov <intelfx100 at gmail.com>
|
||||||
Eric Myhre <hash at exultant.us>
|
Eric Myhre <hash at exultant.us>
|
||||||
* Fix bug where ``media_store_path`` config option was ignored by v0 content
|
* Fix bug where ``media_store_path`` config option was ignored by v0 content
|
||||||
repository API.
|
repository API.
|
||||||
|
|
||||||
|
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
|
||||||
|
* Add SAML2 support for registration and logins.
|
||||||
|
|
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AuthEventTypes = (
|
AuthEventTypes = (
|
||||||
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
||||||
EventTypes.JoinRules,
|
EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -575,6 +575,7 @@ class Auth(object):
|
||||||
levels_to_check = [
|
levels_to_check = [
|
||||||
("users_default", []),
|
("users_default", []),
|
||||||
("events_default", []),
|
("events_default", []),
|
||||||
|
("state_default", []),
|
||||||
("ban", []),
|
("ban", []),
|
||||||
("redact", []),
|
("redact", []),
|
||||||
("kick", []),
|
("kick", []),
|
||||||
|
|
|
@ -75,6 +75,8 @@ class EventTypes(object):
|
||||||
Redaction = "m.room.redaction"
|
Redaction = "m.room.redaction"
|
||||||
Feedback = "m.room.message.feedback"
|
Feedback = "m.room.message.feedback"
|
||||||
|
|
||||||
|
RoomHistoryVisibility = "m.room.history_visibility"
|
||||||
|
|
||||||
# These are used for validation
|
# These are used for validation
|
||||||
Message = "m.room.message"
|
Message = "m.room.message"
|
||||||
Topic = "m.room.topic"
|
Topic = "m.room.topic"
|
||||||
|
|
|
@ -25,12 +25,13 @@ from .registration import RegistrationConfig
|
||||||
from .metrics import MetricsConfig
|
from .metrics import MetricsConfig
|
||||||
from .appservice import AppServiceConfig
|
from .appservice import AppServiceConfig
|
||||||
from .key import KeyConfig
|
from .key import KeyConfig
|
||||||
|
from .saml2 import SAML2Config
|
||||||
|
|
||||||
|
|
||||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||||
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
||||||
VoipConfig, RegistrationConfig,
|
VoipConfig, RegistrationConfig, MetricsConfig,
|
||||||
MetricsConfig, AppServiceConfig, KeyConfig,):
|
AppServiceConfig, KeyConfig, SAML2Config, ):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
54
synapse/config/saml2.py
Normal file
54
synapse/config/saml2.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 Ericsson
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
class SAML2Config(Config):
|
||||||
|
"""SAML2 Configuration
|
||||||
|
Synapse uses pysaml2 libraries for providing SAML2 support
|
||||||
|
|
||||||
|
config_path: Path to the sp_conf.py configuration file
|
||||||
|
idp_redirect_url: Identity provider URL which will redirect
|
||||||
|
the user back to /login/saml2 with proper info.
|
||||||
|
|
||||||
|
sp_conf.py file is something like:
|
||||||
|
https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
|
||||||
|
|
||||||
|
More information: https://pythonhosted.org/pysaml2/howto/config.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def read_config(self, config):
|
||||||
|
saml2_config = config.get("saml2_config", None)
|
||||||
|
if saml2_config:
|
||||||
|
self.saml2_enabled = True
|
||||||
|
self.saml2_config_path = saml2_config["config_path"]
|
||||||
|
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
|
||||||
|
else:
|
||||||
|
self.saml2_enabled = False
|
||||||
|
self.saml2_config_path = None
|
||||||
|
self.saml2_idp_redirect_url = None
|
||||||
|
|
||||||
|
def default_config(self, config_dir_path, server_name):
|
||||||
|
return """
|
||||||
|
# Enable SAML2 for registration and login. Uses pysaml2
|
||||||
|
# config_path: Path to the sp_conf.py configuration file
|
||||||
|
# idp_redirect_url: Identity provider URL which will redirect
|
||||||
|
# the user back to /login/saml2 with proper info.
|
||||||
|
# See pysaml2 docs for format of config.
|
||||||
|
#saml2_config:
|
||||||
|
# config_path: "%s/sp_conf.py"
|
||||||
|
# idp_redirect_url: "http://%s/idp"
|
||||||
|
""" % (config_dir_path, server_name)
|
|
@ -25,11 +25,13 @@ from syutil.base64util import decode_base64, encode_base64
|
||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter
|
from synapse.util.retryutils import get_retry_limiter
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
import urllib
|
import urllib
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
@ -38,6 +40,9 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
||||||
|
|
||||||
|
|
||||||
class Keyring(object):
|
class Keyring(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -49,141 +54,325 @@ class Keyring(object):
|
||||||
|
|
||||||
self.key_downloads = {}
|
self.key_downloads = {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def verify_json_for_server(self, server_name, json_object):
|
def verify_json_for_server(self, server_name, json_object):
|
||||||
logger.debug("Verifying for %s", server_name)
|
return self.verify_json_objects_for_server(
|
||||||
key_ids = signature_ids(json_object, server_name)
|
[(server_name, json_object)]
|
||||||
if not key_ids:
|
)[0]
|
||||||
raise SynapseError(
|
|
||||||
400,
|
|
||||||
"Not signed with a supported algorithm",
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
verify_key = yield self.get_server_verify_key(server_name, key_ids)
|
|
||||||
except IOError as e:
|
|
||||||
logger.warn(
|
|
||||||
"Got IOError when downloading keys for %s: %s %s",
|
|
||||||
server_name, type(e).__name__, str(e.message),
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
|
||||||
502,
|
|
||||||
"Error downloading keys for %s" % (server_name,),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn(
|
|
||||||
"Got Exception when downloading keys for %s: %s %s",
|
|
||||||
server_name, type(e).__name__, str(e.message),
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
|
||||||
401,
|
|
||||||
"No key for %s with id %s" % (server_name, key_ids),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
def verify_json_objects_for_server(self, server_and_json):
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
"""Bulk verfies signatures of json objects, bulk fetching keys as
|
||||||
except:
|
necessary.
|
||||||
raise SynapseError(
|
|
||||||
401,
|
|
||||||
"Invalid signature for server %s with key %s:%s" % (
|
|
||||||
server_name, verify_key.alg, verify_key.version
|
|
||||||
),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_server_verify_key(self, server_name, key_ids):
|
|
||||||
"""Finds a verification key for the server with one of the key ids.
|
|
||||||
Trys to fetch the key from a trusted perspective server first.
|
|
||||||
Args:
|
Args:
|
||||||
server_name(str): The name of the server to fetch a key for.
|
server_and_json (list): List of pairs of (server_name, json_object)
|
||||||
keys_ids (list of str): The key_ids to check for.
|
|
||||||
|
Returns:
|
||||||
|
list of deferreds indicating success or failure to verify each
|
||||||
|
json object's signature for the given server_name.
|
||||||
"""
|
"""
|
||||||
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
|
group_id_to_json = {}
|
||||||
|
group_id_to_group = {}
|
||||||
|
group_ids = []
|
||||||
|
|
||||||
if cached:
|
next_group_id = 0
|
||||||
defer.returnValue(cached[0])
|
deferreds = {}
|
||||||
return
|
|
||||||
|
|
||||||
download = self.key_downloads.get(server_name)
|
for server_name, json_object in server_and_json:
|
||||||
|
logger.debug("Verifying for %s", server_name)
|
||||||
|
group_id = next_group_id
|
||||||
|
next_group_id += 1
|
||||||
|
group_ids.append(group_id)
|
||||||
|
|
||||||
if download is None:
|
key_ids = signature_ids(json_object, server_name)
|
||||||
download = self._get_server_verify_key_impl(server_name, key_ids)
|
if not key_ids:
|
||||||
download = ObservableDeferred(
|
deferreds[group_id] = defer.fail(SynapseError(
|
||||||
download,
|
400,
|
||||||
consumeErrors=True
|
"Not signed with a supported algorithm",
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
deferreds[group_id] = defer.Deferred()
|
||||||
|
|
||||||
|
group = KeyGroup(server_name, group_id, key_ids)
|
||||||
|
|
||||||
|
group_id_to_group[group_id] = group
|
||||||
|
group_id_to_json[group_id] = json_object
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_key_deferred(group, deferred):
|
||||||
|
server_name = group.server_name
|
||||||
|
try:
|
||||||
|
_, _, key_id, verify_key = yield deferred
|
||||||
|
except IOError as e:
|
||||||
|
logger.warn(
|
||||||
|
"Got IOError when downloading keys for %s: %s %s",
|
||||||
|
server_name, type(e).__name__, str(e.message),
|
||||||
|
)
|
||||||
|
raise SynapseError(
|
||||||
|
502,
|
||||||
|
"Error downloading keys for %s" % (server_name,),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Got Exception when downloading keys for %s: %s %s",
|
||||||
|
server_name, type(e).__name__, str(e.message),
|
||||||
|
)
|
||||||
|
raise SynapseError(
|
||||||
|
401,
|
||||||
|
"No key for %s with id %s" % (server_name, key_ids),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
json_object = group_id_to_json[group.group_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
|
except:
|
||||||
|
raise SynapseError(
|
||||||
|
401,
|
||||||
|
"Invalid signature for server %s with key %s:%s" % (
|
||||||
|
server_name, verify_key.alg, verify_key.version
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_to_deferred = {
|
||||||
|
server_name: defer.Deferred()
|
||||||
|
for server_name, _ in server_and_json
|
||||||
|
}
|
||||||
|
|
||||||
|
# We want to wait for any previous lookups to complete before
|
||||||
|
# proceeding.
|
||||||
|
wait_on_deferred = self.wait_for_previous_lookups(
|
||||||
|
[server_name for server_name, _ in server_and_json],
|
||||||
|
server_to_deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actually start fetching keys.
|
||||||
|
wait_on_deferred.addBoth(
|
||||||
|
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
||||||
|
)
|
||||||
|
|
||||||
|
# When we've finished fetching all the keys for a given server_name,
|
||||||
|
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||||
|
# any lookups waiting will proceed.
|
||||||
|
server_to_gids = {}
|
||||||
|
|
||||||
|
def remove_deferreds(res, server_name, group_id):
|
||||||
|
server_to_gids[server_name].discard(group_id)
|
||||||
|
if not server_to_gids[server_name]:
|
||||||
|
server_to_deferred.pop(server_name).callback(None)
|
||||||
|
return res
|
||||||
|
|
||||||
|
for g_id, deferred in deferreds.items():
|
||||||
|
server_name = group_id_to_group[g_id].server_name
|
||||||
|
server_to_gids.setdefault(server_name, set()).add(g_id)
|
||||||
|
deferred.addBoth(remove_deferreds, server_name, g_id)
|
||||||
|
|
||||||
|
# Pass those keys to handle_key_deferred so that the json object
|
||||||
|
# signatures can be verified
|
||||||
|
return [
|
||||||
|
handle_key_deferred(
|
||||||
|
group_id_to_group[g_id],
|
||||||
|
deferreds[g_id],
|
||||||
)
|
)
|
||||||
self.key_downloads[server_name] = download
|
for g_id in group_ids
|
||||||
|
]
|
||||||
@download.addBoth
|
|
||||||
def callback(ret):
|
|
||||||
del self.key_downloads[server_name]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
r = yield download.observe()
|
|
||||||
defer.returnValue(r)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_server_verify_key_impl(self, server_name, key_ids):
|
def wait_for_previous_lookups(self, server_names, server_to_deferred):
|
||||||
keys = None
|
"""Waits for any previous key lookups for the given servers to finish.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_names (list): list of server_names we want to lookup
|
||||||
|
server_to_deferred (dict): server_name to deferred which gets
|
||||||
|
resolved once we've finished looking up keys for that server
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
wait_on = [
|
||||||
|
self.key_downloads[server_name]
|
||||||
|
for server_name in server_names
|
||||||
|
if server_name in self.key_downloads
|
||||||
|
]
|
||||||
|
if wait_on:
|
||||||
|
yield defer.DeferredList(wait_on)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
for server_name, deferred in server_to_deferred:
|
||||||
|
self.key_downloads[server_name] = ObservableDeferred(deferred)
|
||||||
|
|
||||||
|
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
||||||
|
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||||
|
each group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# These are functions that produce keys given a list of key ids
|
||||||
|
key_fetch_fns = (
|
||||||
|
self.get_keys_from_store, # First try the local store
|
||||||
|
self.get_keys_from_perspectives, # Then try via perspectives
|
||||||
|
self.get_keys_from_server, # Then try directly
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_iterations():
|
||||||
|
merged_results = {}
|
||||||
|
|
||||||
|
missing_keys = {
|
||||||
|
group.server_name: key_id
|
||||||
|
for group in group_id_to_group.values()
|
||||||
|
for key_id in group.key_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
for fn in key_fetch_fns:
|
||||||
|
results = yield fn(missing_keys.items())
|
||||||
|
merged_results.update(results)
|
||||||
|
|
||||||
|
# We now need to figure out which groups we have keys for
|
||||||
|
# and which we don't
|
||||||
|
missing_groups = {}
|
||||||
|
for group in group_id_to_group.values():
|
||||||
|
for key_id in group.key_ids:
|
||||||
|
if key_id in merged_results[group.server_name]:
|
||||||
|
group_id_to_deferred[group.group_id].callback((
|
||||||
|
group.group_id,
|
||||||
|
group.server_name,
|
||||||
|
key_id,
|
||||||
|
merged_results[group.server_name][key_id],
|
||||||
|
))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
missing_groups.setdefault(
|
||||||
|
group.server_name, []
|
||||||
|
).append(group)
|
||||||
|
|
||||||
|
if not missing_groups:
|
||||||
|
break
|
||||||
|
|
||||||
|
missing_keys = {
|
||||||
|
server_name: set(
|
||||||
|
key_id for group in groups for key_id in group.key_ids
|
||||||
|
)
|
||||||
|
for server_name, groups in missing_groups.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
for group in missing_groups.values():
|
||||||
|
group_id_to_deferred[group.group_id].errback(SynapseError(
|
||||||
|
401,
|
||||||
|
"No key for %s with id %s" % (
|
||||||
|
group.server_name, group.key_ids,
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
))
|
||||||
|
|
||||||
|
def on_err(err):
|
||||||
|
for deferred in group_id_to_deferred.values():
|
||||||
|
if not deferred.called:
|
||||||
|
deferred.errback(err)
|
||||||
|
|
||||||
|
do_iterations().addErrback(on_err)
|
||||||
|
|
||||||
|
return group_id_to_deferred
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
|
res = yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
self.store.get_server_verify_keys(server_name, key_ids)
|
||||||
|
for server_name, key_ids in server_name_and_key_ids
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
defer.returnValue(dict(zip(
|
||||||
|
[server_name for server_name, _ in server_name_and_key_ids],
|
||||||
|
res
|
||||||
|
)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_keys_from_perspectives(self, server_name_and_key_ids):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_key(perspective_name, perspective_keys):
|
def get_key(perspective_name, perspective_keys):
|
||||||
try:
|
try:
|
||||||
result = yield self.get_server_verify_key_v2_indirect(
|
result = yield self.get_server_verify_key_v2_indirect(
|
||||||
server_name, key_ids, perspective_name, perspective_keys
|
server_name_and_key_ids, perspective_name, perspective_keys
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logger.exception(
|
||||||
"Unable to getting key %r for %r from %r: %s %s",
|
"Unable to get key from %r: %s %s",
|
||||||
key_ids, server_name, perspective_name,
|
perspective_name,
|
||||||
type(e).__name__, str(e.message),
|
type(e).__name__, str(e.message),
|
||||||
)
|
)
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
perspective_results = yield defer.gatherResults([
|
results = yield defer.gatherResults(
|
||||||
get_key(p_name, p_keys)
|
[
|
||||||
for p_name, p_keys in self.perspective_servers.items()
|
get_key(p_name, p_keys)
|
||||||
])
|
for p_name, p_keys in self.perspective_servers.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
for results in perspective_results:
|
union_of_keys = {}
|
||||||
if results is not None:
|
for result in results:
|
||||||
keys = results
|
for server_name, keys in result.items():
|
||||||
|
union_of_keys.setdefault(server_name, {}).update(keys)
|
||||||
|
|
||||||
limiter = yield get_retry_limiter(
|
defer.returnValue(union_of_keys)
|
||||||
server_name,
|
|
||||||
self.clock,
|
|
||||||
self.store,
|
|
||||||
)
|
|
||||||
|
|
||||||
with limiter:
|
@defer.inlineCallbacks
|
||||||
if not keys:
|
def get_keys_from_server(self, server_name_and_key_ids):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_key(server_name, key_ids):
|
||||||
|
limiter = yield get_retry_limiter(
|
||||||
|
server_name,
|
||||||
|
self.clock,
|
||||||
|
self.store,
|
||||||
|
)
|
||||||
|
with limiter:
|
||||||
|
keys = None
|
||||||
try:
|
try:
|
||||||
keys = yield self.get_server_verify_key_v2_direct(
|
keys = yield self.get_server_verify_key_v2_direct(
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logger.info(
|
||||||
"Unable to getting key %r for %r directly: %s %s",
|
"Unable to getting key %r for %r directly: %s %s",
|
||||||
key_ids, server_name,
|
key_ids, server_name,
|
||||||
type(e).__name__, str(e.message),
|
type(e).__name__, str(e.message),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
keys = yield self.get_server_verify_key_v1_direct(
|
keys = yield self.get_server_verify_key_v1_direct(
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
for key_id in key_ids:
|
keys = {server_name: keys}
|
||||||
if key_id in keys:
|
|
||||||
defer.returnValue(keys[key_id])
|
defer.returnValue(keys)
|
||||||
return
|
|
||||||
raise ValueError("No verification key found for given key ids")
|
results = yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
get_key(server_name, key_ids)
|
||||||
|
for server_name, key_ids in server_name_and_key_ids
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
merged = {}
|
||||||
|
for result in results:
|
||||||
|
merged.update(result)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
server_name: keys
|
||||||
|
for server_name, keys in merged.items()
|
||||||
|
if keys
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
|
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
|
||||||
perspective_name,
|
perspective_name,
|
||||||
perspective_keys):
|
perspective_keys):
|
||||||
limiter = yield get_retry_limiter(
|
limiter = yield get_retry_limiter(
|
||||||
|
@ -204,6 +393,7 @@ class Keyring(object):
|
||||||
u"minimum_valid_until_ts": 0
|
u"minimum_valid_until_ts": 0
|
||||||
} for key_id in key_ids
|
} for key_id in key_ids
|
||||||
}
|
}
|
||||||
|
for server_name, key_ids in server_names_and_key_ids
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -243,23 +433,29 @@ class Keyring(object):
|
||||||
" server %r" % (perspective_name,)
|
" server %r" % (perspective_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
response_keys = yield self.process_v2_response(
|
processed_response = yield self.process_v2_response(
|
||||||
server_name, perspective_name, response
|
perspective_name, response
|
||||||
)
|
)
|
||||||
|
|
||||||
keys.update(response_keys)
|
for server_name, response_keys in processed_response.items():
|
||||||
|
keys.setdefault(server_name, {}).update(response_keys)
|
||||||
|
|
||||||
yield self.store_keys(
|
yield defer.gatherResults(
|
||||||
server_name=server_name,
|
[
|
||||||
from_server=perspective_name,
|
self.store_keys(
|
||||||
verify_keys=keys,
|
server_name=server_name,
|
||||||
)
|
from_server=perspective_name,
|
||||||
|
verify_keys=response_keys,
|
||||||
|
)
|
||||||
|
for server_name, response_keys in keys.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(keys)
|
defer.returnValue(keys)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||||
|
|
||||||
keys = {}
|
keys = {}
|
||||||
|
|
||||||
for requested_key_id in key_ids:
|
for requested_key_id in key_ids:
|
||||||
|
@ -295,25 +491,30 @@ class Keyring(object):
|
||||||
raise ValueError("TLS certificate not allowed by fingerprints")
|
raise ValueError("TLS certificate not allowed by fingerprints")
|
||||||
|
|
||||||
response_keys = yield self.process_v2_response(
|
response_keys = yield self.process_v2_response(
|
||||||
server_name=server_name,
|
|
||||||
from_server=server_name,
|
from_server=server_name,
|
||||||
requested_id=requested_key_id,
|
requested_ids=[requested_key_id],
|
||||||
response_json=response,
|
response_json=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
keys.update(response_keys)
|
keys.update(response_keys)
|
||||||
|
|
||||||
yield self.store_keys(
|
yield defer.gatherResults(
|
||||||
server_name=server_name,
|
[
|
||||||
from_server=server_name,
|
self.store_keys(
|
||||||
verify_keys=keys,
|
server_name=key_server_name,
|
||||||
)
|
from_server=server_name,
|
||||||
|
verify_keys=verify_keys,
|
||||||
|
)
|
||||||
|
for key_server_name, verify_keys in keys.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(keys)
|
defer.returnValue(keys)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def process_v2_response(self, server_name, from_server, response_json,
|
def process_v2_response(self, from_server, response_json,
|
||||||
requested_id=None):
|
requested_ids=[]):
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
response_keys = {}
|
response_keys = {}
|
||||||
verify_keys = {}
|
verify_keys = {}
|
||||||
|
@ -335,6 +536,8 @@ class Keyring(object):
|
||||||
verify_key.time_added = time_now_ms
|
verify_key.time_added = time_now_ms
|
||||||
old_verify_keys[key_id] = verify_key
|
old_verify_keys[key_id] = verify_key
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
server_name = response_json["server_name"]
|
||||||
for key_id in response_json["signatures"].get(server_name, {}):
|
for key_id in response_json["signatures"].get(server_name, {}):
|
||||||
if key_id not in response_json["verify_keys"]:
|
if key_id not in response_json["verify_keys"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -357,28 +560,31 @@ class Keyring(object):
|
||||||
signed_key_json_bytes = encode_canonical_json(signed_key_json)
|
signed_key_json_bytes = encode_canonical_json(signed_key_json)
|
||||||
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
|
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
|
||||||
|
|
||||||
updated_key_ids = set()
|
updated_key_ids = set(requested_ids)
|
||||||
if requested_id is not None:
|
|
||||||
updated_key_ids.add(requested_id)
|
|
||||||
updated_key_ids.update(verify_keys)
|
updated_key_ids.update(verify_keys)
|
||||||
updated_key_ids.update(old_verify_keys)
|
updated_key_ids.update(old_verify_keys)
|
||||||
|
|
||||||
response_keys.update(verify_keys)
|
response_keys.update(verify_keys)
|
||||||
response_keys.update(old_verify_keys)
|
response_keys.update(old_verify_keys)
|
||||||
|
|
||||||
for key_id in updated_key_ids:
|
yield defer.gatherResults(
|
||||||
yield self.store.store_server_keys_json(
|
[
|
||||||
server_name=server_name,
|
self.store.store_server_keys_json(
|
||||||
key_id=key_id,
|
server_name=server_name,
|
||||||
from_server=server_name,
|
key_id=key_id,
|
||||||
ts_now_ms=time_now_ms,
|
from_server=server_name,
|
||||||
ts_expires_ms=ts_valid_until_ms,
|
ts_now_ms=time_now_ms,
|
||||||
key_json_bytes=signed_key_json_bytes,
|
ts_expires_ms=ts_valid_until_ms,
|
||||||
)
|
key_json_bytes=signed_key_json_bytes,
|
||||||
|
)
|
||||||
|
for key_id in updated_key_ids
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(response_keys)
|
results[server_name] = response_keys
|
||||||
|
|
||||||
raise ValueError("No verification key found for given key ids")
|
defer.returnValue(results)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_key_v1_direct(self, server_name, key_ids):
|
def get_server_verify_key_v1_direct(self, server_name, key_ids):
|
||||||
|
@ -462,8 +668,13 @@ class Keyring(object):
|
||||||
Returns:
|
Returns:
|
||||||
A deferred that completes when the keys are stored.
|
A deferred that completes when the keys are stored.
|
||||||
"""
|
"""
|
||||||
for key_id, key in verify_keys.items():
|
# TODO(markjh): Store whether the keys have expired.
|
||||||
# TODO(markjh): Store whether the keys have expired.
|
yield defer.gatherResults(
|
||||||
yield self.store.store_server_verify_key(
|
[
|
||||||
server_name, server_name, key.time_added, key
|
self.store.store_server_verify_key(
|
||||||
)
|
server_name, server_name, key.time_added, key
|
||||||
|
)
|
||||||
|
for key_id, key in verify_keys.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
|
@ -74,6 +74,8 @@ def prune_event(event):
|
||||||
)
|
)
|
||||||
elif event_type == EventTypes.Aliases:
|
elif event_type == EventTypes.Aliases:
|
||||||
add_fields("aliases")
|
add_fields("aliases")
|
||||||
|
elif event_type == EventTypes.RoomHistoryVisibility:
|
||||||
|
add_fields("history_visibility")
|
||||||
|
|
||||||
allowed_fields = {
|
allowed_fields = {
|
||||||
k: v
|
k: v
|
||||||
|
|
|
@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class FederationBase(object):
|
class FederationBase(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
|
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||||
|
include_none=False):
|
||||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||||
one. If a PDU fails its signature check then we check if we have it in
|
one. If a PDU fails its signature check then we check if we have it in
|
||||||
the database and if not then request if from the originating server of
|
the database and if not then request if from the originating server of
|
||||||
|
@ -50,84 +51,108 @@ class FederationBase(object):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||||
"""
|
"""
|
||||||
|
deferreds = self._check_sigs_and_hashes(pdus)
|
||||||
|
|
||||||
signed_pdus = []
|
def callback(pdu):
|
||||||
|
return pdu
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def errback(failure, pdu):
|
||||||
def do(pdu):
|
failure.trap(SynapseError)
|
||||||
try:
|
return None
|
||||||
new_pdu = yield self._check_sigs_and_hash(pdu)
|
|
||||||
signed_pdus.append(new_pdu)
|
|
||||||
except SynapseError:
|
|
||||||
# FIXME: We should handle signature failures more gracefully.
|
|
||||||
|
|
||||||
|
def try_local_db(res, pdu):
|
||||||
|
if not res:
|
||||||
# Check local db.
|
# Check local db.
|
||||||
new_pdu = yield self.store.get_event(
|
return self.store.get_event(
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
allow_rejected=True,
|
allow_rejected=True,
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
if new_pdu:
|
return res
|
||||||
signed_pdus.append(new_pdu)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check pdu.origin
|
def try_remote(res, pdu):
|
||||||
if pdu.origin != origin:
|
if not res and pdu.origin != origin:
|
||||||
try:
|
return self.get_pdu(
|
||||||
new_pdu = yield self.get_pdu(
|
destinations=[pdu.origin],
|
||||||
destinations=[pdu.origin],
|
event_id=pdu.event_id,
|
||||||
event_id=pdu.event_id,
|
outlier=outlier,
|
||||||
outlier=outlier,
|
timeout=10000,
|
||||||
timeout=10000,
|
).addErrback(lambda e: None)
|
||||||
)
|
return res
|
||||||
|
|
||||||
if new_pdu:
|
|
||||||
signed_pdus.append(new_pdu)
|
|
||||||
return
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
def warn(res, pdu):
|
||||||
|
if not res:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Failed to find copy of %s with valid signature",
|
"Failed to find copy of %s with valid signature",
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
)
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
yield defer.gatherResults(
|
for pdu, deferred in zip(pdus, deferreds):
|
||||||
[do(pdu) for pdu in pdus],
|
deferred.addCallbacks(
|
||||||
|
callback, errback, errbackArgs=[pdu]
|
||||||
|
).addCallback(
|
||||||
|
try_local_db, pdu
|
||||||
|
).addCallback(
|
||||||
|
try_remote, pdu
|
||||||
|
).addCallback(
|
||||||
|
warn, pdu
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_pdus = yield defer.gatherResults(
|
||||||
|
deferreds,
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(signed_pdus)
|
if include_none:
|
||||||
|
defer.returnValue(valid_pdus)
|
||||||
|
else:
|
||||||
|
defer.returnValue([p for p in valid_pdus if p])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _check_sigs_and_hash(self, pdu):
|
def _check_sigs_and_hash(self, pdu):
|
||||||
"""Throws a SynapseError if the PDU does not have the correct
|
return self._check_sigs_and_hashes([pdu])[0]
|
||||||
|
|
||||||
|
def _check_sigs_and_hashes(self, pdus):
|
||||||
|
"""Throws a SynapseError if a PDU does not have the correct
|
||||||
signatures.
|
signatures.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FrozenEvent: Either the given event or it redacted if it failed the
|
FrozenEvent: Either the given event or it redacted if it failed the
|
||||||
content hash check.
|
content hash check.
|
||||||
"""
|
"""
|
||||||
# Check signatures are correct.
|
|
||||||
redacted_event = prune_event(pdu)
|
|
||||||
redacted_pdu_json = redacted_event.get_pdu_json()
|
|
||||||
|
|
||||||
try:
|
redacted_pdus = [
|
||||||
yield self.keyring.verify_json_for_server(
|
prune_event(pdu)
|
||||||
pdu.origin, redacted_pdu_json
|
for pdu in pdus
|
||||||
)
|
]
|
||||||
except SynapseError:
|
|
||||||
|
deferreds = self.keyring.verify_json_objects_for_server([
|
||||||
|
(p.origin, p.get_pdu_json())
|
||||||
|
for p in redacted_pdus
|
||||||
|
])
|
||||||
|
|
||||||
|
def callback(_, pdu, redacted):
|
||||||
|
if not check_event_content_hash(pdu):
|
||||||
|
logger.warn(
|
||||||
|
"Event content has been tampered, redacting %s: %s",
|
||||||
|
pdu.event_id, pdu.get_pdu_json()
|
||||||
|
)
|
||||||
|
return redacted
|
||||||
|
return pdu
|
||||||
|
|
||||||
|
def errback(failure, pdu):
|
||||||
|
failure.trap(SynapseError)
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Signature check failed for %s",
|
"Signature check failed for %s",
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
)
|
)
|
||||||
raise
|
return failure
|
||||||
|
|
||||||
if not check_event_content_hash(pdu):
|
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
|
||||||
logger.warn(
|
deferred.addCallbacks(
|
||||||
"Event content has been tampered, redacting.",
|
callback, errback,
|
||||||
pdu.event_id,
|
callbackArgs=[pdu, redacted],
|
||||||
|
errbackArgs=[pdu],
|
||||||
)
|
)
|
||||||
defer.returnValue(redacted_event)
|
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
return deferreds
|
||||||
|
|
|
@ -30,6 +30,7 @@ import synapse.metrics
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
||||||
|
|
||||||
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
@ -167,7 +168,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
# FIXME: We should handle signature failures more gracefully.
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
pdus[:] = yield defer.gatherResults(
|
pdus[:] = yield defer.gatherResults(
|
||||||
[self._check_sigs_and_hash(pdu) for pdu in pdus],
|
self._check_sigs_and_hashes(pdus),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
@ -230,7 +231,7 @@ class FederationClient(FederationBase):
|
||||||
pdu = pdu_list[0]
|
pdu = pdu_list[0]
|
||||||
|
|
||||||
# Check signatures are correct.
|
# Check signatures are correct.
|
||||||
pdu = yield self._check_sigs_and_hash(pdu)
|
pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -327,6 +328,9 @@ class FederationClient(FederationBase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def make_join(self, destinations, room_id, user_id):
|
def make_join(self, destinations, room_id, user_id):
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
|
if destination == self.server_name:
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret = yield self.transport_layer.make_join(
|
ret = yield self.transport_layer.make_join(
|
||||||
destination, room_id, user_id
|
destination, room_id, user_id
|
||||||
|
@ -353,6 +357,9 @@ class FederationClient(FederationBase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_join(self, destinations, pdu):
|
def send_join(self, destinations, pdu):
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
|
if destination == self.server_name:
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
_, content = yield self.transport_layer.send_join(
|
_, content = yield self.transport_layer.send_join(
|
||||||
|
@ -374,17 +381,39 @@ class FederationClient(FederationBase):
|
||||||
for p in content.get("auth_chain", [])
|
for p in content.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
signed_state, signed_auth = yield defer.gatherResults(
|
pdus = {
|
||||||
[
|
p.event_id: p
|
||||||
self._check_sigs_and_hash_and_fetch(
|
for p in itertools.chain(state, auth_chain)
|
||||||
destination, state, outlier=True
|
}
|
||||||
),
|
|
||||||
self._check_sigs_and_hash_and_fetch(
|
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, auth_chain, outlier=True
|
destination, pdus.values(),
|
||||||
)
|
outlier=True,
|
||||||
],
|
)
|
||||||
consumeErrors=True
|
|
||||||
).addErrback(unwrapFirstError)
|
valid_pdus_map = {
|
||||||
|
p.event_id: p
|
||||||
|
for p in valid_pdus
|
||||||
|
}
|
||||||
|
|
||||||
|
# NB: We *need* to copy to ensure that we don't have multiple
|
||||||
|
# references being passed on, as that causes... issues.
|
||||||
|
signed_state = [
|
||||||
|
copy.copy(valid_pdus_map[p.event_id])
|
||||||
|
for p in state
|
||||||
|
if p.event_id in valid_pdus_map
|
||||||
|
]
|
||||||
|
|
||||||
|
signed_auth = [
|
||||||
|
valid_pdus_map[p.event_id]
|
||||||
|
for p in auth_chain
|
||||||
|
if p.event_id in valid_pdus_map
|
||||||
|
]
|
||||||
|
|
||||||
|
# NB: We *need* to copy to ensure that we don't have multiple
|
||||||
|
# references being passed on, as that causes... issues.
|
||||||
|
for s in signed_state:
|
||||||
|
s.internal_metadata = copy.deepcopy(s.internal_metadata)
|
||||||
|
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
|
@ -396,7 +425,7 @@ class FederationClient(FederationBase):
|
||||||
except CodeMessageException:
|
except CodeMessageException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(
|
logger.exception(
|
||||||
"Failed to send_join via %s: %s",
|
"Failed to send_join via %s: %s",
|
||||||
destination, e.message
|
destination, e.message
|
||||||
)
|
)
|
||||||
|
|
|
@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
|
||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -222,6 +224,56 @@ class FederationHandler(BaseHandler):
|
||||||
"user_joined_room", user=user, room_id=event.room_id
|
"user_joined_room", user=user, room_id=event.room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _filter_events_for_server(self, server_name, room_id, events):
|
||||||
|
states = yield self.store.get_state_for_events(
|
||||||
|
room_id, [e.event_id for e in events],
|
||||||
|
)
|
||||||
|
|
||||||
|
events_and_states = zip(events, states)
|
||||||
|
|
||||||
|
def redact_disallowed(event_and_state):
|
||||||
|
event, state = event_and_state
|
||||||
|
|
||||||
|
if not state:
|
||||||
|
return event
|
||||||
|
|
||||||
|
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||||
|
if history:
|
||||||
|
visibility = history.content.get("history_visibility", "shared")
|
||||||
|
if visibility in ["invited", "joined"]:
|
||||||
|
# We now loop through all state events looking for
|
||||||
|
# membership states for the requesting server to determine
|
||||||
|
# if the server is either in the room or has been invited
|
||||||
|
# into the room.
|
||||||
|
for ev in state.values():
|
||||||
|
if ev.type != EventTypes.Member:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
domain = UserID.from_string(ev.state_key).domain
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if domain != server_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
memtype = ev.membership
|
||||||
|
if memtype == Membership.JOIN:
|
||||||
|
return event
|
||||||
|
elif memtype == Membership.INVITE:
|
||||||
|
if visibility == "invited":
|
||||||
|
return event
|
||||||
|
else:
|
||||||
|
return prune_event(event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
res = map(redact_disallowed, events_and_states)
|
||||||
|
|
||||||
|
logger.info("_filter_events_for_server %r", res)
|
||||||
|
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def backfill(self, dest, room_id, limit, extremities=[]):
|
def backfill(self, dest, room_id, limit, extremities=[]):
|
||||||
|
@ -882,6 +934,8 @@ class FederationHandler(BaseHandler):
|
||||||
limit
|
limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
events = yield self._filter_events_for_server(origin, room_id, events)
|
||||||
|
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -113,11 +113,21 @@ class MessageHandler(BaseHandler):
|
||||||
"room_key", next_key
|
"room_key", next_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not events:
|
||||||
|
defer.returnValue({
|
||||||
|
"chunk": [],
|
||||||
|
"start": pagin_config.from_token.to_string(),
|
||||||
|
"end": next_token.to_string(),
|
||||||
|
})
|
||||||
|
|
||||||
|
events = yield self._filter_events_for_client(user_id, room_id, events)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
chunk = {
|
chunk = {
|
||||||
"chunk": [
|
"chunk": [
|
||||||
serialize_event(e, time_now, as_client_event) for e in events
|
serialize_event(e, time_now, as_client_event)
|
||||||
|
for e in events
|
||||||
],
|
],
|
||||||
"start": pagin_config.from_token.to_string(),
|
"start": pagin_config.from_token.to_string(),
|
||||||
"end": next_token.to_string(),
|
"end": next_token.to_string(),
|
||||||
|
@ -125,6 +135,52 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
defer.returnValue(chunk)
|
defer.returnValue(chunk)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _filter_events_for_client(self, user_id, room_id, events):
|
||||||
|
states = yield self.store.get_state_for_events(
|
||||||
|
room_id, [e.event_id for e in events],
|
||||||
|
)
|
||||||
|
|
||||||
|
events_and_states = zip(events, states)
|
||||||
|
|
||||||
|
def allowed(event_and_state):
|
||||||
|
event, state = event_and_state
|
||||||
|
|
||||||
|
if event.type == EventTypes.RoomHistoryVisibility:
|
||||||
|
return True
|
||||||
|
|
||||||
|
membership_ev = state.get((EventTypes.Member, user_id), None)
|
||||||
|
if membership_ev:
|
||||||
|
membership = membership_ev.membership
|
||||||
|
else:
|
||||||
|
membership = Membership.LEAVE
|
||||||
|
|
||||||
|
if membership == Membership.JOIN:
|
||||||
|
return True
|
||||||
|
|
||||||
|
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||||
|
if history:
|
||||||
|
visibility = history.content.get("history_visibility", "shared")
|
||||||
|
else:
|
||||||
|
visibility = "shared"
|
||||||
|
|
||||||
|
if visibility == "public":
|
||||||
|
return True
|
||||||
|
elif visibility == "shared":
|
||||||
|
return True
|
||||||
|
elif visibility == "joined":
|
||||||
|
return membership == Membership.JOIN
|
||||||
|
elif visibility == "invited":
|
||||||
|
return membership == Membership.INVITE
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
events_and_states = filter(allowed, events_and_states)
|
||||||
|
defer.returnValue([
|
||||||
|
ev
|
||||||
|
for ev, _ in events_and_states
|
||||||
|
])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_and_send_event(self, event_dict, ratelimit=True,
|
def create_and_send_event(self, event_dict, ratelimit=True,
|
||||||
client=None, txn_id=None):
|
client=None, txn_id=None):
|
||||||
|
@ -321,6 +377,10 @@ class MessageHandler(BaseHandler):
|
||||||
]
|
]
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
messages = yield self._filter_events_for_client(
|
||||||
|
user_id, event.room_id, messages
|
||||||
|
)
|
||||||
|
|
||||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
@ -426,6 +486,10 @@ class MessageHandler(BaseHandler):
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
messages = yield self._filter_events_for_client(
|
||||||
|
user_id, room_id, messages
|
||||||
|
)
|
||||||
|
|
||||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||||
|
|
||||||
|
|
|
@ -192,6 +192,35 @@ class RegistrationHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
logger.info("Valid captcha entered from %s", ip)
|
logger.info("Valid captcha entered from %s", ip)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def register_saml2(self, localpart):
|
||||||
|
"""
|
||||||
|
Registers email_id as SAML2 Based Auth.
|
||||||
|
"""
|
||||||
|
if urllib.quote(localpart) != localpart:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"User ID must only contain characters which do not"
|
||||||
|
" require URL encoding."
|
||||||
|
)
|
||||||
|
user = UserID(localpart, self.hs.hostname)
|
||||||
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
token = self._generate_token(user_id)
|
||||||
|
try:
|
||||||
|
yield self.store.register(
|
||||||
|
user_id=user_id,
|
||||||
|
token=token,
|
||||||
|
password_hash=None
|
||||||
|
)
|
||||||
|
yield self.distributor.fire("registered_user", user)
|
||||||
|
except Exception, e:
|
||||||
|
yield self.store.add_access_token_to_user(user_id, token)
|
||||||
|
# Ignore Registration errors
|
||||||
|
logger.exception(e)
|
||||||
|
defer.returnValue((user_id, token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register_email(self, threepidCreds):
|
def register_email(self, threepidCreds):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -213,6 +213,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
"events": {
|
"events": {
|
||||||
EventTypes.Name: 100,
|
EventTypes.Name: 100,
|
||||||
EventTypes.PowerLevels: 100,
|
EventTypes.PowerLevels: 100,
|
||||||
|
EventTypes.RoomHistoryVisibility: 100,
|
||||||
},
|
},
|
||||||
"events_default": 0,
|
"events_default": 0,
|
||||||
"state_default": 50,
|
"state_default": 50,
|
||||||
|
|
|
@ -292,6 +292,51 @@ class SyncHandler(BaseHandler):
|
||||||
next_batch=now_token,
|
next_batch=now_token,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _filter_events_for_client(self, user_id, room_id, events):
|
||||||
|
states = yield self.store.get_state_for_events(
|
||||||
|
room_id, [e.event_id for e in events],
|
||||||
|
)
|
||||||
|
|
||||||
|
events_and_states = zip(events, states)
|
||||||
|
|
||||||
|
def allowed(event_and_state):
|
||||||
|
event, state = event_and_state
|
||||||
|
|
||||||
|
if event.type == EventTypes.RoomHistoryVisibility:
|
||||||
|
return True
|
||||||
|
|
||||||
|
membership_ev = state.get((EventTypes.Member, user_id), None)
|
||||||
|
if membership_ev:
|
||||||
|
membership = membership_ev.membership
|
||||||
|
else:
|
||||||
|
membership = Membership.LEAVE
|
||||||
|
|
||||||
|
if membership == Membership.JOIN:
|
||||||
|
return True
|
||||||
|
|
||||||
|
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||||
|
if history:
|
||||||
|
visibility = history.content.get("history_visibility", "shared")
|
||||||
|
else:
|
||||||
|
visibility = "shared"
|
||||||
|
|
||||||
|
if visibility == "public":
|
||||||
|
return True
|
||||||
|
elif visibility == "shared":
|
||||||
|
return True
|
||||||
|
elif visibility == "joined":
|
||||||
|
return membership == Membership.JOIN
|
||||||
|
elif visibility == "invited":
|
||||||
|
return membership == Membership.INVITE
|
||||||
|
|
||||||
|
return True
|
||||||
|
events_and_states = filter(allowed, events_and_states)
|
||||||
|
defer.returnValue([
|
||||||
|
ev
|
||||||
|
for ev, _ in events_and_states
|
||||||
|
])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def load_filtered_recents(self, room_id, sync_config, now_token,
|
def load_filtered_recents(self, room_id, sync_config, now_token,
|
||||||
since_token=None):
|
since_token=None):
|
||||||
|
@ -313,6 +358,9 @@ class SyncHandler(BaseHandler):
|
||||||
(room_key, _) = keys
|
(room_key, _) = keys
|
||||||
end_key = "s" + room_key.split('-')[-1]
|
end_key = "s" + room_key.split('-')[-1]
|
||||||
loaded_recents = sync_config.filter.filter_room_events(events)
|
loaded_recents = sync_config.filter.filter_room_events(events)
|
||||||
|
loaded_recents = yield self._filter_events_for_client(
|
||||||
|
sync_config.user.to_string(), room_id, loaded_recents,
|
||||||
|
)
|
||||||
loaded_recents.extend(recents)
|
loaded_recents.extend(recents)
|
||||||
recents = loaded_recents
|
recents = loaded_recents
|
||||||
if len(events) <= load_limit:
|
if len(events) <= load_limit:
|
||||||
|
|
|
@ -32,6 +32,7 @@ REQUIREMENTS = {
|
||||||
"pydenticon": ["pydenticon"],
|
"pydenticon": ["pydenticon"],
|
||||||
"ujson": ["ujson"],
|
"ujson": ["ujson"],
|
||||||
"blist": ["blist"],
|
"blist": ["blist"],
|
||||||
|
"pysaml2": ["saml2"],
|
||||||
}
|
}
|
||||||
CONDITIONAL_REQUIREMENTS = {
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
"web_client": {
|
"web_client": {
|
||||||
|
|
|
@ -20,14 +20,32 @@ from synapse.types import UserID
|
||||||
from base import ClientV1RestServlet, client_path_pattern
|
from base import ClientV1RestServlet, client_path_pattern
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
import urllib
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from saml2 import BINDING_HTTP_POST
|
||||||
|
from saml2 import config
|
||||||
|
from saml2.client import Saml2Client
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoginRestServlet(ClientV1RestServlet):
|
class LoginRestServlet(ClientV1RestServlet):
|
||||||
PATTERN = client_path_pattern("/login$")
|
PATTERN = client_path_pattern("/login$")
|
||||||
PASS_TYPE = "m.login.password"
|
PASS_TYPE = "m.login.password"
|
||||||
|
SAML2_TYPE = "m.login.saml2"
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(LoginRestServlet, self).__init__(hs)
|
||||||
|
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
||||||
|
self.saml2_enabled = hs.config.saml2_enabled
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
|
flows = [{"type": LoginRestServlet.PASS_TYPE}]
|
||||||
|
if self.saml2_enabled:
|
||||||
|
flows.append({"type": LoginRestServlet.SAML2_TYPE})
|
||||||
|
return (200, {"flows": flows})
|
||||||
|
|
||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return (200, {})
|
return (200, {})
|
||||||
|
@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
||||||
result = yield self.do_password_login(login_submission)
|
result = yield self.do_password_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
elif self.saml2_enabled and (login_submission["type"] ==
|
||||||
|
LoginRestServlet.SAML2_TYPE):
|
||||||
|
relay_state = ""
|
||||||
|
if "relay_state" in login_submission:
|
||||||
|
relay_state = "&RelayState="+urllib.quote(
|
||||||
|
login_submission["relay_state"])
|
||||||
|
result = {
|
||||||
|
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
|
||||||
|
}
|
||||||
|
defer.returnValue((200, result))
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "Bad login type.")
|
raise SynapseError(400, "Bad login type.")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -94,6 +122,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SAML2RestServlet(ClientV1RestServlet):
|
||||||
|
PATTERN = client_path_pattern("/login/saml2")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(SAML2RestServlet, self).__init__(hs)
|
||||||
|
self.sp_config = hs.config.saml2_config_path
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
saml2_auth = None
|
||||||
|
try:
|
||||||
|
conf = config.SPConfig()
|
||||||
|
conf.load_file(self.sp_config)
|
||||||
|
SP = Saml2Client(conf)
|
||||||
|
saml2_auth = SP.parse_authn_request_response(
|
||||||
|
request.args['SAMLResponse'][0], BINDING_HTTP_POST)
|
||||||
|
except Exception, e: # Not authenticated
|
||||||
|
logger.exception(e)
|
||||||
|
if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
|
||||||
|
username = saml2_auth.name_id.text
|
||||||
|
handler = self.handlers.registration_handler
|
||||||
|
(user_id, token) = yield handler.register_saml2(username)
|
||||||
|
# Forward to the RelayState callback along with ava
|
||||||
|
if 'RelayState' in request.args:
|
||||||
|
request.redirect(urllib.unquote(
|
||||||
|
request.args['RelayState'][0]) +
|
||||||
|
'?status=authenticated&access_token=' +
|
||||||
|
token + '&user_id=' + user_id + '&ava=' +
|
||||||
|
urllib.quote(json.dumps(saml2_auth.ava)))
|
||||||
|
request.finish()
|
||||||
|
defer.returnValue(None)
|
||||||
|
defer.returnValue((200, {"status": "authenticated",
|
||||||
|
"user_id": user_id, "token": token,
|
||||||
|
"ava": saml2_auth.ava}))
|
||||||
|
elif 'RelayState' in request.args:
|
||||||
|
request.redirect(urllib.unquote(
|
||||||
|
request.args['RelayState'][0]) +
|
||||||
|
'?status=not_authenticated')
|
||||||
|
request.finish()
|
||||||
|
defer.returnValue(None)
|
||||||
|
defer.returnValue((200, {"status": "not_authenticated"}))
|
||||||
|
|
||||||
|
|
||||||
def _parse_json(request):
|
def _parse_json(request):
|
||||||
try:
|
try:
|
||||||
content = json.loads(request.content.read())
|
content = json.loads(request.content.read())
|
||||||
|
@ -106,4 +177,6 @@ def _parse_json(request):
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
LoginRestServlet(hs).register(http_server)
|
LoginRestServlet(hs).register(http_server)
|
||||||
|
if hs.config.saml2_enabled:
|
||||||
|
SAML2RestServlet(hs).register(http_server)
|
||||||
# TODO PasswordResetRestServlet(hs).register(http_server)
|
# TODO PasswordResetRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from . import (
|
||||||
register,
|
register,
|
||||||
auth,
|
auth,
|
||||||
receipts,
|
receipts,
|
||||||
|
keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
|
@ -40,3 +41,4 @@ class ClientV2AlphaRestResource(JsonResource):
|
||||||
register.register_servlets(hs, client_resource)
|
register.register_servlets(hs, client_resource)
|
||||||
auth.register_servlets(hs, client_resource)
|
auth.register_servlets(hs, client_resource)
|
||||||
receipts.register_servlets(hs, client_resource)
|
receipts.register_servlets(hs, client_resource)
|
||||||
|
keys.register_servlets(hs, client_resource)
|
||||||
|
|
276
synapse/rest/client/v2_alpha/keys.py
Normal file
276
synapse/rest/client/v2_alpha/keys.py
Normal file
|
@ -0,0 +1,276 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.http.servlet import RestServlet
|
||||||
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
|
from ._base import client_v2_pattern
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyUploadServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
POST /keys/upload/<device_id> HTTP/1.1
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"user_id": "<user_id>",
|
||||||
|
"device_id": "<device_id>",
|
||||||
|
"valid_until_ts": <millisecond_timestamp>,
|
||||||
|
"algorithms": [
|
||||||
|
"m.olm.curve25519-aes-sha256",
|
||||||
|
]
|
||||||
|
"keys": {
|
||||||
|
"<algorithm>:<device_id>": "<key_base64>",
|
||||||
|
},
|
||||||
|
"signatures:" {
|
||||||
|
"<user_id>" {
|
||||||
|
"<algorithm>:<device_id>": "<signature_base64>"
|
||||||
|
} } },
|
||||||
|
"one_time_keys": {
|
||||||
|
"<algorithm>:<key_id>": "<key_base64>"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(KeyUploadServlet, self).__init__()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, device_id):
|
||||||
|
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = auth_user.to_string()
|
||||||
|
# TODO: Check that the device_id matches that in the authentication
|
||||||
|
# or derive the device_id from the authentication instead.
|
||||||
|
try:
|
||||||
|
body = json.loads(request.content.read())
|
||||||
|
except:
|
||||||
|
raise SynapseError(400, "Invalid key JSON")
|
||||||
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
# TODO: Validate the JSON to make sure it has the right keys.
|
||||||
|
device_keys = body.get("device_keys", None)
|
||||||
|
if device_keys:
|
||||||
|
logger.info(
|
||||||
|
"Updating device_keys for device %r for user %r at %d",
|
||||||
|
device_id, auth_user, time_now
|
||||||
|
)
|
||||||
|
# TODO: Sign the JSON with the server key
|
||||||
|
yield self.store.set_e2e_device_keys(
|
||||||
|
user_id, device_id, time_now,
|
||||||
|
encode_canonical_json(device_keys)
|
||||||
|
)
|
||||||
|
|
||||||
|
one_time_keys = body.get("one_time_keys", None)
|
||||||
|
if one_time_keys:
|
||||||
|
logger.info(
|
||||||
|
"Adding %d one_time_keys for device %r for user %r at %d",
|
||||||
|
len(one_time_keys), device_id, user_id, time_now
|
||||||
|
)
|
||||||
|
key_list = []
|
||||||
|
for key_id, key_json in one_time_keys.items():
|
||||||
|
algorithm, key_id = key_id.split(":")
|
||||||
|
key_list.append((
|
||||||
|
algorithm, key_id, encode_canonical_json(key_json)
|
||||||
|
))
|
||||||
|
|
||||||
|
yield self.store.add_e2e_one_time_keys(
|
||||||
|
user_id, device_id, time_now, key_list
|
||||||
|
)
|
||||||
|
|
||||||
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
|
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, device_id):
|
||||||
|
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = auth_user.to_string()
|
||||||
|
|
||||||
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
|
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||||
|
|
||||||
|
|
||||||
|
class KeyQueryServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
GET /keys/query/<user_id> HTTP/1.1
|
||||||
|
|
||||||
|
GET /keys/query/<user_id>/<device_id> HTTP/1.1
|
||||||
|
|
||||||
|
POST /keys/query HTTP/1.1
|
||||||
|
Content-Type: application/json
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": ["<device_id>"]
|
||||||
|
} }
|
||||||
|
|
||||||
|
HTTP/1.1 200 OK
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {
|
||||||
|
"user_id": "<user_id>", // Duplicated to be signed
|
||||||
|
"device_id": "<device_id>", // Duplicated to be signed
|
||||||
|
"valid_until_ts": <millisecond_timestamp>,
|
||||||
|
"algorithms": [ // List of supported algorithms
|
||||||
|
"m.olm.curve25519-aes-sha256",
|
||||||
|
],
|
||||||
|
"keys": { // Must include a ed25519 signing key
|
||||||
|
"<algorithm>:<key_id>": "<key_base64>",
|
||||||
|
},
|
||||||
|
"signatures:" {
|
||||||
|
// Must be signed with device's ed25519 key
|
||||||
|
"<user_id>/<device_id>": {
|
||||||
|
"<algorithm>:<key_id>": "<signature_base64>"
|
||||||
|
}
|
||||||
|
// Must be signed by this server.
|
||||||
|
"<server_name>": {
|
||||||
|
"<algorithm>:<key_id>": "<signature_base64>"
|
||||||
|
} } } } } }
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERN = client_v2_pattern(
|
||||||
|
"/keys/query(?:"
|
||||||
|
"/(?P<user_id>[^/]*)(?:"
|
||||||
|
"/(?P<device_id>[^/]*)"
|
||||||
|
")?"
|
||||||
|
")?"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(KeyQueryServlet, self).__init__()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, user_id, device_id):
|
||||||
|
logger.debug("onPOST")
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
try:
|
||||||
|
body = json.loads(request.content.read())
|
||||||
|
except:
|
||||||
|
raise SynapseError(400, "Invalid key JSON")
|
||||||
|
query = []
|
||||||
|
for user_id, device_ids in body.get("device_keys", {}).items():
|
||||||
|
if not device_ids:
|
||||||
|
query.append((user_id, None))
|
||||||
|
else:
|
||||||
|
for device_id in device_ids:
|
||||||
|
query.append((user_id, device_id))
|
||||||
|
results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
|
||||||
|
defer.returnValue(self.json_result(request, results))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, user_id, device_id):
|
||||||
|
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||||
|
auth_user_id = auth_user.to_string()
|
||||||
|
if not user_id:
|
||||||
|
user_id = auth_user_id
|
||||||
|
if not device_id:
|
||||||
|
device_id = None
|
||||||
|
# Returns a map of user_id->device_id->json_bytes.
|
||||||
|
results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
|
||||||
|
defer.returnValue(self.json_result(request, results))
|
||||||
|
|
||||||
|
def json_result(self, request, results):
|
||||||
|
json_result = {}
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, json_bytes in device_keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||||
|
json_bytes
|
||||||
|
)
|
||||||
|
return (200, {"device_keys": json_result})
|
||||||
|
|
||||||
|
|
||||||
|
class OneTimeKeyServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
GET /keys/take/<user-id>/<device-id>/<algorithm> HTTP/1.1
|
||||||
|
|
||||||
|
POST /keys/take HTTP/1.1
|
||||||
|
{
|
||||||
|
"one_time_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": "<algorithm>"
|
||||||
|
} } }
|
||||||
|
|
||||||
|
HTTP/1.1 200 OK
|
||||||
|
{
|
||||||
|
"one_time_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {
|
||||||
|
"<algorithm>:<key_id>": "<key_base64>"
|
||||||
|
} } } }
|
||||||
|
|
||||||
|
"""
|
||||||
|
PATTERN = client_v2_pattern(
|
||||||
|
"/keys/take(?:/?|(?:/"
|
||||||
|
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
|
||||||
|
")?)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(OneTimeKeyServlet, self).__init__()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, user_id, device_id, algorithm):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
results = yield self.store.take_e2e_one_time_keys(
|
||||||
|
[(user_id, device_id, algorithm)]
|
||||||
|
)
|
||||||
|
defer.returnValue(self.json_result(request, results))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, user_id, device_id, algorithm):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
try:
|
||||||
|
body = json.loads(request.content.read())
|
||||||
|
except:
|
||||||
|
raise SynapseError(400, "Invalid key JSON")
|
||||||
|
query = []
|
||||||
|
for user_id, device_keys in body.get("one_time_keys", {}).items():
|
||||||
|
for device_id, algorithm in device_keys.items():
|
||||||
|
query.append((user_id, device_id, algorithm))
|
||||||
|
results = yield self.store.take_e2e_one_time_keys(query)
|
||||||
|
defer.returnValue(self.json_result(request, results))
|
||||||
|
|
||||||
|
def json_result(self, request, results):
|
||||||
|
json_result = {}
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, keys in device_keys.items():
|
||||||
|
for key_id, json_bytes in keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = {
|
||||||
|
key_id: json.loads(json_bytes)
|
||||||
|
}
|
||||||
|
return (200, {"one_time_keys": json_result})
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
KeyUploadServlet(hs).register(http_server)
|
||||||
|
KeyQueryServlet(hs).register(http_server)
|
||||||
|
OneTimeKeyServlet(hs).register(http_server)
|
|
@ -37,6 +37,7 @@ from .rejections import RejectionsStore
|
||||||
from .state import StateStore
|
from .state import StateStore
|
||||||
from .signatures import SignatureStore
|
from .signatures import SignatureStore
|
||||||
from .filtering import FilteringStore
|
from .filtering import FilteringStore
|
||||||
|
from .end_to_end_keys import EndToEndKeyStore
|
||||||
|
|
||||||
from .receipts import ReceiptsStore
|
from .receipts import ReceiptsStore
|
||||||
|
|
||||||
|
@ -77,6 +78,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
ApplicationServiceTransactionStore,
|
ApplicationServiceTransactionStore,
|
||||||
EventsStore,
|
EventsStore,
|
||||||
ReceiptsStore,
|
ReceiptsStore,
|
||||||
|
EndToEndKeyStore,
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
|
125
synapse/storage/end_to_end_keys.py
Normal file
125
synapse/storage/end_to_end_keys.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from _base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
|
class EndToEndKeyStore(SQLBaseStore):
|
||||||
|
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
|
||||||
|
return self._simple_upsert(
|
||||||
|
table="e2e_device_keys_json",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"ts_added_ms": time_now,
|
||||||
|
"key_json": json_bytes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_e2e_device_keys(self, query_list):
|
||||||
|
"""Fetch a list of device keys.
|
||||||
|
Args:
|
||||||
|
query_list(list): List of pairs of user_ids and device_ids.
|
||||||
|
Returns:
|
||||||
|
Dict mapping from user-id to dict mapping from device_id to
|
||||||
|
key json byte strings.
|
||||||
|
"""
|
||||||
|
def _get_e2e_device_keys(txn):
|
||||||
|
result = {}
|
||||||
|
for user_id, device_id in query_list:
|
||||||
|
user_result = result.setdefault(user_id, {})
|
||||||
|
keyvalues = {"user_id": user_id}
|
||||||
|
if device_id:
|
||||||
|
keyvalues["device_id"] = device_id
|
||||||
|
rows = self._simple_select_list_txn(
|
||||||
|
txn, table="e2e_device_keys_json",
|
||||||
|
keyvalues=keyvalues,
|
||||||
|
retcols=["device_id", "key_json"]
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
user_result[row["device_id"]] = row["key_json"]
|
||||||
|
return result
|
||||||
|
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
|
||||||
|
|
||||||
|
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||||
|
def _add_e2e_one_time_keys(txn):
|
||||||
|
for (algorithm, key_id, json_bytes) in key_list:
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn, table="e2e_one_time_keys_json",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"algorithm": algorithm,
|
||||||
|
"key_id": key_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"ts_added_ms": time_now,
|
||||||
|
"key_json": json_bytes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"add_e2e_one_time_keys", _add_e2e_one_time_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||||
|
""" Count the number of one time keys the server has for a device
|
||||||
|
Returns:
|
||||||
|
Dict mapping from algorithm to number of keys for that algorithm.
|
||||||
|
"""
|
||||||
|
def _count_e2e_one_time_keys(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
|
||||||
|
" WHERE user_id = ? AND device_id = ?"
|
||||||
|
" GROUP BY algorithm"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, device_id))
|
||||||
|
result = {}
|
||||||
|
for algorithm, key_count in txn.fetchall():
|
||||||
|
result[algorithm] = key_count
|
||||||
|
return result
|
||||||
|
return self.runInteraction(
|
||||||
|
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
def take_e2e_one_time_keys(self, query_list):
|
||||||
|
"""Take a list of one time keys out of the database"""
|
||||||
|
def _take_e2e_one_time_keys(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
|
||||||
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||||
|
" LIMIT 1"
|
||||||
|
)
|
||||||
|
result = {}
|
||||||
|
delete = []
|
||||||
|
for user_id, device_id, algorithm in query_list:
|
||||||
|
user_result = result.setdefault(user_id, {})
|
||||||
|
device_result = user_result.setdefault(device_id, {})
|
||||||
|
txn.execute(sql, (user_id, device_id, algorithm))
|
||||||
|
for key_id, key_json in txn.fetchall():
|
||||||
|
device_result[algorithm + ":" + key_id] = key_json
|
||||||
|
delete.append((user_id, device_id, algorithm, key_id))
|
||||||
|
sql = (
|
||||||
|
"DELETE FROM e2e_one_time_keys_json"
|
||||||
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||||
|
" AND key_id = ?"
|
||||||
|
)
|
||||||
|
for user_id, device_id, algorithm, key_id in delete:
|
||||||
|
txn.execute(sql, (user_id, device_id, algorithm, key_id))
|
||||||
|
return result
|
||||||
|
return self.runInteraction(
|
||||||
|
"take_e2e_one_time_keys", _take_e2e_one_time_keys
|
||||||
|
)
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from _base import SQLBaseStore
|
from _base import SQLBaseStore, cached
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore):
|
||||||
desc="store_server_certificate",
|
desc="store_server_certificate",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_all_server_verify_keys(self, server_name):
|
||||||
|
rows = yield self._simple_select_list(
|
||||||
|
table="server_signature_keys",
|
||||||
|
keyvalues={
|
||||||
|
"server_name": server_name,
|
||||||
|
},
|
||||||
|
retcols=["key_id", "verify_key"],
|
||||||
|
desc="get_all_server_verify_keys",
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
row["key_id"]: decode_verify_key_bytes(
|
||||||
|
row["key_id"], str(row["verify_key"])
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_keys(self, server_name, key_ids):
|
def get_server_verify_keys(self, server_name, key_ids):
|
||||||
"""Retrieve the NACL verification key for a given server for the given
|
"""Retrieve the NACL verification key for a given server for the given
|
||||||
|
@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
(list of VerifyKey): The verification keys.
|
(list of VerifyKey): The verification keys.
|
||||||
"""
|
"""
|
||||||
sql = (
|
keys = yield self.get_all_server_verify_keys(server_name)
|
||||||
"SELECT key_id, verify_key FROM server_signature_keys"
|
defer.returnValue({
|
||||||
" WHERE server_name = ?"
|
k: keys[k]
|
||||||
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
|
for k in key_ids
|
||||||
)
|
if k in keys and keys[k]
|
||||||
|
})
|
||||||
rows = yield self._execute_and_decode(
|
|
||||||
"get_server_verify_keys", sql, server_name, *key_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
keys = []
|
|
||||||
for row in rows:
|
|
||||||
key_id = row["key_id"]
|
|
||||||
key_bytes = row["verify_key"]
|
|
||||||
key = decode_verify_key_bytes(key_id, str(key_bytes))
|
|
||||||
keys.append(key)
|
|
||||||
defer.returnValue(keys)
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
||||||
verify_key):
|
verify_key):
|
||||||
"""Stores a NACL verification key for the given server.
|
"""Stores a NACL verification key for the given server.
|
||||||
|
@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
|
||||||
ts_now_ms (int): The time now in milliseconds
|
ts_now_ms (int): The time now in milliseconds
|
||||||
verification_key (VerifyKey): The NACL verify key.
|
verification_key (VerifyKey): The NACL verify key.
|
||||||
"""
|
"""
|
||||||
return self._simple_upsert(
|
yield self._simple_upsert(
|
||||||
table="server_signature_keys",
|
table="server_signature_keys",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"server_name": server_name,
|
"server_name": server_name,
|
||||||
|
@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
|
||||||
desc="store_server_verify_key",
|
desc="store_server_verify_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.get_all_server_verify_keys.invalidate(server_name)
|
||||||
|
|
||||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||||
"""Stores the JSON bytes for a set of keys from a server
|
"""Stores the JSON bytes for a set of keys from a server
|
||||||
|
@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
|
||||||
"ts_valid_until_ms": ts_expires_ms,
|
"ts_valid_until_ms": ts_expires_ms,
|
||||||
"key_json": buffer(key_json_bytes),
|
"key_json": buffer(key_json_bytes),
|
||||||
},
|
},
|
||||||
|
desc="store_server_keys_json",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_server_keys_json(self, server_keys):
|
def get_server_keys_json(self, server_keys):
|
||||||
|
|
34
synapse/storage/schema/delta/21/end_to_end_keys.sql
Normal file
34
synapse/storage/schema/delta/21/end_to_end_keys.sql
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
/* Copyright 2015 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS e2e_device_keys_json (
|
||||||
|
user_id TEXT NOT NULL, -- The user these keys are for.
|
||||||
|
device_id TEXT NOT NULL, -- Which of the user's devices these keys are for.
|
||||||
|
ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded.
|
||||||
|
key_json TEXT NOT NULL, -- The keys for the device as a JSON blob.
|
||||||
|
CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
|
||||||
|
user_id TEXT NOT NULL, -- The user this one-time key is for.
|
||||||
|
device_id TEXT NOT NULL, -- The device this one-time key is for.
|
||||||
|
algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
|
||||||
|
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
|
||||||
|
ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
|
||||||
|
key_json TEXT NOT NULL, -- The key as a JSON blob.
|
||||||
|
CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
|
||||||
|
);
|
|
@ -92,11 +92,11 @@ class StateStore(SQLBaseStore):
|
||||||
defer.returnValue(dict(state_list))
|
defer.returnValue(dict(state_list))
|
||||||
|
|
||||||
@cached(num_args=1)
|
@cached(num_args=1)
|
||||||
def _fetch_events_for_group(self, state_group, events):
|
def _fetch_events_for_group(self, key, events):
|
||||||
return self._get_events(
|
return self._get_events(
|
||||||
events, get_prev_content=False
|
events, get_prev_content=False
|
||||||
).addCallback(
|
).addCallback(
|
||||||
lambda evs: (state_group, evs)
|
lambda evs: (key, evs)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _store_state_groups_txn(self, txn, event, context):
|
def _store_state_groups_txn(self, txn, event, context):
|
||||||
|
@ -194,6 +194,65 @@ class StateStore(SQLBaseStore):
|
||||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_for_events(self, room_id, event_ids):
|
||||||
|
def f(txn):
|
||||||
|
groups = set()
|
||||||
|
event_to_group = {}
|
||||||
|
for event_id in event_ids:
|
||||||
|
# TODO: Remove this loop.
|
||||||
|
group = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="event_to_state_groups",
|
||||||
|
keyvalues={"event_id": event_id},
|
||||||
|
retcol="state_group",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if group:
|
||||||
|
event_to_group[event_id] = group
|
||||||
|
groups.add(group)
|
||||||
|
|
||||||
|
group_to_state_ids = {}
|
||||||
|
for group in groups:
|
||||||
|
state_ids = self._simple_select_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups_state",
|
||||||
|
keyvalues={"state_group": group},
|
||||||
|
retcol="event_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
group_to_state_ids[group] = state_ids
|
||||||
|
|
||||||
|
return event_to_group, group_to_state_ids
|
||||||
|
|
||||||
|
res = yield self.runInteraction(
|
||||||
|
"annotate_events_with_state_groups",
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_to_group, group_to_state_ids = res
|
||||||
|
|
||||||
|
state_list = yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
self._fetch_events_for_group(group, vals)
|
||||||
|
for group, vals in group_to_state_ids.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
group: {
|
||||||
|
(ev.type, ev.state_key): ev
|
||||||
|
for ev in state
|
||||||
|
}
|
||||||
|
for group, state in state_list
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue([
|
||||||
|
state_dict.get(event_to_group.get(event, None), None)
|
||||||
|
for event in event_ids
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
def _make_group_id(clock):
|
def _make_group_id(clock):
|
||||||
return str(int(clock.time_msec())) + random_string(5)
|
return str(int(clock.time_msec())) + random_string(5)
|
||||||
|
|
Loading…
Reference in a new issue