Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2019-05-23 15:54:13 +01:00
commit 3cd598135f
68 changed files with 2035 additions and 442 deletions

View file

@ -1,3 +1,54 @@
Synapse 0.99.5.1 (2019-05-22)
=============================
No significant changes.
Synapse 0.99.5 (2019-05-22)
===========================
No significant changes.
Synapse 0.99.5rc1 (2019-05-21)
==============================
Features
--------
- Add ability to blacklist IP ranges for the federation client. ([\#5043](https://github.com/matrix-org/synapse/issues/5043))
- Ratelimiting configuration for clients sending messages and the federation server has been altered to match login ratelimiting. The old configuration names will continue working. Check the sample config for details of the new names. ([\#5181](https://github.com/matrix-org/synapse/issues/5181))
- Drop support for the undocumented /_matrix/client/v2_alpha API prefix. ([\#5190](https://github.com/matrix-org/synapse/issues/5190))
- Add an option to disable per-room profiles. ([\#5196](https://github.com/matrix-org/synapse/issues/5196))
- Stick an expiration date to any registered user missing one at startup if account validity is enabled. ([\#5204](https://github.com/matrix-org/synapse/issues/5204))
- Add experimental support for relations (aka reactions and edits). ([\#5209](https://github.com/matrix-org/synapse/issues/5209), [\#5211](https://github.com/matrix-org/synapse/issues/5211), [\#5203](https://github.com/matrix-org/synapse/issues/5203), [\#5212](https://github.com/matrix-org/synapse/issues/5212))
- Add a room version 4 which uses a new event ID format, as per [MSC2002](https://github.com/matrix-org/matrix-doc/pull/2002). ([\#5210](https://github.com/matrix-org/synapse/issues/5210), [\#5217](https://github.com/matrix-org/synapse/issues/5217))
Bugfixes
--------
- Fix image orientation when generating thumbnails (needs pillow>=4.3.0). Contributed by Pau Rodriguez-Estivill. ([\#5039](https://github.com/matrix-org/synapse/issues/5039))
- Exclude soft-failed events from forward-extremity candidates: fixes "No forward extremities left!" error. ([\#5146](https://github.com/matrix-org/synapse/issues/5146))
- Re-order stages in registration flows such that msisdn and email verification are done last. ([\#5174](https://github.com/matrix-org/synapse/issues/5174))
- Fix 3pid guest invites. ([\#5177](https://github.com/matrix-org/synapse/issues/5177))
- Fix a bug where the register endpoint would fail with M_THREEPID_IN_USE instead of returning an account previously registered in the same session. ([\#5187](https://github.com/matrix-org/synapse/issues/5187))
- Prevent registration for user ids that are too long to fit into a state key. Contributed by Reid Anderson. ([\#5198](https://github.com/matrix-org/synapse/issues/5198))
- Fix incompatibility between ACME support and Python 3.5.2. ([\#5218](https://github.com/matrix-org/synapse/issues/5218))
- Fix error handling for rooms whose versions are unknown. ([\#5219](https://github.com/matrix-org/synapse/issues/5219))
Internal Changes
----------------
- Make /sync attempt to return device updates for both joined and invited users. Note that this doesn't currently work correctly due to other bugs. ([\#3484](https://github.com/matrix-org/synapse/issues/3484))
- Update tests to consistently be configured via the same code that is used when loading from configuration files. ([\#5171](https://github.com/matrix-org/synapse/issues/5171), [\#5185](https://github.com/matrix-org/synapse/issues/5185))
- Allow client event serialization to be async. ([\#5183](https://github.com/matrix-org/synapse/issues/5183))
- Expose DataStore._get_events as get_events_as_list. ([\#5184](https://github.com/matrix-org/synapse/issues/5184))
- Make generating SQL bounds for pagination generic. ([\#5191](https://github.com/matrix-org/synapse/issues/5191))
- Stop telling people to install the optional dependencies by default. ([\#5197](https://github.com/matrix-org/synapse/issues/5197))
Synapse 0.99.4 (2019-05-15)
===========================

View file

@ -1 +0,0 @@
Make /sync attempt to return device updates for both joined and invited users. Note that this doesn't currently work correctly due to other bugs.

1
changelog.d/4338.feature Normal file
View file

@ -0,0 +1 @@
Synapse now more efficiently collates room statistics.

View file

@ -1 +0,0 @@
Fix image orientation when generating thumbnails (needs pillow>=4.3.0). Contributed by Pau Rodriguez-Estivill.

View file

@ -1 +0,0 @@
Add ability to blacklist IP ranges for the federation client.

View file

@ -1 +0,0 @@
Exclude soft-failed events from forward-extremity candidates: fixes "No forward extremities left!" error.

View file

@ -1 +0,0 @@
Update tests to consistently be configured via the same code that is used when loading from configuration files.

View file

@ -1 +0,0 @@
Re-order stages in registration flows such that msisdn and email verification are done last.

View file

@ -1 +0,0 @@
Fix 3pid guest invites.

View file

@ -1 +0,0 @@
Ratelimiting configuration for clients sending messages and the federation server has been altered to match login ratelimiting. The old configuration names will continue working. Check the sample config for details of the new names.

View file

@ -1 +0,0 @@
Allow client event serialization to be async.

View file

@ -1 +0,0 @@
Expose DataStore._get_events as get_events_as_list.

View file

@ -1 +0,0 @@
Update tests to consistently be configured via the same code that is used when loading from configuration files.

View file

@ -1 +0,0 @@
Fix a bug where the register endpoint would fail with M_THREEPID_IN_USE instead of returning an account previously registered in the same session.

View file

@ -1 +0,0 @@
Drop support for the undocumented /_matrix/client/v2_alpha API prefix.

View file

@ -1 +0,0 @@
Make generating SQL bounds for pagination generic.

View file

@ -1 +0,0 @@
Add an option to disable per-room profiles.

View file

@ -1 +0,0 @@
Stop telling people to install the optional dependencies by default.

View file

@ -1 +0,0 @@
Prevent registration for user ids that are to long to fit into a state key. Contributed by Reid Anderson.

1
changelog.d/5200.bugfix Normal file
View file

@ -0,0 +1 @@
Fix worker registration bug caused by ClientReaderSlavedStore being unable to see get_profileinfo.

View file

@ -1 +0,0 @@
Stick an expiration date to any registered user missing one at startup if account validity is enabled.

View file

@ -1 +0,0 @@
Add experimental support for relations (aka reactions and edits).

View file

@ -1 +0,0 @@
Add a room version 4 which uses a new event ID format, as per [MSC2002](https://github.com/matrix-org/matrix-doc/pull/2002).

View file

@ -1 +0,0 @@
Add experimental support for relations (aka reactions and edits).

View file

@ -1 +0,0 @@
Add a room version 4 which uses a new event ID format, as per [MSC2002](https://github.com/matrix-org/matrix-doc/pull/2002).

View file

@ -1 +0,0 @@
Fix incompatibility between ACME support and Python 3.5.2.

View file

@ -1 +0,0 @@
Fix error handling for rooms whose versions are unknown.

1
changelog.d/5223.feature Normal file
View file

@ -0,0 +1 @@
Ability to configure default room version.

1
changelog.d/5227.misc Normal file
View file

@ -0,0 +1 @@
Simplifications and comments in do_auth.

1
changelog.d/5230.misc Normal file
View file

@ -0,0 +1 @@
Remove urllib3 pin as requests 2.22.0 has been released supporting urllib3 1.25.2.

1
changelog.d/5232.misc Normal file
View file

@ -0,0 +1 @@
Run black on synapse.crypto.keyring.

1
changelog.d/5234.misc Normal file
View file

@ -0,0 +1 @@
Rewrite store_server_verify_key to store several keys at once.

1
changelog.d/5235.misc Normal file
View file

@ -0,0 +1 @@
Remove unused VerifyKey.expired and .time_added fields.

1
changelog.d/5236.misc Normal file
View file

@ -0,0 +1 @@
Simplify Keyring.process_v2_response.

1
changelog.d/5237.misc Normal file
View file

@ -0,0 +1 @@
Store key validity time in the storage layer.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (0.99.5.1) stable; urgency=medium
* New synapse release 0.99.5.1.
-- Synapse Packaging team <packages@matrix.org> Wed, 22 May 2019 16:22:24 +0000
matrix-synapse-py3 (0.99.4) stable; urgency=medium
[ Christoph Müller ]

View file

@ -161,7 +161,7 @@ specify values for `SYNAPSE_CONFIG_PATH`, `SYNAPSE_SERVER_NAME` and
example:
```
docker run -it --rm
docker run -it --rm \
--mount type=volume,src=synapse-data,dst=/data \
-e SYNAPSE_CONFIG_PATH=/data/homeserver.yaml \
-e SYNAPSE_SERVER_NAME=my.matrix.host \

View file

@ -83,6 +83,15 @@ pid_file: DATADIR/homeserver.pid
#
#restrict_public_rooms_to_local_users: true
# The default room version for newly created rooms.
#
# Known room versions are listed here:
# https://matrix.org/docs/spec/#complete-list-of-room-versions
#
# For example, for room version 1, default_room_version should be set
# to "1".
#default_room_version: "1"
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
#gc_thresholds: [700, 10, 10]
@ -1153,6 +1162,22 @@ password_config:
#
# Local statistics collection. Used in populating the room directory.
#
# 'bucket_size' controls how large each statistics timeslice is. It can
# be defined in a human readable short form -- e.g. "1d", "1y".
#
# 'retention' controls how long historical statistics will be kept for.
# It can be defined in a human readable short form -- e.g. "1d", "1y".
#
#
#stats:
# enabled: true
# bucket_size: 1d
# retention: 1y
# Server Notices room configuration
#
# Uncomment this section to enable a room which can be used to send notices

View file

@ -27,4 +27,4 @@ try:
except ImportError:
pass
__version__ = "0.99.4"
__version__ = "0.99.5.1"

View file

@ -79,6 +79,7 @@ class EventTypes(object):
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
Encryption = "m.room.encryption"
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"

View file

@ -85,10 +85,6 @@ class RoomVersions(object):
)
# the version we will give rooms which are created on this server
DEFAULT_ROOM_VERSION = RoomVersions.V1
KNOWN_ROOM_VERSIONS = {
v.identifier: v for v in (
RoomVersions.V1,

View file

@ -38,6 +38,7 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
@ -81,6 +82,7 @@ class ClientReaderSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedTransactionStore,
SlavedProfileStore,
SlavedClientIpStore,
BaseSlavedStore,
):

View file

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@ -36,20 +37,41 @@ from .saml2_config import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
from .stats import StatsConfig
from .tls import TlsConfig
from .user_directory import UserDirectoryConfig
from .voip import VoipConfig
from .workers import WorkerConfig
class HomeServerConfig(ServerConfig, TlsConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
ConsentConfig,
ServerNoticesConfig, RoomDirectoryConfig,
):
class HomeServerConfig(
ServerConfig,
TlsConfig,
DatabaseConfig,
LoggingConfig,
RatelimitConfig,
ContentRepositoryConfig,
CaptchaConfig,
VoipConfig,
RegistrationConfig,
MetricsConfig,
ApiConfig,
AppServiceConfig,
KeyConfig,
SAML2Config,
CasConfig,
JWTConfig,
PasswordConfig,
EmailConfig,
WorkerConfig,
PasswordAuthProviderConfig,
PushConfig,
SpamCheckerConfig,
GroupsConfig,
UserDirectoryConfig,
ConsentConfig,
StatsConfig,
ServerNoticesConfig,
RoomDirectoryConfig,
):
pass

View file

@ -20,6 +20,7 @@ import os.path
from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
@ -35,6 +36,8 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
DEFAULT_ROOM_VERSION = "1"
class ServerConfig(Config):
@ -88,6 +91,22 @@ class ServerConfig(Config):
"restrict_public_rooms_to_local_users", False,
)
default_room_version = config.get(
"default_room_version", DEFAULT_ROOM_VERSION,
)
# Ensure room version is a str
default_room_version = str(default_room_version)
if default_room_version not in KNOWN_ROOM_VERSIONS:
raise ConfigError(
"Unknown default_room_version: %s, known room versions: %s" %
(default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
)
# Get the actual room version object rather than just the identifier
self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version]
# whether to enable search. If disabled, new entries will not be inserted
# into the search tables and they will not be indexed. Users will receive
# errors when attempting to search for messages.
@ -310,6 +329,10 @@ class ServerConfig(Config):
unsecure_port = 8008
pid_file = os.path.join(data_dir_path, "homeserver.pid")
# Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
# default config string
default_room_version = DEFAULT_ROOM_VERSION
return """\
## Server ##
@ -384,6 +407,15 @@ class ServerConfig(Config):
#
#restrict_public_rooms_to_local_users: true
# The default room version for newly created rooms.
#
# Known room versions are listed here:
# https://matrix.org/docs/spec/#complete-list-of-room-versions
#
# For example, for room version 1, default_room_version should be set
# to "1".
#default_room_version: "%(default_room_version)s"
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
#gc_thresholds: [700, 10, 10]

60
synapse/config/stats.py Normal file
View file

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import sys
from ._base import Config
class StatsConfig(Config):
"""Stats Configuration
Configuration for the behaviour of synapse's stats engine
"""
def read_config(self, config):
self.stats_enabled = True
self.stats_bucket_size = 86400
self.stats_retention = sys.maxsize
stats_config = config.get("stats", None)
if stats_config:
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
self.stats_bucket_size = (
self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000
)
self.stats_retention = (
self.parse_duration(
stats_config.get("retention", "%ds" % (sys.maxsize,))
)
/ 1000
)
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Local statistics collection. Used in populating the room directory.
#
# 'bucket_size' controls how large each statistics timeslice is. It can
# be defined in a human readable short form -- e.g. "1d", "1y".
#
# 'retention' controls how long historical statistics will be kept for.
# It can be defined in a human readable short form -- e.g. "1d", "1y".
#
#
#stats:
# enabled: true
# bucket_size: 1d
# retention: 1y
"""

View file

@ -20,7 +20,6 @@ from collections import namedtuple
from six import raise_from
from six.moves import urllib
import nacl.signing
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@ -43,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
LoggingContext,
@ -56,9 +56,9 @@ from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
VerifyKeyRequest = namedtuple("VerifyRequest", (
"server_name", "key_ids", "json_object", "deferred"
))
VerifyKeyRequest = namedtuple(
"VerifyRequest", ("server_name", "key_ids", "json_object", "deferred")
)
"""
A request for a verify key to verify a JSON object.
@ -96,9 +96,7 @@ class Keyring(object):
def verify_json_for_server(self, server_name, json_object):
return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server(
[(server_name, json_object)]
)[0]
self.verify_json_objects_for_server([(server_name, json_object)])[0]
)
def verify_json_objects_for_server(self, server_and_json):
@ -130,18 +128,15 @@ class Keyring(object):
if not key_ids:
return defer.fail(
SynapseError(
400,
"Not signed by %s" % (server_name,),
Codes.UNAUTHORIZED,
400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED
)
)
logger.debug("Verifying for %s with key_ids %s",
server_name, key_ids)
logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
# add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, defer.Deferred(),
server_name, key_ids, json_object, defer.Deferred()
)
verify_requests.append(verify_request)
@ -179,15 +174,13 @@ class Keyring(object):
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = {
rq.server_name: defer.Deferred()
for rq in verify_requests
rq.server_name: defer.Deferred() for rq in verify_requests
}
# We want to wait for any previous lookups to complete before
# proceeding.
yield self.wait_for_previous_lookups(
[rq.server_name for rq in verify_requests],
server_to_deferred,
[rq.server_name for rq in verify_requests], server_to_deferred
)
# Actually start fetching keys.
@ -216,9 +209,7 @@ class Keyring(object):
return res
for verify_request in verify_requests:
verify_request.deferred.addBoth(
remove_deferreds, verify_request,
)
verify_request.deferred.addBoth(remove_deferreds, verify_request)
except Exception:
logger.exception("Error starting key lookups")
@ -248,7 +239,8 @@ class Keyring(object):
break
logger.info(
"Waiting for existing lookups for %s to complete [loop %i]",
[w[0] for w in wait_on], loop_count,
[w[0] for w in wait_on],
loop_count,
)
with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on))
@ -315,11 +307,15 @@ class Keyring(object):
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
key = result_keys.get(key_id)
if key:
fetch_key_result = result_keys.get(key_id)
if fetch_key_result:
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, key)
(
server_name,
key_id,
fetch_key_result.verify_key,
)
)
break
else:
@ -335,13 +331,14 @@ class Keyring(object):
with PreserveLoggingContext():
for verify_request in requests_missing_keys:
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
verify_request.deferred.errback(
SynapseError(
401,
"No key for %s with id %s"
% (verify_request.server_name, verify_request.key_ids),
Codes.UNAUTHORIZED,
)
)
def on_err(err):
with PreserveLoggingContext():
@ -355,12 +352,12 @@ class Keyring(object):
def get_keys_from_store(self, server_name_and_key_ids):
"""
Args:
server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Returns:
Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
server_name -> key_id -> VerifyKey
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
map from server_name -> key_id -> FetchKeyResult
"""
keys_to_fetch = (
(server_name, key_id)
@ -383,25 +380,26 @@ class Keyring(object):
)
defer.returnValue(result)
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", perspective_name, e,
)
logger.warning("Key lookup failed from %r: %s", perspective_name, e)
except Exception as e:
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
type(e).__name__, str(e),
type(e).__name__,
str(e),
)
defer.returnValue({})
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError))
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
union_of_keys = {}
for result in results:
@ -412,32 +410,42 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct,
server_name,
key_ids,
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError))
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
defer.returnValue(
{server_name: keys for server_name, keys in merged.items() if keys}
)
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys
):
"""
Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
from server_name -> key_id -> FetchKeyResult
"""
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
@ -448,9 +456,7 @@ class Keyring(object):
data={
u"server_keys": {
server_name: {
key_id: {
u"minimum_valid_until_ts": 0
} for key_id in key_ids
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
}
for server_name, key_ids in server_names_and_key_ids
}
@ -458,21 +464,20 @@ class Keyring(object):
long_retries=True,
)
except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(
KeyLookupError("Failed to connect to remote server"), e,
)
raise_from(KeyLookupError("Failed to connect to remote server"), e)
except HttpResponseException as e:
raise_from(
KeyLookupError("Remote server returned an error"), e,
)
raise_from(KeyLookupError("Remote server returned an error"), e)
keys = {}
added_keys = []
responses = query_response["server_keys"]
time_now_ms = self.clock.time_msec()
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
for response in query_response["server_keys"]:
if (
u"signatures" not in response
or perspective_name not in response[u"signatures"]
):
raise KeyLookupError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
@ -482,9 +487,7 @@ class Keyring(object):
for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(
response,
perspective_name,
perspective_keys[key_id]
response, perspective_name, perspective_keys[key_id]
)
verified = True
@ -494,7 +497,7 @@ class Keyring(object):
" known key, signed with: %r, known keys: %r",
perspective_name,
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
list(perspective_keys),
)
raise KeyLookupError(
"Response not signed with a known key for perspective"
@ -502,77 +505,72 @@ class Keyring(object):
)
processed_response = yield self.process_v2_response(
perspective_name, response
perspective_name, response, time_added_ms=time_now_ms
)
server_name = response["server_name"]
added_keys.extend(
(server_name, key_id, key) for key_id, key in processed_response.items()
)
keys.setdefault(server_name, {}).update(processed_response)
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
run_in_background(
self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
)
for server_name, response_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError))
yield self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys
)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, nacl.signing.VerifyKey]
keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids:
if requested_key_id in keys:
continue
time_now_ms = self.clock.time_msec()
try:
response = yield self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True,
)
except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(
KeyLookupError("Failed to connect to remote server"), e,
)
raise_from(KeyLookupError("Failed to connect to remote server"), e)
except HttpResponseException as e:
raise_from(
KeyLookupError("Remote server returned an error"), e,
)
raise_from(KeyLookupError("Remote server returned an error"), e)
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
if (
u"signatures" not in response
or server_name not in response[u"signatures"]
):
raise KeyLookupError("Key response not signed by remote server")
if response["server_name"] != server_name:
raise KeyLookupError("Expected a response for server %r not %r" % (
server_name, response["server_name"]
))
raise KeyLookupError(
"Expected a response for server %r not %r"
% (server_name, response["server_name"])
)
response_keys = yield self.process_v2_response(
from_server=server_name,
requested_ids=[requested_key_id],
response_json=response,
time_added_ms=time_now_ms,
)
yield self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
)
keys.update(response_keys)
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=keys,
)
defer.returnValue({server_name: keys})
@defer.inlineCallbacks
def process_v2_response(
self, from_server, response_json, requested_ids=[],
self, from_server, response_json, time_added_ms, requested_ids=[]
):
"""Parse a 'Server Keys' structure from the result of a /key request
@ -594,103 +592,82 @@ class Keyring(object):
response_json (dict): the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, nacl.signing.VerifyKey]]:
map from key_id to key object
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
time_now_ms = self.clock.time_msec()
response_keys = {}
ts_valid_until_ms = response_json[u"valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)
# TODO: improve this signature checking
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in verify_keys:
raise KeyLookupError(
"Key response must include verification keys for all signatures"
)
verify_signed_json(
response_json, server_name, verify_keys[key_id].verify_key
)
old_verify_keys = {}
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired = key_data["expired_ts"]
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response_json,
server_name,
verify_keys[key_id]
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
# re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server
signed_key_json = sign_json(
response_json,
self.config.server_name,
self.config.signing_key[0],
response_json, self.config.server_name, self.config.signing_key[0]
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError))
defer.returnValue(response_keys)
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
server_name(str): The name of the server the keys are for.
from_server(str): The server the keys were downloaded from.
verify_keys(dict): A mapping of key_id to VerifyKey.
Returns:
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
return logcontext.make_deferred_yieldable(defer.gatherResults(
[
run_in_background(
self.store.store_server_verify_key,
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError))
defer.returnValue(verify_keys)
@defer.inlineCallbacks
@ -713,17 +690,19 @@ def _handle_key_deferred(verify_request):
except KeyLookupError as e:
logger.warn(
"Failed to download keys for %s: %s %s",
server_name, type(e).__name__, str(e),
server_name,
type(e).__name__,
str(e),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e),
server_name,
type(e).__name__,
str(e),
)
raise SynapseError(
401,
@ -733,22 +712,24 @@ def _handle_key_deferred(verify_request):
json_object = verify_request.json_object
logger.debug("Got key %s %s:%s for server %s, verifying" % (
key_id, verify_key.alg, verify_key.version, server_name,
))
logger.debug(
"Got key %s %s:%s for server %s, verifying"
% (key_id, verify_key.alg, verify_key.version, server_name)
)
try:
verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
server_name, verify_key.alg, verify_key.version,
server_name,
verify_key.alg,
verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s: %s" % (
server_name, verify_key.alg, verify_key.version, str(e),
),
"Invalid signature for server %s with key %s:%s: %s"
% (server_name, verify_key.alg, verify_key.version, str(e)),
Codes.UNAUTHORIZED,
)

View file

@ -2013,15 +2013,44 @@ class FederationHandler(BaseHandler):
Args:
origin (str):
event (synapse.events.FrozenEvent):
event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
auth_events (dict[(str, str)->str]):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Map from (event_type, state_key) to event
What we expect the event's auth_events to be, based on the event's
position in the dag. I think? maybe??
Also NB that this function adds entries to it.
Returns:
defer.Deferred[None]
"""
room_version = yield self.store.get_room_version(event.room_id)
yield self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
)
try:
self.auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@defer.inlineCallbacks
def _update_auth_events_and_context_for_auth(
self, origin, event, context, auth_events
):
"""Helper for do_auth. See there for docs.
Args:
origin (str):
event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Returns:
defer.Deferred[None]
"""
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(event.auth_event_ids())
if event.is_state():
@ -2029,11 +2058,21 @@ class FederationHandler(BaseHandler):
else:
event_key = None
if event_auth_events - current_state:
# if the event's auth_events refers to events which are not in our
# calculated auth_events, we need to fetch those events from somewhere.
#
# we start by fetching them from the store, and then try calling /event_auth/.
missing_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
if missing_auth:
# TODO: can we use store.have_seen_events here instead?
have_events = yield self.store.get_seen_events_with_rejections(
event_auth_events - current_state
missing_auth
)
logger.debug("Got events %s from store", have_events)
missing_auth.difference_update(have_events.keys())
else:
have_events = {}
@ -2042,13 +2081,12 @@ class FederationHandler(BaseHandler):
for e in auth_events.values()
})
seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events - current_state
if missing_auth:
logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
logger.info(
"auth_events contains unknown events: %s",
missing_auth,
)
try:
remote_auth_chain = yield self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
@ -2089,145 +2127,168 @@ class FederationHandler(BaseHandler):
have_events = yield self.store.get_seen_events_with_rejections(
event.auth_event_ids()
)
seen_events = set(have_events.keys())
except Exception:
# FIXME:
logger.exception("Failed to get auth chain")
if event.internal_metadata.is_outlier():
logger.info("Skipping auth_event fetch for outlier")
return
# FIXME: Assumes we have and stored all the state for all the
# prev_events
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
if not different_auth:
return
logger.info(
"auth_events refers to events which are not in our calculated auth "
"chain: %s",
different_auth,
)
room_version = yield self.store.get_room_version(event.room_id)
if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
different_events = yield logcontext.make_deferred_yieldable(
defer.gatherResults([
logcontext.run_in_background(
self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
], consumeErrors=True)
).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
(d.type, d.state_key): d for d in different_events if d
})
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
event
different_events = yield logcontext.make_deferred_yieldable(
defer.gatherResults([
logcontext.run_in_background(
self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
], consumeErrors=True)
).addErrback(unwrapFirstError)
auth_events.update(new_state)
if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
(d.type, d.state_key): d for d in different_events if d
})
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
event
)
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
logger.info(
"After state res: updating auth_events with new state %s",
{
(d.type, d.state_key): d.event_id for d in new_state.values()
if auth_events.get((d.type, d.state_key)) != d
},
)
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
auth_events.update(new_state)
# Only do auth resolution if we have something new to say.
# We can't rove an auth failure.
do_resolution = False
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
provable = [
RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR,
]
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
for e_id in different_auth:
if e_id in have_events:
if have_events[e_id] in provable:
do_resolution = True
break
if not different_auth:
# we're done
return
if do_resolution:
prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
event, prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
)
logger.info(
"auth_events still refers to events which are not in the calculated auth "
"chain after state resolution: %s",
different_auth,
)
try:
# 2. Get remote difference.
result = yield self.federation_client.query_auth(
origin,
event.room_id,
event.event_id,
local_auth_chain,
)
# Only do auth resolution if we have something new to say.
# We can't prove an auth failure.
do_resolution = False
seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]]
)
for e_id in different_auth:
if e_id in have_events:
if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
do_resolution = True
break
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
if ev.event_id in seen_remotes:
continue
if not do_resolution:
logger.info(
"Skipping auth resolution due to lack of provable rejection reasons"
)
return
if ev.event_id == event.event_id:
continue
logger.info("Doing auth resolution")
try:
auth_ids = ev.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
if e.event_id in auth_ids
or event.type == EventTypes.Create
}
ev.internal_metadata.outlier = True
prev_state_ids = yield context.get_prev_state_ids(self.store)
logger.debug(
"do_auth %s different_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, ev, auth_events=auth
)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
except AuthError:
pass
except Exception:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs.
# TODO.
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
event, prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
)
try:
self.auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
# 2. Get remote difference.
result = yield self.federation_client.query_auth(
origin,
event.room_id,
event.event_id,
local_auth_chain,
)
seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
if ev.event_id in seen_remotes:
continue
if ev.event_id == event.event_id:
continue
try:
auth_ids = ev.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
if e.event_id in auth_ids
or event.type == EventTypes.Create
}
ev.internal_metadata.outlier = True
logger.debug(
"do_auth %s different_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, ev, auth_events=auth
)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
except AuthError:
pass
except Exception:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs.
# TODO.
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,

View file

@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.api.errors import (
AuthError,
Codes,
@ -601,6 +601,20 @@ class EventCreationHandler(object):
self.validator.validate_new(event)
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
# event multiple times).
relation = event.content.get("m.relates_to", {})
if relation.get("rel_type") == RelationTypes.ANNOTATION:
relates_to = relation["event_id"]
aggregation_key = relation["key"]
already_exists = yield self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender,
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
logger.debug(
"Created event %s",
event.event_id,

View file

@ -27,7 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@ -70,6 +70,7 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.config = hs.config
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@ -475,7 +476,11 @@ class RoomCreationHandler(BaseHandler):
if ratelimit:
yield self.ratelimit(requester)
room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier)
room_version = config.get(
"room_version",
self.config.default_room_version.identifier,
)
if not isinstance(room_version, string_types):
raise SynapseError(
400,

325
synapse/handlers/stats.py Normal file
View file

@ -0,0 +1,325 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
class StatsHandler(StateDeltasHandler):
"""Handles keeping the *_stats tables updated with a simple time-series of
information about the users, rooms and media on the server, such that admins
have some idea of who is consuming their resources.
Heavily derived from UserDirectoryHandler
"""
def __init__(self, hs):
super(StatsHandler, self).__init__(hs)
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.stats_bucket_size = hs.config.stats_bucket_size
# The current position in the current_state_delta stream
self.pos = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False
if hs.config.stats_enabled:
self.notifier.add_replication_callback(self.notify_new_event)
# We kick this off so that we don't have to wait for a change before
# we start populating stats
self.clock.call_later(0, self.notify_new_event)
def notify_new_event(self):
"""Called when there may be more deltas to process
"""
if not self.hs.config.stats_enabled:
return
if self._is_processing:
return
@defer.inlineCallbacks
def process():
try:
yield self._unsafe_process()
finally:
self._is_processing = False
self._is_processing = True
run_as_background_process("stats.notify_new_event", process)
@defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = yield self.store.get_stats_stream_pos()
# If still None then the initial background update hasn't happened yet
if self.pos is None:
defer.returnValue(None)
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "stats_delta"):
deltas = yield self.store.get_current_state_deltas(self.pos)
if not deltas:
return
logger.info("Handling %d state deltas", len(deltas))
yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"]
yield self.store.update_stats_stream_pos(self.pos)
event_processing_positions.labels("stats").set(self.pos)
@defer.inlineCallbacks
def _handle_deltas(self, deltas):
"""
Called with the state deltas to process
"""
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
room_id = delta["room_id"]
event_id = delta["event_id"]
stream_id = delta["stream_id"]
prev_event_id = delta["prev_event_id"]
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
token = yield self.store.get_earliest_token_for_room_stats(room_id)
# If the earliest token to begin from is larger than our current
# stream ID, skip processing this delta.
if token is not None and token >= stream_id:
logger.debug(
"Ignoring: %s as earlier than this room's initial ingestion event",
event_id,
)
continue
if event_id is None and prev_event_id is None:
# Errr...
continue
event_content = {}
if event_id is not None:
event_content = (yield self.store.get_event(event_id)).content or {}
# quantise time to the nearest bucket
now = yield self.store.get_received_ts(event_id)
now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size
if typ == EventTypes.Member:
# we could use _get_key_change here but it's a bit inefficient
# given we're not testing for a specific result; might as well
# just grab the prev_membership and membership strings and
# compare them.
prev_event_content = {}
if prev_event_id is not None:
prev_event_content = (
yield self.store.get_event(prev_event_id)
).content
membership = event_content.get("membership", Membership.LEAVE)
prev_membership = prev_event_content.get("membership", Membership.LEAVE)
if prev_membership == membership:
continue
if prev_membership == Membership.JOIN:
yield self.store.update_stats_delta(
now, "room", room_id, "joined_members", -1
)
elif prev_membership == Membership.INVITE:
yield self.store.update_stats_delta(
now, "room", room_id, "invited_members", -1
)
elif prev_membership == Membership.LEAVE:
yield self.store.update_stats_delta(
now, "room", room_id, "left_members", -1
)
elif prev_membership == Membership.BAN:
yield self.store.update_stats_delta(
now, "room", room_id, "banned_members", -1
)
else:
err = "%s is not a valid prev_membership" % (repr(prev_membership),)
logger.error(err)
raise ValueError(err)
if membership == Membership.JOIN:
yield self.store.update_stats_delta(
now, "room", room_id, "joined_members", +1
)
elif membership == Membership.INVITE:
yield self.store.update_stats_delta(
now, "room", room_id, "invited_members", +1
)
elif membership == Membership.LEAVE:
yield self.store.update_stats_delta(
now, "room", room_id, "left_members", +1
)
elif membership == Membership.BAN:
yield self.store.update_stats_delta(
now, "room", room_id, "banned_members", +1
)
else:
err = "%s is not a valid membership" % (repr(membership),)
logger.error(err)
raise ValueError(err)
user_id = state_key
if self.is_mine_id(user_id):
# update user_stats as it's one of our users
public = yield self._is_public_room(room_id)
if membership == Membership.LEAVE:
yield self.store.update_stats_delta(
now,
"user",
user_id,
"public_rooms" if public else "private_rooms",
-1,
)
elif membership == Membership.JOIN:
yield self.store.update_stats_delta(
now,
"user",
user_id,
"public_rooms" if public else "private_rooms",
+1,
)
elif typ == EventTypes.Create:
# Newly created room. Add it with all blank portions.
yield self.store.update_room_state(
room_id,
{
"join_rules": None,
"history_visibility": None,
"encryption": None,
"name": None,
"topic": None,
"avatar": None,
"canonical_alias": None,
},
)
elif typ == EventTypes.JoinRules:
yield self.store.update_room_state(
room_id, {"join_rules": event_content.get("join_rule")}
)
is_public = yield self._get_key_change(
prev_event_id, event_id, "join_rule", JoinRules.PUBLIC
)
if is_public is not None:
yield self.update_public_room_stats(now, room_id, is_public)
elif typ == EventTypes.RoomHistoryVisibility:
yield self.store.update_room_state(
room_id,
{"history_visibility": event_content.get("history_visibility")},
)
is_public = yield self._get_key_change(
prev_event_id, event_id, "history_visibility", "world_readable"
)
if is_public is not None:
yield self.update_public_room_stats(now, room_id, is_public)
elif typ == EventTypes.Encryption:
yield self.store.update_room_state(
room_id, {"encryption": event_content.get("algorithm")}
)
elif typ == EventTypes.Name:
yield self.store.update_room_state(
room_id, {"name": event_content.get("name")}
)
elif typ == EventTypes.Topic:
yield self.store.update_room_state(
room_id, {"topic": event_content.get("topic")}
)
elif typ == EventTypes.RoomAvatar:
yield self.store.update_room_state(
room_id, {"avatar": event_content.get("url")}
)
elif typ == EventTypes.CanonicalAlias:
yield self.store.update_room_state(
room_id, {"canonical_alias": event_content.get("alias")}
)
@defer.inlineCallbacks
def update_public_room_stats(self, ts, room_id, is_public):
"""
Increment/decrement a user's number of public rooms when a room they are
in changes to/from public visibility.
Args:
ts (int): Timestamp in seconds
room_id (str)
is_public (bool)
"""
# For now, blindly iterate over all local users in the room so that
# we can handle the whole problem of copying buckets over as needed
user_ids = yield self.store.get_users_in_room(room_id)
for user_id in user_ids:
if self.hs.is_mine(UserID.from_string(user_id)):
yield self.store.update_stats_delta(
ts, "user", user_id, "public_rooms", +1 if is_public else -1
)
yield self.store.update_stats_delta(
ts, "user", user_id, "private_rooms", -1 if is_public else +1
)
@defer.inlineCallbacks
def _is_public_room(self, room_id):
join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules)
history_visibility = yield self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility
)
if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or (
(
history_visibility
and history_visibility.content.get("history_visibility")
== "world_readable"
)
):
defer.returnValue(True)
else:
defer.returnValue(False)

View file

@ -74,14 +74,6 @@ REQUIREMENTS = [
"attrs>=17.4.0",
"netaddr>=0.7.18",
# requests is a transitive dep of treq, and urlib3 is a transitive dep
# of requests, as well as of sentry-sdk.
#
# As of requests 2.21, requests does not yet support urllib3 1.25.
# (If we do not pin it here, pip will give us the latest urllib3
# due to the dep via sentry-sdk.)
"urllib3<1.25",
]
CONDITIONAL_REQUIREMENTS = {

View file

@ -16,7 +16,7 @@ import logging
from twisted.internet import defer
from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
@ -36,6 +36,7 @@ class CapabilitiesRestServlet(RestServlet):
"""
super(CapabilitiesRestServlet, self).__init__()
self.hs = hs
self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -48,7 +49,7 @@ class CapabilitiesRestServlet(RestServlet):
response = {
"capabilities": {
"m.room_versions": {
"default": DEFAULT_ROOM_VERSION.identifier,
"default": self.config.default_room_version.identifier,
"available": {
v.identifier: v.disposition
for v in KNOWN_ROOM_VERSIONS.values()

View file

@ -72,6 +72,7 @@ from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.user_directory import UserDirectoryHandler
@ -139,6 +140,7 @@ class HomeServer(object):
'acme_handler',
'auth_handler',
'device_handler',
'stats_handler',
'e2e_keys_handler',
'e2e_room_keys_handler',
'event_handler',
@ -191,6 +193,7 @@ class HomeServer(object):
REQUIRED_ON_MASTER_STARTUP = [
"user_directory_handler",
"stats_handler"
]
# This is overridden in derived application classes
@ -474,6 +477,9 @@ class HomeServer(object):
def build_secrets(self):
return Secrets()
def build_stats_handler(self):
return StatsHandler(self)
def build_spam_checker(self):
return SpamChecker(self)

View file

@ -55,6 +55,7 @@ from .roommember import RoomMemberStore
from .search import SearchStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
@ -100,6 +101,7 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
):
def __init__(self, db_conn, hs):

View file

@ -610,4 +610,28 @@ class EventsWorkerStore(SQLBaseStore):
return res
return self.runInteraction("get_rejection_reasons", f)
return self.runInteraction("get_seen_events_with_rejections", f)
def _get_total_state_event_counts_txn(self, txn, room_id):
"""
See get_state_event_counts.
"""
sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?"
txn.execute(sql, (room_id,))
row = txn.fetchone()
return row[0] if row else 0
def get_total_state_event_counts(self, room_id):
"""
Gets the total number of state events in a room.
Args:
room_id (str)
Returns:
Deferred[int]
"""
return self.runInteraction(
"get_total_state_event_counts",
self._get_total_state_event_counts_txn, room_id
)

View file

@ -19,6 +19,7 @@ import logging
import six
import attr
from signedjson.key import decode_verify_key_bytes
from synapse.util import batch_iter
@ -36,6 +37,12 @@ else:
db_binary_type = memoryview
@attr.s(slots=True, frozen=True)
class FetchKeyResult(object):
verify_key = attr.ib() # VerifyKey: the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
map from (server_name, key_id) -> VerifyKey, or None if the key is
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
"""
keys = {}
@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
"SELECT server_name, key_id, verify_key FROM server_signature_keys "
"WHERE 1=0"
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
"FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn:
server_name, key_id, key_bytes = row
keys[(server_name, key_id)] = decode_verify_key_bytes(
key_id, bytes(key_bytes)
server_name, key_id, key_bytes, ts_valid_until_ms = row
res = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
keys[(server_name, key_id)] = res
def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50):
@ -84,38 +93,53 @@ class KeyStore(SQLBaseStore):
return self.runInteraction("get_server_verify_keys", _txn)
def store_server_verify_key(
self, server_name, from_server, time_now_ms, verify_key
):
"""Stores a NACL verification key for the given server.
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers.
Args:
server_name (str): The name of the server.
from_server (str): Where the verification key was looked up
time_now_ms (int): The time now in milliseconds
verify_key (nacl.signing.VerifyKey): The NACL verify key.
from_server (str): Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added
verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
# XXX fix this to not need a lock (#3819)
def _txn(txn):
self._simple_upsert_txn(
txn,
table="server_signature_keys",
keyvalues={"server_name": server_name, "key_id": key_id},
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
"verify_key": db_binary_type(verify_key.encode()),
},
key_values = []
value_values = []
invalidations = []
for server_name, key_id, fetch_result in verify_keys:
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
txn.call_after(
self._get_server_verify_key.invalidate, ((server_name, key_id),)
)
invalidations.append((server_name, key_id))
return self.runInteraction("store_server_verify_key", _txn)
def _invalidate(res):
f = self._get_server_verify_key.invalidate
for i in invalidations:
f((i, ))
return res
return self.runInteraction(
"store_server_verify_keys",
self._simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=value_values,
).addCallback(_invalidate)
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes

View file

@ -280,7 +280,7 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause = ""
sql = """
SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE {where_clause}
@ -350,9 +350,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
def _get_applicable_edit_txn(txn):
txn.execute(
sql, (event_id, RelationTypes.REPLACE,)
)
txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
@ -367,6 +365,50 @@ class RelationsWorkerStore(SQLBaseStore):
edit_event = yield self.get_event(edit_id, allow_none=True)
defer.returnValue(edit_event)
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
"""Check if a user has already annotated an event with the same key
(e.g. already liked an event).
Args:
parent_id (str): The event being annotated
event_type (str): The event type of the annotation
aggregation_key (str): The aggregation key of the annotation
sender (str): The sender of the annotation
Returns:
Deferred[bool]
"""
sql = """
SELECT 1 FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND relation_type = ?
AND type = ?
AND sender = ?
AND aggregation_key = ?
LIMIT 1;
"""
def _get_if_user_has_annotated_event(txn):
txn.execute(
sql,
(
parent_id,
RelationTypes.ANNOTATION,
event_type,
sender,
aggregation_key,
),
)
return bool(txn.fetchone())
return self.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
class RelationsStore(RelationsWorkerStore):
def _handle_event_relations(self, txn, event):

View file

@ -142,6 +142,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.runInteraction("get_room_summary", _get_room_summary_txn)
def _get_user_count_in_room_txn(self, txn, room_id, membership):
"""
See get_user_count_in_room.
"""
sql = (
"SELECT count(*) FROM room_memberships as m"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
)
txn.execute(sql, (room_id, membership))
row = txn.fetchone()
return row[0]
def get_user_count_in_room(self, room_id, membership):
"""
Get the user count in a room with a particular membership.
Args:
room_id (str)
membership (Membership)
Returns:
Deferred[int]
"""
return self.runInteraction(
"get_users_in_room", self._get_user_count_in_room_txn, room_id, membership
)
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to

View file

@ -0,0 +1,23 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* When we can use this key until, before we have to refresh it. */
ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;
UPDATE server_signature_keys SET ts_valid_until_ms = (
SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
skj.server_name = server_signature_keys.server_name AND
skj.key_id = server_signature_keys.key_id
);

View file

@ -0,0 +1,80 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE stats_stream_pos (
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_id BIGINT,
CHECK (Lock='X')
);
INSERT INTO stats_stream_pos (stream_id) VALUES (null);
CREATE TABLE user_stats (
user_id TEXT NOT NULL,
ts BIGINT NOT NULL,
bucket_size INT NOT NULL,
public_rooms INT NOT NULL,
private_rooms INT NOT NULL
);
CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts);
CREATE TABLE room_stats (
room_id TEXT NOT NULL,
ts BIGINT NOT NULL,
bucket_size INT NOT NULL,
current_state_events INT NOT NULL,
joined_members INT NOT NULL,
invited_members INT NOT NULL,
left_members INT NOT NULL,
banned_members INT NOT NULL,
state_events INT NOT NULL
);
CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts);
-- cache of current room state; useful for the publicRooms list
CREATE TABLE room_state (
room_id TEXT NOT NULL,
join_rules TEXT,
history_visibility TEXT,
encryption TEXT,
name TEXT,
topic TEXT,
avatar TEXT,
canonical_alias TEXT
-- get aliases straight from the right table
);
CREATE UNIQUE INDEX room_state_room ON room_state(room_id);
CREATE TABLE room_stats_earliest_token (
room_id TEXT NOT NULL,
token BIGINT NOT NULL
);
CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id);
-- Set up staging tables
INSERT INTO background_updates (update_name, progress_json) VALUES
('populate_stats_createtables', '{}');
-- Run through each room and update stats
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_rooms', '{}', 'populate_stats_createtables');
-- Clean up staging tables
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_cleanup', '{}', 'populate_stats_process_rooms');

View file

@ -84,10 +84,16 @@ class StateDeltasStore(SQLBaseStore):
"get_current_state_deltas", get_current_state_deltas_txn
)
def get_max_stream_id_in_current_state_deltas(self):
return self._simple_select_one_onecol(
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
return self._simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
keyvalues={},
retcol="COALESCE(MAX(stream_id), -1)",
desc="get_max_stream_id_in_current_state_deltas",
)
def get_max_stream_id_in_current_state_deltas(self):
return self.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)

450
synapse/storage/stats.py Normal file
View file

@ -0,0 +1,450 @@
# -*- coding: utf-8 -*-
# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.storage.state_deltas import StateDeltasStore
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
# these fields track absolutes (e.g. total number of rooms on the server)
ABSOLUTE_STATS_FIELDS = {
"room": (
"current_state_events",
"joined_members",
"invited_members",
"left_members",
"banned_members",
"state_events",
),
"user": ("public_rooms", "private_rooms"),
}
TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
TEMP_TABLE = "_temp_populate_stats"
class StatsStore(StateDeltasStore):
def __init__(self, db_conn, hs):
super(StatsStore, self).__init__(db_conn, hs)
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats_enabled
self.stats_bucket_size = hs.config.stats_bucket_size
self.register_background_update_handler(
"populate_stats_createtables", self._populate_stats_createtables
)
self.register_background_update_handler(
"populate_stats_process_rooms", self._populate_stats_process_rooms
)
self.register_background_update_handler(
"populate_stats_cleanup", self._populate_stats_cleanup
)
@defer.inlineCallbacks
def _populate_stats_createtables(self, progress, batch_size):
if not self.stats_enabled:
yield self._end_background_update("populate_stats_createtables")
defer.returnValue(1)
# Get all the rooms that we want to process.
def _make_staging_area(txn):
sql = (
"CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE
+ "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)"
)
txn.execute(sql)
sql = (
"CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE
+ "_position(position TEXT NOT NULL)"
)
txn.execute(sql)
# Get rooms we want to process from the database
sql = """
SELECT room_id, count(*) FROM current_state_events
GROUP BY room_id
"""
txn.execute(sql)
rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
new_pos = yield self.get_max_stream_id_in_current_state_deltas()
yield self.runInteraction("populate_stats_temp_build", _make_staging_area)
yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
self.get_earliest_token_for_room_stats.invalidate_all()
yield self._end_background_update("populate_stats_createtables")
defer.returnValue(1)
@defer.inlineCallbacks
def _populate_stats_cleanup(self, progress, batch_size):
"""
Update the user directory stream position, then clean up the old tables.
"""
if not self.stats_enabled:
yield self._end_background_update("populate_stats_cleanup")
defer.returnValue(1)
position = yield self._simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
yield self.update_stats_stream_pos(position)
def _delete_staging_area(txn):
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
yield self._end_background_update("populate_stats_cleanup")
defer.returnValue(1)
@defer.inlineCallbacks
def _populate_stats_process_rooms(self, progress, batch_size):
if not self.stats_enabled:
yield self._end_background_update("populate_stats_process_rooms")
defer.returnValue(1)
# If we don't have progress filed, delete everything.
if not progress:
yield self.delete_all_stats()
def _get_next_batch(txn):
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """
SELECT room_id, events FROM %s_rooms
ORDER BY events DESC
LIMIT 250
""" % (
TEMP_TABLE,
)
txn.execute(sql)
rooms_to_work_on = txn.fetchall()
if not rooms_to_work_on:
return None
# Get how many are left to process, so we can give status on how
# far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
progress["remaining"] = txn.fetchone()[0]
return rooms_to_work_on
rooms_to_work_on = yield self.runInteraction(
"populate_stats_temp_read", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
yield self._end_background_update("populate_stats_process_rooms")
defer.returnValue(1)
logger.info(
"Processing the next %d rooms of %d remaining",
(len(rooms_to_work_on), progress["remaining"]),
)
# Number of state events we've processed by going through each room
processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
current_state_ids = yield self.get_current_state_ids(room_id)
join_rules = yield self.get_event(
current_state_ids.get((EventTypes.JoinRules, "")), allow_none=True
)
history_visibility = yield self.get_event(
current_state_ids.get((EventTypes.RoomHistoryVisibility, "")),
allow_none=True,
)
encryption = yield self.get_event(
current_state_ids.get((EventTypes.RoomEncryption, "")), allow_none=True
)
name = yield self.get_event(
current_state_ids.get((EventTypes.Name, "")), allow_none=True
)
topic = yield self.get_event(
current_state_ids.get((EventTypes.Topic, "")), allow_none=True
)
avatar = yield self.get_event(
current_state_ids.get((EventTypes.RoomAvatar, "")), allow_none=True
)
canonical_alias = yield self.get_event(
current_state_ids.get((EventTypes.CanonicalAlias, "")), allow_none=True
)
def _or_none(x, arg):
if x:
return x.content.get(arg)
return None
yield self.update_room_state(
room_id,
{
"join_rules": _or_none(join_rules, "join_rule"),
"history_visibility": _or_none(
history_visibility, "history_visibility"
),
"encryption": _or_none(encryption, "algorithm"),
"name": _or_none(name, "name"),
"topic": _or_none(topic, "topic"),
"avatar": _or_none(avatar, "url"),
"canonical_alias": _or_none(canonical_alias, "alias"),
},
)
now = self.hs.get_reactor().seconds()
# quantise time to the nearest bucket
now = (now // self.stats_bucket_size) * self.stats_bucket_size
def _fetch_data(txn):
# Get the current token of the room
current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
current_state_events = len(current_state_ids)
joined_members = self._get_user_count_in_room_txn(
txn, room_id, Membership.JOIN
)
invited_members = self._get_user_count_in_room_txn(
txn, room_id, Membership.INVITE
)
left_members = self._get_user_count_in_room_txn(
txn, room_id, Membership.LEAVE
)
banned_members = self._get_user_count_in_room_txn(
txn, room_id, Membership.BAN
)
total_state_events = self._get_total_state_event_counts_txn(
txn, room_id
)
self._update_stats_txn(
txn,
"room",
room_id,
now,
{
"bucket_size": self.stats_bucket_size,
"current_state_events": current_state_events,
"joined_members": joined_members,
"invited_members": invited_members,
"left_members": left_members,
"banned_members": banned_members,
"state_events": total_state_events,
},
)
self._simple_insert_txn(
txn,
"room_stats_earliest_token",
{"room_id": room_id, "token": current_token},
)
yield self.runInteraction("update_room_stats", _fetch_data)
# We've finished a room. Delete it from the table.
yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
# Update the remaining counter.
progress["remaining"] -= 1
yield self.runInteraction(
"populate_stats",
self._background_update_progress_txn,
"populate_stats_process_rooms",
progress,
)
processed_event_count += event_count
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
defer.returnValue(processed_event_count)
defer.returnValue(processed_event_count)
def delete_all_stats(self):
"""
Delete all statistics records.
"""
def _delete_all_stats_txn(txn):
txn.execute("DELETE FROM room_state")
txn.execute("DELETE FROM room_stats")
txn.execute("DELETE FROM room_stats_earliest_token")
txn.execute("DELETE FROM user_stats")
return self.runInteraction("delete_all_stats", _delete_all_stats_txn)
def get_stats_stream_pos(self):
return self._simple_select_one_onecol(
table="stats_stream_pos",
keyvalues={},
retcol="stream_id",
desc="stats_stream_pos",
)
def update_stats_stream_pos(self, stream_id):
return self._simple_update_one(
table="stats_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_stats_stream_pos",
)
def update_room_state(self, room_id, fields):
"""
Args:
room_id (str)
fields (dict[str:Any])
"""
return self._simple_upsert(
table="room_state",
keyvalues={"room_id": room_id},
values=fields,
desc="update_room_state",
)
def get_deltas_for_room(self, room_id, start, size=100):
"""
Get statistics deltas for a given room.
Args:
room_id (str)
start (int): Pagination start. Number of entries, not timestamp.
size (int): How many entries to return.
Returns:
Deferred[list[dict]], where the dict has the keys of
ABSOLUTE_STATS_FIELDS["room"] and "ts".
"""
return self._simple_select_list_paginate(
"room_stats",
{"room_id": room_id},
"ts",
start,
size,
retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]),
order_direction="DESC",
)
def get_all_room_state(self):
return self._simple_select_list(
"room_state", None, retcols=("name", "topic", "canonical_alias")
)
@cached()
def get_earliest_token_for_room_stats(self, room_id):
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
start of the background task and any particular room's stats
being calculated.
Returns:
Deferred[int]
"""
return self._simple_select_one_onecol(
"room_stats_earliest_token",
{"room_id": room_id},
retcol="token",
allow_none=True,
)
def update_stats(self, stats_type, stats_id, ts, fields):
table, id_col = TYPE_TO_ROOM[stats_type]
return self._simple_upsert(
table=table,
keyvalues={id_col: stats_id, "ts": ts},
values=fields,
desc="update_stats",
)
def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields):
table, id_col = TYPE_TO_ROOM[stats_type]
return self._simple_upsert_txn(
txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields
)
def update_stats_delta(self, ts, stats_type, stats_id, field, value):
def _update_stats_delta(txn):
table, id_col = TYPE_TO_ROOM[stats_type]
sql = (
"SELECT * FROM %s"
" WHERE %s=? and ts=("
" SELECT MAX(ts) FROM %s"
" WHERE %s=?"
")"
) % (table, id_col, table, id_col)
txn.execute(sql, (stats_id, stats_id))
rows = self.cursor_to_dict(txn)
if len(rows) == 0:
# silently skip as we don't have anything to apply a delta to yet.
# this tries to minimise any race between the initial sync and
# subsequent deltas arriving.
return
current_ts = ts
latest_ts = rows[0]["ts"]
if current_ts < latest_ts:
# This one is in the past, but we're just encountering it now.
# Mark it as part of the current bucket.
current_ts = latest_ts
elif ts != latest_ts:
# we have to copy our absolute counters over to the new entry.
values = {
key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type]
}
values[id_col] = stats_id
values["ts"] = ts
values["bucket_size"] = self.stats_bucket_size
self._simple_insert_txn(txn, table=table, values=values)
# actually update the new value
if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]:
self._simple_update_txn(
txn,
table=table,
keyvalues={id_col: stats_id, "ts": current_ts},
updatevalues={field: value},
)
else:
sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % (
table,
field,
field,
id_col,
)
txn.execute(sql, (value, stats_id, current_ts))
return self.runInteraction("update_stats_delta", _update_stats_delta)

View file

@ -25,6 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
from synapse.crypto.keyring import KeyLookupError
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@ -192,8 +193,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
r = self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
key1_id = "%s:%s" % (key1.alg, key1.version)
r = self.hs.datastore.store_server_verify_keys(
"server9",
time.time() * 1000,
[
(
"server9",
key1_id,
FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
),
],
)
self.get_success(r)
json1 = {}
@ -241,9 +252,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@ -311,9 +323,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@ -336,7 +349,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx"):
with LoggingContext("testctx") as ctx:
# we set the "request" prop to make it easier to follow what's going on in the
# logs.
ctx.request = "testctx"
rv = yield f(*args, **kwargs)
defer.returnValue(rv)

View file

@ -0,0 +1,251 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
class StatsRoomTests(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.handler = self.hs.get_stats_handler()
def _add_background_updates(self):
"""
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
self.store._all_done = False
self.get_success(
self.store._simple_insert(
"background_updates",
{"update_name": "populate_stats_createtables", "progress_json": "{}"},
)
)
self.get_success(
self.store._simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
"progress_json": "{}",
"depends_on": "populate_stats_createtables",
},
)
)
self.get_success(
self.store._simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
"progress_json": "{}",
"depends_on": "populate_stats_process_rooms",
},
)
)
def test_initial_room(self):
"""
The background updates will build the table from scratch.
"""
r = self.get_success(self.store.get_all_room_state())
self.assertEqual(len(r), 0)
# Disable stats
self.hs.config.stats_enabled = False
self.handler.stats_enabled = False
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.helper.send_state(
room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
)
# Stats disabled, shouldn't have done anything
r = self.get_success(self.store.get_all_room_state())
self.assertEqual(len(r), 0)
# Enable stats
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
# Do the initial population of the user directory via the background update
self._add_background_updates()
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
r = self.get_success(self.store.get_all_room_state())
self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo")
def test_initial_earliest_token(self):
"""
Ingestion via notify_new_event will ignore tokens that the background
update have already processed.
"""
self.reactor.advance(86401)
self.hs.config.stats_enabled = False
self.handler.stats_enabled = False
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
u2 = self.register_user("u2", "pass")
u2_token = self.login("u2", "pass")
u3 = self.register_user("u3", "pass")
u3_token = self.login("u3", "pass")
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.helper.send_state(
room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
)
# Begin the ingestion by creating the temp tables. This will also store
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
self.store._all_done = False
self.get_success(self.store.update_stats_stream_pos(None))
self.get_success(
self.store._simple_insert(
"background_updates",
{"update_name": "populate_stats_createtables", "progress_json": "{}"},
)
)
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
# Now, before the table is actually ingested, add some more events.
self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
self.helper.join(room=room_1, user=u2, tok=u2_token)
# Now do the initial ingestion.
self.get_success(
self.store._simple_insert(
"background_updates",
{"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
self.store._simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
"progress_json": "{}",
"depends_on": "populate_stats_process_rooms",
},
)
)
self.store._all_done = False
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
self.reactor.advance(86401)
# Now add some more events, triggering ingestion. Because of the stream
# position being set to before the events sent in the middle, a simpler
# implementation would reprocess those events, and say there were four
# users, not three.
self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
self.helper.join(room=room_1, user=u3, tok=u3_token)
# Get the deltas! There should be two -- day 1, and day 2.
r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
# The oldest has 2 joined members
self.assertEqual(r[-1]["joined_members"], 2)
# The newest has 3
self.assertEqual(r[0]["joined_members"], 3)
def test_incorrect_state_transition(self):
"""
If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to
(JOIN, INVITE, LEAVE, BAN), an error is raised.
"""
events = {
"a1": {"membership": Membership.LEAVE},
"a2": {"membership": "not a real thing"},
}
def get_event(event_id):
m = Mock()
m.content = events[event_id]
d = defer.Deferred()
self.reactor.callLater(0.0, d.callback, m)
return d
def get_received_ts(event_id):
return defer.succeed(1)
self.store.get_received_ts = get_received_ts
self.store.get_event = get_event
deltas = [
{
"type": EventTypes.Member,
"state_key": "some_user",
"room_id": "room",
"event_id": "a1",
"prev_event_id": "a2",
"stream_id": "bleb",
}
]
f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
self.assertEqual(
f.value.args[0], "'not a real thing' is not a valid prev_membership"
)
# And the other way...
deltas = [
{
"type": EventTypes.Member,
"state_key": "some_user",
"room_id": "room",
"event_id": "a2",
"prev_event_id": "a1",
"stream_id": "bleb",
}
]
f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
self.assertEqual(
f.value.args[0], "'not a real thing' is not a valid membership"
)

View file

@ -127,3 +127,20 @@ class RestHelper(object):
)
return channel.json_body
def send_state(self, room_id, event_type, body, tok, expect_code=200):
path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type)
if tok:
path = path + "?access_token=%s" % tok
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
)
return channel.json_body

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.rest.admin
from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
@ -32,6 +32,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
self.config = hs.config
return hs
def test_check_auth_required(self):
@ -51,8 +52,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
for room_version in capabilities['m.room_versions']['available'].keys():
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
self.assertEqual(
DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default']
self.config.default_room_version.identifier,
capabilities['m.room_versions']['default'],
)
def test_get_change_password_capabilities(self):

View file

@ -90,6 +90,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
def test_deny_double_react(self):
"""Test that we deny relations on membership events
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(400, channel.code, channel.json_body)
def test_basic_paginate_relations(self):
"""Tests that calling pagination API corectly the latest relations.
"""
@ -234,14 +243,30 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"""Test that we can paginate within an annotation group.
"""
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
idx = 0
while len(access_tokens) < 10:
user_id, token = self._create_user("test" + str(idx))
idx += 1
self.helper.join(self.room, user=user_id, tok=token)
access_tokens.append(token)
idx = 0
expected_event_ids = []
for _ in range(10):
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", key=u"👍"
RelationTypes.ANNOTATION,
"m.reaction",
key=u"👍",
access_token=access_tokens[idx],
)
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
idx += 1
# Also send a different type of reaction so that we test we don't see it
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)

View file

@ -17,6 +17,8 @@ import signedjson.key
from twisted.internet.defer import Deferred
from synapse.storage.keys import FetchKeyResult
import tests.unittest
KEY_1 = signedjson.key.decode_verify_key_base64(
@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
store = self.hs.get_datastore()
d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
self.get_success(d)
d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
d = store.store_server_verify_keys(
"from_server",
10,
[
("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
)
self.get_success(d)
d = store.get_server_verify_keys(
[
("server1", "ed25519:key1"),
("server1", "ed25519:key2"),
("server1", "ed25519:key3"),
]
[("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
)
res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
res1 = res[("server1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.verify_key.version, "key1")
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("server1", key_id_2)]
self.assertEqual(res2.verify_key, KEY_2)
# version comes from the ID it was stored with
self.assertEqual(res2.verify_key.version, "KEY_ID_2")
self.assertEqual(res2.valid_until_ts, 200)
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
self.get_success(d)
d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
d = store.store_server_verify_keys(
"from_server",
0,
[
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
)
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_2)], KEY_2)
res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, KEY_2)
self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)])
if isinstance(res, Deferred):
res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
d = store.store_server_verify_keys(
"from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
)
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_2)], new_key_2)
res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, new_key_2)
self.assertEqual(res2.valid_until_ts, 300)