Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2020-10-16 11:34:53 +01:00
commit e9b5e642c3
76 changed files with 2027 additions and 791 deletions

View file

@ -1,3 +1,27 @@
Synapse 1.21.2 (2020-10-15)
===========================
Debian packages and Docker images have been rebuilt using the latest versions of dependency libraries, including authlib 0.15.1. Please see bugfixes below.
Security advisory
-----------------
* HTML pages served via Synapse were vulnerable to cross-site scripting (XSS)
attacks. All server administrators are encouraged to upgrade.
([\#8444](https://github.com/matrix-org/synapse/pull/8444))
([CVE-2020-26891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26891))
This fix was originally included in v1.21.0 but was missing a security advisory.
This was reported by [Denis Kasak](https://github.com/dkasak).
Bugfixes
--------
- Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530))
- An updated version of the authlib dependency is included in the Docker and Debian images to fix an issue using OpenID Connect. See [\#8534](https://github.com/matrix-org/synapse/issues/8534) for details.
Synapse 1.21.1 (2020-10-13)
===========================

View file

@ -63,6 +63,10 @@ run-time:
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
```
You can also provided the `-d` option, which will lint the files that have been
changed since the last git commit. This will often be significantly faster than
linting the whole codebase.
Before pushing new changes, ensure they don't produce linting errors. Commit any
files that were corrected.

1
changelog.d/8437.feature Normal file
View file

@ -0,0 +1 @@
Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices.

1
changelog.d/8472.misc Normal file
View file

@ -0,0 +1 @@
Add `-d` option to `./scripts-dev/lint.sh` to lint files that have changed since the last git commit.

1
changelog.d/8488.misc Normal file
View file

@ -0,0 +1 @@
Allow events to be sent to clients sooner when using sharded event persisters.

1
changelog.d/8503.misc Normal file
View file

@ -0,0 +1 @@
Add user agent to user_daily_visits table.

1
changelog.d/8515.misc Normal file
View file

@ -0,0 +1 @@
Apply some internal fixes to the `HomeServer` class to make its code more idiomatic and statically-verifiable.

1
changelog.d/8526.doc Normal file
View file

@ -0,0 +1 @@
Added note about docker in manhole.md regarding which ip address to bind to. Contributed by @Maquis196.

1
changelog.d/8529.doc Normal file
View file

@ -0,0 +1 @@
Document the new behaviour of the `allowed_lifetime_min` and `allowed_lifetime_max` settings in the room retention configuration.

1
changelog.d/8535.feature Normal file
View file

@ -0,0 +1 @@
Support modifying event content in `ThirdPartyRules` modules.

1
changelog.d/8537.misc Normal file
View file

@ -0,0 +1 @@
Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`.

1
changelog.d/8542.misc Normal file
View file

@ -0,0 +1 @@
Improve database performance by executing more queries without starting transactions.

1
changelog.d/8547.misc Normal file
View file

@ -0,0 +1 @@
Enable mypy type checking for `synapse.util.caches`.

1
changelog.d/8548.misc Normal file
View file

@ -0,0 +1 @@
Rename `Cache` to `DeferredCache`, to better reflect its purpose.

7
debian/changelog vendored
View file

@ -1,3 +1,10 @@
matrix-synapse-py3 (1.21.2) stable; urgency=medium
[ Synapse Packaging team ]
* New synapse release 1.21.2.
-- Synapse Packaging team <packages@matrix.org> Thu, 15 Oct 2020 09:23:27 -0400
matrix-synapse-py3 (1.21.1) stable; urgency=medium
[ Synapse Packaging team ]

View file

@ -5,8 +5,45 @@ The "manhole" allows server administrators to access a Python shell on a running
Synapse installation. This is a very powerful mechanism for administration and
debugging.
**_Security Warning_**
Note that this will give administrative access to synapse to **all users** with
shell access to the server. It should therefore **not** be enabled in
environments where untrusted users have shell access.
***
To enable it, first uncomment the `manhole` listener configuration in
`homeserver.yaml`:
`homeserver.yaml`. The configuration is slightly different if you're using docker.
#### Docker config
If you are using Docker, set `bind_addresses` to `['0.0.0.0']` as shown:
```yaml
listeners:
- port: 9000
bind_addresses: ['0.0.0.0']
type: manhole
```
When using `docker run` to start the server, you will then need to change the command to the following to include the
`manhole` port forwarding. The `-p 127.0.0.1:9000:9000` below is important: it
ensures that access to the `manhole` is only possible for local users.
```bash
docker run -d --name synapse \
--mount type=volume,src=synapse-data,dst=/data \
-p 8008:8008 \
-p 127.0.0.1:9000:9000 \
matrixdotorg/synapse:latest
```
#### Native config
If you are not using docker, set `bind_addresses` to `['::1', '127.0.0.1']` as shown.
The `bind_addresses` in the example below is important: it ensures that access to the
`manhole` is only possible for local users).
```yaml
listeners:
@ -15,12 +52,7 @@ listeners:
type: manhole
```
(`bind_addresses` in the above is important: it ensures that access to the
manhole is only possible for local users).
Note that this will give administrative access to synapse to **all users** with
shell access to the server. It should therefore **not** be enabled in
environments where untrusted users have shell access.
#### Accessing synapse manhole
Then restart synapse, and point an ssh client at port 9000 on localhost, using
the username `matrix`:

View file

@ -136,24 +136,34 @@ the server's database.
### Lifetime limits
**Note: this feature is mainly useful within a closed federation or on
servers that don't federate, because there currently is no way to
enforce these limits in an open federation.**
Server admins can restrict the values their local users are allowed to
use for both `min_lifetime` and `max_lifetime`. These limits can be
defined as such in the `retention` section of the configuration file:
Server admins can set limits on the values of `max_lifetime` to use when
purging old events in a room. These limits can be defined as such in the
`retention` section of the configuration file:
```yaml
allowed_lifetime_min: 1d
allowed_lifetime_max: 1y
```
Here, `allowed_lifetime_min` is the lowest value a local user can set
for both `min_lifetime` and `max_lifetime`, and `allowed_lifetime_max`
is the highest value. Both parameters are optional (e.g. setting
`allowed_lifetime_min` but not `allowed_lifetime_max` only enforces a
minimum and no maximum).
The limits are considered when running purge jobs. If necessary, the
effective value of `max_lifetime` will be brought between
`allowed_lifetime_min` and `allowed_lifetime_max` (inclusive).
This means that, if the value of `max_lifetime` defined in the room's state
is lower than `allowed_lifetime_min`, the value of `allowed_lifetime_min`
will be used instead. Likewise, if the value of `max_lifetime` is higher
than `allowed_lifetime_max`, the value of `allowed_lifetime_max` will be
used instead.
In the example above, we ensure Synapse never deletes events that are less
than one day old, and that it always deletes events that are over a year
old.
If a default policy is set, and its `max_lifetime` value is lower than
`allowed_lifetime_min` or higher than `allowed_lifetime_max`, the same
process applies.
Both parameters are optional; if one is omitted Synapse won't use it to
adjust the effective value of `max_lifetime`.
Like other settings in this section, these parameters can be expressed
either as a duration or as a number of milliseconds.

View file

@ -15,6 +15,7 @@ files =
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/federation,
synapse/handlers/appservice.py,
synapse/handlers/account_data.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
@ -64,9 +65,7 @@ files =
synapse/streams,
synapse/types.py,
synapse/util/async_helpers.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/response_cache.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/caches,
synapse/util/metrics.py,
tests/replication,
tests/test_utils,

View file

@ -1,4 +1,4 @@
#!/bin/sh
#!/bin/bash
#
# Runs linting scripts over the local Synapse checkout
# isort - sorts import statements
@ -7,15 +7,90 @@
set -e
if [ $# -ge 1 ]
then
files=$*
else
files="synapse tests scripts-dev scripts contrib synctl"
usage() {
echo
echo "Usage: $0 [-h] [-d] [paths...]"
echo
echo "-d"
echo " Lint files that have changed since the last git commit."
echo
echo " If paths are provided and this option is set, both provided paths and those"
echo " that have changed since the last commit will be linted."
echo
echo " If no paths are provided and this option is not set, all files will be linted."
echo
echo " Note that paths with a file extension that is not '.py' will be excluded."
echo "-h"
echo " Display this help text."
}
USING_DIFF=0
files=()
while getopts ":dh" opt; do
case $opt in
d)
USING_DIFF=1
;;
h)
usage
exit
;;
\?)
echo "ERROR: Invalid option: -$OPTARG" >&2
usage
exit
;;
esac
done
# Strip any options from the command line arguments now that
# we've finished processing them
shift "$((OPTIND-1))"
if [ $USING_DIFF -eq 1 ]; then
# Check both staged and non-staged changes
for path in $(git diff HEAD --name-only); do
filename=$(basename "$path")
file_extension="${filename##*.}"
# If an extension is present, and it's something other than 'py',
# then ignore this file
if [[ -n ${file_extension+x} && $file_extension != "py" ]]; then
continue
fi
# Append this path to our list of files to lint
files+=("$path")
done
fi
echo "Linting these locations: $files"
isort $files
python3 -m black $files
# Append any remaining arguments as files to lint
files+=("$@")
if [[ $USING_DIFF -eq 1 ]]; then
# If we were asked to lint changed files, and no paths were found as a result...
if [ ${#files[@]} -eq 0 ]; then
# Then print and exit
echo "No files found to lint."
exit 0
fi
else
# If we were not asked to lint changed files, and no paths were found as a result,
# then lint everything!
if [[ -z ${files+x} ]]; then
# Lint all source code files and directories
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py")
fi
fi
echo "Linting these paths: ${files[*]}"
echo
# Print out the commands being run
set -x
isort "${files[@]}"
python3 -m black "${files[@]}"
./scripts-dev/config-lint.sh
flake8 $files
flake8 "${files[@]}"

View file

@ -15,12 +15,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
from setuptools import setup, find_packages, Command
import sys
from setuptools import Command, find_packages, setup
here = os.path.abspath(os.path.dirname(__file__))

View file

@ -1,13 +1,12 @@
from .sorteddict import (
SortedDict,
SortedKeysView,
SortedItemsView,
SortedValuesView,
)
from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView
from .sortedlist import SortedKeyList, SortedList, SortedListWithKey
__all__ = [
"SortedDict",
"SortedKeysView",
"SortedItemsView",
"SortedValuesView",
"SortedKeyList",
"SortedList",
"SortedListWithKey",
]

View file

@ -0,0 +1,177 @@
# stub for SortedList. This is an exact copy of
# https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi
# (from https://github.com/grantjenks/python-sortedcontainers/pull/107)
from typing import (
Any,
Callable,
Generic,
Iterable,
Iterator,
List,
MutableSequence,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
_T = TypeVar("_T")
_SL = TypeVar("_SL", bound=SortedList)
_SKL = TypeVar("_SKL", bound=SortedKeyList)
_Key = Callable[[_T], Any]
_Repr = Callable[[], str]
def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ...
class SortedList(MutableSequence[_T]):
DEFAULT_LOAD_FACTOR: int = ...
def __init__(
self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ...,
): ...
# NB: currently mypy does not honour return type, see mypy #3307
@overload
def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ...
@overload
def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ...
@overload
def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ...
@overload
def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ...
@property
def key(self) -> Optional[Callable[[_T], Any]]: ...
def _reset(self, load: int) -> None: ...
def clear(self) -> None: ...
def _clear(self) -> None: ...
def add(self, value: _T) -> None: ...
def _expand(self, pos: int) -> None: ...
def update(self, iterable: Iterable[_T]) -> None: ...
def _update(self, iterable: Iterable[_T]) -> None: ...
def discard(self, value: _T) -> None: ...
def remove(self, value: _T) -> None: ...
def _delete(self, pos: int, idx: int) -> None: ...
def _loc(self, pos: int, idx: int) -> int: ...
def _pos(self, idx: int) -> int: ...
def _build_index(self) -> None: ...
def __contains__(self, value: Any) -> bool: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
@overload
def __getitem__(self, index: int) -> _T: ...
@overload
def __getitem__(self, index: slice) -> List[_T]: ...
@overload
def _getitem(self, index: int) -> _T: ...
@overload
def _getitem(self, index: slice) -> List[_T]: ...
@overload
def __setitem__(self, index: int, value: _T) -> None: ...
@overload
def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ...
def __iter__(self) -> Iterator[_T]: ...
def __reversed__(self) -> Iterator[_T]: ...
def __len__(self) -> int: ...
def reverse(self) -> None: ...
def islice(
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
) -> Iterator[_T]: ...
def _islice(
self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool,
) -> Iterator[_T]: ...
def irange(
self,
minimum: Optional[int] = ...,
maximum: Optional[int] = ...,
inclusive: Tuple[bool, bool] = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def bisect_left(self, value: _T) -> int: ...
def bisect_right(self, value: _T) -> int: ...
def bisect(self, value: _T) -> int: ...
def _bisect_right(self, value: _T) -> int: ...
def count(self, value: _T) -> int: ...
def copy(self: _SL) -> _SL: ...
def __copy__(self: _SL) -> _SL: ...
def append(self, value: _T) -> None: ...
def extend(self, values: Iterable[_T]) -> None: ...
def insert(self, index: int, value: _T) -> None: ...
def pop(self, index: int = ...) -> _T: ...
def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
) -> int: ...
def __add__(self: _SL, other: Iterable[_T]) -> _SL: ...
def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ...
def __iadd__(self: _SL, other: Iterable[_T]) -> _SL: ...
def __mul__(self: _SL, num: int) -> _SL: ...
def __rmul__(self: _SL, num: int) -> _SL: ...
def __imul__(self: _SL, num: int) -> _SL: ...
def __eq__(self, other: Any) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
def __lt__(self, other: Sequence[_T]) -> bool: ...
def __gt__(self, other: Sequence[_T]) -> bool: ...
def __le__(self, other: Sequence[_T]) -> bool: ...
def __ge__(self, other: Sequence[_T]) -> bool: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
class SortedKeyList(SortedList[_T]):
def __init__(
self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
) -> None: ...
def __new__(
cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
) -> SortedKeyList[_T]: ...
@property
def key(self) -> Callable[[_T], Any]: ...
def clear(self) -> None: ...
def _clear(self) -> None: ...
def add(self, value: _T) -> None: ...
def _expand(self, pos: int) -> None: ...
def update(self, iterable: Iterable[_T]) -> None: ...
def _update(self, iterable: Iterable[_T]) -> None: ...
# NB: Must be T to be safely passed to self.func, yet base class imposes Any
def __contains__(self, value: _T) -> bool: ... # type: ignore
def discard(self, value: _T) -> None: ...
def remove(self, value: _T) -> None: ...
def _delete(self, pos: int, idx: int) -> None: ...
def irange(
self,
minimum: Optional[int] = ...,
maximum: Optional[int] = ...,
inclusive: Tuple[bool, bool] = ...,
reverse: bool = ...,
): ...
def irange_key(
self,
min_key: Optional[Any] = ...,
max_key: Optional[Any] = ...,
inclusive: Tuple[bool, bool] = ...,
reserve: bool = ...,
): ...
def bisect_left(self, value: _T) -> int: ...
def bisect_right(self, value: _T) -> int: ...
def bisect(self, value: _T) -> int: ...
def bisect_key_left(self, key: Any) -> int: ...
def _bisect_key_left(self, key: Any) -> int: ...
def bisect_key_right(self, key: Any) -> int: ...
def _bisect_key_right(self, key: Any) -> int: ...
def bisect_key(self, key: Any) -> int: ...
def count(self, value: _T) -> int: ...
def copy(self: _SKL) -> _SKL: ...
def __copy__(self: _SKL) -> _SKL: ...
def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
) -> int: ...
def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
def __iadd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
def __mul__(self: _SKL, num: int) -> _SKL: ...
def __rmul__(self: _SKL, num: int) -> _SKL: ...
def __imul__(self: _SKL, num: int) -> _SKL: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
SortedListWithKey = SortedKeyList

View file

@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.21.1"
__version__ = "1.21.2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -14,14 +14,15 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from synapse.api.constants import EventTypes
from synapse.appservice.api import ApplicationServiceApi
from synapse.types import GroupID, get_domain_from_id
from synapse.events import EventBase
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationServiceApi
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@ -32,38 +33,6 @@ class ApplicationServiceState:
UP = "up"
class AppServiceTransaction:
"""Represents an application service transaction."""
def __init__(self, service, id, events):
self.service = service
self.id = id
self.events = events
async def send(self, as_api: ApplicationServiceApi) -> bool:
"""Sends this transaction using the provided AS API interface.
Args:
as_api: The API to use to send.
Returns:
True if the transaction was sent.
"""
return await as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id
)
async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
transaction contents from the database.
Args:
store: The database store to operate on.
"""
await store.complete_appservice_txn(service=self.service, txn_id=self.id)
class ApplicationService:
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@ -91,6 +60,7 @@ class ApplicationService:
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
supports_ephemeral=False,
):
self.token = token
self.url = (
@ -102,6 +72,7 @@ class ApplicationService:
self.namespaces = self._check_namespaces(namespaces)
self.id = id
self.ip_range_whitelist = ip_range_whitelist
self.supports_ephemeral = supports_ephemeral
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
@ -161,19 +132,21 @@ class ApplicationService:
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces
def _matches_regex(self, test_string, namespace_key):
def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
for regex_obj in self.namespaces[namespace_key]:
if regex_obj["regex"].match(test_string):
return regex_obj
return None
def _is_exclusive(self, ns_key, test_string):
def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj:
return regex_obj["exclusive"]
return False
async def _matches_user(self, event, store):
async def _matches_user(
self, event: Optional[EventBase], store: Optional["DataStore"] = None
) -> bool:
if not event:
return False
@ -188,14 +161,23 @@ class ApplicationService:
if not store:
return False
does_match = await self._matches_user_in_member_list(event.room_id, store)
does_match = await self.matches_user_in_member_list(event.room_id, store)
return does_match
@cached(num_args=1, cache_context=True)
async def _matches_user_in_member_list(self, room_id, store, cache_context):
member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
@cached(num_args=1)
async def matches_user_in_member_list(
self, room_id: str, store: "DataStore"
) -> bool:
"""Check if this service is interested a room based upon it's membership
Args:
room_id: The room to check.
store: The datastore to query.
Returns:
True if this service would like to know about this room.
"""
member_list = await store.get_users_in_room(room_id)
# check joined member events
for user_id in member_list:
@ -203,12 +185,14 @@ class ApplicationService:
return True
return False
def _matches_room_id(self, event):
def _matches_room_id(self, event: EventBase) -> bool:
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
async def _matches_aliases(self, event, store):
async def _matches_aliases(
self, event: EventBase, store: Optional["DataStore"] = None
) -> bool:
if not store or not event:
return False
@ -218,12 +202,15 @@ class ApplicationService:
return True
return False
async def is_interested(self, event, store=None) -> bool:
async def is_interested(
self, event: EventBase, store: Optional["DataStore"] = None
) -> bool:
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
store(DataStore)
event: The event to check.
store: The datastore to query.
Returns:
True if this service would like to know about this event.
"""
@ -231,39 +218,66 @@ class ApplicationService:
if self._matches_room_id(event):
return True
if await self._matches_aliases(event, store):
# This will check the namespaces first before
# checking the store, so should be run before _matches_aliases
if await self._matches_user(event, store):
return True
if await self._matches_user(event, store):
# This will check the store, so should be run last
if await self._matches_aliases(event, store):
return True
return False
def is_interested_in_user(self, user_id):
@cached(num_args=1)
async def is_interested_in_presence(
self, user_id: UserID, store: "DataStore"
) -> bool:
"""Check if this service is interested a user's presence
Args:
user_id: The user to check.
store: The datastore to query.
Returns:
True if this service would like to know about presence for this user.
"""
# Find all the rooms the sender is in
if self.is_interested_in_user(user_id.to_string()):
return True
room_ids = await store.get_rooms_for_user(user_id.to_string())
# Then find out if the appservice is interested in any of those rooms
for room_id in room_ids:
if await self.matches_user_in_member_list(room_id, store):
return True
return False
def is_interested_in_user(self, user_id: str) -> bool:
return (
self._matches_regex(user_id, ApplicationService.NS_USERS)
bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
or user_id == self.sender
)
def is_interested_in_alias(self, alias):
def is_interested_in_alias(self, alias: str) -> bool:
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
def is_interested_in_room(self, room_id):
def is_interested_in_room(self, room_id: str) -> bool:
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
def is_exclusive_user(self, user_id):
def is_exclusive_user(self, user_id: str) -> bool:
return (
self._is_exclusive(ApplicationService.NS_USERS, user_id)
or user_id == self.sender
)
def is_interested_in_protocol(self, protocol):
def is_interested_in_protocol(self, protocol: str) -> bool:
return protocol in self.protocols
def is_exclusive_alias(self, alias):
def is_exclusive_alias(self, alias: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
def is_exclusive_room(self, room_id):
def is_exclusive_room(self, room_id: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exclusive_user_regexes(self):
@ -276,14 +290,14 @@ class ApplicationService:
if regex_obj["exclusive"]
]
def get_groups_for_user(self, user_id):
def get_groups_for_user(self, user_id: str) -> Iterable[str]:
"""Get the groups that this user is associated with by this AS
Args:
user_id (str): The ID of the user.
user_id: The ID of the user.
Returns:
iterable[str]: an iterable that yields group_id strings.
An iterable that yields group_id strings.
"""
return (
regex_obj["group_id"]
@ -291,7 +305,7 @@ class ApplicationService:
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
)
def is_rate_limited(self):
def is_rate_limited(self) -> bool:
return self.rate_limited
def __str__(self):
@ -300,3 +314,45 @@ class ApplicationService:
dict_copy["token"] = "<redacted>"
dict_copy["hs_token"] = "<redacted>"
return "ApplicationService: %s" % (dict_copy,)
class AppServiceTransaction:
"""Represents an application service transaction."""
def __init__(
self,
service: ApplicationService,
id: int,
events: List[EventBase],
ephemeral: List[JsonDict],
):
self.service = service
self.id = id
self.events = events
self.ephemeral = ephemeral
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
Args:
as_api: The API to use to send.
Returns:
True if the transaction was sent.
"""
return await as_api.push_bulk(
service=self.service,
events=self.events,
ephemeral=self.ephemeral,
txn_id=self.id,
)
async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
transaction contents from the database.
Args:
store: The database store to operate on.
"""
await store.complete_appservice_txn(service=self.service, txn_id=self.id)

View file

@ -14,12 +14,13 @@
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID
@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get)
async def push_bulk(self, service, events, txn_id=None):
async def push_bulk(
self,
service: "ApplicationService",
events: List[EventBase],
ephemeral: List[JsonDict],
txn_id: Optional[int] = None,
):
if service.url is None:
return True
@ -211,15 +218,19 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning(
"push_bulk: Missing txn ID sending events to %s", service.url
)
txn_id = str(0)
txn_id = str(txn_id)
txn_id = 0
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
# Never send ephemeral events to appservices that do not support it
if service.supports_ephemeral:
body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
else:
body = {"events": events}
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try:
await self.put_json(
uri=uri,
json_body={"events": events},
args={"access_token": service.hs_token},
uri=uri, json_body=body, args={"access_token": service.hs_token},
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))

View file

@ -49,10 +49,13 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
from typing import List
from synapse.appservice import ApplicationServiceState
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -82,8 +85,13 @@ class ApplicationServiceScheduler:
for service in services:
self.txn_ctrl.start_recoverer(service)
def submit_event_for_as(self, service, event):
self.queuer.enqueue(service, event)
def submit_event_for_as(self, service: ApplicationService, event: EventBase):
self.queuer.enqueue_event(service, event)
def submit_ephemeral_events_for_as(
self, service: ApplicationService, events: List[JsonDict]
):
self.queuer.enqueue_ephemeral(service, events)
class _ServiceQueuer:
@ -96,17 +104,15 @@ class _ServiceQueuer:
def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
self.queued_ephemeral = {} # dict of {service_id: [events]}
# the appservices which currently have a transaction in flight
self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event):
self.queued_events.setdefault(service.id, []).append(event)
def _start_background_request(self, service):
# start a sender for this appservice if we don't already have one
if service.id in self.requests_in_flight:
return
@ -114,7 +120,15 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id,), self._send_request, service
)
async def _send_request(self, service):
def enqueue_event(self, service: ApplicationService, event: EventBase):
self.queued_events.setdefault(service.id, []).append(event)
self._start_background_request(service)
def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
self.queued_ephemeral.setdefault(service.id, []).extend(events)
self._start_background_request(service)
async def _send_request(self, service: ApplicationService):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
@ -123,10 +137,11 @@ class _ServiceQueuer:
try:
while True:
events = self.queued_events.pop(service.id, [])
if not events:
ephemeral = self.queued_ephemeral.pop(service.id, [])
if not events and not ephemeral:
return
try:
await self.txn_ctrl.send(service, events)
await self.txn_ctrl.send(service, events, ephemeral)
except Exception:
logger.exception("AS request failed")
finally:
@ -158,9 +173,16 @@ class _TransactionController:
# for UTs
self.RECOVERER_CLASS = _Recoverer
async def send(self, service, events):
async def send(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict] = [],
):
try:
txn = await self.store.create_appservice_txn(service=service, events=events)
txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral
)
service_is_up = await self._is_service_up(service)
if service_is_up:
sent = await txn.send(self.as_api)
@ -204,7 +226,7 @@ class _TransactionController:
recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers))
async def _is_service_up(self, service):
async def _is_service_up(self, service: ApplicationService) -> bool:
state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None

View file

@ -160,6 +160,8 @@ def _load_appservice(hostname, as_info, config_filename):
if as_info.get("ip_range_whitelist"):
ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False)
return ApplicationService(
token=as_info["as_token"],
hostname=hostname,
@ -168,6 +170,7 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
supports_ephemeral=supports_ephemeral,
protocols=protocols,
rate_limited=rate_limited,
ip_range_whitelist=ip_range_whitelist,

View file

@ -312,6 +312,12 @@ class EventBase(metaclass=abc.ABCMeta):
"""
return [e for e, _ in self.auth_events]
def freeze(self):
"""'Freeze' the event dict, so it cannot be modified by accident"""
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1

View file

@ -97,32 +97,37 @@ class EventBuilder:
def is_state(self):
return self._state_key is not None
async def build(self, prev_event_ids: List[str]) -> EventBase:
async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
) -> EventBase:
"""Transform into a fully signed and hashed event
Args:
prev_event_ids: The event IDs to use as the prev events
auth_event_ids: The event IDs to use as the auth events.
Should normally be set to None, which will cause them to be calculated
based on the room state at the prev_events.
Returns:
The signed and hashed event.
"""
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
auth_ids = self._auth.compute_auth_events(self, state_ids)
if auth_event_ids is None:
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
auth_event_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes(
auth_ids
auth_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes(
prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else:
auth_events = auth_ids
auth_events = auth_event_ids
prev_events = prev_event_ids
old_depth = await self._store.get_max_depth_of(prev_event_ids)

View file

@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from typing import Callable, Union
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@ -44,15 +45,20 @@ class ThirdPartyEventRules:
async def check_event_allowed(
self, event: EventBase, context: EventContext
) -> bool:
) -> Union[bool, dict]:
"""Check if a provided event should be allowed in the given context.
The module can return:
* True: the event is allowed.
* False: the event is not allowed, and should be rejected with M_FORBIDDEN.
* a dict: replacement event data.
Args:
event: The event to be checked.
context: The context of the event.
Returns:
True if the event should be allowed, False if not.
The result from the ThirdPartyRules module, as above
"""
if self.third_party_rules is None:
return True
@ -63,9 +69,10 @@ class ThirdPartyEventRules:
events = await self.store.get_events(prev_state_ids.values())
state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
# The module can modify the event slightly if it wants, but caution should be
# exercised, and it's likely to go very wrong if applied to events received over
# federation.
# Ensure that the event is frozen, to make sure that the module is not tempted
# to try to modify it. Any attempt to modify it at this point will invalidate
# the hashes and signatures.
event.freeze()
return await self.third_party_rules.check_event_allowed(event, state_events)

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional
from prometheus_client import Counter
@ -21,13 +22,16 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import (
event_processing_loop_counter,
event_processing_loop_room_count,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import RoomStreamToken
from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@ -44,6 +48,7 @@ class ApplicationServicesHandler:
self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.event_sources = hs.get_event_sources()
self.current_max = 0
self.is_processing = False
@ -82,7 +87,7 @@ class ApplicationServicesHandler:
if not events:
break
events_by_room = {}
events_by_room = {} # type: Dict[str, List[EventBase]]
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
@ -161,6 +166,104 @@ class ApplicationServicesHandler:
finally:
self.is_processing = False
async def notify_interested_services_ephemeral(
self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [],
):
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
This will determine which appservices
are interested in the event, and submit them.
Events will only be pushed to appservices
that have opted into ephemeral events
Args:
stream_key: The stream the event came from.
new_token: The latest stream token
users: The user(s) involved with the event.
"""
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
]
if not services or not self.notify_appservices:
return
logger.info("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
# Only handle typing if we have the latest token
if stream_key == "typing_key" and new_token is not None:
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
# We don't persist the token for typing_key for performance reasons
elif stream_key == "receipt_key":
events = await self._handle_receipts(service)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
await self.store.set_type_stream_id_for_appservice(
service, "read_receipt", new_token
)
elif stream_key == "presence_key":
events = await self._handle_presence(service, users)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
await self.store.set_type_stream_id_for_appservice(
service, "presence", new_token
)
async def _handle_typing(self, service: ApplicationService, new_token: int):
typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
service=service,
# For performance reasons, we don't persist the previous
# token in the DB and instead fetch the latest typing information
# for appservices.
from_key=new_token - 1,
)
return typing
async def _handle_receipts(self, service: ApplicationService):
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
receipts_source = self.event_sources.sources["receipt"]
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
)
return receipts
async def _handle_presence(
self, service: ApplicationService, users: Collection[UserID]
):
events = [] # type: List[JsonDict]
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)
for user in users:
interested = await service.is_interested_in_presence(user, self.store)
if not interested:
continue
presence_events, _ = await presence_source.get_new_events(
user=user, service=service, from_key=from_key,
)
time_now = self.clock.time_msec()
presence_events = [
{
"type": "m.presence",
"sender": event.user_id,
"content": format_user_presence_state(
event, time_now, include_user_id=False
),
}
for event in presence_events
]
events = events + presence_events
async def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.
@ -223,7 +326,7 @@ class ApplicationServicesHandler:
async def get_3pe_protocols(self, only_protocol=None):
services = self.store.get_app_services()
protocols = {}
protocols = {} # type: Dict[str, List[JsonDict]]
# Collect up all the individual protocol responses out of the ASes
for s in services:

View file

@ -1507,18 +1507,9 @@ class FederationHandler(BaseHandler):
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
except AuthError as e:
except SynapseError as e:
logger.warning("Failed to create join to %s because %s", room_id, e)
raise e
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info("Creation of join %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
raise
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
@ -1567,15 +1558,6 @@ class FederationHandler(BaseHandler):
context = await self._handle_new_event(origin, event)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info("Sending of join %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
event.event_id,
@ -1748,15 +1730,6 @@ class FederationHandler(BaseHandler):
builder=builder
)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.warning("Creation of leave %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
@ -1789,16 +1762,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
context = await self._handle_new_event(origin, event)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info("Sending of leave %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
await self._handle_new_event(origin, event)
logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
@ -2694,18 +2658,6 @@ class FederationHandler(BaseHandler):
builder=builder
)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info(
"Creation of threepid invite %s forbidden by third-party rules",
event,
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@ -2756,18 +2708,6 @@ class FederationHandler(BaseHandler):
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.warning(
"Exchange of threepid invite %s forbidden by third-party rules", event
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)

View file

@ -437,9 +437,9 @@ class EventCreationHandler:
self,
requester: Requester,
event_dict: dict,
token_id: Optional[str] = None,
txn_id: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
) -> Tuple[EventBase, EventContext]:
"""
@ -453,13 +453,18 @@ class EventCreationHandler:
Args:
requester
event_dict: An entire event
token_id
txn_id
prev_event_ids:
the forward extremities to use as the prev_events for the
new event.
If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
require_consent: Whether to check if the requester has
consented to the privacy policy.
Raises:
@ -511,14 +516,17 @@ class EventCreationHandler:
if require_consent and not is_exempt:
await self.assert_accepted_privacy_policy(requester)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if requester.access_token_id is not None:
builder.internal_metadata.token_id = requester.access_token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = await self.create_new_client_event(
builder=builder, requester=requester, prev_event_ids=prev_event_ids,
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
)
# In an ideal world we wouldn't need the second part of this condition. However,
@ -726,7 +734,7 @@ class EventCreationHandler:
return event, event.internal_metadata.stream_ordering
event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
requester, event_dict, txn_id=txn_id
)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
@ -757,6 +765,7 @@ class EventCreationHandler:
builder: EventBuilder,
requester: Optional[Requester] = None,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@ -769,6 +778,11 @@ class EventCreationHandler:
If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
Returns:
Tuple of created event, context
"""
@ -790,11 +804,30 @@ class EventCreationHandler:
builder.type == EventTypes.Create or len(prev_event_ids) > 0
), "Attempting to create an event with no prev_events"
event = await builder.build(prev_event_ids=prev_event_ids)
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
context = await self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
third_party_result = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not third_party_result:
logger.info(
"Event %s forbidden by third-party rules", event,
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
elif isinstance(third_party_result, dict):
# the third-party rules want to replace the event. We'll need to build a new
# event.
event, context = await self._rebuild_event_after_third_party_rules(
third_party_result, event
)
self.validator.validate_new(event, self.config)
# If this event is an annotation then we check that that the sender
@ -881,14 +914,6 @@ class EventCreationHandler:
else:
room_version = await self.store.get_room_version_id(event.room_id)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
if event.internal_metadata.is_out_of_band_membership():
# the only sort of out-of-band-membership events we expect to see here
# are invite rejections we have generated ourselves.
@ -1291,3 +1316,57 @@ class EventCreationHandler:
room_id,
)
del self._rooms_to_exclude_from_dummy_event_insertion[room_id]
async def _rebuild_event_after_third_party_rules(
self, third_party_result: dict, original_event: EventBase
) -> Tuple[EventBase, EventContext]:
# the third_party_event_rules want to replace the event.
# we do some basic checks, and then return the replacement event and context.
# Construct a new EventBuilder and validate it, which helps with the
# rest of these checks.
try:
builder = self.event_builder_factory.for_room_version(
original_event.room_version, third_party_result
)
self.validator.validate_builder(builder)
except SynapseError as e:
raise Exception(
"Third party rules module created an invalid event: " + e.msg,
)
immutable_fields = [
# changing the room is going to break things: we've already checked that the
# room exists, and are holding a concurrency limiter token for that room.
# Also, we might need to use a different room version.
"room_id",
# changing the type or state key might work, but we'd need to check that the
# calling functions aren't making assumptions about them.
"type",
"state_key",
]
for k in immutable_fields:
if getattr(builder, k, None) != original_event.get(k):
raise Exception(
"Third party rules module created an invalid event: "
"cannot change field " + k
)
# check that the new sender belongs to this HS
if not self.hs.is_mine_id(builder.sender):
raise Exception(
"Third party rules module created an invalid event: "
"invalid sender " + builder.sender
)
# copy over the original internal metadata
for k, v in original_event.internal_metadata.get_dict().items():
setattr(builder.internal_metadata, k, v)
event = await builder.build(prev_event_ids=original_event.prev_event_ids())
# we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update.
context = await self.state.compute_event_context(event)
return event, context

View file

@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@ -140,5 +142,36 @@ class ReceiptEventSource:
return (events, to_key)
async def get_new_events_as(
self, from_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]:
"""Returns a set of new receipt events that an appservice
may be interested in.
Args:
from_key: the stream position at which events should be fetched from
service: The appservice which may be interested
"""
from_key = int(from_key)
to_key = self.get_current_key()
if from_key == to_key:
return [], to_key
# We first need to fetch all new receipts
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)
# Then filter down to rooms that the AS can read
events = []
for room_id, event in rooms_to_events.items():
if not await service.matches_user_in_member_list(room_id, self.store):
continue
events.append(event)
return (events, to_key)
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()

View file

@ -214,7 +214,6 @@ class RoomCreationHandler(BaseHandler):
"replacement_room": new_room_id,
},
},
token_id=requester.access_token_id,
)
old_room_version = await self.store.get_room_version_id(old_room_id)
await self.auth.check_from_context(

View file

@ -17,12 +17,10 @@ import abc
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@ -31,12 +29,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase
from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
@ -194,7 +188,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# For backwards compatibility:
"membership": membership,
},
token_id=requester.access_token_id,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
require_consent=require_consent,
@ -1153,31 +1146,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id = invite_event.room_id
target_user = invite_event.state_key
room_version = await self.store.get_room_version(room_id)
content["membership"] = Membership.LEAVE
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
#
# the prev_events are just the invite.
invite_hash = invite_event.event_id # type: Union[str, Tuple]
if room_version.event_format == EventFormatVersions.V1:
alg, h = compute_event_reference_hash(invite_event)
invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
auth_events = tuple(invite_event.auth_events) + (invite_hash,)
prev_events = (invite_hash,)
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(invite_event.depth + 1, MAX_DEPTH)
event_dict = {
"depth": depth,
"auth_events": auth_events,
"prev_events": prev_events,
"type": EventTypes.Member,
"room_id": room_id,
"sender": target_user,
@ -1185,24 +1157,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
event = create_local_event_from_event_dict(
clock=self.clock,
hostname=self.hs.hostname,
signing_key=self.hs.signing_key,
room_version=room_version,
event_dict=event_dict,
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
#
# the prev_events are just the invite.
prev_event_ids = [invite_event.event_id]
auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
)
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
if txn_id is not None:
event.internal_metadata.txn_id = txn_id
if requester.access_token_id is not None:
event.internal_metadata.token_id = requester.access_token_id
EventValidator().validate_new(event, self.config)
context = await self.state_handler.compute_event_context(event)
context.app_service = requester.app_service
result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)],
)

View file

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple

View file

@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@ -430,6 +430,33 @@ class TypingNotificationEventSource:
"content": {"user_ids": list(typing)},
}
async def get_new_events_as(
self, from_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]:
"""Returns a set of new typing events that an appservice
may be interested in.
Args:
from_key: the stream position at which events should be fetched from
service: The appservice which may be interested
"""
with Measure(self.clock, "typing.get_new_events_as"):
from_key = int(from_key)
handler = self.get_typing_handler()
events = []
for room_id in handler._room_serials.keys():
if handler._room_serials[room_id] <= from_key:
continue
if not await service.matches_user_in_member_list(
room_id, handler.store
):
continue
events.append(self._make_event_for(room_id))
return (events, handler._latest_room_serial)
async def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)

View file

@ -329,6 +329,22 @@ class Notifier:
except Exception:
logger.exception("Error notifying application services of event")
async def _notify_app_services_ephemeral(
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [],
):
try:
stream_token = None
if isinstance(new_token, int):
stream_token = new_token
await self.appservice_handler.notify_interested_services_ephemeral(
stream_key, stream_token, users
)
except Exception:
logger.exception("Error notifying application services of event")
async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try:
await self._pusher_pool.on_new_notifications(max_room_stream_token)
@ -367,6 +383,15 @@ class Notifier:
self.notify_replication()
# Notify appservices
run_as_background_process(
"_notify_app_services_ephemeral",
self._notify_app_services_ephemeral,
stream_key,
new_token,
users,
)
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""

View file

@ -15,7 +15,7 @@
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache
from synapse.util.caches.deferred_cache import DeferredCache
from ._base import BaseSlavedStore
@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
) # type: DeferredCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())

View file

@ -205,7 +205,13 @@ class HomeServer(metaclass=abc.ABCMeta):
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
def __init__(
self,
hostname: str,
config: HomeServerConfig,
reactor=None,
version_string="Synapse",
):
"""
Args:
hostname : The hostname for the server.
@ -236,11 +242,9 @@ class HomeServer(metaclass=abc.ABCMeta):
burst_count=config.rc_registration.burst_count,
)
self.datastores = None # type: Optional[Databases]
self.version_string = version_string
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
self.datastores = None # type: Optional[Databases]
def get_instance_id(self) -> str:
"""A unique ID for this synapse process instance.

View file

@ -893,6 +893,12 @@ class DatabasePool:
attempts = 0
while True:
try:
# We can autocommit if we are going to use native upserts
autocommit = (
self.engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
)
return await self.runInteraction(
desc,
self.simple_upsert_txn,
@ -901,6 +907,7 @@ class DatabasePool:
values,
insertion_values,
lock=lock,
db_autocommit=autocommit,
)
except self.engine.module.IntegrityError as e:
attempts += 1
@ -1063,6 +1070,43 @@ class DatabasePool:
)
txn.execute(sql, list(allvalues.values()))
async def simple_upsert_many(
self,
table: str,
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
desc: str,
) -> None:
"""
Upsert, many times.
Args:
table: The table to upsert into
key_names: The key column names.
key_values: A list of each row's key column values.
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
# We can autocommit if we are going to use native upserts
autocommit = (
self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
)
return await self.runInteraction(
desc,
self.simple_upsert_many_txn,
table,
key_names,
key_values,
value_names,
value_values,
db_autocommit=autocommit,
)
def simple_upsert_many_txn(
self,
txn: LoggingTransaction,
@ -1214,7 +1258,13 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
desc,
self.simple_select_one_txn,
table,
keyvalues,
retcols,
allow_none,
db_autocommit=True,
)
@overload
@ -1265,6 +1315,7 @@ class DatabasePool:
keyvalues,
retcol,
allow_none=allow_none,
db_autocommit=True,
)
@overload
@ -1346,7 +1397,12 @@ class DatabasePool:
Results in a list
"""
return await self.runInteraction(
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
desc,
self.simple_select_onecol_txn,
table,
keyvalues,
retcol,
db_autocommit=True,
)
async def simple_select_list(
@ -1371,7 +1427,12 @@ class DatabasePool:
A list of dictionaries.
"""
return await self.runInteraction(
desc, self.simple_select_list_txn, table, keyvalues, retcols
desc,
self.simple_select_list_txn,
table,
keyvalues,
retcols,
db_autocommit=True,
)
@classmethod
@ -1450,6 +1511,7 @@ class DatabasePool:
chunk,
keyvalues,
retcols,
db_autocommit=True,
)
results.extend(rows)
@ -1548,7 +1610,12 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
desc,
self.simple_update_one_txn,
table,
keyvalues,
updatevalues,
db_autocommit=True,
)
@classmethod
@ -1607,7 +1674,9 @@ class DatabasePool:
keyvalues: dict of column names and values to select the row with
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
await self.runInteraction(
desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
)
@staticmethod
def simple_delete_one_txn(
@ -1646,7 +1715,9 @@ class DatabasePool:
Returns:
The number of deleted rows.
"""
return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
return await self.runInteraction(
desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
)
@staticmethod
def simple_delete_txn(
@ -1694,7 +1765,13 @@ class DatabasePool:
Number rows deleted
"""
return await self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
desc,
self.simple_delete_many_txn,
table,
column,
iterable,
keyvalues,
db_autocommit=True,
)
@staticmethod
@ -1860,7 +1937,13 @@ class DatabasePool:
"""
return await self.runInteraction(
desc, self.simple_search_list_txn, table, term, col, retcols
desc,
self.simple_search_list_txn,
table,
term,
col,
retcols,
db_autocommit=True,
)
@classmethod

View file

@ -15,12 +15,15 @@
# limitations under the License.
import logging
import re
from typing import List
from synapse.appservice import AppServiceTransaction
from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.types import JsonDict
from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@ -172,15 +175,23 @@ class ApplicationServiceTransactionWorkerStore(
"application_services_state", {"as_id": service.id}, {"state": state}
)
async def create_appservice_txn(self, service, events):
async def create_appservice_txn(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict],
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events.
with the given list of events. Ephemeral events are NOT persisted to the
database and are not resent if a transaction is retried.
Args:
service(ApplicationService): The service who the transaction is for.
events(list<Event>): A list of events to put in the transaction.
service: The service who the transaction is for.
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
Returns:
AppServiceTransaction: A new transaction.
A new transaction.
"""
def _create_appservice_txn(txn):
@ -207,7 +218,9 @@ class ApplicationServiceTransactionWorkerStore(
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids),
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events, ephemeral=ephemeral
)
return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
@ -296,7 +309,9 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
return AppServiceTransaction(
service=service, id=entry["txn_id"], events=events, ephemeral=[]
)
def _get_last_txn(self, txn, service_id):
txn.execute(
@ -320,7 +335,7 @@ class ApplicationServiceTransactionWorkerStore(
)
async def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""
"""Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
sql = (
@ -351,6 +366,39 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, events
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
def get_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
txn.execute(
"SELECT ? FROM application_services_state WHERE as_id=?",
(stream_id_type, service.id,),
)
last_txn_id = txn.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists
return 0
else:
return int(last_txn_id[0])
return await self.db_pool.runInteraction(
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
)
async def set_type_stream_id_for_appservice(
self, service: ApplicationService, type: str, pos: int
) -> None:
def set_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
txn.execute(
"UPDATE ? SET device_list_stream_id = ? WHERE as_id=?",
(stream_id_type, pos, service.id),
)
await self.db_pool.runInteraction(
"set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
)
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
# This is currently empty due to there not being any AS storage functions

View file

@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache
from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
@ -410,7 +410,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache(
self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)

View file

@ -34,7 +34,8 @@ from synapse.storage.database import (
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import Cache, cached, cachedList
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@ -1004,7 +1005,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
self.device_id_exists_cache = DeferredCache(
name="device_id_exists", keylen=2, max_entries=10000
)

View file

@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -145,7 +146,7 @@ class EventsWorkerStore(SQLBaseStore):
self._cleanup_old_transaction_ids,
)
self._get_event_cache = Cache(
self._get_event_cache = DeferredCache(
"*getEvent*",
keylen=3,
max_entries=hs.config.caches.event_cache_size,

View file

@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
await self.db_pool.runInteraction(
"store_server_verify_keys",
self.db_pool.simple_upsert_many_txn,
await self.db_pool.simple_upsert_many(
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
desc="store_server_verify_keys",
)
invalidate = self._get_server_verify_key.invalidate

View file

@ -281,9 +281,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
a_day_in_milliseconds = 24 * 60 * 60 * 1000
now = self._clock.time_msec()
# A note on user_agent. Technically a given device can have multiple
# user agents, so we need to decide which one to pick. We could have handled this
# in number of ways, but given that we don't _that_ much have gone for MAX()
# For more details of the other options considered see
# https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111
sql = """
INSERT INTO user_daily_visits (user_id, device_id, timestamp)
SELECT u.user_id, u.device_id, ?
INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent)
SELECT u.user_id, u.device_id, ?, MAX(u.user_agent)
FROM user_ips AS u
LEFT JOIN (
SELECT user_id, device_id, timestamp FROM user_daily_visits
@ -294,7 +299,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
WHERE last_seen > ? AND last_seen <= ?
AND udv.timestamp IS NULL AND users.is_guest=0
AND users.appservice_id IS NULL
GROUP BY u.user_id, u.device_id
GROUP BY u.user_id, u.device_id, u.user_agent
"""
# This means that the day has rolled over but there could still

View file

@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
@ -274,6 +275,60 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
}
return results
@cached(num_args=2,)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids.
Args:
to_key: Max stream id to fetch receipts upto.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
A dictionary of roomids to a list of receipts.
"""
def f(txn):
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
"""
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ?
"""
txn.execute(sql, [to_key])
return self.db_pool.cursor_to_dict(txn)
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": "m.receipt", "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
return results
async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]:

View file

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add new column to user_daily_visits to track user agent
ALTER TABLE user_daily_visits
ADD COLUMN user_agent TEXT;

View file

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE application_services_state
ADD COLUMN read_receipt_stream_id INT,
ADD COLUMN presence_stream_id INT;

View file

@ -208,42 +208,56 @@ class TransactionStore(TransactionWorkerStore):
"""
self._destination_retry_cache.pop(destination, None)
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
failure_ts,
retry_last_ts,
retry_interval,
)
if self.database_engine.can_native_upsert:
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_native,
destination,
failure_ts,
retry_last_ts,
retry_interval,
db_autocommit=True, # Safe as its a single upsert
)
else:
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_emulated,
destination,
failure_ts,
retry_last_ts,
retry_interval,
)
def _set_destination_retry_timings(
def _set_destination_retry_timings_native(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
assert self.database_engine.can_native_upsert
if self.database_engine.can_native_upsert:
# Upsert retry time interval if retry_interval is zero (i.e. we're
# resetting it) or greater than the existing retry interval.
# Upsert retry time interval if retry_interval is zero (i.e. we're
# resetting it) or greater than the existing retry interval.
#
# WARNING: This is executed in autocommit, so we shouldn't add any more
# SQL calls in here (without being very careful).
sql = """
INSERT INTO destinations (
destination, failure_ts, retry_last_ts, retry_interval
)
VALUES (?, ?, ?, ?)
ON CONFLICT (destination) DO UPDATE SET
failure_ts = EXCLUDED.failure_ts,
retry_last_ts = EXCLUDED.retry_last_ts,
retry_interval = EXCLUDED.retry_interval
WHERE
EXCLUDED.retry_interval = 0
OR destinations.retry_interval IS NULL
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
sql = """
INSERT INTO destinations (
destination, failure_ts, retry_last_ts, retry_interval
)
VALUES (?, ?, ?, ?)
ON CONFLICT (destination) DO UPDATE SET
failure_ts = EXCLUDED.failure_ts,
retry_last_ts = EXCLUDED.retry_last_ts,
retry_interval = EXCLUDED.retry_interval
WHERE
EXCLUDED.retry_interval = 0
OR destinations.retry_interval IS NULL
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
return
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
def _set_destination_retry_timings_emulated(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us

View file

@ -480,21 +480,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_id_tuples: iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
self.db_pool.simple_upsert_many_txn(
txn,
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
key_values=[
(user_id, other_user_id, room_id)
for user_id, other_user_id in user_id_tuples
],
value_names=(),
value_values=None,
)
await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
await self.db_pool.simple_upsert_many(
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
key_values=[
(user_id, other_user_id, room_id)
for user_id, other_user_id in user_id_tuples
],
value_names=(),
value_values=None,
desc="add_users_who_share_room",
)
async def add_users_in_public_rooms(
@ -508,19 +503,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_ids
"""
def _add_users_in_public_rooms_txn(txn):
self.db_pool.simple_upsert_many_txn(
txn,
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(),
value_values=None,
)
await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
await self.db_pool.simple_upsert_many(
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(),
value_values=None,
desc="add_users_in_public_rooms",
)
async def delete_all_from_user_dir(self) -> None:

View file

@ -618,14 +618,7 @@ class _MultiWriterCtxManager:
db_autocommit=True,
)
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
with self.id_gen._lock:
assert max(self.id_gen._current_positions.values(), default=0) < min(
self.stream_ids
)
self.id_gen._unfinished_ids.update(self.stream_ids)
if self.multiple_ids is None:

View file

@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import threading
from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)
KT = TypeVar("KT")
VT = TypeVar("VT")
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred.
"""
__slots__ = (
"cache",
"name",
"keylen",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(
self,
name: str,
max_entries: int = 1000,
keylen: int = 1,
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key. Ignored unless
`tree` is True.
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
"""
cache_type = TreeCache if tree else dict
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = (
cache_type()
) # type: MutableMapping[KT, CacheEntry]
# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
@property
def max_entries(self):
return self.cache.max_size
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(
self,
key: KT,
default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
):
"""Looks the key up in the caches.
Args:
key(tuple)
default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the result itself
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
if val is not _Sentinel.sentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
if default is _Sentinel.sentinel:
raise KeyError()
else:
return default
def set(
self,
key: KT,
value: defer.Deferred,
callback: Optional[Callable[[], None]] = None,
) -> ObservableDeferred:
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
self.cache.pop(key, None)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
entry.invalidate()
def invalidate_many(self, key: KT):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()

View file

@ -13,25 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import logging
import threading
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from . import register_cache
from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
@ -55,239 +48,6 @@ class _CachedFunction(Generic[F]):
__call__ = None # type: F
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)
_CacheSentinel = object()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache:
__slots__ = (
"cache",
"name",
"keylen",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(
self,
name: str,
max_entries: int = 1000,
keylen: int = 1,
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
Returns:
Cache
"""
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
@property
def max_entries(self):
return self.cache.max_size
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
"""Looks the key up in the caches.
Args:
key(tuple)
default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
else:
return default
def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
self.cache.pop(key, None)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
entry.invalidate()
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()
class _CacheDescriptorBase:
def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
self.orig = orig
@ -390,13 +150,13 @@ class CacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
def __get__(self, obj, owner):
cache = Cache(
cache = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
)
) # type: DeferredCache[Tuple, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
@ -640,9 +400,9 @@ class _CacheContext:
_cache_context_objects = (
WeakValueDictionary()
) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None
def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
self._cache = cache
self._cache_key = cache_key
@ -651,7 +411,9 @@ class _CacheContext:
self._cache.invalidate(self._cache_key)
@classmethod
def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext
def get_instance(
cls, cache, cache_key
): # type: (DeferredCache, CacheKey) -> _CacheContext
"""Returns an instance constructed with the given arguments.
A new instance is only created if none already exists.

View file

@ -64,7 +64,8 @@ class LruCache:
Args:
max_size: The maximum amount of entries the cache can hold
keylen: The length of the tuple used as the cache key
keylen: The length of the tuple used as the cache key. Ignored unless
cache_type is `TreeCache`.
cache_type (type):
type of underlying cache to be used. Typically one of dict

View file

@ -34,7 +34,7 @@ class TTLCache:
self._data = {}
# the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList()
self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
self._timer = timer

View file

@ -22,7 +22,7 @@ class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
http_client=None, homeserverToUse=GenericWorkerServer
http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs

View file

@ -26,7 +26,7 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
http_client=None, homeserverToUse=GenericWorkerServer
http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
http_client=None, homeserverToUse=SynapseHomeServer
http_client=None, homeserver_to_use=SynapseHomeServer
)
return hs

View file

@ -60,7 +60,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
service=service, events=events, ephemeral=[] # txn made and saved
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@ -81,7 +81,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
service=service, events=events, ephemeral=[] # txn made and saved
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
@ -106,7 +106,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events
service=service, events=events, ephemeral=[]
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@ -202,26 +202,28 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
# Expect the event to be sent immediately.
service = Mock(id=4)
event = Mock()
self.queuer.enqueue(service, event)
self.txn_ctrl.send.assert_called_once_with(service, [event])
self.queuer.enqueue_event(service, event)
self.txn_ctrl.send.assert_called_once_with(service, [event], [])
def test_send_single_event_with_queue(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
service = Mock(id=4)
event = Mock(event_id="first")
event2 = Mock(event_id="second")
event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet.
self.queuer.enqueue(service, event)
self.queuer.enqueue_event(service, event)
# Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue(service, event2)
self.queuer.enqueue(service, event3)
self.txn_ctrl.send.assert_called_with(service, [event])
self.queuer.enqueue_event(service, event2)
self.queuer.enqueue_event(service, event3)
self.txn_ctrl.send.assert_called_with(service, [event], [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [event2, event3])
self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self):
@ -239,21 +241,58 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
send_return_list = [srv_1_defer, srv_2_defer]
def do_send(x, y):
def do_send(x, y, z):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent
self.queuer.enqueue(srv1, srv_1_event)
self.queuer.enqueue(srv1, srv_1_event2)
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event])
self.queuer.enqueue(srv2, srv_2_event)
self.queuer.enqueue(srv2, srv_2_event2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event])
self.queuer.enqueue_event(srv1, srv_1_event)
self.queuer.enqueue_event(srv1, srv_1_event2)
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
self.queuer.enqueue_event(srv2, srv_2_event)
self.queuer.enqueue_event(srv2, srv_2_event2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event")]
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
def test_send_multiple_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
def test_send_single_ephemeral_with_queue(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
service = Mock(id=4)
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
# Send an event and don't resolve it just yet.
self.queuer.enqueue_ephemeral(service, event_list_1)
# Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue_ephemeral(service, event_list_2)
self.queuer.enqueue_ephemeral(service, event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
self.assertEquals(2, self.txn_ctrl.send.call_count)

View file

@ -66,7 +66,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
"sender": self.requester.user.to_string(),
"content": {"msgtype": "m.text", "body": random_string(5)},
},
token_id=self.token_id,
txn_id=txn_id,
)
)

View file

@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id)
)
event = self.get_success(builder.build(prev_event_ids))
event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))

View file

@ -59,7 +59,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config.update(extra_config)
worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer,
homeserver_to_use=GenericWorkerServer,
config=config,
reactor=self.reactor,
**kwargs

View file

@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs

View file

@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
}
builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids))
join_event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()

View file

@ -14,8 +14,12 @@
# limitations under the License.
import logging
from mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
@ -36,6 +40,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
@ -43,6 +48,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
self.room_creator = self.hs.get_room_creation_handler()
self.store = hs.get_datastore()
def default_config(self):
conf = super().default_config()
conf["redis"] = {"enabled": "true"}
@ -53,6 +61,29 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
def _create_room(self, room_id: str, user_id: str, tok: str):
"""Create a room with given room_id
"""
# We control the room ID generation by patching out the
# `_generate_room_id` method
async def generate_room(
creator_id: str, is_public: bool, room_version: RoomVersion
):
await self.store.store_room(
room_id=room_id,
room_creator_user_id=creator_id,
is_public=is_public,
room_version=room_version,
)
return room_id
with patch(
"synapse.handlers.room.RoomCreationHandler._generate_room_id"
) as mock:
mock.side_effect = generate_room
self.helper.create_room_as(user_id, tok=tok)
def test_basic(self):
"""Simple test to ensure that multiple rooms can be created and joined,
and that different rooms get handled by different instances.
@ -100,3 +131,189 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(persisted_on_1)
self.assertTrue(persisted_on_2)
def test_vector_clock_token(self):
"""Tests that using a stream token with a vector clock component works
correctly with basic /sync and /messages usage.
"""
self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "worker1"},
)
worker_hs2 = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "worker2"},
)
sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"},
)
# Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test"
room_id2 = "!baz:test"
self.assertEqual(
self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1"
)
self.assertEqual(
self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2"
)
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
store = self.hs.get_datastore()
# Create two room on the different workers.
self._create_room(room_id1, user_id, access_token)
self._create_room(room_id2, user_id, access_token)
# The other user joins
self.helper.join(
room=room_id1, user=self.other_user_id, tok=self.other_access_token
)
self.helper.join(
room=room_id2, user=self.other_user_id, tok=self.other_access_token
)
# Do an initial sync so that we're up to date.
request, channel = self.make_request("GET", "/sync", access_token=access_token)
self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"]
# We now gut wrench into the events stream MultiWriterIdGenerator on
# worker2 to mimic it getting stuck persisting an event. This ensures
# that when we send an event on worker1 we end up in a state where
# worker2 events stream position lags that on worker1, resulting in a
# RoomStreamToken with a non-empty instance map component.
#
# Worker2's event stream position will not advance until we call
# __aexit__ again.
actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
self.get_success(actx.__aenter__())
response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
first_event_in_room1 = response["event_id"]
# Assert that the current stream token has an instance map component, as
# we are trying to test vector clock tokens.
room_stream_token = store.get_room_max_token()
self.assertNotEqual(len(room_stream_token.instance_map), 0)
# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
)
self.render_on_worker(sync_hs, request)
# We should only see the new event and nothing else
self.assertIn(room_id1, channel.json_body["rooms"]["join"])
self.assertNotIn(room_id2, channel.json_body["rooms"]["join"])
events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"]
self.assertListEqual(
[first_event_in_room1], [event["event_id"] for event in events]
)
# Get the next batch and makes sure its a vector clock style token.
vector_clock_token = channel.json_body["next_batch"]
self.assertTrue(vector_clock_token.startswith("m"))
# Now that we've got a vector clock token we finish the fake persisting
# an event we started above.
self.get_success(actx.__aexit__(None, None, None))
# Now try and send an event to the other rooom so that we can test that
# the vector clock style token works as a `since` token.
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]
request, channel = self.make_request(
"GET",
"/sync?since={}".format(vector_clock_token),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertNotIn(room_id1, channel.json_body["rooms"]["join"])
self.assertIn(room_id2, channel.json_body["rooms"]["join"])
events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"]
self.assertListEqual(
[first_event_in_room2], [event["event_id"] for event in events]
)
next_batch = channel.json_body["next_batch"]
# We also want to test that the vector clock style token works with
# pagination. We do this by sending a couple of new events into the room
# and syncing again to get a prev_batch token for each room, then
# paginating from there back to the vector clock token.
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
)
self.render_on_worker(sync_hs, request)
prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][
"prev_batch"
]
prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][
"prev_batch"
]
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
request, channel = self.make_request(
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token
),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])
# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
request, channel = self.make_request(
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token
),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertEqual(len(channel.json_body["chunk"]), 1)
self.assertEqual(
channel.json_body["chunk"][0]["event_id"], first_event_in_room2
)
# Paginating forwards should give the same results
request, channel = self.make_request(
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1
),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])
request, channel = self.make_request(
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2,
),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertEqual(len(channel.json_body["chunk"]), 1)
self.assertEqual(
channel.json_body["chunk"][0]["event_id"], first_event_in_room2
)

View file

@ -114,16 +114,36 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_modify_event(self):
"""Tests that the module can successfully tweak an event before it is persisted.
"""
# first patch the event checker so that it will modify the event
def test_cannot_modify_event(self):
"""cannot accidentally modify an event before it is persisted"""
# first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state):
ev.content = {"x": "y"}
return True
current_rules_module().check_event_allowed = check
# now send the event
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
{"x": "x"},
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.result["code"], b"500", channel.result)
def test_modify_event(self):
"""The module can return a modified version of the event"""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
d = ev.get_dict()
d["content"] = {"x": "y"}
return d
current_rules_module().check_event_allowed = check
# now send the event
request, channel = self.make_request(
"PUT",

View file

@ -20,82 +20,11 @@ from mock import Mock
from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.caches.descriptors import cached
from tests import unittest
class CacheTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.cache = Cache("test")
def test_empty(self):
failed = False
try:
self.cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_hit(self):
self.cache.prefill("foo", 123)
self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self):
self.cache.prefill(("foo",), 123)
self.cache.invalidate(("foo",))
failed = False
try:
self.cache.get(("foo",))
except KeyError:
failed = True
self.assertTrue(failed)
def test_eviction(self):
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
failed = False
try:
cache.get(1)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
# Now access 1 again, thus causing 2 to be least-recently used
cache.get(1)
cache.prefill(3, "three")
failed = False
try:
cache.get(2)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(1)
cache.get(3)
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):

View file

@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events)
self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events)
txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events)
self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events)
self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[3]["id"], 9643, events)
txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events)
self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)

View file

@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._event_id = event_id
@defer.inlineCallbacks
def build(self, prev_event_ids):
def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred(
self._base_builder.build(prev_event_ids)
self._base_builder.build(prev_event_ids, auth_event_ids)
)
built_event._event_id = self._event_id

View file

@ -15,7 +15,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from synapse.util.caches.descriptors import Cache
from synapse.util.caches.deferred_cache import DeferredCache
from tests import unittest
@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
Caches produce metrics reflecting their state when scraped.
"""
CACHE_NAME = "cache_metrics_test_fgjkbdfg"
cache = Cache(CACHE_NAME, max_entries=777)
cache = DeferredCache(CACHE_NAME, max_entries=777)
items = {
x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")

View file

@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
from typing import Optional, Tuple, Type, TypeVar, Union
from typing import Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch
@ -364,6 +364,36 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses.
"""
# Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
# when the `request` arg isn't given, so we define an explicit override to
# cover that case.
@overload
def make_request(
self,
method: Union[bytes, str],
path: Union[bytes, str],
content: Union[bytes, dict] = b"",
access_token: Optional[str] = None,
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
) -> Tuple[SynapseRequest, FakeChannel]:
...
@overload
def make_request(
self,
method: Union[bytes, str],
path: Union[bytes, str],
content: Union[bytes, dict] = b"",
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
) -> Tuple[T, FakeChannel]:
...
def make_request(
self,
method: Union[bytes, str],

View file

@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from twisted.internet import defer
from synapse.util.caches.deferred_cache import DeferredCache
class DeferredCacheTestCase(unittest.TestCase):
def test_empty(self):
cache = DeferredCache("test")
failed = False
try:
cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_hit(self):
cache = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEquals(cache.get("foo"), 123)
def test_invalidate(self):
cache = DeferredCache("test")
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
failed = False
try:
cache.get(("foo",))
except KeyError:
failed = True
self.assertTrue(failed)
def test_invalidate_all(self):
cache = DeferredCache("testcache")
callback_record = [False, False]
def record_callback(idx):
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
def test_eviction(self):
cache = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
cache.prefill(1, "one")
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
failed = False
try:
cache.get(1)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
cache.prefill(1, "one")
cache.prefill(2, "two")
# Now access 1 again, thus causing 2 to be least-recently used
cache.get(1)
cache.prefill(3, "three")
failed = False
try:
cache.get(2)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(1)
cache.get(3)

View file

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import partial
import mock
@ -42,49 +41,6 @@ def run_on_reactor():
return make_deferred_yieldable(d)
class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self):
cache = descriptors.Cache("testcache")
callback_record = [False, False]
def record_callback(idx):
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):

View file

@ -21,6 +21,7 @@ import time
import uuid
import warnings
from inspect import getcallargs
from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
@ -194,8 +195,8 @@ def setup_test_homeserver(
name="test",
config=None,
reactor=None,
homeserverToUse=TestHomeServer,
**kargs
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@ -218,8 +219,8 @@ def setup_test_homeserver(
config.ldap_enabled = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
@ -264,18 +265,20 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
hs = homeserverToUse(
name,
config=config,
version_string="Synapse/tests",
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
**kargs
hs = homeserver_to_use(
name, config=config, version_string="Synapse/tests", reactor=reactor,
)
# Install @cache_in_self attributes
for key, val in kwargs.items():
setattr(hs, key, val)
# Mock TLS
hs.tls_server_context_factory = Mock()
hs.tls_client_options_factory = Mock()
hs.setup()
if homeserverToUse.__name__ == "TestHomeServer":
if homeserver_to_use == TestHomeServer:
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
@ -339,7 +342,7 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None)
fed = kwargs.get("resource_for_federation", None)
if fed:
register_federation_servlets(hs, fed)