0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-13 22:03:31 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into erikj/split_purge_history

This commit is contained in:
Erik Johnston 2019-10-31 15:19:26 +00:00
commit cd581338cf
94 changed files with 781 additions and 282 deletions

View file

@ -413,16 +413,18 @@ For a more detailed guide to configuring your server for federation, see
## Email ## Email
It is desirable for Synapse to have the capability to send email. For example, It is desirable for Synapse to have the capability to send email. This allows
this is required to support the 'password reset' feature. Synapse to send password reset emails, send verifications when an email address
is added to a user's account, and send email notifications to users when they
receive new messages.
To configure an SMTP server for Synapse, modify the configuration section To configure an SMTP server for Synapse, modify the configuration section
headed ``email``, and be sure to have at least the ``smtp_host``, ``smtp_port`` headed `email`, and be sure to have at least the `smtp_host`, `smtp_port`
and ``notif_from`` fields filled out. You may also need to set ``smtp_user``, and `notif_from` fields filled out. You may also need to set `smtp_user`,
``smtp_pass``, and ``require_transport_security``. `smtp_pass`, and `require_transport_security`.
If Synapse is not configured with an SMTP server, password reset via email will If email is not configured, password reset, registration and notifications via
be disabled by default. email will be disabled.
## Registering a user ## Registering a user

1
changelog.d/6259.misc Normal file
View file

@ -0,0 +1 @@
Expose some homeserver functionality to spam checkers.

1
changelog.d/6271.misc Normal file
View file

@ -0,0 +1 @@
Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated.

1
changelog.d/6272.doc Normal file
View file

@ -0,0 +1 @@
Update `INSTALL.md` Email section to talk about `account_threepid_delegates`.

1
changelog.d/6294.misc Normal file
View file

@ -0,0 +1 @@
Split out state storage into separate data store.

1
changelog.d/6300.misc Normal file
View file

@ -0,0 +1 @@
Move `persist_events` out from main data store.

1
changelog.d/6307.bugfix Normal file
View file

@ -0,0 +1 @@
Fix `/purge_room` admin API.

View file

@ -217,8 +217,9 @@ def main(args, environ):
# backwards-compatibility generate-a-config-on-the-fly mode # backwards-compatibility generate-a-config-on-the-fly mode
if "SYNAPSE_CONFIG_PATH" in environ: if "SYNAPSE_CONFIG_PATH" in environ:
error( error(
"SYNAPSE_SERVER_NAME and SYNAPSE_CONFIG_PATH are mutually exclusive " "SYNAPSE_SERVER_NAME can only be combined with SYNAPSE_CONFIG_PATH "
"except in `generate` or `migrate_config` mode." "in `generate` or `migrate_config` mode. To start synapse using a "
"config file, unset the SYNAPSE_SERVER_NAME environment variable."
) )
config_path = "/compiled/homeserver.yaml" config_path = "/compiled/homeserver.yaml"

View file

@ -72,7 +72,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
# check that the original exists # check that the original exists
original_file = src_paths.remote_media_filepath(origin_server, file_id) original_file = src_paths.remote_media_filepath(origin_server, file_id)
if not os.path.exists(original_file): if not os.path.exists(original_file):
logger.warn( logger.warning(
"Original for %s/%s (%s) does not exist", "Original for %s/%s (%s) does not exist",
origin_server, origin_server,
file_id, file_id,

View file

@ -157,7 +157,7 @@ class Store(
) )
except self.database_engine.module.DatabaseError as e: except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e): if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
if i < N: if i < N:
i += 1 i += 1
conn.rollback() conn.rollback()
@ -432,7 +432,7 @@ class Porter(object):
for row in rows: for row in rows:
d = dict(zip(headers, row)) d = dict(zip(headers, row))
if "\0" in d['value']: if "\0" in d['value']:
logger.warn('dropping search row %s', d) logger.warning('dropping search row %s', d)
else: else:
rows_dict.append(d) rows_dict.append(d)
@ -647,7 +647,7 @@ class Porter(object):
if isinstance(col, bytes): if isinstance(col, bytes):
return bytearray(col) return bytearray(col)
elif isinstance(col, string_types) and "\0" in col: elif isinstance(col, string_types) and "\0" in col:
logger.warn( logger.warning(
"DROPPING ROW: NUL value in table %s col %s: %r", "DROPPING ROW: NUL value in table %s col %s: %r",
table, table,
headers[j], headers[j],

View file

@ -497,7 +497,7 @@ class Auth(object):
token = self.get_access_token_from_request(request) token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token.") logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError() raise InvalidClientTokenError()
request.authenticated_entity = service.sender request.authenticated_entity = service.sender
return defer.succeed(service) return defer.succeed(service)

View file

@ -44,6 +44,8 @@ def check_bind_error(e, address, bind_addresses):
bind_addresses (list): Addresses on which the service listens. bind_addresses (list): Addresses on which the service listens.
""" """
if address == "0.0.0.0" and "::" in bind_addresses: if address == "0.0.0.0" and "::" in bind_addresses:
logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") logger.warning(
"Failed to listen on 0.0.0.0, continuing because listening on [::]"
)
else: else:
raise e raise e

View file

@ -94,7 +94,7 @@ class AppserviceServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -103,7 +103,7 @@ class AppserviceServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -153,7 +153,7 @@ class ClientReaderServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -162,7 +162,7 @@ class ClientReaderServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -147,7 +147,7 @@ class EventCreatorServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -156,7 +156,7 @@ class EventCreatorServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -132,7 +132,7 @@ class FederationReaderServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -141,7 +141,7 @@ class FederationReaderServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -123,7 +123,7 @@ class FederationSenderServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -132,7 +132,7 @@ class FederationSenderServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -204,7 +204,7 @@ class FrontendProxyServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -213,7 +213,7 @@ class FrontendProxyServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -282,7 +282,7 @@ class SynapseHomeServer(HomeServer):
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -291,7 +291,7 @@ class SynapseHomeServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
def run_startup_checks(self, db_conn, database_engine): def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain( all_users_native = are_all_users_on_domain(
@ -569,7 +569,7 @@ def run(hs):
hs.config.report_stats_endpoint, stats hs.config.report_stats_endpoint, stats
) )
except Exception as e: except Exception as e:
logger.warn("Error reporting stats: %s", e) logger.warning("Error reporting stats: %s", e)
def performance_stats_init(): def performance_stats_init():
try: try:

View file

@ -120,7 +120,7 @@ class MediaRepositoryServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -129,7 +129,7 @@ class MediaRepositoryServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -114,7 +114,7 @@ class PusherServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -123,7 +123,7 @@ class PusherServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -326,7 +326,7 @@ class SynchrotronServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -335,7 +335,7 @@ class SynchrotronServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -150,7 +150,7 @@ class UserDirectoryServer(HomeServer):
) )
elif listener["type"] == "metrics": elif listener["type"] == "metrics":
if not self.get_config().enable_metrics: if not self.get_config().enable_metrics:
logger.warn( logger.warning(
( (
"Metrics listener configured, but " "Metrics listener configured, but "
"enable_metrics is not True!" "enable_metrics is not True!"
@ -159,7 +159,7 @@ class UserDirectoryServer(HomeServer):
else: else:
_base.listen_metrics(listener["bind_addresses"], listener["port"]) _base.listen_metrics(listener["bind_addresses"], listener["port"])
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)

View file

@ -125,7 +125,7 @@ class KeyConfig(Config):
# if neither trusted_key_servers nor perspectives are given, use the default. # if neither trusted_key_servers nor perspectives are given, use the default.
if "perspectives" not in config and "trusted_key_servers" not in config: if "perspectives" not in config and "trusted_key_servers" not in config:
logger.warn(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN)
key_servers = [{"server_name": "matrix.org"}] key_servers = [{"server_name": "matrix.org"}]
else: else:
key_servers = config.get("trusted_key_servers", []) key_servers = config.get("trusted_key_servers", [])
@ -156,7 +156,7 @@ class KeyConfig(Config):
if not self.macaroon_secret_key: if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this # Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key. # set. Lets just be "nice" and derive one from their secret key.
logger.warn("Config is missing macaroon_secret_key") logger.warning("Config is missing macaroon_secret_key")
seed = bytes(self.signing_key[0]) seed = bytes(self.signing_key[0])
self.macaroon_secret_key = hashlib.sha256(seed).digest() self.macaroon_secret_key = hashlib.sha256(seed).digest()

View file

@ -182,7 +182,7 @@ def _reload_stdlib_logging(*args, log_config=None):
logger = logging.getLogger("") logger = logging.getLogger("")
if not log_config: if not log_config:
logger.warn("Reloaded a blank config?") logger.warning("Reloaded a blank config?")
logging.config.dictConfig(log_config) logging.config.dictConfig(log_config)

View file

@ -77,7 +77,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if auth_events is None: if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now) # are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id) logger.warning("Trusting event: %s", event.event_id)
return return
if event.type == EventTypes.Create: if event.type == EventTypes.Create:

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd # Copyright 2017 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,6 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from synapse.spam_checker_api import SpamCheckerApi
class SpamChecker(object): class SpamChecker(object):
def __init__(self, hs): def __init__(self, hs):
@ -26,7 +31,14 @@ class SpamChecker(object):
pass pass
if module is not None: if module is not None:
self.spam_checker = module(config=config) # Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
if "api" in spam_args.args:
api = SpamCheckerApi(hs)
self.spam_checker = module(config=config, api=api)
else:
self.spam_checker = module(config=config)
def check_event_for_spam(self, event): def check_event_for_spam(self, event):
"""Checks if a given event is considered "spammy" by this server. """Checks if a given event is considered "spammy" by this server.

View file

@ -102,7 +102,7 @@ class FederationBase(object):
pass pass
if not res: if not res:
logger.warn( logger.warning(
"Failed to find copy of %s with valid signature", pdu.event_id "Failed to find copy of %s with valid signature", pdu.event_id
) )
@ -173,7 +173,7 @@ class FederationBase(object):
return redacted_event return redacted_event
if self.spam_checker.check_event_for_spam(pdu): if self.spam_checker.check_event_for_spam(pdu):
logger.warn( logger.warning(
"Event contains spam, redacting %s: %s", "Event contains spam, redacting %s: %s",
pdu.event_id, pdu.event_id,
pdu.get_pdu_json(), pdu.get_pdu_json(),
@ -185,7 +185,7 @@ class FederationBase(object):
def errback(failure, pdu): def errback(failure, pdu):
failure.trap(SynapseError) failure.trap(SynapseError)
with PreserveLoggingContext(ctx): with PreserveLoggingContext(ctx):
logger.warn( logger.warning(
"Signature check failed for %s: %s", "Signature check failed for %s: %s",
pdu.event_id, pdu.event_id,
failure.getErrorMessage(), failure.getErrorMessage(),

View file

@ -522,12 +522,12 @@ class FederationClient(FederationBase):
res = yield callback(destination) res = yield callback(destination)
return res return res
except InvalidResponseError as e: except InvalidResponseError as e:
logger.warn("Failed to %s via %s: %s", description, destination, e) logger.warning("Failed to %s via %s: %s", description, destination, e)
except HttpResponseException as e: except HttpResponseException as e:
if not 500 <= e.code < 600: if not 500 <= e.code < 600:
raise e.to_synapse_error() raise e.to_synapse_error()
else: else:
logger.warn( logger.warning(
"Failed to %s via %s: %i %s", "Failed to %s via %s: %i %s",
description, description,
destination, destination,
@ -535,7 +535,9 @@ class FederationClient(FederationBase):
e.args[0], e.args[0],
) )
except Exception: except Exception:
logger.warn("Failed to %s via %s", description, destination, exc_info=1) logger.warning(
"Failed to %s via %s", description, destination, exc_info=1
)
raise SynapseError(502, "Failed to %s via any server" % (description,)) raise SynapseError(502, "Failed to %s via any server" % (description,))

View file

@ -220,7 +220,7 @@ class FederationServer(FederationBase):
try: try:
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e: except AuthError as e:
logger.warn("Ignoring PDUs for room %s from banned server", room_id) logger.warning("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]: for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id event_id = pdu.event_id
pdu_results[event_id] = e.error_dict() pdu_results[event_id] = e.error_dict()
@ -233,7 +233,7 @@ class FederationServer(FederationBase):
await self._handle_received_pdu(origin, pdu) await self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {} pdu_results[event_id] = {}
except FederationError as e: except FederationError as e:
logger.warn("Error handling PDU %s: %s", event_id, e) logger.warning("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)} pdu_results[event_id] = {"error": str(e)}
except Exception as e: except Exception as e:
f = failure.Failure() f = failure.Failure()
@ -333,7 +333,9 @@ class FederationServer(FederationBase):
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
if room_version not in supported_versions: if room_version not in supported_versions:
logger.warn("Room version %s not in %s", room_version, supported_versions) logger.warning(
"Room version %s not in %s", room_version, supported_versions
)
raise IncompatibleRoomVersionError(room_version=room_version) raise IncompatibleRoomVersionError(room_version=room_version)
pdu = await self.handler.on_make_join_request(origin, room_id, user_id) pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
@ -679,7 +681,7 @@ def server_matches_acl_event(server_name, acl_event):
# server name is a literal IP # server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True) allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool): if not isinstance(allow_ip_literals, bool):
logger.warn("Ignorning non-bool allow_ip_literals flag") logger.warning("Ignorning non-bool allow_ip_literals flag")
allow_ip_literals = True allow_ip_literals = True
if not allow_ip_literals: if not allow_ip_literals:
# check for ipv6 literals. These start with '['. # check for ipv6 literals. These start with '['.
@ -693,7 +695,7 @@ def server_matches_acl_event(server_name, acl_event):
# next, check the deny list # next, check the deny list
deny = acl_event.content.get("deny", []) deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)): if not isinstance(deny, (list, tuple)):
logger.warn("Ignorning non-list deny ACL %s", deny) logger.warning("Ignorning non-list deny ACL %s", deny)
deny = [] deny = []
for e in deny: for e in deny:
if _acl_entry_matches(server_name, e): if _acl_entry_matches(server_name, e):
@ -703,7 +705,7 @@ def server_matches_acl_event(server_name, acl_event):
# then the allow list. # then the allow list.
allow = acl_event.content.get("allow", []) allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)): if not isinstance(allow, (list, tuple)):
logger.warn("Ignorning non-list allow ACL %s", allow) logger.warning("Ignorning non-list allow ACL %s", allow)
allow = [] allow = []
for e in allow: for e in allow:
if _acl_entry_matches(server_name, e): if _acl_entry_matches(server_name, e):
@ -717,7 +719,7 @@ def server_matches_acl_event(server_name, acl_event):
def _acl_entry_matches(server_name, acl_entry): def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types): if not isinstance(acl_entry, six.string_types):
logger.warn( logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
) )
return False return False
@ -772,7 +774,7 @@ class FederationHandlerRegistry(object):
async def on_edu(self, edu_type, origin, content): async def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type) handler = self.edu_handlers.get(edu_type)
if not handler: if not handler:
logger.warn("No handler registered for EDU type %s", edu_type) logger.warning("No handler registered for EDU type %s", edu_type)
with start_active_span_from_edu(content, "handle_edu"): with start_active_span_from_edu(content, "handle_edu"):
try: try:
@ -785,7 +787,7 @@ class FederationHandlerRegistry(object):
def on_query(self, query_type, args): def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type) handler = self.query_handlers.get(query_type)
if not handler: if not handler:
logger.warn("No handler registered for query type %s", query_type) logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,)) raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args) return handler(args)

View file

@ -146,7 +146,7 @@ class TransactionManager(object):
if code == 200: if code == 200:
for e_id, r in response.get("pdus", {}).items(): for e_id, r in response.get("pdus", {}).items():
if "error" in r: if "error" in r:
logger.warn( logger.warning(
"TX [%s] {%s} Remote returned error for %s: %s", "TX [%s] {%s} Remote returned error for %s: %s",
destination, destination,
txn_id, txn_id,
@ -155,7 +155,7 @@ class TransactionManager(object):
) )
else: else:
for p in pdus: for p in pdus:
logger.warn( logger.warning(
"TX [%s] {%s} Failed to send event %s", "TX [%s] {%s} Failed to send event %s",
destination, destination,
txn_id, txn_id,

View file

@ -202,7 +202,7 @@ def _parse_auth_header(header_bytes):
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])
return origin, key, sig return origin, key, sig
except Exception as e: except Exception as e:
logger.warn( logger.warning(
"Error parsing auth header '%s': %s", "Error parsing auth header '%s': %s",
header_bytes.decode("ascii", "replace"), header_bytes.decode("ascii", "replace"),
e, e,
@ -287,10 +287,12 @@ class BaseFederationServlet(object):
except NoAuthenticationError: except NoAuthenticationError:
origin = None origin = None
if self.REQUIRE_AUTH: if self.REQUIRE_AUTH:
logger.warn("authenticate_request failed: missing authentication") logger.warning(
"authenticate_request failed: missing authentication"
)
raise raise
except Exception as e: except Exception as e:
logger.warn("authenticate_request failed: %s", e) logger.warning("authenticate_request failed: %s", e)
raise raise
request_tags = { request_tags = {

View file

@ -181,7 +181,7 @@ class GroupAttestionRenewer(object):
elif not self.is_mine_id(user_id): elif not self.is_mine_id(user_id):
destination = get_domain_from_id(user_id) destination = get_domain_from_id(user_id)
else: else:
logger.warn( logger.warning(
"Incorrectly trying to do attestations for user: %r in %r", "Incorrectly trying to do attestations for user: %r in %r",
user_id, user_id,
group_id, group_id,

View file

@ -488,7 +488,7 @@ class GroupsServerHandler(object):
profile = yield self.profile_handler.get_profile_from_cache(user_id) profile = yield self.profile_handler.get_profile_from_cache(user_id)
user_profile.update(profile) user_profile.update(profile)
except Exception as e: except Exception as e:
logger.warn("Error getting profile for %s: %s", user_id, e) logger.warning("Error getting profile for %s: %s", user_id, e)
user_profiles.append(user_profile) user_profiles.append(user_profile)
return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}

View file

@ -30,6 +30,9 @@ class AdminHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(AdminHandler, self).__init__(hs) super(AdminHandler, self).__init__(hs)
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks @defer.inlineCallbacks
def get_whois(self, user): def get_whois(self, user):
connections = [] connections = []
@ -205,7 +208,7 @@ class AdminHandler(BaseHandler):
from_key = events[-1].internal_metadata.after from_key = events[-1].internal_metadata.after
events = yield filter_events_for_client(self.store, user_id, events) events = yield filter_events_for_client(self.storage, user_id, events)
writer.write_events(room_id, events) writer.write_events(room_id, events)
@ -241,7 +244,7 @@ class AdminHandler(BaseHandler):
for event_id in extremities: for event_id in extremities:
if not event_to_unseen_prevs[event_id]: if not event_to_unseen_prevs[event_id]:
continue continue
state = yield self.store.get_state_for_event(event_id) state = yield self.state_store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state) writer.write_state(room_id, event_id, state)
return writer.finished() return writer.finished()

View file

@ -525,7 +525,7 @@ class AuthHandler(BaseHandler):
result = None result = None
if not user_infos: if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warning("Attempted to login as %s but they do not exist", user_id)
elif len(user_infos) == 1: elif len(user_infos) == 1:
# a single match (possibly not exact) # a single match (possibly not exact)
result = user_infos.popitem() result = user_infos.popitem()
@ -534,7 +534,7 @@ class AuthHandler(BaseHandler):
result = (user_id, user_infos[user_id]) result = (user_id, user_infos[user_id])
else: else:
# multiple matches, none of them exact # multiple matches, none of them exact
logger.warn( logger.warning(
"Attempted to login as %s but it matches more than one user " "Attempted to login as %s but it matches more than one user "
"inexactly: %r", "inexactly: %r",
user_id, user_id,
@ -728,7 +728,7 @@ class AuthHandler(BaseHandler):
result = yield self.validate_hash(password, password_hash) result = yield self.validate_hash(password, password_hash)
if not result: if not result:
logger.warn("Failed password login for user %s", user_id) logger.warning("Failed password login for user %s", user_id)
return None return None
return user_id return user_id

View file

@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler):
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@trace @trace
@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler):
continue continue
# mapping from event_id -> state_dict # mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to # Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users. # the "possibly changed" users.
@ -656,7 +657,7 @@ class DeviceListUpdater(object):
except (NotRetryingDestination, RequestSendFailed, HttpResponseException): except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
# TODO: Remember that we are now out of sync and try again # TODO: Remember that we are now out of sync and try again
# later # later
logger.warn("Failed to handle device list update for %s", user_id) logger.warning("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update # We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list # as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync # is out of date. If we bail then we will retry the resync
@ -694,7 +695,7 @@ class DeviceListUpdater(object):
# up on storing the total list of devices and only handle the # up on storing the total list of devices and only handle the
# delta instead. # delta instead.
if len(devices) > 1000: if len(devices) > 1000:
logger.warn( logger.warning(
"Ignoring device list snapshot for %s as it has >1K devs (%d)", "Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id, user_id,
len(devices), len(devices),

View file

@ -52,7 +52,7 @@ class DeviceMessageHandler(object):
local_messages = {} local_messages = {}
sender_user_id = content["sender"] sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id): if origin != get_domain_from_id(sender_user_id):
logger.warn( logger.warning(
"Dropping device message from %r with spoofed sender %r", "Dropping device message from %r with spoofed sender %r",
origin, origin,
sender_user_id, sender_user_id,

View file

@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler): class EventHandler(BaseHandler):
def __init__(self, hs):
super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, user, room_id, event_id): def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event. """Retrieve a single specified event.
@ -172,7 +176,7 @@ class EventHandler(BaseHandler):
is_peeking = user.to_string() not in users is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client( filtered = yield filter_events_for_client(
self.store, user.to_string(), [event], is_peeking=is_peeking self.storage, user.to_string(), [event], is_peeking=is_peeking
) )
if not filtered: if not filtered:

View file

@ -110,6 +110,7 @@ class FederationHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state
self.federation_client = hs.get_federation_client() self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
@ -181,7 +182,7 @@ class FederationHandler(BaseHandler):
try: try:
self._sanity_check_event(pdu) self._sanity_check_event(pdu)
except SynapseError as err: except SynapseError as err:
logger.warn( logger.warning(
"[%s %s] Received event failed sanity checks", room_id, event_id "[%s %s] Received event failed sanity checks", room_id, event_id
) )
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
@ -302,7 +303,7 @@ class FederationHandler(BaseHandler):
# following. # following.
if sent_to_us_directly: if sent_to_us_directly:
logger.warn( logger.warning(
"[%s %s] Rejecting: failed to fetch %d prev events: %s", "[%s %s] Rejecting: failed to fetch %d prev events: %s",
room_id, room_id,
event_id, event_id,
@ -325,7 +326,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu} event_map = {event_id: pdu}
try: try:
# Get the state of the events we know about # Get the state of the events we know about
ours = yield self.store.get_state_groups_ids(room_id, seen) ours = yield self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list( state_maps = list(
@ -406,7 +407,7 @@ class FederationHandler(BaseHandler):
state = [event_map[e] for e in six.itervalues(state_map)] state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains) auth_chain = list(auth_chains)
except Exception: except Exception:
logger.warn( logger.warning(
"[%s %s] Error attempting to resolve state at missing " "[%s %s] Error attempting to resolve state at missing "
"prev_events", "prev_events",
room_id, room_id,
@ -519,7 +520,9 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle # We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary # the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue. # events anyway, it is safe to simply log the error and continue.
logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) logger.warning(
"[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
)
return return
logger.info( logger.info(
@ -546,7 +549,7 @@ class FederationHandler(BaseHandler):
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e: except FederationError as e:
if e.code == 403: if e.code == 403:
logger.warn( logger.warning(
"[%s %s] Received prev_event %s failed history check.", "[%s %s] Received prev_event %s failed history check.",
room_id, room_id,
event_id, event_id,
@ -889,7 +892,7 @@ class FederationHandler(BaseHandler):
# We set `check_history_visibility_only` as we might otherwise get false # We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased. # positives from users having been erased.
filtered_extremities = yield filter_events_for_server( filtered_extremities = yield filter_events_for_server(
self.store, self.storage,
self.server_name, self.server_name,
list(extremities_events.values()), list(extremities_events.values()),
redact=False, redact=False,
@ -1060,7 +1063,7 @@ class FederationHandler(BaseHandler):
SynapseError if the event does not pass muster SynapseError if the event does not pass muster
""" """
if len(ev.prev_event_ids()) > 20: if len(ev.prev_event_ids()) > 20:
logger.warn( logger.warning(
"Rejecting event %s which has %i prev_events", "Rejecting event %s which has %i prev_events",
ev.event_id, ev.event_id,
len(ev.prev_event_ids()), len(ev.prev_event_ids()),
@ -1068,7 +1071,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
if len(ev.auth_event_ids()) > 10: if len(ev.auth_event_ids()) > 10:
logger.warn( logger.warning(
"Rejecting event %s which has %i auth_events", "Rejecting event %s which has %i auth_events",
ev.event_id, ev.event_id,
len(ev.auth_event_ids()), len(ev.auth_event_ids()),
@ -1204,7 +1207,7 @@ class FederationHandler(BaseHandler):
with nested_logging_context(p.event_id): with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e: except Exception as e:
logger.warn( logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
) )
@ -1251,7 +1254,7 @@ class FederationHandler(BaseHandler):
builder=builder builder=builder
) )
except AuthError as e: except AuthError as e:
logger.warn("Failed to create join to %s because %s", room_id, e) logger.warning("Failed to create join to %s because %s", room_id, e)
raise e raise e
event_allowed = yield self.third_party_event_rules.check_event_allowed( event_allowed = yield self.third_party_event_rules.check_event_allowed(
@ -1495,7 +1498,7 @@ class FederationHandler(BaseHandler):
room_version, event, context, do_sig_check=False room_version, event, context, do_sig_check=False
) )
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warning("Failed to create new leave %r because %s", event, e)
raise e raise e
return event return event
@ -1550,7 +1553,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id event_id, allow_none=False, check_room_id=room_id
) )
state_groups = yield self.store.get_state_groups(room_id, [event_id]) state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
if state_groups: if state_groups:
_, state = list(iteritems(state_groups)).pop() _, state = list(iteritems(state_groups)).pop()
@ -1579,7 +1582,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id event_id, allow_none=False, check_room_id=room_id
) )
state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups: if state_groups:
_, state = list(state_groups.items()).pop() _, state = list(state_groups.items()).pop()
@ -1607,7 +1610,7 @@ class FederationHandler(BaseHandler):
events = yield self.store.get_backfill_events(room_id, pdu_list, limit) events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield filter_events_for_server(self.store, origin, events) events = yield filter_events_for_server(self.storage, origin, events)
return events return events
@ -1637,7 +1640,7 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield filter_events_for_server(self.store, origin, [event]) events = yield filter_events_for_server(self.storage, origin, [event])
event = events[0] event = events[0]
return event return event
else: else:
@ -1789,7 +1792,7 @@ class FederationHandler(BaseHandler):
# cause SynapseErrors in auth.check. We don't want to give up # cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases. # the attempt to federate altogether in such cases.
logger.warn("Rejecting %s because %s", e.event_id, err.msg) logger.warning("Rejecting %s because %s", e.event_id, err.msg)
if e == event: if e == event:
raise raise
@ -1845,7 +1848,9 @@ class FederationHandler(BaseHandler):
try: try:
yield self.do_auth(origin, event, context, auth_events=auth_events) yield self.do_auth(origin, event, context, auth_events=auth_events)
except AuthError as e: except AuthError as e:
logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) logger.warning(
"[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
)
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
@ -1903,7 +1908,7 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases # given state at the event. This should correctly handle cases
# like bans, especially with state res v2. # like bans, especially with state res v2.
state_sets = yield self.store.get_state_groups( state_sets = yield self.state_store.get_state_groups(
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets = list(state_sets.values()) state_sets = list(state_sets.values())
@ -1939,7 +1944,7 @@ class FederationHandler(BaseHandler):
try: try:
event_auth.check(room_version, event, auth_events=current_auth_events) event_auth.check(room_version, event, auth_events=current_auth_events)
except AuthError as e: except AuthError as e:
logger.warn("Soft-failing %r because %s", event, e) logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True event.internal_metadata.soft_failed = True
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1994,7 +1999,7 @@ class FederationHandler(BaseHandler):
) )
missing_events = yield filter_events_for_server( missing_events = yield filter_events_for_server(
self.store, origin, missing_events self.storage, origin, missing_events
) )
return missing_events return missing_events
@ -2038,7 +2043,7 @@ class FederationHandler(BaseHandler):
try: try:
event_auth.check(room_version, event, auth_events=auth_events) event_auth.check(room_version, event, auth_events=auth_events)
except AuthError as e: except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e) logger.warning("Failed auth resolution for %r because %s", event, e)
raise e raise e
@defer.inlineCallbacks @defer.inlineCallbacks
@ -2235,7 +2240,7 @@ class FederationHandler(BaseHandler):
# create a new state group as a delta from the existing one. # create a new state group as a delta from the existing one.
prev_group = context.state_group prev_group = context.state_group
state_group = yield self.store.store_state_group( state_group = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=prev_group, prev_group=prev_group,
@ -2432,7 +2437,7 @@ class FederationHandler(BaseHandler):
try: try:
yield self.auth.check_from_context(room_version, event, context) yield self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e) logger.warning("Denying new third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, context) yield self._check_signature(event, context)
@ -2488,7 +2493,7 @@ class FederationHandler(BaseHandler):
try: try:
yield self.auth.check_from_context(room_version, event, context) yield self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e) logger.warning("Denying third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, context) yield self._check_signature(event, context)

View file

@ -392,7 +392,7 @@ class GroupsLocalHandler(object):
try: try:
user_profile = yield self.profile_handler.get_profile(user_id) user_profile = yield self.profile_handler.get_profile(user_id)
except Exception as e: except Exception as e:
logger.warn("No profile for user %s: %s", user_id, e) logger.warning("No profile for user %s: %s", user_id, e)
user_profile = {} user_profile = {}
return {"state": "invite", "user_profile": user_profile} return {"state": "invite", "user_profile": user_profile}

View file

@ -272,7 +272,7 @@ class IdentityHandler(BaseHandler):
changed = False changed = False
if e.code in (400, 404, 501): if e.code in (400, 404, 501):
# The remote server probably doesn't support unbinding (yet) # The remote server probably doesn't support unbinding (yet)
logger.warn("Received %d response while unbinding threepid", e.code) logger.warning("Received %d response while unbinding threepid", e.code)
else: else:
logger.error("Failed to unbind threepid on identity server: %s", e) logger.error("Failed to unbind threepid on identity server: %s", e)
raise SynapseError(500, "Failed to contact identity server") raise SynapseError(500, "Failed to contact identity server")
@ -403,7 +403,7 @@ class IdentityHandler(BaseHandler):
if self.hs.config.using_identity_server_from_trusted_list: if self.hs.config.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use # Warn that a deprecated config option is in use
logger.warn( logger.warning(
'The config option "trust_identity_server_for_password_resets" ' 'The config option "trust_identity_server_for_password_resets" '
'has been replaced by "account_threepid_delegate". ' 'has been replaced by "account_threepid_delegate". '
"Please consult the sample config at docs/sample_config.yaml for " "Please consult the sample config at docs/sample_config.yaml for "
@ -457,7 +457,7 @@ class IdentityHandler(BaseHandler):
if self.hs.config.using_identity_server_from_trusted_list: if self.hs.config.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use # Warn that a deprecated config option is in use
logger.warn( logger.warning(
'The config option "trust_identity_server_for_password_resets" ' 'The config option "trust_identity_server_for_password_resets" '
'has been replaced by "account_threepid_delegate". ' 'has been replaced by "account_threepid_delegate". '
"Please consult the sample config at docs/sample_config.yaml for " "Please consult the sample config at docs/sample_config.yaml for "

View file

@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler):
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache() self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
def snapshot_all_rooms( def snapshot_all_rooms(
self, self,
@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler):
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,) room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.store.get_state_for_events, [event.event_id] self.state_store.get_state_for_events, [event.event_id]
) )
deferred_room_state.addCallback( deferred_room_state.addCallback(
lambda states: states[event.event_id] lambda states: states[event.event_id]
@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler):
) )
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(self.store, user_id, messages) messages = yield filter_events_for_client(
self.storage, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token) start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", room_end_token) end_token = now_token.copy_and_replace("room_key", room_end_token)
@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted( def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
): ):
room_state = yield self.store.get_state_for_events([member_event_id]) room_state = yield self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id] room_state = room_state[member_event_id]
@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler):
) )
messages = yield filter_events_for_client( messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking self.storage, user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken.START.copy_and_replace("room_key", token) start_token = StreamToken.START.copy_and_replace("room_key", token)
@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler):
) )
messages = yield filter_events_for_client( messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking self.storage, user_id, messages, is_peeking=is_peeking
) )
start_token = now_token.copy_and_replace("room_key", token) start_token = now_token.copy_and_replace("room_key", token)

View file

@ -59,6 +59,8 @@ class MessageHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks @defer.inlineCallbacks
@ -82,7 +84,7 @@ class MessageHandler(object):
data = yield self.state.get_current_state(room_id, event_type, state_key) data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.store.get_state_for_events( room_state = yield self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key]) [membership_event_id], StateFilter.from_types([key])
) )
data = room_state[membership_event_id].get(key) data = room_state[membership_event_id].get(key)
@ -135,12 +137,12 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,)) raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client( visible_events = yield filter_events_for_client(
self.store, user_id, last_events self.storage, user_id, last_events
) )
event = last_events[0] event = last_events[0]
if visible_events: if visible_events:
room_state = yield self.store.get_state_for_events( room_state = yield self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter [event.event_id], state_filter=state_filter
) )
room_state = room_state[event.event_id] room_state = room_state[event.event_id]
@ -161,7 +163,7 @@ class MessageHandler(object):
) )
room_state = yield self.store.get_events(state_ids.values()) room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events( room_state = yield self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter [membership_event_id], state_filter=state_filter
) )
room_state = room_state[membership_event_id] room_state = room_state[membership_event_id]
@ -688,7 +690,7 @@ class EventCreationHandler(object):
try: try:
yield self.auth.check_from_context(room_version, event, context) yield self.auth.check_from_context(room_version, event, context)
except AuthError as err: except AuthError as err:
logger.warn("Denying new event %r because %s", event, err) logger.warning("Denying new event %r because %s", event, err)
raise err raise err
# Ensure that we can round trip before trying to persist in db # Ensure that we can round trip before trying to persist in db

View file

@ -70,6 +70,7 @@ class PaginationHandler(object):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._server_name = hs.hostname self._server_name = hs.hostname
@ -258,7 +259,7 @@ class PaginationHandler(object):
events = event_filter.filter(events) events = event_filter.filter(events)
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, user_id, events, is_peeking=(member_event_id is None) self.storage, user_id, events, is_peeking=(member_event_id is None)
) )
if not events: if not events:
@ -277,7 +278,7 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events (EventTypes.Member, event.sender) for event in events
) )
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter events[0].event_id, state_filter=state_filter
) )

View file

@ -275,7 +275,7 @@ class BaseProfileHandler(BaseHandler):
ratelimit=False, # Try to hide that these events aren't atomic. ratelimit=False, # Try to hide that these events aren't atomic.
) )
except Exception as e: except Exception as e:
logger.warn( logger.warning(
"Failed to update join event for room %s - %s", room_id, str(e) "Failed to update join event for room %s - %s", room_id, str(e)
) )

View file

@ -822,6 +822,8 @@ class RoomContextHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, event_filter): def get_event_context(self, user, room_id, event_id, limit, event_filter):
@ -848,7 +850,7 @@ class RoomContextHandler(object):
def filter_evts(events): def filter_evts(events):
return filter_events_for_client( return filter_events_for_client(
self.store, user.to_string(), events, is_peeking=is_peeking self.storage, user.to_string(), events, is_peeking=is_peeking
) )
event = yield self.store.get_event( event = yield self.store.get_event(
@ -890,7 +892,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync? # first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687 # https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events( state = yield self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter [last_event_id], state_filter=state_filter
) )
results["state"] = list(state[last_event_id].values()) results["state"] = list(state[last_event_id].values())
@ -922,7 +924,7 @@ class RoomEventSource(object):
from_token = RoomStreamToken.parse(from_key) from_token = RoomStreamToken.parse(from_key)
if from_token.topological: if from_token.topological:
logger.warn("Stream has topological part!!!! %r", from_key) logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,) from_key = "s%s" % (from_token.stream,)
app_service = self.store.get_app_service_by_user_id(user.to_string()) app_service = self.store.get_app_service_by_user_id(user.to_string())

View file

@ -35,6 +35,8 @@ class SearchHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(SearchHandler, self).__init__(hs) super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks @defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id): def get_old_rooms_from_upgraded_room(self, room_id):
@ -221,7 +223,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events self.storage, user.to_string(), filtered_events
) )
events.sort(key=lambda e: -rank_map[e.event_id]) events.sort(key=lambda e: -rank_map[e.event_id])
@ -271,7 +273,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events self.storage, user.to_string(), filtered_events
) )
room_events.extend(events) room_events.extend(events)
@ -340,11 +342,11 @@ class SearchHandler(BaseHandler):
) )
res["events_before"] = yield filter_events_for_client( res["events_before"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_before"] self.storage, user.to_string(), res["events_before"]
) )
res["events_after"] = yield filter_events_for_client( res["events_after"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_after"] self.storage, user.to_string(), res["events_after"]
) )
res["start"] = now_token.copy_and_replace( res["start"] = now_token.copy_and_replace(
@ -372,7 +374,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders] [(EventTypes.Member, sender) for sender in senders]
) )
state = yield self.store.get_state_for_event( state = yield self.state_store.get_state_for_event(
last_event_id, state_filter last_event_id, state_filter
) )

View file

@ -230,6 +230,8 @@ class SyncHandler(object):
self.response_cache = ResponseCache(hs, "sync") self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.storage = hs.get_storage()
self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id) # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache( self.lazy_loaded_members_cache = ExpiringCache(
@ -417,7 +419,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids)) current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client( recents = yield filter_events_for_client(
self.store, self.storage,
sync_config.user.to_string(), sync_config.user.to_string(),
recents, recents,
always_include_ids=current_state_ids, always_include_ids=current_state_ids,
@ -470,7 +472,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids)) current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client( loaded_recents = yield filter_events_for_client(
self.store, self.storage,
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
always_include_ids=current_state_ids, always_include_ids=current_state_ids,
@ -509,7 +511,7 @@ class SyncHandler(object):
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter event.event_id, state_filter=state_filter
) )
if event.is_state(): if event.is_state():
@ -580,7 +582,7 @@ class SyncHandler(object):
return None return None
last_event = last_events[-1] last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
last_event.event_id, last_event.event_id,
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@ -757,11 +759,11 @@ class SyncHandler(object):
if full_state: if full_state:
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id, state_filter=state_filter
) )
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id, state_filter=state_filter
) )
@ -781,7 +783,7 @@ class SyncHandler(object):
) )
elif batch.limited: elif batch.limited:
if batch: if batch:
state_at_timeline_start = yield self.store.get_state_ids_for_event( state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id, state_filter=state_filter
) )
else: else:
@ -810,7 +812,7 @@ class SyncHandler(object):
) )
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id, state_filter=state_filter
) )
else: else:
@ -841,7 +843,7 @@ class SyncHandler(object):
# So we fish out all the member events corresponding to the # So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below. # timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, batch.events[0].event_id,
# we only want members! # we only want members!
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(

View file

@ -535,7 +535,7 @@ class SimpleHttpClient(object):
b"Content-Length" in resp_headers b"Content-Length" in resp_headers
and int(resp_headers[b"Content-Length"][0]) > max_size and int(resp_headers[b"Content-Length"][0]) > max_size
): ):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError( raise SynapseError(
502, 502,
"Requested file is too large > %r bytes" % (self.max_size,), "Requested file is too large > %r bytes" % (self.max_size,),
@ -543,7 +543,7 @@ class SimpleHttpClient(object):
) )
if response.code > 299: if response.code > 299:
logger.warn("Got %d when downloading %s" % (response.code, url)) logger.warning("Got %d when downloading %s" % (response.code, url))
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
# TODO: if our Content-Type is HTML or something, just read the first # TODO: if our Content-Type is HTML or something, just read the first

View file

@ -148,7 +148,7 @@ class SrvResolver(object):
# Try something in the cache, else rereaise # Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None) cache_entry = self._cache.get(service_name, None)
if cache_entry: if cache_entry:
logger.warn( logger.warning(
"Failed to resolve %r, falling back to cache. %r", service_name, e "Failed to resolve %r, falling back to cache. %r", service_name, e
) )
return list(cache_entry) return list(cache_entry)

View file

@ -149,7 +149,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
body = yield make_deferred_yieldable(d) body = yield make_deferred_yieldable(d)
except Exception as e: except Exception as e:
logger.warn( logger.warning(
"{%s} [%s] Error reading response: %s", "{%s} [%s] Error reading response: %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
@ -457,7 +457,7 @@ class MatrixFederationHttpClient(object):
except Exception as e: except Exception as e:
# Eh, we're already going to raise an exception so lets # Eh, we're already going to raise an exception so lets
# ignore if this fails. # ignore if this fails.
logger.warn( logger.warning(
"{%s} [%s] Failed to get error response: %s %s: %s", "{%s} [%s] Failed to get error response: %s %s: %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
@ -478,7 +478,7 @@ class MatrixFederationHttpClient(object):
break break
except RequestSendFailed as e: except RequestSendFailed as e:
logger.warn( logger.warning(
"{%s} [%s] Request failed: %s %s: %s", "{%s} [%s] Request failed: %s %s: %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
@ -513,7 +513,7 @@ class MatrixFederationHttpClient(object):
raise raise
except Exception as e: except Exception as e:
logger.warn( logger.warning(
"{%s} [%s] Request failed: %s %s: %s", "{%s} [%s] Request failed: %s %s: %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
@ -889,7 +889,7 @@ class MatrixFederationHttpClient(object):
d.addTimeout(self.default_timeout, self.reactor) d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d) length = yield make_deferred_yieldable(d)
except Exception as e: except Exception as e:
logger.warn( logger.warning(
"{%s} [%s] Error reading response: %s", "{%s} [%s] Error reading response: %s",
request.txn_id, request.txn_id,
request.destination, request.destination,

View file

@ -170,7 +170,7 @@ class RequestMetrics(object):
tag = context.tag tag = context.tag
if context != self.start_context: if context != self.start_context:
logger.warn( logger.warning(
"Context have unexpectedly changed %r, %r", "Context have unexpectedly changed %r, %r",
context, context,
self.start_context, self.start_context,

View file

@ -454,7 +454,7 @@ def respond_with_json(
# the Deferred fires, but since the flag is RIGHT THERE it seems like # the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste. # a waste.
if request._disconnected: if request._disconnected:
logger.warn( logger.warning(
"Not sending response to request %s, already disconnected.", request "Not sending response to request %s, already disconnected.", request
) )
return return

View file

@ -219,13 +219,13 @@ def parse_json_value_from_request(request, allow_empty_body=False):
try: try:
content_unicode = content_bytes.decode("utf8") content_unicode = content_bytes.decode("utf8")
except UnicodeDecodeError: except UnicodeDecodeError:
logger.warn("Unable to decode UTF-8") logger.warning("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
try: try:
content = json.loads(content_unicode) content = json.loads(content_unicode)
except Exception as e: except Exception as e:
logger.warn("Unable to parse JSON: %s", e) logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content return content

View file

@ -199,7 +199,7 @@ class SynapseRequest(Request):
# It's useful to log it here so that we can get an idea of when # It's useful to log it here so that we can get an idea of when
# the client disconnects. # the client disconnects.
with PreserveLoggingContext(self.logcontext): with PreserveLoggingContext(self.logcontext):
logger.warn( logger.warning(
"Error processing request %r: %s %s", self, reason.type, reason.value "Error processing request %r: %s %s", self, reason.type, reason.value
) )
@ -305,7 +305,7 @@ class SynapseRequest(Request):
try: try:
self.request_metrics.stop(self.finish_time, self.code, self.sentLength) self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
except Exception as e: except Exception as e:
logger.warn("Failed to stop metrics: %r", e) logger.warning("Failed to stop metrics: %r", e)
class XForwardedForRequest(SynapseRequest): class XForwardedForRequest(SynapseRequest):

View file

@ -294,7 +294,7 @@ class LoggingContext(object):
"""Enters this logging context into thread local storage""" """Enters this logging context into thread local storage"""
old_context = self.set_current_context(self) old_context = self.set_current_context(self)
if self.previous_context != old_context: if self.previous_context != old_context:
logger.warn( logger.warning(
"Expected previous context %r, found %r", "Expected previous context %r, found %r",
self.previous_context, self.previous_context,
old_context, old_context,

View file

@ -159,6 +159,7 @@ class Notifier(object):
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.hs = hs self.hs = hs
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pending_new_room_events = [] self.pending_new_room_events = []
@ -425,7 +426,10 @@ class Notifier(object):
if name == "room": if name == "room":
new_events = yield filter_events_for_client( new_events = yield filter_events_for_client(
self.store, user.to_string(), new_events, is_peeking=is_peeking self.storage,
user.to_string(),
new_events,
is_peeking=is_peeking,
) )
elif name == "presence": elif name == "presence":
now = self.clock.time_msec() now = self.clock.time_msec()

View file

@ -64,6 +64,7 @@ class HttpPusher(object):
def __init__(self, hs, pusherdict): def __init__(self, hs, pusherdict):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict["user_name"] self.user_id = pusherdict["user_name"]
@ -246,7 +247,7 @@ class HttpPusher(object):
# we really only give up so that if the URL gets # we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load # fixed, we don't suddenly deliver a load
# of old notifications. # of old notifications.
logger.warn( logger.warning(
"Giving up on a notification to user %s, " "pushkey %s", "Giving up on a notification to user %s, " "pushkey %s",
self.user_id, self.user_id,
self.pushkey, self.pushkey,
@ -299,7 +300,7 @@ class HttpPusher(object):
if pk != self.pushkey: if pk != self.pushkey:
# for sanity, we only remove the pushkey if it # for sanity, we only remove the pushkey if it
# was the one we actually sent... # was the one we actually sent...
logger.warn( logger.warning(
("Ignoring rejected pushkey %s because we" " didn't send it"), ("Ignoring rejected pushkey %s because we" " didn't send it"),
pk, pk,
) )
@ -329,7 +330,7 @@ class HttpPusher(object):
return d return d
ctx = yield push_tools.get_context_for_event( ctx = yield push_tools.get_context_for_event(
self.store, self.state_handler, event, self.user_id self.storage, self.state_handler, event, self.user_id
) )
d = { d = {

View file

@ -119,6 +119,7 @@ class Mailer(object):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator() self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self.app_name = app_name self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name) logger.info("Created Mailer for app_name %s" % app_name)
@ -389,7 +390,7 @@ class Mailer(object):
} }
the_events = yield filter_events_for_client( the_events = yield filter_events_for_client(
self.store, user_id, results["events_before"] self.storage, user_id, results["events_before"]
) )
the_events.append(notif_event) the_events.append(notif_event)

View file

@ -117,7 +117,7 @@ class PushRuleEvaluatorForEvent(object):
pattern = UserID.from_string(user_id).localpart pattern = UserID.from_string(user_id).localpart
if not pattern: if not pattern:
logger.warn("event_match condition with no pattern") logger.warning("event_match condition with no pattern")
return False return False
# XXX: optimisation: cache our pattern regexps # XXX: optimisation: cache our pattern regexps
@ -173,7 +173,7 @@ def _glob_matches(glob, value, word_boundary=False):
regex_cache[(glob, word_boundary)] = r regex_cache[(glob, word_boundary)] = r
return r.search(value) return r.search(value)
except re.error: except re.error:
logger.warn("Failed to parse glob to regex: %r", glob) logger.warning("Failed to parse glob to regex: %r", glob)
return False return False

View file

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
@defer.inlineCallbacks @defer.inlineCallbacks
@ -43,22 +44,22 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_context_for_event(store, state_handler, ev, user_id): def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {} ctx = {}
room_state_ids = yield store.get_state_ids_for_event(ev.event_id) room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.name, an alias or
# a list of people in the room # a list of people in the room
name = yield calculate_room_name( name = yield calculate_room_name(
store, room_state_ids, user_id, fallback_to_single_member=False storage.main, room_state_ids, user_id, fallback_to_single_member=False
) )
if name: if name:
ctx["name"] = name ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield store.get_event(sender_state_event_id) sender_state_event = yield storage.main.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event) ctx["sender_display_name"] = name_from_member_event(sender_state_event)
return ctx return ctx

View file

@ -180,7 +180,7 @@ class ReplicationEndpoint(object):
if e.code != 504 or not cls.RETRY_ON_TIMEOUT: if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
raise raise
logger.warn("%s request timed out", cls.NAME) logger.warning("%s request timed out", cls.NAME)
# If we timed out we probably don't need to worry about backing # If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway. # off too much, but lets just wait a little anyway.

View file

@ -144,7 +144,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# The 'except' clause is very broad, but we need to # The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards # capture everything from DNS failures upwards
# #
logger.warn("Failed to reject invite: %s", e) logger.warning("Failed to reject invite: %s", e)
await self.store.locally_reject_invite(user_id, room_id) await self.store.locally_reject_invite(user_id, room_id)
ret = {} ret = {}

View file

@ -168,7 +168,7 @@ class ReplicationClientHandler(object):
if self.connection: if self.connection:
self.connection.send_command(cmd) self.connection.send_command(cmd)
else: else:
logger.warn("Queuing command as not connected: %r", cmd.NAME) logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd) self.pending_commands.append(cmd)
def send_federation_ack(self, token): def send_federation_ack(self, token):

View file

@ -249,7 +249,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return handler(cmd) return handler(cmd)
def close(self): def close(self):
logger.warn("[%s] Closing connection", self.id()) logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec() self.time_we_closed = self.clock.time_msec()
self.transport.loseConnection() self.transport.loseConnection()
self.on_connection_closed() self.on_connection_closed()

View file

@ -286,7 +286,7 @@ class PurgeHistoryRestServlet(RestServlet):
room_id, stream_ordering room_id, stream_ordering
) )
if not r: if not r:
logger.warn( logger.warning(
"[purge] purging events not possible: No event found " "[purge] purging events not possible: No event found "
"(received_ts %i => stream_ordering %i)", "(received_ts %i => stream_ordering %i)",
ts, ts,

View file

@ -221,7 +221,7 @@ class LoginRestServlet(RestServlet):
medium, address medium, address
) )
if not user_id: if not user_id:
logger.warn( logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address "unknown 3pid identifier medium %s, address %r", medium, address
) )
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)

View file

@ -71,7 +71,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warn( logger.warning(
"User password resets have been disabled due to lack of email config" "User password resets have been disabled due to lack of email config"
) )
raise SynapseError( raise SynapseError(
@ -162,7 +162,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
) )
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warn( logger.warning(
"Password reset emails have been disabled due to lack of an email config" "Password reset emails have been disabled due to lack of an email config"
) )
raise SynapseError( raise SynapseError(
@ -183,7 +183,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set # Perform a 302 redirect if next_link is set
if next_link: if next_link:
if next_link.startswith("file:///"): if next_link.startswith("file:///"):
logger.warn( logger.warning(
"Not redirecting to next_link as it is a local file: address" "Not redirecting to next_link as it is a local file: address"
) )
else: else:
@ -350,7 +350,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warn( logger.warning(
"Adding emails have been disabled due to lack of an email config" "Adding emails have been disabled due to lack of an email config"
) )
raise SynapseError( raise SynapseError(
@ -441,7 +441,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
if not self.hs.config.account_threepid_delegate_msisdn: if not self.hs.config.account_threepid_delegate_msisdn:
logger.warn( logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to " "No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request" "handle this request"
) )
@ -488,7 +488,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warn( logger.warning(
"Adding emails have been disabled due to lack of an email config" "Adding emails have been disabled due to lack of an email config"
) )
raise SynapseError( raise SynapseError(
@ -515,7 +515,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set # Perform a 302 redirect if next_link is set
if next_link: if next_link:
if next_link.startswith("file:///"): if next_link.startswith("file:///"):
logger.warn( logger.warning(
"Not redirecting to next_link as it is a local file: address" "Not redirecting to next_link as it is a local file: address"
) )
else: else:

View file

@ -106,7 +106,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config: if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warn( logger.warning(
"Email registration has been disabled due to lack of email config" "Email registration has been disabled due to lack of email config"
) )
raise SynapseError( raise SynapseError(
@ -207,7 +207,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
) )
if not self.hs.config.account_threepid_delegate_msisdn: if not self.hs.config.account_threepid_delegate_msisdn:
logger.warn( logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to " "No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request" "handle this request"
) )
@ -266,7 +266,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
) )
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warn( logger.warning(
"User registration via email has been disabled due to lack of email config" "User registration via email has been disabled due to lack of email config"
) )
raise SynapseError( raise SynapseError(
@ -287,7 +287,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set # Perform a 302 redirect if next_link is set
if next_link: if next_link:
if next_link.startswith("file:///"): if next_link.startswith("file:///"):
logger.warn( logger.warning(
"Not redirecting to next_link as it is a local file: address" "Not redirecting to next_link as it is a local file: address"
) )
else: else:
@ -480,7 +480,7 @@ class RegisterRestServlet(RestServlet):
# a password to work around a client bug where it sent # a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out # the 'initial_device_display_name' param alone, wiping out
# the original registration params # the original registration params
logger.warn("Ignoring initial_device_display_name without password") logger.warning("Ignoring initial_device_display_name without password")
del body["initial_device_display_name"] del body["initial_device_display_name"]
session_id = self.auth_handler.get_session_id(body) session_id = self.auth_handler.get_session_id(body)

View file

@ -394,7 +394,7 @@ class SyncRestServlet(RestServlet):
# We've had bug reports that events were coming down under the # We've had bug reports that events were coming down under the
# wrong room. # wrong room.
if event.room_id != room.room_id: if event.room_id != room.room_id:
logger.warn( logger.warning(
"Event %r is under room %r instead of %r", "Event %r is under room %r instead of %r",
event.event_id, event.event_id,
room.room_id, room.room_id,

View file

@ -363,7 +363,7 @@ class MediaRepository(object):
}, },
) )
except RequestSendFailed as e: except RequestSendFailed as e:
logger.warn( logger.warning(
"Request failed fetching remote media %s/%s: %r", "Request failed fetching remote media %s/%s: %r",
server_name, server_name,
media_id, media_id,
@ -372,7 +372,7 @@ class MediaRepository(object):
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e: except HttpResponseException as e:
logger.warn( logger.warning(
"HTTP error fetching remote media %s/%s: %s", "HTTP error fetching remote media %s/%s: %s",
server_name, server_name,
media_id, media_id,
@ -383,10 +383,12 @@ class MediaRepository(object):
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
except SynapseError: except SynapseError:
logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) logger.warning(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise raise
except NotRetryingDestination: except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name) logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
except Exception: except Exception:
logger.exception( logger.exception(
@ -691,7 +693,7 @@ class MediaRepository(object):
try: try:
os.remove(full_path) os.remove(full_path)
except OSError as e: except OSError as e:
logger.warn("Failed to remove file: %r", full_path) logger.warning("Failed to remove file: %r", full_path)
if e.errno == errno.ENOENT: if e.errno == errno.ENOENT:
pass pass
else: else:

View file

@ -136,7 +136,7 @@ class PreviewUrlResource(DirectServeResource):
match = False match = False
continue continue
if match: if match:
logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) logger.warning("URL %s blocked by url_blacklist entry %s", url, entry)
raise SynapseError( raise SynapseError(
403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
) )
@ -208,7 +208,7 @@ class PreviewUrlResource(DirectServeResource):
og["og:image:width"] = dims["width"] og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"] og["og:image:height"] = dims["height"]
else: else:
logger.warn("Couldn't get dims for %s" % url) logger.warning("Couldn't get dims for %s" % url)
# define our OG response for this media # define our OG response for this media
elif _is_html(media_info["media_type"]): elif _is_html(media_info["media_type"]):
@ -256,7 +256,7 @@ class PreviewUrlResource(DirectServeResource):
og["og:image:width"] = dims["width"] og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"] og["og:image:height"] = dims["height"]
else: else:
logger.warn("Couldn't get dims for %s", og["og:image"]) logger.warning("Couldn't get dims for %s", og["og:image"])
og["og:image"] = "mxc://%s/%s" % ( og["og:image"] = "mxc://%s/%s" % (
self.server_name, self.server_name,
@ -267,7 +267,7 @@ class PreviewUrlResource(DirectServeResource):
else: else:
del og["og:image"] del og["og:image"]
else: else:
logger.warn("Failed to find any OG data in %s", url) logger.warning("Failed to find any OG data in %s", url)
og = {} og = {}
logger.debug("Calculated OG for %s as %s", url, og) logger.debug("Calculated OG for %s as %s", url, og)
@ -319,7 +319,7 @@ class PreviewUrlResource(DirectServeResource):
) )
except Exception as e: except Exception as e:
# FIXME: pass through 404s and other error messages nicely # FIXME: pass through 404s and other error messages nicely
logger.warn("Error downloading %s: %r", url, e) logger.warning("Error downloading %s: %r", url, e)
raise SynapseError( raise SynapseError(
500, 500,
@ -400,7 +400,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e: except OSError as e:
# If the path doesn't exist, meh # If the path doesn't exist, meh
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e) logger.warning("Failed to remove media: %r: %s", media_id, e)
continue continue
removed_media.append(media_id) removed_media.append(media_id)
@ -432,7 +432,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e: except OSError as e:
# If the path doesn't exist, meh # If the path doesn't exist, meh
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e) logger.warning("Failed to remove media: %r: %s", media_id, e)
continue continue
try: try:
@ -448,7 +448,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e: except OSError as e:
# If the path doesn't exist, meh # If the path doesn't exist, meh
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e) logger.warning("Failed to remove media: %r: %s", media_id, e)
continue continue
removed_media.append(media_id) removed_media.append(media_id)

View file

@ -182,7 +182,7 @@ class ThumbnailResource(DirectServeResource):
if file_path: if file_path:
yield respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
else: else:
logger.warn("Failed to generate thumbnail") logger.warning("Failed to generate thumbnail")
respond_404(request) respond_404(request)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -245,7 +245,7 @@ class ThumbnailResource(DirectServeResource):
if file_path: if file_path:
yield respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
else: else:
logger.warn("Failed to generate thumbnail") logger.warning("Failed to generate thumbnail")
respond_404(request) respond_404(request)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -83,7 +83,7 @@ class ResourceLimitsServerNotices(object):
room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id)
if not room_id: if not room_id:
logger.warn("Failed to get server notices room") logger.warning("Failed to get server notices room")
return return
yield self._check_and_set_tags(user_id, room_id) yield self._check_and_set_tags(user_id, room_id)

View file

@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.storage.state import StateFilter
logger = logging.getLogger(__name__)
class SpamCheckerApi(object):
"""A proxy object that gets passed to spam checkers so they can get
access to rooms and other relevant information.
"""
def __init__(self, hs):
self.hs = hs
self._store = hs.get_datastore()
@defer.inlineCallbacks
def get_state_events_in_room(self, room_id, types):
"""Gets state events for the given room.
Args:
room_id (string): The room ID to get state events in.
types (tuple): The event type and state key (using None
to represent 'any') of the room state to acquire.
Returns:
twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
The filtered state events in the room.
"""
state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
state = yield self._store.get_events(state_ids.values())
return state.values()

View file

@ -103,6 +103,7 @@ class StateHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state_store = hs.get_storage().state
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@ -271,7 +272,7 @@ class StateHandler(object):
else: else:
current_state_ids = prev_state_ids current_state_ids = prev_state_ids
state_group = yield self.store.store_state_group( state_group = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=None, prev_group=None,
@ -321,7 +322,7 @@ class StateHandler(object):
delta_ids = dict(entry.delta_ids) delta_ids = dict(entry.delta_ids)
delta_ids[key] = event.event_id delta_ids[key] = event.event_id
state_group = yield self.store.store_state_group( state_group = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=prev_group, prev_group=prev_group,
@ -334,7 +335,7 @@ class StateHandler(object):
delta_ids = entry.delta_ids delta_ids = entry.delta_ids
if entry.state_group is None: if entry.state_group is None:
entry.state_group = yield self.store.store_state_group( entry.state_group = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=entry.prev_group, prev_group=entry.prev_group,
@ -376,14 +377,16 @@ class StateHandler(object):
# map from state group id to the state in that state group (where # map from state group id to the state in that state group (where
# 'state' is a map from state key to event id) # 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]] # dict[int, dict[(str, str), str]]
state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) state_groups_ids = yield self.state_store.get_state_groups_ids(
room_id, event_ids
)
if len(state_groups_ids) == 0: if len(state_groups_ids) == 0:
return _StateCacheEntry(state={}, state_group=None) return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1: elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop() name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
return _StateCacheEntry( return _StateCacheEntry(
state=state_list, state=state_list,

View file

@ -31,6 +31,7 @@ from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main import DataStore from synapse.storage.data_stores.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
__all__ = ["DataStores", "DataStore"] __all__ = ["DataStores", "DataStore"]
@ -47,6 +48,7 @@ class Storage(object):
self.persistence = EventsPersistenceStorage(hs, stores) self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores) self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
def are_all_users_on_domain(txn, database_engine, domain): def are_all_users_on_domain(txn, database_engine, domain):

View file

@ -494,7 +494,7 @@ class SQLBaseStore(object):
exception_callbacks = [] exception_callbacks = []
if LoggingContext.current_context() == LoggingContext.sentinel: if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warn("Starting db txn '%s' from sentinel context", desc) logger.warning("Starting db txn '%s' from sentinel context", desc)
try: try:
result = yield self.runWithConnection( result = yield self.runWithConnection(
@ -532,7 +532,7 @@ class SQLBaseStore(object):
""" """
parent_context = LoggingContext.current_context() parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel: if parent_context == LoggingContext.sentinel:
logger.warn( logger.warning(
"Starting db connection from sentinel context: metrics will be lost" "Starting db connection from sentinel context: metrics will be lost"
) )
parent_context = None parent_context = None
@ -719,7 +719,7 @@ class SQLBaseStore(object):
raise raise
# presumably we raced with another transaction: let's retry. # presumably we raced with another transaction: let's retry.
logger.warn( logger.warning(
"IntegrityError when upserting into %s; retrying: %s", table, e "IntegrityError when upserting into %s; retrying: %s", table, e
) )

View file

@ -82,7 +82,7 @@ def _retry_on_integrity_error(func):
@defer.inlineCallbacks @defer.inlineCallbacks
def f(self, *args, **kwargs): def f(self, *args, **kwargs):
try: try:
res = yield func(self, *args, **kwargs) res = yield func(self, *args, delete_existing=False, **kwargs)
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.") logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs) res = yield func(self, *args, delete_existing=True, **kwargs)
@ -1669,7 +1669,6 @@ class EventsStore(
"room_stats_earliest_token", "room_stats_earliest_token",
"rooms", "rooms",
"stream_ordering_to_exterm", "stream_ordering_to_exterm",
"topics",
"users_in_public_rooms", "users_in_public_rooms",
"users_who_share_private_rooms", "users_who_share_private_rooms",
# no useful index, but let's clear them anyway # no useful index, but let's clear them anyway

View file

@ -44,7 +44,7 @@ class PusherWorkerStore(SQLBaseStore):
r["data"] = json.loads(dataJson) r["data"] = json.loads(dataJson)
except Exception as e: except Exception as e:
logger.warn( logger.warning(
"Invalid JSON in data for pusher %d: %s, %s", "Invalid JSON in data for pusher %d: %s, %s",
r["id"], r["id"],
dataJson, dataJson,

View file

@ -196,7 +196,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
" ON event_search USING GIN (vector)" " ON event_search USING GIN (vector)"
) )
except psycopg2.ProgrammingError as e: except psycopg2.ProgrammingError as e:
logger.warn( logger.warning(
"Ignoring error %r when trying to switch from GIST to GIN", e "Ignoring error %r when trying to switch from GIST to GIN", e
) )

View file

@ -260,9 +260,7 @@ class EventsPersistenceStorage(object):
self._event_persist_queue.handle_queue(room_id, persisting_queue) self._event_persist_queue.handle_queue(room_id, persisting_queue)
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_events( def _persist_events(self, events_and_contexts, backfilled=False):
self, events_and_contexts, backfilled=False, delete_existing=False
):
"""Calculates the change to current state and forward extremities, and """Calculates the change to current state and forward extremities, and
persists the given events and with those updates. persists the given events and with those updates.
@ -412,7 +410,6 @@ class EventsPersistenceStorage(object):
state_delta_for_room=state_delta_for_room, state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties, new_forward_extremeties=new_forward_extremeties,
backfilled=backfilled, backfilled=backfilled,
delete_existing=delete_existing,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -550,7 +547,7 @@ class EventsPersistenceStorage(object):
if missing_event_ids: if missing_event_ids:
# Now pull out the state groups for any missing events from DB # Now pull out the state groups for any missing events from DB
event_to_groups = yield self.state_store._get_state_group_for_events( event_to_groups = yield self.main_store._get_state_group_for_events(
missing_event_ids missing_event_ids
) )
event_id_to_state_group.update(event_to_groups) event_id_to_state_group.update(event_to_groups)

View file

@ -19,6 +19,8 @@ from six import iteritems, itervalues
import attr import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -322,3 +324,234 @@ class StateFilter(object):
) )
return member_filter, non_member_filter return member_filter, non_member_filter
class StateGroupStorage(object):
"""High level interface to fetching state for event.
"""
def __init__(self, hs, stores):
self.stores = stores
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
(prev_group, delta_ids)
"""
return self.stores.main.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id (str): id of the room for these events
event_ids (iterable[str]): ids of the events
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
return {}
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self.stores.main._get_state_for_groups(groups)
return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
"""Get the event IDs of all the state in the given state group
Args:
state_group (int)
Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = yield self._get_state_for_groups((state_group,))
return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.stores.main.get_events(
[
ev_id
for group_ids in itervalues(group_to_ids)
for ev_id in itervalues(group_ids)
],
get_prev_content=False,
)
return {
group: [
state_event_map[v]
for v in itervalues(event_id_map)
if v in state_event_map
]
for group, event_id_map in iteritems(group_to_ids)
}
def _get_state_groups_from_groups(self, groups, state_filter):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups(list[int]): list of state group IDs to query
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
return self.stores.main._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self.stores.main._get_state_for_groups(
groups, state_filter
)
state_event_map = yield self.stores.main.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in iteritems(group_to_state[group])
if v in state_event_map
}
for event_id, group in iteritems(event_to_groups)
}
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self.stores.main._get_state_for_groups(
groups, state_filter
)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in iteritems(event_to_groups)
}
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
return self.stores.main._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
to event_id.
Returns:
Deferred[int]: The state group ID
"""
return self.stores.main.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)

View file

@ -309,7 +309,7 @@ class Linearizer(object):
) )
else: else:
logger.warn( logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r", "Unexpected exception waiting for linearizer lock %r for key %r",
self.name, self.name,
key, key,

View file

@ -107,7 +107,7 @@ def register_cache(cache_type, cache_name, cache, collect_callback=None):
if collect_callback: if collect_callback:
collect_callback() collect_callback()
except Exception as e: except Exception as e:
logger.warn("Error calculating metrics for %s: %s", cache_name, e) logger.warning("Error calculating metrics for %s: %s", cache_name, e)
raise raise
yield GaugeMetricFamily("__unused", "") yield GaugeMetricFamily("__unused", "")

View file

@ -119,7 +119,7 @@ class Measure(object):
context = LoggingContext.current_context() context = LoggingContext.current_context()
if context != self.start_context: if context != self.start_context:
logger.warn( logger.warning(
"Context has unexpectedly changed from '%s' to '%s'. (%r)", "Context has unexpectedly changed from '%s' to '%s'. (%r)",
self.start_context, self.start_context,
context, context,
@ -128,7 +128,7 @@ class Measure(object):
return return
if not context: if not context:
logger.warn("Expected context. (%r)", self.name) logger.warning("Expected context. (%r)", self.name)
return return
current = context.get_resource_usage() current = context.get_resource_usage()
@ -140,7 +140,7 @@ class Measure(object):
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError: except ValueError:
logger.warn( logger.warning(
"Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current
) )

View file

@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no):
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
) )
except (ValueError, resource.error) as e: except (ValueError, resource.error) as e:
logger.warn("Failed to set file or core limit: %s", e) logger.warning("Failed to set file or core limit: %s", e)

View file

@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks @defer.inlineCallbacks
def filter_events_for_client( def filter_events_for_client(
store, user_id, events, is_peeking=False, always_include_ids=frozenset() storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset()
): ):
""" """
Check which events a user is allowed to see Check which events a user is allowed to see
Args: Args:
store (synapse.storage.DataStore): our datastore (can also be a worker storage
store)
user_id(str): user id to be checked user_id(str): user id to be checked
events(list[synapse.events.EventBase]): sequence of events to be checked events(list[synapse.events.EventBase]): sequence of events to be checked
is_peeking(bool): should be True if: is_peeking(bool): should be True if:
@ -68,12 +68,12 @@ def filter_events_for_client(
events = list(e for e in events if not e.internal_metadata.is_soft_failed()) events = list(e for e in events if not e.internal_metadata.is_soft_failed())
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
event_id_to_state = yield store.get_state_for_events( event_id_to_state = yield storage.state.get_state_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types), state_filter=StateFilter.from_types(types),
) )
ignore_dict_content = yield store.get_global_account_data_by_type_for_user( ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id "m.ignored_user_list", user_id
) )
@ -84,7 +84,7 @@ def filter_events_for_client(
else [] else []
) )
erased_senders = yield store.are_users_erased((e.sender for e in events)) erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
def allowed(event): def allowed(event):
""" """
@ -213,13 +213,17 @@ def filter_events_for_client(
@defer.inlineCallbacks @defer.inlineCallbacks
def filter_events_for_server( def filter_events_for_server(
store, server_name, events, redact=True, check_history_visibility_only=False storage: Storage,
server_name,
events,
redact=True,
check_history_visibility_only=False,
): ):
"""Filter a list of events based on whether given server is allowed to """Filter a list of events based on whether given server is allowed to
see them. see them.
Args: Args:
store (DataStore) storage
server_name (str) server_name (str)
events (iterable[FrozenEvent]) events (iterable[FrozenEvent])
redact (bool): Whether to return a redacted version of the event, or redact (bool): Whether to return a redacted version of the event, or
@ -274,7 +278,7 @@ def filter_events_for_server(
# Lets check to see if all the events have a history visibility # Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't # of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room). # need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events( event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""),) types=((EventTypes.RoomHistoryVisibility, ""),)
@ -292,14 +296,14 @@ def filter_events_for_server(
if not visibility_ids: if not visibility_ids:
all_open = True all_open = True
else: else:
event_map = yield store.get_events(visibility_ids) event_map = yield storage.main.get_events(visibility_ids)
all_open = all( all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable") e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in itervalues(event_map) for e in itervalues(event_map)
) )
if not check_history_visibility_only: if not check_history_visibility_only:
erased_senders = yield store.are_users_erased((e.sender for e in events)) erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
else: else:
# We don't want to check whether users are erased, which is equivalent # We don't want to check whether users are erased, which is equivalent
# to no users having been erased. # to no users having been erased.
@ -328,7 +332,7 @@ def filter_events_for_server(
# first, for each event we're wanting to return, get the event_ids # first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events. # of the history vis and membership state at those events.
event_to_state_ids = yield store.get_state_ids_for_events( event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@ -358,7 +362,7 @@ def filter_events_for_server(
return False return False
return state_key[idx + 1 :] == server_name return state_key[idx + 1 :] == server_name
event_map = yield store.get_events( event_map = yield storage.main.get_events(
[ [
e_id e_id
for e_id, key in iteritems(event_id_to_state_key) for e_id, key in iteritems(event_id_to_state_key)

View file

@ -561,3 +561,81 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"] return channel.json_body["groups"]
class PurgeRoomTestCase(unittest.HomeserverTestCase):
"""Test /purge_room admin API.
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
def test_purge_room(self):
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
# All users have to have left the room.
self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
url = "/_synapse/admin/v1/purge_room"
request, channel = self.make_request(
"POST",
url.encode("ascii"),
{"room_id": room_id},
access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Test that the following tables have been purged of all rows related to the room.
for table in (
"current_state_events",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
"event_push_actions",
"event_search",
"events",
"group_rooms",
"public_room_list_stream",
"receipts_graph",
"receipts_linearized",
"room_aliases",
"room_depth",
"room_memberships",
"room_stats_state",
"room_stats_current",
"room_stats_historical",
"room_stats_earliest_token",
"rooms",
"stream_ordering_to_exterm",
"users_in_public_rooms",
"users_who_share_private_rooms",
"appservice_room_list",
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
"local_invites",
"room_account_data",
"room_tags",
):
count = self.get_success(
self.store._simple_select_one_onecol(
table="events",
keyvalues={"room_id": room_id},
retcol="COUNT(*)",
desc="test_purge_room",
)
)
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))

View file

@ -161,7 +161,11 @@ def make_request(
path = path.encode("ascii") path = path.encode("ascii")
# Decorate it to be the full path, if we're using shorthand # Decorate it to be the full path, if we're using shorthand
if shorthand and not path.startswith(b"/_matrix"): if (
shorthand
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
path = b"/_matrix/client/r0/" + path path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/") path = path.replace(b"//", b"/")

View file

@ -35,6 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_datastore = self.store
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -83,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield self.store.get_state_groups_ids( state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id] self.room, [e2.event_id]
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
@ -102,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) state_group_map = yield self.storage.state.get_state_groups(
self.room, [e2.event_id]
)
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0] state_list = list(state_group_map.values())[0]
@ -142,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we get the full state as of the final event # check we get the full state as of the final event
state = yield self.store.get_state_for_event(e5.event_id) state = yield self.storage.state.get_state_for_event(e5.event_id)
self.assertIsNotNone(e4) self.assertIsNotNone(e4)
@ -158,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we can filter to the m.room.name event (with a '' state key) # check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
) )
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key) # check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
) )
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key) # check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
) )
@ -182,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the # check we can grab a specific room member without filtering out the
# other event types # other event types
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}}, types={EventTypes.Member: {self.u_alice.to_string()}},
@ -200,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check that we can grab everything except members # check that we can grab everything except members
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -216,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################
room_id = self.room.to_string() room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group_ids = yield self.storage.state.get_state_groups_ids(
room_id, [e5.event_id]
)
group = list(group_ids.keys())[0] group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members # test _get_state_for_group_using_cache correctly filters out members
# with types=[] # with types=[]
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -238,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -251,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with wildcard types # with wildcard types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types={EventTypes.Member: None}, include_others=True
@ -268,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types={EventTypes.Member: None}, include_others=True
@ -288,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types={EventTypes.Member: {e5.state_key}}, include_others=True
@ -305,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types={EventTypes.Member: {e5.state_key}}, include_others=True
@ -318,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types={EventTypes.Member: {e5.state_key}}, include_others=False
@ -332,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################
# deliberately remove e2 (room name) from the _state_group_cache # deliberately remove e2 (room name) from the _state_group_cache
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( (
group is_all,
) known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertEqual(known_absent, set()) self.assertEqual(known_absent, set())
@ -347,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
state_dict_ids.pop((e2.type, e2.state_key)) state_dict_ids.pop((e2.type, e2.state_key))
self.store._state_group_cache.invalidate(group) self.state_datastore._state_group_cache.invalidate(group)
self.store._state_group_cache.update( self.state_datastore._state_group_cache.update(
sequence=self.store._state_group_cache.sequence, sequence=self.state_datastore._state_group_cache.sequence,
key=group, key=group,
value=state_dict_ids, value=state_dict_ids,
# list fetched keys so it knows it's partial # list fetched keys so it knows it's partial
fetched_keys=((e1.type, e1.state_key),), fetched_keys=((e1.type, e1.state_key),),
) )
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( (
group is_all,
) known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
@ -370,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members # test _get_state_for_group_using_cache correctly filters out members
# with types=[] # with types=[]
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -382,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -395,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# wildcard types # wildcard types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types={EventTypes.Member: None}, include_others=True
@ -406,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types={EventTypes.Member: None}, include_others=True
@ -425,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types={EventTypes.Member: {e5.state_key}}, include_others=True
@ -436,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types={EventTypes.Member: {e5.state_key}}, include_others=True
@ -449,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types={EventTypes.Member: {e5.state_key}}, include_others=False
@ -460,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict) self.assertDictEqual({}, state_dict)
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types={EventTypes.Member: {e5.state_key}}, include_others=False

View file

@ -158,10 +158,12 @@ class Graph(object):
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = StateGroupStore() self.store = StateGroupStore()
storage = Mock(main=self.store, state=self.store)
hs = Mock( hs = Mock(
spec_set=[ spec_set=[
"config", "config",
"get_datastore", "get_datastore",
"get_storage",
"get_auth", "get_auth",
"get_state_handler", "get_state_handler",
"get_clock", "get_clock",
@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage.return_value = storage
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
@ -63,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt) events_to_filter.append(evt)
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter self.storage, "test_server", events_to_filter
) )
# the result should be 5 redacted events, and 5 unredacted events. # the result should be 5 redacted events, and 5 unredacted events.
@ -101,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# ... and the filtering happens. # ... and the filtering happens.
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter self.storage, "test_server", events_to_filter
) )
for i in range(0, len(events_to_filter)): for i in range(0, len(events_to_filter)):
@ -258,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
logger.info("Starting filtering") logger.info("Starting filtering")
start = time.time() start = time.time()
storage = Mock()
storage.main = test_store
storage.state = test_store
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter test_store, "test_server", events_to_filter
) )