Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2020-10-30 12:08:09 +00:00
commit 1ff3bc332a
108 changed files with 2360 additions and 1540 deletions

View file

@ -46,7 +46,7 @@ locally. You'll need python 3.6 or later, and to install a number of tools:
```
# Install the dependencies
pip install -e ".[lint]"
pip install -e ".[lint,mypy]"
# Run the linter script
./scripts-dev/lint.sh
@ -63,7 +63,7 @@ run-time:
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
```
You can also provided the `-d` option, which will lint the files that have been
You can also provide the `-d` option, which will lint the files that have been
changed since the last git commit. This will often be significantly faster than
linting the whole codebase.

View file

@ -57,7 +57,7 @@ light workloads.
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- Python 3.5.2 or later, up to Python 3.8.
- Python 3.5.2 or later, up to Python 3.9.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in Python but some of the libraries it uses are written in

View file

@ -256,9 +256,9 @@ directory of your choice::
Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv::
virtualenv -p python3 env
source env/bin/activate
python -m pip install --no-use-pep517 -e ".[all]"
python3 -m venv ./env
source ./env/bin/activate
pip install -e ".[all,test]"
This will run a process of downloading and installing all the needed
dependencies into a virtual env.
@ -270,9 +270,9 @@ check that everything is installed as it should be::
This should end with a 'PASSED' result::
Ran 143 tests in 0.601s
Ran 1266 tests in 643.930s
PASSED (successes=143)
PASSED (skips=15, successes=1251)
Running the Integration Tests
=============================

View file

@ -75,6 +75,22 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.23.0
====================
Structured logging configuration breaking changes
-------------------------------------------------
This release deprecates use of the ``structured: true`` logging configuration for
structured logging. If your logging configuration contains ``structured: true``
then it should be modified based on the `structured logging documentation
<https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md>`_.
The ``structured`` and ``drains`` logging options are now deprecated and should
be replaced by standard logging configuration of ``handlers`` and ``formatters`.
A future will release of Synapse will make using ``structured: true`` an error.
Upgrading to v1.22.0
====================

1
changelog.d/8559.misc Normal file
View file

@ -0,0 +1 @@
Optimise `/createRoom` with multiple invited users.

1
changelog.d/8595.misc Normal file
View file

@ -0,0 +1 @@
Implement and use an @lru_cache decorator.

1
changelog.d/8607.feature Normal file
View file

@ -0,0 +1 @@
Support generating structured logs via the standard logging configuration.

1
changelog.d/8610.feature Normal file
View file

@ -0,0 +1 @@
Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.

1
changelog.d/8616.misc Normal file
View file

@ -0,0 +1 @@
Change schema to support access tokens belonging to one user but granting access to another.

1
changelog.d/8633.misc Normal file
View file

@ -0,0 +1 @@
Run `mypy` as part of the lint.sh script.

1
changelog.d/8655.misc Normal file
View file

@ -0,0 +1 @@
Add more type hints to the application services code.

1
changelog.d/8664.misc Normal file
View file

@ -0,0 +1 @@
Tell Black to format code for Python 3.5.

1
changelog.d/8665.doc Normal file
View file

@ -0,0 +1 @@
Note support for Python 3.9.

1
changelog.d/8666.doc Normal file
View file

@ -0,0 +1 @@
Minor updates to docs on running tests.

1
changelog.d/8667.doc Normal file
View file

@ -0,0 +1 @@
Interlink prometheus/grafana documentation.

1
changelog.d/8669.misc Normal file
View file

@ -0,0 +1 @@
Don't pull event from DB when handling replication traffic.

1
changelog.d/8671.misc Normal file
View file

@ -0,0 +1 @@
Abstract some invite-related code in preparation for landing knocking.

1
changelog.d/8676.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broken in v1.22.0.

1
changelog.d/8678.bugfix Normal file
View file

@ -0,0 +1 @@
Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules.

1
changelog.d/8679.misc Normal file
View file

@ -0,0 +1 @@
Clarify representation of events in logfiles.

1
changelog.d/8680.misc Normal file
View file

@ -0,0 +1 @@
Don't require `hiredis` package to be installed to run unit tests.

1
changelog.d/8682.bugfix Normal file
View file

@ -0,0 +1 @@
Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories.

1
changelog.d/8684.misc Normal file
View file

@ -0,0 +1 @@
Fix typing info on cache call signature to accept `on_invalidate`.

1
changelog.d/8685.feature Normal file
View file

@ -0,0 +1 @@
Support generating structured logs via the standard logging configuration.

1
changelog.d/8688.misc Normal file
View file

@ -0,0 +1 @@
Abstract some invite-related code in preparation for landing knocking.

1
changelog.d/8689.feature Normal file
View file

@ -0,0 +1 @@
Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.

1
changelog.d/8690.misc Normal file
View file

@ -0,0 +1 @@
Fail tests if they do not await coroutines.

View file

@ -3,4 +3,4 @@
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
3. Set up additional recording rules
3. Set up required recording rules. https://github.com/matrix-org/synapse/tree/master/contrib/prometheus

View file

@ -611,3 +611,82 @@ The following parameters should be set in the URL:
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
- ``device_id`` - The device to delete.
List all pushers
================
Gets information about all pushers for a specific ``user_id``.
The API is::
GET /_synapse/admin/v1/users/<user_id>/pushers
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
A response body like the following is returned:
.. code:: json
{
"pushers": [
{
"app_display_name":"HTTP Push Notifications",
"app_id":"m.http",
"data": {
"url":"example.com"
},
"device_display_name":"pushy push",
"kind":"http",
"lang":"None",
"profile_tag":"",
"pushkey":"a@example.com"
}
],
"total": 1
}
**Parameters**
The following parameters should be set in the URL:
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
**Response**
The following fields are returned in the JSON response body:
- ``pushers`` - An array containing the current pushers for the user
- ``app_display_name`` - string - A string that will allow the user to identify
what application owns this pusher.
- ``app_id`` - string - This is a reverse-DNS style identifier for the application.
Max length, 64 chars.
- ``data`` - A dictionary of information for the pusher implementation itself.
- ``url`` - string - Required if ``kind`` is ``http``. The URL to use to send
notifications to.
- ``format`` - string - The format to use when sending notifications to the
Push Gateway.
- ``device_display_name`` - string - A string that will allow the user to identify
what device owns this pusher.
- ``profile_tag`` - string - This string determines which set of device specific rules
this pusher executes.
- ``kind`` - string - The kind of pusher. "http" is a pusher that sends HTTP pokes.
- ``lang`` - string - The preferred language for receiving notifications
(e.g. 'en' or 'en-US')
- ``profile_tag`` - string - This string determines which set of device specific rules
this pusher executes.
- ``pushkey`` - string - This is a unique identifier for this pusher.
Max length, 512 bytes.
- ``total`` - integer - Number of pushers.
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_

View file

@ -60,6 +60,8 @@
1. Restart Prometheus.
1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/) and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/)
## Monitoring workers
To monitor a Synapse installation using

View file

@ -3,7 +3,11 @@
# This is a YAML file containing a standard Python logging configuration
# dictionary. See [1] for details on the valid settings.
#
# Synapse also supports structured logging for machine readable logs which can
# be ingested by ELK stacks. See [2] for details.
#
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
version: 1

View file

@ -1,11 +1,116 @@
# Structured Logging
A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack".
A structured logging system can be useful when your logs are destined for a
machine to parse and process. By maintaining its machine-readable characteristics,
it enables more efficient searching and aggregations when consumed by software
such as the "ELK stack".
Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to).
Synapse's structured logging system is configured via the file that Synapse's
`log_config` config option points to. The file should include a formatter which
uses the `synapse.logging.TerseJsonFormatter` class included with Synapse and a
handler which uses the above formatter.
There is also a `synapse.logging.JsonFormatter` option which does not include
a timestamp in the resulting JSON. This is useful if the log ingester adds its
own timestamp.
A structured logging configuration looks similar to the following:
```yaml
version: 1
formatters:
structured:
class: synapse.logging.TerseJsonFormatter
handlers:
file:
class: logging.handlers.TimedRotatingFileHandler
formatter: structured
filename: /path/to/my/logs/homeserver.log
when: midnight
backupCount: 3 # Does not include the current log file.
encoding: utf8
loggers:
synapse:
level: INFO
handlers: [remote]
synapse.storage.SQL:
level: WARNING
```
The above logging config will set Synapse as 'INFO' logging level by default,
with the SQL layer at 'WARNING', and will log to a file, stored as JSON.
It is also possible to figure Synapse to log to a remote endpoint by using the
`synapse.logging.RemoteHandler` class included with Synapse. It takes the
following arguments:
- `host`: Hostname or IP address of the log aggregator.
- `port`: Numerical port to contact on the host.
- `maximum_buffer`: (Optional, defaults to 1000) The maximum buffer size to allow.
A remote structured logging configuration looks similar to the following:
```yaml
version: 1
formatters:
structured:
class: synapse.logging.TerseJsonFormatter
handlers:
remote:
class: synapse.logging.RemoteHandler
formatter: structured
host: 10.1.2.3
port: 9999
loggers:
synapse:
level: INFO
handlers: [remote]
synapse.storage.SQL:
level: WARNING
```
The above logging config will set Synapse as 'INFO' logging level by default,
with the SQL layer at 'WARNING', and will log JSON formatted messages to a
remote endpoint at 10.1.2.3:9999.
## Upgrading from legacy structured logging configuration
Versions of Synapse prior to v1.23.0 included a custom structured logging
configuration which is deprecated. It used a `structured: true` flag and
configured `drains` instead of ``handlers`` and `formatters`.
Synapse currently automatically converts the old configuration to the new
configuration, but this will be removed in a future version of Synapse. The
following reference can be used to update your configuration. Based on the drain
`type`, we can pick a new handler:
1. For a type of `console`, `console_json`, or `console_json_terse`: a handler
with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout`
or `ext://sys.stderr` should be used.
2. For a type of `file` or `file_json`: a handler of `logging.FileHandler` with
a location of the file path should be used.
3. For a type of `network_json_terse`: a handler of `synapse.logging.RemoteHandler`
with the host and port should be used.
Then based on the drain `type` we can pick a new formatter:
1. For a type of `console` or `file` no formatter is necessary.
2. For a type of `console_json` or `file_json`: a formatter of
`synapse.logging.JsonFormatter` should be used.
3. For a type of `console_json_terse` or `network_json_terse`: a formatter of
`synapse.logging.TerseJsonFormatter` should be used.
For each new handler and formatter they should be added to the logging configuration
and then assigned to either a logger or the root logger.
An example legacy configuration:
```yaml
structured: true
@ -24,60 +129,33 @@ drains:
location: homeserver.log
```
The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON).
Would be converted into a new configuration:
## Drain Types
```yaml
version: 1
Drain types can be specified by the `type` key.
formatters:
json:
class: synapse.logging.JsonFormatter
### `console`
handlers:
console:
class: logging.StreamHandler
location: ext://sys.stdout
file:
class: logging.FileHandler
formatter: json
filename: homeserver.log
Outputs human-readable logs to the console.
loggers:
synapse:
level: INFO
handlers: [console, file]
synapse.storage.SQL:
level: WARNING
```
Arguments:
- `location`: Either `stdout` or `stderr`.
### `console_json`
Outputs machine-readable JSON logs to the console.
Arguments:
- `location`: Either `stdout` or `stderr`.
### `console_json_terse`
Outputs machine-readable JSON logs to the console, separated by newlines. This
format is not designed to be read and re-formatted into human-readable text, but
is optimal for a logging aggregation system.
Arguments:
- `location`: Either `stdout` or `stderr`.
### `file`
Outputs human-readable logs to a file.
Arguments:
- `location`: An absolute path to the file to log to.
### `file_json`
Outputs machine-readable logs to a file.
Arguments:
- `location`: An absolute path to the file to log to.
### `network_json_terse`
Delivers machine-readable JSON logs to a log aggregator over TCP. This is
compatible with LogStash's TCP input with the codec set to `json_lines`.
Arguments:
- `host`: Hostname or IP address of the log aggregator.
- `port`: Numerical port to contact on the host.
The new logging configuration is a bit more verbose, but significantly more
flexible. It allows for configuration that were not previously possible, such as
sending plain logs over the network, or using different handlers for different
modules.

View file

@ -57,6 +57,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py,
@ -82,6 +83,9 @@ ignore_missing_imports = True
[mypy-zope]
ignore_missing_imports = True
[mypy-bcrypt]
ignore_missing_imports = True
[mypy-constantly]
ignore_missing_imports = True

View file

@ -35,7 +35,7 @@
showcontent = true
[tool.black]
target-version = ['py34']
target-version = ['py35']
exclude = '''
(

View file

@ -80,7 +80,7 @@ else
# then lint everything!
if [[ -z ${files+x} ]]; then
# Lint all source code files and directories
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py")
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
fi
fi
@ -94,3 +94,4 @@ isort "${files[@]}"
python3 -m black "${files[@]}"
./scripts-dev/config-lint.sh
flake8 "${files[@]}"
mypy

View file

@ -19,9 +19,10 @@ can crop up, e.g the cache descriptors.
from typing import Callable, Optional
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType
from mypy.types import CallableType, NoneType
class SynapsePlugin(Plugin):
@ -40,8 +41,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
It already has *almost* the correct signature, except:
1. the `self` argument needs to be marked as "bound"; and
2. any `cache_context` argument should be removed.
1. the `self` argument needs to be marked as "bound";
2. any `cache_context` argument should be removed;
3. an optional keyword argument `on_invalidated` should be added.
"""
# First we mark this as a bound function signature.
@ -58,19 +60,33 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
context_arg_index = idx
break
arg_types = list(signature.arg_types)
arg_names = list(signature.arg_names)
arg_kinds = list(signature.arg_kinds)
if context_arg_index:
arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)
arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)
arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)
signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)
# Third, we add an optional "on_invalidate" argument.
#
# This is a callable which accepts no input and returns nothing.
calltyp = CallableType(
arg_types=[],
arg_kinds=[],
arg_names=[],
ret_type=NoneType(),
fallback=ctx.api.named_generic_type("builtins.function", []),
)
arg_types.append(calltyp)
arg_names.append("on_invalidate")
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)
return signature

View file

@ -131,6 +131,7 @@ setup(
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
scripts=["synctl"] + glob.glob("scripts/*"),
cmdclass={"test": TestCommand},

View file

@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@ -190,10 +191,6 @@ class Auth:
user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
@ -203,31 +200,38 @@ class Auth:
device_id="dummy-device", # stubbed
)
return synapse.types.create_requester(user_id, app_service=app_service)
requester = synapse.types.create_requester(
user_id, app_service=app_service
)
request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("user_id", user_id)
opentracing.set_tag("appservice_id", app_service.id)
return requester
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
shadow_banned = user_info["shadow_banned"]
token_id = user_info.token_id
is_guest = user_info.is_guest
shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
if await self.store.is_account_expired(
user_info.user_id, self.clock.time_msec()
):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
device_id = user_info.device_id
if user and access_token and ip_addr:
if access_token and ip_addr:
await self.store.insert_client_ip(
user_id=user.to_string(),
user_id=user_info.token_owner,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
@ -241,19 +245,23 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
request.authenticated_entity = user.to_string()
opentracing.set_tag("authenticated_entity", user.to_string())
if device_id:
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
user,
requester = synapse.types.create_requester(
user_info.user_id,
token_id,
is_guest,
shadow_banned,
device_id,
app_service=app_service,
authenticated_entity=user_info.token_owner,
)
request.requester = requester
opentracing.set_tag("authenticated_entity", user_info.token_owner)
opentracing.set_tag("user_id", user_info.user_id)
if device_id:
opentracing.set_tag("device_id", device_id)
return requester
except KeyError:
raise MissingClientTokenError()
@ -284,7 +292,7 @@ class Auth:
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
) -> dict:
) -> TokenLookupResult:
""" Validate access token and get user_id from it
Args:
@ -293,13 +301,7 @@ class Auth:
allow this
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
InvalidClientTokenError if a user by that token exists, but the token is
expired
@ -309,9 +311,9 @@ class Auth:
if rights == "access":
# first look in the database
r = await self._look_up_user_by_access_token(token)
r = await self.store.get_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
valid_until_ms = r.valid_until_ms
if (
not allow_expired
and valid_until_ms is not None
@ -328,7 +330,6 @@ class Auth:
# otherwise it needs to be a valid macaroon
try:
user_id, guest = self._parse_and_validate_macaroon(token, rights)
user = UserID.from_string(user_id)
if rights == "access":
if not guest:
@ -354,23 +355,17 @@ class Auth:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)
ret = {
"user": user,
"is_guest": True,
"shadow_banned": False,
"token_id": None,
ret = TokenLookupResult(
user_id=user_id,
is_guest=True,
# all guests get the same device id
"device_id": GUEST_DEVICE_ID,
}
device_id=GUEST_DEVICE_ID,
)
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"shadow_banned": False,
"token_id": None,
"device_id": None,
}
ret = TokenLookupResult(user_id=user_id, is_guest=False)
else:
raise RuntimeError("Unknown rights setting %s", rights)
return ret
@ -479,31 +474,15 @@ class Auth:
now = self.hs.get_clock().time_msec()
return now < expiry
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
return user_info
def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
request.requester = synapse.types.create_requester(
service.sender, app_service=service
)
return service
async def is_server_admin(self, user: UserID) -> bool:

View file

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationServiceApi
@ -52,11 +52,11 @@ class ApplicationService:
self,
token,
hostname,
id,
sender,
url=None,
namespaces=None,
hs_token=None,
sender=None,
id=None,
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
@ -164,9 +164,9 @@ class ApplicationService:
does_match = await self.matches_user_in_member_list(event.room_id, store)
return does_match
@cached(num_args=1)
@cached(num_args=1, cache_context=True)
async def matches_user_in_member_list(
self, room_id: str, store: "DataStore"
self, room_id: str, store: "DataStore", cache_context: _CacheContext,
) -> bool:
"""Check if this service is interested a room based upon it's membership
@ -177,7 +177,9 @@ class ApplicationService:
Returns:
True if this service would like to know about this room.
"""
member_list = await store.get_users_in_room(room_id)
member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
# check joined member events
for user_id in member_list:

View file

@ -23,7 +23,6 @@ from string import Template
import yaml
from twisted.logger import (
ILogObserver,
LogBeginner,
STDLibLogObserver,
eventAsText,
@ -32,11 +31,9 @@ from twisted.logger import (
import synapse
from synapse.app import _base as appbase
from synapse.logging._structured import (
reload_structured_logging,
setup_structured_logging,
)
from synapse.logging._structured import setup_structured_logging
from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter
from synapse.util.versionstring import get_version_string
from ._base import Config, ConfigError
@ -48,7 +45,11 @@ DEFAULT_LOG_CONFIG = Template(
# This is a YAML file containing a standard Python logging configuration
# dictionary. See [1] for details on the valid settings.
#
# Synapse also supports structured logging for machine readable logs which can
# be ingested by ELK stacks. See [2] for details.
#
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
version: 1
@ -176,11 +177,11 @@ class LoggingConfig(Config):
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
"""
Set up Python stdlib logging.
Set up Python standard library logging.
"""
if log_config is None:
if log_config_path is None:
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
@ -196,7 +197,8 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
handler.setFormatter(formatter)
logger.addHandler(handler)
else:
logging.config.dictConfig(log_config)
# Load the logging configuration.
_load_logging_config(log_config_path)
# We add a log record factory that runs all messages through the
# LoggingContextFilter so that we get the context *at the time we log*
@ -204,12 +206,14 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
log_filter = LoggingContextFilter(request="")
log_context_filter = LoggingContextFilter(request="")
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
log_filter.filter(record)
log_context_filter.filter(record)
log_metadata_filter.filter(record)
return record
logging.setLogRecordFactory(factory)
@ -255,21 +259,40 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
if not config.no_redirect_stdio:
print("Redirected stdout/stderr to logs")
return observer
def _reload_stdlib_logging(*args, log_config=None):
logger = logging.getLogger("")
def _load_logging_config(log_config_path: str) -> None:
"""
Configure logging from a log config path.
"""
with open(log_config_path, "rb") as f:
log_config = yaml.safe_load(f.read())
if not log_config:
logger.warning("Reloaded a blank config?")
logging.warning("Loaded a blank logging config?")
# If the old structured logging configuration is being used, convert it to
# the new style configuration.
if "structured" in log_config and log_config.get("structured"):
log_config = setup_structured_logging(log_config)
logging.config.dictConfig(log_config)
def _reload_logging_config(log_config_path):
"""
Reload the log configuration from the file and apply it.
"""
# If no log config path was given, it cannot be reloaded.
if log_config_path is None:
return
_load_logging_config(log_config_path)
logging.info("Reloaded log config from %s due to SIGHUP", log_config_path)
def setup_logging(
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
) -> ILogObserver:
) -> None:
"""
Set up the logging subsystem.
@ -282,41 +305,18 @@ def setup_logging(
logBeginner: The Twisted logBeginner to use.
Returns:
The "root" Twisted Logger observer, suitable for sending logs to from a
Logger instance.
"""
log_config = config.worker_log_config if use_worker_options else config.log_config
log_config_path = (
config.worker_log_config if use_worker_options else config.log_config
)
def read_config(*args, callback=None):
if log_config is None:
return None
# Perform one-time logging configuration.
_setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
# Add a SIGHUP handler to reload the logging configuration, if one is available.
appbase.register_sighup(_reload_logging_config, log_config_path)
with open(log_config, "rb") as f:
log_config_body = yaml.safe_load(f.read())
if callback:
callback(log_config=log_config_body)
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
return log_config_body
log_config_body = read_config()
if log_config_body and log_config_body.get("structured") is True:
logger = setup_structured_logging(
hs, config, log_config_body, logBeginner=logBeginner
)
appbase.register_sighup(read_config, callback=reload_structured_logging)
else:
logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner)
appbase.register_sighup(read_config, callback=_reload_stdlib_logging)
# make sure that the first thing we log is a thing we can grep backwards
# for
# Log immediately so we can grep backwards.
logging.warning("***** STARTING SERVER *****")
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
logging.info("Instance name: %s", hs.get_instance_name())
return logger

View file

@ -368,7 +368,7 @@ class FrozenEvent(EventBase):
return self.__repr__()
def __repr__(self):
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % (
self.get("event_id", None),
self.get("type", None),
self.get("state_key", None),
@ -451,7 +451,7 @@ class FrozenEventV2(EventBase):
return self.__repr__()
def __repr__(self):
return "<%s event_id='%s', type='%s', state_key='%s'>" % (
return "<%s event_id=%r, type=%r, state_key=%r>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),

View file

@ -154,7 +154,7 @@ class Authenticator:
)
logger.debug("Request from %s", origin)
request.authenticated_entity = origin
request.requester = origin
# If we get a valid signed request from the other side, its probably
# alive

View file

@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from prometheus_client import Counter
@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
class ApplicationServicesHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
self.appservice_api = hs.get_application_service_api()
@ -247,7 +250,9 @@ class ApplicationServicesHandler:
service, "presence", new_token
)
async def _handle_typing(self, service: ApplicationService, new_token: int):
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
@ -259,7 +264,7 @@ class ApplicationServicesHandler:
)
return typing
async def _handle_receipts(self, service: ApplicationService):
async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
@ -271,7 +276,7 @@ class ApplicationServicesHandler:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
):
) -> List[JsonDict]:
events = [] # type: List[JsonDict]
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
@ -301,11 +306,11 @@ class ApplicationServicesHandler:
return events
async def query_user_exists(self, user_id):
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.
Args:
user_id(str): The user to query if they exist on any AS.
user_id: The user to query if they exist on any AS.
Returns:
True if this user exists on at least one application service.
"""
@ -316,11 +321,13 @@ class ApplicationServicesHandler:
return True
return False
async def query_room_alias_exists(self, room_alias):
async def query_room_alias_exists(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
"""Check if an application service knows this room alias exists.
Args:
room_alias(RoomAlias): The room alias to query.
room_alias: The room alias to query.
Returns:
namedtuple: with keys "room_id" and "servers" or None if no
association can be found.
@ -336,10 +343,13 @@ class ApplicationServicesHandler:
)
if is_known_alias:
# the alias exists now so don't query more ASes.
result = await self.store.get_association_from_room_alias(room_alias)
return result
return await self.store.get_association_from_room_alias(room_alias)
async def query_3pe(self, kind, protocol, fields):
return None
async def query_3pe(
self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
) -> List[JsonDict]:
services = self._get_services_for_3pn(protocol)
results = await make_deferred_yieldable(
@ -361,7 +371,9 @@ class ApplicationServicesHandler:
return ret
async def get_3pe_protocols(self, only_protocol=None):
async def get_3pe_protocols(
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]]
@ -379,7 +391,7 @@ class ApplicationServicesHandler:
if info is not None:
protocols[p].append(info)
def _merge_instances(infos):
def _merge_instances(infos: List[JsonDict]) -> JsonDict:
if not infos:
return {}
@ -394,19 +406,17 @@ class ApplicationServicesHandler:
return combined
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
return protocols
async def _get_services_for_event(self, event):
async def _get_services_for_event(
self, event: EventBase
) -> List[ApplicationService]:
"""Retrieve a list of application services interested in this event.
Args:
event(Event): The event to check. Can be None if alias_list is not.
event: The event to check. Can be None if alias_list is not.
Returns:
list<ApplicationService>: A list of services interested in this
event based on the service regex.
A list of services interested in this event based on the service regex.
"""
services = self.store.get_app_services()
@ -420,17 +430,15 @@ class ApplicationServicesHandler:
return interested_list
def _get_services_for_user(self, user_id):
def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
services = self.store.get_app_services()
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
return interested_list
return [s for s in services if (s.is_interested_in_user(user_id))]
def _get_services_for_3pn(self, protocol):
def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
services = self.store.get_app_services()
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
return interested_list
return [s for s in services if s.is_interested_in_protocol(protocol)]
async def _is_unknown_user(self, user_id):
async def _is_unknown_user(self, user_id: str) -> bool:
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
@ -445,9 +453,8 @@ class ApplicationServicesHandler:
service_list = [s for s in services if s.sender == user_id]
return len(service_list) == 0
async def _check_user_exists(self, user_id):
async def _check_user_exists(self, user_id: str) -> bool:
unknown_user = await self._is_unknown_user(user_id)
if unknown_user:
exists = await self.query_user_exists(user_id)
return exists
return await self.query_user_exists(user_id)
return True

View file

@ -18,10 +18,20 @@ import logging
import time
import unicodedata
import urllib.parse
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
import attr
import bcrypt # type: ignore[import]
import bcrypt
import pymacaroons
from synapse.api.constants import LoginType
@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
@ -982,17 +991,17 @@ class AuthHandler(BaseHandler):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token
if user_info["token_id"] is not None:
if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],)
user_info.user_id, (user_info.token_id,)
)
async def delete_access_tokens_for_user(

View file

@ -50,9 +50,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder
from synapse.util import json_decoder, json_encoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@ -928,7 +927,7 @@ class EventCreationHandler:
# Ensure that we can round trip before trying to persist in db
try:
dump = frozendict_json_encoder.encode(event.content)
dump = json_encoder.encode(event.content)
json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
@ -1100,34 +1099,13 @@ class EventCreationHandler:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = await context.get_current_state_ids()
# We know this event is not an outlier, so this must be
# non-None.
assert current_state_ids is not None
state_to_include_ids = [
e_id
for k, e_id in current_state_ids.items()
if k[0] in self.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
state_to_include = await self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for e in state_to_include.values()
]
event.unsigned[
"invite_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
self.room_invite_state_types,
membership_user_id=event.sender,
)
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):

View file

@ -48,7 +48,7 @@ from synapse.util.wheel_timer import WheelTimer
MYPY = False
if MYPY:
import synapse.server
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -101,7 +101,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class BasePresenceHandler(abc.ABC):
"""Parts of the PresenceHandler that are shared between workers and master"""
def __init__(self, hs: "synapse.server.HomeServer"):
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@ -199,7 +199,7 @@ class BasePresenceHandler(abc.ABC):
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
@ -1011,7 +1011,7 @@ def format_user_presence_state(state, now, include_user_id=True):
class PresenceEventSource:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
@ -1071,12 +1071,14 @@ class PresenceEventSource:
users_interested_in = await self._get_interested_in(user, explicit_room_id)
user_ids_changed = set()
user_ids_changed = set() # type: Collection[str]
changed = None
if from_key:
changed = stream_change_cache.get_all_entities_changed(from_key)
if changed is not None and len(changed) < 500:
assert isinstance(user_ids_changed, set)
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
get_updates_counter.labels("stream").inc()

View file

@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
if (
not user_data.is_guest
or UserID.from_string(user_data.user_id).localpart != localpart
):
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
user_id=user_id,

View file

@ -771,22 +771,29 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
for invitee in invite_list:
# we avoid dropping the lock between invites, as otherwise joins can
# start coming in and making the createRoom slow.
#
# we also don't need to check the requester's shadow-ban here, as we
# have already done so above (and potentially emptied invite_list).
with (await self.room_member_handler.member_linearizer.queue((room_id,))):
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
content["is_direct"] = is_direct
# Note that update_membership with an action of "invite" can raise a
# ShadowBanError, but this was handled above by emptying invite_list.
_, last_stream_id = await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
ratelimit=False,
content=content,
)
for invitee in invite_list:
(
_,
last_stream_id,
) = await self.room_member_handler.update_membership_locked(
requester,
UserID.from_string(invitee),
room_id,
"invite",
ratelimit=False,
content=content,
)
for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"]

View file

@ -327,7 +327,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# haproxy would have timed the request out anyway...
raise SynapseError(504, "took to long to process")
result = await self._update_membership(
result = await self.update_membership_locked(
requester,
target,
room_id,
@ -342,7 +342,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return result
async def _update_membership(
async def update_membership_locked(
self,
requester: Requester,
target: UserID,
@ -355,6 +355,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content: Optional[dict] = None,
require_consent: bool = True,
) -> Tuple[str, int]:
"""Helper for update_membership.
Assumes that the membership linearizer is already held for the room.
"""
content_specified = bool(content)
if content is None:
content = {}

View file

@ -359,7 +359,7 @@ class SimpleHttpClient:
agent=self.agent,
data=body_producer,
headers=headers,
**self._extra_treq_args
**self._extra_treq_args,
) # type: defer.Deferred
# we use our own timeout mechanism rather than treq's as a workaround

View file

@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import File, NoRangeStaticProducer
from twisted.web.util import redirectTo
import synapse.events
import synapse.metrics
from synapse.api.errors import (
CodeMessageException,
Codes,
@ -620,7 +618,7 @@ def respond_with_json(
if pretty_print:
encoder = iterencode_pretty_printed_json
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
if canonical_json:
encoder = iterencode_canonical_json
else:
encoder = _encode_json_bytes

View file

@ -14,7 +14,7 @@
import contextlib
import logging
import time
from typing import Optional
from typing import Optional, Union
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.types import Requester
logger = logging.getLogger(__name__)
@ -54,9 +55,12 @@ class SynapseRequest(Request):
Request.__init__(self, channel, *args, **kw)
self.site = channel.site
self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0.0
# The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object.
self.requester = None # type: Optional[Union[Requester, str]]
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext]
@ -271,11 +275,23 @@ class SynapseRequest(Request):
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
# need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# Convert the requester into a string that we can log
authenticated_entity = None
if isinstance(self.requester, str):
authenticated_entity = self.requester
elif isinstance(self.requester, Requester):
authenticated_entity = self.requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
authenticated_entity, self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically

View file

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# These are imported to allow for nicer logging configuration files.
from synapse.logging._remote import RemoteHandler
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
__all__ = ["RemoteHandler", "JsonFormatter", "TerseJsonFormatter"]

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import traceback
from collections import deque
@ -21,10 +22,11 @@ from math import floor
from typing import Callable, Optional
import attr
from typing_extensions import Deque
from zope.interface import implementer
from twisted.application.internet import ClientService
from twisted.internet.defer import Deferred
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
@ -32,7 +34,9 @@ from twisted.internet.endpoints import (
)
from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import ILogObserver, Logger, LogLevel
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@attr.s
@ -45,11 +49,11 @@ class LogProducer:
Args:
buffer: Log buffer to read logs from.
transport: Transport to write to.
format_event: A callable to format the log entry to a string.
format: A callable to format the log record to a string.
"""
transport = attr.ib(type=ITransport)
format_event = attr.ib(type=Callable[[dict], str])
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@ -61,16 +65,19 @@ class LogProducer:
self._buffer = deque()
def resumeProducing(self):
# If we're already producing, nothing to do.
self._paused = False
# Loop until paused.
while self._paused is False and (self._buffer and self.transport.connected):
try:
# Request the next event and format it.
event = self._buffer.popleft()
msg = self.format_event(event)
# Request the next record and format it.
record = self._buffer.popleft()
msg = self._format(record)
# Send it as a new line over the transport.
self.transport.write(msg.encode("utf8"))
self.transport.write(b"\n")
except Exception:
# Something has gone wrong writing to the transport -- log it
# and break out of the while.
@ -78,76 +85,85 @@ class LogProducer:
break
@attr.s
@implementer(ILogObserver)
class TCPLogObserver:
class RemoteHandler(logging.Handler):
"""
An IObserver that writes JSON logs to a TCP target.
An logging handler that writes logs to a TCP target.
Args:
hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
format_event: A callable to format the log entry to a string.
maximum_buffer: The maximum buffer size.
"""
hs = attr.ib()
host = attr.ib(type=str)
port = attr.ib(type=int)
format_event = attr.ib(type=Callable[[dict], str])
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
_logger = attr.ib(default=attr.Factory(Logger))
_producer = attr.ib(default=None, type=Optional[LogProducer])
def __init__(
self,
host: str,
port: int,
maximum_buffer: int = 1000,
level=logging.NOTSET,
_reactor=None,
):
super().__init__(level=level)
self.host = host
self.port = port
self.maximum_buffer = maximum_buffer
def start(self) -> None:
self._buffer = deque() # type: Deque[logging.LogRecord]
self._connection_waiter = None # type: Optional[Deferred]
self._producer = None # type: Optional[LogProducer]
# Connect without DNS lookups if it's a direct IP.
if _reactor is None:
from twisted.internet import reactor
_reactor = reactor
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else:
raise ValueError("Unknown IP address provided: %s" % (self.host,))
except ValueError:
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
endpoint = HostnameEndpoint(_reactor, self.host, self.port)
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service = ClientService(endpoint, factory, clock=_reactor)
self._service.startService()
self._stopping = False
self._connect()
def stop(self):
def close(self):
self._stopping = True
self._service.stopService()
def _connect(self) -> None:
"""
Triggers an attempt to connect then write to the remote if not already writing.
"""
# Do not attempt to open multiple connections.
if self._connection_waiter:
return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
@self._connection_waiter.addErrback
def fail(r):
r.printTraceback(file=sys.__stderr__)
def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
if failure.check(CancelledError) and self._stopping:
return
# For a different error, print the traceback and re-connect.
failure.printTraceback(file=sys.__stderr__)
self._connection_waiter = None
self._connect()
@self._connection_waiter.addCallback
def writer(r):
def writer(result: Protocol) -> None:
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
if self._producer and r.transport is self._producer.transport:
if self._producer and result.transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@ -158,29 +174,29 @@ class TCPLogObserver:
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
transport=r.transport,
format_event=self.format_event,
buffer=self._buffer, transport=result.transport, format=self.format,
)
r.transport.registerProducer(self._producer, True)
result.transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
self._connection_waiter.addCallbacks(writer, fail)
def _handle_pressure(self) -> None:
"""
Handle backpressure by shedding events.
Handle backpressure by shedding records.
The buffer will, in this order, until the buffer is below the maximum:
- Shed DEBUG events
- Shed INFO events
- Shed the middle 50% of the events.
- Shed DEBUG records.
- Shed INFO records.
- Shed the middle 50% of the records.
"""
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out DEBUGs
self._buffer = deque(
filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
filter(lambda record: record.levelno > logging.DEBUG, self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
@ -188,7 +204,7 @@ class TCPLogObserver:
# Strip out INFOs
self._buffer = deque(
filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
filter(lambda record: record.levelno > logging.INFO, self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
@ -209,17 +225,17 @@ class TCPLogObserver:
self._buffer.extend(reversed(end_buffer))
def __call__(self, event: dict) -> None:
self._buffer.append(event)
def emit(self, record: logging.LogRecord) -> None:
self._buffer.append(record)
# Handle backpressure, if it exists.
try:
self._handle_pressure()
except Exception:
# If handling backpressure fails,clear the buffer and log the
# If handling backpressure fails, clear the buffer and log the
# exception.
self._buffer.clear()
self._logger.failure("Failed clearing backpressure")
logger.warning("Failed clearing backpressure")
# Try and write immediately.
self._connect()

View file

@ -12,138 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os.path
import sys
import typing
import warnings
from typing import List
from typing import Any, Dict, Generator, Optional, Tuple
import attr
from constantly import NamedConstant, Names, ValueConstant, Values
from zope.interface import implementer
from twisted.logger import (
FileLogObserver,
FilteringLogObserver,
ILogObserver,
LogBeginner,
Logger,
LogLevel,
LogLevelFilterPredicate,
LogPublisher,
eventAsText,
jsonFileLogObserver,
)
from constantly import NamedConstant, Names
from synapse.config._base import ConfigError
from synapse.logging._terse_json import (
TerseJSONToConsoleLogObserver,
TerseJSONToTCPLogObserver,
)
from synapse.logging.context import current_context
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
"""
Convert a stdlib log level to Twisted's log level.
"""
lvl = level.lower().replace("warning", "warn")
return LogLevel.levelWithName(lvl)
@attr.s
@implementer(ILogObserver)
class LogContextObserver:
"""
An ILogObserver which adds Synapse-specific log context information.
Attributes:
observer (ILogObserver): The target parent observer.
"""
observer = attr.ib()
def __call__(self, event: dict) -> None:
"""
Consume a log event and emit it to the parent observer after filtering
and adding log context information.
Args:
event (dict)
"""
# Filter out some useless events that Twisted outputs
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
if event["log_text"].startswith("(UDP Port "):
return
if event["log_text"].startswith("Timing out client") or event[
"log_format"
].startswith("Timing out client"):
return
context = current_context()
# Copy the context information to the log event.
context.copy_to_twisted_log_entry(event)
self.observer(event)
class PythonStdlibToTwistedLogger(logging.Handler):
"""
Transform a Python stdlib log message into a Twisted one.
"""
def __init__(self, observer, *args, **kwargs):
"""
Args:
observer (ILogObserver): A Twisted logging observer.
*args, **kwargs: Args/kwargs to be passed to logging.Handler.
"""
self.observer = observer
super().__init__(*args, **kwargs)
def emit(self, record: logging.LogRecord) -> None:
"""
Emit a record to Twisted's observer.
Args:
record (logging.LogRecord)
"""
self.observer(
{
"log_time": record.created,
"log_text": record.getMessage(),
"log_format": "{log_text}",
"log_namespace": record.name,
"log_level": stdlib_log_level_to_twisted(record.levelname),
}
)
def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver:
"""
A log observer that formats events like the traditional log formatter and
sends them to `outFile`.
Args:
outFile (file object): The file object to write to.
"""
def formatEvent(_event: dict) -> str:
event = dict(_event)
event["log_level"] = event["log_level"].name.upper()
event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
event.get("log_format", "{log_text}") or "{log_text}"
)
return eventAsText(event, includeSystem=False) + "\n"
return FileLogObserver(outFile, formatEvent)
class DrainType(Names):
@ -155,30 +29,12 @@ class DrainType(Names):
NETWORK_JSON_TERSE = NamedConstant()
class OutputPipeType(Values):
stdout = ValueConstant(sys.__stdout__)
stderr = ValueConstant(sys.__stderr__)
@attr.s
class DrainConfiguration:
name = attr.ib()
type = attr.ib()
location = attr.ib()
options = attr.ib(default=None)
@attr.s
class NetworkJSONTerseOptions:
maximum_buffer = attr.ib(type=int)
DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
def parse_drain_configs(
drains: dict,
) -> typing.Generator[DrainConfiguration, None, None]:
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
"""
Parse the drain configurations.
@ -186,11 +42,12 @@ def parse_drain_configs(
drains (dict): A list of drain configurations.
Yields:
DrainConfiguration instances.
dict instances representing a logging handler.
Raises:
ConfigError: If any of the drain configuration items are invalid.
"""
for name, config in drains.items():
if "type" not in config:
raise ConfigError("Logging drains require a 'type' key.")
@ -202,6 +59,18 @@ def parse_drain_configs(
"%s is not a known logging drain type." % (config["type"],)
)
# Either use the default formatter or the tersejson one.
if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
formatter = "json" # type: Optional[str]
elif logging_type in (
DrainType.CONSOLE_JSON_TERSE,
DrainType.NETWORK_JSON_TERSE,
):
formatter = "tersejson"
else:
# A formatter of None implies using the default formatter.
formatter = None
if logging_type in [
DrainType.CONSOLE,
DrainType.CONSOLE_JSON,
@ -217,9 +86,11 @@ def parse_drain_configs(
% (logging_type,)
)
pipe = OutputPipeType.lookupByName(location).value
yield DrainConfiguration(name=name, type=logging_type, location=pipe)
yield name, {
"class": "logging.StreamHandler",
"formatter": formatter,
"stream": "ext://sys." + location,
}
elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
if "location" not in config:
@ -233,18 +104,25 @@ def parse_drain_configs(
"File paths need to be absolute, '%s' is a relative path"
% (location,)
)
yield DrainConfiguration(name=name, type=logging_type, location=location)
yield name, {
"class": "logging.FileHandler",
"formatter": formatter,
"filename": location,
}
elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
host = config.get("host")
port = config.get("port")
maximum_buffer = config.get("maximum_buffer", 1000)
yield DrainConfiguration(
name=name,
type=logging_type,
location=(host, port),
options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
)
yield name, {
"class": "synapse.logging.RemoteHandler",
"formatter": formatter,
"host": host,
"port": port,
"maximum_buffer": maximum_buffer,
}
else:
raise ConfigError(
@ -253,126 +131,29 @@ def parse_drain_configs(
)
class StoppableLogPublisher(LogPublisher):
def setup_structured_logging(log_config: dict,) -> dict:
"""
A log publisher that can tell its observers to shut down any external
communications.
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
to one compatible with the new standard library handlers.
"""
def stop(self):
for obs in self._observers:
if hasattr(obs, "stop"):
obs.stop()
def setup_structured_logging(
hs,
config,
log_config: dict,
logBeginner: LogBeginner,
redirect_stdlib_logging: bool = True,
) -> LogPublisher:
"""
Set up Twisted's structured logging system.
Args:
hs: The homeserver to use.
config (HomeserverConfig): The configuration of the Synapse homeserver.
log_config (dict): The log configuration to use.
"""
if config.no_redirect_stdio:
raise ConfigError(
"no_redirect_stdio cannot be defined using structured logging."
)
logger = Logger()
if "drains" not in log_config:
raise ConfigError("The logging configuration requires a list of drains.")
observers = [] # type: List[ILogObserver]
new_config = {
"version": 1,
"formatters": {
"json": {"class": "synapse.logging.JsonFormatter"},
"tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
},
"handlers": {},
"loggers": log_config.get("loggers", DEFAULT_LOGGERS),
"root": {"handlers": []},
}
for observer in parse_drain_configs(log_config["drains"]):
# Pipe drains
if observer.type == DrainType.CONSOLE:
logger.debug(
"Starting up the {name} console logger drain", name=observer.name
)
observers.append(SynapseFileLogObserver(observer.location))
elif observer.type == DrainType.CONSOLE_JSON:
logger.debug(
"Starting up the {name} JSON console logger drain", name=observer.name
)
observers.append(jsonFileLogObserver(observer.location))
elif observer.type == DrainType.CONSOLE_JSON_TERSE:
logger.debug(
"Starting up the {name} terse JSON console logger drain",
name=observer.name,
)
observers.append(
TerseJSONToConsoleLogObserver(observer.location, metadata={})
)
for handler_name, handler in parse_drain_configs(log_config["drains"]):
new_config["handlers"][handler_name] = handler
# File drains
elif observer.type == DrainType.FILE:
logger.debug("Starting up the {name} file logger drain", name=observer.name)
log_file = open(observer.location, "at", buffering=1, encoding="utf8")
observers.append(SynapseFileLogObserver(log_file))
elif observer.type == DrainType.FILE_JSON:
logger.debug(
"Starting up the {name} JSON file logger drain", name=observer.name
)
log_file = open(observer.location, "at", buffering=1, encoding="utf8")
observers.append(jsonFileLogObserver(log_file))
# Add each handler to the root logger.
new_config["root"]["handlers"].append(handler_name)
elif observer.type == DrainType.NETWORK_JSON_TERSE:
metadata = {"server_name": hs.config.server_name}
log_observer = TerseJSONToTCPLogObserver(
hs=hs,
host=observer.location[0],
port=observer.location[1],
metadata=metadata,
maximum_buffer=observer.options.maximum_buffer,
)
log_observer.start()
observers.append(log_observer)
else:
# We should never get here, but, just in case, throw an error.
raise ConfigError("%s drain type cannot be configured" % (observer.type,))
publisher = StoppableLogPublisher(*observers)
log_filter = LogLevelFilterPredicate()
for namespace, namespace_config in log_config.get(
"loggers", DEFAULT_LOGGERS
).items():
# Set the log level for twisted.logger.Logger namespaces
log_filter.setLogLevelForNamespace(
namespace,
stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
)
# Also set the log levels for the stdlib logger namespaces, to prevent
# them getting to PythonStdlibToTwistedLogger and having to be formatted
if "level" in namespace_config:
logging.getLogger(namespace).setLevel(namespace_config.get("level"))
f = FilteringLogObserver(publisher, [log_filter])
lco = LogContextObserver(f)
if redirect_stdlib_logging:
stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
stdliblogger = logging.getLogger()
stdliblogger.addHandler(stuff_into_twisted)
# Always redirect standard I/O, otherwise other logging outputs might miss
# it.
logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
return publisher
def reload_structured_logging(*args, log_config=None) -> None:
warnings.warn(
"Currently the structured logging system can not be reloaded, doing nothing"
)
return new_config

View file

@ -16,141 +16,65 @@
"""
Log formatters that output terse JSON.
"""
import json
from typing import IO
from twisted.logger import FileLogObserver
from synapse.logging._remote import TCPLogObserver
import logging
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
def flatten_event(event: dict, metadata: dict, include_time: bool = False):
"""
Flatten a Twisted logging event to an dictionary capable of being sent
as a log event to a logging aggregation system.
The format is vastly simplified and is not designed to be a "human readable
string" in the sense that traditional logs are. Instead, the structure is
optimised for searchability and filtering, with human-understandable log
keys.
Args:
event (dict): The Twisted logging event we are flattening.
metadata (dict): Additional data to include with each log message. This
can be information like the server name. Since the target log
consumer does not know who we are other than by host IP, this
allows us to forward through static information.
include_time (bool): Should we include the `time` key? If False, the
event time is stripped from the event.
"""
new_event = {}
# If it's a failure, make the new event's log_failure be the traceback text.
if "log_failure" in event:
new_event["log_failure"] = event["log_failure"].getTraceback()
# If it's a warning, copy over a string representation of the warning.
if "warning" in event:
new_event["warning"] = str(event["warning"])
# Stdlib logging events have "log_text" as their human-readable portion,
# Twisted ones have "log_format". For now, include the log_format, so that
# context only given in the log format (e.g. what is being logged) is
# available.
if "log_text" in event:
new_event["log"] = event["log_text"]
else:
new_event["log"] = event["log_format"]
# We want to include the timestamp when forwarding over the network, but
# exclude it when we are writing to stdout. This is because the log ingester
# (e.g. logstash, fluentd) can add its own timestamp.
if include_time:
new_event["time"] = round(event["log_time"], 2)
# Convert the log level to a textual representation.
new_event["level"] = event["log_level"].name.upper()
# Ignore these keys, and do not transfer them over to the new log object.
# They are either useless (isError), transferred manually above (log_time,
# log_level, etc), or contain Python objects which are not useful for output
# (log_logger, log_source).
keys_to_delete = [
"isError",
"log_failure",
"log_format",
"log_level",
"log_logger",
"log_source",
"log_system",
"log_time",
"log_text",
"observer",
"warning",
]
# If it's from the Twisted legacy logger (twisted.python.log), it adds some
# more keys we want to purge.
if event.get("log_namespace") == "log_legacy":
keys_to_delete.extend(["message", "system", "time"])
# Rather than modify the dictionary in place, construct a new one with only
# the content we want. The original event should be considered 'frozen'.
for key in event.keys():
if key in keys_to_delete:
continue
if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
# If it's a plain type, include it as is.
new_event[key] = event[key]
else:
# If it's not one of those basic types, write out a string
# representation. This should probably be a warning in development,
# so that we are sure we are only outputting useful data.
new_event[key] = str(event[key])
# Add the metadata information to the event (e.g. the server_name).
new_event.update(metadata)
return new_event
# The properties of a standard LogRecord.
_LOG_RECORD_ATTRIBUTES = {
"args",
"asctime",
"created",
"exc_info",
# exc_text isn't a public attribute, but is used to cache the result of formatException.
"exc_text",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"message",
"module",
"msecs",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
}
def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver:
"""
A log observer that formats events to a flattened JSON representation.
class JsonFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
event = {
"log": record.getMessage(),
"namespace": record.name,
"level": record.levelname,
}
Args:
outFile: The file object to write to.
metadata: Metadata to be added to each log object.
"""
return self._format(record, event)
def formatEvent(_event: dict) -> str:
flattened = flatten_event(_event, metadata)
return _encoder.encode(flattened) + "\n"
def _format(self, record: logging.LogRecord, event: dict) -> str:
# Add any extra attributes to the event.
for key, value in record.__dict__.items():
if key not in _LOG_RECORD_ATTRIBUTES:
event[key] = value
return FileLogObserver(outFile, formatEvent)
return _encoder.encode(event)
def TerseJSONToTCPLogObserver(
hs, host: str, port: int, metadata: dict, maximum_buffer: int
) -> FileLogObserver:
"""
A log observer that formats events to a flattened JSON representation.
class TerseJsonFormatter(JsonFormatter):
def format(self, record: logging.LogRecord) -> str:
event = {
"log": record.getMessage(),
"namespace": record.name,
"level": record.levelname,
"time": round(record.created, 2),
}
Args:
hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
metadata: Metadata to be added to each log object.
maximum_buffer: The maximum buffer size.
"""
def formatEvent(_event: dict) -> str:
flattened = flatten_event(_event, metadata, include_time=True)
return _encoder.encode(flattened) + "\n"
return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
return self._format(record, event)

33
synapse/logging/filter.py Normal file
View file

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing_extensions import Literal
class MetadataFilter(logging.Filter):
"""Logging filter that adds constant values to each record.
Args:
metadata: Key-value pairs to add to each record.
"""
def __init__(self, metadata: dict):
self._metadata = metadata
def filter(self, record: logging.LogRecord) -> Literal[True]:
for key, value in self._metadata.items():
setattr(record, key, value)
return True

View file

@ -28,6 +28,7 @@ from typing import (
Union,
)
import attr
from prometheus_client import Counter
from twisted.internet import defer
@ -173,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
return bool(self.events)
@attr.s(slots=True, frozen=True)
class _PendingRoomEventEntry:
event_pos = attr.ib(type=PersistedEventPosition)
extra_users = attr.ib(type=Collection[UserID])
room_id = attr.ib(type=str)
type = attr.ib(type=str)
state_key = attr.ib(type=Optional[str])
membership = attr.ib(type=Optional[str])
class Notifier:
""" This class is responsible for notifying any listeners when there are
new events available for it.
@ -190,9 +202,7 @@ class Notifier:
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = (
[]
) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@ -255,7 +265,29 @@ class Notifier:
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
""" Used by handlers to inform the notifier something has happened
"""Unwraps event and calls `on_new_room_event_args`.
"""
self.on_new_room_event_args(
event_pos=event_pos,
room_id=event.room_id,
event_type=event.type,
state_key=event.get("state_key"),
membership=event.content.get("membership"),
max_room_stream_token=max_room_stream_token,
extra_users=extra_users,
)
def on_new_room_event_args(
self,
room_id: str,
event_type: str,
state_key: Optional[str],
membership: Optional[str],
event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
"""Used by handlers to inform the notifier something has happened
in the room, room event wise.
This triggers the notifier to wake up any listeners that are
@ -266,7 +298,16 @@ class Notifier:
until all previous events have been persisted before notifying
the client streams.
"""
self.pending_new_room_events.append((event_pos, event, extra_users))
self.pending_new_room_events.append(
_PendingRoomEventEntry(
event_pos=event_pos,
extra_users=extra_users,
room_id=room_id,
type=event_type,
state_key=state_key,
membership=membership,
)
)
self._notify_pending_new_room_events(max_room_stream_token)
self.notify_replication()
@ -284,18 +325,19 @@ class Notifier:
users = set() # type: Set[UserID]
rooms = set() # type: Set[str]
for event_pos, event, extra_users in pending:
if event_pos.persisted_after(max_room_stream_token):
self.pending_new_room_events.append((event_pos, event, extra_users))
for entry in pending:
if entry.event_pos.persisted_after(max_room_stream_token):
self.pending_new_room_events.append(entry)
else:
if (
event.type == EventTypes.Member
and event.membership == Membership.JOIN
entry.type == EventTypes.Member
and entry.membership == Membership.JOIN
and entry.state_key
):
self._user_joined_room(event.state_key, event.room_id)
self._user_joined_room(entry.state_key, entry.room_id)
users.update(extra_users)
rooms.add(event.room_id)
users.update(entry.extra_users)
rooms.add(entry.room_id)
if users or rooms:
self.on_new_event(

View file

@ -15,8 +15,8 @@
# limitations under the License.
import logging
from collections import namedtuple
import attr
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes
@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
dict of user_id -> push_rules
"""
room_id = event.room_id
rules_for_room = await self._get_rules_for_room(room_id)
rules_for_room = self._get_rules_for_room(room_id)
rules_by_user = await rules_for_room.get_rules(event, context)
@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
return rules_by_user
@cached()
@lru_cache()
def _get_rules_for_room(self, room_id):
"""Get the current RulesForRoom object for the given room id
@ -275,12 +276,14 @@ class RulesForRoom:
the entire cache for the room.
"""
def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
def __init__(
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
):
"""
Args:
hs (HomeServer)
room_id (str)
rules_for_room_cache(Cache): The cache object that caches these
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
room_push_rule_cache_metrics (CacheMetric)
"""
@ -489,13 +492,21 @@ class RulesForRoom:
self.state_group = state_group
class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
# which namedtuple does for us (i.e. two _CacheContext are the same if
# their caches and keys match). This is important in particular to
# dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow.
@attr.attrs(slots=True, frozen=True)
class _Invalidation:
# _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
# which means that it it is stored on the bulk_get_push_rules cache entry. In order
# to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
# we need to ensure that two _Invalidation objects are "equal" if they refer to the
# same `cache` and `room_id`.
#
# attrs provides suitable __hash__ and __eq__ methods, provided we remember to
# set `frozen=True`.
cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str)
def __call__(self):
rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()

View file

@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
if requester.user:
request.authenticated_entity = requester.user.to_string()
request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id)
@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
if requester.user:
request.authenticated_entity = requester.user.to_string()
request.requester = requester
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite(

View file

@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
if requester.user:
request.authenticated_entity = requester.user.to_string()
request.requester = requester
logger.info(
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id

View file

@ -141,21 +141,25 @@ class ReplicationDataHandler:
if row.type != EventsStreamEventRow.TypeId:
continue
assert isinstance(row, EventsStreamRow)
assert isinstance(row.data, EventsStreamEventRow)
event = await self.store.get_event(
row.data.event_id, allow_rejected=True
)
if event.rejected_reason:
if row.data.rejected:
continue
extra_users = () # type: Tuple[UserID, ...]
if event.type == EventTypes.Member:
extra_users = (UserID.from_string(event.state_key),)
if row.data.type == EventTypes.Member and row.data.state_key:
extra_users = (UserID.from_string(row.data.state_key),)
max_token = self.store.get_room_max_token()
event_pos = PersistedEventPosition(instance_name, token)
self.notifier.on_new_room_event(
event, event_pos, max_token, extra_users
self.notifier.on_new_room_event_args(
event_pos=event_pos,
max_room_stream_token=max_token,
extra_users=extra_users,
room_id=row.data.room_id,
event_type=row.data.type,
state_key=row.data.state_key,
membership=row.data.membership,
)
# Notify any waiting deferreds. The list is ordered by position so we

View file

@ -15,12 +15,15 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
from typing import List, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
if TYPE_CHECKING:
from synapse.server import HomeServer
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@ -81,12 +84,14 @@ class BaseEventsStreamRow:
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
event_id = attr.ib() # str
room_id = attr.ib() # str
type = attr.ib() # str
state_key = attr.ib() # str, optional
redacts = attr.ib() # str, optional
relates_to = attr.ib() # str, optional
event_id = attr.ib(type=str)
room_id = attr.ib(type=str)
type = attr.ib(type=str)
state_key = attr.ib(type=Optional[str])
redacts = attr.ib(type=Optional[str])
relates_to = attr.ib(type=Optional[str])
membership = attr.ib(type=Optional[str])
rejected = attr.ib(type=bool)
@attr.s(slots=True, frozen=True)
@ -113,7 +118,7 @@ class EventsStream(Stream):
NAME = "events"
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),

View file

@ -50,6 +50,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
UserAdminServlet,
@ -226,8 +227,9 @@ def register_servlets(hs, http_server):
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):

View file

@ -39,6 +39,17 @@ from synapse.types import JsonDict, UserID
logger = logging.getLogger(__name__)
_GET_PUSHERS_ALLOWED_KEYS = {
"app_display_name",
"app_id",
"data",
"device_display_name",
"kind",
"lang",
"profile_tag",
"pushkey",
}
class UsersRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
@ -713,6 +724,47 @@ class UserMembershipRestServlet(RestServlet):
return 200, ret
class PushersRestServlet(RestServlet):
"""
Gets information about all pushers for a specific `user_id`.
Example:
http://localhost:8008/_synapse/admin/v1/users/
@user:server/pushers
Returns:
pushers: Dictionary containing pushers information.
total: Number of pushers in dictonary `pushers`.
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
def __init__(self, hs):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only lookup local users")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
pushers = await self.store.get_pushers_by_user_id(user_id)
filtered_pushers = [
{k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
for p in pushers
]
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
class UserMediaRestServlet(RestServlet):
"""
Gets information about all uploaded local media for a specific `user_id`.

View file

@ -305,15 +305,12 @@ class MediaRepository:
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
# one.
if media_info:
file_id = media_info["filesystem_id"]
else:
file_id = random_string(24)
file_info = FileInfo(server_name, file_id)
# If we have an entry in the DB, try and look for it
if media_info:
file_id = media_info["filesystem_id"]
file_info = FileInfo(server_name, file_id)
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
@ -324,14 +321,34 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it.
media_info = await self._download_remote_file(server_name, media_id, file_id)
try:
media_info = await self._download_remote_file(server_name, media_id,)
except SynapseError:
raise
except Exception as e:
# An exception may be because we downloaded media in another
# process, so let's check if we magically have the media.
media_info = await self.store.get_cached_remote_media(server_name, media_id)
if not media_info:
raise e
file_id = media_info["filesystem_id"]
file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media
# as a) it's conceivable that the other download request dies before it
# generates thumbnails, but mainly b) we want to be sure the thumbnails
# have finished being generated before responding to the client,
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"]
)
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@ -346,6 +363,8 @@ class MediaRepository:
The media info of the file.
"""
file_id = random_string(24)
file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
@ -401,22 +420,32 @@ class MediaRepository:
await finish()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
# Multiple remote media download requests can race (when using
# multiple media repos), so this may throw a violation constraint
# exception. If it does we'll delete the newly downloaded file from
# disk (as we're in the ctx manager).
#
# However: we've already called `finish()` so we may have also
# written to the storage providers. This is preferable to the
# alternative where we call `finish()` *after* this, where we could
# end up having an entry in the DB but fail to write the files to
# the storage providers.
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
logger.info("Stored remote media in file %r", fname)
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
media_info = {
"media_type": media_type,
"media_length": length,
@ -425,8 +454,6 @@ class MediaRepository:
"filesystem_id": file_id,
}
await self._generate_thumbnails(server_name, media_id, file_id, media_type)
return media_info
def _get_thumbnail_requirements(self, media_type):
@ -692,42 +719,60 @@ class MediaRepository:
if not t_byte_source:
continue
try:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
t_byte_source.close()
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally:
t_byte_source.close()
t_len = os.path.getsize(output_path)
t_len = os.path.getsize(fname)
# Write to database
if server_name:
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
t_width,
t_height,
t_type,
t_method,
t_len,
)
else:
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
# Write to database
if server_name:
# Multiple remote media download requests can race (when
# using multiple media repos), so this may throw a violation
# constraint exception. If it does we'll delete the newly
# generated thumbnail from disk (as we're in the ctx
# manager).
#
# However: we've already called `finish()` so we may have
# also written to the storage providers. This is preferable
# to the alternative where we call `finish()` *after* this,
# where we could end up having an entry in the DB but fail
# to write the files to the storage providers.
try:
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
t_width,
t_height,
t_type,
t_method,
t_len,
)
except Exception as e:
thumbnail_exists = await self.store.get_remote_media_thumbnail(
server_name, media_id, t_width, t_height, t_type,
)
if not thumbnail_exists:
raise e
else:
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}

View file

@ -52,6 +52,7 @@ class MediaStorage:
storage_providers: Sequence["StorageProviderWrapper"],
):
self.hs = hs
self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@ -70,13 +71,16 @@ class MediaStorage:
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
await self.write_to_file(source, f)
await finish_cb()
return fname
async def write_to_file(self, source: IO, output: IO):
"""Asynchronously write the `source` to `output`.
"""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
@ -112,14 +116,20 @@ class MediaStorage:
finished_called = [False]
async def finish():
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
try:
with open(fname, "wb") as f:
async def finish():
# Ensure that all writes have been flushed and close the
# file.
f.flush()
f.close()
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
except Exception:
try:
@ -210,7 +220,7 @@ class MediaStorage:
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()

View file

@ -94,7 +94,7 @@ def make_pool(
cp_openfun=lambda conn: engine.on_new_connection(
LoggingDatabaseConnection(conn, engine, "on_new_connection")
),
**db_config.config.get("args", {})
**db_config.config.get("args", {}),
)
@ -632,7 +632,7 @@ class DatabasePool:
func,
*args,
db_autocommit=db_autocommit,
**kwargs
**kwargs,
)
for after_callback, after_args, after_kwargs in after_callbacks:

View file

@ -15,21 +15,31 @@
# limitations under the License.
import logging
import re
from typing import List
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
)
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache):
def _make_exclusive_regex(
services_cache: List[ApplicationService],
) -> Optional[Pattern]:
# We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
exclusive_user_regex = re.compile(exclusive_user_regex)
exclusive_user_pattern = re.compile(
exclusive_user_regex
) # type: Optional[Pattern]
else:
# We handle this case specially otherwise the constructed regex
# will always match
exclusive_user_regex = None
exclusive_user_pattern = None
return exclusive_user_regex
return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
def get_app_services(self):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id):
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
"""Check if the user is one associated with an app service (exclusively)
"""
if self.exclusive_user_regex:
@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
else:
return False
def get_app_service_by_user_id(self, user_id):
def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
a user ID to an application service.
Args:
user_id(str): The user ID to see if it is an application service.
user_id: The user ID to see if it is an application service.
Returns:
synapse.appservice.ApplicationService or None.
The application service or None.
"""
for service in self.services_cache:
if service.sender == user_id:
return service
return None
def get_app_service_by_token(self, token):
def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice token.
Args:
token (str): The application service token.
token: The application service token.
Returns:
synapse.appservice.ApplicationService or None.
The application service or None.
"""
for service in self.services_cache:
if service.token == token:
return service
return None
def get_app_service_by_id(self, as_id):
def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice ID.
Args:
as_id (str): The application service ID.
as_id: The application service ID.
Returns:
synapse.appservice.ApplicationService or None.
The application service or None.
"""
for service in self.services_cache:
if service.id == as_id:
@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
async def get_appservices_by_state(self, state):
async def get_appservices_by_state(
self, state: ApplicationServiceState
) -> List[ApplicationService]:
"""Get a list of application services based on their state.
Args:
state(ApplicationServiceState): The state to filter on.
state: The state to filter on.
Returns:
A list of ApplicationServices, which may be empty.
"""
@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service)
return services
async def get_appservice_state(self, service):
async def get_appservice_state(
self, service: ApplicationService
) -> Optional[ApplicationServiceState]:
"""Get the application service state.
Args:
service(ApplicationService): The service whose state to set.
service: The service whose state to set.
Returns:
An ApplicationServiceState.
An ApplicationServiceState or none.
"""
result = await self.db_pool.simple_select_one(
"application_services_state",
@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
async def set_appservice_state(self, service, state) -> None:
async def set_appservice_state(
self, service: ApplicationService, state: ApplicationServiceState
) -> None:
"""Set the application service state.
Args:
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
service: The service whose state to set.
state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
"create_appservice_txn", _create_appservice_txn
)
async def complete_appservice_txn(self, txn_id, service) -> None:
async def complete_appservice_txn(
self, txn_id: int, service: ApplicationService
) -> None:
"""Completes an application service transaction.
Args:
txn_id(str): The transaction ID being completed.
service(ApplicationService): The application service which was sent
this transaction.
txn_id: The transaction ID being completed.
service: The application service which was sent this transaction.
"""
txn_id = int(txn_id)
@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
# has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id)
if (last_txn_id + 1) != txn_id:
if (txn_id + 1) != txn_id:
logger.error(
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
"complete_appservice_txn", _complete_appservice_txn
)
async def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
service.
async def get_oldest_unsent_txn(
self, service: ApplicationService
) -> Optional[AppServiceTransaction]:
"""Get the oldest transaction which has not been sent for this service.
Args:
service(ApplicationService): The app service to get the oldest txn.
service: The app service to get the oldest txn.
Returns:
An AppServiceTransaction or None.
"""
@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
service=service, id=entry["txn_id"], events=events, ephemeral=[]
)
def _get_last_txn(self, txn, service_id):
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,),
@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
async def set_appservice_last_pos(self, pos) -> None:
async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
async def get_new_events_for_appservice(self, current_id, limit):
async def get_new_events_for_appservice(
self, current_id: int, limit: int
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
)
async def set_type_stream_id_for_appservice(
self, service: ApplicationService, type: str, pos: int
self, service: ApplicationService, type: str, pos: Optional[int]
) -> None:
if type not in ("read_receipt", "presence"):
raise ValueError(

View file

@ -22,7 +22,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -104,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
pruned_json = frozendict_json_encoder.encode(
pruned_json = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@ -170,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
# Prune the event's dict then convert it to JSON.
pruned_json = frozendict_json_encoder.encode(
pruned_json = json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict())
)

View file

@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@ -769,9 +769,7 @@ class PersistEventsStore:
logger.exception("")
raise
metadata_json = frozendict_json_encoder.encode(
event.internal_metadata.get_dict()
)
metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@ -826,10 +824,10 @@ class PersistEventsStore:
{
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": frozendict_json_encoder.encode(
"internal_metadata": json_encoder.encode(
event.internal_metadata.get_dict()
),
"json": frozendict_json_encoder.encode(event_dict(event)),
"json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
}
for event, _ in events_and_contexts

View file

@ -31,6 +31,7 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import (
@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
return event_map
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
state_types_to_include: List[EventTypes],
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
Retrieve the stripped state from a room, given an event context to retrieve state
from as well as the state types to include. Optionally, include the membership
events from a specific user.
"Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
are included from each state event.
Args:
context: The event context to retrieve state of the room from.
state_types_to_include: The type of state events to include.
membership_user_id: An optional user ID to include the stripped membership state
events of. This is useful when generating the stripped state of a room for
invites. We want to send membership events of the inviter, so that the
invitee can display the inviter's profile information if the room lacks any.
Returns:
A list of dictionaries, each representing a stripped state event from the room.
"""
current_state_ids = await context.get_current_state_ids()
# We know this event is not an outlier, so this must be
# non-None.
assert current_state_ids is not None
# The state to include
state_to_include_ids = [
e_id
for k, e_id in current_state_ids.items()
if k[0] in state_types_to_include
or (membership_user_id and k == (EventTypes.Member, membership_user_id))
]
state_to_include = await self.get_events(state_to_include_ids)
return [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for e in state_to_include.values()
]
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
@ -1065,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" AND instance_name = ?"
" ORDER BY stream_ordering ASC"
@ -1100,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" AND out.instance_name = ?"

View file

@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails",
)
async def get_remote_media_thumbnail(
self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
) -> Optional[Dict[str, Any]]:
"""Fetch the thumbnail info of given width, height and type.
"""
return await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
keyvalues={
"media_origin": origin,
"media_id": media_id,
"thumbnail_width": t_width,
"thumbnail_height": t_height,
"thumbnail_type": t_type,
},
retcols=(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
"filesystem_id",
),
allow_none=True,
desc="get_remote_media_thumbnail",
)
async def store_remote_media_thumbnail(
self,
origin,

View file

@ -18,6 +18,8 @@ import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
@attr.s(frozen=True, slots=True)
class TokenLookupResult:
"""Result of looking up an access token.
Attributes:
user_id: The user that this token authenticates as
is_guest
shadow_banned
token_id: The ID of the access token looked up
device_id: The device associated with the token, if any.
valid_until_ms: The timestamp the token expires, if any.
token_owner: The "owner" of the token. This is either the same as the
user, or a server admin who is logged in as the user.
"""
user_id = attr.ib(type=str)
is_guest = attr.ib(type=bool, default=False)
shadow_banned = attr.ib(type=bool, default=False)
token_id = attr.ib(type=Optional[int], default=None)
device_id = attr.ib(type=Optional[str], default=None)
valid_until_ms = attr.ib(type=Optional[int], default=None)
token_owner = attr.ib(type=str)
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
def _default_token_owner(self):
return self.user_id
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return is_trial
@cached()
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token.
Args:
token: The access token of a user.
Returns:
None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
None, if the token did not match, otherwise a `TokenLookupResult`
"""
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name,
SELECT users.name as user_id,
users.is_guest,
users.shadow_banned,
access_tokens.id as token_id,
access_tokens.device_id,
access_tokens.valid_until_ms
access_tokens.valid_until_ms,
access_tokens.user_id as token_owner
FROM users
INNER JOIN access_tokens on users.name = access_tokens.user_id
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ?
"""
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]
return TokenLookupResult(**rows[0])
return None

View file

@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Whether the access token is an admin token for controlling another user.
ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;

View file

@ -29,6 +29,7 @@ from typing import (
Tuple,
Type,
TypeVar,
Union,
)
import attr
@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore
# define a version of typing.Collection that works on python 3.5
@ -74,6 +76,7 @@ class Requester(
"shadow_banned",
"device_id",
"app_service",
"authenticated_entity",
],
)
):
@ -104,6 +107,7 @@ class Requester(
"shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
"authenticated_entity": self.authenticated_entity,
}
@staticmethod
@ -129,16 +133,18 @@ class Requester(
shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
authenticated_entity=input["authenticated_entity"],
)
def create_requester(
user_id,
access_token_id=None,
is_guest=False,
shadow_banned=False,
device_id=None,
app_service=None,
user_id: Union[str, "UserID"],
access_token_id: Optional[int] = None,
is_guest: Optional[bool] = False,
shadow_banned: Optional[bool] = False,
device_id: Optional[str] = None,
app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
):
"""
Create a new ``Requester`` object
@ -151,14 +157,27 @@ def create_requester(
shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
Returns:
Requester
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
if authenticated_entity is None:
authenticated_entity = user_id.to_string()
return Requester(
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
user_id,
access_token_id,
is_guest,
shadow_banned,
device_id,
app_service,
authenticated_entity,
)

View file

@ -18,6 +18,7 @@ import logging
import re
import attr
from frozendict import frozendict
from twisted.internet import defer, task
@ -31,9 +32,26 @@ def _reject_invalid_json(val):
raise ValueError("Invalid JSON value: '%s'" % val)
# Create a custom encoder to reduce the whitespace produced by JSON encoding and
# ensure that valid JSON is produced.
json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
def _handle_frozendict(obj):
"""Helper for json_encoder. Makes frozendicts serializable by returning
the underlying dict
"""
if type(obj) is frozendict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
# A custom JSON encoder which:
# * handles frozendicts
# * produces valid JSON (no NaNs etc)
# * reduces redundant whitespace
json_encoder = json.JSONEncoder(
allow_nan=False, separators=(",", ":"), default=_handle_frozendict
)
# Create a custom decoder to reject Python extensions to JSON.
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)

View file

@ -13,10 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import functools
import inspect
import logging
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from typing import (
Any,
Callable,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
from weakref import WeakValueDictionary
from twisted.internet import defer
@ -24,6 +37,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@ -97,8 +111,107 @@ class _CacheDescriptorBase:
self.add_cache_context = cache_context
self.cache_key_builder = get_cache_key_builder(
self.arg_names, self.arg_defaults
)
class CacheDescriptor(_CacheDescriptorBase):
class _LruCachedFunction(Generic[F]):
cache = None # type: LruCache[CacheKey, Any]
__call__ = None # type: F
def lru_cache(
max_entries: int = 1000, cache_context: bool = False,
) -> Callable[[F], _LruCachedFunction[F]]:
"""A method decorator that applies a memoizing cache around the function.
This is more-or-less a drop-in equivalent to functools.lru_cache, although note
that the signature is slightly different.
The main differences with functools.lru_cache are:
(a) the size of the cache can be controlled via the cache_factor mechanism
(b) the wrapped function can request a "cache_context" which provides a
callback mechanism to indicate that the result is no longer valid
(c) prometheus metrics are exposed automatically.
The function should take zero or more arguments, which are used as the key for the
cache. Single-argument functions use that argument as the cache key; otherwise the
arguments are built into a tuple.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example:
@lru_cache(cache_context=True)
def foo(self, key, cache_context):
r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
return r1 + r2
The wrapped function also has a 'cache' property which offers direct access to the
underlying LruCache.
"""
def func(orig: F) -> _LruCachedFunction[F]:
desc = LruCacheDescriptor(
orig, max_entries=max_entries, cache_context=cache_context,
)
return cast(_LruCachedFunction[F], desc)
return func
class LruCacheDescriptor(_CacheDescriptorBase):
"""Helper for @lru_cache"""
class _Sentinel(enum.Enum):
sentinel = object()
def __init__(
self, orig, max_entries: int = 1000, cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
def __get__(self, obj, owner):
cache = LruCache(
cache_name=self.orig.__name__, max_size=self.max_entries,
) # type: LruCache[CacheKey, Any]
get_cache_key = self.cache_key_builder
sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else ()
cache_key = get_cache_key(args, kwargs)
ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
if ret != sentinel:
return ret
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
ret2 = self.orig(obj, *args, **kwargs)
cache.set(cache_key, ret2, callbacks=callbacks)
return ret2
wrapped = cast(_CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
class DeferredCacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_context=False,
iterable=False,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
self.max_entries = max_entries
@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
) # type: DeferredCache[CacheKey, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.
We loop through each arg name, looking up if its in the `kwargs`,
otherwise using the next argument in `args`. If there are no more
args then we try looking the arg name up in the defaults
"""
pos = 0
for nm in self.arg_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield self.arg_defaults[nm]
# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.
if self.num_args == 1:
nm = self.arg_names[0]
def get_cache_key(args, kwargs):
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return self.arg_defaults[nm]
else:
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return wrapped
class CacheListDescriptor(_CacheDescriptorBase):
class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
@ -382,11 +459,13 @@ class _CacheContext:
on a lower level.
"""
Cache = Union[DeferredCache, LruCache]
_cache_context_objects = (
WeakValueDictionary()
) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
self._cache = cache
self._cache_key = cache_key
@ -396,8 +475,8 @@ class _CacheContext:
@classmethod
def get_instance(
cls, cache, cache_key
): # type: (DeferredCache, CacheKey) -> _CacheContext
cls, cache: "_CacheContext.Cache", cache_key: CacheKey
) -> "_CacheContext":
"""Returns an instance constructed with the given arguments.
A new instance is only created if none already exists.
@ -418,7 +497,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: CacheDescriptor(
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
@ -460,7 +539,7 @@ def cachedList(
def batch_do_something(self, first_arg, second_args):
...
"""
func = lambda orig: CacheListDescriptor(
func = lambda orig: DeferredCacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
@ -468,3 +547,65 @@ def cachedList(
)
return cast(Callable[[F], _CachedFunction[F]], func)
def get_cache_key_builder(
param_names: Sequence[str], param_defaults: Mapping[str, Any]
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
"""Construct a function which will build cache keys suitable for a cached function
Args:
param_names: list of formal parameter names for the cached function
param_defaults: a mapping from parameter name to default value for that param
Returns:
A function which will take an (args, kwargs) pair and return a cache key
"""
# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.
if len(param_names) == 1:
nm = param_names[0]
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return param_defaults[nm]
else:
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
return get_cache_key
def _get_cache_key_gen(
param_names: Iterable[str],
param_defaults: Mapping[str, Any],
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> Iterable[Any]:
"""Given some args/kwargs return a generator that resolves into
the cache_key.
This is essentially the same operation as `inspect.getcallargs`, but optimised so
that we don't need to inspect the target function for each call.
"""
# We loop through each arg name, looking up if its in the `kwargs`,
# otherwise using the next argument in `args`. If there are no more
# args then we try looking the arg name up in the defaults.
pos = 0
for nm in param_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield param_defaults[nm]

View file

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from frozendict import frozendict
@ -49,23 +47,3 @@ def unfreeze(o):
pass
return o
def _handle_frozendict(obj):
"""Helper for EventEncoder. Makes frozendicts serializable by returning
the underlying dict
"""
if type(obj) is frozendict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
# A JSONEncoder which is capable of encoding frozendicts without barfing.
# Additionally reduce the whitespace produced by JSON encoding.
frozendict_json_encoder = json.JSONEncoder(
allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
)

View file

@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
failure_ts,
retry_interval,
backoff_on_failure=backoff_on_failure,
**kwargs
**kwargs,
)

View file

@ -21,45 +21,6 @@ except ImportError:
from twisted.internet.pollreactor import PollReactor as Reactor
from twisted.internet.main import installReactor
from synapse.config.homeserver import HomeServerConfig
from synapse.util import Clock
from tests.utils import default_config, setup_test_homeserver
async def make_homeserver(reactor, config=None):
"""
Make a Homeserver suitable for running benchmarks against.
Args:
reactor: A Twisted reactor to run under.
config: A HomeServerConfig to use, or None.
"""
cleanup_tasks = []
clock = Clock(reactor)
if not config:
config = default_config("test")
config_obj = HomeServerConfig()
config_obj.parse_config_dict(config, "", "")
hs = setup_test_homeserver(
cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
)
stor = hs.get_datastore()
# Run the database background updates.
if hasattr(stor.db_pool.updates, "do_next_background_update"):
while not await stor.db_pool.updates.has_completed_background_updates():
await stor.db_pool.updates.do_next_background_update(1)
def cleanup():
for i in cleanup_tasks:
i()
return hs, clock.sleep, cleanup
def make_reactor():
"""

View file

@ -12,20 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from argparse import REMAINDER
from contextlib import redirect_stderr
from io import StringIO
import pyperf
from synmark import make_reactor
from synmark.suites import SUITES
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.logger import globalLogBeginner, textFileLogObserver
from twisted.python.failure import Failure
from synmark import make_reactor
from synmark.suites import SUITES
from tests.utils import setupdb

View file

@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import warnings
from io import StringIO
from mock import Mock
from pyperf import perf_counter
from synmark import make_homeserver
from twisted.internet.defer import Deferred
from twisted.internet.protocol import ServerFactory
from twisted.logger import LogBeginner, Logger, LogPublisher
from twisted.logger import LogBeginner, LogPublisher
from twisted.protocols.basic import LineOnlyReceiver
from synapse.logging._structured import setup_structured_logging
from synapse.config.logger import _setup_stdlib_logging
from synapse.logging import RemoteHandler
from synapse.util import Clock
class LineCounter(LineOnlyReceiver):
@ -62,7 +64,15 @@ async def main(reactor, loops):
logger_factory.on_done = Deferred()
port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1")
hs, wait, cleanup = await make_homeserver(reactor)
# A fake homeserver config.
class Config:
server_name = "synmark-" + str(loops)
no_redirect_stdio = True
hs_config = Config()
# To be able to sleep.
clock = Clock(reactor)
errors = StringIO()
publisher = LogPublisher()
@ -72,47 +82,49 @@ async def main(reactor, loops):
)
log_config = {
"loggers": {"synapse": {"level": "DEBUG"}},
"drains": {
"version": 1,
"loggers": {"synapse": {"level": "DEBUG", "handlers": ["tersejson"]}},
"formatters": {"tersejson": {"class": "synapse.logging.TerseJsonFormatter"}},
"handlers": {
"tersejson": {
"type": "network_json_terse",
"class": "synapse.logging.RemoteHandler",
"host": "127.0.0.1",
"port": port.getHost().port,
"maximum_buffer": 100,
"_reactor": reactor,
}
},
}
logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher)
logging_system = setup_structured_logging(
hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False
logger = logging.getLogger("synapse.logging.test_terse_json")
_setup_stdlib_logging(
hs_config, log_config, logBeginner=beginner,
)
# Wait for it to connect...
await logging_system._observers[0]._service.whenConnected()
for handler in logging.getLogger("synapse").handlers:
if isinstance(handler, RemoteHandler):
break
else:
raise RuntimeError("Improperly configured: no RemoteHandler found.")
await handler._service.whenConnected()
start = perf_counter()
# Send a bunch of useful messages
for i in range(0, loops):
logger.info("test message %s" % (i,))
logger.info("test message %s", i)
if (
len(logging_system._observers[0]._buffer)
== logging_system._observers[0].maximum_buffer
):
while (
len(logging_system._observers[0]._buffer)
> logging_system._observers[0].maximum_buffer / 2
):
await wait(0.01)
if len(handler._buffer) == handler.maximum_buffer:
while len(handler._buffer) > handler.maximum_buffer / 2:
await clock.sleep(0.01)
await logger_factory.on_done
end = perf_counter() - start
logging_system.stop()
handler.close()
port.stopListening()
cleanup()
return end

View file

@ -29,6 +29,7 @@ from synapse.api.errors import (
MissingClientTokenError,
ResourceLimitError,
)
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(
{"name": "@baldrick:matrix.org", "device_id": "device"}
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
)
@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
self.assertEqual(user_id, user_info.user_id)
# TODO: device_id should come from the macaroon, but currently comes
# from the db.
self.assertEqual(user_info["device_id"], "device")
self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest)
self.assertEqual(user_id, user_info.user_id)
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks
@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
if token != tok:
return defer.succeed(None)
return defer.succeed(
{
"name": USER_ID,
"is_guest": False,
"token_id": 1234,
"device_id": "DEVICE",
}
TokenLookupResult(
user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
)
)
self.store.get_user_by_access_token = get_user

View file

@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=True,
None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=False,
None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)

View file

@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
id="unique_identifier",
sender="@as:test",
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user

View file

@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that our device ID has changed
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
self.assertEqual(user_info["device_id"], retrieved_device_id)
self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))

View file

@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
)
self.token_id = self.info["token_id"]
self.token_id = self.info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)

View file

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
class LoggerCleanupMixin:
def get_logger(self, handler):
"""
Attach a handler to a logger and add clean-ups to remove revert this.
"""
# Create a logger and add the handler to it.
logger = logging.getLogger(__name__)
logger.addHandler(handler)
# Ensure the logger actually logs something.
logger.setLevel(logging.INFO)
# Ensure the logger gets cleaned-up appropriately.
self.addCleanup(logger.removeHandler, handler)
self.addCleanup(logger.setLevel, logging.NOTSET)
return logger

View file

@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import AccumulatingProtocol
from synapse.logging import RemoteHandler
from tests.logging import LoggerCleanupMixin
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
def connect_logging_client(reactor, client_id):
# This is essentially tests.server.connect_client, but disabling autoflush on
# the client transport. This is necessary to avoid an infinite loop due to
# sending of data via the logging transport causing additional logs to be
# written.
factory = reactor.tcpClients.pop(client_id)[2]
client = factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, reactor))
client.makeConnection(FakeTransport(server, reactor, autoflush=False))
return client, server
class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
def setUp(self):
self.reactor, _ = get_clock()
def test_log_output(self):
"""
The remote handler delivers logs over TCP.
"""
handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor)
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
# Trigger the connection
client, server = connect_logging_client(self.reactor, 0)
# Trigger data being sent
client.transport.flush()
# One log message, with a single trailing newline
logs = server.data.decode("utf8").splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(server.data.count(b"\n"), 1)
# Ensure the data passed through properly.
self.assertEqual(logs[0], "Hello there, wally!")
def test_log_backpressure_debug(self):
"""
When backpressure is hit, DEBUG logs will be shed.
"""
handler = RemoteHandler(
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
)
logger = self.get_logger(handler)
# Send some debug messages
for i in range(0, 3):
logger.debug("debug %s" % (i,))
# Send a bunch of useful messages
for i in range(0, 7):
logger.info("info %s" % (i,))
# The last debug message pushes it past the maximum buffer
logger.debug("too much debug")
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
client.transport.flush()
# Only the 7 infos made it through, the debugs were elided
logs = server.data.splitlines()
self.assertEqual(len(logs), 7)
self.assertNotIn(b"debug", server.data)
def test_log_backpressure_info(self):
"""
When backpressure is hit, DEBUG and INFO logs will be shed.
"""
handler = RemoteHandler(
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
)
logger = self.get_logger(handler)
# Send some debug messages
for i in range(0, 3):
logger.debug("debug %s" % (i,))
# Send a bunch of useful messages
for i in range(0, 10):
logger.warning("warn %s" % (i,))
# Send a bunch of info messages
for i in range(0, 3):
logger.info("info %s" % (i,))
# The last debug message pushes it past the maximum buffer
logger.debug("too much debug")
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
client.transport.flush()
# The 10 warnings made it through, the debugs and infos were elided
logs = server.data.splitlines()
self.assertEqual(len(logs), 10)
self.assertNotIn(b"debug", server.data)
self.assertNotIn(b"info", server.data)
def test_log_backpressure_cut_middle(self):
"""
When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
it will cut the middle messages out.
"""
handler = RemoteHandler(
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
)
logger = self.get_logger(handler)
# Send a bunch of useful messages
for i in range(0, 20):
logger.warning("warn %s" % (i,))
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
client.transport.flush()
# The first five and last five warnings made it through, the debugs and
# infos were elided
logs = server.data.decode("utf8").splitlines()
self.assertEqual(
["warn %s" % (i,) for i in range(5)]
+ ["warn %s" % (i,) for i in range(15, 20)],
logs,
)
def test_cancel_connection(self):
"""
Gracefully handle the connection being cancelled.
"""
handler = RemoteHandler(
"127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
)
logger = self.get_logger(handler)
# Send a message.
logger.info("Hello there, %s!", "wally")
# Do not accept the connection and shutdown. This causes the pending
# connection to be cancelled (and should not raise any exceptions).
handler.close()

View file

@ -1,214 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import os.path
import shutil
import sys
import textwrap
from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile
from synapse.config.logger import setup_logging
from synapse.logging._structured import setup_structured_logging
from synapse.logging.context import LoggingContext
from tests.unittest import DEBUG, HomeserverTestCase
class FakeBeginner:
def beginLoggingTo(self, observers, **kwargs):
self.observers = observers
class StructuredLoggingTestBase:
"""
Test base that registers a cleanup handler to reset the stdlib log handler
to 'unset'.
"""
def prepare(self, reactor, clock, hs):
def _cleanup():
logging.getLogger("synapse").setLevel(logging.NOTSET)
self.addCleanup(_cleanup)
class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase):
"""
Tests for Synapse's structured logging support.
"""
def test_output_to_json_round_trip(self):
"""
Synapse logs can be outputted to JSON and then read back again.
"""
temp_dir = self.mktemp()
os.mkdir(temp_dir)
self.addCleanup(shutil.rmtree, temp_dir)
json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json"))
log_config = {
"drains": {"jsonfile": {"type": "file_json", "location": json_log_file}}
}
# Begin the logger with our config
beginner = FakeBeginner()
setup_structured_logging(
self.hs, self.hs.config, log_config, logBeginner=beginner
)
# Make a logger and send an event
logger = Logger(
namespace="tests.logging.test_structured", observer=beginner.observers[0]
)
logger.info("Hello there, {name}!", name="wally")
# Read the log file and check it has the event we sent
with open(json_log_file, "r") as f:
logged_events = list(eventsFromJSONLogFile(f))
self.assertEqual(len(logged_events), 1)
# The event pulled from the file should render fine
self.assertEqual(
eventAsText(logged_events[0], includeTimestamp=False),
"[tests.logging.test_structured#info] Hello there, wally!",
)
def test_output_to_text(self):
"""
Synapse logs can be outputted to text.
"""
temp_dir = self.mktemp()
os.mkdir(temp_dir)
self.addCleanup(shutil.rmtree, temp_dir)
log_file = os.path.abspath(os.path.join(temp_dir, "out.log"))
log_config = {"drains": {"file": {"type": "file", "location": log_file}}}
# Begin the logger with our config
beginner = FakeBeginner()
setup_structured_logging(
self.hs, self.hs.config, log_config, logBeginner=beginner
)
# Make a logger and send an event
logger = Logger(
namespace="tests.logging.test_structured", observer=beginner.observers[0]
)
logger.info("Hello there, {name}!", name="wally")
# Read the log file and check it has the event we sent
with open(log_file, "r") as f:
logged_events = f.read().strip().split("\n")
self.assertEqual(len(logged_events), 1)
# The event pulled from the file should render fine
self.assertTrue(
logged_events[0].endswith(
" - tests.logging.test_structured - INFO - None - Hello there, wally!"
)
)
def test_collects_logcontext(self):
"""
Test that log outputs have the attached logging context.
"""
log_config = {"drains": {}}
# Begin the logger with our config
beginner = FakeBeginner()
publisher = setup_structured_logging(
self.hs, self.hs.config, log_config, logBeginner=beginner
)
logs = []
publisher.addObserver(logs.append)
# Make a logger and send an event
logger = Logger(
namespace="tests.logging.test_structured", observer=beginner.observers[0]
)
with LoggingContext("testcontext", request="somereq"):
logger.info("Hello there, {name}!", name="steve")
self.assertEqual(len(logs), 1)
self.assertEqual(logs[0]["request"], "somereq")
class StructuredLoggingConfigurationFileTestCase(
StructuredLoggingTestBase, HomeserverTestCase
):
def make_homeserver(self, reactor, clock):
tempdir = self.mktemp()
os.mkdir(tempdir)
log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml"))
self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log"))
config = self.default_config()
config["log_config"] = log_config_file
with open(log_config_file, "w") as f:
f.write(
textwrap.dedent(
"""\
structured: true
drains:
file:
type: file_json
location: %s
"""
% (self.homeserver_log,)
)
)
self.addCleanup(self._sys_cleanup)
return self.setup_test_homeserver(config=config)
def _sys_cleanup(self):
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
# Do not remove! We need the logging system to be set other than WARNING.
@DEBUG
def test_log_output(self):
"""
When a structured logging config is given, Synapse will use it.
"""
beginner = FakeBeginner()
publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner)
# Make a logger and send an event
logger = Logger(namespace="tests.logging.test_structured", observer=publisher)
with LoggingContext("testcontext", request="somereq"):
logger.info("Hello there, {name}!", name="steve")
with open(self.homeserver_log, "r") as f:
logged_events = [
eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f)
]
logs = "\n".join(logged_events)
self.assertTrue("***** STARTING SERVER *****" in logs)
self.assertTrue("Hello there, steve!" in logs)

View file

@ -14,57 +14,33 @@
# limitations under the License.
import json
from collections import Counter
import logging
from io import StringIO
from twisted.logger import Logger
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging._structured import setup_structured_logging
from tests.server import connect_client
from tests.unittest import HomeserverTestCase
from .test_structured import FakeBeginner, StructuredLoggingTestBase
from tests.logging import LoggerCleanupMixin
from tests.unittest import TestCase
class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
def test_log_output(self):
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
def test_terse_json_output(self):
"""
The Terse JSON outputter delivers simplified structured logs over TCP.
The Terse JSON formatter converts log messages to JSON.
"""
log_config = {
"drains": {
"tersejson": {
"type": "network_json_terse",
"host": "127.0.0.1",
"port": 8000,
}
}
}
output = StringIO()
# Begin the logger with our config
beginner = FakeBeginner()
setup_structured_logging(
self.hs, self.hs.config, log_config, logBeginner=beginner
)
handler = logging.StreamHandler(output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
logger = Logger(
namespace="tests.logging.test_terse_json", observer=beginner.observers[0]
)
logger.info("Hello there, {name}!", name="wally")
logger.info("Hello there, %s!", "wally")
# Trigger the connection
self.pump()
_, server = connect_client(self.reactor, 0)
# Trigger data being sent
self.pump()
# One log message, with a single trailing newline
logs = server.data.decode("utf8").splitlines()
# One log message, with a single trailing newline.
data = output.getvalue()
logs = data.splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(server.data.count(b"\n"), 1)
self.assertEqual(data.count("\n"), 1)
log = json.loads(logs[0])
# The terse logger should give us these keys.
@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
"log",
"time",
"level",
"log_namespace",
"request",
"scope",
"server_name",
"name",
"namespace",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
def test_extra_data(self):
"""
Additional information can be included in the structured logging.
"""
output = StringIO()
handler = logging.StreamHandler(output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
logger.info(
"Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
)
# One log message, with a single trailing newline.
data = output.getvalue()
logs = data.splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(data.count("\n"), 1)
log = json.loads(logs[0])
# The terse logger should give us these keys.
expected_log_keys = [
"log",
"time",
"level",
"namespace",
# The additional keys given via extra.
"foo",
"int",
"bool",
]
self.assertCountEqual(log.keys(), expected_log_keys)
# It contains the data we expect.
self.assertEqual(log["name"], "wally")
# Check the values of the extra fields.
self.assertEqual(log["foo"], "bar")
self.assertEqual(log["int"], 3)
self.assertIs(log["bool"], True)
def test_log_backpressure_debug(self):
def test_json_output(self):
"""
When backpressure is hit, DEBUG logs will be shed.
The Terse JSON formatter converts log messages to JSON.
"""
log_config = {
"loggers": {"synapse": {"level": "DEBUG"}},
"drains": {
"tersejson": {
"type": "network_json_terse",
"host": "127.0.0.1",
"port": 8000,
"maximum_buffer": 10,
}
},
}
output = StringIO()
# Begin the logger with our config
beginner = FakeBeginner()
setup_structured_logging(
self.hs,
self.hs.config,
log_config,
logBeginner=beginner,
redirect_stdlib_logging=False,
)
handler = logging.StreamHandler(output)
handler.setFormatter(JsonFormatter())
logger = self.get_logger(handler)
logger = Logger(
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
)
logger.info("Hello there, %s!", "wally")
# Send some debug messages
for i in range(0, 3):
logger.debug("debug %s" % (i,))
# One log message, with a single trailing newline.
data = output.getvalue()
logs = data.splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(data.count("\n"), 1)
log = json.loads(logs[0])
# Send a bunch of useful messages
for i in range(0, 7):
logger.info("test message %s" % (i,))
# The last debug message pushes it past the maximum buffer
logger.debug("too much debug")
# Allow the reconnection
_, server = connect_client(self.reactor, 0)
self.pump()
# Only the 7 infos made it through, the debugs were elided
logs = server.data.splitlines()
self.assertEqual(len(logs), 7)
def test_log_backpressure_info(self):
"""
When backpressure is hit, DEBUG and INFO logs will be shed.
"""
log_config = {
"loggers": {"synapse": {"level": "DEBUG"}},
"drains": {
"tersejson": {
"type": "network_json_terse",
"host": "127.0.0.1",
"port": 8000,
"maximum_buffer": 10,
}
},
}
# Begin the logger with our config
beginner = FakeBeginner()
setup_structured_logging(
self.hs,
self.hs.config,
log_config,
logBeginner=beginner,
redirect_stdlib_logging=False,
)
logger = Logger(
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
)
# Send some debug messages
for i in range(0, 3):
logger.debug("debug %s" % (i,))
# Send a bunch of useful messages
for i in range(0, 10):
logger.warn("test warn %s" % (i,))
# Send a bunch of info messages
for i in range(0, 3):
logger.info("test message %s" % (i,))
# The last debug message pushes it past the maximum buffer
logger.debug("too much debug")
# Allow the reconnection
client, server = connect_client(self.reactor, 0)
self.pump()
# The 10 warnings made it through, the debugs and infos were elided
logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
self.assertEqual(len(logs), 10)
self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
def test_log_backpressure_cut_middle(self):
"""
When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
it will cut the middle messages out.
"""
log_config = {
"loggers": {"synapse": {"level": "DEBUG"}},
"drains": {
"tersejson": {
"type": "network_json_terse",
"host": "127.0.0.1",
"port": 8000,
"maximum_buffer": 10,
}
},
}
# Begin the logger with our config
beginner = FakeBeginner()
setup_structured_logging(
self.hs,
self.hs.config,
log_config,
logBeginner=beginner,
redirect_stdlib_logging=False,
)
logger = Logger(
namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
)
# Send a bunch of useful messages
for i in range(0, 20):
logger.warn("test warn", num=i)
# Allow the reconnection
client, server = connect_client(self.reactor, 0)
self.pump()
# The first five and last five warnings made it through, the debugs and
# infos were elided
logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
self.assertEqual(len(logs), 10)
self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs])
# The terse logger should give us these keys.
expected_log_keys = [
"log",
"level",
"namespace",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")

View file

@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(

View file

@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(

View file

@ -16,7 +16,6 @@ import logging
from typing import Any, Callable, List, Optional, Tuple
import attr
import hiredis
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@ -39,12 +38,22 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport, render
try:
import hiredis
except ImportError:
hiredis = None
logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
# hiredis is an optional dependency so we don't want to require it for running
# the tests.
if not hiredis:
skip = "Requires hiredis"
servlets = [
streams.register_servlets,
]
@ -269,7 +278,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
homeserver_to_use=GenericWorkerServer,
config=config,
reactor=self.reactor,
**kwargs
**kwargs,
)
# If the instance is in the `instance_map` config then workers may try

View file

@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
sender=sender,
type="test_event",
content={"body": body},
**kwargs
**kwargs,
)
)

View file

@ -0,0 +1,277 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from binascii import unhexlify
from typing import Tuple
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport
logger = logging.getLogger(__name__)
test_server_connection_factory = None
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks running multiple media repos work correctly.
"""
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.reactor.lookups["example.com"] = "127.0.0.2"
def default_config(self):
conf = super().default_config()
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
def _get_media_req(
self, hs: HomeServer, target: str, media_id: str
) -> Tuple[FakeChannel, Request]:
"""Request some remote media from the given HS by calling the download
API.
This then triggers an outbound request from the HS to the target.
Returns:
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""
request, channel = self.make_request(
"GET",
"/{}/{}".format(target, media_id),
shorthand=False,
access_token=self.access_token,
)
request.render(hs.get_media_repository_resource().children[b"download"])
self.pump()
clients = self.reactor.tcpClients
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
# build the test server
server_tls_protocol = _build_test_server(get_connection_factory())
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
# a FakeTransport.
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(
FakeTransport(server_tls_protocol, self.reactor, client_protocol)
)
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
# fish the test server back out of the server-side TLS protocol.
http_server = server_tls_protocol.wrappedProtocol
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
self.assertEqual(
request.path,
"/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
)
return channel, request
def test_basic(self):
"""Test basic fetching of remote media from a single worker.
"""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
request.setResponseCode(200)
request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
request.write(b"Hello!")
request.finish()
self.pump(0.1)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"Hello!")
def test_download_simple_file_race(self):
"""Test that fetching remote media from two different processes at the
same time works.
"""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
hs2 = self.make_worker_hs("synapse.app.generic_worker")
start_count = self._count_remote_media()
# Make two requests without responding to the outbound media requests.
channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
# Respond to the first outbound media request and check that the client
# request is successful
request1.setResponseCode(200)
request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
request1.write(b"Hello!")
request1.finish()
self.pump(0.1)
self.assertEqual(channel1.code, 200, channel1.result["body"])
self.assertEqual(channel1.result["body"], b"Hello!")
# Now respond to the second with the same content.
request2.setResponseCode(200)
request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
request2.write(b"Hello!")
request2.finish()
self.pump(0.1)
self.assertEqual(channel2.code, 200, channel2.result["body"])
self.assertEqual(channel2.result["body"], b"Hello!")
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
def test_download_image_race(self):
"""Test that fetching remote *images* from two different processes at
the same time works.
This checks that races generating thumbnails are handled correctly.
"""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
hs2 = self.make_worker_hs("synapse.app.generic_worker")
start_count = self._count_remote_thumbnails()
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
png_data = unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
)
request1.setResponseCode(200)
request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
request1.write(png_data)
request1.finish()
self.pump(0.1)
self.assertEqual(channel1.code, 200, channel1.result["body"])
self.assertEqual(channel1.result["body"], png_data)
request2.setResponseCode(200)
request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
request2.write(png_data)
request2.finish()
self.pump(0.1)
self.assertEqual(channel2.code, 200, channel2.result["body"])
self.assertEqual(channel2.result["body"], png_data)
# We expect only three new thumbnails to have been persisted.
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
def _count_remote_media(self) -> int:
"""Count the number of files in our remote media directory.
"""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_content"
)
return sum(len(files) for _, _, files in os.walk(path))
def _count_remote_thumbnails(self) -> int:
"""Count the number of files in our remote thumbnails directory.
"""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
)
return sum(len(files) for _, _, files in os.walk(path))
def get_connection_factory():
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
test_server_connection_factory = TestServerTLSConnectionFactory(
sanlist=[b"DNS:example.com"]
)
return test_server_connection_factory
def _build_test_server(connection_creator):
"""Construct a test server
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
Args:
connection_creator (IOpenSSLServerConnectionCreator): thing to build
SSL connections
sanlist (list[bytes]): list of the SAN entries for the cert returned
by the server
Returns:
TLSMemoryBIOProtocol
"""
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request
server_tls_factory = TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=server_factory
)
return server_tls_factory.buildProtocol(None)
def _log_request(request):
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)

View file

@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_dict["token_id"]
token_id = user_dict.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(

View file

@ -1118,6 +1118,130 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
class PushersRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote(
self.other_user
)
def test_no_auth(self):
"""
Try to list pushers of an user without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user", "pass")
request, channel = self.make_request(
"GET", self.url, access_token=other_user_token,
)
self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_get_pushers(self):
"""
Tests that a normal lookup for pushers is successfully
"""
# Get pushers
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
# Register the pusher
other_user_token = self.login("user", "pass")
user_tuple = self.get_success(
self.store.get_user_by_access_token(other_user_token)
)
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=self.other_user,
access_token=token_id,
kind="http",
app_id="m.http",
app_display_name="HTTP Push Notifications",
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
)
)
# Get pushers
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
for p in channel.json_body["pushers"]:
self.assertIn("pushkey", p)
self.assertIn("kind", p)
self.assertIn("app_id", p)
self.assertIn("app_display_name", p)
self.assertIn("device_display_name", p)
self.assertIn("profile_tag", p)
self.assertIn("lang", p)
self.assertIn("url", p["data"])
class UserMediaRestTestCase(unittest.HomeserverTestCase):
servlets = [

View file

@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
)
self.hs.get_datastore().services_cache.append(appservice)

Some files were not shown because too many files have changed in this diff Show more