Merge remote-tracking branch 'origin/release-v1.21.3' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2020-10-22 09:57:06 +01:00
commit ab4cd7f802
84 changed files with 1264 additions and 781 deletions

1
.gitignore vendored
View file

@ -21,6 +21,7 @@ _trial_temp*/
/.python-version
/*.signing.key
/env/
/.venv*/
/homeserver*.yaml
/logs
/media_store/

1
changelog.d/8504.bugfix Normal file
View file

@ -0,0 +1 @@
Expose the `uk.half-shot.msc2778.login.application_service` to clients from the login API. This feature was added in v1.21.0, but was not exposed as a potential login flow.

1
changelog.d/8544.feature Normal file
View file

@ -0,0 +1 @@
Allow running background tasks in a separate worker process.

1
changelog.d/8545.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long standing bug where email notifications for encrypted messages were blank.

1
changelog.d/8561.misc Normal file
View file

@ -0,0 +1 @@
Move metric registration code down into `LruCache`.

1
changelog.d/8562.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations for `LruCache`.

1
changelog.d/8563.misc Normal file
View file

@ -0,0 +1 @@
Replace `DeferredCache` with the lighter-weight `LruCache` where possible.

1
changelog.d/8564.feature Normal file
View file

@ -0,0 +1 @@
Support modifying event content in `ThirdPartyRules` modules.

1
changelog.d/8566.misc Normal file
View file

@ -0,0 +1 @@
Add virtualenv-generated folders to `.gitignore`.

1
changelog.d/8567.bugfix Normal file
View file

@ -0,0 +1 @@
Fix increase in the number of `There was no active span...` errors logged when using OpenTracing.

1
changelog.d/8568.misc Normal file
View file

@ -0,0 +1 @@
Add `get_immediate` method to `DeferredCache`.

1
changelog.d/8569.misc Normal file
View file

@ -0,0 +1 @@
Fix mypy not properly checking across the codebase, additionally, fix a typing assertion error in `handlers/auth.py`.

1
changelog.d/8571.misc Normal file
View file

@ -0,0 +1 @@
Fix `synmark` benchmark runner.

1
changelog.d/8572.misc Normal file
View file

@ -0,0 +1 @@
Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s.

1
changelog.d/8577.misc Normal file
View file

@ -0,0 +1 @@
Adjust a protocol-type definition to fit `sqlite3` assertions.

1
changelog.d/8578.misc Normal file
View file

@ -0,0 +1 @@
Support macOS on the `synmark` benchmark runner.

1
changelog.d/8583.misc Normal file
View file

@ -0,0 +1 @@
Update `mypy` static type checker to 0.790.

1
changelog.d/8585.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug that prevented errors encountered during execution of the `synapse_port_db` from being correctly printed.

1
changelog.d/8587.misc Normal file
View file

@ -0,0 +1 @@
Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting.

1
changelog.d/8589.removal Normal file
View file

@ -0,0 +1 @@
Drop unused `device_max_stream_id` table.

1
changelog.d/8590.misc Normal file
View file

@ -0,0 +1 @@
Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices.

1
changelog.d/8591.misc Normal file
View file

@ -0,0 +1 @@
Move metric registration code down into `LruCache`.

1
changelog.d/8592.misc Normal file
View file

@ -0,0 +1 @@
Remove extraneous unittest logging decorators from unit tests.

1
changelog.d/8593.misc Normal file
View file

@ -0,0 +1 @@
Minor optimisations in caching code.

1
changelog.d/8594.misc Normal file
View file

@ -0,0 +1 @@
Minor optimisations in caching code.

1
changelog.d/8599.feature Normal file
View file

@ -0,0 +1 @@
Allow running background tasks in a separate worker process.

1
changelog.d/8600.misc Normal file
View file

@ -0,0 +1 @@
Update `mypy` static type checker to 0.790.

1
changelog.d/8606.feature Normal file
View file

@ -0,0 +1 @@
Limit appservice transactions to 100 persistent and 100 ephemeral events.

1
changelog.d/8609.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to profile and base handler.

View file

@ -15,8 +15,9 @@ files =
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/federation,
synapse/handlers/appservice.py,
synapse/handlers/_base.py,
synapse/handlers/account_data.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py,
@ -32,6 +33,7 @@ files =
synapse/handlers/pagination.py,
synapse/handlers/password_policy.py,
synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,

View file

@ -22,6 +22,7 @@ import logging
import sys
import time
import traceback
from typing import Optional
import yaml
@ -152,7 +153,7 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
end_error = None
end_error = None # type: Optional[str]
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
@ -635,7 +636,7 @@ class Porter(object):
self.progress.done()
except Exception as e:
global end_error_exec_info
end_error = e
end_error = str(e)
end_error_exec_info = sys.exc_info()
logger.exception("")
finally:

View file

@ -102,6 +102,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8",
]
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
# Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests.
# Tests assume that all optional dependencies are installed.

View file

@ -34,7 +34,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging import opentracing as opentracing
from synapse.types import StateMap, UserID
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@ -70,8 +69,9 @@ class Auth:
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.token_cache = LruCache(10000)
register_cache("cache", "token_cache", self.token_cache)
self.token_cache = LruCache(
10000, "token_cache"
) # type: LruCache[str, Tuple[str, bool]]
self._auth_blocking = AuthBlocking(self.hs)

View file

@ -60,6 +60,13 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
# Maximum number of events to provide in an AS transaction.
MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
# Maximum number of ephemeral events to provide in an AS transaction.
MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
class ApplicationServiceScheduler:
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
@ -136,10 +143,17 @@ class _ServiceQueuer:
self.requests_in_flight.add(service.id)
try:
while True:
events = self.queued_events.pop(service.id, [])
ephemeral = self.queued_ephemeral.pop(service.id, [])
all_events = self.queued_events.get(service.id, [])
events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
all_events_ephemeral = self.queued_ephemeral.get(service.id, [])
ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
if not events and not ephemeral:
return
try:
await self.txn_ctrl.send(service, events, ephemeral)
except Exception:

View file

@ -98,7 +98,7 @@ class EventBuilder:
return self._state_key is not None
async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
) -> EventBase:
"""Transform into a fully signed and hashed event

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
import synapse.state
import synapse.storage
@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -30,11 +34,7 @@ class BaseHandler:
Common base class for the event handlers.
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@ -56,7 +56,7 @@ class BaseHandler:
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
)
) # type: Optional[Ratelimiter]
else:
self.admin_redaction_ratelimiter = None
@ -127,15 +127,15 @@ class BaseHandler:
if guest_access != "can_join":
if context:
current_state_ids = await context.get_current_state_ids()
current_state = await self.store.get_events(
current_state_dict = await self.store.get_events(
list(current_state_ids.values())
)
current_state = list(current_state_dict.values())
else:
current_state = await self.state_handler.get_current_state(
current_state_map = await self.state_handler.get_current_state(
event.room_id
)
current_state = list(current_state.values())
current_state = list(current_state_map.values())
logger.info("maybe_kick_guest_users %r", current_state)
await self.kick_guest_users(current_state)

View file

@ -22,7 +22,7 @@ from typing import List
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
@ -63,16 +63,10 @@ class AccountValidityHandler:
self._raw_from = email.utils.parseaddr(self._from_string)[1]
# Check the renewal emails to send and send them every 30min.
def send_emails():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"send_renewals", self._send_renewal_emails
)
if hs.config.run_background_tasks:
self.clock.looping_call(send_emails, 30 * 60 * 1000)
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``

View file

@ -1122,20 +1122,22 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash():
def _do_validate_hash(checked_hash: bytes):
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
stored_hash,
checked_hash,
)
if stored_hash:
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii")
return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
return await defer_to_thread(
self.hs.get_reactor(), _do_validate_hash, stored_hash
)
else:
return False

View file

@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
# The member_event_id will always be available if membership is set
# to leave.
assert member_event_id
result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
membership: str,
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)

View file

@ -1364,7 +1364,12 @@ class EventCreationHandler:
for k, v in original_event.internal_metadata.get_dict().items():
setattr(builder.internal_metadata, k, v)
event = await builder.build(prev_event_ids=original_event.prev_event_ids())
# the event type hasn't changed, so there's no point in re-calculating the
# auth events.
event = await builder.build(
prev_event_ids=original_event.prev_event_ids(),
auth_event_ids=original_event.auth_event_ids(),
)
# we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update.

View file

@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import (
AuthError,
@ -24,11 +24,20 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester, get_domain_from_id
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import (
JsonDict,
Requester,
UserID,
create_requester,
get_domain_from_id,
)
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256
@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.federation = hs.get_federation_client()
@ -57,10 +66,10 @@ class ProfileHandler(BaseHandler):
if hs.config.run_background_tasks:
self.clock.looping_call(
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
async def get_profile(self, user_id):
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
async def get_profile_from_cache(self, user_id):
async def get_profile_from_cache(self, user_id: str) -> JsonDict:
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}
async def get_displayname(self, target_user):
async def get_displayname(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(
@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
return result["displayname"]
async def set_displayname(
self, target_user, requester, new_displayname, by_admin=False
):
self,
target_user: UserID,
requester: Requester,
new_displayname: str,
by_admin: bool = False,
) -> None:
"""Set the displayname of a user
Args:
target_user (UserID): the user whose displayname is to be changed.
requester (Requester): The user attempting to make this change.
new_displayname (str): The displayname to give this user.
by_admin (bool): Whether this change was made by an administrator.
target_user: the user whose displayname is to be changed.
requester: The user attempting to make this change.
new_displayname: The displayname to give this user.
by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
displayname_to_set = new_displayname # type: Optional[str]
if new_displayname == "":
new_displayname = None
displayname_to_set = None
# If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms.
@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
)
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
async def get_avatar_url(self, target_user):
async def get_avatar_url(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
avatar_url = await self.store.get_profile_avatar_url(
@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
return result["avatar_url"]
async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
self,
target_user: UserID,
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
):
"""Set a new avatar URL for a user.
Args:
target_user (UserID): the user whose avatar URL is to be changed.
requester (Requester): The user attempting to make this change.
new_avatar_url (str): The avatar URL to give this user.
by_admin (bool): Whether this change was made by an administrator.
target_user: the user whose avatar URL is to be changed.
requester: The user attempting to make this change.
new_avatar_url: The avatar URL to give this user.
by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
async def on_profile_query(self, args):
async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
return response
async def _update_join_states(self, requester, target_user):
async def _update_join_states(
self, requester: Requester, target_user: UserID
) -> None:
if not self.hs.is_mine(target_user):
return
@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e)
)
async def check_profile_query_allowed(self, target_user, requester=None):
async def check_profile_query_allowed(
self, target_user: UserID, requester: Optional[UserID] = None
) -> None:
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.
Args:
target_user (UserID): The owner of the queried profile.
requester (None|UserID): The user querying for the profile.
target_user: The owner of the queried profile.
requester: The user querying for the profile.
Raises:
SynapseError(403): The two users share no room, or ne user couldn't
@ -370,11 +394,7 @@ class ProfileHandler(BaseHandler):
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
raise
def _start_update_remote_profile_cache(self):
return run_as_background_process(
"Update remote profile", self._update_remote_profile_cache
)
@wrap_as_background_process("Update remote profile")
async def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't
checked in a while.

225
synapse/logging/_remote.py Normal file
View file

@ -0,0 +1,225 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
from typing import Callable, Optional
import attr
from zope.interface import implementer
from twisted.application.internet import ClientService
from twisted.internet.defer import Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import ILogObserver, Logger, LogLevel
@attr.s
@implementer(IPushProducer)
class LogProducer:
"""
An IPushProducer that writes logs from its buffer to its transport when it
is resumed.
Args:
buffer: Log buffer to read logs from.
transport: Transport to write to.
format_event: A callable to format the log entry to a string.
"""
transport = attr.ib(type=ITransport)
format_event = attr.ib(type=Callable[[dict], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
def pauseProducing(self):
self._paused = True
def stopProducing(self):
self._paused = True
self._buffer = deque()
def resumeProducing(self):
self._paused = False
while self._paused is False and (self._buffer and self.transport.connected):
try:
# Request the next event and format it.
event = self._buffer.popleft()
msg = self.format_event(event)
# Send it as a new line over the transport.
self.transport.write(msg.encode("utf8"))
except Exception:
# Something has gone wrong writing to the transport -- log it
# and break out of the while.
traceback.print_exc(file=sys.__stderr__)
break
@attr.s
@implementer(ILogObserver)
class TCPLogObserver:
"""
An IObserver that writes JSON logs to a TCP target.
Args:
hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
format_event: A callable to format the log entry to a string.
maximum_buffer: The maximum buffer size.
"""
hs = attr.ib()
host = attr.ib(type=str)
port = attr.ib(type=int)
format_event = attr.ib(type=Callable[[dict], str])
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
_logger = attr.ib(default=attr.Factory(Logger))
_producer = attr.ib(default=None, type=Optional[LogProducer])
def start(self) -> None:
# Connect without DNS lookups if it's a direct IP.
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
else:
raise ValueError("Unknown IP address provided: %s" % (self.host,))
except ValueError:
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service.startService()
self._connect()
def stop(self):
self._service.stopService()
def _connect(self) -> None:
"""
Triggers an attempt to connect then write to the remote if not already writing.
"""
if self._connection_waiter:
return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
@self._connection_waiter.addErrback
def fail(r):
r.printTraceback(file=sys.__stderr__)
self._connection_waiter = None
self._connect()
@self._connection_waiter.addCallback
def writer(r):
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
if self._producer and r.transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
# If the producer is still producing, stop it.
if self._producer:
self._producer.stopProducing()
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
transport=r.transport,
format_event=self.format_event,
)
r.transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
def _handle_pressure(self) -> None:
"""
Handle backpressure by shedding events.
The buffer will, in this order, until the buffer is below the maximum:
- Shed DEBUG events
- Shed INFO events
- Shed the middle 50% of the events.
"""
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out DEBUGs
self._buffer = deque(
filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out INFOs
self._buffer = deque(
filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Cut the middle entries out
buffer_split = floor(self.maximum_buffer / 2)
old_buffer = self._buffer
self._buffer = deque()
for i in range(buffer_split):
self._buffer.append(old_buffer.popleft())
end_buffer = []
for i in range(buffer_split):
end_buffer.append(old_buffer.pop())
self._buffer.extend(reversed(end_buffer))
def __call__(self, event: dict) -> None:
self._buffer.append(event)
# Handle backpressure, if it exists.
try:
self._handle_pressure()
except Exception:
# If handling backpressure fails,clear the buffer and log the
# exception.
self._buffer.clear()
self._logger.failure("Failed clearing backpressure")
# Try and write immediately.
self._connect()

View file

@ -18,26 +18,11 @@ Log formatters that output terse JSON.
"""
import json
import sys
import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
from typing import IO, Optional
from typing import IO
import attr
from zope.interface import implementer
from twisted.logger import FileLogObserver
from twisted.application.internet import ClientService
from twisted.internet.defer import Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import FileLogObserver, ILogObserver, Logger
from synapse.logging._remote import TCPLogObserver
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
@ -150,180 +135,22 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
return FileLogObserver(outFile, formatEvent)
@attr.s
@implementer(IPushProducer)
class LogProducer:
def TerseJSONToTCPLogObserver(
hs, host: str, port: int, metadata: dict, maximum_buffer: int
) -> FileLogObserver:
"""
An IPushProducer that writes logs from its buffer to its transport when it
is resumed.
Args:
buffer: Log buffer to read logs from.
transport: Transport to write to.
"""
transport = attr.ib(type=ITransport)
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
def pauseProducing(self):
self._paused = True
def stopProducing(self):
self._paused = True
self._buffer = deque()
def resumeProducing(self):
self._paused = False
while self._paused is False and (self._buffer and self.transport.connected):
try:
event = self._buffer.popleft()
self.transport.write(_encoder.encode(event).encode("utf8"))
self.transport.write(b"\n")
except Exception:
# Something has gone wrong writing to the transport -- log it
# and break out of the while.
traceback.print_exc(file=sys.__stderr__)
break
@attr.s
@implementer(ILogObserver)
class TerseJSONToTCPLogObserver:
"""
An IObserver that writes JSON logs to a TCP target.
A log observer that formats events to a flattened JSON representation.
Args:
hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
metadata: Metadata to be added to each log entry.
metadata: Metadata to be added to each log object.
maximum_buffer: The maximum buffer size.
"""
hs = attr.ib()
host = attr.ib(type=str)
port = attr.ib(type=int)
metadata = attr.ib(type=dict)
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
_logger = attr.ib(default=attr.Factory(Logger))
_producer = attr.ib(default=None, type=Optional[LogProducer])
def formatEvent(_event: dict) -> str:
flattened = flatten_event(_event, metadata, include_time=True)
return _encoder.encode(flattened) + "\n"
def start(self) -> None:
# Connect without DNS lookups if it's a direct IP.
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
except ValueError:
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service.startService()
self._connect()
def stop(self):
self._service.stopService()
def _connect(self) -> None:
"""
Triggers an attempt to connect then write to the remote if not already writing.
"""
if self._connection_waiter:
return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
@self._connection_waiter.addErrback
def fail(r):
r.printTraceback(file=sys.__stderr__)
self._connection_waiter = None
self._connect()
@self._connection_waiter.addCallback
def writer(r):
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
if self._producer and r.transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
# If the producer is still producing, stop it.
if self._producer:
self._producer.stopProducing()
# Make a new producer and start it.
self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
r.transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
def _handle_pressure(self) -> None:
"""
Handle backpressure by shedding events.
The buffer will, in this order, until the buffer is below the maximum:
- Shed DEBUG events
- Shed INFO events
- Shed the middle 50% of the events.
"""
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out DEBUGs
self._buffer = deque(
filter(lambda event: event["level"] != "DEBUG", self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out INFOs
self._buffer = deque(
filter(lambda event: event["level"] != "INFO", self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Cut the middle entries out
buffer_split = floor(self.maximum_buffer / 2)
old_buffer = self._buffer
self._buffer = deque()
for i in range(buffer_split):
self._buffer.append(old_buffer.popleft())
end_buffer = []
for i in range(buffer_split):
end_buffer.append(old_buffer.pop())
self._buffer.extend(reversed(end_buffer))
def __call__(self, event: dict) -> None:
flattened = flatten_event(event, self.metadata, include_time=True)
self._buffer.append(flattened)
# Handle backpressure, if it exists.
try:
self._handle_pressure()
except Exception:
# If handling backpressure fails,clear the buffer and log the
# exception.
self._buffer.clear()
self._logger.failure("Failed clearing backpressure")
# Try and write immediately.
self._connect()
return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)

View file

@ -24,6 +24,7 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import start_active_span
if TYPE_CHECKING:
import resource
@ -197,14 +198,14 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
with BackgroundProcessLoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count)
try:
result = func(*args, **kwargs)
with start_active_span(desc, tags={"request_id": context.request}):
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
if inspect.isawaitable(result):
result = await result
return result
return result
except Exception:
logger.exception(
"Background process '%s' threw an exception", desc,

View file

@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
# dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow.
def __call__(self):
rules = self.cache.get(self.room_id, None, update_metrics=False)
rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()

View file

@ -387,8 +387,8 @@ class Mailer:
return ret
async def get_message_vars(self, notif, event, room_state_ids):
if event.type != EventTypes.Message:
return
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return None
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
sender_state_event = await self.store.get_event(sender_state_event_id)
@ -399,10 +399,8 @@ class Mailer:
# sender_hash % the number of default images to choose from
sender_hash = string_ordinal_total(event.sender)
msgtype = event.content.get("msgtype")
ret = {
"msgtype": msgtype,
"event_type": event.type,
"is_historical": event.event_id != notif["event_id"],
"id": event.event_id,
"ts": event.origin_server_ts,
@ -411,6 +409,14 @@ class Mailer:
"sender_hash": sender_hash,
}
# Encrypted messages don't have any additional useful information.
if event.type == EventTypes.Encrypted:
return ret
msgtype = event.content.get("msgtype")
ret["msgtype"] = msgtype
if msgtype == "m.text":
self.add_text_message_vars(ret, event)
elif msgtype == "m.image":

View file

@ -16,11 +16,10 @@
import logging
import re
from typing import Any, Dict, List, Optional, Pattern, Union
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase
from synapse.types import UserID
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@ -174,20 +173,21 @@ class PushRuleEvaluatorForEvent:
# Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None)
if not r:
r = re.escape(display_name)
r = _re_word_boundary(r)
r = re.compile(r, flags=re.IGNORECASE)
r1 = re.escape(display_name)
r1 = _re_word_boundary(r1)
r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r
return r.search(body)
return bool(r.search(body))
def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000)
register_cache("cache", "regex_push_cache", regex_cache)
regex_cache = LruCache(
50000, "regex_push_cache"
) # type: LruCache[Tuple[str, bool, bool], Pattern]
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
@ -205,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
if not r:
r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r
return r.search(value)
return bool(r.search(value))
except re.error:
logger.warning("Failed to parse glob to regex: %r", glob)
return False

View file

@ -15,7 +15,7 @@
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache
from ._base import BaseSlavedStore
@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000
) # type: DeferredCache[tuple, int]
self.client_ip_last_seen = LruCache(
cache_name="client_ip_last_seen", keylen=4, max_size=50000
) # type: LruCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore):
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.client_ip_last_seen.prefill(key, now)
self.client_ip_last_seen.set(key, now)
self.hs.get_tcp_replication().send_user_ip(
user_id, access_token, ip, user_agent, device_id, now

View file

@ -1,41 +1,47 @@
{% for message in notif.messages %}
{%- for message in notif.messages %}
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
<td class="sender_avatar">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
{% if message.sender_avatar_url %}
{%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
{%- if message.sender_avatar_url %}
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
{% else %}
{% if message.sender_hash % 3 == 0 %}
{%- else %}
{%- if message.sender_hash % 3 == 0 %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" />
{% elif message.sender_hash % 3 == 1 %}
{%- elif message.sender_hash % 3 == 1 %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" />
{% else %}
{%- else %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" />
{% endif %}
{% endif %}
{% endif %}
{%- endif %}
{%- endif %}
{%- endif %}
</td>
<td class="message_contents">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
{% endif %}
{%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
<div class="sender_name">{%- if message.msgtype == "m.emote" %}*{%- endif %} {{ message.sender_name }}</div>
{%- endif %}
<div class="message_body">
{% if message.msgtype == "m.text" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.image" %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{% elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
{% endif %}
{%- if message.event_type == "m.room.encrypted" %}
An encrypted message.
{%- elif message.event_type == "m.room.message" %}
{%- if message.msgtype == "m.text" %}
{{ message.body_text_html }}
{%- elif message.msgtype == "m.emote" %}
{{ message.body_text_html }}
{%- elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
{%- elif message.msgtype == "m.image" %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{%- elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
{%- else %}
A message with unrecognised content.
{%- endif %}
{%- endif %}
</div>
</td>
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
</tr>
{% endfor %}
{%- endfor %}
<tr class="notif_link">
<td></td>
<td>

View file

@ -1,16 +1,22 @@
{% for message in notif.messages %}
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
{% if message.msgtype == "m.text" %}
{%- for message in notif.messages %}
{%- if message.event_type == "m.room.encrypted" %}
An encrypted message.
{%- elif message.event_type == "m.room.message" %}
{%- if message.msgtype == "m.emote" %}* {%- endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
{%- if message.msgtype == "m.text" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.emote" %}
{%- elif message.msgtype == "m.emote" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.notice" %}
{%- elif message.msgtype == "m.notice" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.image" %}
{%- elif message.msgtype == "m.image" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.file" %}
{%- elif message.msgtype == "m.file" %}
{{ message.body_text_plain }}
{% endif %}
{% endfor %}
{%- else %}
A message with unrecognised content.
{%- endif %}
{%- endif %}
{%- endfor %}
View {{ room.title }} at {{ notif.link }}

View file

@ -2,8 +2,8 @@
<html lang="en">
<head>
<style type="text/css">
{% include 'mail.css' without context %}
{% include "mail-%s.css" % app_name ignore missing without context %}
{%- include 'mail.css' without context %}
{%- include "mail-%s.css" % app_name ignore missing without context %}
</style>
</head>
<body>
@ -18,21 +18,21 @@
<div class="summarytext">{{ summary_text }}</div>
</td>
<td class="logo">
{% if app_name == "Riot" %}
{%- if app_name == "Riot" %}
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
{%- elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% elif app_name == "Element" %}
{%- elif app_name == "Element" %}
<img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
{% else %}
{%- else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
{%- endif %}
</td>
</tr>
</table>
{% for room in rooms %}
{% include 'room.html' with context %}
{% endfor %}
{%- for room in rooms %}
{%- include 'room.html' with context %}
{%- endfor %}
<div class="footer">
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
<br/>
@ -41,12 +41,12 @@
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
{% if reason.last_sent_ts %}
{%- if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
{% else %}
{%- else %}
and we don't have a last time we sent a mail for this room.
{% endif %}
{%- endif %}
</div>
</div>
</td>

View file

@ -2,9 +2,9 @@ Hi {{ user_display_name }},
{{ summary_text }}
{% for room in rooms %}
{% include 'room.txt' with context %}
{% endfor %}
{%- for room in rooms %}
{%- include 'room.txt' with context %}
{%- endfor %}
You can disable these notifications at {{ unsubscribe_link }}

View file

@ -1,23 +1,23 @@
<table class="room">
<tr class="room_header">
<td class="room_avatar">
{% if room.avatar_url %}
{%- if room.avatar_url %}
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
{% else %}
{% if room.hash % 3 == 0 %}
{%- else %}
{%- if room.hash % 3 == 0 %}
<img alt="" src="https://riot.im/img/external/avatar-1.png" />
{% elif room.hash % 3 == 1 %}
{%- elif room.hash % 3 == 1 %}
<img alt="" src="https://riot.im/img/external/avatar-2.png" />
{% else %}
{%- else %}
<img alt="" src="https://riot.im/img/external/avatar-3.png" />
{% endif %}
{% endif %}
{%- endif %}
{%- endif %}
</td>
<td class="room_name" colspan="2">
{{ room.title }}
</td>
</tr>
{% if room.invite %}
{%- if room.invite %}
<tr>
<td></td>
<td>
@ -25,9 +25,9 @@
</td>
<td></td>
</tr>
{% else %}
{% for notif in room.notifs %}
{% include 'notif.html' with context %}
{% endfor %}
{% endif %}
{%- else %}
{%- for notif in room.notifs %}
{%- include 'notif.html' with context %}
{%- endfor %}
{%- endif %}
</table>

View file

@ -1,9 +1,9 @@
{{ room.title }}
{% if room.invite %}
{%- if room.invite %}
You've been invited, join at {{ room.link }}
{% else %}
{% for notif in room.notifs %}
{% include 'notif.txt' with context %}
{% endfor %}
{% endif %}
{%- else %}
{%- for notif in room.notifs %}
{%- include 'notif.txt' with context %}
{%- endfor %}
{%- endif %}

View file

@ -110,6 +110,8 @@ class LoginRestServlet(RestServlet):
({"type": t} for t in self.auth_handler.get_supported_login_types())
)
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
return 200, {"flows": flows}
def on_OPTIONS(self, request: SynapseRequest):

View file

@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta):
"""
try:
if key is None:
getattr(self, cache_name).invalidate_all()
else:
getattr(self, cache_name).invalidate(tuple(key))
cache = getattr(self, cache_name)
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
pass
return
if key is None:
cache.invalidate_all()
else:
cache.invalidate(tuple(key))
def db_to_json(db_content):

View file

@ -160,7 +160,7 @@ class LoggingDatabaseConnection:
self.conn.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool:
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class.

View file

@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@ -410,8 +410,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000
self.client_ip_last_seen = LruCache(
cache_name="client_ip_last_seen", keylen=4, max_size=50000
)
super().__init__(database, db_conn, hs)
@ -442,7 +442,7 @@ class ClientIpStore(ClientIpWorkerStore):
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.client_ip_last_seen.prefill(key, now)
self.client_ip_last_seen.set(key, now)
self._batch_row_update[key] = (user_agent, device_id, now)

View file

@ -34,8 +34,8 @@ from synapse.storage.database import (
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@ -1005,8 +1005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = DeferredCache(
name="device_id_exists", keylen=2, max_entries=10000
self.device_id_exists_cache = LruCache(
cache_name="device_id_exists", keylen=2, max_size=10000
)
async def store_device(
@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
self.device_id_exists_cache.prefill(key, True)
self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
raise

View file

@ -1051,9 +1051,7 @@ class PersistEventsStore:
def prefill():
for cache_entry in to_prefill:
self.store._get_event_cache.prefill(
(cache_entry[0].event_id,), cache_entry
)
self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
txn.call_after(prefill)

View file

@ -33,7 +33,10 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
@ -42,8 +45,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -137,20 +140,16 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1
)
if not hs.config.worker.worker_app:
if hs.config.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
"_cleanup_old_transaction_ids",
self._cleanup_old_transaction_ids,
self._cleanup_old_transaction_ids, 5 * 60 * 1000,
)
self._get_event_cache = DeferredCache(
"*getEvent*",
self._get_event_cache = LruCache(
cache_name="*getEvent*",
keylen=3,
max_entries=hs.config.caches.event_cache_size,
apply_cache_factor_from_config=False,
max_size=hs.config.caches.event_cache_size,
)
self._event_fetch_lock = threading.Condition()
@ -749,7 +748,7 @@ class EventsWorkerStore(SQLBaseStore):
event=original_ev, redacted_event=redacted_event
)
self._get_event_cache.prefill((event_id,), cache_entry)
self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry
return result_map
@ -1375,6 +1374,7 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self):
"""Cleans out transaction id mappings older than 24hrs.
"""

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_displayname(
self, user_localpart: str, new_displayname: str
self, user_localpart: str, new_displayname: Optional[str]
) -> None:
await self.db_pool.simple_update_one(
table="profiles",
@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`
"""

View file

@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
lock=False,
)
user_has_pusher = self.get_if_user_has_pusher.cache.get(
user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
(user_id,), None, update_metrics=False
)

View file

@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
if receipt_type != "m.read":
return
# Returns either an ObservableDeferred or the raw result
res = self.get_users_with_read_receipts_in_room.cache.get(
res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
room_id, None, update_metrics=False
)
# first handle the ObservableDeferred case
if isinstance(res, ObservableDeferred):
if res.has_called():
res = res.get_result()
else:
res = None
if res and user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there

View file

@ -20,7 +20,10 @@ from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
@ -67,16 +70,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
run_as_background_process,
60 * 1000,
"_count_known_servers",
self._count_known_servers,
self._count_known_servers, 60 * 1000,
)
self.hs.get_clock().call_later(
1000,
run_as_background_process,
"_count_known_servers",
self._count_known_servers,
1000, self._count_known_servers,
)
LaterGauge(
"synapse_federation_known_servers",
@ -85,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
@wrap_as_background_process("_count_known_servers")
async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@ -531,7 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids`
if context.prev_group and context.delta_ids:
prev_res = self._get_joined_users_from_context.cache.get(
prev_res = self._get_joined_users_from_context.cache.get_immediate(
(room_id, context.prev_group), None
)
if prev_res and isinstance(prev_res, dict):

View file

@ -13,6 +13,5 @@
* limitations under the License.
*/
ALTER TABLE application_services_state
ADD COLUMN read_receipt_stream_id INT,
ADD COLUMN presence_stream_id INT;
ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT;
ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT;

View file

@ -0,0 +1 @@
DROP TABLE device_max_stream_id;

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, List, Tuple
from typing import Any, Iterable, Iterator, List, Optional, Tuple
from typing_extensions import Protocol
@ -65,5 +65,5 @@ class Connection(Protocol):
def __enter__(self) -> "Connection":
...
def __exit__(self, exc_type, exc_value, traceback) -> bool:
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
...

View file

@ -16,7 +16,7 @@
import logging
from sys import intern
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Sized
import attr
from prometheus_client.core import Gauge
@ -92,7 +92,7 @@ class CacheMetric:
def register_cache(
cache_type: str,
cache_name: str,
cache,
cache: Sized,
collect_callback: Optional[Callable] = None,
resizable: bool = True,
resize_callback: Optional[Callable] = None,
@ -100,12 +100,15 @@ def register_cache(
"""Register a cache object for metric collection and resizing.
Args:
cache_type
cache_type: a string indicating the "type" of the cache. This is used
only for deduplication so isn't too important provided it's constant.
cache_name: name of the cache
cache: cache itself
cache: cache itself, which must implement __len__(), and may optionally implement
a max_size property
collect_callback: If given, a function which is called during metric
collection to update additional metrics.
resizable: Whether this cache supports being resized.
resizable: Whether this cache supports being resized, in which case either
resize_callback must be provided, or the cache must support set_max_size().
resize_callback: A function which can be called to resize the cache.
Returns:

View file

@ -17,14 +17,23 @@
import enum
import threading
from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
from typing import (
Callable,
Generic,
Iterable,
MutableMapping,
Optional,
TypeVar,
Union,
cast,
)
from prometheus_client import Gauge
from twisted.internet import defer
from twisted.python import failure
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@ -34,7 +43,7 @@ cache_pending_metric = Gauge(
["name"],
)
T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")
@ -49,15 +58,12 @@ class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred.
will return a Deferred.
"""
__slots__ = (
"cache",
"name",
"keylen",
"thread",
"metrics",
"_pending_deferred_cache",
)
@ -89,37 +95,27 @@ class DeferredCache(Generic[KT, VT]):
cache_type()
) # type: MutableMapping[KT, CacheEntry]
def metrics_cb():
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
cache_name=name,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
) # type: LruCache[KT, VT]
self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
@property
def max_entries(self):
return self.cache.max_size
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@ -133,62 +129,113 @@ class DeferredCache(Generic[KT, VT]):
def get(
self,
key: KT,
default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
):
) -> defer.Deferred:
"""Looks the key up in the caches.
For symmetry with set(), this method does *not* follow the synapse logcontext
rules: the logcontext will not be cleared on return, and the Deferred will run
its callbacks in the sentinel context. In other words: wrap the result with
make_deferred_yieldable() before `await`ing it.
Args:
key(tuple)
default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
key:
callback: Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the result itself
A Deferred which completes with the result. Note that this may later fail
if there is an ongoing set() operation which later completes with a failure.
Raises:
KeyError if the key is not found in the cache
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
return val.deferred.observe()
val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
if val is not _Sentinel.sentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
if default is _Sentinel.sentinel:
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
)
if val2 is _Sentinel.sentinel:
raise KeyError()
else:
return default
return defer.succeed(val2)
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]:
"""If we have a *completed* cached value, return it."""
return self.cache.get(key, default, update_metrics=update_metrics)
def set(
self,
key: KT,
value: defer.Deferred,
callback: Optional[Callable[[], None]] = None,
) -> ObservableDeferred:
) -> defer.Deferred:
"""Adds a new entry to the cache (or updates an existing one).
The given `value` *must* be a Deferred.
First any existing entry for the same key is invalidated. Then a new entry
is added to the cache for the given key.
Until the `value` completes, calls to `get()` for the key will also result in an
incomplete Deferred, which will ultimately complete with the same result as
`value`.
If `value` completes successfully, subsequent calls to `get()` will then return
a completed deferred with the same result. If it *fails*, the cache is
invalidated and subequent calls to `get()` will raise a KeyError.
If another call to `set()` happens before `value` completes, then (a) any
invalidation callbacks registered in the interim will be called, (b) any
`get()`s in the interim will continue to complete with the result from the
*original* `value`, (c) any future calls to `get()` will complete with the
result from the *new* `value`.
It is expected that `value` does *not* follow the synapse logcontext rules - ie,
if it is incomplete, it runs its callbacks in the sentinel context.
Args:
key: Key to be set
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
# XXX: why don't we invalidate the entry in `self.cache` yet?
# we can save a whole load of effort if the deferred is ready.
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
self.cache.set(key, result, callbacks)
return value
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
self._pending_deferred_cache[key] = entry
def compare_and_pop():
@ -232,7 +279,9 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
@ -257,11 +306,12 @@ class DeferredCache(Generic[KT, VT]):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
key = cast(KT, key)
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()

View file

@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
) # type: DeferredCache[Tuple, Any]
) # type: DeferredCache[CacheKey, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
@ -202,32 +201,20 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_key = get_cache_key(args, kwargs)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
observer = defer.succeed(cached_result_d)
ret = cache.get(cache_key, callback=invalidate_callback)
except KeyError:
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext.get_instance(
cache, cache_key
)
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = cache.set(cache_key, ret, callback=invalidate_callback)
def onErr(f):
cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
return make_deferred_yieldable(observer)
return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
@ -286,7 +273,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name)
cache = cached_method.cache
cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
num_args = cached_method.num_args
@functools.wraps(self.orig)
@ -326,14 +313,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
if not res.called:
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
results[arg] = res.get_result()
results[arg] = res.result
except KeyError:
missing.add(arg)

View file

@ -12,15 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import threading
from collections import namedtuple
from typing import Any
from synapse.util.caches.lrucache import LruCache
from . import register_cache
logger = logging.getLogger(__name__)
@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
return len(self.value)
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class DictionaryCache:
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries, size_callback=len)
self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[Any, DictionaryEntry]
self.name = name
self.sequence = 0
self.thread = None
# caches_by_name[name] = self.cache
class Sentinel:
__slots__ = []
self.sentinel = Sentinel()
self.metrics = register_cache("dictionary", name, self.cache)
def check_thread(self):
expected_thread = self.thread
@ -80,10 +80,8 @@ class DictionaryCache:
Returns:
DictionaryEntry
"""
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
self.metrics.inc_hits()
entry = self.cache.get(key, _Sentinel.sentinel)
if entry is not _Sentinel.sentinel:
if dict_keys is None:
return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value)
@ -95,7 +93,6 @@ class DictionaryCache:
{k: entry.value[k] for k in dict_keys if k in entry.value},
)
self.metrics.inc_misses()
return DictionaryEntry(False, set(), {})
def invalidate(self, key):

View file

@ -15,11 +15,35 @@
import threading
from functools import wraps
from typing import Callable, Optional, Type, Union
from typing import (
Any,
Callable,
Generic,
Iterable,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
# Function type: the type used for invalidation callbacks
FT = TypeVar("FT", bound=Callable[..., Any])
# Key and Value type for the cache
KT = TypeVar("KT")
VT = TypeVar("VT")
# a general type var, distinct from either KT or VT
T = TypeVar("T")
def enumerate_leaves(node, depth):
if depth == 0:
@ -41,29 +65,31 @@ class _Node:
self.callbacks = callbacks
class LruCache:
class LruCache(Generic[KT, VT]):
"""
Least-recently-used cache.
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(
self,
max_size: int,
cache_name: Optional[str] = None,
keylen: int = 1,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable] = None,
evicted_callback: Optional[Callable] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
max_size: The maximum amount of entries the cache can hold
cache_name: The name of this cache, for the prometheus metrics. If unset,
no metrics will be reported on this cache.
keylen: The length of the tuple used as the cache key. Ignored unless
cache_type is `TreeCache`.
@ -73,9 +99,13 @@ class LruCache:
size_callback (func(V) -> int | None):
evicted_callback (func(int)|None):
if not None, called on eviction with the size of the evicted
entry
metrics_collection_callback:
metrics collection callback. This is called early in the metrics
collection process, before any of the metrics registered with the
prometheus Registry are collected, so can be used to update any dynamic
metrics.
Ignored if cache_name is None.
apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
@ -94,6 +124,23 @@ class LruCache:
else:
self.max_size = int(max_size)
# register_cache might call our "set_cache_factor" callback; there's nothing to
# do yet when we get resized.
self._on_resize = None # type: Optional[Callable[[],None]]
if cache_name is not None:
metrics = register_cache(
"lru_cache",
cache_name,
self,
collect_callback=metrics_collection_callback,
) # type: Optional[CacheMetric]
else:
metrics = None
# this is exposed for access from outside this class
self.metrics = metrics
list_root = _Node(None, None, None, None)
list_root.next_node = list_root
list_root.prev_node = list_root
@ -105,16 +152,16 @@ class LruCache:
todelete = list_root.prev_node
evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
if evicted_callback:
evicted_callback(evicted_len)
if metrics:
metrics.inc_evictions(evicted_len)
def synchronized(f):
def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return inner
return cast(FT, inner)
cached_cache_len = [0]
if size_callback is not None:
@ -168,18 +215,45 @@ class LruCache:
node.callbacks.clear()
return deleted_len
@overload
def cache_get(
key: KT,
default: Literal[None] = None,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Optional[VT]:
...
@overload
def cache_get(
key: KT,
default: T,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Union[T, VT]:
...
@synchronized
def cache_get(key, default=None, callbacks=[]):
def cache_get(
key: KT,
default: Optional[T] = None,
callbacks: Iterable[Callable[[], None]] = [],
update_metrics: bool = True,
):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
node.callbacks.update(callbacks)
if update_metrics and metrics:
metrics.inc_hits()
return node.value
else:
if update_metrics and metrics:
metrics.inc_misses()
return default
@synchronized
def cache_set(key, value, callbacks=[]):
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@ -208,7 +282,7 @@ class LruCache:
evict()
@synchronized
def cache_set_default(key, value):
def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
@ -217,8 +291,16 @@ class LruCache:
evict()
return value
@overload
def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
...
@overload
def cache_pop(key: KT, default: T) -> Union[T, VT]:
...
@synchronized
def cache_pop(key, default=None):
def cache_pop(key: KT, default: Optional[T] = None):
node = cache.get(key, None)
if node:
delete_node(node)
@ -228,18 +310,18 @@ class LruCache:
return default
@synchronized
def cache_del_multi(key):
def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
for leaf in enumerate_leaves(popped, keylen - len(key)):
for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)
@synchronized
def cache_clear():
def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
@ -250,15 +332,21 @@ class LruCache:
cached_cache_len[0] = 0
@synchronized
def cache_contains(key):
def cache_contains(key: KT) -> bool:
return key in cache
self.sentinel = object()
# make sure that we clear out any excess entries after we get resized.
self._on_resize = evict
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
# invalidated by the cache invalidation replication stream.
self.invalidate = cache_pop
if cache_type is TreeCache:
self.del_multi = cache_del_multi
self.len = synchronized(cache_len)
@ -302,6 +390,7 @@ class LruCache:
new_size = int(self._original_max_size * factor)
if new_size != self.max_size:
self.max_size = new_size
self._on_resize()
if self._on_resize:
self._on_resize()
return True
return False

View file

@ -15,7 +15,10 @@
import sys
from twisted.internet import epollreactor
try:
from twisted.internet.epollreactor import EPollReactor as Reactor
except ImportError:
from twisted.internet.pollreactor import PollReactor as Reactor
from twisted.internet.main import installReactor
from synapse.config.homeserver import HomeServerConfig
@ -41,7 +44,7 @@ async def make_homeserver(reactor, config=None):
config_obj = HomeServerConfig()
config_obj.parse_config_dict(config, "", "")
hs = await setup_test_homeserver(
hs = setup_test_homeserver(
cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
)
stor = hs.get_datastore()
@ -63,7 +66,7 @@ def make_reactor():
Instantiate and install a Twisted reactor suitable for testing (i.e. not the
default global one).
"""
reactor = epollreactor.EPollReactor()
reactor = Reactor()
if "twisted.internet.reactor" in sys.modules:
del sys.modules["twisted.internet.reactor"]

View file

@ -260,6 +260,31 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
srv_1_defer = defer.Deferred()
srv_2_defer = defer.Deferred()
send_return_list = [srv_1_defer, srv_2_defer]
def do_send(x, y, z):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
service = Mock(id=4, name="service")
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
for event in event_list:
self.queuer.enqueue_event(service, event)
# Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
srv_1_defer.callback(service)
# Then send the next 100 events
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
srv_2_defer.callback(service)
# Then the final 99 events
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
@ -296,3 +321,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_send_large_txns_ephemeral(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
event_list = first_chunk + second_chunk
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
self.assertEquals(2, self.txn_ctrl.send.call_count)

View file

@ -78,7 +78,7 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
"server_name",
"name",
]
self.assertEqual(set(log.keys()), set(expected_log_keys))
self.assertCountEqual(log.keys(), expected_log_keys)
# It contains the data we expect.
self.assertEqual(log["name"], "wally")

View file

@ -158,8 +158,21 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
def test_encrypted_message(self):
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
)
self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
# The other user sends some messages
self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token)
# We should get emailed about that message
self._check_for_mail()
def _check_for_mail(self):
"Check that the user receives an email notification"
"""Check that the user receives an email notification"""
# Get the stream ordering before it gets sent
pushers = self.get_success(

View file

@ -352,7 +352,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(request.code, 401)
@unittest.INFO
def test_pending_invites(self):
"""Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastore()

View file

@ -104,7 +104,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")
@unittest.INFO
def test_fallback_captcha(self):
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec

View file

@ -15,237 +15,9 @@
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached
from tests import unittest
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
class A:
@cached()
def func(self, key):
return key
a = A()
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
class A:
@cached()
def func(self, key):
return key
A().func.invalidate(("what",))
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
class A:
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
a = A()
for k in range(0, 12):
yield a.func(k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0, 12):
yield a.func(k)
self.assertTrue(
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
)
def test_prefill(self):
callcount = [0]
d = defer.succeed(123)
class A:
@cached()
def func(self, key):
callcount[0] += 1
return d
a = A()
a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A:
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)
class UpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.storage = hs.get_datastore()

View file

@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from twisted.internet import defer
from synapse.util.caches.deferred_cache import DeferredCache
from tests.unittest import TestCase
class DeferredCacheTestCase(unittest.TestCase):
class DeferredCacheTestCase(TestCase):
def test_empty(self):
cache = DeferredCache("test")
failed = False
@ -36,7 +37,118 @@ class DeferredCacheTestCase(unittest.TestCase):
cache = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEquals(cache.get("foo"), 123)
self.assertEquals(self.successResultOf(cache.get("foo")), 123)
def test_hit_deferred(self):
cache = DeferredCache("test")
origin_d = defer.Deferred()
set_d = cache.set("k1", origin_d)
# get should return an incomplete deferred
get_d = cache.get("k1")
self.assertFalse(get_d.called)
# add a callback that will make sure that the set_d gets called before the get_d
def check1(r):
self.assertTrue(set_d.called)
return r
# TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
# maybe we should fix that?
# get_d.addCallback(check1)
# now fire off all the deferreds
origin_d.callback(99)
self.assertEqual(self.successResultOf(origin_d), 99)
self.assertEqual(self.successResultOf(set_d), 99)
self.assertEqual(self.successResultOf(get_d), 99)
def test_callbacks(self):
"""Invalidation callbacks are called at the right time"""
cache = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
origin_d = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
# we don't expect the invalidation callback for the original value to have
# been called yet, even though get() will now return a different result.
# I'm not sure if that is by design or not.
self.assertEqual(callbacks, set())
# now fire off all the deferreds
origin_d.callback(20)
self.assertEqual(self.successResultOf(set_d), 20)
self.assertEqual(self.successResultOf(get_d), 20)
# now the original invalidation callback should have been called, but none of
# the others
self.assertEqual(callbacks, {"prefill"})
callbacks.clear()
# another update should invalidate both the previous results
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"set", "get"})
def test_set_fail(self):
cache = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
origin_d = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
# none of the callbacks should have been called yet
self.assertEqual(callbacks, set())
# oh noes! fails!
e = Exception("oops")
origin_d.errback(e)
self.assertIs(self.failureResultOf(set_d, Exception).value, e)
self.assertIs(self.failureResultOf(get_d, Exception).value, e)
# the callbacks for the failed requests should have been called.
# I'm not sure if this is deliberate or not.
self.assertEqual(callbacks, {"get", "set"})
callbacks.clear()
# the old value should still be returned now?
get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
self.assertEqual(self.successResultOf(get_d2), 10)
# replacing the value now should run the callbacks for those requests
# which got the original result
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"prefill", "get2"})
def test_get_immediate(self):
cache = DeferredCache("test")
d1 = defer.Deferred()
cache.set("key1", d1)
# get_immediate should return default
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 1)
# now complete the set
d1.callback(2)
# get_immediate should return result
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 2)
def test_invalidate(self):
cache = DeferredCache("test")
@ -66,23 +178,24 @@ class DeferredCacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# lookup should return pending deferreds
self.assertFalse(cache.get("key1").called)
self.assertFalse(cache.get("key2").called)
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now the cache will return a completed deferred
self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# lookup should fail
with self.assertRaises(KeyError):
cache.get("key1")
with self.assertRaises(KeyError):
cache.get("key2")
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
@ -90,7 +203,8 @@ class DeferredCacheTestCase(unittest.TestCase):
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
with self.assertRaises(KeyError):
cache.get("key1", None)
def test_eviction(self):
cache = DeferredCache(

View file

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Set
import mock
@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
def test_cache_with_async_exception(self):
"""The wrapped function returns a failure
"""
class Cls:
result = None
call_count = 0
@cached()
def fn(self, arg1):
self.call_count += 1
return self.result
obj = Cls()
callbacks = set() # type: Set[str]
# set off an asynchronous request
obj.result = origin_d = defer.Deferred()
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
self.assertFalse(d1.called)
# a second request should also return a deferred, but should not call the
# function itself.
d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
self.assertFalse(d2.called)
self.assertEqual(obj.call_count, 1)
# no callbacks yet
self.assertEqual(callbacks, set())
# the original request fails
e = Exception("bzz")
origin_d.errback(e)
# ... which should cause the lookups to fail similarly
self.assertIs(self.failureResultOf(d1, Exception).value, e)
self.assertIs(self.failureResultOf(d2, Exception).value, e)
# ... and the callbacks to have been, uh, called.
self.assertEqual(callbacks, {"d1", "d2"})
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should work as normal
obj.result = defer.succeed(100)
d3 = obj.fn(1)
self.assertEqual(self.successResultOf(d3), 100)
self.assertEqual(obj.call_count, 2)
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase):
self.failureResultOf(d, SynapseError)
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
The following is a set of tests that got lost in a different file for a while.
There are probably duplicates of the tests in DescriptorTestCase. Ideally the
duplicates would be removed and the two sets of classes combined.
"""
@defer.inlineCallbacks
def test_passthrough(self):
class A:
@cached()
def func(self, key):
return key
a = A()
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
class A:
@cached()
def func(self, key):
return key
A().func.invalidate(("what",))
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
class A:
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
a = A()
for k in range(0, 12):
yield a.func(k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0, 12):
yield a.func(k)
self.assertTrue(
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
)
def test_prefill(self):
callcount = [0]
d = defer.succeed(123)
class A:
@cached()
def func(self, key):
callcount[0] += 1
return d
a = A()
a.func.prefill(("foo",), 456)
self.assertEquals(a.func("foo").result, 456)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A:
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):

View file

@ -19,7 +19,8 @@ from mock import Mock
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from .. import unittest
from tests import unittest
from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEquals(cache.pop("key"), None)
def test_del_multi(self):
cache = LruCache(4, 2, cache_type=TreeCache)
cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache.clear()
self.assertEquals(len(cache), 0)
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
def test_special_size(self):
cache = LruCache(10, "mycache")
self.assertEqual(cache.max_size, 100)
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])

View file

@ -158,12 +158,9 @@ commands=
coverage html
[testenv:mypy]
skip_install = True
deps =
{[base]deps}
mypy==0.782
mypy-zope
extras = all
extras = all,mypy
commands = mypy
# To find all folders that pass mypy you run: