0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-04-04 02:33:28 +02:00

Merge branch 'develop' of github.com:matrix-org/synapse into neilj/server_notices_on_blocking

This commit is contained in:
Erik Johnston 2018-08-22 17:06:10 +01:00
commit fd2dbf1836
116 changed files with 1625 additions and 765 deletions

View file

@ -1,3 +1,85 @@
Synapse 0.33.3 (2018-08-22)
===========================
Bugfixes
--------
- Fix bug introduced in v0.33.3rc1 which made the ToS give a 500 error ([\#3732](https://github.com/matrix-org/synapse/issues/3732))
Synapse 0.33.3rc2 (2018-08-21)
==============================
Bugfixes
--------
- Fix bug in v0.33.3rc1 which caused infinite loops and OOMs ([\#3723](https://github.com/matrix-org/synapse/issues/3723))
Synapse 0.33.3rc1 (2018-08-21)
==============================
Features
--------
- Add support for the SNI extension to federation TLS connections. Thanks to @vojeroen! ([\#3439](https://github.com/matrix-org/synapse/issues/3439))
- Add /_media/r0/config ([\#3184](https://github.com/matrix-org/synapse/issues/3184))
- speed up /members API and add `at` and `membership` params as per MSC1227 ([\#3568](https://github.com/matrix-org/synapse/issues/3568))
- implement `summary` block in /sync response as per MSC688 ([\#3574](https://github.com/matrix-org/synapse/issues/3574))
- Add lazy-loading support to /messages as per MSC1227 ([\#3589](https://github.com/matrix-org/synapse/issues/3589))
- Add ability to limit number of monthly active users on the server ([\#3633](https://github.com/matrix-org/synapse/issues/3633))
- Support more federation endpoints on workers ([\#3653](https://github.com/matrix-org/synapse/issues/3653))
- Basic support for room versioning ([\#3654](https://github.com/matrix-org/synapse/issues/3654))
- Ability to disable client/server Synapse via conf toggle ([\#3655](https://github.com/matrix-org/synapse/issues/3655))
- Ability to whitelist specific threepids against monthly active user limiting ([\#3662](https://github.com/matrix-org/synapse/issues/3662))
- Add some metrics for the appservice and federation event sending loops ([\#3664](https://github.com/matrix-org/synapse/issues/3664))
- Where server is disabled, block ability for locked out users to read new messages ([\#3670](https://github.com/matrix-org/synapse/issues/3670))
- set admin uri via config, to be used in error messages where the user should contact the administrator ([\#3687](https://github.com/matrix-org/synapse/issues/3687))
- Synapse's presence functionality can now be disabled with the "use_presence" configuration option. ([\#3694](https://github.com/matrix-org/synapse/issues/3694))
- For resource limit blocked users, prevent writing into rooms ([\#3708](https://github.com/matrix-org/synapse/issues/3708))
Bugfixes
--------
- Fix occasional glitches in the synapse_event_persisted_position metric ([\#3658](https://github.com/matrix-org/synapse/issues/3658))
- Fix bug on deleting 3pid when using identity servers that don't support unbind API ([\#3661](https://github.com/matrix-org/synapse/issues/3661))
- Make the tests pass on Twisted < 18.7.0 ([\#3676](https://github.com/matrix-org/synapse/issues/3676))
- Dont ship recaptcha_ajax.js, use it directly from Google ([\#3677](https://github.com/matrix-org/synapse/issues/3677))
- Fixes test_reap_monthly_active_users so it passes under postgres ([\#3681](https://github.com/matrix-org/synapse/issues/3681))
- Fix mau blocking calulation bug on login ([\#3689](https://github.com/matrix-org/synapse/issues/3689))
- Fix missing yield in synapse.storage.monthly_active_users.initialise_reserved_users ([\#3692](https://github.com/matrix-org/synapse/issues/3692))
- Improve HTTP request logging to include all requests ([\#3700](https://github.com/matrix-org/synapse/issues/3700))
- Avoid timing out requests while we are streaming back the response ([\#3701](https://github.com/matrix-org/synapse/issues/3701))
- Support more federation endpoints on workers ([\#3705](https://github.com/matrix-org/synapse/issues/3705), [\#3713](https://github.com/matrix-org/synapse/issues/3713))
- Fix "Starting db txn 'get_all_updated_receipts' from sentinel context" warning ([\#3710](https://github.com/matrix-org/synapse/issues/3710))
- Fix bug where `state_cache` cache factor ignored environment variables ([\#3719](https://github.com/matrix-org/synapse/issues/3719))
Deprecations and Removals
-------------------------
- The Shared-Secret registration method of the legacy v1/register REST endpoint has been removed. For a replacement, please see [the admin/register API documentation](https://github.com/matrix-org/synapse/blob/master/docs/admin_api/register_api.rst). ([\#3703](https://github.com/matrix-org/synapse/issues/3703))
Internal Changes
----------------
- The test suite now can run under PostgreSQL. ([\#3423](https://github.com/matrix-org/synapse/issues/3423))
- Refactor HTTP replication endpoints to reduce code duplication ([\#3632](https://github.com/matrix-org/synapse/issues/3632))
- Tests now correctly execute on Python 3. ([\#3647](https://github.com/matrix-org/synapse/issues/3647))
- Sytests can now be run inside a Docker container. ([\#3660](https://github.com/matrix-org/synapse/issues/3660))
- Port over enough to Python 3 to allow the sytests to start. ([\#3668](https://github.com/matrix-org/synapse/issues/3668))
- Update docker base image from alpine 3.7 to 3.8. ([\#3669](https://github.com/matrix-org/synapse/issues/3669))
- Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7. ([\#3678](https://github.com/matrix-org/synapse/issues/3678))
- Synapse's tests are now formatted with the black autoformatter. ([\#3679](https://github.com/matrix-org/synapse/issues/3679))
- Implemented a new testing base class to reduce test boilerplate. ([\#3684](https://github.com/matrix-org/synapse/issues/3684))
- Rename MAU prometheus metrics ([\#3690](https://github.com/matrix-org/synapse/issues/3690))
- add new error type ResourceLimit ([\#3707](https://github.com/matrix-org/synapse/issues/3707))
- Logcontexts for replication command handlers ([\#3709](https://github.com/matrix-org/synapse/issues/3709))
- Update admin register API documentation to reference a real user ID. ([\#3712](https://github.com/matrix-org/synapse/issues/3712))
Synapse 0.33.2 (2018-08-09) Synapse 0.33.2 (2018-08-09)
=========================== ===========================
@ -24,7 +106,7 @@ Features
Bugfixes Bugfixes
-------- --------
- Make /directory/list API return 404 for room not found instead of 400 ([\#2952](https://github.com/matrix-org/synapse/issues/2952)) - Make /directory/list API return 404 for room not found instead of 400. Thanks to @fuzzmz! ([\#3620](https://github.com/matrix-org/synapse/issues/3620))
- Default inviter_display_name to mxid for email invites ([\#3391](https://github.com/matrix-org/synapse/issues/3391)) - Default inviter_display_name to mxid for email invites ([\#3391](https://github.com/matrix-org/synapse/issues/3391))
- Don't generate TURN credentials if no TURN config options are set ([\#3514](https://github.com/matrix-org/synapse/issues/3514)) - Don't generate TURN credentials if no TURN config options are set ([\#3514](https://github.com/matrix-org/synapse/issues/3514))
- Correctly announce deleted devices over federation ([\#3520](https://github.com/matrix-org/synapse/issues/3520)) - Correctly announce deleted devices over federation ([\#3520](https://github.com/matrix-org/synapse/issues/3520))

View file

@ -1 +0,0 @@
Add support for the SNI extension to federation TLS connections

View file

@ -1 +0,0 @@
Add /_media/r0/config

View file

@ -1 +0,0 @@
The test suite now can run under PostgreSQL.

View file

@ -1 +0,0 @@
speed up /members API and add `at` and `membership` params as per MSC1227

View file

@ -1 +0,0 @@
implement `summary` block in /sync response as per MSC688

View file

@ -1 +0,0 @@
Add lazy-loading support to /messages as per MSC1227

View file

@ -1 +0,0 @@
Refactor HTTP replication endpoints to reduce code duplication

View file

@ -1 +0,0 @@
Add ability to limit number of monthly active users on the server

View file

@ -1 +0,0 @@
Tests now correctly execute on Python 3.

View file

@ -1 +0,0 @@
Support more federation endpoints on workers

View file

@ -1 +0,0 @@
Basic support for room versioning

View file

@ -1 +0,0 @@
Ability to disable client/server Synapse via conf toggle

View file

@ -1 +0,0 @@
Fix occasional glitches in the synapse_event_persisted_position metric

1
changelog.d/3659.feature Normal file
View file

@ -0,0 +1 @@
Support profile API endpoints on workers

View file

@ -1 +0,0 @@
Sytests can now be run inside a Docker container.

View file

@ -1 +0,0 @@
Fix bug on deleting 3pid when using identity servers that don't support unbind API

View file

@ -1 +0,0 @@
Ability to whitelist specific threepids against monthly active user limiting

View file

@ -1 +0,0 @@
Add some metrics for the appservice and federation event sending loops

View file

@ -1 +0,0 @@
Update docker base image from alpine 3.7 to 3.8.

View file

@ -1 +0,0 @@
Where server is disabled, block ability for locked out users to read new messages

1
changelog.d/3673.misc Normal file
View file

@ -0,0 +1 @@
Refactor state module to support multiple room versions

View file

@ -1 +0,0 @@
Make the tests pass on Twisted < 18.7.0

View file

@ -1 +0,0 @@
Dont ship recaptcha_ajax.js, use it directly from Google

View file

@ -1 +0,0 @@
Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7.

View file

@ -1 +0,0 @@
Synapse's tests are now formatted with the black autoformatter.

View file

@ -1 +0,0 @@
Fixes test_reap_monthly_active_users so it passes under postgres

View file

@ -1 +0,0 @@
Implemented a new testing base class to reduce test boilerplate.

View file

@ -1 +0,0 @@
set admin uri via config, to be used in error messages where the user should contact the administrator

View file

@ -1 +0,0 @@
Fix mau blocking calulation bug on login

View file

@ -1 +0,0 @@
Rename MAU prometheus metrics

View file

@ -1 +0,0 @@
Fix missing yield in synapse.storage.monthly_active_users.initialise_reserved_users

View file

@ -1 +0,0 @@
Support more federation endpoints on workers

View file

@ -1 +0,0 @@
add new error type ResourceLimit

View file

@ -1 +0,0 @@
For resource limit blocked users, prevent writing into rooms

View file

@ -1 +0,0 @@
Update admin register API documentation to reference a real user ID.

1
changelog.d/3722.bugfix Normal file
View file

@ -0,0 +1 @@
Fix error collecting prometheus metrics when run on dedicated thread due to threading concurrency issues

1
changelog.d/3724.feature Normal file
View file

@ -0,0 +1 @@
Allow guests to use /rooms/:roomId/event/:eventId

1
changelog.d/3726.misc Normal file
View file

@ -0,0 +1 @@
Split the state_group_cache into member and non-member state events (and so speed up LL /sync)

1
changelog.d/3727.misc Normal file
View file

@ -0,0 +1 @@
Log failure to authenticate remote servers as warnings (without stack traces)

1
changelog.d/3735.misc Normal file
View file

@ -0,0 +1 @@
Fix minor spelling error in federation client documentation.

View file

@ -241,6 +241,14 @@ regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/keys/upload ^/_matrix/client/(api/v1|r0|unstable)/keys/upload
If ``use_presence`` is False in the homeserver config, it can also handle REST
endpoints matching the following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status
This "stub" presence handler will pass through ``GET`` request but make the
``PUT`` effectively a no-op.
It will proxy any requests it cannot handle to the main synapse instance. It It will proxy any requests it cannot handle to the main synapse instance. It
must therefore be configured with the location of the main instance, via must therefore be configured with the location of the main instance, via
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
@ -257,6 +265,7 @@ Handles some event creation. It can handle REST endpoints matching::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
^/_matrix/client/(api/v1|r0|unstable)/join/ ^/_matrix/client/(api/v1|r0|unstable)/join/
^/_matrix/client/(api/v1|r0|unstable)/profile/
It will create events locally and then send them on to the main synapse It will create events locally and then send them on to the main synapse
instance to be persisted and handled. instance to be persisted and handled.

View file

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.33.2" __version__ = "0.33.3"

View file

@ -211,7 +211,7 @@ class Auth(object):
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent", b"User-Agent",
default=[b""] default=[b""]
)[0] )[0].decode('ascii', 'surrogateescape')
if user and access_token and ip_addr: if user and access_token and ip_addr:
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id=user.to_string(), user_id=user.to_string(),
@ -682,7 +682,7 @@ class Auth(object):
Returns: Returns:
bool: False if no access_token was given, True otherwise. bool: False if no access_token was given, True otherwise.
""" """
query_params = request.args.get("access_token") query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@ -698,7 +698,7 @@ class Auth(object):
401 since some of the old clients depended on auth errors returning 401 since some of the old clients depended on auth errors returning
403. 403.
Returns: Returns:
str: The access_token unicode: The access_token
Raises: Raises:
AuthError: If there isn't an access_token in the request. AuthError: If there isn't an access_token in the request.
""" """
@ -720,9 +720,9 @@ class Auth(object):
"Too many Authorization headers.", "Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN, errcode=Codes.MISSING_TOKEN,
) )
parts = auth_headers[0].split(" ") parts = auth_headers[0].split(b" ")
if parts[0] == "Bearer" and len(parts) == 2: if parts[0] == b"Bearer" and len(parts) == 2:
return parts[1] return parts[1].decode('ascii')
else: else:
raise AuthError( raise AuthError(
token_not_found_http_status, token_not_found_http_status,
@ -738,7 +738,7 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN errcode=Codes.MISSING_TOKEN
) )
return query_params[0] return query_params[0].decode('ascii')
@defer.inlineCallbacks @defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id): def check_in_room_or_world_readable(self, room_id, user_id):

View file

@ -98,13 +98,17 @@ class ThirdPartyEntityKind(object):
LOCATION = "location" LOCATION = "location"
class RoomVersions(object):
V1 = "1"
VDH_TEST = "vdh-test-version"
# the version we will give rooms which are created on this server # the version we will give rooms which are created on this server
DEFAULT_ROOM_VERSION = "1" DEFAULT_ROOM_VERSION = RoomVersions.V1
# vdh-test-version is a placeholder to get room versioning support working and tested # vdh-test-version is a placeholder to get room versioning support working and tested
# until we have a working v2. # until we have a working v2.
KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"} KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST}
ServerNoticeMsgType = "m.server_notice" ServerNoticeMsgType = "m.server_notice"
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"

View file

@ -72,7 +72,7 @@ class Ratelimiter(object):
return allowed, time_allowed return allowed, time_allowed
def prune_message_counts(self, time_now_s): def prune_message_counts(self, time_now_s):
for user_id in self.message_counts.keys(): for user_id in list(self.message_counts.keys()):
message_count, time_start, msg_rate_hz = ( message_count, time_start, msg_rate_hz = (
self.message_counts[user_id] self.message_counts[user_id]
) )

View file

@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port):
logger.info("Metrics now reporting on %s:%d", host, port) logger.info("Metrics now reporting on %s:%d", host, port)
def listen_tcp(bind_addresses, port, factory, backlog=50): def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
""" """
Create a TCP socket for a port and several addresses Create a TCP socket for a port and several addresses
""" """
@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50):
check_bind_error(e, address, bind_addresses) check_bind_error(e, address, bind_addresses)
def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): def listen_ssl(
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
):
""" """
Create an SSL socket for a port and several addresses Create an SSL socket for a port and several addresses
""" """

View file

@ -117,8 +117,9 @@ class ASReplicationHandler(ReplicationClientHandler):
super(ASReplicationHandler, self).__init__(hs.get_datastore()) super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
if stream_name == "events": if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering() max_stream_id = self.store.get_room_max_stream_ordering()

View file

@ -45,6 +45,11 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.profile import (
ProfileAvatarURLRestServlet,
ProfileDisplaynameRestServlet,
ProfileRestServlet,
)
from synapse.rest.client.v1.room import ( from synapse.rest.client.v1.room import (
JoinRoomAliasServlet, JoinRoomAliasServlet,
RoomMembershipRestServlet, RoomMembershipRestServlet,
@ -53,6 +58,7 @@ from synapse.rest.client.v1.room import (
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -62,6 +68,9 @@ logger = logging.getLogger("synapse.app.event_creator")
class EventCreatorSlavedStore( class EventCreatorSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
DirectoryStore, DirectoryStore,
SlavedTransactionStore, SlavedTransactionStore,
SlavedProfileStore, SlavedProfileStore,
@ -101,6 +110,9 @@ class EventCreatorServer(HomeServer):
RoomMembershipRestServlet(self).register(resource) RoomMembershipRestServlet(self).register(resource)
RoomStateEventRestServlet(self).register(resource) RoomStateEventRestServlet(self).register(resource)
JoinRoomAliasServlet(self).register(resource) JoinRoomAliasServlet(self).register(resource)
ProfileAvatarURLRestServlet(self).register(resource)
ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,

View file

@ -144,8 +144,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self) self.send_handler = FederationSenderHandler(hs, self)
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(FederationSenderReplicationHandler, self).on_rdata( yield super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows stream_name, token, rows
) )
self.send_handler.process_replication_rows(stream_name, token, rows) self.send_handler.process_replication_rows(stream_name, token, rows)

View file

@ -38,6 +38,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -49,6 +50,35 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy") logger = logging.getLogger("synapse.app.frontend_proxy")
class PresenceStatusStubServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
super(PresenceStatusStubServlet, self).__init__(hs)
self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
self.main_uri = hs.config.worker_main_http_uri
@defer.inlineCallbacks
def on_GET(self, request, user_id):
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
headers = {
"Authorization": auth_headers,
}
result = yield self.http_client.get_json(
self.main_uri + request.uri,
headers=headers,
)
defer.returnValue((200, result))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
yield self.auth.get_user_by_req(request)
defer.returnValue((200, {}))
class KeyUploadServlet(RestServlet): class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
@ -135,6 +165,12 @@ class FrontendProxyServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
KeyUploadServlet(self).register(resource) KeyUploadServlet(self).register(resource)
# If presence is disabled, use the stub servlet that does
# not allow sending presence
if not self.config.use_presence:
PresenceStatusStubServlet(self).register(resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
@ -153,7 +189,8 @@ class FrontendProxyServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
self.version_string, self.version_string,
) ),
reactor=self.get_reactor()
) )
logger.info("Synapse client reader now listening on port %d", port) logger.info("Synapse client reader now listening on port %d", port)

View file

@ -148,8 +148,9 @@ class PusherReplicationHandler(ReplicationClientHandler):
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -162,11 +163,11 @@ class PusherReplicationHandler(ReplicationClientHandler):
else: else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey) yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events": elif stream_name == "events":
yield self.pusher_pool.on_new_notifications( self.pusher_pool.on_new_notifications(
token, token, token, token,
) )
elif stream_name == "receipts": elif stream_name == "receipts":
yield self.pusher_pool.on_new_receipts( self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows) token, token, set(row.room_id for row in rows)
) )
except Exception: except Exception:

View file

@ -114,7 +114,10 @@ class SynchrotronPresence(object):
logger.info("Presence process_id is %r", self.process_id) logger.info("Presence process_id is %r", self.process_id)
def send_user_sync(self, user_id, is_syncing, last_sync_ms): def send_user_sync(self, user_id, is_syncing, last_sync_ms):
self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_user_sync(
user_id, is_syncing, last_sync_ms
)
def mark_as_coming_online(self, user_id): def mark_as_coming_online(self, user_id):
"""A user has started syncing. Send a UserSync to the master, unless they """A user has started syncing. Send a UserSync to the master, unless they
@ -211,10 +214,13 @@ class SynchrotronPresence(object):
yield self.notify_from_replication(states, stream_id) yield self.notify_from_replication(states, stream_id)
def get_currently_syncing_users(self): def get_currently_syncing_users(self):
return [ if self.hs.config.use_presence:
user_id for user_id, count in iteritems(self.user_to_num_current_syncs) return [
if count > 0 user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
] if count > 0
]
else:
return set()
class SynchrotronTyping(object): class SynchrotronTyping(object):
@ -332,8 +338,9 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self): def get_streams_to_replicate(self):

View file

@ -169,8 +169,9 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler() self.user_directory = hs.get_user_directory_handler()
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(UserDirectoryReplicationHandler, self).on_rdata( yield super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows stream_name, token, rows
) )
if stream_name == "current_state_deltas": if stream_name == "current_state_deltas":

View file

@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
if log_file: if log_file:
# TODO: Customisable file size / backup count # TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler( handler = logging.handlers.RotatingFileHandler(
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
encoding='utf8'
) )
def sighup(signum, stack): def sighup(signum, stack):

View file

@ -49,6 +49,9 @@ class ServerConfig(Config):
# "disable" federation # "disable" federation
self.send_federation = config.get("send_federation", True) self.send_federation = config.get("send_federation", True)
# Whether to enable user presence.
self.use_presence = config.get("use_presence", True)
# Whether to update the user directory or not. This should be set to # Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker # false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True) self.update_user_directory = config.get("update_user_directory", True)
@ -250,6 +253,9 @@ class ServerConfig(Config):
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# Set to false to disable presence tracking on this homeserver.
use_presence: true
# The GC threshold parameters to pass to `gc.set_threshold`, if defined # The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10] # gc_thresholds: [700, 10, 10]

View file

@ -18,7 +18,9 @@ import logging
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.names.error import DomainError
from twisted.web.http import HTTPClient from twisted.web.http import HTTPClient
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
@ -47,12 +49,14 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
server_response, server_certificate = yield protocol.remote_key server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate)) defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e: except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,)) logger.warn("Error getting key for %r: %s", server_name, e)
if e.status.startswith("4"): if e.status.startswith("4"):
# Don't retry for 4xx responses. # Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name) raise IOError("Cannot get key for %r" % server_name)
except (ConnectError, DomainError) as e:
logger.warn("Error getting key for %r: %s", server_name, e)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception("Error getting key for %r", server_name)
raise IOError("Cannot get key for %r" % server_name) raise IOError("Cannot get key for %r" % server_name)

View file

@ -58,6 +58,7 @@ class TransactionQueue(object):
""" """
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -308,6 +309,9 @@ class TransactionQueue(object):
Args: Args:
states (list(UserPresenceState)) states (list(UserPresenceState))
""" """
if not self.hs.config.use_presence:
# No-op if presence is disabled.
return
# First we queue up the new presence by user ID, so multiple presence # First we queue up the new presence by user ID, so multiple presence
# updates in quick successtion are correctly handled # updates in quick successtion are correctly handled

View file

@ -106,7 +106,7 @@ class TransportLayerClient(object):
dest (str) dest (str)
room_id (str) room_id (str)
event_tuples (list) event_tuples (list)
limt (int) limit (int)
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Deferred: Results in a dict received from the remote homeserver.

View file

@ -261,10 +261,10 @@ class BaseFederationServlet(object):
except NoAuthenticationError: except NoAuthenticationError:
origin = None origin = None
if self.REQUIRE_AUTH: if self.REQUIRE_AUTH:
logger.exception("authenticate_request failed") logger.warn("authenticate_request failed: missing authentication")
raise raise
except Exception: except Exception as e:
logger.exception("authenticate_request failed") logger.warn("authenticate_request failed: %s", e)
raise raise
if origin: if origin:

View file

@ -291,8 +291,9 @@ class FederationHandler(BaseHandler):
ev_ids, get_prev_content=False, check_redacted=False ev_ids, get_prev_content=False, check_redacted=False
) )
room_version = yield self.store.get_room_version(pdu.room_id)
state_map = yield resolve_events_with_factory( state_map = yield resolve_events_with_factory(
state_groups, {pdu.event_id: pdu}, fetch room_version, state_groups, {pdu.event_id: pdu}, fetch
) )
state = (yield self.store.get_events(state_map.values())).values() state = (yield self.store.get_events(state_map.values())).values()
@ -1828,7 +1829,10 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
room_version = yield self.store.get_room_version(event.room_id)
new_state = self.state_handler.resolve_events( new_state = self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())], [list(local_view.values()), list(remote_view.values())],
event event
) )
@ -2386,8 +2390,7 @@ class FederationHandler(BaseHandler):
extra_users=extra_users extra_users=extra_users
) )
logcontext.run_in_background( self.pusher_pool.on_new_notifications(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id, event_stream_id, max_stream_id,
) )

View file

@ -372,6 +372,10 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
defer.returnValue([])
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
[m.user_id for m in room_members], [m.user_id for m in room_members],
as_event=True, as_event=True,

View file

@ -778,11 +778,8 @@ class EventCreationHandler(object):
event, context=context event, context=context
) )
# this intentionally does not yield: we don't care about the result self.pusher_pool.on_new_notifications(
# and don't need to wait for it. event_stream_id, max_stream_id,
run_in_background(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id
) )
def _notify(): def _notify():

View file

@ -395,6 +395,10 @@ class PresenceHandler(object):
"""We've seen the user do something that indicates they're interacting """We've seen the user do something that indicates they're interacting
with the app. with the app.
""" """
# If presence is disabled, no-op
if not self.hs.config.use_presence:
return
user_id = user.to_string() user_id = user.to_string()
bump_active_time_counter.inc() bump_active_time_counter.inc()
@ -424,6 +428,11 @@ class PresenceHandler(object):
Useful for streams that are not associated with an actual Useful for streams that are not associated with an actual
client that is being used by a user. client that is being used by a user.
""" """
# Override if it should affect the user's presence, if presence is
# disabled.
if not self.hs.config.use_presence:
affect_presence = False
if affect_presence: if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0) curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1 self.user_to_num_current_syncs[user_id] = curr_sync + 1
@ -469,13 +478,16 @@ class PresenceHandler(object):
Returns: Returns:
set(str): A set of user_id strings. set(str): A set of user_id strings.
""" """
syncing_user_ids = { if self.hs.config.use_presence:
user_id for user_id, count in self.user_to_num_current_syncs.items() syncing_user_ids = {
if count user_id for user_id, count in self.user_to_num_current_syncs.items()
} if count
for user_ids in self.external_process_to_current_syncs.values(): }
syncing_user_ids.update(user_ids) for user_ids in self.external_process_to_current_syncs.values():
return syncing_user_ids syncing_user_ids.update(user_ids)
return syncing_user_ids
else:
return set()
@defer.inlineCallbacks @defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):

View file

@ -32,12 +32,16 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler): class BaseProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000 """Handles fetching and updating user profile information.
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
BaseProfileHandler can be instantiated directly on workers and will
delegate to master when necessary. The master process should use the
subclass MasterProfileHandler
"""
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) super(BaseProfileHandler, self).__init__(hs)
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler( hs.get_federation_registry().register_query_handler(
@ -46,11 +50,6 @@ class ProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
if hs.config.worker_app is None:
self.clock.looping_call(
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_profile(self, user_id): def get_profile(self, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -282,6 +281,20 @@ class ProfileHandler(BaseHandler):
room_id, str(e.message) room_id, str(e.message)
) )
class MasterProfileHandler(BaseProfileHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
super(MasterProfileHandler, self).__init__(hs)
assert hs.config.worker_app is None
self.clock.looping_call(
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
)
def _start_update_remote_profile_cache(self): def _start_update_remote_profile_cache(self):
return run_as_background_process( return run_as_background_process(
"Update remote profile", self._update_remote_profile_cache, "Update remote profile", self._update_remote_profile_cache,

View file

@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler from ._base import BaseHandler
@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler):
affected_room_ids = list(set([r["room_id"] for r in receipts])) affected_room_ids = list(set([r["room_id"] for r in receipts]))
with PreserveLoggingContext(): self.notifier.on_new_event(
self.notifier.on_new_event( "receipt_key", max_batch_id, rooms=affected_room_ids
"receipt_key", max_batch_id, rooms=affected_room_ids )
) # Note that the min here shouldn't be relied upon to be accurate.
# Note that the min here shouldn't be relied upon to be accurate. self.hs.get_pusherpool().on_new_receipts(
self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids,
min_batch_id, max_batch_id, affected_room_ids )
)
defer.returnValue(True) defer.returnValue(True)
@logcontext.preserve_fn # caller should not yield on this @logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -344,6 +344,7 @@ class RoomMemberHandler(object):
latest_event_ids = ( latest_event_ids = (
event_id for (event_id, _, _) in prev_events_and_hashes event_id for (event_id, _, _) in prev_events_and_hashes
) )
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )

View file

@ -185,6 +185,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
class SyncHandler(object): class SyncHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.hs_config = hs.config
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
@ -860,7 +861,7 @@ class SyncHandler(object):
since_token is None and since_token is None and
sync_config.filter_collection.blocks_all_presence() sync_config.filter_collection.blocks_all_presence()
) )
if not block_all_presence_data: if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence( yield self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_users sync_result_builder, newly_joined_rooms, newly_joined_users
) )

View file

@ -119,6 +119,8 @@ class UserDirectoryHandler(object):
"""Called to update index of our local user profiles when they change """Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in. irrespective of any rooms the user may be in.
""" """
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
yield self.store.update_profile_in_user_dir( yield self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url, None, user_id, profile.display_name, profile.avatar_url, None,
) )
@ -127,6 +129,8 @@ class UserDirectoryHandler(object):
def handle_user_deactivated(self, user_id): def handle_user_deactivated(self, user_id):
"""Called when a user ID is deactivated """Called when a user ID is deactivated
""" """
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
yield self.store.remove_from_user_dir(user_id) yield self.store.remove_from_user_dir(user_id)
yield self.store.remove_from_user_in_public_room(user_id) yield self.store.remove_from_user_in_public_room(user_id)

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import threading
from prometheus_client.core import Counter, Histogram from prometheus_client.core import Counter, Histogram
@ -111,6 +112,9 @@ in_flight_requests_db_sched_duration = Counter(
# The set of all in flight requests, set[RequestMetrics] # The set of all in flight requests, set[RequestMetrics]
_in_flight_requests = set() _in_flight_requests = set()
# Protects the _in_flight_requests set from concurrent accesss
_in_flight_requests_lock = threading.Lock()
def _get_in_flight_counts(): def _get_in_flight_counts():
"""Returns a count of all in flight requests by (method, server_name) """Returns a count of all in flight requests by (method, server_name)
@ -120,7 +124,8 @@ def _get_in_flight_counts():
""" """
# Cast to a list to prevent it changing while the Prometheus # Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics # thread is collecting metrics
reqs = list(_in_flight_requests) with _in_flight_requests_lock:
reqs = list(_in_flight_requests)
for rm in reqs: for rm in reqs:
rm.update_metrics() rm.update_metrics()
@ -154,10 +159,12 @@ class RequestMetrics(object):
# to the "in flight" metrics. # to the "in flight" metrics.
self._request_stats = self.start_context.get_resource_usage() self._request_stats = self.start_context.get_resource_usage()
_in_flight_requests.add(self) with _in_flight_requests_lock:
_in_flight_requests.add(self)
def stop(self, time_sec, request): def stop(self, time_sec, request):
_in_flight_requests.discard(self) with _in_flight_requests_lock:
_in_flight_requests.discard(self)
context = LoggingContext.current_context() context = LoggingContext.current_context()

View file

@ -25,8 +25,9 @@ from canonicaljson import encode_canonical_json, encode_pretty_printed_json, jso
from twisted.internet import defer from twisted.internet import defer
from twisted.python import failure from twisted.python import failure
from twisted.web import resource, server from twisted.web import resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.web.static import NoRangeStaticProducer
from twisted.web.util import redirectTo from twisted.web.util import redirectTo
import synapse.events import synapse.events
@ -37,10 +38,13 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.request_metrics import requests_counter
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def wrap_json_request_handler(h): def wrap_json_request_handler(h):
"""Wraps a request handler method with exception handling. """Wraps a request handler method with exception handling.
Also adds logging as per wrap_request_handler_with_logging. Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
The handler must return a deferred. If the deferred succeeds we assume that The handler must return a deferred. If the deferred succeeds we assume that
a response has been sent. If the deferred fails with a SynapseError we use a response has been sent. If the deferred fails with a SynapseError we use
@ -108,24 +111,23 @@ def wrap_json_request_handler(h):
pretty_print=_request_user_agent_is_curl(request), pretty_print=_request_user_agent_is_curl(request),
) )
return wrap_request_handler_with_logging(wrapped_request_handler) return wrap_async_request_handler(wrapped_request_handler)
def wrap_html_request_handler(h): def wrap_html_request_handler(h):
"""Wraps a request handler method with exception handling. """Wraps a request handler method with exception handling.
Also adds logging as per wrap_request_handler_with_logging. Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
""" """
def wrapped_request_handler(self, request): def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request) d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request) d.addErrback(_return_html_error, request)
return d return d
return wrap_request_handler_with_logging(wrapped_request_handler) return wrap_async_request_handler(wrapped_request_handler)
def _return_html_error(f, request): def _return_html_error(f, request):
@ -170,46 +172,26 @@ def _return_html_error(f, request):
finish_request(request) finish_request(request)
def wrap_request_handler_with_logging(h): def wrap_async_request_handler(h):
"""Wraps a request handler to provide logging and metrics """Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
is correctly recorded against the request metrics/logs.
The handler method must have a signature of "handle_foo(self, request)", The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a where "request" must be a SynapseRequest.
SynapseRequest).
As well as calling `request.processing` (which will log the response and The handler may return a deferred, in which case the completion of the request isn't
duration for this request), the wrapped request handler will insert the logged until the deferred completes.
request id into the logging context.
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped_request_handler(self, request): def wrapped_async_request_handler(self, request):
""" with request.processing():
Args: yield h(self, request)
self:
request (synapse.http.site.SynapseRequest):
"""
request_id = request.get_request_id() # we need to preserve_fn here, because the synchronous render method won't yield for
with LoggingContext(request_id) as request_context: # us (obviously)
request_context.request = request_id return preserve_fn(wrapped_async_request_handler)
with Measure(self.clock, "wrapped_request_handler"):
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = self.__class__.__name__
with request.processing(servlet_name):
with PreserveLoggingContext(request_context):
d = defer.maybeDeferred(h, self, request)
# record the arrival of the request *after*
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
requests_counter.labels(request.method,
request.request_metrics.name).inc()
yield d
return wrapped_request_handler
class HttpServer(object): class HttpServer(object):
@ -272,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This gets called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.
""" """
self._async_render(request) self._async_render(request)
return server.NOT_DONE_YET return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
@ -413,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
return return
if pretty_print: if pretty_print:
json_bytes = (encode_pretty_printed_json(json_object) + "\n" json_bytes = encode_pretty_printed_json(json_object) + b"\n"
).encode("utf-8")
else: else:
if canonical_json or synapse.events.USE_FROZEN_DICTS: if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes # canonicaljson already encodes to bytes
@ -450,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
if send_cors: if send_cors:
set_cors_headers(request) set_cors_headers(request)
request.write(json_bytes) # todo: we can almost certainly avoid this copy and encode the json straight into
finish_request(request) # the bytesIO, but it would involve faffing around with string->bytes wrappers.
bytes_io = BytesIO(json_bytes)
producer = NoRangeStaticProducer(request, bytes_io)
producer.start()
return NOT_DONE_YET return NOT_DONE_YET

View file

@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults default (int|None): value to use if the parameter is absent, defaults
to None. to None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False): def parse_integer_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
try: try:
return int(args[name][0]) return int(args[name][0])
@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults default (bool|None): value to use if the parameter is absent, defaults
to None. to None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False): def parse_boolean_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
try: try:
return { return {
"true": True, b"true": True,
"false": False, b"false": False,
}[args[name][0]] }[args[name][0]]
except Exception: except Exception:
message = ( message = (
@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string(request, name, default=None, required=False, def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string", encoding='ascii'):
"""Parse a string parameter from the request query string. """
Parse a string parameter from the request query string.
If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (str|None): value to use if the parameter is absent, defaults default (bytes/unicode|None): value to use if the parameter is absent,
to None. defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False. parameter is absent, defaults to False.
allowed_values (list[str]): List of allowed values for the string, allowed_values (list[bytes/unicode]): List of allowed values for the
or None if any value is allowed, defaults to None string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the name to, and decode the string
content with.
Returns: Returns:
str|None: A string value or the default. bytes/unicode|None: A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises: Raises:
SynapseError if the parameter is absent and required, or if the SynapseError if the parameter is absent and required, or if the
@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
is not one of those allowed values. is not one of those allowed values.
""" """
return parse_string_from_args( return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type, request.args, name, default, required, allowed_values, param_type, encoding
) )
def parse_string_from_args(args, name, default=None, required=False, def parse_string_from_args(args, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string", encoding='ascii'):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
value = args[name][0] value = args[name][0]
if encoding:
value = value.decode(encoding)
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)
@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
message = "Missing %s query parameter %r" % (param_type, name) message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
if encoding and isinstance(default, bytes):
return default.decode(encoding)
return default return default

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import logging import logging
import time import time
@ -19,8 +18,8 @@ import time
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.http import redact_uri from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.util.logcontext import ContextResourceUsage, LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,25 +33,43 @@ class SynapseRequest(Request):
It extends twisted's twisted.web.server.Request, and adds: It extends twisted's twisted.web.server.Request, and adds:
* Unique request ID * Unique request ID
* A log context associated with the request
* Redaction of access_token query-params in __repr__ * Redaction of access_token query-params in __repr__
* Logging at start and end * Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint. * Metrics to record CPU, wallclock and DB time by endpoint.
It provides a method `processing` which should be called by the Resource It also provides a method `processing`, which returns a context manager. If this
which is handling the request, and returns a context manager. method is called, the request won't be logged until the context manager is closed;
this is useful for asynchronous request handlers which may go on processing the
request even after the client has disconnected.
Attributes:
logcontext(LoggingContext) : the log context for this request
""" """
def __init__(self, site, channel, *args, **kw): def __init__(self, site, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self.site = site self.site = site
self._channel = channel self._channel = channel # this is used by the tests
self.authenticated_entity = None self.authenticated_entity = None
self.start_time = 0 self.start_time = 0
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None
global _next_request_seq global _next_request_seq
self.request_seq = _next_request_seq self.request_seq = _next_request_seq
_next_request_seq += 1 _next_request_seq += 1
# whether an asynchronous request handler has called processing()
self._is_processing = False
# the time when the asynchronous request handler completed its processing
self._processing_finished_time = None
# what time we finished sending the response to the client (or the connection
# dropped)
self.finish_time = None
def __repr__(self): def __repr__(self):
# We overwrite this so that we don't log ``access_token`` # We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
@ -74,11 +91,116 @@ class SynapseRequest(Request):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
def render(self, resrc): def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
# create a LogContext for this request
request_id = self.get_request_id()
logcontext = self.logcontext = LoggingContext(request_id)
logcontext.request = request_id
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.site.server_version_string)
return Request.render(self, resrc)
with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = resrc.__class__.__name__
self._started_processing(servlet_name)
Request.render(self, resrc)
# record the arrival of the request *after*
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
requests_counter.labels(self.method,
self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@defer.inlineCallbacks
def handle_request(request):
with request.processing("FooServlet"):
yield really_handle_the_request()
Once the context manager is closed, the completion of the request will be logged,
and the various metrics will be updated.
"""
if self._is_processing:
raise RuntimeError("Request is already processing")
self._is_processing = True
try:
yield
except Exception:
# this should already have been caught, and sent back to the client as a 500.
logger.exception("Asynchronous messge handler raised an uncaught exception")
finally:
# the request handler has finished its work and either sent the whole response
# back, or handed over responsibility to a Producer.
self._processing_finished_time = time.time()
self._is_processing = False
# if we've already sent the response, log it now; otherwise, we wait for the
# response to be sent.
if self.finish_time is not None:
self._finished_processing()
def finish(self):
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
logging.
"""
self.finish_time = time.time()
Request.finish(self)
if not self._is_processing:
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
def connectionLost(self, reason):
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
do logging.
"""
self.finish_time = time.time()
Request.connectionLost(self, reason)
# we only get here if the connection to the client drops before we send
# the response.
#
# It's useful to log it here so that we can get an idea of when
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
logger.warn(
"Error processing request %r: %s %s", self, reason.type, reason.value,
)
if not self._is_processing:
self._finished_processing()
def _started_processing(self, servlet_name): def _started_processing(self, servlet_name):
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
be sure to call finished_processing.
Args:
servlet_name (str): the name of the servlet which will be
processing this request. This is used in the metrics.
It is possible to update this afterwards by updating
self.request_metrics.name.
"""
self.start_time = time.time() self.start_time = time.time()
self.request_metrics = RequestMetrics() self.request_metrics = RequestMetrics()
self.request_metrics.start( self.request_metrics.start(
@ -94,18 +216,32 @@ class SynapseRequest(Request):
) )
def _finished_processing(self): def _finished_processing(self):
try: """Log the completion of this request and update the metrics
context = LoggingContext.current_context() """
usage = context.get_resource_usage()
except Exception:
usage = ContextResourceUsage()
end_time = time.time() if self.logcontext is None:
# this can happen if the connection closed before we read the
# headers (so render was never called). In that case we'll already
# have logged a warning, so just bail out.
return
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
# we completed the request without anything calling processing()
self._processing_finished_time = time.time()
# the time between receiving the request and the request handler finishing
processing_time = self._processing_finished_time - self.start_time
# the time between the request handler finishing and the response being sent
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
# need to decode as it could be raw utf-8 bytes # need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header # from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity authenticated_entity = self.authenticated_entity
if authenticated_entity is not None: if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace") authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# ...or could be raw utf-8 bytes in the User-Agent header. # ...or could be raw utf-8 bytes in the User-Agent header.
@ -116,22 +252,31 @@ class SynapseRequest(Request):
user_agent = self.get_user_agent() user_agent = self.get_user_agent()
if user_agent is not None: if user_agent is not None:
user_agent = user_agent.decode("utf-8", "replace") user_agent = user_agent.decode("utf-8", "replace")
else:
user_agent = "-"
code = str(self.code)
if not self.finished:
# we didn't send the full response before we gave up (presumably because
# the connection dropped)
code += "!"
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
authenticated_entity, authenticated_entity,
end_time - self.start_time, processing_time,
response_send_time,
usage.ru_utime, usage.ru_utime,
usage.ru_stime, usage.ru_stime,
usage.db_sched_duration_sec, usage.db_sched_duration_sec,
usage.db_txn_duration_sec, usage.db_txn_duration_sec,
int(usage.db_txn_count), int(usage.db_txn_count),
self.sentLength, self.sentLength,
self.code, code,
self.method, self.method,
self.get_redacted_uri(), self.get_redacted_uri(),
self.clientproto, self.clientproto,
@ -140,38 +285,10 @@ class SynapseRequest(Request):
) )
try: try:
self.request_metrics.stop(end_time, self) self.request_metrics.stop(self.finish_time, self)
except Exception as e: except Exception as e:
logger.warn("Failed to stop metrics: %r", e) logger.warn("Failed to stop metrics: %r", e)
@contextlib.contextmanager
def processing(self, servlet_name):
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@defer.inlineCallbacks
def handle_request(request):
with request.processing("FooServlet"):
yield really_handle_the_request()
This will log the request's arrival. Once the context manager is
closed, the completion of the request will be logged, and the various
metrics will be updated.
Args:
servlet_name (str): the name of the servlet which will be
processing this request. This is used in the metrics.
It is possible to update this afterwards by updating
self.request_metrics.servlet_name.
"""
# TODO: we should probably just move this into render() and finish(),
# to save having to call a separate method.
self._started_processing(servlet_name)
yield
self._finished_processing()
class XForwardedForRequest(SynapseRequest): class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
@ -217,7 +334,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False) proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied) self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string self.server_version_string = server_version_string.encode('ascii')
def log(self, request): def log(self, request):
pass pass

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import threading
import six import six
from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
@ -78,6 +80,9 @@ _background_process_counts = dict() # type: dict[str, int]
# of process descriptions that no longer have any active processes. # of process descriptions that no longer have any active processes.
_background_processes = dict() # type: dict[str, set[_BackgroundProcess]] _background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
# A lock that covers the above dicts
_bg_metrics_lock = threading.Lock()
class _Collector(object): class _Collector(object):
"""A custom metrics collector for the background process metrics. """A custom metrics collector for the background process metrics.
@ -92,7 +97,11 @@ class _Collector(object):
labels=["name"], labels=["name"],
) )
for desc, processes in six.iteritems(_background_processes): # We copy the dict so that it doesn't change from underneath us
with _bg_metrics_lock:
_background_processes_copy = dict(_background_processes)
for desc, processes in six.iteritems(_background_processes_copy):
background_process_in_flight_count.add_metric( background_process_in_flight_count.add_metric(
(desc,), len(processes), (desc,), len(processes),
) )
@ -167,19 +176,26 @@ def run_as_background_process(desc, func, *args, **kwargs):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def run(): def run():
count = _background_process_counts.get(desc, 0) with _bg_metrics_lock:
_background_process_counts[desc] = count + 1 count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
_background_process_start_count.labels(desc).inc() _background_process_start_count.labels(desc).inc()
with LoggingContext(desc) as context: with LoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count) context.request = "%s-%i" % (desc, count)
proc = _BackgroundProcess(desc, context) proc = _BackgroundProcess(desc, context)
_background_processes.setdefault(desc, set()).add(proc)
with _bg_metrics_lock:
_background_processes.setdefault(desc, set()).add(proc)
try: try:
yield func(*args, **kwargs) yield func(*args, **kwargs)
finally: finally:
proc.update_metrics() proc.update_metrics()
_background_processes[desc].remove(proc)
with _bg_metrics_lock:
_background_processes[desc].remove(proc)
with PreserveLoggingContext(): with PreserveLoggingContext():
return run() return run()

View file

@ -18,6 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -122,8 +123,14 @@ class PusherPool:
p['app_id'], p['pushkey'], p['user_name'], p['app_id'], p['pushkey'], p['user_name'],
) )
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id): def on_new_notifications(self, min_stream_id, max_stream_id):
run_as_background_process(
"on_new_notifications",
self._on_new_notifications, min_stream_id, max_stream_id,
)
@defer.inlineCallbacks
def _on_new_notifications(self, min_stream_id, max_stream_id):
try: try:
users_affected = yield self.store.get_push_action_users_in_range( users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id min_stream_id, max_stream_id
@ -147,8 +154,14 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
run_as_background_process(
"on_new_receipts",
self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids,
)
@defer.inlineCallbacks
def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
try: try:
# Need to subtract 1 from the minimum because the lower bound here # Need to subtract 1 from the minimum because the lower bound here
# is not inclusive # is not inclusive

View file

@ -156,7 +156,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
edu_content = content["content"] edu_content = content["content"]
logger.info( logger.info(
"Got %r edu from $s", "Got %r edu from %s",
edu_type, origin, edu_type, origin,
) )

View file

@ -107,7 +107,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more. Can be overriden in subclasses to handle more.
""" """
logger.info("Received rdata %s -> %s", stream_name, token) logger.info("Received rdata %s -> %s", stream_name, token)
self.store.process_replication_rows(stream_name, token, rows) return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token): def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes """Called when we get new position data. By default this just pokes
@ -115,7 +115,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more. Can be overriden in subclasses to handle more.
""" """
self.store.process_replication_rows(stream_name, token, []) return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data): def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting """When we received a SYNC we wake up any deferreds that were waiting

View file

@ -59,6 +59,12 @@ class Command(object):
""" """
return self.data return self.data
def get_logcontext_id(self):
"""Get a suitable string for the logcontext when processing this command"""
# by default, we just use the command name.
return self.NAME
class ServerCommand(Command): class ServerCommand(Command):
"""Sent by the server on new connection and includes the server_name. """Sent by the server on new connection and includes the server_name.
@ -116,6 +122,9 @@ class RdataCommand(Command):
_json_encoder.encode(self.row), _json_encoder.encode(self.row),
)) ))
def get_logcontext_id(self):
return "RDATA-" + self.stream_name
class PositionCommand(Command): class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without """Sent by the client to tell the client the stream postition without
@ -190,6 +199,9 @@ class ReplicateCommand(Command):
def to_line(self): def to_line(self):
return " ".join((self.stream_name, str(self.token),)) return " ".join((self.stream_name, str(self.token),))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
class UserSyncCommand(Command): class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or """Sent by the client to inform the server that a user has started or

View file

@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from .commands import ( from .commands import (
@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function # Now lets try and call on_<CMD_NAME> function
try: try:
getattr(self, "on_%s" % (cmd_name,))(cmd) run_as_background_process(
"replication-" + cmd.get_logcontext_id(),
getattr(self, "on_%s" % (cmd_name,)),
cmd,
)
except Exception: except Exception:
logger.exception("[%s] Failed to handle line: %r", self.id(), line) logger.exception("[%s] Failed to handle line: %r", self.id(), line)
@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.name = cmd.data self.name = cmd.data
def on_USER_SYNC(self, cmd): def on_USER_SYNC(self, cmd):
self.streamer.on_user_sync( return self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
) )
@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL": if stream_name == "ALL":
# Subscribe to all streams we're publishing to. # Subscribe to all streams we're publishing to.
for stream in iterkeys(self.streamer.streams_by_name): deferreds = [
self.subscribe_to_stream(stream, token) run_in_background(
self.subscribe_to_stream,
stream, token,
)
for stream in iterkeys(self.streamer.streams_by_name)
]
return make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else: else:
self.subscribe_to_stream(stream_name, token) return self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd): def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token) return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd): def on_REMOVE_PUSHER(self, cmd):
self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) return self.streamer.on_remove_pusher(
cmd.app_id, cmd.push_key, cmd.user_id,
)
def on_INVALIDATE_CACHE(self, cmd): def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd): def on_USER_IP(self, cmd):
self.streamer.on_user_ip( return self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen, cmd.last_seen,
) )
@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, []) rows = self.pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
return self.handler.on_rdata(stream_name, cmd.token, rows)
self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd): def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token) return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd): def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data) return self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token): def replicate(self, stream_name, token):
"""Send the subscription request to the server """Send the subscription request to the server

View file

@ -53,7 +53,7 @@ class HttpTransactionCache(object):
str: A transaction key str: A transaction key
""" """
token = self.auth.get_access_token_from_request(request) token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token return request.path.decode('utf8') + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs): def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts """A helper function for fetch_or_execute which extracts

View file

@ -84,7 +84,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except Exception: except Exception:
raise SynapseError(400, "Unable to parse state") raise SynapseError(400, "Unable to parse state")
yield self.presence_handler.set_state(user, state) if self.hs.config.use_presence:
yield self.presence_handler.set_state(user, state)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -531,7 +531,7 @@ class RoomEventServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
event = yield self.event_handler.get_event(requester.user, room_id, event_id) event = yield self.event_handler.get_event(requester.user, room_id, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -129,12 +129,9 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"] login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE is_application_server = login_type == LoginType.APPLICATION_SERVICE
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = ( can_register = (
self.enable_registration self.enable_registration
or is_application_server or is_application_server
or is_using_shared_secret
) )
if not can_register: if not can_register:
raise SynapseError(403, "Registration has been disabled") raise SynapseError(403, "Registration has been disabled")
@ -144,7 +141,6 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD: self._do_password, LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity, LoginType.EMAIL_IDENTITY: self._do_email_identity,
LoginType.APPLICATION_SERVICE: self._do_app_service, LoginType.APPLICATION_SERVICE: self._do_app_service,
LoginType.SHARED_SECRET: self._do_shared_secret,
} }
session_info = self._get_session_info(request, session) session_info = self._get_session_info(request, session)
@ -325,56 +321,6 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
}) })
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
assert_params_in_dict(register_json, ["mac", "user", "password"])
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8")
password = register_json["password"].encode("utf-8")
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
if b"\x00" in user:
raise SynapseError(400, "Invalid user")
if b"\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(register_json["mac"])
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret.encode(),
digestmod=sha1,
)
want_mac.update(user)
want_mac.update(b"\x00")
want_mac.update(password)
want_mac.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user.lower(),
password=password,
admin=bool(admin),
)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
else:
raise SynapseError(
403, "HMAC incorrect",
)
class CreateUserRestServlet(ClientV1RestServlet): class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface """Handles user creation via a server-to-server interface

View file

@ -140,7 +140,7 @@ class ConsentResource(Resource):
version = parse_string(request, "v", version = parse_string(request, "v",
default=self._default_consent_version) default=self._default_consent_version)
username = parse_string(request, "u", required=True) username = parse_string(request, "u", required=True)
userhmac = parse_string(request, "h", required=True) userhmac = parse_string(request, "h", required=True, encoding=None)
self._check_hash(username, userhmac) self._check_hash(username, userhmac)
@ -175,7 +175,7 @@ class ConsentResource(Resource):
""" """
version = parse_string(request, "v", required=True) version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True) username = parse_string(request, "u", required=True)
userhmac = parse_string(request, "h", required=True) userhmac = parse_string(request, "h", required=True, encoding=None)
self._check_hash(username, userhmac) self._check_hash(username, userhmac)
@ -210,9 +210,18 @@ class ConsentResource(Resource):
finish_request(request) finish_request(request)
def _check_hash(self, userid, userhmac): def _check_hash(self, userid, userhmac):
"""
Args:
userid (unicode):
userhmac (bytes):
Raises:
SynapseError if the hash doesn't match
"""
want_mac = hmac.new( want_mac = hmac.new(
key=self._hmac_secret, key=self._hmac_secret,
msg=userid, msg=userid.encode('utf-8'),
digestmod=sha256, digestmod=sha256,
).hexdigest() ).hexdigest()

View file

@ -55,7 +55,7 @@ class UploadResource(Resource):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader(b"Content-Length").decode('ascii')
if content_length is None: if content_length is None:
raise SynapseError( raise SynapseError(
msg="Request must specify a Content-Length", code=400 msg="Request must specify a Content-Length", code=400
@ -66,10 +66,10 @@ class UploadResource(Resource):
code=413, code=413,
) )
upload_name = parse_string(request, "filename") upload_name = parse_string(request, b"filename", encoding=None)
if upload_name: if upload_name:
try: try:
upload_name = upload_name.decode('UTF-8') upload_name = upload_name.decode('utf8')
except UnicodeDecodeError: except UnicodeDecodeError:
raise SynapseError( raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
@ -78,8 +78,8 @@ class UploadResource(Resource):
headers = request.requestHeaders headers = request.requestHeaders
if headers.hasHeader("Content-Type"): if headers.hasHeader(b"Content-Type"):
media_type = headers.getRawHeaders(b"Content-Type")[0] media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
else: else:
raise SynapseError( raise SynapseError(
msg="Upload request missing 'Content-Type'", msg="Upload request missing 'Content-Type'",

View file

@ -38,4 +38,4 @@ else:
return os.urandom(nbytes) return os.urandom(nbytes)
def token_hex(self, nbytes=32): def token_hex(self, nbytes=32):
return binascii.hexlify(self.token_bytes(nbytes)) return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')

View file

@ -56,7 +56,7 @@ from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler from synapse.handlers.pagination import PaginationHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.room import RoomContextHandler, RoomCreationHandler from synapse.handlers.room import RoomContextHandler, RoomCreationHandler
@ -308,7 +308,10 @@ class HomeServer(object):
return InitialSyncHandler(self) return InitialSyncHandler(self)
def build_profile_handler(self): def build_profile_handler(self):
return ProfileHandler(self) if self.config.worker_app:
return BaseProfileHandler(self)
else:
return MasterProfileHandler(self)
def build_event_creation_handler(self): def build_event_creation_handler(self):
return EventCreationHandler(self) return EventCreationHandler(self)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,23 +14,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import logging import logging
from collections import namedtuple from collections import namedtuple
from six import iteritems, iterkeys, itervalues from six import iteritems, itervalues
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
from synapse import event_auth from synapse.api.constants import EventTypes, RoomVersions
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.state import v1
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -40,7 +38,7 @@ logger = logging.getLogger(__name__)
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache")
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
@ -264,6 +262,7 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events( entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
@ -338,8 +337,11 @@ class StateHandler(object):
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
Args: Args:
room_id (str): room_id (str)
event_ids (list[str]): event_ids (list[str])
explicit_room_version (str|None): If set uses the the given room
version to choose the resolution algorithm. If None, then
checks the database for room version.
Returns: Returns:
Deferred[_StateCacheEntry]: resolved state Deferred[_StateCacheEntry]: resolved state
@ -353,7 +355,12 @@ class StateHandler(object):
room_id, event_ids room_id, event_ids
) )
if len(state_groups_ids) == 1: if len(state_groups_ids) == 0:
defer.returnValue(_StateCacheEntry(
state={},
state_group=None,
))
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop() name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@ -365,8 +372,11 @@ class StateHandler(object):
delta_ids=delta_ids, delta_ids=delta_ids,
)) ))
room_version = yield self.store.get_room_version(room_id)
result = yield self._state_resolution_handler.resolve_state_groups( result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, None, self._state_map_factory, room_id, room_version, state_groups_ids, None,
self._state_map_factory,
) )
defer.returnValue(result) defer.returnValue(result)
@ -375,7 +385,7 @@ class StateHandler(object):
ev_ids, get_prev_content=False, check_redacted=False, ev_ids, get_prev_content=False, check_redacted=False,
) )
def resolve_events(self, state_sets, event): def resolve_events(self, room_version, state_sets, event):
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
@ -391,7 +401,9 @@ class StateHandler(object):
} }
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = resolve_events_with_state_map(
room_version, state_set_ids, state_map,
)
new_state = { new_state = {
key: state_map[ev_id] for key, ev_id in iteritems(new_state) key: state_map[ev_id] for key, ev_id in iteritems(new_state)
@ -430,7 +442,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups( def resolve_state_groups(
self, room_id, state_groups_ids, event_map, state_map_factory, self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
): ):
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
@ -439,6 +451,7 @@ class StateResolutionHandler(object):
Args: Args:
room_id (str): room we are resolving for (used for logging) room_id (str): room we are resolving for (used for logging)
room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]): state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group map from state group id to the state in that state group
(where 'state' is a map from state key to event id) (where 'state' is a map from state key to event id)
@ -492,6 +505,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
room_version,
list(itervalues(state_groups_ids)), list(itervalues(state_groups_ids)),
event_map=event_map, event_map=event_map,
state_map_factory=state_map_factory, state_map_factory=state_map_factory,
@ -575,16 +589,10 @@ def _make_state_cache_entry(
) )
def _ordered_events(events): def resolve_events_with_state_map(room_version, state_sets, state_map):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
return sorted(events, key=key_func)
def resolve_events_with_state_map(state_sets, state_map):
""" """
Args: Args:
room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in state_map(dict): a dict from event_id to event, for all events in
@ -594,75 +602,23 @@ def resolve_events_with_state_map(state_sets, state_map):
dict[(str, str), str]: dict[(str, str), str]:
a map from (type, state_key) to event_id. a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
return state_sets[0] return v1.resolve_events_with_state_map(
state_sets, state_map,
unconflicted_state, conflicted_state = _seperate( )
state_sets, else:
) # This should only happen if we added a version but forgot to add it to
# the list above.
auth_events = _create_auth_events_from_maps( raise Exception(
unconflicted_state, conflicted_state, state_map "No state resolution algorithm defined for version %r" % (room_version,)
) )
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
def _seperate(state_sets): def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
state_sets(iterable[dict[(str, str), str]]):
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
for state_set in state_set_iterator:
for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return unconflicted_state, conflicted_state
@defer.inlineCallbacks
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
""" """
Args: Args:
room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
@ -682,185 +638,13 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
Deferred[dict[(str, str), str]]: Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id. a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
defer.returnValue(state_sets[0]) return v1.resolve_events_with_factory(
state_sets, event_map, state_map_factory,
unconflicted_state, conflicted_state = _seperate( )
state_sets, else:
) # This should only happen if we added a version but forgot to add it to
# the list above.
needed_events = set( raise Exception(
event_id "No state resolution algorithm defined for version %r" % (room_version,)
for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in itervalues(conflicted_state):
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in iteritems(conflicted_state_ids):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
key: state_map[ev_id]
for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map
}
try:
resolved_state = _resolve_state_events(
conflicted_state, auth_events
) )
except Exception:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
for key, event in iteritems(resolved_state):
new_state[key] = event.event_id
return new_state
def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event

321
synapse/state/v1.py Normal file
View file

@ -0,0 +1,321 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import logging
from six import iteritems, iterkeys, itervalues
from twisted.internet import defer
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in
state_sets.
Returns
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
@defer.inlineCallbacks
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_map_factory.
state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
Returns
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
defer.returnValue(state_sets[0])
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
needed_events = set(
event_id
for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
state_sets(iterable[dict[(str, str), str]]):
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
for state_set in state_set_iterator:
for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return unconflicted_state, conflicted_state
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in itervalues(conflicted_state):
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in iteritems(conflicted_state_ids):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
key: state_map[ev_id]
for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map
}
try:
resolved_state = _resolve_state_events(
conflicted_state, auth_events
)
except Exception:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
for key, event in iteritems(resolved_state):
new_state[key] = event.event_id
return new_state
def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event
def _ordered_events(events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
return sorted(events, key=key_func)

View file

@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
} }
events_map = {ev.event_id: ev for ev, _ in events_context} events_map = {ev.event_id: ev for ev, _ in events_context}
room_version = yield self.get_room_version(room_id)
logger.debug("calling resolve_state_groups from preserve_events") logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups( res = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups, events_map, get_events room_id, room_version, state_groups, events_map, get_events
) )
defer.returnValue((res.state, None)) defer.returnValue((res.state, None))

View file

@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache", desc="get_from_remote_profile_cache",
) )
class ProfileStore(ProfileWorkerStore):
def create_profile(self, user_localpart): def create_profile(self, user_localpart):
return self._simple_insert( return self._simple_insert(
table="profiles", table="profiles",
@ -96,6 +94,8 @@ class ProfileStore(ProfileWorkerStore):
desc="set_profile_avatar_url", desc="set_profile_avatar_url",
) )
class ProfileStore(ProfileWorkerStore):
def add_remote_profile_cache(self, user_id, displayname, avatar_url): def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles. """Ensure we are caching the remote user's profiles.

View file

@ -186,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked", desc="is_room_blocked",
) )
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
Args:
user_id (str)
Returns:
RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self._simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
allow_none=True,
desc="get_ratelimit_for_user",
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
else:
defer.returnValue(None)
class RoomStore(RoomWorkerStore, SearchStore): class RoomStore(RoomWorkerStore, SearchStore):
@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
"get_all_new_public_rooms", get_all_new_public_rooms "get_all_new_public_rooms", get_all_new_public_rooms
) )
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
Args:
user_id (str)
Returns:
RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self._simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
allow_none=True,
desc="get_ratelimit_for_user",
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
else:
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def block_room(self, room_id, user_id): def block_room(self, room_id, user_id):
yield self._simple_insert( yield self._simple_insert(

View file

@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(db_conn, hs) super(StateGroupWorkerStore, self).__init__(db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
# on the state_group* tables.
#
# The point of using a DictionaryCache is that it can cache a subset
# of the state events for a given state group (i.e. a subset of the keys for a
# given dict which is an entry in the cache for a given state group ID).
#
# However, this poses problems when performing complicated queries
# on the store - for instance: "give me all the state for this group, but
# limit members to this subset of users", as DictionaryCache's API isn't
# rich enough to say "please cache any of these fields, apart from this subset".
# This is problematic when lazy loading members, which requires this behaviour,
# as without it the cache has no choice but to speculatively load all
# state events for the group, which negates the efficiency being sought.
#
# Rather than overcomplicating DictionaryCache's API, we instead split the
# state_group_cache into two halves - one for tracking non-member events,
# and the other for tracking member_events. This means that lazy loading
# queries can be made in a cache-friendly manner by querying both caches
# separately and then merging the result. So for the example above, you
# would query the members cache for a specific subset of state keys
# (which DictionaryCache will handle efficiently and fine) and the non-members
# cache for all state (which DictionaryCache will similarly handle fine)
# and then just merge the results together.
#
# We size the non-members cache to be smaller than the members cache as the
# vast majority of state in Matrix (today) is member events.
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") "*stateGroupCache*",
# TODO: this hasn't been tuned yet
50000 * get_cache_factor_for("stateGroupCache")
)
self._state_group_members_cache = DictionaryCache(
"*stateGroupMembersCache*",
500000 * get_cache_factor_for("stateGroupMembersCache")
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types): def _get_state_groups_from_groups(self, groups, types, members=None):
"""Returns the state groups for a given set of groups, filtering on """Returns the state groups for a given set of groups, filtering on
types of state events. types of state events.
@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
types (Iterable[str, str|None]|None): list of 2-tuples of the form types (Iterable[str, str|None]|None): list of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all (`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. If None, all types are returned. state_keys for the `type`. If None, all types are returned.
members (bool|None): If not None, then, in addition to any filtering
implied by types, the results are also filtered to only include
member events (if True), or to exclude member events (if False)
Returns: Returns:
dictionary state_group -> (dict of (type, state_key) -> event id) dictionary state_group -> (dict of (type, state_key) -> event id)
@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for chunk in chunks: for chunk in chunks:
res = yield self.runInteraction( res = yield self.runInteraction(
"_get_state_groups_from_groups", "_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, chunk, types, self._get_state_groups_from_groups_txn, chunk, types, members,
) )
results.update(res) results.update(res)
defer.returnValue(results) defer.returnValue(results)
def _get_state_groups_from_groups_txn( def _get_state_groups_from_groups_txn(
self, txn, groups, types=None, self, txn, groups, types=None, members=None,
): ):
results = {group: {} for group in groups} results = {group: {} for group in groups}
@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
%s %s
""") """)
if members is True:
sql += " AND type = '%s'" % (EventTypes.Member,)
elif members is False:
sql += " AND type <> '%s'" % (EventTypes.Member,)
# Turns out that postgres doesn't like doing a list of OR's and # Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific # is about 1000x slower, so we just issue a query for each specific
# type seperately. # type seperately.
@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
else: else:
where_clause = "" where_clause = ""
if members is True:
where_clause += " AND type = '%s'" % EventTypes.Member
elif members is False:
where_clause += " AND type <> '%s'" % EventTypes.Member
# We don't use WITH RECURSIVE on sqlite3 as there are distributions # We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy) # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups: for group in groups:
@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types, filtered_types=None): def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
Args: Args:
cache(DictionaryCache): the state group cache to use
group(int): The state group to lookup group(int): The state group to lookup
types(list[str, str|None]): List of 2-tuples of the form types(list[str, str|None]): List of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all (`type`, `state_key`), where a `state_key` of `None` matches all
@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the requests state from the cache, if False we need to query the DB for the
missing state. missing state.
""" """
is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) is_all, known_absent, state_dict_ids = cache.get(group)
type_to_key = {} type_to_key = {}
# tracks whether any of ourrequested types are missing from the cache # tracks whether any of our requested types are missing from the cache
missing_types = False missing_types = False
for typ, state_key in types: for typ, state_key in types:
@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if include(k[0], k[1]) if include(k[0], k[1])
}, got_all }, got_all
def _get_all_state_from_cache(self, group): def _get_all_state_from_cache(self, cache, group):
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cache, if False we need to query the DB for the missing state. cache, if False we need to query the DB for the missing state.
Args: Args:
cache(DictionaryCache): the state group cache to use
group: The state group to lookup group: The state group to lookup
""" """
is_all, _, state_dict_ids = self._state_group_cache.get(group) is_all, _, state_dict_ids = cache.get(group)
return state_dict_ids, is_all return state_dict_ids, is_all
@ -681,6 +731,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list of event types. Other types of events are returned unfiltered. list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events. If None, `types` filtering is applied to all events.
Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary.
"""
if types is not None:
non_member_types = [t for t in types if t[0] != EventTypes.Member]
if filtered_types is not None and EventTypes.Member not in filtered_types:
# we want all of the membership events
member_types = None
else:
member_types = [t for t in types if t[0] == EventTypes.Member]
else:
non_member_types = None
member_types = None
non_member_state = yield self._get_state_for_groups_using_cache(
groups, self._state_group_cache, non_member_types, filtered_types,
)
# XXX: we could skip this entirely if member_types is []
member_state = yield self._get_state_for_groups_using_cache(
# we set filtered_types=None as member_state only ever contain members.
groups, self._state_group_members_cache, member_types, None,
)
state = non_member_state
for group in groups:
state[group].update(member_state[group])
defer.returnValue(state)
@defer.inlineCallbacks
def _get_state_for_groups_using_cache(
self, groups, cache, types=None, filtered_types=None
):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
cache (DictionaryCache): the cache of group ids to state dicts which
we will pass through - either the normal state cache or the specific
members state cache.
types (None|iterable[(str, None|str)]):
indicates the state type/keys required. If None, the whole
state is fetched and returned.
Otherwise, each entry should be a `(type, state_key)` tuple to
include in the response. A `state_key` of None is a wildcard
meaning that we require all state with that type.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]] Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary. a dictionary mapping from state group to state dictionary.
@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if types is not None: if types is not None:
for group in set(groups): for group in set(groups):
state_dict_ids, got_all = self._get_some_state_from_cache( state_dict_ids, got_all = self._get_some_state_from_cache(
group, types, filtered_types cache, group, types, filtered_types
) )
results[group] = state_dict_ids results[group] = state_dict_ids
@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
else: else:
for group in set(groups): for group in set(groups):
state_dict_ids, got_all = self._get_all_state_from_cache( state_dict_ids, got_all = self._get_all_state_from_cache(
group cache, group
) )
results[group] = state_dict_ids results[group] = state_dict_ids
@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
missing_groups.append(group) missing_groups.append(group)
if missing_groups: if missing_groups:
# Okay, so we have some missing_types, lets fetch them. # Okay, so we have some missing_types, let's fetch them.
cache_seq_num = self._state_group_cache.sequence cache_seq_num = cache.sequence
# the DictionaryCache knows if it has *all* the state, but # the DictionaryCache knows if it has *all* the state, but
# does not know if it has all of the keys of a particular type, # does not know if it has all of the keys of a particular type,
@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
types_to_fetch = types types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types_to_fetch missing_groups, types_to_fetch, cache == self._state_group_members_cache,
) )
for group, group_state_dict in iteritems(group_to_state_dict): for group, group_state_dict in iteritems(group_to_state_dict):
@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# update the cache with all the things we fetched from the # update the cache with all the things we fetched from the
# database. # database.
self._state_group_cache.update( cache.update(
cache_seq_num, cache_seq_num,
key=group, key=group,
value=group_state_dict, value=group_state_dict,
@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
], ],
) )
# Prefill the state group cache with this group. # Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map # It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could # is immutable. (If the map wasn't immutable then this prefill could
# race with another update) # race with another update)
current_member_state_ids = {
s: ev
for (s, ev) in iteritems(current_state_ids)
if s[0] == EventTypes.Member
}
txn.call_after(
self._state_group_members_cache.update,
self._state_group_members_cache.sequence,
key=state_group,
value=dict(current_member_state_ids),
)
current_non_member_state_ids = {
s: ev
for (s, ev) in iteritems(current_state_ids)
if s[0] != EventTypes.Member
}
txn.call_after( txn.call_after(
self._state_group_cache.update, self._state_group_cache.update,
self._state_group_cache.sequence, self._state_group_cache.sequence,
key=state_group, key=state_group,
value=dict(current_state_ids), value=dict(current_non_member_state_ids),
) )
return state_group return state_group

View file

@ -385,7 +385,13 @@ class LoggingContextFilter(logging.Filter):
context = LoggingContext.current_context() context = LoggingContext.current_context()
for key, value in self.defaults.items(): for key, value in self.defaults.items():
setattr(record, key, value) setattr(record, key, value)
context.copy_to(record)
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
context.copy_to(record)
return True return True
@ -396,7 +402,9 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"] __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=LoggingContext.sentinel): def __init__(self, new_context=None):
if new_context is None:
new_context = LoggingContext.sentinel
self.new_context = new_context self.new_context = new_context
def __enter__(self): def __enter__(self):

View file

@ -20,6 +20,8 @@ import time
from functools import wraps from functools import wraps
from inspect import getcallargs from inspect import getcallargs
from six import PY3
_TIME_FUNC_ID = 0 _TIME_FUNC_ID = 0
@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name) logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
lineno = f.func_code.co_firstlineno if PY3:
pathname = f.func_code.co_filename lineno = f.__code__.co_firstlineno
pathname = f.__code__.co_filename
else:
lineno = f.func_code.co_firstlineno
pathname = f.func_code.co_filename
record = logging.LogRecord( record = logging.LogRecord(
name=name, name=name,

View file

@ -16,6 +16,7 @@
import random import random
import string import string
from six import PY3
from six.moves import range from six.moves import range
_string_with_symbols = ( _string_with_symbols = (
@ -34,6 +35,17 @@ def random_string_with_symbols(length):
def is_ascii(s): def is_ascii(s):
if PY3:
if isinstance(s, bytes):
try:
s.decode('ascii').encode('ascii')
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
return False
return True
try: try:
s.encode("ascii") s.encode("ascii")
except UnicodeEncodeError: except UnicodeEncodeError:
@ -49,6 +61,9 @@ def to_ascii(s):
If given None then will return None. If given None then will return None.
""" """
if PY3:
return s
if s is None: if s is None:
return None return None

View file

@ -30,7 +30,7 @@ def get_version_string(module):
['git', 'rev-parse', '--abbrev-ref', 'HEAD'], ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip() ).strip().decode('ascii')
git_branch = "b=" + git_branch git_branch = "b=" + git_branch
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_branch = "" git_branch = ""
@ -40,7 +40,7 @@ def get_version_string(module):
['git', 'describe', '--exact-match'], ['git', 'describe', '--exact-match'],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip() ).strip().decode('ascii')
git_tag = "t=" + git_tag git_tag = "t=" + git_tag
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_tag = "" git_tag = ""
@ -50,7 +50,7 @@ def get_version_string(module):
['git', 'rev-parse', '--short', 'HEAD'], ['git', 'rev-parse', '--short', 'HEAD'],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip() ).strip().decode('ascii')
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_commit = "" git_commit = ""
@ -60,7 +60,7 @@ def get_version_string(module):
['git', 'describe', '--dirty=' + dirty_string], ['git', 'describe', '--dirty=' + dirty_string],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip().endswith(dirty_string) ).strip().decode('ascii').endswith(dirty_string)
git_dirty = "dirty" if is_dirty else "" git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
@ -77,8 +77,8 @@ def get_version_string(module):
"%s (%s)" % ( "%s (%s)" % (
module.__version__, git_version, module.__version__, git_version,
) )
).encode("ascii") )
except Exception as e: except Exception as e:
logger.info("Failed to check for git repository: %s", e) logger.info("Failed to check for git repository: %s", e)
return module.__version__.encode("ascii") return module.__version__

0
tests/app/__init__.py Normal file
View file

View file

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.app.frontend_proxy import FrontendProxyServer
from tests.unittest import HomeserverTestCase
class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
http_client=None, homeserverToUse=FrontendProxyServer
)
return hs
def test_listen_http_with_presence_enabled(self):
"""
When presence is on, the stub servlet will not register.
"""
# Presence is on
self.hs.config.use_presence = True
config = {
"port": 8080,
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": ["client"]}],
}
# Listen with the config
self.hs._listen_http(config)
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"]
)
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
def test_listen_http_with_presence_disabled(self):
"""
When presence is on, the stub servlet will register.
"""
# Presence is off
self.hs.config.use_presence = False
config = {
"port": 8080,
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": ["client"]}],
}
# Listen with the config
self.hs._listen_http(config)
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"]
)
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
@ -29,7 +29,7 @@ from tests.utils import setup_test_homeserver
class ProfileHandlers(object): class ProfileHandlers(object):
def __init__(self, hs): def __init__(self, hs):
self.profile_handler = ProfileHandler(hs) self.profile_handler = MasterProfileHandler(hs)
class ProfileTestCase(unittest.TestCase): class ProfileTestCase(unittest.TestCase):

Some files were not shown because too many files have changed in this diff Show more