Merge remote-tracking branch 'origin/develop' into rav/event_auth/1

This commit is contained in:
Richard van der Hoff 2019-10-18 10:11:40 +01:00
commit 80003dfcd5
46 changed files with 458 additions and 221 deletions

View file

@ -1,3 +1,11 @@
Synapse 1.4.1rc1 (2019-10-17)
=============================
Bugfixes
--------
- Fix bug where redacted events were sometimes incorrectly censored in the database, breaking APIs that attempted to fetch such events. ([\#6185](https://github.com/matrix-org/synapse/issues/6185), [5b0e9948](https://github.com/matrix-org/synapse/commit/5b0e9948eaae801643e594b5abc8ee4b10bd194e))
Synapse 1.4.0 (2019-10-03) Synapse 1.4.0 (2019-10-03)
========================== ==========================

View file

@ -47,7 +47,5 @@ prune debian
prune demo/etc prune demo/etc
prune docker prune docker
prune mypy.ini prune mypy.ini
prune snap
prune stubs prune stubs
exclude jenkins*
recursive-exclude jenkins *.sh

1
changelog.d/6114.feature Normal file
View file

@ -0,0 +1 @@
CAS login now provides a default display name for users if a `displayname_attribute` is set in the configuration file.

1
changelog.d/6156.misc Normal file
View file

@ -0,0 +1 @@
Use Postgres ANY for selecting many values.

1
changelog.d/6168.bugfix Normal file
View file

@ -0,0 +1 @@
Fix monthly active user reaping where reserved users are specified.

View file

@ -1 +0,0 @@
Fix bug where redacted events were sometimes incorrectly censored in the database, breaking APIs that attempted to fetch such events.

1
changelog.d/6186.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where we were updating censored events as bytes rather than text, occaisonally causing invalid JSON being inserted breaking APIs that attempted to fetch such events.

1
changelog.d/6186.misc Normal file
View file

@ -0,0 +1 @@
Reject (accidental) attempts to insert bytes into postgres tables.

1
changelog.d/6189.misc Normal file
View file

@ -0,0 +1 @@
Make `version` optional in body of `PUT /room_keys/version/{version}`, since it's redundant.

1
changelog.d/6191.misc Normal file
View file

@ -0,0 +1 @@
Add snapcraft packaging information. Contributed by @devec0.

1
changelog.d/6193.misc Normal file
View file

@ -0,0 +1 @@
Make storage layer responsible for adding device names to key, rather than the handler.

1
changelog.d/6195.bugfix Normal file
View file

@ -0,0 +1 @@
Fix tracing of non-JSON APIs, /media, /key etc.

View file

@ -1220,6 +1220,7 @@ saml2_config:
# enabled: true # enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# service_url: "https://homeserver.domain.com:8448" # service_url: "https://homeserver.domain.com:8448"
# #displayname_attribute: name
# #required_attributes: # #required_attributes:
# # name: value # # name: value

View file

@ -36,7 +36,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.4.0" __version__ = "1.4.1rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View file

@ -605,13 +605,13 @@ def run(hs):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_monthly_active_users(): def generate_monthly_active_users():
current_mau_count = 0 current_mau_count = 0
reserved_count = 0 reserved_users = ()
store = hs.get_datastore() store = hs.get_datastore()
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
current_mau_count = yield store.get_monthly_active_count() current_mau_count = yield store.get_monthly_active_count()
reserved_count = yield store.get_registered_reserved_users_count() reserved_users = yield store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count)) current_mau_gauge.set(float(current_mau_count))
registered_reserved_users_mau_gauge.set(float(reserved_count)) registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
max_mau_gauge.set(float(hs.config.max_mau_value)) max_mau_gauge.set(float(hs.config.max_mau_value))
def start_generate_monthly_active_users(): def start_generate_monthly_active_users():

View file

@ -30,11 +30,13 @@ class CasConfig(Config):
self.cas_enabled = cas_config.get("enabled", True) self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"] self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"] self.cas_service_url = cas_config["service_url"]
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes", {}) self.cas_required_attributes = cas_config.get("required_attributes", {})
else: else:
self.cas_enabled = False self.cas_enabled = False
self.cas_server_url = None self.cas_server_url = None
self.cas_service_url = None self.cas_service_url = None
self.cas_displayname_attribute = None
self.cas_required_attributes = {} self.cas_required_attributes = {}
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
@ -45,6 +47,7 @@ class CasConfig(Config):
# enabled: true # enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# service_url: "https://homeserver.domain.com:8448" # service_url: "https://homeserver.domain.com:8448"
# #displayname_attribute: name
# #required_attributes: # #required_attributes:
# # name: value # # name: value
""" """

View file

@ -248,16 +248,10 @@ class E2eKeysHandler(object):
results = yield self.store.get_e2e_device_keys(local_query) results = yield self.store.get_e2e_device_keys(local_query)
# Build the result structure, un-jsonify the results, and add the # Build the result structure
# "unsigned" section
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items(): for device_id, device_info in device_keys.items():
r = dict(device_info["keys"]) result_dict[user_id][device_id] = device_info
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
log_kv(results) log_kv(results)
return result_dict return result_dict

View file

@ -352,8 +352,8 @@ class E2eRoomKeysHandler(object):
A deferred of an empty dict. A deferred of an empty dict.
""" """
if "version" not in version_info: if "version" not in version_info:
raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM) version_info["version"] = version
if version_info["version"] != version: elif version_info["version"] != version:
raise SynapseError( raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM 400, "Version in body does not match", Codes.INVALID_PARAM
) )

View file

@ -388,7 +388,7 @@ class DirectServeResource(resource.Resource):
if not callback: if not callback:
return super().render(request) return super().render(request)
resp = callback(request) resp = trace_servlet(self.__class__.__name__)(callback)(request)
# If it's a coroutine, turn it into a Deferred # If it's a coroutine, turn it into a Deferred
if isinstance(resp, types.CoroutineType): if isinstance(resp, types.CoroutineType):

View file

@ -169,6 +169,7 @@ import contextlib
import inspect import inspect
import logging import logging
import re import re
import types
from functools import wraps from functools import wraps
from typing import Dict from typing import Dict
@ -778,8 +779,7 @@ def trace_servlet(servlet_name, extract_context=False):
return func return func
@wraps(func) @wraps(func)
@defer.inlineCallbacks async def _trace_servlet_inner(request, *args, **kwargs):
def _trace_servlet_inner(request, *args, **kwargs):
request_tags = { request_tags = {
"request_id": request.get_request_id(), "request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
@ -796,9 +796,15 @@ def trace_servlet(servlet_name, extract_context=False):
scope = start_active_span(servlet_name, tags=request_tags) scope = start_active_span(servlet_name, tags=request_tags)
with scope: with scope:
result = yield defer.maybeDeferred(func, request, *args, **kwargs) result = func(request, *args, **kwargs)
if not isinstance(result, (types.CoroutineType, defer.Deferred)):
# Some servlets aren't async and just return results
# directly, so we handle that here.
return result return result
return await result
return _trace_servlet_inner return _trace_servlet_inner
return _trace_servlet_inner_1 return _trace_servlet_inner_1

View file

@ -377,6 +377,7 @@ class CasTicketServlet(RestServlet):
super(CasTicketServlet, self).__init__() super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_simple_http_client() self._http_client = hs.get_simple_http_client()
@ -400,6 +401,7 @@ class CasTicketServlet(RestServlet):
def handle_cas_response(self, request, cas_response_body, client_redirect_url): def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body) user, attributes = self.parse_cas_response(cas_response_body)
displayname = attributes.pop(self.cas_displayname_attribute, None)
for required_attribute, required_value in self.cas_required_attributes.items(): for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden # If required attribute was not in CAS Response - Forbidden
@ -414,7 +416,7 @@ class CasTicketServlet(RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth( return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url user, request, client_redirect_url, displayname
) )
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):

View file

@ -375,7 +375,7 @@ class RoomKeysVersionServlet(RestServlet):
"ed25519:something": "hijklmnop" "ed25519:something": "hijklmnop"
} }
}, },
"version": "42" "version": "12345"
} }
HTTP/1.1 200 OK HTTP/1.1 200 OK

View file

@ -270,7 +270,7 @@ class PreviewUrlResource(DirectServeResource):
logger.debug("Calculated OG for %s as %s" % (url, og)) logger.debug("Calculated OG for %s as %s" % (url, og))
jsonog = json.dumps(og).encode("utf8") jsonog = json.dumps(og)
# store OG in history-aware DB cache # store OG in history-aware DB cache
yield self.store.store_url_cache( yield self.store.store_url_cache(
@ -283,7 +283,7 @@ class PreviewUrlResource(DirectServeResource):
media_info["created_ts"], media_info["created_ts"],
) )
return jsonog return jsonog.encode("utf8")
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_url(self, url, user): def _download_url(self, url, user):

View file

@ -20,6 +20,7 @@ import random
import sys import sys
import threading import threading
import time import time
from typing import Iterable, Tuple
from six import PY2, iteritems, iterkeys, itervalues from six import PY2, iteritems, iterkeys, itervalues
from six.moves import builtins, intern, range from six.moves import builtins, intern, range
@ -1163,19 +1164,18 @@ class SQLBaseStore(object):
if not iterable: if not iterable:
return [] return []
sql = "SELECT %s FROM %s" % (", ".join(retcols), table) clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
clauses = []
values = []
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues): for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
if clauses: sql = "SELECT %s FROM %s WHERE %s" % (
sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) ", ".join(retcols),
table,
" AND ".join(clauses),
)
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@ -1324,10 +1324,8 @@ class SQLBaseStore(object):
sql = "DELETE FROM %s" % table sql = "DELETE FROM %s" % table
clauses = [] clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
values = [] clauses = [clause]
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues): for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
@ -1694,3 +1692,30 @@ def db_to_json(db_content):
except Exception: except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content) logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise raise
def make_in_list_sql_clause(
database_engine, column: str, iterable: Iterable
) -> Tuple[str, Iterable]:
"""Returns an SQL clause that checks the given column is in the iterable.
On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
it expands to `column = ANY(?)`. While both DBs support the `IN` form,
using the `ANY` form on postgres means that it views queries with
different length iterables as the same, helping the query stats.
Args:
database_engine
column: Name of the column
iterable: The values to check the column against.
Returns:
A tuple of SQL query and the args
"""
if database_engine.supports_using_any_list:
# This should hopefully be faster, but also makes postgres query
# stats easier to understand.
return "%s = ANY(?)" % (column,), [list(iterable)]
else:
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)

View file

@ -20,7 +20,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -378,15 +378,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
else: else:
if not devices: if not devices:
continue continue
sql = (
"SELECT device_id FROM devices" clause, args = make_in_list_sql_clause(
" WHERE user_id = ? AND device_id IN (" txn.database_engine, "device_id", devices
+ ",".join("?" * len(devices))
+ ")"
) )
sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
txn.execute(sql, [user_id] + devices) txn.execute(sql, [user_id] + list(args))
for row in txn: for row in txn:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server

View file

@ -28,7 +28,12 @@ from synapse.logging.opentracing import (
whitelisted_homeserver, whitelisted_homeserver,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage._base import (
Cache,
SQLBaseStore,
db_to_json,
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util import batch_iter from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@ -448,11 +453,14 @@ class DeviceWorkerStore(SQLBaseStore):
sql = """ sql = """
SELECT DISTINCT user_id FROM device_lists_stream SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ? WHERE stream_id > ?
AND user_id IN (%s) AND
""" """
for chunk in batch_iter(to_check, 100): for chunk in batch_iter(to_check, 100):
txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk) clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + clause, (from_key,) + tuple(args))
changes.update(user_id for user_id, in txn) changes.update(user_id for user_id, in txn)
return changes return changes

View file

@ -40,7 +40,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
This option only takes effect if include_all_devices is true. This option only takes effect if include_all_devices is true.
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name". key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
""" """
set_tag("query_list", query_list) set_tag("query_list", query_list)
if not query_list: if not query_list:
@ -54,11 +55,20 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
include_deleted_devices, include_deleted_devices,
) )
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
rv = {}
for user_id, device_keys in iteritems(results): for user_id, device_keys in iteritems(results):
rv[user_id] = {}
for device_id, device_info in iteritems(device_keys): for device_id, device_info in iteritems(device_keys):
device_info["keys"] = db_to_json(device_info.pop("key_json")) r = db_to_json(device_info.pop("key_json"))
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
rv[user_id][device_id] = r
return results return rv
@trace @trace
def _get_e2e_device_keys_txn( def _get_e2e_device_keys_txn(

View file

@ -22,6 +22,13 @@ class PostgresEngine(object):
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_):
raise Exception("Passing bytes to DB is disabled.")
self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit = database_config.get("synchronous_commit", True) self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet self._version = None # unknown as yet
@ -79,6 +86,12 @@ class PostgresEngine(object):
""" """
return True return True
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list
"""
return True
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html # https://www.postgresql.org/docs/current/static/errcodes-appendix.html

View file

@ -46,6 +46,12 @@ class Sqlite3Engine(object):
""" """
return self.module.sqlite_version_info >= (3, 15, 0) return self.module.sqlite_version_info >= (3, 15, 0)
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list
"""
return False
def check_database(self, txn): def check_database(self, txn):
pass pass

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -68,7 +68,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else: else:
results = set() results = set()
base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" base_sql = "SELECT auth_id FROM event_auth WHERE "
front = set(event_ids) front = set(event_ids)
while front: while front:
@ -76,7 +76,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
front_list = list(front) front_list = list(front)
chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
for chunk in chunks: for chunk in chunks:
txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk) clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", chunk
)
txn.execute(base_sql + clause, list(args))
new_front.update([r[0] for r in txn]) new_front.update([r[0] for r in txn])
new_front -= results new_front -= results

View file

@ -39,6 +39,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import BucketCollector from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore from synapse.state import StateResolutionStore
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
@ -641,14 +642,16 @@ class EventsStore(
LEFT JOIN rejections USING (event_id) LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id) LEFT JOIN event_json USING (event_id)
WHERE WHERE
prev_event_id IN (%s) NOT events.outlier
AND NOT events.outlier
AND rejections.event_id IS NULL AND rejections.event_id IS NULL
""" % ( AND
",".join("?" for _ in batch), """
clause, args = make_in_list_sql_clause(
self.database_engine, "prev_event_id", batch
) )
txn.execute(sql, batch) txn.execute(sql + clause, args)
results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100): for chunk in batch_iter(event_ids, 100):
@ -695,13 +698,15 @@ class EventsStore(
LEFT JOIN rejections USING (event_id) LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id) LEFT JOIN event_json USING (event_id)
WHERE WHERE
event_id IN (%s) NOT events.outlier
AND NOT events.outlier AND
""" % ( """
",".join("?" for _ in to_recursively_check),
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", to_recursively_check
) )
txn.execute(sql, to_recursively_check) txn.execute(sql + clause, args)
to_recursively_check = [] to_recursively_check = []
for event_id, prev_event_id, metadata, rejected in txn: for event_id, prev_event_id, metadata, rejected in txn:
@ -1543,10 +1548,14 @@ class EventsStore(
" FROM events as e" " FROM events as e"
" LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts" " LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)" " WHERE "
) % (",".join(["?"] * len(ev_map)),) )
txn.execute(sql, list(ev_map)) clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", list(ev_map)
)
txn.execute(sql + clause, args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
for row in rows: for row in rows:
event = ev_map[row["event_id"]] event = ev_map[row["event_id"]]
@ -2249,11 +2258,12 @@ class EventsStore(
sql = """ sql = """
SELECT DISTINCT state_group FROM event_to_state_groups SELECT DISTINCT state_group FROM event_to_state_groups
LEFT JOIN events_to_purge AS ep USING (event_id) LEFT JOIN events_to_purge AS ep USING (event_id)
WHERE state_group IN (%s) AND ep.event_id IS NULL WHERE ep.event_id IS NULL AND
""" % ( """
",".join("?" for _ in current_search), clause, args = make_in_list_sql_clause(
txn.database_engine, "state_group", current_search
) )
txn.execute(sql, list(current_search)) txn.execute(sql + clause, list(args))
referenced = set(sg for sg, in txn) referenced = set(sg for sg, in txn)
referenced_groups |= referenced referenced_groups |= referenced

View file

@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,6 +72,19 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"redactions_received_ts", self._redactions_received_ts "redactions_received_ts", self._redactions_received_ts
) )
# This index gets deleted in `event_fix_redactions_bytes` update
self.register_background_index_update(
"event_fix_redactions_bytes_create_index",
index_name="redactions_censored_redacts",
table="redactions",
columns=["redacts"],
where_clause="have_censored",
)
self.register_background_update_handler(
"event_fix_redactions_bytes", self._event_fix_redactions_bytes
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size): def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
@ -312,12 +326,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
INNER JOIN event_json USING (event_id) INNER JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id) LEFT JOIN rejections USING (event_id)
WHERE WHERE
prev_event_id IN (%s) NOT events.outlier
AND NOT events.outlier AND
""" % ( """
",".join("?" for _ in to_check), clause, args = make_in_list_sql_clause(
self.database_engine, "prev_event_id", to_check
) )
txn.execute(sql, to_check) txn.execute(sql + clause, list(args))
for prev_event_id, event_id, metadata, rejected in txn: for prev_event_id, event_id, metadata, rejected in txn:
if event_id in graph: if event_id in graph:
@ -458,3 +473,33 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
yield self._end_background_update("redactions_received_ts") yield self._end_background_update("redactions_received_ts")
return count return count
@defer.inlineCallbacks
def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
def _event_fix_redactions_bytes_txn(txn):
# This update is quite fast due to new index.
txn.execute(
"""
UPDATE event_json
SET
json = convert_from(json::bytea, 'utf8')
FROM redactions
WHERE
redactions.have_censored
AND event_json.event_id = redactions.redacts
AND json NOT LIKE '{%';
"""
)
txn.execute("DROP INDEX redactions_censored_redacts")
yield self.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
yield self._end_background_update("event_fix_redactions_bytes")
return 1

View file

@ -31,12 +31,11 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import batch_iter from synapse.util import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -623,10 +622,14 @@ class EventsWorkerStore(SQLBaseStore):
" rej.reason " " rej.reason "
" FROM event_json as e" " FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN rejections as rej USING (event_id)"
" WHERE e.event_id IN (%s)" " WHERE "
) % (",".join(["?"] * len(evs)),) )
txn.execute(sql, evs) clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", evs
)
txn.execute(sql + clause, args)
for row in txn: for row in txn:
event_id = row[0] event_id = row[0]
@ -640,11 +643,11 @@ class EventsWorkerStore(SQLBaseStore):
} }
# check for redactions # check for redactions
redactions_sql = ( redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
"SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)"
) % (",".join(["?"] * len(evs)),)
txn.execute(redactions_sql, evs) clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs)
txn.execute(redactions_sql + clause, args)
for (redacter, redacted) in txn: for (redacter, redacted) in txn:
d = event_dict.get(redacted) d = event_dict.get(redacted)
@ -753,10 +756,11 @@ class EventsWorkerStore(SQLBaseStore):
results = set() results = set()
def have_seen_events_txn(txn, chunk): def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % ( sql = "SELECT event_id FROM events as e WHERE "
",".join("?" * len(chunk)), clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", chunk
) )
txn.execute(sql, chunk) txn.execute(sql + clause, args)
for (event_id,) in txn: for (event_id,) in txn:
results.add(event_id) results.add(event_id)

View file

@ -51,7 +51,7 @@ class FilteringStore(SQLBaseStore):
"SELECT filter_id FROM user_filters " "SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?" "WHERE user_id = ? AND filter_json = ?"
) )
txn.execute(sql, (user_localpart, def_json)) txn.execute(sql, (user_localpart, bytearray(def_json)))
filter_id_response = txn.fetchone() filter_id_response = txn.fetchone()
if filter_id_response is not None: if filter_id_response is not None:
return filter_id_response[0] return filter_id_response[0]
@ -68,7 +68,7 @@ class FilteringStore(SQLBaseStore):
"INSERT INTO user_filters (user_id, filter_id, filter_json)" "INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)" "VALUES(?, ?, ?)"
) )
txn.execute(sql, (user_localpart, filter_id, def_json)) txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
return filter_id return filter_id

View file

@ -32,7 +32,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
super(MonthlyActiveUsersStore, self).__init__(None, hs) super(MonthlyActiveUsersStore, self).__init__(None, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.hs = hs self.hs = hs
self.reserved_users = ()
# Do not add more reserved users than the total allowable number # Do not add more reserved users than the total allowable number
self._new_transaction( self._new_transaction(
dbconn, dbconn,
@ -51,7 +50,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn (cursor): txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve threepids (list[dict]): List of threepid dicts to reserve
""" """
reserved_user_list = []
for tp in threepids: for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@ -60,10 +58,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
is_support = self.is_support_user_txn(txn, user_id) is_support = self.is_support_user_txn(txn, user_id)
if not is_support: if not is_support:
self.upsert_monthly_active_user_txn(txn, user_id) self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id)
else: else:
logger.warning("mau limit reserved threepid %s not found in db" % tp) logger.warning("mau limit reserved threepid %s not found in db" % tp)
self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def reap_monthly_active_users(self): def reap_monthly_active_users(self):
@ -74,8 +70,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[] Deferred[]
""" """
def _reap_users(txn): def _reap_users(txn, reserved_users):
# Purge stale users """
Args:
reserved_users (tuple): reserved users to preserve
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago] query_args = [thirty_days_ago]
@ -83,20 +82,19 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite. # when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0: if len(reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting # questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s' # tuples in 'WHERE IN %s'
questionmarks = "?" * len(self.reserved_users) question_marks = ",".join("?" * len(reserved_users))
query_args.extend(self.reserved_users) query_args.extend(reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format( sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
",".join(questionmarks)
)
else: else:
sql = base_sql sql = base_sql
txn.execute(sql, query_args) txn.execute(sql, query_args)
max_mau_value = self.hs.config.max_mau_value
if self.hs.config.limit_usage_by_mau: if self.hs.config.limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on # If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis. # a least recently active basis.
@ -106,12 +104,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# While Postgres does not require 'LIMIT', but also does not support # While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can # negative LIMIT values. So there is no way to write it that both can
# support # support
safe_guard = self.hs.config.max_mau_value - len(self.reserved_users) if len(reserved_users) == 0:
# Must be greater than zero for postgres sql = """
safe_guard = safe_guard if safe_guard > 0 else 0
query_args = [safe_guard]
base_sql = """
DELETE FROM monthly_active_users DELETE FROM monthly_active_users
WHERE user_id NOT IN ( WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users SELECT user_id FROM monthly_active_users
@ -119,18 +113,43 @@ class MonthlyActiveUsersStore(SQLBaseStore):
LIMIT ? LIMIT ?
) )
""" """
txn.execute(sql, (max_mau_value,))
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite. # when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0:
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
",".join(questionmarks)
)
else: else:
sql = base_sql # Must be >= 0 for postgres
num_of_non_reserved_users_to_remove = max(
max_mau_value - len(reserved_users), 0
)
# It is important to filter reserved users twice to guard
# against the case where the reserved user is present in the
# SELECT, meaning that a legitmate mau is deleted.
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
WHERE user_id NOT IN ({})
ORDER BY timestamp DESC
LIMIT ?
)
AND user_id NOT IN ({})
""".format(
question_marks, question_marks
)
query_args = [
*reserved_users,
num_of_non_reserved_users_to_remove,
*reserved_users,
]
txn.execute(sql, query_args) txn.execute(sql, query_args)
yield self.runInteraction("reap_monthly_active_users", _reap_users) reserved_users = yield self.get_registered_reserved_users()
yield self.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
# It seems poor to invalidate the whole cache, Postgres supports # It seems poor to invalidate the whole cache, Postgres supports
# 'Returning' which would allow me to invalidate only the # 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead # specific users, but sqlite has no way to do this and instead
@ -159,21 +178,25 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_registered_reserved_users_count(self): def get_registered_reserved_users(self):
"""Of the reserved threepids defined in config, how many are associated """Of the reserved threepids defined in config, which are associated
with registered users? with registered users?
Returns: Returns:
Defered[int]: Number of real reserved users Defered[list]: Real reserved users
""" """
count = 0 users = []
for tp in self.hs.config.mau_limits_reserved_threepids:
for tp in self.hs.config.mau_limits_reserved_threepids[
: self.hs.config.max_mau_value
]:
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"] tp["medium"], tp["address"]
) )
if user_id: if user_id:
count = count + 1 users.append(user_id)
return count
return users
@defer.inlineCallbacks @defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id): def upsert_monthly_active_user(self, user_id):

View file

@ -18,11 +18,10 @@ from collections import namedtuple
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.util import batch_iter from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from ._base import SQLBaseStore
class UserPresenceState( class UserPresenceState(
namedtuple( namedtuple(
@ -119,14 +118,13 @@ class PresenceStore(SQLBaseStore):
) )
# Delete old rows to stop database from getting really big # Delete old rows to stop database from getting really big
sql = ( sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
"DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
)
for states in batch_iter(presence_states, 50): for states in batch_iter(presence_states, 50):
args = [stream_id] clause, args = make_in_list_sql_clause(
args.extend(s.user_id for s in states) self.database_engine, "user_id", [s.user_id for s in states]
txn.execute(sql % (",".join("?" for _ in states),), args) )
txn.execute(sql + clause, [stream_id] + list(args))
def get_all_presence_updates(self, last_id, current_id): def get_all_presence_updates(self, last_id, current_id):
if last_id == current_id: if last_id == current_id:

View file

@ -241,7 +241,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name, "device_display_name": device_display_name,
"ts": pushkey_ts, "ts": pushkey_ts,
"lang": lang, "lang": lang,
"data": encode_canonical_json(data), "data": bytearray(encode_canonical_json(data)),
"last_stream_ordering": last_stream_ordering, "last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag, "profile_tag": profile_tag,
"id": stream_id, "id": stream_id,

View file

@ -21,12 +21,11 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import SQLBaseStore
from .util.id_generators import StreamIdGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -217,24 +216,26 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn): def f(txn):
if from_key: if from_key:
sql = ( sql = """
"SELECT * FROM receipts_linearized WHERE" SELECT * FROM receipts_linearized WHERE
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?" stream_id > ? AND stream_id <= ? AND
) % (",".join(["?"] * len(room_ids))) """
args = list(room_ids) clause, args = make_in_list_sql_clause(
args.extend([from_key, to_key]) self.database_engine, "room_id", room_ids
)
txn.execute(sql, args) txn.execute(sql + clause, [from_key, to_key] + list(args))
else: else:
sql = ( sql = """
"SELECT * FROM receipts_linearized WHERE" SELECT * FROM receipts_linearized WHERE
" room_id IN (%s) AND stream_id <= ?" stream_id <= ? AND
) % (",".join(["?"] * len(room_ids))) """
args = list(room_ids) clause, args = make_in_list_sql_clause(
args.append(to_key) self.database_engine, "room_id", room_ids
)
txn.execute(sql, args) txn.execute(sql + clause, [to_key] + list(args))
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
@ -433,13 +434,19 @@ class ReceiptsStore(ReceiptsWorkerStore):
# we need to points in graph -> linearized form. # we need to points in graph -> linearized form.
# TODO: Make this better. # TODO: Make this better.
def graph_to_linear(txn): def graph_to_linear(txn):
query = ( clause, args = make_in_list_sql_clause(
"SELECT event_id WHERE room_id = ? AND stream_ordering IN (" self.database_engine, "event_id", event_ids
" SELECT max(stream_ordering) WHERE event_id IN (%s)" )
")"
) % (",".join(["?"] * len(event_ids)))
txn.execute(query, [room_id] + event_ids) sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall() rows = txn.fetchall()
if rows: if rows:
return rows[0][0] return rows[0][0]

View file

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import Sqlite3Engine from synapse.storage.engines import Sqlite3Engine
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
@ -372,6 +372,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
results = [] results = []
if membership_list: if membership_list:
if self._current_state_events_membership_up_to_date: if self._current_state_events_membership_up_to_date:
clause, args = make_in_list_sql_clause(
self.database_engine, "c.membership", membership_list
)
sql = """ sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM current_state_events AS c FROM current_state_events AS c
@ -379,11 +382,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
WHERE WHERE
c.type = 'm.room.member' c.type = 'm.room.member'
AND state_key = ? AND state_key = ?
AND c.membership IN (%s) AND %s
""" % ( """ % (
",".join("?" * len(membership_list)) clause,
) )
else: else:
clause, args = make_in_list_sql_clause(
self.database_engine, "m.membership", membership_list
)
sql = """ sql = """
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
FROM current_state_events AS c FROM current_state_events AS c
@ -392,12 +398,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
WHERE WHERE
c.type = 'm.room.member' c.type = 'm.room.member'
AND state_key = ? AND state_key = ?
AND m.membership IN (%s) AND %s
""" % ( """ % (
",".join("?" * len(membership_list)) clause,
) )
txn.execute(sql, (user_id, *membership_list)) txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
if do_invite: if do_invite:

View file

@ -15,12 +15,11 @@
-- There was a bug where we may have updated censored redactions as bytes, -- There was a bug where we may have updated censored redactions as bytes,
-- which can (somehow) cause json to be inserted hex encoded. This goes and -- which can (somehow) cause json to be inserted hex encoded. These updates go
-- undoes any such hex encoded JSON. -- and undoes any such hex encoded JSON.
UPDATE event_json SET json = convert_from(json::bytea, 'utf8')
WHERE event_id IN ( INSERT into background_updates (update_name, progress_json)
SELECT event_json.event_id VALUES ('event_fix_redactions_bytes_create_index', '{}');
FROM event_json
INNER JOIN redactions ON (event_json.event_id = redacts) INSERT into background_updates (update_name, progress_json, depends_on)
WHERE have_censored AND json NOT LIKE '{%' VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index');
);

View file

@ -24,6 +24,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from .background_updates import BackgroundUpdateStore from .background_updates import BackgroundUpdateStore
@ -385,8 +386,10 @@ class SearchStore(SearchBackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
if len(room_ids) < 500: if len(room_ids) < 500:
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) clause, args = make_in_list_sql_clause(
args.extend(room_ids) self.database_engine, "room_id", room_ids
)
clauses = [clause]
local_clauses = [] local_clauses = []
for key in keys: for key in keys:
@ -492,8 +495,10 @@ class SearchStore(SearchBackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
if len(room_ids) < 500: if len(room_ids) < 500:
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) clause, args = make_in_list_sql_clause(
args.extend(room_ids) self.database_engine, "room_id", room_ids
)
clauses = [clause]
local_clauses = [] local_clauses = []
for key in keys: for key in keys:

View file

@ -56,15 +56,15 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates. # iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids)) user_ids = tuple(set(user_ids))
def _get_erased_users(txn): rows = yield self._simple_select_many_batch(
txn.execute( table="erased_users",
"SELECT user_id FROM erased_users WHERE user_id IN (%s)" column="user_id",
% (",".join("?" * len(user_ids))), iterable=user_ids,
user_ids, retcols=("user_id",),
desc="are_users_erased",
) )
return set(r[0] for r in txn) erased_users = set(row["user_id"] for row in rows)
erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
res = dict((u, u in erased_users) for u in user_ids) res = dict((u, u in erased_users) for u in user_ids)
return res return res

View file

@ -187,9 +187,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_update_bad_version(self): def test_update_omitted_version(self):
"""Check that we get a 400 if the version in the body is missing or """Check that the update succeeds if the version is missing from the body
doesn't match
""" """
version = yield self.handler.create_version( version = yield self.handler.create_version(
self.local_user, self.local_user,
@ -197,8 +196,6 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = None
try:
yield self.handler.update_version( yield self.handler.update_version(
self.local_user, self.local_user,
version, version,
@ -207,9 +204,27 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"auth_data": "revised_first_version_auth_data", "auth_data": "revised_first_version_auth_data",
}, },
) )
except errors.SynapseError as e:
res = e.code # check we can retrieve it as the current version
self.assertEqual(res, 400) res = yield self.handler.get_version_info(self.local_user)
self.assertDictEqual(
res,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
},
)
@defer.inlineCallbacks
def test_update_bad_version(self):
"""Check that we get a 400 if the version in the body doesn't match
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
)
self.assertEqual(version, "1")
res = None res = None
try: try:

View file

@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("user", res) self.assertIn("user", res)
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev) self.assertDictContainsSubset(json, dev)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_reupload_key(self): def test_reupload_key(self):
@ -68,7 +68,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset( self.assertDictContainsSubset(
{"keys": json, "device_display_name": "display_name"}, dev {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.store_device("user2", "device1", None) yield self.store.store_device("user2", "device1", None)
yield self.store.store_device("user2", "device2", None) yield self.store.store_device("user2", "device2", None)
yield self.store.set_e2e_device_keys("user1", "device1", now, "json11") yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
yield self.store.set_e2e_device_keys("user1", "device2", now, "json12") yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
yield self.store.set_e2e_device_keys("user2", "device1", now, "json21") yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
yield self.store.set_e2e_device_keys("user2", "device2", now, "json22") yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
res = yield self.store.get_e2e_device_keys( res = yield self.store.get_e2e_device_keys(
(("user1", "device1"), ("user2", "device2")) (("user1", "device1"), ("user2", "device2"))

View file

@ -57,7 +57,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
"(event_id, algorithm, hash) " "(event_id, algorithm, hash) "
"VALUES (?, 'sha256', ?)" "VALUES (?, 'sha256', ?)"
), ),
(event_id, b"ffff"), (event_id, bytearray(b"ffff")),
) )
for i in range(0, 11): for i in range(0, 11):

View file

@ -50,6 +50,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": user2_email}, {"medium": "email", "address": user2_email},
{"medium": "email", "address": user3_email}, {"medium": "email", "address": user3_email},
] ]
self.hs.config.mau_limits_reserved_threepids = threepids
# -1 because user3 is a support user and does not count # -1 because user3 is a support user and does not count
user_num = len(threepids) - 1 user_num = len(threepids) - 1
@ -84,6 +85,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.hs.config.max_mau_value = 0 self.hs.config.max_mau_value = 0
self.reactor.advance(FORTY_DAYS) self.reactor.advance(FORTY_DAYS)
self.hs.config.max_mau_value = 5
self.store.reap_monthly_active_users() self.store.reap_monthly_active_users()
self.pump() self.pump()
@ -147,9 +149,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.reap_monthly_active_users() self.store.reap_monthly_active_users()
self.pump() self.pump()
count = self.store.get_monthly_active_count() count = self.store.get_monthly_active_count()
self.assertEquals( self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
self.get_success(count), initial_users - self.hs.config.max_mau_value
)
self.reactor.advance(FORTY_DAYS) self.reactor.advance(FORTY_DAYS)
self.store.reap_monthly_active_users() self.store.reap_monthly_active_users()
@ -158,6 +158,44 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.store.get_monthly_active_count() count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0) self.assertEquals(self.get_success(count), 0)
def test_reap_monthly_active_users_reserved_users(self):
""" Tests that reaping correctly handles reaping where reserved users are
present"""
self.hs.config.max_mau_value = 5
initial_users = 5
reserved_user_number = initial_users - 1
threepids = []
for i in range(initial_users):
user = "@user%d:server" % i
email = "user%d@example.com" % i
self.get_success(self.store.upsert_monthly_active_user(user))
threepids.append({"medium": "email", "address": email})
# Need to ensure that the most recent entries in the
# monthly_active_users table are reserved
now = int(self.hs.get_clock().time_msec())
if i != 0:
self.get_success(
self.store.register_user(user_id=user, password_hash=None)
)
self.get_success(
self.store.user_add_threepid(user, "email", email, now, now)
)
self.hs.config.mau_limits_reserved_threepids = threepids
self.store.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
count = self.store.get_monthly_active_count()
self.assertTrue(self.get_success(count), initial_users)
users = self.store.get_registered_reserved_users()
self.assertEquals(len(self.get_success(users)), reserved_user_number)
self.get_success(self.store.reap_monthly_active_users())
count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
def test_populate_monthly_users_is_guest(self): def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list # Test that guest users are not added to mau list
user_id = "@user_id:host" user_id = "@user_id:host"
@ -192,12 +230,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_get_reserved_real_user_account(self): def test_get_reserved_real_user_account(self):
# Test no reserved users, or reserved threepids # Test no reserved users, or reserved threepids
count = self.store.get_registered_reserved_users_count() users = self.get_success(self.store.get_registered_reserved_users())
self.assertEquals(self.get_success(count), 0) self.assertEquals(len(users), 0)
# Test reserved users but no registered users # Test reserved users but no registered users
user1 = "@user1:example.com" user1 = "@user1:example.com"
user2 = "@user2:example.com" user2 = "@user2:example.com"
user1_email = "user1@example.com" user1_email = "user1@example.com"
user2_email = "user2@example.com" user2_email = "user2@example.com"
threepids = [ threepids = [
@ -210,8 +249,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
) )
self.pump() self.pump()
count = self.store.get_registered_reserved_users_count() users = self.get_success(self.store.get_registered_reserved_users())
self.assertEquals(self.get_success(count), 0) self.assertEquals(len(users), 0)
# Test reserved registed users # Test reserved registed users
self.store.register_user(user_id=user1, password_hash=None) self.store.register_user(user_id=user1, password_hash=None)
@ -221,8 +260,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
now = int(self.hs.get_clock().time_msec()) now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now)
count = self.store.get_registered_reserved_users_count()
self.assertEquals(self.get_success(count), len(threepids)) users = self.get_success(self.store.get_registered_reserved_users())
self.assertEquals(len(users), len(threepids))
def test_support_user_not_add_to_mau_limits(self): def test_support_user_not_add_to_mau_limits(self):
support_user_id = "@support:test" support_user_id = "@support:test"