mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 18:03:53 +01:00
Merge remote-tracking branch 'matrix-org/develop' into travis/new-worker-docs
This commit is contained in:
commit
88964b987e
126 changed files with 3867 additions and 2528 deletions
86
CHANGES.rst
86
CHANGES.rst
|
@ -1,11 +1,89 @@
|
||||||
Unreleased
|
Changes in synapse v0.27.2 (2018-03-26)
|
||||||
==========
|
=======================================
|
||||||
|
|
||||||
synctl no longer starts the main synapse when using ``-a`` option with workers.
|
Bug fixes:
|
||||||
A new worker file should be added with ``worker_app: synapse.app.homeserver``.
|
|
||||||
|
* Fix bug which broke TCP replication between workers (PR #3015)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.27.1 (2018-03-26)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
Meta release as v0.27.0 temporarily pointed to the wrong commit
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.27.0 (2018-03-26)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
No changes since v0.27.0-rc2
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.27.0-rc2 (2018-03-19)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Pulls in v0.26.1
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix bug introduced in v0.27.0-rc1 that causes much increased memory usage in state cache (PR #3005)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.26.1 (2018-03-15)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix bug where an invalid event caused server to stop functioning correctly,
|
||||||
|
due to parsing and serializing bugs in ujson library (PR #3008)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.27.0-rc1 (2018-03-14)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
The common case for running Synapse is not to run separate workers, but for those that do, be aware that synctl no longer starts the main synapse when using ``-a`` option with workers. A new worker file should be added with ``worker_app: synapse.app.homeserver``.
|
||||||
|
|
||||||
This release also begins the process of renaming a number of the metrics
|
This release also begins the process of renaming a number of the metrics
|
||||||
reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
|
reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
|
||||||
|
Note that the v0.28.0 release will remove the deprecated metric names.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add ability for ASes to override message send time (PR #2754)
|
||||||
|
* Add support for custom storage providers for media repository (PR #2867, #2777, #2783, #2789, #2791, #2804, #2812, #2814, #2857, #2868, #2767)
|
||||||
|
* Add purge API features, see `docs/admin_api/purge_history_api.rst <docs/admin_api/purge_history_api.rst>`_ for full details (PR #2858, #2867, #2882, #2946, #2962, #2943)
|
||||||
|
* Add support for whitelisting 3PIDs that users can register. (PR #2813)
|
||||||
|
* Add ``/room/{id}/event/{id}`` API (PR #2766)
|
||||||
|
* Add an admin API to get all the media in a room (PR #2818) Thanks to @turt2live!
|
||||||
|
* Add ``federation_domain_whitelist`` option (PR #2820, #2821)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Continue to factor out processing from main process and into worker processes. See updated `docs/workers.rst <docs/workers.rst>`_ (PR #2892 - #2904, #2913, #2920 - #2926, #2947, #2847, #2854, #2872, #2873, #2874, #2928, #2929, #2934, #2856, #2976 - #2984, #2987 - #2989, #2991 - #2993, #2995, #2784)
|
||||||
|
* Ensure state cache is used when persisting events (PR #2864, #2871, #2802, #2835, #2836, #2841, #2842, #2849)
|
||||||
|
* Change the default config to bind on both IPv4 and IPv6 on all platforms (PR #2435) Thanks to @silkeh!
|
||||||
|
* No longer require a specific version of saml2 (PR #2695) Thanks to @okurz!
|
||||||
|
* Remove ``verbosity``/``log_file`` from generated config (PR #2755)
|
||||||
|
* Add and improve metrics and logging (PR #2770, #2778, #2785, #2786, #2787, #2793, #2794, #2795, #2809, #2810, #2833, #2834, #2844, #2965, #2927, #2975, #2790, #2796, #2838)
|
||||||
|
* When using synctl with workers, don't start the main synapse automatically (PR #2774)
|
||||||
|
* Minor performance improvements (PR #2773, #2792)
|
||||||
|
* Use a connection pool for non-federation outbound connections (PR #2817)
|
||||||
|
* Make it possible to run unit tests against postgres (PR #2829)
|
||||||
|
* Update pynacl dependency to 1.2.1 or higher (PR #2888) Thanks to @bachp!
|
||||||
|
* Remove ability for AS users to call /events and /sync (PR #2948)
|
||||||
|
* Use bcrypt.checkpw (PR #2949) Thanks to @krombel!
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix broken ``ldap_config`` config option (PR #2683) Thanks to @seckrv!
|
||||||
|
* Fix error message when user is not allowed to unban (PR #2761) Thanks to @turt2live!
|
||||||
|
* Fix publicised groups GET API (singular) over federation (PR #2772)
|
||||||
|
* Fix user directory when using ``user_directory_search_all_users`` config option (PR #2803, #2831)
|
||||||
|
* Fix error on ``/publicRooms`` when no rooms exist (PR #2827)
|
||||||
|
* Fix bug in quarantine_media (PR #2837)
|
||||||
|
* Fix url_previews when no Content-Type is returned from URL (PR #2845)
|
||||||
|
* Fix rare race in sync API when joining room (PR #2944)
|
||||||
|
* Fix slow event search, switch back from GIST to GIN indexes (PR #2769, #2848)
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.26.0 (2018-01-05)
|
Changes in synapse v0.26.0 (2018-01-05)
|
||||||
|
|
|
@ -30,8 +30,12 @@ use github's pull request workflow to review the contribution, and either ask
|
||||||
you to make any refinements needed or merge it and make them ourselves. The
|
you to make any refinements needed or merge it and make them ourselves. The
|
||||||
changes will then land on master when we next do a release.
|
changes will then land on master when we next do a release.
|
||||||
|
|
||||||
We use Jenkins for continuous integration (http://matrix.org/jenkins), and
|
We use `Jenkins <http://matrix.org/jenkins>`_ and
|
||||||
typically all pull requests get automatically tested Jenkins: if your change breaks the build, Jenkins will yell about it in #matrix-dev:matrix.org so please lurk there and keep an eye open.
|
`Travis <https://travis-ci.org/matrix-org/synapse>`_ for continuous
|
||||||
|
integration. All pull requests to synapse get automatically tested by Travis;
|
||||||
|
the Jenkins builds require an adminstrator to start them. If your change
|
||||||
|
breaks the build, this will be shown in github, so please keep an eye on the
|
||||||
|
pull request for feedback.
|
||||||
|
|
||||||
Code style
|
Code style
|
||||||
~~~~~~~~~~
|
~~~~~~~~~~
|
||||||
|
|
15
README.rst
15
README.rst
|
@ -354,6 +354,10 @@ https://matrix.org/docs/projects/try-matrix-now.html (or build your own with one
|
||||||
Fedora
|
Fedora
|
||||||
------
|
------
|
||||||
|
|
||||||
|
Synapse is in the Fedora repositories as ``matrix-synapse``::
|
||||||
|
|
||||||
|
sudo dnf install matrix-synapse
|
||||||
|
|
||||||
Oleg Girko provides Fedora RPMs at
|
Oleg Girko provides Fedora RPMs at
|
||||||
https://obs.infoserver.lv/project/monitor/matrix-synapse
|
https://obs.infoserver.lv/project/monitor/matrix-synapse
|
||||||
|
|
||||||
|
@ -890,6 +894,17 @@ This should end with a 'PASSED' result::
|
||||||
|
|
||||||
PASSED (successes=143)
|
PASSED (successes=143)
|
||||||
|
|
||||||
|
Running the Integration Tests
|
||||||
|
=============================
|
||||||
|
|
||||||
|
Synapse is accompanied by `SyTest <https://github.com/matrix-org/sytest>`_,
|
||||||
|
a Matrix homeserver integration testing suite, which uses HTTP requests to
|
||||||
|
access the API as a Matrix client would. It is able to run Synapse directly from
|
||||||
|
the source tree, so installation of the server is not required.
|
||||||
|
|
||||||
|
Testing with SyTest is recommended for verifying that changes related to the
|
||||||
|
Client-Server API are functioning correctly. See the `installation instructions
|
||||||
|
<https://github.com/matrix-org/sytest#installing>`_ for details.
|
||||||
|
|
||||||
Building Internal API Documentation
|
Building Internal API Documentation
|
||||||
===================================
|
===================================
|
||||||
|
|
12
UPGRADE.rst
12
UPGRADE.rst
|
@ -48,6 +48,18 @@ returned by the Client-Server API:
|
||||||
# configured on port 443.
|
# configured on port 443.
|
||||||
curl -kv https://<host.name>/_matrix/client/versions 2>&1 | grep "Server:"
|
curl -kv https://<host.name>/_matrix/client/versions 2>&1 | grep "Server:"
|
||||||
|
|
||||||
|
Upgrading to $NEXT_VERSION
|
||||||
|
====================
|
||||||
|
|
||||||
|
This release expands the anonymous usage stats sent if the opt-in
|
||||||
|
``report_stats`` configuration is set to ``true``. We now capture RSS memory
|
||||||
|
and cpu use at a very coarse level. This requires administrators to install
|
||||||
|
the optional ``psutil`` python module.
|
||||||
|
|
||||||
|
We would appreciate it if you could assist by ensuring this module is available
|
||||||
|
and ``report_stats`` is enabled. This will let us see if performance changes to
|
||||||
|
synapse are having an impact to the general community.
|
||||||
|
|
||||||
Upgrading to v0.15.0
|
Upgrading to v0.15.0
|
||||||
====================
|
====================
|
||||||
|
|
||||||
|
|
|
@ -8,20 +8,56 @@ Depending on the amount of history being purged a call to the API may take
|
||||||
several minutes or longer. During this period users will not be able to
|
several minutes or longer. During this period users will not be able to
|
||||||
paginate further back in the room from the point being purged from.
|
paginate further back in the room from the point being purged from.
|
||||||
|
|
||||||
The API is simply:
|
The API is:
|
||||||
|
|
||||||
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
|
``POST /_matrix/client/r0/admin/purge_history/<room_id>[/<event_id>]``
|
||||||
|
|
||||||
including an ``access_token`` of a server admin.
|
including an ``access_token`` of a server admin.
|
||||||
|
|
||||||
By default, events sent by local users are not deleted, as they may represent
|
By default, events sent by local users are not deleted, as they may represent
|
||||||
the only copies of this content in existence. (Events sent by remote users are
|
the only copies of this content in existence. (Events sent by remote users are
|
||||||
deleted, and room state data before the cutoff is always removed).
|
deleted.)
|
||||||
|
|
||||||
To delete local events as well, set ``delete_local_events`` in the body:
|
Room state data (such as joins, leaves, topic) is always preserved.
|
||||||
|
|
||||||
|
To delete local message events as well, set ``delete_local_events`` in the body:
|
||||||
|
|
||||||
.. code:: json
|
.. code:: json
|
||||||
|
|
||||||
{
|
{
|
||||||
"delete_local_events": true
|
"delete_local_events": true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
The caller must specify the point in the room to purge up to. This can be
|
||||||
|
specified by including an event_id in the URI, or by setting a
|
||||||
|
``purge_up_to_event_id`` or ``purge_up_to_ts`` in the request body. If an event
|
||||||
|
id is given, that event (and others at the same graph depth) will be retained.
|
||||||
|
If ``purge_up_to_ts`` is given, it should be a timestamp since the unix epoch,
|
||||||
|
in milliseconds.
|
||||||
|
|
||||||
|
The API starts the purge running, and returns immediately with a JSON body with
|
||||||
|
a purge id:
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"purge_id": "<opaque id>"
|
||||||
|
}
|
||||||
|
|
||||||
|
Purge status query
|
||||||
|
------------------
|
||||||
|
|
||||||
|
It is possible to poll for updates on recent purges with a second API;
|
||||||
|
|
||||||
|
``GET /_matrix/client/r0/admin/purge_history_status/<purge_id>``
|
||||||
|
|
||||||
|
(again, with a suitable ``access_token``). This API returns a JSON body like
|
||||||
|
the following:
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"status": "active"
|
||||||
|
}
|
||||||
|
|
||||||
|
The status will be one of ``active``, ``complete``, or ``failed``.
|
||||||
|
|
|
@ -279,9 +279,9 @@ Obviously that option means that the operations done in
|
||||||
that might be fixed by setting a different logcontext via a ``with
|
that might be fixed by setting a different logcontext via a ``with
|
||||||
LoggingContext(...)`` in ``background_operation``).
|
LoggingContext(...)`` in ``background_operation``).
|
||||||
|
|
||||||
The second option is to use ``logcontext.preserve_fn``, which wraps a function
|
The second option is to use ``logcontext.run_in_background``, which wraps a
|
||||||
so that it doesn't reset the logcontext even when it returns an incomplete
|
function so that it doesn't reset the logcontext even when it returns an
|
||||||
deferred, and adds a callback to the returned deferred to reset the
|
incomplete deferred, and adds a callback to the returned deferred to reset the
|
||||||
logcontext. In other words, it turns a function that follows the Synapse rules
|
logcontext. In other words, it turns a function that follows the Synapse rules
|
||||||
about logcontexts and Deferreds into one which behaves more like an external
|
about logcontexts and Deferreds into one which behaves more like an external
|
||||||
function — the opposite operation to that described in the previous section.
|
function — the opposite operation to that described in the previous section.
|
||||||
|
@ -293,7 +293,7 @@ It can be used like this:
|
||||||
def do_request_handling():
|
def do_request_handling():
|
||||||
yield foreground_operation()
|
yield foreground_operation()
|
||||||
|
|
||||||
logcontext.preserve_fn(background_operation)()
|
logcontext.run_in_background(background_operation)
|
||||||
|
|
||||||
# this will now be logged against the request context
|
# this will now be logged against the request context
|
||||||
logger.debug("Request handling complete")
|
logger.debug("Request handling complete")
|
||||||
|
|
|
@ -235,7 +235,7 @@ file. For example::
|
||||||
``synapse.app.event_creator``
|
``synapse.app.event_creator``
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
Handles some event creation. It can handle REST endpoints matching:
|
Handles some event creation. It can handle REST endpoints matching::
|
||||||
|
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
|
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
|
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
|
||||||
|
|
|
@ -16,4 +16,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.26.0"
|
__version__ = "0.27.2"
|
||||||
|
|
|
@ -15,9 +15,10 @@
|
||||||
|
|
||||||
"""Contains exceptions and error codes."""
|
"""Contains exceptions and error codes."""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.types import UserID, RoomID
|
from synapse.types import UserID, RoomID
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
import jsonschema
|
import jsonschema
|
||||||
from jsonschema import FormatChecker
|
from jsonschema import FormatChecker
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.appservice")
|
logger = logging.getLogger("synapse.app.appservice")
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class AppserviceServer(HomeServer):
|
||||||
if name == "metrics":
|
if name == "metrics":
|
||||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
|
|
@ -44,7 +44,7 @@ from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.client_reader")
|
logger = logging.getLogger("synapse.app.client_reader")
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ class ClientReaderServer(HomeServer):
|
||||||
"/_matrix/client/api/v1": resource,
|
"/_matrix/client/api/v1": resource,
|
||||||
})
|
})
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
@ -156,7 +156,6 @@ def start(config_options):
|
||||||
)
|
)
|
||||||
|
|
||||||
ss.setup()
|
ss.setup()
|
||||||
ss.get_handlers()
|
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
|
|
|
@ -27,14 +27,24 @@ from synapse.http.server import JsonResource
|
||||||
from synapse.http.site import SynapseSite
|
from synapse.http.site import SynapseSite
|
||||||
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
|
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
|
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||||
|
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
|
from synapse.replication.slave.storage.profile import SlavedProfileStore
|
||||||
|
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||||
|
from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
||||||
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.client.v1.room import RoomSendEventRestServlet
|
from synapse.rest.client.v1.room import (
|
||||||
|
RoomSendEventRestServlet, RoomMembershipRestServlet, RoomStateEventRestServlet,
|
||||||
|
JoinRoomAliasServlet,
|
||||||
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
@ -42,12 +52,19 @@ from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.event_creator")
|
logger = logging.getLogger("synapse.app.event_creator")
|
||||||
|
|
||||||
|
|
||||||
class EventCreatorSlavedStore(
|
class EventCreatorSlavedStore(
|
||||||
|
DirectoryStore,
|
||||||
|
TransactionStore,
|
||||||
|
SlavedProfileStore,
|
||||||
|
SlavedAccountDataStore,
|
||||||
|
SlavedPusherStore,
|
||||||
|
SlavedReceiptsStore,
|
||||||
|
SlavedPushRuleStore,
|
||||||
SlavedDeviceStore,
|
SlavedDeviceStore,
|
||||||
SlavedClientIpStore,
|
SlavedClientIpStore,
|
||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
|
@ -77,6 +94,9 @@ class EventCreatorServer(HomeServer):
|
||||||
elif name == "client":
|
elif name == "client":
|
||||||
resource = JsonResource(self, canonical_json=False)
|
resource = JsonResource(self, canonical_json=False)
|
||||||
RoomSendEventRestServlet(self).register(resource)
|
RoomSendEventRestServlet(self).register(resource)
|
||||||
|
RoomMembershipRestServlet(self).register(resource)
|
||||||
|
RoomStateEventRestServlet(self).register(resource)
|
||||||
|
JoinRoomAliasServlet(self).register(resource)
|
||||||
resources.update({
|
resources.update({
|
||||||
"/_matrix/client/r0": resource,
|
"/_matrix/client/r0": resource,
|
||||||
"/_matrix/client/unstable": resource,
|
"/_matrix/client/unstable": resource,
|
||||||
|
@ -84,7 +104,7 @@ class EventCreatorServer(HomeServer):
|
||||||
"/_matrix/client/api/v1": resource,
|
"/_matrix/client/api/v1": resource,
|
||||||
})
|
})
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
@ -153,7 +173,6 @@ def start(config_options):
|
||||||
)
|
)
|
||||||
|
|
||||||
ss.setup()
|
ss.setup()
|
||||||
ss.get_handlers()
|
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
|
|
|
@ -41,7 +41,7 @@ from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.federation_reader")
|
logger = logging.getLogger("synapse.app.federation_reader")
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ class FederationReaderServer(HomeServer):
|
||||||
FEDERATION_PREFIX: TransportLayerServer(self),
|
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||||
})
|
})
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
@ -144,7 +144,6 @@ def start(config_options):
|
||||||
)
|
)
|
||||||
|
|
||||||
ss.setup()
|
ss.setup()
|
||||||
ss.get_handlers()
|
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
|
|
|
@ -42,7 +42,7 @@ from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.federation_sender")
|
logger = logging.getLogger("synapse.app.federation_sender")
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ class FederationSenderServer(HomeServer):
|
||||||
if name == "metrics":
|
if name == "metrics":
|
||||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
|
|
@ -44,7 +44,7 @@ from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.frontend_proxy")
|
logger = logging.getLogger("synapse.app.frontend_proxy")
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ class FrontendProxyServer(HomeServer):
|
||||||
"/_matrix/client/api/v1": resource,
|
"/_matrix/client/api/v1": resource,
|
||||||
})
|
})
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
@ -211,7 +211,6 @@ def start(config_options):
|
||||||
)
|
)
|
||||||
|
|
||||||
ss.setup()
|
ss.setup()
|
||||||
ss.get_handlers()
|
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
|
|
|
@ -56,7 +56,7 @@ from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.resource import EncodingResourceWrapper, Resource
|
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
||||||
from twisted.web.server import GzipEncoderFactory
|
from twisted.web.server import GzipEncoderFactory
|
||||||
from twisted.web.static import File
|
from twisted.web.static import File
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
if WEB_CLIENT_PREFIX in resources:
|
if WEB_CLIENT_PREFIX in resources:
|
||||||
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
||||||
else:
|
else:
|
||||||
root_resource = Resource()
|
root_resource = NoResource()
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, root_resource)
|
root_resource = create_resource_tree(resources, root_resource)
|
||||||
|
|
||||||
|
@ -348,7 +348,7 @@ def setup(config_options):
|
||||||
hs.get_state_handler().start_caching()
|
hs.get_state_handler().start_caching()
|
||||||
hs.get_datastore().start_profiling()
|
hs.get_datastore().start_profiling()
|
||||||
hs.get_datastore().start_doing_background_updates()
|
hs.get_datastore().start_doing_background_updates()
|
||||||
hs.get_replication_layer().start_get_pdu_cache()
|
hs.get_federation_client().start_get_pdu_cache()
|
||||||
|
|
||||||
register_memory_metrics(hs)
|
register_memory_metrics(hs)
|
||||||
|
|
||||||
|
@ -402,6 +402,10 @@ def run(hs):
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
|
|
||||||
|
# Contains the list of processes we will be monitoring
|
||||||
|
# currently either 0 or 1
|
||||||
|
stats_process = []
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def phone_stats_home():
|
def phone_stats_home():
|
||||||
logger.info("Gathering stats for reporting")
|
logger.info("Gathering stats for reporting")
|
||||||
|
@ -428,6 +432,13 @@ def run(hs):
|
||||||
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
|
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
|
||||||
stats["daily_sent_messages"] = daily_sent_messages
|
stats["daily_sent_messages"] = daily_sent_messages
|
||||||
|
|
||||||
|
if len(stats_process) > 0:
|
||||||
|
stats["memory_rss"] = 0
|
||||||
|
stats["cpu_average"] = 0
|
||||||
|
for process in stats_process:
|
||||||
|
stats["memory_rss"] += process.memory_info().rss
|
||||||
|
stats["cpu_average"] += int(process.cpu_percent(interval=None))
|
||||||
|
|
||||||
logger.info("Reporting stats to matrix.org: %s" % (stats,))
|
logger.info("Reporting stats to matrix.org: %s" % (stats,))
|
||||||
try:
|
try:
|
||||||
yield hs.get_simple_http_client().put_json(
|
yield hs.get_simple_http_client().put_json(
|
||||||
|
@ -437,10 +448,32 @@ def run(hs):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("Error reporting stats: %s", e)
|
logger.warn("Error reporting stats: %s", e)
|
||||||
|
|
||||||
|
def performance_stats_init():
|
||||||
|
try:
|
||||||
|
import psutil
|
||||||
|
process = psutil.Process()
|
||||||
|
# Ensure we can fetch both, and make the initial request for cpu_percent
|
||||||
|
# so the next request will use this as the initial point.
|
||||||
|
process.memory_info().rss
|
||||||
|
process.cpu_percent(interval=None)
|
||||||
|
logger.info("report_stats can use psutil")
|
||||||
|
stats_process.append(process)
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
logger.warn(
|
||||||
|
"report_stats enabled but psutil is not installed or incorrect version."
|
||||||
|
" Disabling reporting of memory/cpu stats."
|
||||||
|
" Ensuring psutil is available will help matrix.org track performance"
|
||||||
|
" changes across releases."
|
||||||
|
)
|
||||||
|
|
||||||
if hs.config.report_stats:
|
if hs.config.report_stats:
|
||||||
logger.info("Scheduling stats reporting for 3 hour intervals")
|
logger.info("Scheduling stats reporting for 3 hour intervals")
|
||||||
clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
|
clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
|
||||||
|
|
||||||
|
# We need to defer this init for the cases that we daemonize
|
||||||
|
# otherwise the process ID we get is that of the non-daemon process
|
||||||
|
clock.call_later(0, performance_stats_init)
|
||||||
|
|
||||||
# We wait 5 minutes to send the first set of stats as the server can
|
# We wait 5 minutes to send the first set of stats as the server can
|
||||||
# be quite busy the first few minutes
|
# be quite busy the first few minutes
|
||||||
clock.call_later(5 * 60, phone_stats_home)
|
clock.call_later(5 * 60, phone_stats_home)
|
||||||
|
|
|
@ -43,7 +43,7 @@ from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.media_repository")
|
logger = logging.getLogger("synapse.app.media_repository")
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ class MediaRepositoryServer(HomeServer):
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
@ -158,7 +158,6 @@ def start(config_options):
|
||||||
)
|
)
|
||||||
|
|
||||||
ss.setup()
|
ss.setup()
|
||||||
ss.get_handlers()
|
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
|
|
|
@ -32,13 +32,12 @@ from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, preserve_fn
|
from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.pusher")
|
logger = logging.getLogger("synapse.app.pusher")
|
||||||
|
|
||||||
|
@ -75,10 +74,6 @@ class PusherSlaveStore(
|
||||||
DataStore.get_profile_displayname.__func__
|
DataStore.get_profile_displayname.__func__
|
||||||
)
|
)
|
||||||
|
|
||||||
who_forgot_in_room = (
|
|
||||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PusherServer(HomeServer):
|
class PusherServer(HomeServer):
|
||||||
def setup(self):
|
def setup(self):
|
||||||
|
@ -99,7 +94,7 @@ class PusherServer(HomeServer):
|
||||||
if name == "metrics":
|
if name == "metrics":
|
||||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
|
|
@ -56,14 +56,12 @@ from synapse.util.manhole import manhole
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.synchrotron")
|
logger = logging.getLogger("synapse.app.synchrotron")
|
||||||
|
|
||||||
|
|
||||||
class SynchrotronSlavedStore(
|
class SynchrotronSlavedStore(
|
||||||
SlavedPushRuleStore,
|
|
||||||
SlavedEventStore,
|
|
||||||
SlavedReceiptsStore,
|
SlavedReceiptsStore,
|
||||||
SlavedAccountDataStore,
|
SlavedAccountDataStore,
|
||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
|
@ -73,14 +71,12 @@ class SynchrotronSlavedStore(
|
||||||
SlavedGroupServerStore,
|
SlavedGroupServerStore,
|
||||||
SlavedDeviceInboxStore,
|
SlavedDeviceInboxStore,
|
||||||
SlavedDeviceStore,
|
SlavedDeviceStore,
|
||||||
|
SlavedPushRuleStore,
|
||||||
|
SlavedEventStore,
|
||||||
SlavedClientIpStore,
|
SlavedClientIpStore,
|
||||||
RoomStore,
|
RoomStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
):
|
):
|
||||||
who_forgot_in_room = (
|
|
||||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
|
||||||
)
|
|
||||||
|
|
||||||
did_forget = (
|
did_forget = (
|
||||||
RoomMemberStore.__dict__["did_forget"]
|
RoomMemberStore.__dict__["did_forget"]
|
||||||
)
|
)
|
||||||
|
@ -273,7 +269,7 @@ class SynchrotronServer(HomeServer):
|
||||||
"/_matrix/client/api/v1": resource,
|
"/_matrix/client/api/v1": resource,
|
||||||
})
|
})
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
|
|
@ -38,7 +38,7 @@ def pid_running(pid):
|
||||||
try:
|
try:
|
||||||
os.kill(pid, 0)
|
os.kill(pid, 0)
|
||||||
return True
|
return True
|
||||||
except OSError, err:
|
except OSError as err:
|
||||||
if err.errno == errno.EPERM:
|
if err.errno == errno.EPERM:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@ -98,7 +98,7 @@ def stop(pidfile, app):
|
||||||
try:
|
try:
|
||||||
os.kill(pid, signal.SIGTERM)
|
os.kill(pid, signal.SIGTERM)
|
||||||
write("stopped %s" % (app,), colour=GREEN)
|
write("stopped %s" % (app,), colour=GREEN)
|
||||||
except OSError, err:
|
except OSError as err:
|
||||||
if err.errno == errno.ESRCH:
|
if err.errno == errno.ESRCH:
|
||||||
write("%s not running" % (app,), colour=YELLOW)
|
write("%s not running" % (app,), colour=YELLOW)
|
||||||
elif err.errno == errno.EPERM:
|
elif err.errno == errno.EPERM:
|
||||||
|
|
|
@ -43,7 +43,7 @@ from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import NoResource
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.user_dir")
|
logger = logging.getLogger("synapse.app.user_dir")
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ class UserDirectoryServer(HomeServer):
|
||||||
"/_matrix/client/api/v1": resource,
|
"/_matrix/client/api/v1": resource,
|
||||||
})
|
})
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, NoResource())
|
||||||
|
|
||||||
_base.listen_tcp(
|
_base.listen_tcp(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
|
|
|
@ -77,7 +77,9 @@ class RegistrationConfig(Config):
|
||||||
|
|
||||||
# Set the number of bcrypt rounds used to generate password hash.
|
# Set the number of bcrypt rounds used to generate password hash.
|
||||||
# Larger numbers increase the work factor needed to generate the hash.
|
# Larger numbers increase the work factor needed to generate the hash.
|
||||||
# The default number of rounds is 12.
|
# The default number is 12 (which equates to 2^12 rounds).
|
||||||
|
# N.B. that increasing this will exponentially increase the time required
|
||||||
|
# to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
|
||||||
bcrypt_rounds: 12
|
bcrypt_rounds: 12
|
||||||
|
|
||||||
# Allows users to register as guests without a password/email/etc, and
|
# Allows users to register as guests without a password/email/etc, and
|
||||||
|
|
|
@ -15,11 +15,3 @@
|
||||||
|
|
||||||
""" This package includes all the federation specific logic.
|
""" This package includes all the federation specific logic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .replication import ReplicationLayer
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_http_replication(hs):
|
|
||||||
transport = hs.get_federation_transport_client()
|
|
||||||
|
|
||||||
return ReplicationLayer(hs, transport)
|
|
||||||
|
|
|
@ -27,7 +27,13 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class FederationBase(object):
|
class FederationBase(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
self.keyring = hs.get_keyring()
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self._clock = hs.get_clock()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||||
|
|
|
@ -58,6 +58,7 @@ class FederationClient(FederationBase):
|
||||||
self._clear_tried_cache, 60 * 1000,
|
self._clear_tried_cache, 60 * 1000,
|
||||||
)
|
)
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
self.transport_layer = hs.get_federation_transport_client()
|
||||||
|
|
||||||
def _clear_tried_cache(self):
|
def _clear_tried_cache(self):
|
||||||
"""Clear pdu_destination_tried cache"""
|
"""Clear pdu_destination_tried cache"""
|
||||||
|
|
|
@ -17,12 +17,14 @@ import logging
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, FederationError, SynapseError
|
from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
from synapse.federation.federation_base import (
|
from synapse.federation.federation_base import (
|
||||||
FederationBase,
|
FederationBase,
|
||||||
event_from_pdu_json,
|
event_from_pdu_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from synapse.federation.persistence import TransactionActions
|
||||||
from synapse.federation.units import Edu, Transaction
|
from synapse.federation.units import Edu, Transaction
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
|
@ -52,50 +54,19 @@ class FederationServer(FederationBase):
|
||||||
super(FederationServer, self).__init__(hs)
|
super(FederationServer, self).__init__(hs)
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.handler = hs.get_handlers().federation_handler
|
||||||
|
|
||||||
self._server_linearizer = async.Linearizer("fed_server")
|
self._server_linearizer = async.Linearizer("fed_server")
|
||||||
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
|
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
|
||||||
|
|
||||||
|
self.transaction_actions = TransactionActions(self.store)
|
||||||
|
|
||||||
|
self.registry = hs.get_federation_registry()
|
||||||
|
|
||||||
# We cache responses to state queries, as they take a while and often
|
# We cache responses to state queries, as they take a while and often
|
||||||
# come in waves.
|
# come in waves.
|
||||||
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
|
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
|
||||||
|
|
||||||
def set_handler(self, handler):
|
|
||||||
"""Sets the handler that the replication layer will use to communicate
|
|
||||||
receipt of new PDUs from other home servers. The required methods are
|
|
||||||
documented on :py:class:`.ReplicationHandler`.
|
|
||||||
"""
|
|
||||||
self.handler = handler
|
|
||||||
|
|
||||||
def register_edu_handler(self, edu_type, handler):
|
|
||||||
if edu_type in self.edu_handlers:
|
|
||||||
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
|
|
||||||
|
|
||||||
self.edu_handlers[edu_type] = handler
|
|
||||||
|
|
||||||
def register_query_handler(self, query_type, handler):
|
|
||||||
"""Sets the handler callable that will be used to handle an incoming
|
|
||||||
federation Query of the given type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_type (str): Category name of the query, which should match
|
|
||||||
the string used by make_query.
|
|
||||||
handler (callable): Invoked to handle incoming queries of this type
|
|
||||||
|
|
||||||
handler is invoked as:
|
|
||||||
result = handler(args)
|
|
||||||
|
|
||||||
where 'args' is a dict mapping strings to strings of the query
|
|
||||||
arguments. It should return a Deferred that will eventually yield an
|
|
||||||
object to encode as JSON.
|
|
||||||
"""
|
|
||||||
if query_type in self.query_handlers:
|
|
||||||
raise KeyError(
|
|
||||||
"Already have a Query handler for %s" % (query_type,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.query_handlers[query_type] = handler
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_backfill_request(self, origin, room_id, versions, limit):
|
def on_backfill_request(self, origin, room_id, versions, limit):
|
||||||
|
@ -229,16 +200,7 @@ class FederationServer(FederationBase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def received_edu(self, origin, edu_type, content):
|
def received_edu(self, origin, edu_type, content):
|
||||||
received_edus_counter.inc()
|
received_edus_counter.inc()
|
||||||
|
yield self.registry.on_edu(edu_type, origin, content)
|
||||||
if edu_type in self.edu_handlers:
|
|
||||||
try:
|
|
||||||
yield self.edu_handlers[edu_type](origin, content)
|
|
||||||
except SynapseError as e:
|
|
||||||
logger.info("Failed to handle edu %r: %r", edu_type, e)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Failed to handle edu %r", edu_type)
|
|
||||||
else:
|
|
||||||
logger.warn("Received EDU of type %s with no handler", edu_type)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -328,14 +290,8 @@ class FederationServer(FederationBase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_query_request(self, query_type, args):
|
def on_query_request(self, query_type, args):
|
||||||
received_queries_counter.inc(query_type)
|
received_queries_counter.inc(query_type)
|
||||||
|
resp = yield self.registry.on_query(query_type, args)
|
||||||
if query_type in self.query_handlers:
|
defer.returnValue((200, resp))
|
||||||
response = yield self.query_handlers[query_type](args)
|
|
||||||
defer.returnValue((200, response))
|
|
||||||
else:
|
|
||||||
defer.returnValue(
|
|
||||||
(404, "No handler for Query type '%s'" % (query_type,))
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_make_join_request(self, room_id, user_id):
|
def on_make_join_request(self, room_id, user_id):
|
||||||
|
@ -607,3 +563,66 @@ class FederationServer(FederationBase):
|
||||||
origin, room_id, event_dict
|
origin, room_id, event_dict
|
||||||
)
|
)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
|
||||||
|
class FederationHandlerRegistry(object):
|
||||||
|
"""Allows classes to register themselves as handlers for a given EDU or
|
||||||
|
query type for incoming federation traffic.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.edu_handlers = {}
|
||||||
|
self.query_handlers = {}
|
||||||
|
|
||||||
|
def register_edu_handler(self, edu_type, handler):
|
||||||
|
"""Sets the handler callable that will be used to handle an incoming
|
||||||
|
federation EDU of the given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
edu_type (str): The type of the incoming EDU to register handler for
|
||||||
|
handler (Callable[[str, dict]]): A callable invoked on incoming EDU
|
||||||
|
of the given type. The arguments are the origin server name and
|
||||||
|
the EDU contents.
|
||||||
|
"""
|
||||||
|
if edu_type in self.edu_handlers:
|
||||||
|
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
|
||||||
|
|
||||||
|
self.edu_handlers[edu_type] = handler
|
||||||
|
|
||||||
|
def register_query_handler(self, query_type, handler):
|
||||||
|
"""Sets the handler callable that will be used to handle an incoming
|
||||||
|
federation query of the given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_type (str): Category name of the query, which should match
|
||||||
|
the string used by make_query.
|
||||||
|
handler (Callable[[dict], Deferred[dict]]): Invoked to handle
|
||||||
|
incoming queries of this type. The return will be yielded
|
||||||
|
on and the result used as the response to the query request.
|
||||||
|
"""
|
||||||
|
if query_type in self.query_handlers:
|
||||||
|
raise KeyError(
|
||||||
|
"Already have a Query handler for %s" % (query_type,)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_edu(self, edu_type, origin, content):
|
||||||
|
handler = self.edu_handlers.get(edu_type)
|
||||||
|
if not handler:
|
||||||
|
logger.warn("No handler registered for EDU type %s", edu_type)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield handler(origin, content)
|
||||||
|
except SynapseError as e:
|
||||||
|
logger.info("Failed to handle edu %r: %r", edu_type, e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to handle edu %r", edu_type)
|
||||||
|
|
||||||
|
def on_query(self, query_type, args):
|
||||||
|
handler = self.query_handlers.get(query_type)
|
||||||
|
if not handler:
|
||||||
|
logger.warn("No handler registered for query type %s", query_type)
|
||||||
|
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
|
||||||
|
|
||||||
|
return handler(args)
|
||||||
|
|
|
@ -1,73 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""This layer is responsible for replicating with remote home servers using
|
|
||||||
a given transport.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .federation_client import FederationClient
|
|
||||||
from .federation_server import FederationServer
|
|
||||||
|
|
||||||
from .persistence import TransactionActions
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplicationLayer(FederationClient, FederationServer):
|
|
||||||
"""This layer is responsible for replicating with remote home servers over
|
|
||||||
the given transport. I.e., does the sending and receiving of PDUs to
|
|
||||||
remote home servers.
|
|
||||||
|
|
||||||
The layer communicates with the rest of the server via a registered
|
|
||||||
ReplicationHandler.
|
|
||||||
|
|
||||||
In more detail, the layer:
|
|
||||||
* Receives incoming data and processes it into transactions and pdus.
|
|
||||||
* Fetches any PDUs it thinks it might have missed.
|
|
||||||
* Keeps the current state for contexts up to date by applying the
|
|
||||||
suitable conflict resolution.
|
|
||||||
* Sends outgoing pdus wrapped in transactions.
|
|
||||||
* Fills out the references to previous pdus/transactions appropriately
|
|
||||||
for outgoing data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hs, transport_layer):
|
|
||||||
self.server_name = hs.hostname
|
|
||||||
|
|
||||||
self.keyring = hs.get_keyring()
|
|
||||||
|
|
||||||
self.transport_layer = transport_layer
|
|
||||||
|
|
||||||
self.federation_client = self
|
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
|
||||||
|
|
||||||
self.handler = None
|
|
||||||
self.edu_handlers = {}
|
|
||||||
self.query_handlers = {}
|
|
||||||
|
|
||||||
self._clock = hs.get_clock()
|
|
||||||
|
|
||||||
self.transaction_actions = TransactionActions(self.store)
|
|
||||||
|
|
||||||
self.hs = hs
|
|
||||||
|
|
||||||
super(ReplicationLayer, self).__init__(hs)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "<ReplicationLayer(%s)>" % self.server_name
|
|
|
@ -1190,7 +1190,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
|
||||||
def register_servlets(hs, resource, authenticator, ratelimiter):
|
def register_servlets(hs, resource, authenticator, ratelimiter):
|
||||||
for servletclass in FEDERATION_SERVLET_CLASSES:
|
for servletclass in FEDERATION_SERVLET_CLASSES:
|
||||||
servletclass(
|
servletclass(
|
||||||
handler=hs.get_replication_layer(),
|
handler=hs.get_federation_server(),
|
||||||
authenticator=authenticator,
|
authenticator=authenticator,
|
||||||
ratelimiter=ratelimiter,
|
ratelimiter=ratelimiter,
|
||||||
server_name=hs.hostname,
|
server_name=hs.hostname,
|
||||||
|
|
|
@ -17,7 +17,6 @@ from .register import RegistrationHandler
|
||||||
from .room import (
|
from .room import (
|
||||||
RoomCreationHandler, RoomContextHandler,
|
RoomCreationHandler, RoomContextHandler,
|
||||||
)
|
)
|
||||||
from .room_member import RoomMemberHandler
|
|
||||||
from .message import MessageHandler
|
from .message import MessageHandler
|
||||||
from .federation import FederationHandler
|
from .federation import FederationHandler
|
||||||
from .directory import DirectoryHandler
|
from .directory import DirectoryHandler
|
||||||
|
@ -49,7 +48,6 @@ class Handlers(object):
|
||||||
self.registration_handler = RegistrationHandler(hs)
|
self.registration_handler = RegistrationHandler(hs)
|
||||||
self.message_handler = MessageHandler(hs)
|
self.message_handler = MessageHandler(hs)
|
||||||
self.room_creation_handler = RoomCreationHandler(hs)
|
self.room_creation_handler = RoomCreationHandler(hs)
|
||||||
self.room_member_handler = RoomMemberHandler(hs)
|
|
||||||
self.federation_handler = FederationHandler(hs)
|
self.federation_handler = FederationHandler(hs)
|
||||||
self.directory_handler = DirectoryHandler(hs)
|
self.directory_handler = DirectoryHandler(hs)
|
||||||
self.admin_handler = AdminHandler(hs)
|
self.admin_handler = AdminHandler(hs)
|
||||||
|
|
|
@ -158,7 +158,7 @@ class BaseHandler(object):
|
||||||
# homeserver.
|
# homeserver.
|
||||||
requester = synapse.types.create_requester(
|
requester = synapse.types.create_requester(
|
||||||
target_user, is_guest=True)
|
target_user, is_guest=True)
|
||||||
handler = self.hs.get_handlers().room_member_handler
|
handler = self.hs.get_room_member_handler()
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
target_user,
|
target_user,
|
||||||
|
|
|
@ -863,8 +863,10 @@ class AuthHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _do_validate_hash():
|
def _do_validate_hash():
|
||||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
return bcrypt.checkpw(
|
||||||
stored_hash.encode('utf8')) == stored_hash
|
password.encode('utf8') + self.hs.config.password_pepper,
|
||||||
|
stored_hash.encode('utf8')
|
||||||
|
)
|
||||||
|
|
||||||
if stored_hash:
|
if stored_hash:
|
||||||
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
|
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
|
||||||
|
|
|
@ -37,14 +37,15 @@ class DeviceHandler(BaseHandler):
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self.federation_sender = hs.get_federation_sender()
|
self.federation_sender = hs.get_federation_sender()
|
||||||
self.federation = hs.get_replication_layer()
|
|
||||||
|
|
||||||
self._edu_updater = DeviceListEduUpdater(hs, self)
|
self._edu_updater = DeviceListEduUpdater(hs, self)
|
||||||
|
|
||||||
self.federation.register_edu_handler(
|
federation_registry = hs.get_federation_registry()
|
||||||
|
|
||||||
|
federation_registry.register_edu_handler(
|
||||||
"m.device_list_update", self._edu_updater.incoming_device_list_update,
|
"m.device_list_update", self._edu_updater.incoming_device_list_update,
|
||||||
)
|
)
|
||||||
self.federation.register_query_handler(
|
federation_registry.register_query_handler(
|
||||||
"user_devices", self.on_federation_query_user_devices,
|
"user_devices", self.on_federation_query_user_devices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -154,7 +155,7 @@ class DeviceHandler(BaseHandler):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.store.delete_device(user_id, device_id)
|
yield self.store.delete_device(user_id, device_id)
|
||||||
except errors.StoreError, e:
|
except errors.StoreError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
# no match
|
# no match
|
||||||
pass
|
pass
|
||||||
|
@ -203,7 +204,7 @@ class DeviceHandler(BaseHandler):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.store.delete_devices(user_id, device_ids)
|
yield self.store.delete_devices(user_id, device_ids)
|
||||||
except errors.StoreError, e:
|
except errors.StoreError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
# no match
|
# no match
|
||||||
pass
|
pass
|
||||||
|
@ -242,7 +243,7 @@ class DeviceHandler(BaseHandler):
|
||||||
new_display_name=content.get("display_name")
|
new_display_name=content.get("display_name")
|
||||||
)
|
)
|
||||||
yield self.notify_device_update(user_id, [device_id])
|
yield self.notify_device_update(user_id, [device_id])
|
||||||
except errors.StoreError, e:
|
except errors.StoreError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
raise errors.NotFoundError()
|
raise errors.NotFoundError()
|
||||||
else:
|
else:
|
||||||
|
@ -430,7 +431,7 @@ class DeviceListEduUpdater(object):
|
||||||
|
|
||||||
def __init__(self, hs, device_handler):
|
def __init__(self, hs, device_handler):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_federation_client()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.device_handler = device_handler
|
self.device_handler = device_handler
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ class DeviceMessageHandler(object):
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.federation = hs.get_federation_sender()
|
self.federation = hs.get_federation_sender()
|
||||||
|
|
||||||
hs.get_replication_layer().register_edu_handler(
|
hs.get_federation_registry().register_edu_handler(
|
||||||
"m.direct_to_device", self.on_direct_to_device_edu
|
"m.direct_to_device", self.on_direct_to_device_edu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,8 +36,8 @@ class DirectoryHandler(BaseHandler):
|
||||||
self.appservice_handler = hs.get_application_service_handler()
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_federation_client()
|
||||||
self.federation.register_query_handler(
|
hs.get_federation_registry().register_query_handler(
|
||||||
"directory", self.on_directory_query
|
"directory", self.on_directory_query
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,7 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
@ -32,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||||
class E2eKeysHandler(object):
|
class E2eKeysHandler(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_federation_client()
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
@ -40,7 +41,7 @@ class E2eKeysHandler(object):
|
||||||
# doesn't really work as part of the generic query API, because the
|
# doesn't really work as part of the generic query API, because the
|
||||||
# query request requires an object POST, but we abuse the
|
# query request requires an object POST, but we abuse the
|
||||||
# "query handler" interface.
|
# "query handler" interface.
|
||||||
self.federation.register_query_handler(
|
hs.get_federation_registry().register_query_handler(
|
||||||
"client_keys", self.on_federation_query_client_keys
|
"client_keys", self.on_federation_query_client_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -134,23 +135,8 @@ class E2eKeysHandler(object):
|
||||||
if user_id in destination_query:
|
if user_id in destination_query:
|
||||||
results[user_id] = keys
|
results[user_id] = keys
|
||||||
|
|
||||||
except CodeMessageException as e:
|
|
||||||
failures[destination] = {
|
|
||||||
"status": e.code, "message": e.message
|
|
||||||
}
|
|
||||||
except NotRetryingDestination as e:
|
|
||||||
failures[destination] = {
|
|
||||||
"status": 503, "message": "Not ready for retry",
|
|
||||||
}
|
|
||||||
except FederationDeniedError as e:
|
|
||||||
failures[destination] = {
|
|
||||||
"status": 403, "message": "Federation Denied",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# include ConnectionRefused and other errors
|
failures[destination] = _exception_to_failure(e)
|
||||||
failures[destination] = {
|
|
||||||
"status": 503, "message": e.message
|
|
||||||
}
|
|
||||||
|
|
||||||
yield make_deferred_yieldable(defer.gatherResults([
|
yield make_deferred_yieldable(defer.gatherResults([
|
||||||
preserve_fn(do_remote_query)(destination)
|
preserve_fn(do_remote_query)(destination)
|
||||||
|
@ -252,19 +238,8 @@ class E2eKeysHandler(object):
|
||||||
for user_id, keys in remote_result["one_time_keys"].items():
|
for user_id, keys in remote_result["one_time_keys"].items():
|
||||||
if user_id in device_keys:
|
if user_id in device_keys:
|
||||||
json_result[user_id] = keys
|
json_result[user_id] = keys
|
||||||
except CodeMessageException as e:
|
|
||||||
failures[destination] = {
|
|
||||||
"status": e.code, "message": e.message
|
|
||||||
}
|
|
||||||
except NotRetryingDestination as e:
|
|
||||||
failures[destination] = {
|
|
||||||
"status": 503, "message": "Not ready for retry",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# include ConnectionRefused and other errors
|
failures[destination] = _exception_to_failure(e)
|
||||||
failures[destination] = {
|
|
||||||
"status": 503, "message": e.message
|
|
||||||
}
|
|
||||||
|
|
||||||
yield make_deferred_yieldable(defer.gatherResults([
|
yield make_deferred_yieldable(defer.gatherResults([
|
||||||
preserve_fn(claim_client_keys)(destination)
|
preserve_fn(claim_client_keys)(destination)
|
||||||
|
@ -362,6 +337,31 @@ class E2eKeysHandler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _exception_to_failure(e):
|
||||||
|
if isinstance(e, CodeMessageException):
|
||||||
|
return {
|
||||||
|
"status": e.code, "message": e.message,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(e, NotRetryingDestination):
|
||||||
|
return {
|
||||||
|
"status": 503, "message": "Not ready for retry",
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(e, FederationDeniedError):
|
||||||
|
return {
|
||||||
|
"status": 403, "message": "Federation Denied",
|
||||||
|
}
|
||||||
|
|
||||||
|
# include ConnectionRefused and other errors
|
||||||
|
#
|
||||||
|
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
|
||||||
|
# give a string for e.message, which simplejson then fails to serialize.
|
||||||
|
return {
|
||||||
|
"status": 503, "message": str(e.message),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _one_time_keys_match(old_key_json, new_key):
|
def _one_time_keys_match(old_key_json, new_key):
|
||||||
old_key = json.loads(old_key_json)
|
old_key = json.loads(old_key_json)
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,7 @@ class FederationHandler(BaseHandler):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.replication_layer = hs.get_replication_layer()
|
self.replication_layer = hs.get_federation_client()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
|
@ -78,8 +78,6 @@ class FederationHandler(BaseHandler):
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
|
||||||
self.replication_layer.set_handler(self)
|
|
||||||
|
|
||||||
# When joining a room we need to queue any events for that room up
|
# When joining a room we need to queue any events for that room up
|
||||||
self.room_queues = {}
|
self.room_queues = {}
|
||||||
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
||||||
|
@ -1447,16 +1445,24 @@ class FederationHandler(BaseHandler):
|
||||||
auth_events=auth_events,
|
auth_events=auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not event.internal_metadata.is_outlier() and not backfilled:
|
try:
|
||||||
yield self.action_generator.handle_push_actions_for_event(
|
if not event.internal_metadata.is_outlier() and not backfilled:
|
||||||
event, context
|
yield self.action_generator.handle_push_actions_for_event(
|
||||||
)
|
event, context
|
||||||
|
)
|
||||||
|
|
||||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||||
event,
|
event,
|
||||||
context=context,
|
context=context,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
|
except: # noqa: E722, as we reraise the exception this is fine.
|
||||||
|
# Ensure that we actually remove the entries in the push actions
|
||||||
|
# staging area
|
||||||
|
logcontext.preserve_fn(
|
||||||
|
self.store.remove_push_actions_from_staging
|
||||||
|
)(event.event_id)
|
||||||
|
raise
|
||||||
|
|
||||||
if not backfilled:
|
if not backfilled:
|
||||||
# this intentionally does not yield: we don't care about the result
|
# this intentionally does not yield: we don't care about the result
|
||||||
|
@ -2145,7 +2151,7 @@ class FederationHandler(BaseHandler):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
yield self._check_signature(event, context)
|
yield self._check_signature(event, context)
|
||||||
member_handler = self.hs.get_handlers().room_member_handler
|
member_handler = self.hs.get_room_member_handler()
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
else:
|
else:
|
||||||
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
|
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
|
||||||
|
@ -2189,7 +2195,7 @@ class FederationHandler(BaseHandler):
|
||||||
# TODO: Make sure the signatures actually are correct.
|
# TODO: Make sure the signatures actually are correct.
|
||||||
event.signatures.update(returned_invite.signatures)
|
event.signatures.update(returned_invite.signatures)
|
||||||
|
|
||||||
member_handler = self.hs.get_handlers().room_member_handler
|
member_handler = self.hs.get_room_member_handler()
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -15,6 +15,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Utilities for interacting with Identity Servers"""
|
"""Utilities for interacting with Identity Servers"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
|
@ -24,9 +29,6 @@ from ._base import BaseHandler
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||||
|
@ -24,9 +25,10 @@ from synapse.types import (
|
||||||
UserID, RoomAlias, RoomStreamToken,
|
UserID, RoomAlias, RoomStreamToken,
|
||||||
)
|
)
|
||||||
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
|
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn, run_in_background
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import frozendict_json_encoder
|
||||||
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
from synapse.replication.http.send_event import send_event_to_master
|
from synapse.replication.http.send_event import send_event_to_master
|
||||||
|
|
||||||
|
@ -36,11 +38,41 @@ from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import ujson
|
import simplejson
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PurgeStatus(object):
|
||||||
|
"""Object tracking the status of a purge request
|
||||||
|
|
||||||
|
This class contains information on the progress of a purge request, for
|
||||||
|
return by get_purge_status.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
status (int): Tracks whether this request has completed. One of
|
||||||
|
STATUS_{ACTIVE,COMPLETE,FAILED}
|
||||||
|
"""
|
||||||
|
|
||||||
|
STATUS_ACTIVE = 0
|
||||||
|
STATUS_COMPLETE = 1
|
||||||
|
STATUS_FAILED = 2
|
||||||
|
|
||||||
|
STATUS_TEXT = {
|
||||||
|
STATUS_ACTIVE: "active",
|
||||||
|
STATUS_COMPLETE: "complete",
|
||||||
|
STATUS_FAILED: "failed",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.status = PurgeStatus.STATUS_ACTIVE
|
||||||
|
|
||||||
|
def asdict(self):
|
||||||
|
return {
|
||||||
|
"status": PurgeStatus.STATUS_TEXT[self.status]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MessageHandler(BaseHandler):
|
class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -50,18 +82,87 @@ class MessageHandler(BaseHandler):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
self.pagination_lock = ReadWriteLock()
|
self.pagination_lock = ReadWriteLock()
|
||||||
|
self._purges_in_progress_by_room = set()
|
||||||
|
# map from purge id to PurgeStatus
|
||||||
|
self._purges_by_id = {}
|
||||||
|
|
||||||
|
def start_purge_history(self, room_id, topological_ordering,
|
||||||
|
delete_local_events=False):
|
||||||
|
"""Start off a history purge on a room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str): The room to purge from
|
||||||
|
|
||||||
|
topological_ordering (int): minimum topo ordering to preserve
|
||||||
|
delete_local_events (bool): True to delete local events as well as
|
||||||
|
remote ones
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: unique ID for this purge transaction.
|
||||||
|
"""
|
||||||
|
if room_id in self._purges_in_progress_by_room:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"History purge already in progress for %s" % (room_id, ),
|
||||||
|
)
|
||||||
|
|
||||||
|
purge_id = random_string(16)
|
||||||
|
|
||||||
|
# we log the purge_id here so that it can be tied back to the
|
||||||
|
# request id in the log lines.
|
||||||
|
logger.info("[purge] starting purge_id %s", purge_id)
|
||||||
|
|
||||||
|
self._purges_by_id[purge_id] = PurgeStatus()
|
||||||
|
run_in_background(
|
||||||
|
self._purge_history,
|
||||||
|
purge_id, room_id, topological_ordering, delete_local_events,
|
||||||
|
)
|
||||||
|
return purge_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def purge_history(self, room_id, event_id, delete_local_events=False):
|
def _purge_history(self, purge_id, room_id, topological_ordering,
|
||||||
event = yield self.store.get_event(event_id)
|
delete_local_events):
|
||||||
|
"""Carry out a history purge on a room.
|
||||||
|
|
||||||
if event.room_id != room_id:
|
Args:
|
||||||
raise SynapseError(400, "Event is for wrong room.")
|
purge_id (str): The id for this purge
|
||||||
|
room_id (str): The room to purge from
|
||||||
|
topological_ordering (int): minimum topo ordering to preserve
|
||||||
|
delete_local_events (bool): True to delete local events as well as
|
||||||
|
remote ones
|
||||||
|
|
||||||
depth = event.depth
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
self._purges_in_progress_by_room.add(room_id)
|
||||||
|
try:
|
||||||
|
with (yield self.pagination_lock.write(room_id)):
|
||||||
|
yield self.store.purge_history(
|
||||||
|
room_id, topological_ordering, delete_local_events,
|
||||||
|
)
|
||||||
|
logger.info("[purge] complete")
|
||||||
|
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
|
||||||
|
except Exception:
|
||||||
|
logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
|
||||||
|
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
|
||||||
|
finally:
|
||||||
|
self._purges_in_progress_by_room.discard(room_id)
|
||||||
|
|
||||||
with (yield self.pagination_lock.write(room_id)):
|
# remove the purge from the list 24 hours after it completes
|
||||||
yield self.store.purge_history(room_id, depth, delete_local_events)
|
def clear_purge():
|
||||||
|
del self._purges_by_id[purge_id]
|
||||||
|
reactor.callLater(24 * 3600, clear_purge)
|
||||||
|
|
||||||
|
def get_purge_status(self, purge_id):
|
||||||
|
"""Get the current status of an active purge
|
||||||
|
|
||||||
|
Args:
|
||||||
|
purge_id (str): purge_id returned by start_purge_history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PurgeStatus|None
|
||||||
|
"""
|
||||||
|
return self._purges_by_id.get(purge_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_messages(self, requester, room_id=None, pagin_config=None,
|
def get_messages(self, requester, room_id=None, pagin_config=None,
|
||||||
|
@ -553,24 +654,21 @@ class EventCreationHandler(object):
|
||||||
event,
|
event,
|
||||||
context,
|
context,
|
||||||
ratelimit=True,
|
ratelimit=True,
|
||||||
extra_users=[]
|
extra_users=[],
|
||||||
):
|
):
|
||||||
# We now need to go and hit out to wherever we need to hit out to.
|
"""Processes a new event. This includes checking auth, persisting it,
|
||||||
|
notifying users, sending to remote servers, etc.
|
||||||
|
|
||||||
# If we're a worker we need to hit out to the master.
|
If called from a worker will hit out to the master process for final
|
||||||
if self.config.worker_app:
|
processing.
|
||||||
yield send_event_to_master(
|
|
||||||
self.http_client,
|
|
||||||
host=self.config.worker_replication_host,
|
|
||||||
port=self.config.worker_replication_http_port,
|
|
||||||
requester=requester,
|
|
||||||
event=event,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if ratelimit:
|
Args:
|
||||||
yield self.base_handler.ratelimit(requester)
|
requester (Requester)
|
||||||
|
event (FrozenEvent)
|
||||||
|
context (EventContext)
|
||||||
|
ratelimit (bool)
|
||||||
|
extra_users (list(UserID)): Any extra users to notify about event
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.auth.check_from_context(event, context)
|
yield self.auth.check_from_context(event, context)
|
||||||
|
@ -580,12 +678,63 @@ class EventCreationHandler(object):
|
||||||
|
|
||||||
# Ensure that we can round trip before trying to persist in db
|
# Ensure that we can round trip before trying to persist in db
|
||||||
try:
|
try:
|
||||||
dump = ujson.dumps(unfreeze(event.content))
|
dump = frozendict_json_encoder.encode(event.content)
|
||||||
ujson.loads(dump)
|
simplejson.loads(dump)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to encode content: %r", event.content)
|
logger.exception("Failed to encode content: %r", event.content)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
yield self.action_generator.handle_push_actions_for_event(
|
||||||
|
event, context
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If we're a worker we need to hit out to the master.
|
||||||
|
if self.config.worker_app:
|
||||||
|
yield send_event_to_master(
|
||||||
|
self.http_client,
|
||||||
|
host=self.config.worker_replication_host,
|
||||||
|
port=self.config.worker_replication_http_port,
|
||||||
|
requester=requester,
|
||||||
|
event=event,
|
||||||
|
context=context,
|
||||||
|
ratelimit=ratelimit,
|
||||||
|
extra_users=extra_users,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
yield self.persist_and_notify_client_event(
|
||||||
|
requester,
|
||||||
|
event,
|
||||||
|
context,
|
||||||
|
ratelimit=ratelimit,
|
||||||
|
extra_users=extra_users,
|
||||||
|
)
|
||||||
|
except: # noqa: E722, as we reraise the exception this is fine.
|
||||||
|
# Ensure that we actually remove the entries in the push actions
|
||||||
|
# staging area, if we calculated them.
|
||||||
|
preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def persist_and_notify_client_event(
|
||||||
|
self,
|
||||||
|
requester,
|
||||||
|
event,
|
||||||
|
context,
|
||||||
|
ratelimit=True,
|
||||||
|
extra_users=[],
|
||||||
|
):
|
||||||
|
"""Called when we have fully built the event, have already
|
||||||
|
calculated the push actions for the event, and checked auth.
|
||||||
|
|
||||||
|
This should only be run on master.
|
||||||
|
"""
|
||||||
|
assert not self.config.worker_app
|
||||||
|
|
||||||
|
if ratelimit:
|
||||||
|
yield self.base_handler.ratelimit(requester)
|
||||||
|
|
||||||
yield self.base_handler.maybe_kick_guest_users(event, context)
|
yield self.base_handler.maybe_kick_guest_users(event, context)
|
||||||
|
|
||||||
if event.type == EventTypes.CanonicalAlias:
|
if event.type == EventTypes.CanonicalAlias:
|
||||||
|
@ -679,20 +828,10 @@ class EventCreationHandler(object):
|
||||||
"Changing the room create event is forbidden",
|
"Changing the room create event is forbidden",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.action_generator.handle_push_actions_for_event(
|
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
||||||
event, context
|
event, context=context
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
|
||||||
event, context=context
|
|
||||||
)
|
|
||||||
except: # noqa: E722, as we reraise the exception this is fine.
|
|
||||||
# Ensure that we actually remove the entries in the push actions
|
|
||||||
# staging area
|
|
||||||
preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# this intentionally does not yield: we don't care about the result
|
# this intentionally does not yield: we don't care about the result
|
||||||
# and don't need to wait for it.
|
# and don't need to wait for it.
|
||||||
preserve_fn(self.pusher_pool.on_new_notifications)(
|
preserve_fn(self.pusher_pool.on_new_notifications)(
|
||||||
|
|
|
@ -93,29 +93,30 @@ class PresenceHandler(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.wheel_timer = WheelTimer()
|
self.wheel_timer = WheelTimer()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.replication = hs.get_replication_layer()
|
|
||||||
self.federation = hs.get_federation_sender()
|
self.federation = hs.get_federation_sender()
|
||||||
|
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
self.replication.register_edu_handler(
|
federation_registry = hs.get_federation_registry()
|
||||||
|
|
||||||
|
federation_registry.register_edu_handler(
|
||||||
"m.presence", self.incoming_presence
|
"m.presence", self.incoming_presence
|
||||||
)
|
)
|
||||||
self.replication.register_edu_handler(
|
federation_registry.register_edu_handler(
|
||||||
"m.presence_invite",
|
"m.presence_invite",
|
||||||
lambda origin, content: self.invite_presence(
|
lambda origin, content: self.invite_presence(
|
||||||
observed_user=UserID.from_string(content["observed_user"]),
|
observed_user=UserID.from_string(content["observed_user"]),
|
||||||
observer_user=UserID.from_string(content["observer_user"]),
|
observer_user=UserID.from_string(content["observer_user"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.replication.register_edu_handler(
|
federation_registry.register_edu_handler(
|
||||||
"m.presence_accept",
|
"m.presence_accept",
|
||||||
lambda origin, content: self.accept_presence(
|
lambda origin, content: self.accept_presence(
|
||||||
observed_user=UserID.from_string(content["observed_user"]),
|
observed_user=UserID.from_string(content["observed_user"]),
|
||||||
observer_user=UserID.from_string(content["observer_user"]),
|
observer_user=UserID.from_string(content["observer_user"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.replication.register_edu_handler(
|
federation_registry.register_edu_handler(
|
||||||
"m.presence_deny",
|
"m.presence_deny",
|
||||||
lambda origin, content: self.deny_presence(
|
lambda origin, content: self.deny_presence(
|
||||||
observed_user=UserID.from_string(content["observed_user"]),
|
observed_user=UserID.from_string(content["observed_user"]),
|
||||||
|
|
|
@ -31,14 +31,17 @@ class ProfileHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileHandler, self).__init__(hs)
|
super(ProfileHandler, self).__init__(hs)
|
||||||
|
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_federation_client()
|
||||||
self.federation.register_query_handler(
|
hs.get_federation_registry().register_query_handler(
|
||||||
"profile", self.on_profile_query
|
"profile", self.on_profile_query
|
||||||
)
|
)
|
||||||
|
|
||||||
self.user_directory_handler = hs.get_user_directory_handler()
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
|
|
||||||
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
|
if hs.config.worker_app is None:
|
||||||
|
self.clock.looping_call(
|
||||||
|
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_profile(self, user_id):
|
def get_profile(self, user_id):
|
||||||
|
@ -233,7 +236,7 @@ class ProfileHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
handler = self.hs.get_handlers().room_member_handler
|
handler = self.hs.get_room_member_handler()
|
||||||
try:
|
try:
|
||||||
# Assume the target_user isn't a guest,
|
# Assume the target_user isn't a guest,
|
||||||
# because we don't let guests set profile or avatar data.
|
# because we don't let guests set profile or avatar data.
|
||||||
|
|
|
@ -41,9 +41,9 @@ class ReadMarkerHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with (yield self.read_marker_linearizer.queue((room_id, user_id))):
|
with (yield self.read_marker_linearizer.queue((room_id, user_id))):
|
||||||
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
|
existing_read_marker = yield self.store.get_account_data_for_room_and_type(
|
||||||
|
user_id, room_id, "m.fully_read",
|
||||||
existing_read_marker = account_data.get("m.fully_read", None)
|
)
|
||||||
|
|
||||||
should_update = True
|
should_update = True
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.federation = hs.get_federation_sender()
|
self.federation = hs.get_federation_sender()
|
||||||
hs.get_replication_layer().register_edu_handler(
|
hs.get_federation_registry().register_edu_handler(
|
||||||
"m.receipt", self._received_remote_receipt
|
"m.receipt", self._received_remote_receipt
|
||||||
)
|
)
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
|
|
@ -24,7 +24,7 @@ from synapse.api.errors import (
|
||||||
from synapse.http.client import CaptchaServerHttpClient
|
from synapse.http.client import CaptchaServerHttpClient
|
||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor, Linearizer
|
||||||
from synapse.util.threepids import check_3pid_allowed
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -46,6 +46,10 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
|
self._generate_user_id_linearizer = Linearizer(
|
||||||
|
name="_generate_user_id_linearizer",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_username(self, localpart, guest_access_token=None,
|
def check_username(self, localpart, guest_access_token=None,
|
||||||
assigned_user_id=None):
|
assigned_user_id=None):
|
||||||
|
@ -345,9 +349,11 @@ class RegistrationHandler(BaseHandler):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_user_id(self, reseed=False):
|
def _generate_user_id(self, reseed=False):
|
||||||
if reseed or self._next_generated_user_id is None:
|
if reseed or self._next_generated_user_id is None:
|
||||||
self._next_generated_user_id = (
|
with (yield self._generate_user_id_linearizer.queue(())):
|
||||||
yield self.store.find_next_generated_user_id_localpart()
|
if reseed or self._next_generated_user_id is None:
|
||||||
)
|
self._next_generated_user_id = (
|
||||||
|
yield self.store.find_next_generated_user_id_localpart()
|
||||||
|
)
|
||||||
|
|
||||||
id = self._next_generated_user_id
|
id = self._next_generated_user_id
|
||||||
self._next_generated_user_id += 1
|
self._next_generated_user_id += 1
|
||||||
|
@ -446,16 +452,34 @@ class RegistrationHandler(BaseHandler):
|
||||||
return self.hs.get_auth_handler()
|
return self.hs.get_auth_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def guest_access_token_for(self, medium, address, inviter_user_id):
|
def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
|
||||||
|
"""Get a guest access token for a 3PID, creating a guest account if
|
||||||
|
one doesn't already exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
medium (str)
|
||||||
|
address (str)
|
||||||
|
inviter_user_id (str): The user ID who is trying to invite the
|
||||||
|
3PID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
|
||||||
|
3PID guest account.
|
||||||
|
"""
|
||||||
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
|
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
|
||||||
if access_token:
|
if access_token:
|
||||||
defer.returnValue(access_token)
|
user_info = yield self.auth.get_user_by_access_token(
|
||||||
|
access_token
|
||||||
|
)
|
||||||
|
|
||||||
_, access_token = yield self.register(
|
defer.returnValue((user_info["user"].to_string(), access_token))
|
||||||
|
|
||||||
|
user_id, access_token = yield self.register(
|
||||||
generate_token=True,
|
generate_token=True,
|
||||||
make_guest=True
|
make_guest=True
|
||||||
)
|
)
|
||||||
access_token = yield self.store.save_or_get_3pid_guest_access_token(
|
access_token = yield self.store.save_or_get_3pid_guest_access_token(
|
||||||
medium, address, access_token, inviter_user_id
|
medium, address, access_token, inviter_user_id
|
||||||
)
|
)
|
||||||
defer.returnValue(access_token)
|
|
||||||
|
defer.returnValue((user_id, access_token))
|
||||||
|
|
|
@ -165,7 +165,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
|
|
||||||
creation_content = config.get("creation_content", {})
|
creation_content = config.get("creation_content", {})
|
||||||
|
|
||||||
room_member_handler = self.hs.get_handlers().room_member_handler
|
room_member_handler = self.hs.get_room_member_handler()
|
||||||
|
|
||||||
yield self._send_events_for_new_room(
|
yield self._send_events_for_new_room(
|
||||||
requester,
|
requester,
|
||||||
|
@ -224,7 +224,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
id_server = invite_3pid["id_server"]
|
id_server = invite_3pid["id_server"]
|
||||||
address = invite_3pid["address"]
|
address = invite_3pid["address"]
|
||||||
medium = invite_3pid["medium"]
|
medium = invite_3pid["medium"]
|
||||||
yield self.hs.get_handlers().room_member_handler.do_3pid_invite(
|
yield self.hs.get_room_member_handler().do_3pid_invite(
|
||||||
room_id,
|
room_id,
|
||||||
requester.user,
|
requester.user,
|
||||||
medium,
|
medium,
|
||||||
|
@ -475,12 +475,9 @@ class RoomEventSource(object):
|
||||||
user.to_string()
|
user.to_string()
|
||||||
)
|
)
|
||||||
if app_service:
|
if app_service:
|
||||||
events, end_key = yield self.store.get_appservice_room_stream(
|
# We no longer support AS users using /sync directly.
|
||||||
service=app_service,
|
# See https://github.com/matrix-org/matrix-doc/issues/1144
|
||||||
from_key=from_key,
|
raise NotImplementedError()
|
||||||
to_key=to_key,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
room_events = yield self.store.get_membership_changes_for_user(
|
room_events = yield self.store.get_membership_changes_for_user(
|
||||||
user.to_string(), from_key, to_key
|
user.to_string(), from_key, to_key
|
||||||
|
|
|
@ -409,7 +409,7 @@ class RoomListHandler(BaseHandler):
|
||||||
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
|
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
|
||||||
search_filter=None, include_all_networks=False,
|
search_filter=None, include_all_networks=False,
|
||||||
third_party_instance_id=None,):
|
third_party_instance_id=None,):
|
||||||
repl_layer = self.hs.get_replication_layer()
|
repl_layer = self.hs.get_federation_client()
|
||||||
if search_filter:
|
if search_filter:
|
||||||
# We can't cache when asking for search
|
# We can't cache when asking for search
|
||||||
return repl_layer.get_public_rooms(
|
return repl_layer.get_public_rooms(
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
@ -30,22 +30,32 @@ from synapse.api.errors import AuthError, SynapseError, Codes
|
||||||
from synapse.types import UserID, RoomID
|
from synapse.types import UserID, RoomID
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.distributor import user_left_room, user_joined_room
|
from synapse.util.distributor import user_left_room, user_joined_room
|
||||||
from ._base import BaseHandler
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
id_server_scheme = "https://"
|
id_server_scheme = "https://"
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberHandler(BaseHandler):
|
class RoomMemberHandler(object):
|
||||||
# TODO(paul): This handler currently contains a messy conflation of
|
# TODO(paul): This handler currently contains a messy conflation of
|
||||||
# low-level API that works on UserID objects and so on, and REST-level
|
# low-level API that works on UserID objects and so on, and REST-level
|
||||||
# API that takes ID strings and returns pagination chunks. These concerns
|
# API that takes ID strings and returns pagination chunks. These concerns
|
||||||
# ought to be separated out a lot better.
|
# ought to be separated out a lot better.
|
||||||
|
|
||||||
def __init__(self, hs):
|
__metaclass__ = abc.ABCMeta
|
||||||
super(RoomMemberHandler, self).__init__(hs)
|
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.state_handler = hs.get_state_handler()
|
||||||
|
self.config = hs.config
|
||||||
|
self.simple_http_client = hs.get_simple_http_client()
|
||||||
|
|
||||||
|
self.federation_handler = hs.get_handlers().federation_handler
|
||||||
|
self.directory_handler = hs.get_handlers().directory_handler
|
||||||
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.event_creation_hander = hs.get_event_creation_handler()
|
self.event_creation_hander = hs.get_event_creation_handler()
|
||||||
|
|
||||||
|
@ -54,9 +64,87 @@ class RoomMemberHandler(BaseHandler):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
self.distributor = hs.get_distributor()
|
@abc.abstractmethod
|
||||||
self.distributor.declare("user_joined_room")
|
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||||
self.distributor.declare("user_left_room")
|
"""Try and join a room that this server is not in
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester (Requester)
|
||||||
|
remote_room_hosts (list[str]): List of servers that can be used
|
||||||
|
to join via.
|
||||||
|
room_id (str): Room that we are trying to join
|
||||||
|
user (UserID): User who is trying to join
|
||||||
|
content (dict): A dict that should be used as the content of the
|
||||||
|
join event.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _remote_reject_invite(self, remote_room_hosts, room_id, target):
|
||||||
|
"""Attempt to reject an invite for a room this server is not in. If we
|
||||||
|
fail to do so we locally mark the invite as rejected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester (Requester)
|
||||||
|
remote_room_hosts (list[str]): List of servers to use to try and
|
||||||
|
reject invite
|
||||||
|
room_id (str)
|
||||||
|
target (UserID): The user rejecting the invite
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict]: A dictionary to be returned to the client, may
|
||||||
|
include event_id etc, or nothing if we locally rejected
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
|
||||||
|
"""Get a guest access token for a 3PID, creating a guest account if
|
||||||
|
one doesn't already exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester (Requester)
|
||||||
|
medium (str)
|
||||||
|
address (str)
|
||||||
|
inviter_user_id (str): The user ID who is trying to invite the
|
||||||
|
3PID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
|
||||||
|
3PID guest account.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _user_joined_room(self, target, room_id):
|
||||||
|
"""Notifies distributor on master process that the user has joined the
|
||||||
|
room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (UserID)
|
||||||
|
room_id (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred|None
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _user_left_room(self, target, room_id):
|
||||||
|
"""Notifies distributor on master process that the user has left the
|
||||||
|
room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (UserID)
|
||||||
|
room_id (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred|None
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _local_membership_update(
|
def _local_membership_update(
|
||||||
|
@ -120,32 +208,15 @@ class RoomMemberHandler(BaseHandler):
|
||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
if newly_joined:
|
if newly_joined:
|
||||||
yield user_joined_room(self.distributor, target, room_id)
|
yield self._user_joined_room(target, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
if prev_member_event.membership == Membership.JOIN:
|
if prev_member_event.membership == Membership.JOIN:
|
||||||
user_left_room(self.distributor, target, room_id)
|
yield self._user_left_room(target, room_id)
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def remote_join(self, remote_room_hosts, room_id, user, content):
|
|
||||||
if len(remote_room_hosts) == 0:
|
|
||||||
raise SynapseError(404, "No known servers")
|
|
||||||
|
|
||||||
# We don't do an auth check if we are doing an invite
|
|
||||||
# join dance for now, since we're kinda implicitly checking
|
|
||||||
# that we are allowed to join when we decide whether or not we
|
|
||||||
# need to do the invite/join dance.
|
|
||||||
yield self.hs.get_handlers().federation_handler.do_invite_join(
|
|
||||||
remote_room_hosts,
|
|
||||||
room_id,
|
|
||||||
user.to_string(),
|
|
||||||
content,
|
|
||||||
)
|
|
||||||
yield user_joined_room(self.distributor, user, room_id)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_membership(
|
def update_membership(
|
||||||
self,
|
self,
|
||||||
|
@ -204,8 +275,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
||||||
# invite into a normal invite before we can handle the join.
|
# invite into a normal invite before we can handle the join.
|
||||||
if third_party_signed is not None:
|
if third_party_signed is not None:
|
||||||
replication = self.hs.get_replication_layer()
|
yield self.federation_handler.exchange_third_party_invite(
|
||||||
yield replication.exchange_third_party_invite(
|
|
||||||
third_party_signed["sender"],
|
third_party_signed["sender"],
|
||||||
target.to_string(),
|
target.to_string(),
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -226,7 +296,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
requester.user,
|
requester.user,
|
||||||
)
|
)
|
||||||
if not is_requester_admin:
|
if not is_requester_admin:
|
||||||
if self.hs.config.block_non_admin_invites:
|
if self.config.block_non_admin_invites:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Blocking invite: user is not admin and non-admin "
|
"Blocking invite: user is not admin and non-admin "
|
||||||
"invites disabled"
|
"invites disabled"
|
||||||
|
@ -285,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
|
||||||
if not is_host_in_room:
|
if not is_host_in_room:
|
||||||
inviter = yield self.get_inviter(target.to_string(), room_id)
|
inviter = yield self._get_inviter(target.to_string(), room_id)
|
||||||
if inviter and not self.hs.is_mine(inviter):
|
if inviter and not self.hs.is_mine(inviter):
|
||||||
remote_room_hosts.append(inviter.domain)
|
remote_room_hosts.append(inviter.domain)
|
||||||
|
|
||||||
|
@ -299,15 +369,15 @@ class RoomMemberHandler(BaseHandler):
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
content["kind"] = "guest"
|
content["kind"] = "guest"
|
||||||
|
|
||||||
ret = yield self.remote_join(
|
ret = yield self._remote_join(
|
||||||
remote_room_hosts, room_id, target, content
|
requester, remote_room_hosts, room_id, target, content
|
||||||
)
|
)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
elif effective_membership_state == Membership.LEAVE:
|
elif effective_membership_state == Membership.LEAVE:
|
||||||
if not is_host_in_room:
|
if not is_host_in_room:
|
||||||
# perhaps we've been invited
|
# perhaps we've been invited
|
||||||
inviter = yield self.get_inviter(target.to_string(), room_id)
|
inviter = yield self._get_inviter(target.to_string(), room_id)
|
||||||
if not inviter:
|
if not inviter:
|
||||||
raise SynapseError(404, "Not a known room")
|
raise SynapseError(404, "Not a known room")
|
||||||
|
|
||||||
|
@ -321,28 +391,10 @@ class RoomMemberHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
# send the rejection to the inviter's HS.
|
# send the rejection to the inviter's HS.
|
||||||
remote_room_hosts = remote_room_hosts + [inviter.domain]
|
remote_room_hosts = remote_room_hosts + [inviter.domain]
|
||||||
fed_handler = self.hs.get_handlers().federation_handler
|
res = yield self._remote_reject_invite(
|
||||||
try:
|
requester, remote_room_hosts, room_id, target,
|
||||||
ret = yield fed_handler.do_remotely_reject_invite(
|
)
|
||||||
remote_room_hosts,
|
defer.returnValue(res)
|
||||||
room_id,
|
|
||||||
target.to_string(),
|
|
||||||
)
|
|
||||||
defer.returnValue(ret)
|
|
||||||
except Exception as e:
|
|
||||||
# if we were unable to reject the exception, just mark
|
|
||||||
# it as rejected on our end and plough ahead.
|
|
||||||
#
|
|
||||||
# The 'except' clause is very broad, but we need to
|
|
||||||
# capture everything from DNS failures upwards
|
|
||||||
#
|
|
||||||
logger.warn("Failed to reject invite: %s", e)
|
|
||||||
|
|
||||||
yield self.store.locally_reject_invite(
|
|
||||||
target.to_string(), room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue({})
|
|
||||||
|
|
||||||
res = yield self._local_membership_update(
|
res = yield self._local_membership_update(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
|
@ -438,12 +490,12 @@ class RoomMemberHandler(BaseHandler):
|
||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
if newly_joined:
|
if newly_joined:
|
||||||
yield user_joined_room(self.distributor, target_user, room_id)
|
yield self._user_joined_room(target_user, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
if prev_member_event.membership == Membership.JOIN:
|
if prev_member_event.membership == Membership.JOIN:
|
||||||
user_left_room(self.distributor, target_user, room_id)
|
yield self._user_left_room(target_user, room_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _can_guest_join(self, current_state_ids):
|
def _can_guest_join(self, current_state_ids):
|
||||||
|
@ -477,7 +529,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if room alias could not be found.
|
SynapseError if room alias could not be found.
|
||||||
"""
|
"""
|
||||||
directory_handler = self.hs.get_handlers().directory_handler
|
directory_handler = self.directory_handler
|
||||||
mapping = yield directory_handler.get_association(room_alias)
|
mapping = yield directory_handler.get_association(room_alias)
|
||||||
|
|
||||||
if not mapping:
|
if not mapping:
|
||||||
|
@ -489,7 +541,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
defer.returnValue((RoomID.from_string(room_id), servers))
|
defer.returnValue((RoomID.from_string(room_id), servers))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_inviter(self, user_id, room_id):
|
def _get_inviter(self, user_id, room_id):
|
||||||
invite = yield self.store.get_invite_for_user_in_room(
|
invite = yield self.store.get_invite_for_user_in_room(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -508,7 +560,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
requester,
|
requester,
|
||||||
txn_id
|
txn_id
|
||||||
):
|
):
|
||||||
if self.hs.config.block_non_admin_invites:
|
if self.config.block_non_admin_invites:
|
||||||
is_requester_admin = yield self.auth.is_server_admin(
|
is_requester_admin = yield self.auth.is_server_admin(
|
||||||
requester.user,
|
requester.user,
|
||||||
)
|
)
|
||||||
|
@ -555,7 +607,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
str: the matrix ID of the 3pid, or None if it is not recognized.
|
str: the matrix ID of the 3pid, or None if it is not recognized.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data = yield self.hs.get_simple_http_client().get_json(
|
data = yield self.simple_http_client.get_json(
|
||||||
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
|
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
|
||||||
{
|
{
|
||||||
"medium": medium,
|
"medium": medium,
|
||||||
|
@ -566,7 +618,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
if "mxid" in data:
|
if "mxid" in data:
|
||||||
if "signatures" not in data:
|
if "signatures" not in data:
|
||||||
raise AuthError(401, "No signatures on 3pid binding")
|
raise AuthError(401, "No signatures on 3pid binding")
|
||||||
self.verify_any_signature(data, id_server)
|
yield self._verify_any_signature(data, id_server)
|
||||||
defer.returnValue(data["mxid"])
|
defer.returnValue(data["mxid"])
|
||||||
|
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
|
@ -574,11 +626,11 @@ class RoomMemberHandler(BaseHandler):
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def verify_any_signature(self, data, server_hostname):
|
def _verify_any_signature(self, data, server_hostname):
|
||||||
if server_hostname not in data["signatures"]:
|
if server_hostname not in data["signatures"]:
|
||||||
raise AuthError(401, "No signature from server %s" % (server_hostname,))
|
raise AuthError(401, "No signature from server %s" % (server_hostname,))
|
||||||
for key_name, signature in data["signatures"][server_hostname].items():
|
for key_name, signature in data["signatures"][server_hostname].items():
|
||||||
key_data = yield self.hs.get_simple_http_client().get_json(
|
key_data = yield self.simple_http_client.get_json(
|
||||||
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
|
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
|
||||||
(id_server_scheme, server_hostname, key_name,),
|
(id_server_scheme, server_hostname, key_name,),
|
||||||
)
|
)
|
||||||
|
@ -603,7 +655,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
user,
|
user,
|
||||||
txn_id
|
txn_id
|
||||||
):
|
):
|
||||||
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
|
room_state = yield self.state_handler.get_current_state(room_id)
|
||||||
|
|
||||||
inviter_display_name = ""
|
inviter_display_name = ""
|
||||||
inviter_avatar_url = ""
|
inviter_avatar_url = ""
|
||||||
|
@ -634,6 +686,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
|
|
||||||
token, public_keys, fallback_public_key, display_name = (
|
token, public_keys, fallback_public_key, display_name = (
|
||||||
yield self._ask_id_server_for_third_party_invite(
|
yield self._ask_id_server_for_third_party_invite(
|
||||||
|
requester=requester,
|
||||||
id_server=id_server,
|
id_server=id_server,
|
||||||
medium=medium,
|
medium=medium,
|
||||||
address=address,
|
address=address,
|
||||||
|
@ -670,6 +723,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _ask_id_server_for_third_party_invite(
|
def _ask_id_server_for_third_party_invite(
|
||||||
self,
|
self,
|
||||||
|
requester,
|
||||||
id_server,
|
id_server,
|
||||||
medium,
|
medium,
|
||||||
address,
|
address,
|
||||||
|
@ -686,6 +740,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
Asks an identity server for a third party invite.
|
Asks an identity server for a third party invite.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
requester (Requester)
|
||||||
id_server (str): hostname + optional port for the identity server.
|
id_server (str): hostname + optional port for the identity server.
|
||||||
medium (str): The literal string "email".
|
medium (str): The literal string "email".
|
||||||
address (str): The third party address being invited.
|
address (str): The third party address being invited.
|
||||||
|
@ -727,24 +782,20 @@ class RoomMemberHandler(BaseHandler):
|
||||||
"sender_avatar_url": inviter_avatar_url,
|
"sender_avatar_url": inviter_avatar_url,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.hs.config.invite_3pid_guest:
|
if self.config.invite_3pid_guest:
|
||||||
registration_handler = self.hs.get_handlers().registration_handler
|
guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
|
||||||
guest_access_token = yield registration_handler.guest_access_token_for(
|
requester=requester,
|
||||||
medium=medium,
|
medium=medium,
|
||||||
address=address,
|
address=address,
|
||||||
inviter_user_id=inviter_user_id,
|
inviter_user_id=inviter_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
|
|
||||||
guest_access_token
|
|
||||||
)
|
|
||||||
|
|
||||||
invite_config.update({
|
invite_config.update({
|
||||||
"guest_access_token": guest_access_token,
|
"guest_access_token": guest_access_token,
|
||||||
"guest_user_id": guest_user_info["user"].to_string(),
|
"guest_user_id": guest_user_id,
|
||||||
})
|
})
|
||||||
|
|
||||||
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
|
data = yield self.simple_http_client.post_urlencoded_get_json(
|
||||||
is_url,
|
is_url,
|
||||||
invite_config
|
invite_config
|
||||||
)
|
)
|
||||||
|
@ -766,27 +817,6 @@ class RoomMemberHandler(BaseHandler):
|
||||||
display_name = data["display_name"]
|
display_name = data["display_name"]
|
||||||
defer.returnValue((token, public_keys, fallback_public_key, display_name))
|
defer.returnValue((token, public_keys, fallback_public_key, display_name))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def forget(self, user, room_id):
|
|
||||||
user_id = user.to_string()
|
|
||||||
|
|
||||||
member = yield self.state_handler.get_current_state(
|
|
||||||
room_id=room_id,
|
|
||||||
event_type=EventTypes.Member,
|
|
||||||
state_key=user_id
|
|
||||||
)
|
|
||||||
membership = member.membership if member else None
|
|
||||||
|
|
||||||
if membership is not None and membership not in [
|
|
||||||
Membership.LEAVE, Membership.BAN
|
|
||||||
]:
|
|
||||||
raise SynapseError(400, "User %s in room %s" % (
|
|
||||||
user_id, room_id
|
|
||||||
))
|
|
||||||
|
|
||||||
if membership:
|
|
||||||
yield self.store.forget(user_id, room_id)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _is_host_in_room(self, current_state_ids):
|
def _is_host_in_room(self, current_state_ids):
|
||||||
# Have we just created the room, and is this about to be the very
|
# Have we just created the room, and is this about to be the very
|
||||||
|
@ -808,3 +838,94 @@ class RoomMemberHandler(BaseHandler):
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
|
||||||
|
class RoomMemberMasterHandler(RoomMemberHandler):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomMemberMasterHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
self.distributor = hs.get_distributor()
|
||||||
|
self.distributor.declare("user_joined_room")
|
||||||
|
self.distributor.declare("user_left_room")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||||
|
"""Implements RoomMemberHandler._remote_join
|
||||||
|
"""
|
||||||
|
if len(remote_room_hosts) == 0:
|
||||||
|
raise SynapseError(404, "No known servers")
|
||||||
|
|
||||||
|
# We don't do an auth check if we are doing an invite
|
||||||
|
# join dance for now, since we're kinda implicitly checking
|
||||||
|
# that we are allowed to join when we decide whether or not we
|
||||||
|
# need to do the invite/join dance.
|
||||||
|
yield self.federation_handler.do_invite_join(
|
||||||
|
remote_room_hosts,
|
||||||
|
room_id,
|
||||||
|
user.to_string(),
|
||||||
|
content,
|
||||||
|
)
|
||||||
|
yield self._user_joined_room(user, room_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
|
||||||
|
"""Implements RoomMemberHandler._remote_reject_invite
|
||||||
|
"""
|
||||||
|
fed_handler = self.federation_handler
|
||||||
|
try:
|
||||||
|
ret = yield fed_handler.do_remotely_reject_invite(
|
||||||
|
remote_room_hosts,
|
||||||
|
room_id,
|
||||||
|
target.to_string(),
|
||||||
|
)
|
||||||
|
defer.returnValue(ret)
|
||||||
|
except Exception as e:
|
||||||
|
# if we were unable to reject the exception, just mark
|
||||||
|
# it as rejected on our end and plough ahead.
|
||||||
|
#
|
||||||
|
# The 'except' clause is very broad, but we need to
|
||||||
|
# capture everything from DNS failures upwards
|
||||||
|
#
|
||||||
|
logger.warn("Failed to reject invite: %s", e)
|
||||||
|
|
||||||
|
yield self.store.locally_reject_invite(
|
||||||
|
target.to_string(), room_id
|
||||||
|
)
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
|
||||||
|
"""Implements RoomMemberHandler.get_or_register_3pid_guest
|
||||||
|
"""
|
||||||
|
rg = self.registration_handler
|
||||||
|
return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
|
||||||
|
|
||||||
|
def _user_joined_room(self, target, room_id):
|
||||||
|
"""Implements RoomMemberHandler._user_joined_room
|
||||||
|
"""
|
||||||
|
return user_joined_room(self.distributor, target, room_id)
|
||||||
|
|
||||||
|
def _user_left_room(self, target, room_id):
|
||||||
|
"""Implements RoomMemberHandler._user_left_room
|
||||||
|
"""
|
||||||
|
return user_left_room(self.distributor, target, room_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def forget(self, user, room_id):
|
||||||
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
member = yield self.state_handler.get_current_state(
|
||||||
|
room_id=room_id,
|
||||||
|
event_type=EventTypes.Member,
|
||||||
|
state_key=user_id
|
||||||
|
)
|
||||||
|
membership = member.membership if member else None
|
||||||
|
|
||||||
|
if membership is not None and membership not in [
|
||||||
|
Membership.LEAVE, Membership.BAN
|
||||||
|
]:
|
||||||
|
raise SynapseError(400, "User %s in room %s" % (
|
||||||
|
user_id, room_id
|
||||||
|
))
|
||||||
|
|
||||||
|
if membership:
|
||||||
|
yield self.store.forget(user_id, room_id)
|
||||||
|
|
102
synapse/handlers/room_member_worker.py
Normal file
102
synapse/handlers/room_member_worker.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.handlers.room_member import RoomMemberHandler
|
||||||
|
from synapse.replication.http.membership import (
|
||||||
|
remote_join, remote_reject_invite, get_or_register_3pid_guest,
|
||||||
|
notify_user_membership_change,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RoomMemberWorkerHandler(RoomMemberHandler):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||||
|
"""Implements RoomMemberHandler._remote_join
|
||||||
|
"""
|
||||||
|
if len(remote_room_hosts) == 0:
|
||||||
|
raise SynapseError(404, "No known servers")
|
||||||
|
|
||||||
|
ret = yield remote_join(
|
||||||
|
self.simple_http_client,
|
||||||
|
host=self.config.worker_replication_host,
|
||||||
|
port=self.config.worker_replication_http_port,
|
||||||
|
requester=requester,
|
||||||
|
remote_room_hosts=remote_room_hosts,
|
||||||
|
room_id=room_id,
|
||||||
|
user_id=user.to_string(),
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._user_joined_room(user, room_id)
|
||||||
|
|
||||||
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
|
||||||
|
"""Implements RoomMemberHandler._remote_reject_invite
|
||||||
|
"""
|
||||||
|
return remote_reject_invite(
|
||||||
|
self.simple_http_client,
|
||||||
|
host=self.config.worker_replication_host,
|
||||||
|
port=self.config.worker_replication_http_port,
|
||||||
|
requester=requester,
|
||||||
|
remote_room_hosts=remote_room_hosts,
|
||||||
|
room_id=room_id,
|
||||||
|
user_id=target.to_string(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _user_joined_room(self, target, room_id):
|
||||||
|
"""Implements RoomMemberHandler._user_joined_room
|
||||||
|
"""
|
||||||
|
return notify_user_membership_change(
|
||||||
|
self.simple_http_client,
|
||||||
|
host=self.config.worker_replication_host,
|
||||||
|
port=self.config.worker_replication_http_port,
|
||||||
|
user_id=target.to_string(),
|
||||||
|
room_id=room_id,
|
||||||
|
change="joined",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _user_left_room(self, target, room_id):
|
||||||
|
"""Implements RoomMemberHandler._user_left_room
|
||||||
|
"""
|
||||||
|
return notify_user_membership_change(
|
||||||
|
self.simple_http_client,
|
||||||
|
host=self.config.worker_replication_host,
|
||||||
|
port=self.config.worker_replication_http_port,
|
||||||
|
user_id=target.to_string(),
|
||||||
|
room_id=room_id,
|
||||||
|
change="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
|
||||||
|
"""Implements RoomMemberHandler.get_or_register_3pid_guest
|
||||||
|
"""
|
||||||
|
return get_or_register_3pid_guest(
|
||||||
|
self.simple_http_client,
|
||||||
|
host=self.config.worker_replication_host,
|
||||||
|
port=self.config.worker_replication_http_port,
|
||||||
|
requester=requester,
|
||||||
|
medium=medium,
|
||||||
|
address=address,
|
||||||
|
inviter_user_id=inviter_user_id,
|
||||||
|
)
|
|
@ -235,10 +235,10 @@ class SyncHandler(object):
|
||||||
defer.returnValue(rules)
|
defer.returnValue(rules)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def ephemeral_by_room(self, sync_config, now_token, since_token=None):
|
def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
|
||||||
"""Get the ephemeral events for each room the user is in
|
"""Get the ephemeral events for each room the user is in
|
||||||
Args:
|
Args:
|
||||||
sync_config (SyncConfig): The flags, filters and user for the sync.
|
sync_result_builder(SyncResultBuilder)
|
||||||
now_token (StreamToken): Where the server is currently up to.
|
now_token (StreamToken): Where the server is currently up to.
|
||||||
since_token (StreamToken): Where the server was when the client
|
since_token (StreamToken): Where the server was when the client
|
||||||
last synced.
|
last synced.
|
||||||
|
@ -248,10 +248,12 @@ class SyncHandler(object):
|
||||||
typing events for that room.
|
typing events for that room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
sync_config = sync_result_builder.sync_config
|
||||||
|
|
||||||
with Measure(self.clock, "ephemeral_by_room"):
|
with Measure(self.clock, "ephemeral_by_room"):
|
||||||
typing_key = since_token.typing_key if since_token else "0"
|
typing_key = since_token.typing_key if since_token else "0"
|
||||||
|
|
||||||
room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
room_ids = sync_result_builder.joined_room_ids
|
||||||
|
|
||||||
typing_source = self.event_sources.sources["typing"]
|
typing_source = self.event_sources.sources["typing"]
|
||||||
typing, typing_key = yield typing_source.get_new_events(
|
typing, typing_key = yield typing_source.get_new_events(
|
||||||
|
@ -565,10 +567,22 @@ class SyncHandler(object):
|
||||||
# Always use the `now_token` in `SyncResultBuilder`
|
# Always use the `now_token` in `SyncResultBuilder`
|
||||||
now_token = yield self.event_sources.get_current_token()
|
now_token = yield self.event_sources.get_current_token()
|
||||||
|
|
||||||
|
user_id = sync_config.user.to_string()
|
||||||
|
app_service = self.store.get_app_service_by_user_id(user_id)
|
||||||
|
if app_service:
|
||||||
|
# We no longer support AS users using /sync directly.
|
||||||
|
# See https://github.com/matrix-org/matrix-doc/issues/1144
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
joined_room_ids = yield self.get_rooms_for_user_at(
|
||||||
|
user_id, now_token.room_stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
sync_result_builder = SyncResultBuilder(
|
sync_result_builder = SyncResultBuilder(
|
||||||
sync_config, full_state,
|
sync_config, full_state,
|
||||||
since_token=since_token,
|
since_token=since_token,
|
||||||
now_token=now_token,
|
now_token=now_token,
|
||||||
|
joined_room_ids=joined_room_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
account_data_by_room = yield self._generate_sync_entry_for_account_data(
|
account_data_by_room = yield self._generate_sync_entry_for_account_data(
|
||||||
|
@ -603,7 +617,6 @@ class SyncHandler(object):
|
||||||
device_id = sync_config.device_id
|
device_id = sync_config.device_id
|
||||||
one_time_key_counts = {}
|
one_time_key_counts = {}
|
||||||
if device_id:
|
if device_id:
|
||||||
user_id = sync_config.user.to_string()
|
|
||||||
one_time_key_counts = yield self.store.count_e2e_one_time_keys(
|
one_time_key_counts = yield self.store.count_e2e_one_time_keys(
|
||||||
user_id, device_id
|
user_id, device_id
|
||||||
)
|
)
|
||||||
|
@ -891,7 +904,7 @@ class SyncHandler(object):
|
||||||
ephemeral_by_room = {}
|
ephemeral_by_room = {}
|
||||||
else:
|
else:
|
||||||
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
||||||
sync_result_builder.sync_config,
|
sync_result_builder,
|
||||||
now_token=sync_result_builder.now_token,
|
now_token=sync_result_builder.now_token,
|
||||||
since_token=sync_result_builder.since_token,
|
since_token=sync_result_builder.since_token,
|
||||||
)
|
)
|
||||||
|
@ -996,15 +1009,8 @@ class SyncHandler(object):
|
||||||
if rooms_changed:
|
if rooms_changed:
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
app_service = self.store.get_app_service_by_user_id(user_id)
|
|
||||||
if app_service:
|
|
||||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
|
||||||
joined_room_ids = set(r.room_id for r in rooms)
|
|
||||||
else:
|
|
||||||
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
|
|
||||||
|
|
||||||
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
|
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
|
||||||
for room_id in joined_room_ids:
|
for room_id in sync_result_builder.joined_room_ids:
|
||||||
if self.store.has_room_changed_since(room_id, stream_id):
|
if self.store.has_room_changed_since(room_id, stream_id):
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
@ -1028,13 +1034,6 @@ class SyncHandler(object):
|
||||||
|
|
||||||
assert since_token
|
assert since_token
|
||||||
|
|
||||||
app_service = self.store.get_app_service_by_user_id(user_id)
|
|
||||||
if app_service:
|
|
||||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
|
||||||
joined_room_ids = set(r.room_id for r in rooms)
|
|
||||||
else:
|
|
||||||
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
|
|
||||||
|
|
||||||
# Get a list of membership change events that have happened.
|
# Get a list of membership change events that have happened.
|
||||||
rooms_changed = yield self.store.get_membership_changes_for_user(
|
rooms_changed = yield self.store.get_membership_changes_for_user(
|
||||||
user_id, since_token.room_key, now_token.room_key
|
user_id, since_token.room_key, now_token.room_key
|
||||||
|
@ -1057,7 +1056,7 @@ class SyncHandler(object):
|
||||||
# we do send down the room, and with full state, where necessary
|
# we do send down the room, and with full state, where necessary
|
||||||
|
|
||||||
old_state_ids = None
|
old_state_ids = None
|
||||||
if room_id in joined_room_ids and non_joins:
|
if room_id in sync_result_builder.joined_room_ids and non_joins:
|
||||||
# Always include if the user (re)joined the room, especially
|
# Always include if the user (re)joined the room, especially
|
||||||
# important so that device list changes are calculated correctly.
|
# important so that device list changes are calculated correctly.
|
||||||
# If there are non join member events, but we are still in the room,
|
# If there are non join member events, but we are still in the room,
|
||||||
|
@ -1067,7 +1066,7 @@ class SyncHandler(object):
|
||||||
# User is in the room so we don't need to do the invite/leave checks
|
# User is in the room so we don't need to do the invite/leave checks
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if room_id in joined_room_ids or has_join:
|
if room_id in sync_result_builder.joined_room_ids or has_join:
|
||||||
old_state_ids = yield self.get_state_at(room_id, since_token)
|
old_state_ids = yield self.get_state_at(room_id, since_token)
|
||||||
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
|
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
|
||||||
old_mem_ev = None
|
old_mem_ev = None
|
||||||
|
@ -1079,7 +1078,7 @@ class SyncHandler(object):
|
||||||
newly_joined_rooms.append(room_id)
|
newly_joined_rooms.append(room_id)
|
||||||
|
|
||||||
# If user is in the room then we don't need to do the invite/leave checks
|
# If user is in the room then we don't need to do the invite/leave checks
|
||||||
if room_id in joined_room_ids:
|
if room_id in sync_result_builder.joined_room_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not non_joins:
|
if not non_joins:
|
||||||
|
@ -1146,7 +1145,7 @@ class SyncHandler(object):
|
||||||
|
|
||||||
# Get all events for rooms we're currently joined to.
|
# Get all events for rooms we're currently joined to.
|
||||||
room_to_events = yield self.store.get_room_events_stream_for_rooms(
|
room_to_events = yield self.store.get_room_events_stream_for_rooms(
|
||||||
room_ids=joined_room_ids,
|
room_ids=sync_result_builder.joined_room_ids,
|
||||||
from_key=since_token.room_key,
|
from_key=since_token.room_key,
|
||||||
to_key=now_token.room_key,
|
to_key=now_token.room_key,
|
||||||
limit=timeline_limit + 1,
|
limit=timeline_limit + 1,
|
||||||
|
@ -1154,7 +1153,7 @@ class SyncHandler(object):
|
||||||
|
|
||||||
# We loop through all room ids, even if there are no new events, in case
|
# We loop through all room ids, even if there are no new events, in case
|
||||||
# there are non room events taht we need to notify about.
|
# there are non room events taht we need to notify about.
|
||||||
for room_id in joined_room_ids:
|
for room_id in sync_result_builder.joined_room_ids:
|
||||||
room_entry = room_to_events.get(room_id, None)
|
room_entry = room_to_events.get(room_id, None)
|
||||||
|
|
||||||
if room_entry:
|
if room_entry:
|
||||||
|
@ -1362,6 +1361,54 @@ class SyncHandler(object):
|
||||||
else:
|
else:
|
||||||
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
|
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_rooms_for_user_at(self, user_id, stream_ordering):
|
||||||
|
"""Get set of joined rooms for a user at the given stream ordering.
|
||||||
|
|
||||||
|
The stream ordering *must* be recent, otherwise this may throw an
|
||||||
|
exception if older than a month. (This function is called with the
|
||||||
|
current token, which should be perfectly fine).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str)
|
||||||
|
stream_ordering (int)
|
||||||
|
|
||||||
|
ReturnValue:
|
||||||
|
Deferred[frozenset[str]]: Set of room_ids the user is in at given
|
||||||
|
stream_ordering.
|
||||||
|
"""
|
||||||
|
joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
joined_room_ids = set()
|
||||||
|
|
||||||
|
# We need to check that the stream ordering of the join for each room
|
||||||
|
# is before the stream_ordering asked for. This might not be the case
|
||||||
|
# if the user joins a room between us getting the current token and
|
||||||
|
# calling `get_rooms_for_user_with_stream_ordering`.
|
||||||
|
# If the membership's stream ordering is after the given stream
|
||||||
|
# ordering, we need to go and work out if the user was in the room
|
||||||
|
# before.
|
||||||
|
for room_id, membership_stream_ordering in joined_rooms:
|
||||||
|
if membership_stream_ordering <= stream_ordering:
|
||||||
|
joined_room_ids.add(room_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("User joined room after current token: %s", room_id)
|
||||||
|
|
||||||
|
extrems = yield self.store.get_forward_extremeties_for_room(
|
||||||
|
room_id, stream_ordering,
|
||||||
|
)
|
||||||
|
users_in_room = yield self.state.get_current_user_in_room(
|
||||||
|
room_id, extrems,
|
||||||
|
)
|
||||||
|
if user_id in users_in_room:
|
||||||
|
joined_room_ids.add(room_id)
|
||||||
|
|
||||||
|
joined_room_ids = frozenset(joined_room_ids)
|
||||||
|
defer.returnValue(joined_room_ids)
|
||||||
|
|
||||||
|
|
||||||
def _action_has_highlight(actions):
|
def _action_has_highlight(actions):
|
||||||
for action in actions:
|
for action in actions:
|
||||||
|
@ -1411,7 +1458,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
|
||||||
|
|
||||||
class SyncResultBuilder(object):
|
class SyncResultBuilder(object):
|
||||||
"Used to help build up a new SyncResult for a user"
|
"Used to help build up a new SyncResult for a user"
|
||||||
def __init__(self, sync_config, full_state, since_token, now_token):
|
def __init__(self, sync_config, full_state, since_token, now_token,
|
||||||
|
joined_room_ids):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
sync_config(SyncConfig)
|
sync_config(SyncConfig)
|
||||||
|
@ -1423,6 +1471,7 @@ class SyncResultBuilder(object):
|
||||||
self.full_state = full_state
|
self.full_state = full_state
|
||||||
self.since_token = since_token
|
self.since_token = since_token
|
||||||
self.now_token = now_token
|
self.now_token = now_token
|
||||||
|
self.joined_room_ids = joined_room_ids
|
||||||
|
|
||||||
self.presence = []
|
self.presence = []
|
||||||
self.account_data = []
|
self.account_data = []
|
||||||
|
|
|
@ -56,7 +56,7 @@ class TypingHandler(object):
|
||||||
|
|
||||||
self.federation = hs.get_federation_sender()
|
self.federation = hs.get_federation_sender()
|
||||||
|
|
||||||
hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
|
hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
|
||||||
|
|
||||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -36,7 +37,7 @@ from twisted.web.util import redirectTo
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
import ujson
|
import simplejson
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -59,6 +60,11 @@ response_count = metrics.register_counter(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
requests_counter = metrics.register_counter(
|
||||||
|
"requests_received",
|
||||||
|
labels=["method", "servlet", ],
|
||||||
|
)
|
||||||
|
|
||||||
outgoing_responses_counter = metrics.register_counter(
|
outgoing_responses_counter = metrics.register_counter(
|
||||||
"responses",
|
"responses",
|
||||||
labels=["method", "code"],
|
labels=["method", "code"],
|
||||||
|
@ -145,7 +151,8 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
||||||
# at the servlet name. For most requests that name will be
|
# at the servlet name. For most requests that name will be
|
||||||
# JsonResource (or a subclass), and JsonResource._async_render
|
# JsonResource (or a subclass), and JsonResource._async_render
|
||||||
# will update it once it picks a servlet.
|
# will update it once it picks a servlet.
|
||||||
request_metrics.start(self.clock, name=self.__class__.__name__)
|
servlet_name = self.__class__.__name__
|
||||||
|
request_metrics.start(self.clock, name=servlet_name)
|
||||||
|
|
||||||
request_context.request = request_id
|
request_context.request = request_id
|
||||||
with request.processing():
|
with request.processing():
|
||||||
|
@ -154,6 +161,7 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
||||||
if include_metrics:
|
if include_metrics:
|
||||||
yield request_handler(self, request, request_metrics)
|
yield request_handler(self, request, request_metrics)
|
||||||
else:
|
else:
|
||||||
|
requests_counter.inc(request.method, servlet_name)
|
||||||
yield request_handler(self, request)
|
yield request_handler(self, request)
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
code = e.code
|
code = e.code
|
||||||
|
@ -229,7 +237,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
""" This implements the HttpServer interface and provides JSON support for
|
""" This implements the HttpServer interface and provides JSON support for
|
||||||
Resources.
|
Resources.
|
||||||
|
|
||||||
Register callbacks via register_path()
|
Register callbacks via register_paths()
|
||||||
|
|
||||||
Callbacks can return a tuple of status code and a dict in which case the
|
Callbacks can return a tuple of status code and a dict in which case the
|
||||||
the dict will automatically be sent to the client as a JSON object.
|
the dict will automatically be sent to the client as a JSON object.
|
||||||
|
@ -276,49 +284,59 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
This checks if anyone has registered a callback for that method and
|
This checks if anyone has registered a callback for that method and
|
||||||
path.
|
path.
|
||||||
"""
|
"""
|
||||||
|
callback, group_dict = self._get_handler_for_request(request)
|
||||||
|
|
||||||
|
servlet_instance = getattr(callback, "__self__", None)
|
||||||
|
if servlet_instance is not None:
|
||||||
|
servlet_classname = servlet_instance.__class__.__name__
|
||||||
|
else:
|
||||||
|
servlet_classname = "%r" % callback
|
||||||
|
|
||||||
|
request_metrics.name = servlet_classname
|
||||||
|
requests_counter.inc(request.method, servlet_classname)
|
||||||
|
|
||||||
|
# Now trigger the callback. If it returns a response, we send it
|
||||||
|
# here. If it throws an exception, that is handled by the wrapper
|
||||||
|
# installed by @request_handler.
|
||||||
|
|
||||||
|
kwargs = intern_dict({
|
||||||
|
name: urllib.unquote(value).decode("UTF-8") if value else value
|
||||||
|
for name, value in group_dict.items()
|
||||||
|
})
|
||||||
|
|
||||||
|
callback_return = yield callback(request, **kwargs)
|
||||||
|
if callback_return is not None:
|
||||||
|
code, response = callback_return
|
||||||
|
self._send_response(request, code, response)
|
||||||
|
|
||||||
|
def _get_handler_for_request(self, request):
|
||||||
|
"""Finds a callback method to handle the given request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (twisted.web.http.Request):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Callable, dict[str, str]]: callback method, and the dict
|
||||||
|
mapping keys to path components as specified in the handler's
|
||||||
|
path match regexp.
|
||||||
|
|
||||||
|
The callback will normally be a method registered via
|
||||||
|
register_paths, so will return (possibly via Deferred) either
|
||||||
|
None, or a tuple of (http code, response body).
|
||||||
|
"""
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
self._send_response(request, 200, {})
|
return _options_handler, {}
|
||||||
return
|
|
||||||
|
|
||||||
# Loop through all the registered callbacks to check if the method
|
# Loop through all the registered callbacks to check if the method
|
||||||
# and path regex match
|
# and path regex match
|
||||||
for path_entry in self.path_regexs.get(request.method, []):
|
for path_entry in self.path_regexs.get(request.method, []):
|
||||||
m = path_entry.pattern.match(request.path)
|
m = path_entry.pattern.match(request.path)
|
||||||
if not m:
|
if m:
|
||||||
continue
|
# We found a match!
|
||||||
|
return path_entry.callback, m.groupdict()
|
||||||
# We found a match! First update the metrics object to indicate
|
|
||||||
# which servlet is handling the request.
|
|
||||||
|
|
||||||
callback = path_entry.callback
|
|
||||||
|
|
||||||
servlet_instance = getattr(callback, "__self__", None)
|
|
||||||
if servlet_instance is not None:
|
|
||||||
servlet_classname = servlet_instance.__class__.__name__
|
|
||||||
else:
|
|
||||||
servlet_classname = "%r" % callback
|
|
||||||
|
|
||||||
request_metrics.name = servlet_classname
|
|
||||||
|
|
||||||
# Now trigger the callback. If it returns a response, we send it
|
|
||||||
# here. If it throws an exception, that is handled by the wrapper
|
|
||||||
# installed by @request_handler.
|
|
||||||
|
|
||||||
kwargs = intern_dict({
|
|
||||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
|
||||||
for name, value in m.groupdict().items()
|
|
||||||
})
|
|
||||||
|
|
||||||
callback_return = yield callback(request, **kwargs)
|
|
||||||
if callback_return is not None:
|
|
||||||
code, response = callback_return
|
|
||||||
self._send_response(request, code, response)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||||
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
|
return _unrecognised_request_handler, {}
|
||||||
raise UnrecognizedRequestError()
|
|
||||||
|
|
||||||
def _send_response(self, request, code, response_json_object,
|
def _send_response(self, request, code, response_json_object,
|
||||||
response_code_message=None):
|
response_code_message=None):
|
||||||
|
@ -335,6 +353,34 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _options_handler(request):
|
||||||
|
"""Request handler for OPTIONS requests
|
||||||
|
|
||||||
|
This is a request handler suitable for return from
|
||||||
|
_get_handler_for_request. It returns a 200 and an empty body.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (twisted.web.http.Request):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, dict]: http code, response body.
|
||||||
|
"""
|
||||||
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
def _unrecognised_request_handler(request):
|
||||||
|
"""Request handler for unrecognised requests
|
||||||
|
|
||||||
|
This is a request handler suitable for return from
|
||||||
|
_get_handler_for_request. It actually just raises an
|
||||||
|
UnrecognizedRequestError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (twisted.web.http.Request):
|
||||||
|
"""
|
||||||
|
raise UnrecognizedRequestError()
|
||||||
|
|
||||||
|
|
||||||
class RequestMetrics(object):
|
class RequestMetrics(object):
|
||||||
def start(self, clock, name):
|
def start(self, clock, name):
|
||||||
self.start = clock.time_msec()
|
self.start = clock.time_msec()
|
||||||
|
@ -415,8 +461,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
|
||||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
||||||
json_bytes = encode_canonical_json(json_object)
|
json_bytes = encode_canonical_json(json_object)
|
||||||
else:
|
else:
|
||||||
# ujson doesn't like frozen_dicts.
|
json_bytes = simplejson.dumps(json_object)
|
||||||
json_bytes = ujson.dumps(json_object, ensure_ascii=False)
|
|
||||||
|
|
||||||
return respond_with_json_bytes(
|
return respond_with_json_bytes(
|
||||||
request, code, json_bytes,
|
request, code, json_bytes,
|
||||||
|
@ -443,6 +488,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
|
||||||
request.setHeader(b"Content-Type", b"application/json")
|
request.setHeader(b"Content-Type", b"application/json")
|
||||||
request.setHeader(b"Server", version_string)
|
request.setHeader(b"Server", version_string)
|
||||||
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
|
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
|
||||||
|
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
|
||||||
|
|
||||||
if send_cors:
|
if send_cors:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
|
|
|
@ -57,15 +57,31 @@ class Metrics(object):
|
||||||
return metric
|
return metric
|
||||||
|
|
||||||
def register_counter(self, *args, **kwargs):
|
def register_counter(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
CounterMetric
|
||||||
|
"""
|
||||||
return self._register(CounterMetric, *args, **kwargs)
|
return self._register(CounterMetric, *args, **kwargs)
|
||||||
|
|
||||||
def register_callback(self, *args, **kwargs):
|
def register_callback(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
CallbackMetric
|
||||||
|
"""
|
||||||
return self._register(CallbackMetric, *args, **kwargs)
|
return self._register(CallbackMetric, *args, **kwargs)
|
||||||
|
|
||||||
def register_distribution(self, *args, **kwargs):
|
def register_distribution(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
DistributionMetric
|
||||||
|
"""
|
||||||
return self._register(DistributionMetric, *args, **kwargs)
|
return self._register(DistributionMetric, *args, **kwargs)
|
||||||
|
|
||||||
def register_cache(self, *args, **kwargs):
|
def register_cache(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
CacheMetric
|
||||||
|
"""
|
||||||
return self._register(CacheMetric, *args, **kwargs)
|
return self._register(CacheMetric, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -144,6 +144,7 @@ class BulkPushRuleEvaluator(object):
|
||||||
Deferred
|
Deferred
|
||||||
"""
|
"""
|
||||||
rules_by_user = yield self._get_rules_for_event(event, context)
|
rules_by_user = yield self._get_rules_for_event(event, context)
|
||||||
|
actions_by_user = {}
|
||||||
|
|
||||||
room_members = yield self.store.get_joined_users_from_context(
|
room_members = yield self.store.get_joined_users_from_context(
|
||||||
event, context
|
event, context
|
||||||
|
@ -189,14 +190,17 @@ class BulkPushRuleEvaluator(object):
|
||||||
if matches:
|
if matches:
|
||||||
actions = [x for x in rule['actions'] if x != 'dont_notify']
|
actions = [x for x in rule['actions'] if x != 'dont_notify']
|
||||||
if actions and 'notify' in actions:
|
if actions and 'notify' in actions:
|
||||||
# Push rules say we should notify the user of this event,
|
# Push rules say we should notify the user of this event
|
||||||
# so we mark it in the DB in the staging area. (This
|
actions_by_user[uid] = actions
|
||||||
# will then get handled when we persist the event)
|
|
||||||
yield self.store.add_push_actions_to_staging(
|
|
||||||
event.event_id, uid, actions,
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Mark in the DB staging area the push actions for users who should be
|
||||||
|
# notified for this event. (This will then get handled when we persist
|
||||||
|
# the event)
|
||||||
|
yield self.store.add_push_actions_to_staging(
|
||||||
|
event.event_id, actions_by_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
||||||
for cond in conditions:
|
for cond in conditions:
|
||||||
|
|
|
@ -24,17 +24,16 @@ REQUIREMENTS = {
|
||||||
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
|
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
|
||||||
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
|
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
|
||||||
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
|
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
|
||||||
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
|
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
|
||||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||||
"Twisted>=16.0.0": ["twisted>=16.0.0"],
|
"Twisted>=16.0.0": ["twisted>=16.0.0"],
|
||||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||||
"pyyaml": ["yaml"],
|
"pyyaml": ["yaml"],
|
||||||
"pyasn1": ["pyasn1"],
|
"pyasn1": ["pyasn1"],
|
||||||
"daemonize": ["daemonize"],
|
"daemonize": ["daemonize"],
|
||||||
"bcrypt": ["bcrypt"],
|
"bcrypt": ["bcrypt>=3.1.0"],
|
||||||
"pillow": ["PIL"],
|
"pillow": ["PIL"],
|
||||||
"pydenticon": ["pydenticon"],
|
"pydenticon": ["pydenticon"],
|
||||||
"ujson": ["ujson"],
|
|
||||||
"blist": ["blist"],
|
"blist": ["blist"],
|
||||||
"pysaml2>=3.0.0": ["saml2>=3.0.0"],
|
"pysaml2>=3.0.0": ["saml2>=3.0.0"],
|
||||||
"pymacaroons-pynacl": ["pymacaroons"],
|
"pymacaroons-pynacl": ["pymacaroons"],
|
||||||
|
|
|
@ -13,10 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import send_event
|
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
|
from synapse.replication.http import membership, send_event
|
||||||
|
|
||||||
|
|
||||||
REPLICATION_PREFIX = "/_synapse/replication"
|
REPLICATION_PREFIX = "/_synapse/replication"
|
||||||
|
@ -29,3 +27,4 @@ class ReplicationRestResource(JsonResource):
|
||||||
|
|
||||||
def register_servlets(self, hs):
|
def register_servlets(self, hs):
|
||||||
send_event.register_servlets(hs, self)
|
send_event.register_servlets(hs, self)
|
||||||
|
membership.register_servlets(hs, self)
|
||||||
|
|
334
synapse/replication/http/membership.py
Normal file
334
synapse/replication/http/membership.py
Normal file
|
@ -0,0 +1,334 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError, MatrixCodeMessageException
|
||||||
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
|
from synapse.types import Requester, UserID
|
||||||
|
from synapse.util.distributor import user_left_room, user_joined_room
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def remote_join(client, host, port, requester, remote_room_hosts,
|
||||||
|
room_id, user_id, content):
|
||||||
|
"""Ask the master to do a remote join for the given user to the given room
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (SimpleHttpClient)
|
||||||
|
host (str): host of master
|
||||||
|
port (int): port on master listening for HTTP replication
|
||||||
|
requester (Requester)
|
||||||
|
remote_room_hosts (list[str]): Servers to try and join via
|
||||||
|
room_id (str)
|
||||||
|
user_id (str)
|
||||||
|
content (dict): The event content to use for the join event
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"requester": requester.serialize(),
|
||||||
|
"remote_room_hosts": remote_room_hosts,
|
||||||
|
"room_id": room_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = yield client.post_json_get_json(uri, payload)
|
||||||
|
except MatrixCodeMessageException as e:
|
||||||
|
# We convert to SynapseError as we know that it was a SynapseError
|
||||||
|
# on the master process that we should send to the client. (And
|
||||||
|
# importantly, not stack traces everywhere)
|
||||||
|
raise SynapseError(e.code, e.msg, e.errcode)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def remote_reject_invite(client, host, port, requester, remote_room_hosts,
|
||||||
|
room_id, user_id):
|
||||||
|
"""Ask master to reject the invite for the user and room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (SimpleHttpClient)
|
||||||
|
host (str): host of master
|
||||||
|
port (int): port on master listening for HTTP replication
|
||||||
|
requester (Requester)
|
||||||
|
remote_room_hosts (list[str]): Servers to try and reject via
|
||||||
|
room_id (str)
|
||||||
|
user_id (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"requester": requester.serialize(),
|
||||||
|
"remote_room_hosts": remote_room_hosts,
|
||||||
|
"room_id": room_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = yield client.post_json_get_json(uri, payload)
|
||||||
|
except MatrixCodeMessageException as e:
|
||||||
|
# We convert to SynapseError as we know that it was a SynapseError
|
||||||
|
# on the master process that we should send to the client. (And
|
||||||
|
# importantly, not stack traces everywhere)
|
||||||
|
raise SynapseError(e.code, e.msg, e.errcode)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_or_register_3pid_guest(client, host, port, requester,
|
||||||
|
medium, address, inviter_user_id):
|
||||||
|
"""Ask the master to get/create a guest account for given 3PID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (SimpleHttpClient)
|
||||||
|
host (str): host of master
|
||||||
|
port (int): port on master listening for HTTP replication
|
||||||
|
requester (Requester)
|
||||||
|
medium (str)
|
||||||
|
address (str)
|
||||||
|
inviter_user_id (str): The user ID who is trying to invite the
|
||||||
|
3PID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
|
||||||
|
3PID guest account.
|
||||||
|
"""
|
||||||
|
|
||||||
|
uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"requester": requester.serialize(),
|
||||||
|
"medium": medium,
|
||||||
|
"address": address,
|
||||||
|
"inviter_user_id": inviter_user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = yield client.post_json_get_json(uri, payload)
|
||||||
|
except MatrixCodeMessageException as e:
|
||||||
|
# We convert to SynapseError as we know that it was a SynapseError
|
||||||
|
# on the master process that we should send to the client. (And
|
||||||
|
# importantly, not stack traces everywhere)
|
||||||
|
raise SynapseError(e.code, e.msg, e.errcode)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def notify_user_membership_change(client, host, port, user_id, room_id, change):
|
||||||
|
"""Notify master that a user has joined or left the room
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (SimpleHttpClient)
|
||||||
|
host (str): host of master
|
||||||
|
port (int): port on master listening for HTTP replication.
|
||||||
|
user_id (str)
|
||||||
|
room_id (str)
|
||||||
|
change (str): Either "join" or "left"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
assert change in ("joined", "left")
|
||||||
|
|
||||||
|
uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = yield client.post_json_get_json(uri, payload)
|
||||||
|
except MatrixCodeMessageException as e:
|
||||||
|
# We convert to SynapseError as we know that it was a SynapseError
|
||||||
|
# on the master process that we should send to the client. (And
|
||||||
|
# importantly, not stack traces everywhere)
|
||||||
|
raise SynapseError(e.code, e.msg, e.errcode)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationRemoteJoinRestServlet(RestServlet):
|
||||||
|
PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReplicationRemoteJoinRestServlet, self).__init__()
|
||||||
|
|
||||||
|
self.federation_handler = hs.get_handlers().federation_handler
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
remote_room_hosts = content["remote_room_hosts"]
|
||||||
|
room_id = content["room_id"]
|
||||||
|
user_id = content["user_id"]
|
||||||
|
event_content = content["content"]
|
||||||
|
|
||||||
|
requester = Requester.deserialize(self.store, content["requester"])
|
||||||
|
|
||||||
|
if requester.user:
|
||||||
|
request.authenticated_entity = requester.user.to_string()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"remote_join: %s into room: %s",
|
||||||
|
user_id, room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.federation_handler.do_invite_join(
|
||||||
|
remote_room_hosts,
|
||||||
|
room_id,
|
||||||
|
user_id,
|
||||||
|
event_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationRemoteRejectInviteRestServlet(RestServlet):
|
||||||
|
PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")]
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReplicationRemoteRejectInviteRestServlet, self).__init__()
|
||||||
|
|
||||||
|
self.federation_handler = hs.get_handlers().federation_handler
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
remote_room_hosts = content["remote_room_hosts"]
|
||||||
|
room_id = content["room_id"]
|
||||||
|
user_id = content["user_id"]
|
||||||
|
|
||||||
|
requester = Requester.deserialize(self.store, content["requester"])
|
||||||
|
|
||||||
|
if requester.user:
|
||||||
|
request.authenticated_entity = requester.user.to_string()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"remote_reject_invite: %s out of room: %s",
|
||||||
|
user_id, room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
event = yield self.federation_handler.do_remotely_reject_invite(
|
||||||
|
remote_room_hosts,
|
||||||
|
room_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
ret = event.get_pdu_json()
|
||||||
|
except Exception as e:
|
||||||
|
# if we were unable to reject the exception, just mark
|
||||||
|
# it as rejected on our end and plough ahead.
|
||||||
|
#
|
||||||
|
# The 'except' clause is very broad, but we need to
|
||||||
|
# capture everything from DNS failures upwards
|
||||||
|
#
|
||||||
|
logger.warn("Failed to reject invite: %s", e)
|
||||||
|
|
||||||
|
yield self.store.locally_reject_invite(
|
||||||
|
user_id, room_id
|
||||||
|
)
|
||||||
|
ret = {}
|
||||||
|
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationRegister3PIDGuestRestServlet(RestServlet):
|
||||||
|
PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")]
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReplicationRegister3PIDGuestRestServlet, self).__init__()
|
||||||
|
|
||||||
|
self.registeration_handler = hs.get_handlers().registration_handler
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
medium = content["medium"]
|
||||||
|
address = content["address"]
|
||||||
|
inviter_user_id = content["inviter_user_id"]
|
||||||
|
|
||||||
|
requester = Requester.deserialize(self.store, content["requester"])
|
||||||
|
|
||||||
|
if requester.user:
|
||||||
|
request.authenticated_entity = requester.user.to_string()
|
||||||
|
|
||||||
|
logger.info("get_or_register_3pid_guest: %r", content)
|
||||||
|
|
||||||
|
ret = yield self.registeration_handler.get_or_register_3pid_guest(
|
||||||
|
medium, address, inviter_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationUserJoinedLeftRoomRestServlet(RestServlet):
|
||||||
|
PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")]
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__()
|
||||||
|
|
||||||
|
self.registeration_handler = hs.get_handlers().registration_handler
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.distributor = hs.get_distributor()
|
||||||
|
|
||||||
|
def on_POST(self, request, change):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
user_id = content["user_id"]
|
||||||
|
room_id = content["room_id"]
|
||||||
|
|
||||||
|
logger.info("user membership change: %s in %s", user_id, room_id)
|
||||||
|
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
|
if change == "joined":
|
||||||
|
user_joined_room(self.distributor, user, room_id)
|
||||||
|
elif change == "left":
|
||||||
|
user_left_room(self.distributor, user, room_id)
|
||||||
|
else:
|
||||||
|
raise Exception("Unrecognized change: %r", change)
|
||||||
|
|
||||||
|
return (200, {})
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
ReplicationRemoteJoinRestServlet(hs).register(http_server)
|
||||||
|
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
|
||||||
|
ReplicationRegister3PIDGuestRestServlet(hs).register(http_server)
|
||||||
|
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
|
|
@ -15,12 +15,17 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, MatrixCodeMessageException
|
from synapse.api.errors import (
|
||||||
|
SynapseError, MatrixCodeMessageException, CodeMessageException,
|
||||||
|
)
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.types import Requester
|
from synapse.types import Requester, UserID
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
@ -29,7 +34,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_event_to_master(client, host, port, requester, event, context):
|
def send_event_to_master(client, host, port, requester, event, context,
|
||||||
|
ratelimit, extra_users):
|
||||||
"""Send event to be handled on the master
|
"""Send event to be handled on the master
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -39,8 +45,12 @@ def send_event_to_master(client, host, port, requester, event, context):
|
||||||
requester (Requester)
|
requester (Requester)
|
||||||
event (FrozenEvent)
|
event (FrozenEvent)
|
||||||
context (EventContext)
|
context (EventContext)
|
||||||
|
ratelimit (bool)
|
||||||
|
extra_users (list(UserID)): Any extra users to notify about event
|
||||||
"""
|
"""
|
||||||
uri = "http://%s:%s/_synapse/replication/send_event" % (host, port,)
|
uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
|
||||||
|
host, port, event.event_id,
|
||||||
|
)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"event": event.get_pdu_json(),
|
"event": event.get_pdu_json(),
|
||||||
|
@ -48,10 +58,27 @@ def send_event_to_master(client, host, port, requester, event, context):
|
||||||
"rejected_reason": event.rejected_reason,
|
"rejected_reason": event.rejected_reason,
|
||||||
"context": context.serialize(event),
|
"context": context.serialize(event),
|
||||||
"requester": requester.serialize(),
|
"requester": requester.serialize(),
|
||||||
|
"ratelimit": ratelimit,
|
||||||
|
"extra_users": [u.to_string() for u in extra_users],
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = yield client.post_json_get_json(uri, payload)
|
# We keep retrying the same request for timeouts. This is so that we
|
||||||
|
# have a good idea that the request has either succeeded or failed on
|
||||||
|
# the master, and so whether we should clean up or not.
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = yield client.put_json(uri, payload)
|
||||||
|
break
|
||||||
|
except CodeMessageException as e:
|
||||||
|
if e.code != 504:
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.warn("send_event request timed out")
|
||||||
|
|
||||||
|
# If we timed out we probably don't need to worry about backing
|
||||||
|
# off too much, but lets just wait a little anyway.
|
||||||
|
yield sleep(1)
|
||||||
except MatrixCodeMessageException as e:
|
except MatrixCodeMessageException as e:
|
||||||
# We convert to SynapseError as we know that it was a SynapseError
|
# We convert to SynapseError as we know that it was a SynapseError
|
||||||
# on the master process that we should send to the client. (And
|
# on the master process that we should send to the client. (And
|
||||||
|
@ -66,7 +93,7 @@ class ReplicationSendEventRestServlet(RestServlet):
|
||||||
|
|
||||||
The API looks like:
|
The API looks like:
|
||||||
|
|
||||||
POST /_synapse/replication/send_event
|
POST /_synapse/replication/send_event/:event_id
|
||||||
|
|
||||||
{
|
{
|
||||||
"event": { .. serialized event .. },
|
"event": { .. serialized event .. },
|
||||||
|
@ -74,9 +101,11 @@ class ReplicationSendEventRestServlet(RestServlet):
|
||||||
"rejected_reason": .., // The event.rejected_reason field
|
"rejected_reason": .., // The event.rejected_reason field
|
||||||
"context": { .. serialized event context .. },
|
"context": { .. serialized event context .. },
|
||||||
"requester": { .. serialized requester .. },
|
"requester": { .. serialized requester .. },
|
||||||
|
"ratelimit": true,
|
||||||
|
"extra_users": [],
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
PATTERNS = [re.compile("^/_synapse/replication/send_event$")]
|
PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ReplicationSendEventRestServlet, self).__init__()
|
super(ReplicationSendEventRestServlet, self).__init__()
|
||||||
|
@ -85,8 +114,23 @@ class ReplicationSendEventRestServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
# The responses are tiny, so we may as well cache them for a while
|
||||||
|
self.response_cache = ResponseCache(hs, timeout_ms=30 * 60 * 1000)
|
||||||
|
|
||||||
|
def on_PUT(self, request, event_id):
|
||||||
|
result = self.response_cache.get(event_id)
|
||||||
|
if not result:
|
||||||
|
result = self.response_cache.set(
|
||||||
|
event_id,
|
||||||
|
self._handle_request(request)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn("Returning cached response")
|
||||||
|
return make_deferred_yieldable(result)
|
||||||
|
|
||||||
|
@preserve_fn
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def _handle_request(self, request):
|
||||||
with Measure(self.clock, "repl_send_event_parse"):
|
with Measure(self.clock, "repl_send_event_parse"):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
@ -98,6 +142,9 @@ class ReplicationSendEventRestServlet(RestServlet):
|
||||||
requester = Requester.deserialize(self.store, content["requester"])
|
requester = Requester.deserialize(self.store, content["requester"])
|
||||||
context = yield EventContext.deserialize(self.store, content["context"])
|
context = yield EventContext.deserialize(self.store, content["context"])
|
||||||
|
|
||||||
|
ratelimit = content["ratelimit"]
|
||||||
|
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
||||||
|
|
||||||
if requester.user:
|
if requester.user:
|
||||||
request.authenticated_entity = requester.user.to_string()
|
request.authenticated_entity = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -106,8 +153,10 @@ class ReplicationSendEventRestServlet(RestServlet):
|
||||||
event.event_id, event.room_id,
|
event.event_id, event.room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.event_creation_handler.handle_new_client_event(
|
yield self.event_creation_handler.persist_and_notify_client_event(
|
||||||
requester, event, context,
|
requester, event, context,
|
||||||
|
ratelimit=ratelimit,
|
||||||
|
extra_users=extra_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,50 +14,20 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.account_data import AccountDataWorkerStore
|
||||||
from synapse.storage.account_data import AccountDataStore
|
from synapse.storage.tags import TagsWorkerStore
|
||||||
from synapse.storage.tags import TagsStore
|
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
||||||
|
|
||||||
|
|
||||||
class SlavedAccountDataStore(BaseSlavedStore):
|
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
|
|
||||||
self._account_data_id_gen = SlavedIdTracker(
|
self._account_data_id_gen = SlavedIdTracker(
|
||||||
db_conn, "account_data_max_stream_id", "stream_id",
|
db_conn, "account_data_max_stream_id", "stream_id",
|
||||||
)
|
)
|
||||||
self._account_data_stream_cache = StreamChangeCache(
|
|
||||||
"AccountDataAndTagsChangeCache",
|
|
||||||
self._account_data_id_gen.get_current_token(),
|
|
||||||
)
|
|
||||||
|
|
||||||
get_account_data_for_user = (
|
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
|
||||||
AccountDataStore.__dict__["get_account_data_for_user"]
|
|
||||||
)
|
|
||||||
|
|
||||||
get_global_account_data_by_type_for_users = (
|
|
||||||
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
|
|
||||||
)
|
|
||||||
|
|
||||||
get_global_account_data_by_type_for_user = (
|
|
||||||
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
|
|
||||||
)
|
|
||||||
|
|
||||||
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
|
|
||||||
get_tags_for_room = (
|
|
||||||
DataStore.get_tags_for_room.__func__
|
|
||||||
)
|
|
||||||
get_account_data_for_room = (
|
|
||||||
DataStore.get_account_data_for_room.__func__
|
|
||||||
)
|
|
||||||
|
|
||||||
get_updated_tags = DataStore.get_updated_tags.__func__
|
|
||||||
get_updated_account_data_for_user = (
|
|
||||||
DataStore.get_updated_account_data_for_user.__func__
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_max_account_data_stream_id(self):
|
def get_max_account_data_stream_id(self):
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
@ -85,6 +56,10 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
||||||
(row.data_type, row.user_id,)
|
(row.data_type, row.user_id,)
|
||||||
)
|
)
|
||||||
self.get_account_data_for_user.invalidate((row.user_id,))
|
self.get_account_data_for_user.invalidate((row.user_id,))
|
||||||
|
self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
|
||||||
|
self.get_account_data_for_room_and_type.invalidate(
|
||||||
|
(row.user_id, row.room_id, row.data_type,),
|
||||||
|
)
|
||||||
self._account_data_stream_cache.entity_has_changed(
|
self._account_data_stream_cache.entity_has_changed(
|
||||||
row.user_id, token
|
row.user_id, token
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,33 +14,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from synapse.storage.appservice import (
|
||||||
from synapse.storage import DataStore
|
ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore,
|
||||||
from synapse.config.appservice import load_appservices
|
)
|
||||||
from synapse.storage.appservice import _make_exclusive_regex
|
|
||||||
|
|
||||||
|
|
||||||
class SlavedApplicationServiceStore(BaseSlavedStore):
|
class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
|
||||||
def __init__(self, db_conn, hs):
|
ApplicationServiceWorkerStore):
|
||||||
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
|
pass
|
||||||
self.services_cache = load_appservices(
|
|
||||||
hs.config.server_name,
|
|
||||||
hs.config.app_service_config_files
|
|
||||||
)
|
|
||||||
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
|
|
||||||
|
|
||||||
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
|
||||||
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
|
||||||
get_app_services = DataStore.get_app_services.__func__
|
|
||||||
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
|
|
||||||
create_appservice_txn = DataStore.create_appservice_txn.__func__
|
|
||||||
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
|
|
||||||
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
|
|
||||||
_get_last_txn = DataStore._get_last_txn.__func__
|
|
||||||
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
|
|
||||||
get_appservice_state = DataStore.get_appservice_state.__func__
|
|
||||||
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
|
|
||||||
set_appservice_state = DataStore.set_appservice_state.__func__
|
|
||||||
get_if_app_services_interested_in_user = (
|
|
||||||
DataStore.get_if_app_services_interested_in_user.__func__
|
|
||||||
)
|
|
||||||
|
|
|
@ -14,10 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from synapse.storage.directory import DirectoryStore
|
from synapse.storage.directory import DirectoryWorkerStore
|
||||||
|
|
||||||
|
|
||||||
class DirectoryStore(BaseSlavedStore):
|
class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
|
||||||
get_aliases_for_room = DirectoryStore.__dict__[
|
pass
|
||||||
"get_aliases_for_room"
|
|
||||||
]
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -15,14 +16,13 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.event_federation import EventFederationWorkerStore
|
||||||
from synapse.storage.event_federation import EventFederationStore
|
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
|
||||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
from synapse.storage.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.state import StateGroupWorkerStore
|
from synapse.storage.state import StateGroupWorkerStore
|
||||||
from synapse.storage.stream import StreamStore
|
from synapse.storage.stream import StreamWorkerStore
|
||||||
from synapse.storage.signatures import SignatureStore
|
from synapse.storage.signatures import SignatureWorkerStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
@ -38,157 +38,33 @@ logger = logging.getLogger(__name__)
|
||||||
# the method descriptor on the DataStore and chuck them into our class.
|
# the method descriptor on the DataStore and chuck them into our class.
|
||||||
|
|
||||||
|
|
||||||
class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
|
class SlavedEventStore(EventFederationWorkerStore,
|
||||||
|
RoomMemberWorkerStore,
|
||||||
|
EventPushActionsWorkerStore,
|
||||||
|
StreamWorkerStore,
|
||||||
|
EventsWorkerStore,
|
||||||
|
StateGroupWorkerStore,
|
||||||
|
SignatureWorkerStore,
|
||||||
|
BaseSlavedStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SlavedEventStore, self).__init__(db_conn, hs)
|
|
||||||
self._stream_id_gen = SlavedIdTracker(
|
self._stream_id_gen = SlavedIdTracker(
|
||||||
db_conn, "events", "stream_ordering",
|
db_conn, "events", "stream_ordering",
|
||||||
)
|
)
|
||||||
self._backfill_id_gen = SlavedIdTracker(
|
self._backfill_id_gen = SlavedIdTracker(
|
||||||
db_conn, "events", "stream_ordering", step=-1
|
db_conn, "events", "stream_ordering", step=-1
|
||||||
)
|
)
|
||||||
events_max = self._stream_id_gen.get_current_token()
|
|
||||||
event_cache_prefill, min_event_val = self._get_cache_dict(
|
|
||||||
db_conn, "events",
|
|
||||||
entity_column="room_id",
|
|
||||||
stream_column="stream_ordering",
|
|
||||||
max_value=events_max,
|
|
||||||
)
|
|
||||||
self._events_stream_cache = StreamChangeCache(
|
|
||||||
"EventsRoomStreamChangeCache", min_event_val,
|
|
||||||
prefilled_cache=event_cache_prefill,
|
|
||||||
)
|
|
||||||
self._membership_stream_cache = StreamChangeCache(
|
|
||||||
"MembershipStreamChangeCache", events_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.stream_ordering_month_ago = 0
|
super(SlavedEventStore, self).__init__(db_conn, hs)
|
||||||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
|
||||||
|
|
||||||
# Cached functions can't be accessed through a class instance so we need
|
# Cached functions can't be accessed through a class instance so we need
|
||||||
# to reach inside the __dict__ to extract them.
|
# to reach inside the __dict__ to extract them.
|
||||||
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
|
|
||||||
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
|
|
||||||
get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"]
|
|
||||||
get_users_who_share_room_with_user = (
|
|
||||||
RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
|
|
||||||
)
|
|
||||||
get_latest_event_ids_in_room = EventFederationStore.__dict__[
|
|
||||||
"get_latest_event_ids_in_room"
|
|
||||||
]
|
|
||||||
get_invited_rooms_for_user = RoomMemberStore.__dict__[
|
|
||||||
"get_invited_rooms_for_user"
|
|
||||||
]
|
|
||||||
get_unread_event_push_actions_by_room_for_user = (
|
|
||||||
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
|
|
||||||
)
|
|
||||||
_get_unread_counts_by_receipt_txn = (
|
|
||||||
DataStore._get_unread_counts_by_receipt_txn.__func__
|
|
||||||
)
|
|
||||||
_get_unread_counts_by_pos_txn = (
|
|
||||||
DataStore._get_unread_counts_by_pos_txn.__func__
|
|
||||||
)
|
|
||||||
get_recent_event_ids_for_room = (
|
|
||||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
|
||||||
)
|
|
||||||
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
|
|
||||||
has_room_changed_since = DataStore.has_room_changed_since.__func__
|
|
||||||
|
|
||||||
get_unread_push_actions_for_user_in_range_for_http = (
|
def get_room_max_stream_ordering(self):
|
||||||
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
return self._stream_id_gen.get_current_token()
|
||||||
)
|
|
||||||
get_unread_push_actions_for_user_in_range_for_email = (
|
|
||||||
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
|
|
||||||
)
|
|
||||||
get_push_action_users_in_range = (
|
|
||||||
DataStore.get_push_action_users_in_range.__func__
|
|
||||||
)
|
|
||||||
get_event = DataStore.get_event.__func__
|
|
||||||
get_events = DataStore.get_events.__func__
|
|
||||||
get_rooms_for_user_where_membership_is = (
|
|
||||||
DataStore.get_rooms_for_user_where_membership_is.__func__
|
|
||||||
)
|
|
||||||
get_membership_changes_for_user = (
|
|
||||||
DataStore.get_membership_changes_for_user.__func__
|
|
||||||
)
|
|
||||||
get_room_events_max_id = DataStore.get_room_events_max_id.__func__
|
|
||||||
get_room_events_stream_for_room = (
|
|
||||||
DataStore.get_room_events_stream_for_room.__func__
|
|
||||||
)
|
|
||||||
get_events_around = DataStore.get_events_around.__func__
|
|
||||||
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
|
|
||||||
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
|
|
||||||
_get_joined_users_from_context = (
|
|
||||||
RoomMemberStore.__dict__["_get_joined_users_from_context"]
|
|
||||||
)
|
|
||||||
|
|
||||||
get_joined_hosts = DataStore.get_joined_hosts.__func__
|
def get_room_min_stream_ordering(self):
|
||||||
_get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"]
|
return self._backfill_id_gen.get_current_token()
|
||||||
|
|
||||||
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
|
|
||||||
get_room_events_stream_for_rooms = (
|
|
||||||
DataStore.get_room_events_stream_for_rooms.__func__
|
|
||||||
)
|
|
||||||
is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
|
|
||||||
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
|
||||||
|
|
||||||
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
|
||||||
|
|
||||||
_get_events = DataStore._get_events.__func__
|
|
||||||
_get_events_from_cache = DataStore._get_events_from_cache.__func__
|
|
||||||
|
|
||||||
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
|
|
||||||
_enqueue_events = DataStore._enqueue_events.__func__
|
|
||||||
_do_fetch = DataStore._do_fetch.__func__
|
|
||||||
_fetch_event_rows = DataStore._fetch_event_rows.__func__
|
|
||||||
_get_event_from_row = DataStore._get_event_from_row.__func__
|
|
||||||
_get_rooms_for_user_where_membership_is_txn = (
|
|
||||||
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
|
||||||
)
|
|
||||||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
|
||||||
|
|
||||||
get_backfill_events = DataStore.get_backfill_events.__func__
|
|
||||||
_get_backfill_events = DataStore._get_backfill_events.__func__
|
|
||||||
get_missing_events = DataStore.get_missing_events.__func__
|
|
||||||
_get_missing_events = DataStore._get_missing_events.__func__
|
|
||||||
|
|
||||||
get_auth_chain = DataStore.get_auth_chain.__func__
|
|
||||||
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
|
|
||||||
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
|
|
||||||
|
|
||||||
get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
|
|
||||||
|
|
||||||
get_forward_extremeties_for_room = (
|
|
||||||
DataStore.get_forward_extremeties_for_room.__func__
|
|
||||||
)
|
|
||||||
_get_forward_extremeties_for_room = (
|
|
||||||
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
|
|
||||||
)
|
|
||||||
|
|
||||||
get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
|
|
||||||
|
|
||||||
get_federation_out_pos = DataStore.get_federation_out_pos.__func__
|
|
||||||
update_federation_out_pos = DataStore.update_federation_out_pos.__func__
|
|
||||||
|
|
||||||
get_latest_event_ids_and_hashes_in_room = (
|
|
||||||
DataStore.get_latest_event_ids_and_hashes_in_room.__func__
|
|
||||||
)
|
|
||||||
_get_latest_event_ids_and_hashes_in_room = (
|
|
||||||
DataStore._get_latest_event_ids_and_hashes_in_room.__func__
|
|
||||||
)
|
|
||||||
_get_event_reference_hashes_txn = (
|
|
||||||
DataStore._get_event_reference_hashes_txn.__func__
|
|
||||||
)
|
|
||||||
add_event_hashes = (
|
|
||||||
DataStore.add_event_hashes.__func__
|
|
||||||
)
|
|
||||||
get_event_reference_hashes = (
|
|
||||||
SignatureStore.__dict__["get_event_reference_hashes"]
|
|
||||||
)
|
|
||||||
get_event_reference_hash = (
|
|
||||||
SignatureStore.__dict__["get_event_reference_hash"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedEventStore, self).stream_positions()
|
result = super(SlavedEventStore, self).stream_positions()
|
||||||
|
|
21
synapse/replication/slave/storage/profile.py
Normal file
21
synapse/replication/slave/storage/profile.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
|
from synapse.storage.profile import ProfileWorkerStore
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
|
||||||
|
pass
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -15,29 +16,15 @@
|
||||||
|
|
||||||
from .events import SlavedEventStore
|
from .events import SlavedEventStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.push_rule import PushRulesWorkerStore
|
||||||
from synapse.storage.push_rule import PushRuleStore
|
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
||||||
|
|
||||||
|
|
||||||
class SlavedPushRuleStore(SlavedEventStore):
|
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
|
|
||||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||||
db_conn, "push_rules_stream", "stream_id",
|
db_conn, "push_rules_stream", "stream_id",
|
||||||
)
|
)
|
||||||
self.push_rules_stream_cache = StreamChangeCache(
|
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
|
||||||
"PushRulesStreamChangeCache",
|
|
||||||
self._push_rules_stream_id_gen.get_current_token(),
|
|
||||||
)
|
|
||||||
|
|
||||||
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
|
|
||||||
get_push_rules_enabled_for_user = (
|
|
||||||
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
|
|
||||||
)
|
|
||||||
have_push_rules_changed_for_user = (
|
|
||||||
DataStore.have_push_rules_changed_for_user.__func__
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_push_rules_stream_token(self):
|
def get_push_rules_stream_token(self):
|
||||||
return (
|
return (
|
||||||
|
@ -45,6 +32,9 @@ class SlavedPushRuleStore(SlavedEventStore):
|
||||||
self._stream_id_gen.get_current_token(),
|
self._stream_id_gen.get_current_token(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_max_push_rules_stream_id(self):
|
||||||
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
result = super(SlavedPushRuleStore, self).stream_positions()
|
||||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -16,10 +17,10 @@
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.pusher import PusherWorkerStore
|
||||||
|
|
||||||
|
|
||||||
class SlavedPusherStore(BaseSlavedStore):
|
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SlavedPusherStore, self).__init__(db_conn, hs)
|
super(SlavedPusherStore, self).__init__(db_conn, hs)
|
||||||
|
@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore):
|
||||||
extra_tables=[("deleted_pushers", "stream_id")],
|
extra_tables=[("deleted_pushers", "stream_id")],
|
||||||
)
|
)
|
||||||
|
|
||||||
get_all_pushers = DataStore.get_all_pushers.__func__
|
|
||||||
get_pushers_by = DataStore.get_pushers_by.__func__
|
|
||||||
get_pushers_by_app_id_and_pushkey = (
|
|
||||||
DataStore.get_pushers_by_app_id_and_pushkey.__func__
|
|
||||||
)
|
|
||||||
_decode_pushers_rows = DataStore._decode_pushers_rows.__func__
|
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedPusherStore, self).stream_positions()
|
result = super(SlavedPusherStore, self).stream_positions()
|
||||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -16,9 +17,7 @@
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.receipts import ReceiptsWorkerStore
|
||||||
from synapse.storage.receipts import ReceiptsStore
|
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
||||||
|
|
||||||
# So, um, we want to borrow a load of functions intended for reading from
|
# So, um, we want to borrow a load of functions intended for reading from
|
||||||
# a DataStore, but we don't want to take functions that either write to the
|
# a DataStore, but we don't want to take functions that either write to the
|
||||||
|
@ -29,36 +28,19 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
# the method descriptor on the DataStore and chuck them into our class.
|
# the method descriptor on the DataStore and chuck them into our class.
|
||||||
|
|
||||||
|
|
||||||
class SlavedReceiptsStore(BaseSlavedStore):
|
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
|
# We instantiate this first as the ReceiptsWorkerStore constructor
|
||||||
|
# needs to be able to call get_max_receipt_stream_id
|
||||||
self._receipts_id_gen = SlavedIdTracker(
|
self._receipts_id_gen = SlavedIdTracker(
|
||||||
db_conn, "receipts_linearized", "stream_id"
|
db_conn, "receipts_linearized", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._receipts_stream_cache = StreamChangeCache(
|
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
|
||||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
|
||||||
)
|
|
||||||
|
|
||||||
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
|
def get_max_receipt_stream_id(self):
|
||||||
get_linearized_receipts_for_room = (
|
return self._receipts_id_gen.get_current_token()
|
||||||
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
|
|
||||||
)
|
|
||||||
_get_linearized_receipts_for_rooms = (
|
|
||||||
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
|
|
||||||
)
|
|
||||||
get_last_receipt_event_id_for_user = (
|
|
||||||
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
|
|
||||||
)
|
|
||||||
|
|
||||||
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
|
|
||||||
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
|
|
||||||
|
|
||||||
get_linearized_receipts_for_rooms = (
|
|
||||||
DataStore.get_linearized_receipts_for_rooms.__func__
|
|
||||||
)
|
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||||
|
@ -71,6 +53,8 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
||||||
self.get_last_receipt_event_id_for_user.invalidate(
|
self.get_last_receipt_event_id_for_user.invalidate(
|
||||||
(user_id, room_id, receipt_type)
|
(user_id, room_id, receipt_type)
|
||||||
)
|
)
|
||||||
|
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
|
||||||
|
self.get_receipts_for_room.invalidate((room_id, receipt_type))
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "receipts":
|
if stream_name == "receipts":
|
||||||
|
|
|
@ -14,20 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.registration import RegistrationWorkerStore
|
||||||
from synapse.storage.registration import RegistrationStore
|
|
||||||
|
|
||||||
|
|
||||||
class SlavedRegistrationStore(BaseSlavedStore):
|
class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
|
||||||
def __init__(self, db_conn, hs):
|
pass
|
||||||
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
|
|
||||||
|
|
||||||
# TODO: use the cached version and invalidate deleted tokens
|
|
||||||
get_user_by_access_token = RegistrationStore.__dict__[
|
|
||||||
"get_user_by_access_token"
|
|
||||||
]
|
|
||||||
|
|
||||||
_query_for_auth = DataStore._query_for_auth.__func__
|
|
||||||
get_user_by_id = RegistrationStore.__dict__[
|
|
||||||
"get_user_by_id"
|
|
||||||
]
|
|
||||||
|
|
|
@ -14,32 +14,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.room import RoomWorkerStore
|
||||||
from synapse.storage.room import RoomStore
|
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
|
||||||
class RoomStore(BaseSlavedStore):
|
class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(RoomStore, self).__init__(db_conn, hs)
|
super(RoomStore, self).__init__(db_conn, hs)
|
||||||
self._public_room_id_gen = SlavedIdTracker(
|
self._public_room_id_gen = SlavedIdTracker(
|
||||||
db_conn, "public_room_list_stream", "stream_id"
|
db_conn, "public_room_list_stream", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
get_public_room_ids = DataStore.get_public_room_ids.__func__
|
def get_current_public_room_stream_id(self):
|
||||||
get_current_public_room_stream_id = (
|
return self._public_room_id_gen.get_current_token()
|
||||||
DataStore.get_current_public_room_stream_id.__func__
|
|
||||||
)
|
|
||||||
get_public_room_ids_at_stream_id = (
|
|
||||||
RoomStore.__dict__["get_public_room_ids_at_stream_id"]
|
|
||||||
)
|
|
||||||
get_public_room_ids_at_stream_id_txn = (
|
|
||||||
DataStore.get_public_room_ids_at_stream_id_txn.__func__
|
|
||||||
)
|
|
||||||
get_published_at_stream_id_txn = (
|
|
||||||
DataStore.get_published_at_stream_id_txn.__func__
|
|
||||||
)
|
|
||||||
get_public_room_changes = DataStore.get_public_room_changes.__func__
|
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(RoomStore, self).stream_positions()
|
result = super(RoomStore, self).stream_positions()
|
||||||
|
|
|
@ -19,11 +19,13 @@ allowed to be sent by which side.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import simplejson
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
|
||||||
|
|
||||||
|
|
||||||
class Command(object):
|
class Command(object):
|
||||||
"""The base command class.
|
"""The base command class.
|
||||||
|
@ -100,14 +102,14 @@ class RdataCommand(Command):
|
||||||
return cls(
|
return cls(
|
||||||
stream_name,
|
stream_name,
|
||||||
None if token == "batch" else int(token),
|
None if token == "batch" else int(token),
|
||||||
json.loads(row_json)
|
simplejson.loads(row_json)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self):
|
||||||
return " ".join((
|
return " ".join((
|
||||||
self.stream_name,
|
self.stream_name,
|
||||||
str(self.token) if self.token is not None else "batch",
|
str(self.token) if self.token is not None else "batch",
|
||||||
json.dumps(self.row),
|
_json_encoder.encode(self.row),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
@ -298,10 +300,12 @@ class InvalidateCacheCommand(Command):
|
||||||
def from_line(cls, line):
|
def from_line(cls, line):
|
||||||
cache_func, keys_json = line.split(" ", 1)
|
cache_func, keys_json = line.split(" ", 1)
|
||||||
|
|
||||||
return cls(cache_func, json.loads(keys_json))
|
return cls(cache_func, simplejson.loads(keys_json))
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self):
|
||||||
return " ".join((self.cache_func, json.dumps(self.keys)))
|
return " ".join((
|
||||||
|
self.cache_func, _json_encoder.encode(self.keys),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
class UserIpCommand(Command):
|
class UserIpCommand(Command):
|
||||||
|
@ -325,14 +329,14 @@ class UserIpCommand(Command):
|
||||||
def from_line(cls, line):
|
def from_line(cls, line):
|
||||||
user_id, jsn = line.split(" ", 1)
|
user_id, jsn = line.split(" ", 1)
|
||||||
|
|
||||||
access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
|
access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
user_id, access_token, ip, user_agent, device_id, last_seen
|
user_id, access_token, ip, user_agent, device_id, last_seen
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self):
|
||||||
return self.user_id + " " + json.dumps((
|
return self.user_id + " " + _json_encoder.encode((
|
||||||
self.access_token, self.ip, self.user_agent, self.device_id,
|
self.access_token, self.ip, self.user_agent, self.device_id,
|
||||||
self.last_seen,
|
self.last_seen,
|
||||||
))
|
))
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
from synapse.api.errors import AuthError, SynapseError
|
from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
|
||||||
|
@ -114,12 +114,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
class PurgeHistoryRestServlet(ClientV1RestServlet):
|
class PurgeHistoryRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns(
|
PATTERNS = client_path_patterns(
|
||||||
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
"/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer)
|
||||||
|
"""
|
||||||
super(PurgeHistoryRestServlet, self).__init__(hs)
|
super(PurgeHistoryRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, event_id):
|
def on_POST(self, request, room_id, event_id):
|
||||||
|
@ -133,12 +139,89 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
delete_local_events = bool(body.get("delete_local_events", False))
|
delete_local_events = bool(body.get("delete_local_events", False))
|
||||||
|
|
||||||
yield self.handlers.message_handler.purge_history(
|
# establish the topological ordering we should keep events from. The
|
||||||
room_id, event_id,
|
# user can provide an event_id in the URL or the request body, or can
|
||||||
|
# provide a timestamp in the request body.
|
||||||
|
if event_id is None:
|
||||||
|
event_id = body.get('purge_up_to_event_id')
|
||||||
|
|
||||||
|
if event_id is not None:
|
||||||
|
event = yield self.store.get_event(event_id)
|
||||||
|
|
||||||
|
if event.room_id != room_id:
|
||||||
|
raise SynapseError(400, "Event is for wrong room.")
|
||||||
|
|
||||||
|
depth = event.depth
|
||||||
|
logger.info(
|
||||||
|
"[purge] purging up to depth %i (event_id %s)",
|
||||||
|
depth, event_id,
|
||||||
|
)
|
||||||
|
elif 'purge_up_to_ts' in body:
|
||||||
|
ts = body['purge_up_to_ts']
|
||||||
|
if not isinstance(ts, int):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "purge_up_to_ts must be an int",
|
||||||
|
errcode=Codes.BAD_JSON,
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_ordering = (
|
||||||
|
yield self.store.find_first_stream_ordering_after_ts(ts)
|
||||||
|
)
|
||||||
|
|
||||||
|
(_, depth, _) = (
|
||||||
|
yield self.store.get_room_event_after_stream_ordering(
|
||||||
|
room_id, stream_ordering,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[purge] purging up to depth %i (received_ts %i => "
|
||||||
|
"stream_ordering %i)",
|
||||||
|
depth, ts, stream_ordering,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"must specify purge_up_to_event_id or purge_up_to_ts",
|
||||||
|
errcode=Codes.BAD_JSON,
|
||||||
|
)
|
||||||
|
|
||||||
|
purge_id = yield self.handlers.message_handler.start_purge_history(
|
||||||
|
room_id, depth,
|
||||||
delete_local_events=delete_local_events,
|
delete_local_events=delete_local_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {
|
||||||
|
"purge_id": purge_id,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns(
|
||||||
|
"/admin/purge_history_status/(?P<purge_id>[^/]+)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer)
|
||||||
|
"""
|
||||||
|
super(PurgeHistoryStatusRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, purge_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
|
if not is_admin:
|
||||||
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
|
purge_status = self.handlers.message_handler.get_purge_status(purge_id)
|
||||||
|
if purge_status is None:
|
||||||
|
raise NotFoundError("purge id '%s' not found" % purge_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, purge_status.asdict()))
|
||||||
|
|
||||||
|
|
||||||
class DeactivateAccountRestServlet(ClientV1RestServlet):
|
class DeactivateAccountRestServlet(ClientV1RestServlet):
|
||||||
|
@ -180,6 +263,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id):
|
def on_POST(self, request, room_id):
|
||||||
|
@ -238,7 +322,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||||
logger.info("Kicking %r from %r...", user_id, room_id)
|
logger.info("Kicking %r from %r...", user_id, room_id)
|
||||||
|
|
||||||
target_requester = create_requester(user_id)
|
target_requester = create_requester(user_id)
|
||||||
yield self.handlers.room_member_handler.update_membership(
|
yield self.room_member_handler.update_membership(
|
||||||
requester=target_requester,
|
requester=target_requester,
|
||||||
target=target_requester.user,
|
target=target_requester.user,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -247,9 +331,9 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||||
ratelimit=False
|
ratelimit=False
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.handlers.room_member_handler.forget(target_requester.user, room_id)
|
yield self.room_member_handler.forget(target_requester.user, room_id)
|
||||||
|
|
||||||
yield self.handlers.room_member_handler.update_membership(
|
yield self.room_member_handler.update_membership(
|
||||||
requester=target_requester,
|
requester=target_requester,
|
||||||
target=target_requester.user,
|
target=target_requester.user,
|
||||||
room_id=new_room_id,
|
room_id=new_room_id,
|
||||||
|
@ -508,6 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
WhoisRestServlet(hs).register(http_server)
|
WhoisRestServlet(hs).register(http_server)
|
||||||
PurgeMediaCacheRestServlet(hs).register(http_server)
|
PurgeMediaCacheRestServlet(hs).register(http_server)
|
||||||
|
PurgeHistoryStatusRestServlet(hs).register(http_server)
|
||||||
DeactivateAccountRestServlet(hs).register(http_server)
|
DeactivateAccountRestServlet(hs).register(http_server)
|
||||||
PurgeHistoryRestServlet(hs).register(http_server)
|
PurgeHistoryRestServlet(hs).register(http_server)
|
||||||
UsersRestServlet(hs).register(http_server)
|
UsersRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -30,7 +30,7 @@ from synapse.http.servlet import (
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -84,6 +84,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
super(RoomStateEventRestServlet, self).__init__(hs)
|
super(RoomStateEventRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
self.event_creation_hander = hs.get_event_creation_handler()
|
self.event_creation_hander = hs.get_event_creation_handler()
|
||||||
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
# /room/$roomid/state/$eventtype
|
# /room/$roomid/state/$eventtype
|
||||||
|
@ -156,7 +157,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
if event_type == EventTypes.Member:
|
if event_type == EventTypes.Member:
|
||||||
membership = content.get("membership", None)
|
membership = content.get("membership", None)
|
||||||
event = yield self.handlers.room_member_handler.update_membership(
|
event = yield self.room_member_handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
target=UserID.from_string(state_key),
|
target=UserID.from_string(state_key),
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -229,7 +230,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||||
class JoinRoomAliasServlet(ClientV1RestServlet):
|
class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(JoinRoomAliasServlet, self).__init__(hs)
|
super(JoinRoomAliasServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
# /join/$room_identifier[/$txn_id]
|
# /join/$room_identifier[/$txn_id]
|
||||||
|
@ -257,7 +258,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
except Exception:
|
except Exception:
|
||||||
remote_room_hosts = None
|
remote_room_hosts = None
|
||||||
elif RoomAlias.is_valid(room_identifier):
|
elif RoomAlias.is_valid(room_identifier):
|
||||||
handler = self.handlers.room_member_handler
|
handler = self.room_member_handler
|
||||||
room_alias = RoomAlias.from_string(room_identifier)
|
room_alias = RoomAlias.from_string(room_identifier)
|
||||||
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
|
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
|
||||||
room_id = room_id.to_string()
|
room_id = room_id.to_string()
|
||||||
|
@ -266,7 +267,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
room_identifier,
|
room_identifier,
|
||||||
))
|
))
|
||||||
|
|
||||||
yield self.handlers.room_member_handler.update_membership(
|
yield self.room_member_handler.update_membership(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
target=requester.user,
|
target=requester.user,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -562,7 +563,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
|
||||||
class RoomForgetRestServlet(ClientV1RestServlet):
|
class RoomForgetRestServlet(ClientV1RestServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomForgetRestServlet, self).__init__(hs)
|
super(RoomForgetRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
|
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
|
||||||
|
@ -575,7 +576,7 @@ class RoomForgetRestServlet(ClientV1RestServlet):
|
||||||
allow_guest=False,
|
allow_guest=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.handlers.room_member_handler.forget(
|
yield self.room_member_handler.forget(
|
||||||
user=requester.user,
|
user=requester.user,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
)
|
)
|
||||||
|
@ -593,12 +594,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomMembershipRestServlet, self).__init__(hs)
|
super(RoomMembershipRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
# /rooms/$roomid/[invite|join|leave]
|
# /rooms/$roomid/[invite|join|leave]
|
||||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
||||||
"(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
|
"(?P<membership_action>join|invite|leave|ban|unban|kick)")
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -622,7 +623,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
content = {}
|
content = {}
|
||||||
|
|
||||||
if membership_action == "invite" and self._has_3pid_invite_keys(content):
|
if membership_action == "invite" and self._has_3pid_invite_keys(content):
|
||||||
yield self.handlers.room_member_handler.do_3pid_invite(
|
yield self.room_member_handler.do_3pid_invite(
|
||||||
room_id,
|
room_id,
|
||||||
requester.user,
|
requester.user,
|
||||||
content["medium"],
|
content["medium"],
|
||||||
|
@ -644,7 +645,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
if 'reason' in content and membership_action in ['kick', 'ban']:
|
if 'reason' in content and membership_action in ['kick', 'ban']:
|
||||||
event_content = {'reason': content['reason']}
|
event_content = {'reason': content['reason']}
|
||||||
|
|
||||||
yield self.handlers.room_member_handler.update_membership(
|
yield self.room_member_handler.update_membership(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
target=target,
|
target=target,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
|
|
@ -183,7 +183,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
self.room_member_handler = hs.get_handlers().room_member_handler
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ from ._base import set_timeline_upper_limit
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -58,23 +58,13 @@ class MediaStorage(object):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[str]: the file path written to in the primary media store
|
Deferred[str]: the file path written to in the primary media store
|
||||||
"""
|
"""
|
||||||
path = self._file_info_to_path(file_info)
|
|
||||||
fname = os.path.join(self.local_media_directory, path)
|
|
||||||
|
|
||||||
dirname = os.path.dirname(fname)
|
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
||||||
if not os.path.exists(dirname):
|
# Write to the main repository
|
||||||
os.makedirs(dirname)
|
yield make_deferred_yieldable(threads.deferToThread(
|
||||||
|
_write_file_synchronously, source, f,
|
||||||
# Write to the main repository
|
))
|
||||||
yield make_deferred_yieldable(threads.deferToThread(
|
yield finish_cb()
|
||||||
_write_file_synchronously, source, fname,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Tell the storage providers about the new file. They'll decide
|
|
||||||
# if they should upload it and whether to do so synchronously
|
|
||||||
# or not.
|
|
||||||
for provider in self.storage_providers:
|
|
||||||
yield provider.store_file(path, file_info)
|
|
||||||
|
|
||||||
defer.returnValue(fname)
|
defer.returnValue(fname)
|
||||||
|
|
||||||
|
@ -240,21 +230,16 @@ class MediaStorage(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _write_file_synchronously(source, fname):
|
def _write_file_synchronously(source, dest):
|
||||||
"""Write `source` to the path `fname` synchronously. Should be called
|
"""Write `source` to the file like `dest` synchronously. Should be called
|
||||||
from a thread.
|
from a thread.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A file like object to be written
|
source: A file like object that's to be written
|
||||||
fname (str): Path to write to
|
dest: A file like object to be written to
|
||||||
"""
|
"""
|
||||||
dirname = os.path.dirname(fname)
|
|
||||||
if not os.path.exists(dirname):
|
|
||||||
os.makedirs(dirname)
|
|
||||||
|
|
||||||
source.seek(0) # Ensure we read from the start of the file
|
source.seek(0) # Ensure we read from the start of the file
|
||||||
with open(fname, "wb") as f:
|
shutil.copyfileobj(source, dest)
|
||||||
shutil.copyfileobj(source, f)
|
|
||||||
|
|
||||||
|
|
||||||
class FileResponder(Responder):
|
class FileResponder(Responder):
|
||||||
|
|
|
@ -23,7 +23,7 @@ import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
import urlparse
|
import urlparse
|
||||||
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
|
|
@ -32,8 +32,10 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||||
from synapse.crypto.keyring import Keyring
|
from synapse.crypto.keyring import Keyring
|
||||||
from synapse.events.builder import EventBuilderFactory
|
from synapse.events.builder import EventBuilderFactory
|
||||||
from synapse.events.spamcheck import SpamChecker
|
from synapse.events.spamcheck import SpamChecker
|
||||||
from synapse.federation import initialize_http_replication
|
from synapse.federation.federation_client import FederationClient
|
||||||
|
from synapse.federation.federation_server import FederationServer
|
||||||
from synapse.federation.send_queue import FederationRemoteSendQueue
|
from synapse.federation.send_queue import FederationRemoteSendQueue
|
||||||
|
from synapse.federation.federation_server import FederationHandlerRegistry
|
||||||
from synapse.federation.transport.client import TransportLayerClient
|
from synapse.federation.transport.client import TransportLayerClient
|
||||||
from synapse.federation.transaction_queue import TransactionQueue
|
from synapse.federation.transaction_queue import TransactionQueue
|
||||||
from synapse.handlers import Handlers
|
from synapse.handlers import Handlers
|
||||||
|
@ -45,6 +47,8 @@ from synapse.handlers.device import DeviceHandler
|
||||||
from synapse.handlers.e2e_keys import E2eKeysHandler
|
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||||
from synapse.handlers.presence import PresenceHandler
|
from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.handlers.room_list import RoomListHandler
|
from synapse.handlers.room_list import RoomListHandler
|
||||||
|
from synapse.handlers.room_member import RoomMemberMasterHandler
|
||||||
|
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
||||||
from synapse.handlers.set_password import SetPasswordHandler
|
from synapse.handlers.set_password import SetPasswordHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
from synapse.handlers.typing import TypingHandler
|
from synapse.handlers.typing import TypingHandler
|
||||||
|
@ -98,7 +102,8 @@ class HomeServer(object):
|
||||||
DEPENDENCIES = [
|
DEPENDENCIES = [
|
||||||
'http_client',
|
'http_client',
|
||||||
'db_pool',
|
'db_pool',
|
||||||
'replication_layer',
|
'federation_client',
|
||||||
|
'federation_server',
|
||||||
'handlers',
|
'handlers',
|
||||||
'v1auth',
|
'v1auth',
|
||||||
'auth',
|
'auth',
|
||||||
|
@ -145,6 +150,8 @@ class HomeServer(object):
|
||||||
'groups_attestation_signing',
|
'groups_attestation_signing',
|
||||||
'groups_attestation_renewer',
|
'groups_attestation_renewer',
|
||||||
'spam_checker',
|
'spam_checker',
|
||||||
|
'room_member_handler',
|
||||||
|
'federation_registry',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, **kwargs):
|
||||||
|
@ -193,8 +200,11 @@ class HomeServer(object):
|
||||||
def get_ratelimiter(self):
|
def get_ratelimiter(self):
|
||||||
return self.ratelimiter
|
return self.ratelimiter
|
||||||
|
|
||||||
def build_replication_layer(self):
|
def build_federation_client(self):
|
||||||
return initialize_http_replication(self)
|
return FederationClient(self)
|
||||||
|
|
||||||
|
def build_federation_server(self):
|
||||||
|
return FederationServer(self)
|
||||||
|
|
||||||
def build_handlers(self):
|
def build_handlers(self):
|
||||||
return Handlers(self)
|
return Handlers(self)
|
||||||
|
@ -382,6 +392,14 @@ class HomeServer(object):
|
||||||
def build_spam_checker(self):
|
def build_spam_checker(self):
|
||||||
return SpamChecker(self)
|
return SpamChecker(self)
|
||||||
|
|
||||||
|
def build_room_member_handler(self):
|
||||||
|
if self.config.worker_app:
|
||||||
|
return RoomMemberWorkerHandler(self)
|
||||||
|
return RoomMemberMasterHandler(self)
|
||||||
|
|
||||||
|
def build_federation_registry(self):
|
||||||
|
return FederationHandlerRegistry()
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
|
|
@ -132,7 +132,7 @@ class StateHandler(object):
|
||||||
|
|
||||||
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
|
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
|
||||||
state = {
|
state = {
|
||||||
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
@ -378,7 +378,7 @@ class StateHandler(object):
|
||||||
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
||||||
|
|
||||||
new_state = {
|
new_state = {
|
||||||
key: state_map[ev_id] for key, ev_id in new_state.items()
|
key: state_map[ev_id] for key, ev_id in new_state.iteritems()
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_state
|
return new_state
|
||||||
|
@ -458,15 +458,15 @@ class StateResolutionHandler(object):
|
||||||
# build a map from state key to the event_ids which set that state.
|
# build a map from state key to the event_ids which set that state.
|
||||||
# dict[(str, str), set[str])
|
# dict[(str, str), set[str])
|
||||||
state = {}
|
state = {}
|
||||||
for st in state_groups_ids.values():
|
for st in state_groups_ids.itervalues():
|
||||||
for key, e_id in st.items():
|
for key, e_id in st.iteritems():
|
||||||
state.setdefault(key, set()).add(e_id)
|
state.setdefault(key, set()).add(e_id)
|
||||||
|
|
||||||
# build a map from state key to the event_ids which set that state,
|
# build a map from state key to the event_ids which set that state,
|
||||||
# including only those where there are state keys in conflict.
|
# including only those where there are state keys in conflict.
|
||||||
conflicted_state = {
|
conflicted_state = {
|
||||||
k: list(v)
|
k: list(v)
|
||||||
for k, v in state.items()
|
for k, v in state.iteritems()
|
||||||
if len(v) > 1
|
if len(v) > 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -480,36 +480,37 @@ class StateResolutionHandler(object):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_state = {
|
new_state = {
|
||||||
key: e_ids.pop() for key, e_ids in state.items()
|
key: e_ids.pop() for key, e_ids in state.iteritems()
|
||||||
}
|
}
|
||||||
|
|
||||||
# if the new state matches any of the input state groups, we can
|
with Measure(self.clock, "state.create_group_ids"):
|
||||||
# use that state group again. Otherwise we will generate a state_id
|
# if the new state matches any of the input state groups, we can
|
||||||
# which will be used as a cache key for future resolutions, but
|
# use that state group again. Otherwise we will generate a state_id
|
||||||
# not get persisted.
|
# which will be used as a cache key for future resolutions, but
|
||||||
state_group = None
|
# not get persisted.
|
||||||
new_state_event_ids = frozenset(new_state.values())
|
state_group = None
|
||||||
for sg, events in state_groups_ids.items():
|
new_state_event_ids = frozenset(new_state.itervalues())
|
||||||
if new_state_event_ids == frozenset(e_id for e_id in events):
|
for sg, events in state_groups_ids.iteritems():
|
||||||
state_group = sg
|
if new_state_event_ids == frozenset(e_id for e_id in events):
|
||||||
break
|
state_group = sg
|
||||||
|
break
|
||||||
|
|
||||||
# TODO: We want to create a state group for this set of events, to
|
# TODO: We want to create a state group for this set of events, to
|
||||||
# increase cache hits, but we need to make sure that it doesn't
|
# increase cache hits, but we need to make sure that it doesn't
|
||||||
# end up as a prev_group without being added to the database
|
# end up as a prev_group without being added to the database
|
||||||
|
|
||||||
prev_group = None
|
prev_group = None
|
||||||
delta_ids = None
|
delta_ids = None
|
||||||
for old_group, old_ids in state_groups_ids.iteritems():
|
for old_group, old_ids in state_groups_ids.iteritems():
|
||||||
if not set(new_state) - set(old_ids):
|
if not set(new_state) - set(old_ids):
|
||||||
n_delta_ids = {
|
n_delta_ids = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in new_state.iteritems()
|
for k, v in new_state.iteritems()
|
||||||
if old_ids.get(k) != v
|
if old_ids.get(k) != v
|
||||||
}
|
}
|
||||||
if not delta_ids or len(n_delta_ids) < len(delta_ids):
|
if not delta_ids or len(n_delta_ids) < len(delta_ids):
|
||||||
prev_group = old_group
|
prev_group = old_group
|
||||||
delta_ids = n_delta_ids
|
delta_ids = n_delta_ids
|
||||||
|
|
||||||
cache = _StateCacheEntry(
|
cache = _StateCacheEntry(
|
||||||
state=new_state,
|
state=new_state,
|
||||||
|
@ -702,7 +703,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
|
||||||
|
|
||||||
auth_events = {
|
auth_events = {
|
||||||
key: state_map[ev_id]
|
key: state_map[ev_id]
|
||||||
for key, ev_id in auth_event_ids.items()
|
for key, ev_id in auth_event_ids.iteritems()
|
||||||
if ev_id in state_map
|
if ev_id in state_map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -740,7 +741,7 @@ def _resolve_state_events(conflicted_state, auth_events):
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
for key, events in conflicted_state.items():
|
for key, events in conflicted_state.iteritems():
|
||||||
if key[0] == EventTypes.JoinRules:
|
if key[0] == EventTypes.JoinRules:
|
||||||
logger.debug("Resolving conflicted join rules %r", events)
|
logger.debug("Resolving conflicted join rules %r", events)
|
||||||
resolved_state[key] = _resolve_auth_events(
|
resolved_state[key] = _resolve_auth_events(
|
||||||
|
@ -750,7 +751,7 @@ def _resolve_state_events(conflicted_state, auth_events):
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
for key, events in conflicted_state.items():
|
for key, events in conflicted_state.iteritems():
|
||||||
if key[0] == EventTypes.Member:
|
if key[0] == EventTypes.Member:
|
||||||
logger.debug("Resolving conflicted member lists %r", events)
|
logger.debug("Resolving conflicted member lists %r", events)
|
||||||
resolved_state[key] = _resolve_auth_events(
|
resolved_state[key] = _resolve_auth_events(
|
||||||
|
@ -760,7 +761,7 @@ def _resolve_state_events(conflicted_state, auth_events):
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
for key, events in conflicted_state.items():
|
for key, events in conflicted_state.iteritems():
|
||||||
if key not in resolved_state:
|
if key not in resolved_state:
|
||||||
logger.debug("Resolving conflicted state %r:%r", key, events)
|
logger.debug("Resolving conflicted state %r:%r", key, events)
|
||||||
resolved_state[key] = _resolve_normal_events(
|
resolved_state[key] = _resolve_normal_events(
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -19,7 +20,6 @@ from synapse.storage.devices import DeviceStore
|
||||||
from .appservice import (
|
from .appservice import (
|
||||||
ApplicationServiceStore, ApplicationServiceTransactionStore
|
ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||||
)
|
)
|
||||||
from ._base import LoggingTransaction
|
|
||||||
from .directory import DirectoryStore
|
from .directory import DirectoryStore
|
||||||
from .events import EventsStore
|
from .events import EventsStore
|
||||||
from .presence import PresenceStore, UserPresenceState
|
from .presence import PresenceStore, UserPresenceState
|
||||||
|
@ -104,12 +104,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
db_conn, "events", "stream_ordering", step=-1,
|
db_conn, "events", "stream_ordering", step=-1,
|
||||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
|
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
|
||||||
)
|
)
|
||||||
self._receipts_id_gen = StreamIdGenerator(
|
|
||||||
db_conn, "receipts_linearized", "stream_id"
|
|
||||||
)
|
|
||||||
self._account_data_id_gen = StreamIdGenerator(
|
|
||||||
db_conn, "account_data_max_stream_id", "stream_id"
|
|
||||||
)
|
|
||||||
self._presence_id_gen = StreamIdGenerator(
|
self._presence_id_gen = StreamIdGenerator(
|
||||||
db_conn, "presence_stream", "stream_id"
|
db_conn, "presence_stream", "stream_id"
|
||||||
)
|
)
|
||||||
|
@ -146,27 +140,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
else:
|
else:
|
||||||
self._cache_id_gen = None
|
self._cache_id_gen = None
|
||||||
|
|
||||||
events_max = self._stream_id_gen.get_current_token()
|
|
||||||
event_cache_prefill, min_event_val = self._get_cache_dict(
|
|
||||||
db_conn, "events",
|
|
||||||
entity_column="room_id",
|
|
||||||
stream_column="stream_ordering",
|
|
||||||
max_value=events_max,
|
|
||||||
)
|
|
||||||
self._events_stream_cache = StreamChangeCache(
|
|
||||||
"EventsRoomStreamChangeCache", min_event_val,
|
|
||||||
prefilled_cache=event_cache_prefill,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._membership_stream_cache = StreamChangeCache(
|
|
||||||
"MembershipStreamChangeCache", events_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
account_max = self._account_data_id_gen.get_current_token()
|
|
||||||
self._account_data_stream_cache = StreamChangeCache(
|
|
||||||
"AccountDataAndTagsChangeCache", account_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||||
|
|
||||||
presence_cache_prefill, min_presence_val = self._get_cache_dict(
|
presence_cache_prefill, min_presence_val = self._get_cache_dict(
|
||||||
|
@ -180,18 +153,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
prefilled_cache=presence_cache_prefill
|
prefilled_cache=presence_cache_prefill
|
||||||
)
|
)
|
||||||
|
|
||||||
push_rules_prefill, push_rules_id = self._get_cache_dict(
|
|
||||||
db_conn, "push_rules_stream",
|
|
||||||
entity_column="user_id",
|
|
||||||
stream_column="stream_id",
|
|
||||||
max_value=self._push_rules_stream_id_gen.get_current_token()[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.push_rules_stream_cache = StreamChangeCache(
|
|
||||||
"PushRulesStreamChangeCache", push_rules_id,
|
|
||||||
prefilled_cache=push_rules_prefill,
|
|
||||||
)
|
|
||||||
|
|
||||||
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
|
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
|
||||||
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
|
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
|
||||||
db_conn, "device_inbox",
|
db_conn, "device_inbox",
|
||||||
|
@ -226,6 +187,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
"DeviceListFederationStreamChangeCache", device_list_max,
|
"DeviceListFederationStreamChangeCache", device_list_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
events_max = self._stream_id_gen.get_current_token()
|
||||||
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
|
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
|
||||||
db_conn, "current_state_delta_stream",
|
db_conn, "current_state_delta_stream",
|
||||||
entity_column="room_id",
|
entity_column="room_id",
|
||||||
|
@ -250,20 +212,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
prefilled_cache=_group_updates_prefill,
|
prefilled_cache=_group_updates_prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
cur = LoggingTransaction(
|
|
||||||
db_conn.cursor(),
|
|
||||||
name="_find_stream_orderings_for_times_txn",
|
|
||||||
database_engine=self.database_engine,
|
|
||||||
after_callbacks=[],
|
|
||||||
final_callbacks=[],
|
|
||||||
)
|
|
||||||
self._find_stream_orderings_for_times_txn(cur)
|
|
||||||
cur.close()
|
|
||||||
|
|
||||||
self.find_stream_orderings_looping_call = self._clock.looping_call(
|
|
||||||
self._find_stream_orderings_for_times, 10 * 60 * 1000
|
|
||||||
)
|
|
||||||
|
|
||||||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||||
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
||||||
|
|
||||||
|
|
|
@ -48,16 +48,16 @@ class LoggingTransaction(object):
|
||||||
passed to the constructor. Adds logging and metrics to the .execute()
|
passed to the constructor. Adds logging and metrics to the .execute()
|
||||||
method."""
|
method."""
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"txn", "name", "database_engine", "after_callbacks", "final_callbacks",
|
"txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, txn, name, database_engine, after_callbacks,
|
def __init__(self, txn, name, database_engine, after_callbacks,
|
||||||
final_callbacks):
|
exception_callbacks):
|
||||||
object.__setattr__(self, "txn", txn)
|
object.__setattr__(self, "txn", txn)
|
||||||
object.__setattr__(self, "name", name)
|
object.__setattr__(self, "name", name)
|
||||||
object.__setattr__(self, "database_engine", database_engine)
|
object.__setattr__(self, "database_engine", database_engine)
|
||||||
object.__setattr__(self, "after_callbacks", after_callbacks)
|
object.__setattr__(self, "after_callbacks", after_callbacks)
|
||||||
object.__setattr__(self, "final_callbacks", final_callbacks)
|
object.__setattr__(self, "exception_callbacks", exception_callbacks)
|
||||||
|
|
||||||
def call_after(self, callback, *args, **kwargs):
|
def call_after(self, callback, *args, **kwargs):
|
||||||
"""Call the given callback on the main twisted thread after the
|
"""Call the given callback on the main twisted thread after the
|
||||||
|
@ -66,8 +66,8 @@ class LoggingTransaction(object):
|
||||||
"""
|
"""
|
||||||
self.after_callbacks.append((callback, args, kwargs))
|
self.after_callbacks.append((callback, args, kwargs))
|
||||||
|
|
||||||
def call_finally(self, callback, *args, **kwargs):
|
def call_on_exception(self, callback, *args, **kwargs):
|
||||||
self.final_callbacks.append((callback, args, kwargs))
|
self.exception_callbacks.append((callback, args, kwargs))
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self.txn, name)
|
return getattr(self.txn, name)
|
||||||
|
@ -215,7 +215,7 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
self._clock.looping_call(loop, 10000)
|
self._clock.looping_call(loop, 10000)
|
||||||
|
|
||||||
def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
|
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
|
||||||
logging_context, func, *args, **kwargs):
|
logging_context, func, *args, **kwargs):
|
||||||
start = time.time() * 1000
|
start = time.time() * 1000
|
||||||
txn_id = self._TXN_ID
|
txn_id = self._TXN_ID
|
||||||
|
@ -236,7 +236,7 @@ class SQLBaseStore(object):
|
||||||
txn = conn.cursor()
|
txn = conn.cursor()
|
||||||
txn = LoggingTransaction(
|
txn = LoggingTransaction(
|
||||||
txn, name, self.database_engine, after_callbacks,
|
txn, name, self.database_engine, after_callbacks,
|
||||||
final_callbacks,
|
exception_callbacks,
|
||||||
)
|
)
|
||||||
r = func(txn, *args, **kwargs)
|
r = func(txn, *args, **kwargs)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
@ -308,11 +308,11 @@ class SQLBaseStore(object):
|
||||||
current_context = LoggingContext.current_context()
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
after_callbacks = []
|
after_callbacks = []
|
||||||
final_callbacks = []
|
exception_callbacks = []
|
||||||
|
|
||||||
def inner_func(conn, *args, **kwargs):
|
def inner_func(conn, *args, **kwargs):
|
||||||
return self._new_transaction(
|
return self._new_transaction(
|
||||||
conn, desc, after_callbacks, final_callbacks, current_context,
|
conn, desc, after_callbacks, exception_callbacks, current_context,
|
||||||
func, *args, **kwargs
|
func, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -321,9 +321,10 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||||
after_callback(*after_args, **after_kwargs)
|
after_callback(*after_args, **after_kwargs)
|
||||||
finally:
|
except: # noqa: E722, as we reraise the exception this is fine.
|
||||||
for after_callback, after_args, after_kwargs in final_callbacks:
|
for after_callback, after_args, after_kwargs in exception_callbacks:
|
||||||
after_callback(*after_args, **after_kwargs)
|
after_callback(*after_args, **after_kwargs)
|
||||||
|
raise
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@ -1000,7 +1001,8 @@ class SQLBaseStore(object):
|
||||||
# __exit__ called after the transaction finishes.
|
# __exit__ called after the transaction finishes.
|
||||||
ctx = self._cache_id_gen.get_next()
|
ctx = self._cache_id_gen.get_next()
|
||||||
stream_id = ctx.__enter__()
|
stream_id = ctx.__enter__()
|
||||||
txn.call_finally(ctx.__exit__, None, None, None)
|
txn.call_on_exception(ctx.__exit__, None, None, None)
|
||||||
|
txn.call_after(ctx.__exit__, None, None, None)
|
||||||
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,18 +14,46 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
|
|
||||||
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
|
||||||
|
|
||||||
import ujson as json
|
import abc
|
||||||
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AccountDataStore(SQLBaseStore):
|
class AccountDataWorkerStore(SQLBaseStore):
|
||||||
|
"""This is an abstract base class where subclasses must implement
|
||||||
|
`get_max_account_data_stream_id` which can be called in the initializer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This ABCMeta metaclass ensures that we cannot be instantiated without
|
||||||
|
# the abstract methods being implemented.
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
account_max = self.get_max_account_data_stream_id()
|
||||||
|
self._account_data_stream_cache = StreamChangeCache(
|
||||||
|
"AccountDataAndTagsChangeCache", account_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_max_account_data_stream_id(self):
|
||||||
|
"""Get the current max stream ID for account data stream
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_account_data_for_user(self, user_id):
|
def get_account_data_for_user(self, user_id):
|
||||||
|
@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore):
|
||||||
for row in rows
|
for row in rows
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@cached(num_args=2)
|
||||||
def get_account_data_for_room(self, user_id, room_id):
|
def get_account_data_for_room(self, user_id, room_id):
|
||||||
"""Get all the client account_data for a user for a room.
|
"""Get all the client account_data for a user for a room.
|
||||||
|
|
||||||
|
@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore):
|
||||||
"get_account_data_for_room", get_account_data_for_room_txn
|
"get_account_data_for_room", get_account_data_for_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(num_args=3, max_entries=5000)
|
||||||
|
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
|
||||||
|
"""Get the client account_data of given type for a user for a room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id(str): The user to get the account_data for.
|
||||||
|
room_id(str): The room to get the account_data for.
|
||||||
|
account_data_type (str): The account data type to get.
|
||||||
|
Returns:
|
||||||
|
A deferred of the room account_data for that type, or None if
|
||||||
|
there isn't any set.
|
||||||
|
"""
|
||||||
|
def get_account_data_for_room_and_type_txn(txn):
|
||||||
|
content_json = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="room_account_data",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"account_data_type": account_data_type,
|
||||||
|
},
|
||||||
|
retcol="content",
|
||||||
|
allow_none=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return json.loads(content_json) if content_json else None
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_account_data_for_room_and_type",
|
||||||
|
get_account_data_for_room_and_type_txn,
|
||||||
|
)
|
||||||
|
|
||||||
def get_all_updated_account_data(self, last_global_id, last_room_id,
|
def get_all_updated_account_data(self, last_global_id, last_room_id,
|
||||||
current_id, limit):
|
current_id, limit):
|
||||||
"""Get all the client account_data that has changed on the server
|
"""Get all the client account_data that has changed on the server
|
||||||
|
@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore):
|
||||||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
|
||||||
|
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
|
||||||
|
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
|
||||||
|
"m.ignored_user_list", ignorer_user_id,
|
||||||
|
on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
if not ignored_account_data:
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
defer.returnValue(
|
||||||
|
ignored_user_id in ignored_account_data.get("ignored_users", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AccountDataStore(AccountDataWorkerStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
self._account_data_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "account_data_max_stream_id", "stream_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
super(AccountDataStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
def get_max_account_data_stream_id(self):
|
||||||
|
"""Get the current max stream id for the private user data stream
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A deferred int.
|
||||||
|
"""
|
||||||
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
|
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
|
||||||
"""Add some account_data to a room for a user.
|
"""Add some account_data to a room for a user.
|
||||||
|
@ -251,6 +343,10 @@ class AccountDataStore(SQLBaseStore):
|
||||||
|
|
||||||
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||||
self.get_account_data_for_user.invalidate((user_id,))
|
self.get_account_data_for_user.invalidate((user_id,))
|
||||||
|
self.get_account_data_for_room.invalidate((user_id, room_id,))
|
||||||
|
self.get_account_data_for_room_and_type.prefill(
|
||||||
|
(user_id, room_id, account_data_type,), content,
|
||||||
|
)
|
||||||
|
|
||||||
result = self._account_data_id_gen.get_current_token()
|
result = self._account_data_id_gen.get_current_token()
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
@ -321,16 +417,3 @@ class AccountDataStore(SQLBaseStore):
|
||||||
"update_account_data_max_stream_id",
|
"update_account_data_max_stream_id",
|
||||||
_update,
|
_update,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
|
|
||||||
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
|
|
||||||
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
|
|
||||||
"m.ignored_user_list", ignorer_user_id,
|
|
||||||
on_invalidate=cache_context.invalidate,
|
|
||||||
)
|
|
||||||
if not ignored_account_data:
|
|
||||||
defer.returnValue(False)
|
|
||||||
|
|
||||||
defer.returnValue(
|
|
||||||
ignored_user_id in ignored_account_data.get("ignored_users", {})
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,10 +18,9 @@ import re
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
|
||||||
from synapse.appservice import AppServiceTransaction
|
from synapse.appservice import AppServiceTransaction
|
||||||
from synapse.config.appservice import load_appservices
|
from synapse.config.appservice import load_appservices
|
||||||
from synapse.storage.roommember import RoomsForUser
|
from synapse.storage.events import EventsWorkerStore
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,17 +46,16 @@ def _make_exclusive_regex(services_cache):
|
||||||
return exclusive_user_regex
|
return exclusive_user_regex
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceStore(SQLBaseStore):
|
class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(ApplicationServiceStore, self).__init__(db_conn, hs)
|
|
||||||
self.hostname = hs.hostname
|
|
||||||
self.services_cache = load_appservices(
|
self.services_cache = load_appservices(
|
||||||
hs.hostname,
|
hs.hostname,
|
||||||
hs.config.app_service_config_files
|
hs.config.app_service_config_files
|
||||||
)
|
)
|
||||||
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
|
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
|
||||||
|
|
||||||
|
super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
def get_app_services(self):
|
def get_app_services(self):
|
||||||
return self.services_cache
|
return self.services_cache
|
||||||
|
|
||||||
|
@ -112,83 +111,17 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||||
return service
|
return service
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_app_service_rooms(self, service):
|
|
||||||
"""Get a list of RoomsForUser for this application service.
|
|
||||||
|
|
||||||
Application services may be "interested" in lots of rooms depending on
|
class ApplicationServiceStore(ApplicationServiceWorkerStore):
|
||||||
the room ID, the room aliases, or the members in the room. This function
|
# This is currently empty due to there not being any AS storage functions
|
||||||
takes all of these into account and returns a list of RoomsForUser which
|
# that can't be run on the workers. Since this may change in future, and
|
||||||
represent the entire list of room IDs that this application service
|
# to keep consistency with the other stores, we keep this empty class for
|
||||||
wants to know about.
|
# now.
|
||||||
|
pass
|
||||||
Args:
|
|
||||||
service: The application service to get a room list for.
|
|
||||||
Returns:
|
|
||||||
A list of RoomsForUser.
|
|
||||||
"""
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_app_service_rooms",
|
|
||||||
self._get_app_service_rooms_txn,
|
|
||||||
service,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_app_service_rooms_txn(self, txn, service):
|
|
||||||
# get all rooms matching the room ID regex.
|
|
||||||
room_entries = self._simple_select_list_txn(
|
|
||||||
txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
|
|
||||||
)
|
|
||||||
matching_room_list = set([
|
|
||||||
r["room_id"] for r in room_entries if
|
|
||||||
service.is_interested_in_room(r["room_id"])
|
|
||||||
])
|
|
||||||
|
|
||||||
# resolve room IDs for matching room alias regex.
|
|
||||||
room_alias_mappings = self._simple_select_list_txn(
|
|
||||||
txn=txn, table="room_aliases", keyvalues=None,
|
|
||||||
retcols=["room_id", "room_alias"]
|
|
||||||
)
|
|
||||||
matching_room_list |= set([
|
|
||||||
r["room_id"] for r in room_alias_mappings if
|
|
||||||
service.is_interested_in_alias(r["room_alias"])
|
|
||||||
])
|
|
||||||
|
|
||||||
# get all rooms for every user for this AS. This is scoped to users on
|
|
||||||
# this HS only.
|
|
||||||
user_list = self._simple_select_list_txn(
|
|
||||||
txn=txn, table="users", keyvalues=None, retcols=["name"]
|
|
||||||
)
|
|
||||||
user_list = [
|
|
||||||
u["name"] for u in user_list if
|
|
||||||
service.is_interested_in_user(u["name"])
|
|
||||||
]
|
|
||||||
rooms_for_user_matching_user_id = set() # RoomsForUser list
|
|
||||||
for user_id in user_list:
|
|
||||||
# FIXME: This assumes this store is linked with RoomMemberStore :(
|
|
||||||
rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
|
|
||||||
txn=txn,
|
|
||||||
user_id=user_id,
|
|
||||||
membership_list=[Membership.JOIN]
|
|
||||||
)
|
|
||||||
rooms_for_user_matching_user_id |= set(rooms_for_user)
|
|
||||||
|
|
||||||
# make RoomsForUser tuples for room ids and aliases which are not in the
|
|
||||||
# main rooms_for_user_list - e.g. they are rooms which do not have AS
|
|
||||||
# registered users in it.
|
|
||||||
known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
|
|
||||||
missing_rooms_for_user = [
|
|
||||||
RoomsForUser(r, service.sender, "join") for r in
|
|
||||||
matching_room_list if r not in known_room_ids
|
|
||||||
]
|
|
||||||
rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
|
|
||||||
|
|
||||||
return rooms_for_user_matching_user_id
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceTransactionStore(SQLBaseStore):
|
class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||||
|
EventsWorkerStore):
|
||||||
def __init__(self, db_conn, hs):
|
|
||||||
super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_appservices_by_state(self, state):
|
def get_appservices_by_state(self, state):
|
||||||
"""Get a list of application services based on their state.
|
"""Get a list of application services based on their state.
|
||||||
|
@ -433,3 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
events = yield self._get_events(event_ids)
|
events = yield self._get_events(event_ids)
|
||||||
|
|
||||||
defer.returnValue((upper_bound, events))
|
defer.returnValue((upper_bound, events))
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
|
||||||
|
# This is currently empty due to there not being any AS storage functions
|
||||||
|
# that can't be run on the workers. Since this may change in future, and
|
||||||
|
# to keep consistency with the other stores, we keep this empty class for
|
||||||
|
# now.
|
||||||
|
pass
|
||||||
|
|
|
@ -19,7 +19,7 @@ from . import engines
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ujson
|
import simplejson
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
rows = []
|
rows = []
|
||||||
for destination, edu in remote_messages_by_destination.items():
|
for destination, edu in remote_messages_by_destination.items():
|
||||||
edu_json = ujson.dumps(edu)
|
edu_json = simplejson.dumps(edu)
|
||||||
rows.append((destination, stream_id, now_ms, edu_json))
|
rows.append((destination, stream_id, now_ms, edu_json))
|
||||||
txn.executemany(sql, rows)
|
txn.executemany(sql, rows)
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
" WHERE user_id = ?"
|
" WHERE user_id = ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
message_json = ujson.dumps(messages_by_device["*"])
|
message_json = simplejson.dumps(messages_by_device["*"])
|
||||||
for row in txn:
|
for row in txn:
|
||||||
# Add the message for all devices for this user on this
|
# Add the message for all devices for this user on this
|
||||||
# server.
|
# server.
|
||||||
|
@ -199,7 +199,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
# Only insert into the local inbox if the device exists on
|
# Only insert into the local inbox if the device exists on
|
||||||
# this server
|
# this server
|
||||||
device = row[0]
|
device = row[0]
|
||||||
message_json = ujson.dumps(messages_by_device[device])
|
message_json = simplejson.dumps(messages_by_device[device])
|
||||||
messages_json_for_user[device] = message_json
|
messages_json_for_user[device] = message_json
|
||||||
|
|
||||||
if messages_json_for_user:
|
if messages_json_for_user:
|
||||||
|
@ -253,7 +253,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
messages = []
|
messages = []
|
||||||
for row in txn:
|
for row in txn:
|
||||||
stream_pos = row[0]
|
stream_pos = row[0]
|
||||||
messages.append(ujson.loads(row[1]))
|
messages.append(simplejson.loads(row[1]))
|
||||||
if len(messages) < limit:
|
if len(messages) < limit:
|
||||||
stream_pos = current_stream_id
|
stream_pos = current_stream_id
|
||||||
return (messages, stream_pos)
|
return (messages, stream_pos)
|
||||||
|
@ -389,7 +389,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
messages = []
|
messages = []
|
||||||
for row in txn:
|
for row in txn:
|
||||||
stream_pos = row[0]
|
stream_pos = row[0]
|
||||||
messages.append(ujson.loads(row[1]))
|
messages.append(simplejson.loads(row[1]))
|
||||||
if len(messages) < limit:
|
if len(messages) < limit:
|
||||||
stream_pos = current_stream_id
|
stream_pos = current_stream_id
|
||||||
return (messages, stream_pos)
|
return (messages, stream_pos)
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
|
@ -29,8 +29,7 @@ RoomAliasMapping = namedtuple(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DirectoryStore(SQLBaseStore):
|
class DirectoryWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_association_from_room_alias(self, room_alias):
|
def get_association_from_room_alias(self, room_alias):
|
||||||
""" Get's the room_id and server list for a given room_alias
|
""" Get's the room_id and server list for a given room_alias
|
||||||
|
@ -69,6 +68,28 @@ class DirectoryStore(SQLBaseStore):
|
||||||
RoomAliasMapping(room_id, room_alias.to_string(), servers)
|
RoomAliasMapping(room_id, room_alias.to_string(), servers)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_room_alias_creator(self, room_alias):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="room_aliases",
|
||||||
|
keyvalues={
|
||||||
|
"room_alias": room_alias,
|
||||||
|
},
|
||||||
|
retcol="creator",
|
||||||
|
desc="get_room_alias_creator",
|
||||||
|
allow_none=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached(max_entries=5000)
|
||||||
|
def get_aliases_for_room(self, room_id):
|
||||||
|
return self._simple_select_onecol(
|
||||||
|
"room_aliases",
|
||||||
|
{"room_id": room_id},
|
||||||
|
"room_alias",
|
||||||
|
desc="get_aliases_for_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryStore(DirectoryWorkerStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
|
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
|
||||||
""" Creates an associatin between a room alias and room_id/servers
|
""" Creates an associatin between a room alias and room_id/servers
|
||||||
|
@ -116,17 +137,6 @@ class DirectoryStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def get_room_alias_creator(self, room_alias):
|
|
||||||
return self._simple_select_one_onecol(
|
|
||||||
table="room_aliases",
|
|
||||||
keyvalues={
|
|
||||||
"room_alias": room_alias,
|
|
||||||
},
|
|
||||||
retcol="creator",
|
|
||||||
desc="get_room_alias_creator",
|
|
||||||
allow_none=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_room_alias(self, room_alias):
|
def delete_room_alias(self, room_alias):
|
||||||
room_id = yield self.runInteraction(
|
room_id = yield self.runInteraction(
|
||||||
|
@ -135,7 +145,6 @@ class DirectoryStore(SQLBaseStore):
|
||||||
room_alias,
|
room_alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_aliases_for_room.invalidate((room_id,))
|
|
||||||
defer.returnValue(room_id)
|
defer.returnValue(room_id)
|
||||||
|
|
||||||
def _delete_room_alias_txn(self, txn, room_alias):
|
def _delete_room_alias_txn(self, txn, room_alias):
|
||||||
|
@ -160,17 +169,12 @@ class DirectoryStore(SQLBaseStore):
|
||||||
(room_alias.to_string(),)
|
(room_alias.to_string(),)
|
||||||
)
|
)
|
||||||
|
|
||||||
return room_id
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_aliases_for_room, (room_id,)
|
||||||
@cached(max_entries=5000)
|
|
||||||
def get_aliases_for_room(self, room_id):
|
|
||||||
return self._simple_select_onecol(
|
|
||||||
"room_aliases",
|
|
||||||
{"room_id": room_id},
|
|
||||||
"room_alias",
|
|
||||||
desc="get_aliases_for_room",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return room_id
|
||||||
|
|
||||||
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
|
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
|
||||||
def _update_aliases_for_room_txn(txn):
|
def _update_aliases_for_room_txn(txn):
|
||||||
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
|
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
|
||||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,10 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.storage.events import EventsWorkerStore
|
||||||
|
from synapse.storage.signatures import SignatureWorkerStore
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from unpaddedbase64 import encode_base64
|
from unpaddedbase64 import encode_base64
|
||||||
|
@ -27,30 +30,8 @@ from Queue import PriorityQueue, Empty
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EventFederationStore(SQLBaseStore):
|
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||||
""" Responsible for storing and serving up the various graphs associated
|
SQLBaseStore):
|
||||||
with an event. Including the main event graph and the auth chains for an
|
|
||||||
event.
|
|
||||||
|
|
||||||
Also has methods for getting the front (latest) and back (oldest) edges
|
|
||||||
of the event graphs. These are used to generate the parents for new events
|
|
||||||
and backfilling from another server respectively.
|
|
||||||
"""
|
|
||||||
|
|
||||||
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
|
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
|
||||||
super(EventFederationStore, self).__init__(db_conn, hs)
|
|
||||||
|
|
||||||
self.register_background_update_handler(
|
|
||||||
self.EVENT_AUTH_STATE_ONLY,
|
|
||||||
self._background_delete_non_state_event_auth,
|
|
||||||
)
|
|
||||||
|
|
||||||
hs.get_clock().looping_call(
|
|
||||||
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_auth_chain(self, event_ids, include_given=False):
|
def get_auth_chain(self, event_ids, include_given=False):
|
||||||
"""Get auth events for given event_ids. The events *must* be state events.
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
|
@ -228,88 +209,6 @@ class EventFederationStore(SQLBaseStore):
|
||||||
|
|
||||||
return int(min_depth) if min_depth is not None else None
|
return int(min_depth) if min_depth is not None else None
|
||||||
|
|
||||||
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
|
||||||
min_depth = self._get_min_depth_interaction(txn, room_id)
|
|
||||||
|
|
||||||
if min_depth and depth >= min_depth:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._simple_upsert_txn(
|
|
||||||
txn,
|
|
||||||
table="room_depth",
|
|
||||||
keyvalues={
|
|
||||||
"room_id": room_id,
|
|
||||||
},
|
|
||||||
values={
|
|
||||||
"min_depth": depth,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_mult_prev_events(self, txn, events):
|
|
||||||
"""
|
|
||||||
For the given event, update the event edges table and forward and
|
|
||||||
backward extremities tables.
|
|
||||||
"""
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="event_edges",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"event_id": ev.event_id,
|
|
||||||
"prev_event_id": e_id,
|
|
||||||
"room_id": ev.room_id,
|
|
||||||
"is_state": False,
|
|
||||||
}
|
|
||||||
for ev in events
|
|
||||||
for e_id, _ in ev.prev_events
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
self._update_backward_extremeties(txn, events)
|
|
||||||
|
|
||||||
def _update_backward_extremeties(self, txn, events):
|
|
||||||
"""Updates the event_backward_extremities tables based on the new/updated
|
|
||||||
events being persisted.
|
|
||||||
|
|
||||||
This is called for new events *and* for events that were outliers, but
|
|
||||||
are now being persisted as non-outliers.
|
|
||||||
|
|
||||||
Forward extremities are handled when we first start persisting the events.
|
|
||||||
"""
|
|
||||||
events_by_room = {}
|
|
||||||
for ev in events:
|
|
||||||
events_by_room.setdefault(ev.room_id, []).append(ev)
|
|
||||||
|
|
||||||
query = (
|
|
||||||
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
|
||||||
" SELECT ?, ? WHERE NOT EXISTS ("
|
|
||||||
" SELECT 1 FROM event_backward_extremities"
|
|
||||||
" WHERE event_id = ? AND room_id = ?"
|
|
||||||
" )"
|
|
||||||
" AND NOT EXISTS ("
|
|
||||||
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
|
|
||||||
" AND outlier = ?"
|
|
||||||
" )"
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.executemany(query, [
|
|
||||||
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
|
||||||
for ev in events for e_id, _ in ev.prev_events
|
|
||||||
if not ev.internal_metadata.is_outlier()
|
|
||||||
])
|
|
||||||
|
|
||||||
query = (
|
|
||||||
"DELETE FROM event_backward_extremities"
|
|
||||||
" WHERE event_id = ? AND room_id = ?"
|
|
||||||
)
|
|
||||||
txn.executemany(
|
|
||||||
query,
|
|
||||||
[
|
|
||||||
(ev.event_id, ev.room_id) for ev in events
|
|
||||||
if not ev.internal_metadata.is_outlier()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
|
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
|
||||||
"""For a given room_id and stream_ordering, return the forward
|
"""For a given room_id and stream_ordering, return the forward
|
||||||
extremeties of the room at that point in "time".
|
extremeties of the room at that point in "time".
|
||||||
|
@ -371,28 +270,6 @@ class EventFederationStore(SQLBaseStore):
|
||||||
get_forward_extremeties_for_room_txn
|
get_forward_extremeties_for_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def _delete_old_forward_extrem_cache(self):
|
|
||||||
def _delete_old_forward_extrem_cache_txn(txn):
|
|
||||||
# Delete entries older than a month, while making sure we don't delete
|
|
||||||
# the only entries for a room.
|
|
||||||
sql = ("""
|
|
||||||
DELETE FROM stream_ordering_to_exterm
|
|
||||||
WHERE
|
|
||||||
room_id IN (
|
|
||||||
SELECT room_id
|
|
||||||
FROM stream_ordering_to_exterm
|
|
||||||
WHERE stream_ordering > ?
|
|
||||||
) AND stream_ordering < ?
|
|
||||||
""")
|
|
||||||
txn.execute(
|
|
||||||
sql,
|
|
||||||
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
|
|
||||||
)
|
|
||||||
return self.runInteraction(
|
|
||||||
"_delete_old_forward_extrem_cache",
|
|
||||||
_delete_old_forward_extrem_cache_txn
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_backfill_events(self, room_id, event_list, limit):
|
def get_backfill_events(self, room_id, event_list, limit):
|
||||||
"""Get a list of Events for a given topic that occurred before (and
|
"""Get a list of Events for a given topic that occurred before (and
|
||||||
including) the events in event_list. Return a list of max size `limit`
|
including) the events in event_list. Return a list of max size `limit`
|
||||||
|
@ -522,6 +399,135 @@ class EventFederationStore(SQLBaseStore):
|
||||||
|
|
||||||
return event_results
|
return event_results
|
||||||
|
|
||||||
|
|
||||||
|
class EventFederationStore(EventFederationWorkerStore):
|
||||||
|
""" Responsible for storing and serving up the various graphs associated
|
||||||
|
with an event. Including the main event graph and the auth chains for an
|
||||||
|
event.
|
||||||
|
|
||||||
|
Also has methods for getting the front (latest) and back (oldest) edges
|
||||||
|
of the event graphs. These are used to generate the parents for new events
|
||||||
|
and backfilling from another server respectively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(EventFederationStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self.register_background_update_handler(
|
||||||
|
self.EVENT_AUTH_STATE_ONLY,
|
||||||
|
self._background_delete_non_state_event_auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
hs.get_clock().looping_call(
|
||||||
|
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
||||||
|
min_depth = self._get_min_depth_interaction(txn, room_id)
|
||||||
|
|
||||||
|
if min_depth and depth >= min_depth:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="room_depth",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"min_depth": depth,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_mult_prev_events(self, txn, events):
|
||||||
|
"""
|
||||||
|
For the given event, update the event edges table and forward and
|
||||||
|
backward extremities tables.
|
||||||
|
"""
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="event_edges",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"event_id": ev.event_id,
|
||||||
|
"prev_event_id": e_id,
|
||||||
|
"room_id": ev.room_id,
|
||||||
|
"is_state": False,
|
||||||
|
}
|
||||||
|
for ev in events
|
||||||
|
for e_id, _ in ev.prev_events
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._update_backward_extremeties(txn, events)
|
||||||
|
|
||||||
|
def _update_backward_extremeties(self, txn, events):
|
||||||
|
"""Updates the event_backward_extremities tables based on the new/updated
|
||||||
|
events being persisted.
|
||||||
|
|
||||||
|
This is called for new events *and* for events that were outliers, but
|
||||||
|
are now being persisted as non-outliers.
|
||||||
|
|
||||||
|
Forward extremities are handled when we first start persisting the events.
|
||||||
|
"""
|
||||||
|
events_by_room = {}
|
||||||
|
for ev in events:
|
||||||
|
events_by_room.setdefault(ev.room_id, []).append(ev)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
||||||
|
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||||
|
" SELECT 1 FROM event_backward_extremities"
|
||||||
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
|
" )"
|
||||||
|
" AND NOT EXISTS ("
|
||||||
|
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
|
||||||
|
" AND outlier = ?"
|
||||||
|
" )"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.executemany(query, [
|
||||||
|
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
||||||
|
for ev in events for e_id, _ in ev.prev_events
|
||||||
|
if not ev.internal_metadata.is_outlier()
|
||||||
|
])
|
||||||
|
|
||||||
|
query = (
|
||||||
|
"DELETE FROM event_backward_extremities"
|
||||||
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
|
)
|
||||||
|
txn.executemany(
|
||||||
|
query,
|
||||||
|
[
|
||||||
|
(ev.event_id, ev.room_id) for ev in events
|
||||||
|
if not ev.internal_metadata.is_outlier()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_old_forward_extrem_cache(self):
|
||||||
|
def _delete_old_forward_extrem_cache_txn(txn):
|
||||||
|
# Delete entries older than a month, while making sure we don't delete
|
||||||
|
# the only entries for a room.
|
||||||
|
sql = ("""
|
||||||
|
DELETE FROM stream_ordering_to_exterm
|
||||||
|
WHERE
|
||||||
|
room_id IN (
|
||||||
|
SELECT room_id
|
||||||
|
FROM stream_ordering_to_exterm
|
||||||
|
WHERE stream_ordering > ?
|
||||||
|
) AND stream_ordering < ?
|
||||||
|
""")
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"_delete_old_forward_extrem_cache",
|
||||||
|
_delete_old_forward_extrem_cache_txn
|
||||||
|
)
|
||||||
|
|
||||||
def clean_room_for_join(self, room_id):
|
def clean_room_for_join(self, room_id):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"clean_room_for_join",
|
"clean_room_for_join",
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015 OpenMarket Ltd
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,7 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore, LoggingTransaction
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
@ -21,7 +22,7 @@ from synapse.types import RoomStreamToken
|
||||||
from .stream import lower_bound
|
from .stream import lower_bound
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -62,77 +63,28 @@ def _deserialize_action(actions, is_highlight):
|
||||||
return DEFAULT_NOTIF_ACTION
|
return DEFAULT_NOTIF_ACTION
|
||||||
|
|
||||||
|
|
||||||
class EventPushActionsStore(SQLBaseStore):
|
class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
|
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(EventPushActionsStore, self).__init__(db_conn, hs)
|
super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.register_background_index_update(
|
# These get correctly set by _find_stream_orderings_for_times_txn
|
||||||
self.EPA_HIGHLIGHT_INDEX,
|
self.stream_ordering_month_ago = None
|
||||||
index_name="event_push_actions_u_highlight",
|
self.stream_ordering_day_ago = None
|
||||||
table="event_push_actions",
|
|
||||||
columns=["user_id", "stream_ordering"],
|
cur = LoggingTransaction(
|
||||||
|
db_conn.cursor(),
|
||||||
|
name="_find_stream_orderings_for_times_txn",
|
||||||
|
database_engine=self.database_engine,
|
||||||
|
after_callbacks=[],
|
||||||
|
exception_callbacks=[],
|
||||||
)
|
)
|
||||||
|
self._find_stream_orderings_for_times_txn(cur)
|
||||||
|
cur.close()
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.find_stream_orderings_looping_call = self._clock.looping_call(
|
||||||
"event_push_actions_highlights_index",
|
self._find_stream_orderings_for_times, 10 * 60 * 1000
|
||||||
index_name="event_push_actions_highlights_index",
|
|
||||||
table="event_push_actions",
|
|
||||||
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
|
|
||||||
where_clause="highlight=1"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._doing_notif_rotation = False
|
|
||||||
self._rotate_notif_loop = self._clock.looping_call(
|
|
||||||
self._rotate_notifs, 30 * 60 * 1000
|
|
||||||
)
|
|
||||||
|
|
||||||
def _set_push_actions_for_event_and_users_txn(self, txn, event):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
event: the event set actions for
|
|
||||||
tuples: list of tuples of (user_id, actions)
|
|
||||||
"""
|
|
||||||
|
|
||||||
sql = """
|
|
||||||
INSERT INTO event_push_actions (
|
|
||||||
room_id, event_id, user_id, actions, stream_ordering,
|
|
||||||
topological_ordering, notif, highlight
|
|
||||||
)
|
|
||||||
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
|
|
||||||
FROM event_push_actions_staging
|
|
||||||
WHERE event_id = ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(sql, (
|
|
||||||
event.room_id, event.internal_metadata.stream_ordering,
|
|
||||||
event.depth, event.event_id,
|
|
||||||
))
|
|
||||||
|
|
||||||
user_ids = self._simple_select_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="event_push_actions_staging",
|
|
||||||
keyvalues={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
},
|
|
||||||
retcol="user_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_delete_txn(
|
|
||||||
txn,
|
|
||||||
table="event_push_actions_staging",
|
|
||||||
keyvalues={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
for uid in user_ids:
|
|
||||||
txn.call_after(
|
|
||||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
|
||||||
(event.room_id, uid,)
|
|
||||||
)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
||||||
def get_unread_event_push_actions_by_room_for_user(
|
def get_unread_event_push_actions_by_room_for_user(
|
||||||
self, room_id, user_id, last_read_event_id
|
self, room_id, user_id, last_read_event_id
|
||||||
|
@ -449,6 +401,280 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
# Now return the first `limit`
|
# Now return the first `limit`
|
||||||
defer.returnValue(notifs[:limit])
|
defer.returnValue(notifs[:limit])
|
||||||
|
|
||||||
|
def add_push_actions_to_staging(self, event_id, user_id_actions):
|
||||||
|
"""Add the push actions for the event to the push action staging area.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id (str)
|
||||||
|
user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
|
||||||
|
user_id to list of push actions, where an action can either be
|
||||||
|
a string or dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not user_id_actions:
|
||||||
|
return
|
||||||
|
|
||||||
|
# This is a helper function for generating the necessary tuple that
|
||||||
|
# can be used to inert into the `event_push_actions_staging` table.
|
||||||
|
def _gen_entry(user_id, actions):
|
||||||
|
is_highlight = 1 if _action_has_highlight(actions) else 0
|
||||||
|
return (
|
||||||
|
event_id, # event_id column
|
||||||
|
user_id, # user_id column
|
||||||
|
_serialize_action(actions, is_highlight), # actions column
|
||||||
|
1, # notif column
|
||||||
|
is_highlight, # highlight column
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_push_actions_to_staging_txn(txn):
|
||||||
|
# We don't use _simple_insert_many here to avoid the overhead
|
||||||
|
# of generating lists of dicts.
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
INSERT INTO event_push_actions_staging
|
||||||
|
(event_id, user_id, actions, notif, highlight)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.executemany(sql, (
|
||||||
|
_gen_entry(user_id, actions)
|
||||||
|
for user_id, actions in user_id_actions.iteritems()
|
||||||
|
))
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_push_actions_from_staging(self, event_id):
|
||||||
|
"""Called if we failed to persist the event to ensure that stale push
|
||||||
|
actions don't build up in the DB
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id (str)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._simple_delete(
|
||||||
|
table="event_push_actions_staging",
|
||||||
|
keyvalues={
|
||||||
|
"event_id": event_id,
|
||||||
|
},
|
||||||
|
desc="remove_push_actions_from_staging",
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _find_stream_orderings_for_times(self):
|
||||||
|
yield self.runInteraction(
|
||||||
|
"_find_stream_orderings_for_times",
|
||||||
|
self._find_stream_orderings_for_times_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_stream_orderings_for_times_txn(self, txn):
|
||||||
|
logger.info("Searching for stream ordering 1 month ago")
|
||||||
|
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
|
||||||
|
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Found stream ordering 1 month ago: it's %d",
|
||||||
|
self.stream_ordering_month_ago
|
||||||
|
)
|
||||||
|
logger.info("Searching for stream ordering 1 day ago")
|
||||||
|
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
|
||||||
|
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Found stream ordering 1 day ago: it's %d",
|
||||||
|
self.stream_ordering_day_ago
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_first_stream_ordering_after_ts(self, ts):
|
||||||
|
"""Gets the stream ordering corresponding to a given timestamp.
|
||||||
|
|
||||||
|
Specifically, finds the stream_ordering of the first event that was
|
||||||
|
received on or after the timestamp. This is done by a binary search on
|
||||||
|
the events table, since there is no index on received_ts, so is
|
||||||
|
relatively slow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts (int): timestamp in millis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[int]: stream ordering of the first event received on/after
|
||||||
|
the timestamp
|
||||||
|
"""
|
||||||
|
return self.runInteraction(
|
||||||
|
"_find_first_stream_ordering_after_ts_txn",
|
||||||
|
self._find_first_stream_ordering_after_ts_txn,
|
||||||
|
ts,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_first_stream_ordering_after_ts_txn(txn, ts):
|
||||||
|
"""
|
||||||
|
Find the stream_ordering of the first event that was received on or
|
||||||
|
after a given timestamp. This is relatively slow as there is no index
|
||||||
|
on received_ts but we can then use this to delete push actions before
|
||||||
|
this.
|
||||||
|
|
||||||
|
received_ts must necessarily be in the same order as stream_ordering
|
||||||
|
and stream_ordering is indexed, so we manually binary search using
|
||||||
|
stream_ordering
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (twisted.enterprise.adbapi.Transaction):
|
||||||
|
ts (int): timestamp to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: stream ordering
|
||||||
|
"""
|
||||||
|
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
||||||
|
max_stream_ordering = txn.fetchone()[0]
|
||||||
|
|
||||||
|
if max_stream_ordering is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# We want the first stream_ordering in which received_ts is greater
|
||||||
|
# than or equal to ts. Call this point X.
|
||||||
|
#
|
||||||
|
# We maintain the invariants:
|
||||||
|
#
|
||||||
|
# range_start <= X <= range_end
|
||||||
|
#
|
||||||
|
range_start = 0
|
||||||
|
range_end = max_stream_ordering + 1
|
||||||
|
|
||||||
|
# Given a stream_ordering, look up the timestamp at that
|
||||||
|
# stream_ordering.
|
||||||
|
#
|
||||||
|
# The array may be sparse (we may be missing some stream_orderings).
|
||||||
|
# We treat the gaps as the same as having the same value as the
|
||||||
|
# preceding entry, because we will pick the lowest stream_ordering
|
||||||
|
# which satisfies our requirement of received_ts >= ts.
|
||||||
|
#
|
||||||
|
# For example, if our array of events indexed by stream_ordering is
|
||||||
|
# [10, <none>, 20], we should treat this as being equivalent to
|
||||||
|
# [10, 10, 20].
|
||||||
|
#
|
||||||
|
sql = (
|
||||||
|
"SELECT received_ts FROM events"
|
||||||
|
" WHERE stream_ordering <= ?"
|
||||||
|
" ORDER BY stream_ordering DESC"
|
||||||
|
" LIMIT 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
while range_end - range_start > 0:
|
||||||
|
middle = (range_end + range_start) // 2
|
||||||
|
txn.execute(sql, (middle,))
|
||||||
|
row = txn.fetchone()
|
||||||
|
if row is None:
|
||||||
|
# no rows with stream_ordering<=middle
|
||||||
|
range_start = middle + 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
middle_ts = row[0]
|
||||||
|
if ts > middle_ts:
|
||||||
|
# we got a timestamp lower than the one we were looking for.
|
||||||
|
# definitely need to look higher: X > middle.
|
||||||
|
range_start = middle + 1
|
||||||
|
else:
|
||||||
|
# we got a timestamp higher than (or the same as) the one we
|
||||||
|
# were looking for. We aren't yet sure about the point we
|
||||||
|
# looked up, but we can be sure that X <= middle.
|
||||||
|
range_end = middle
|
||||||
|
|
||||||
|
return range_end
|
||||||
|
|
||||||
|
|
||||||
|
class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||||
|
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(EventPushActionsStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
self.EPA_HIGHLIGHT_INDEX,
|
||||||
|
index_name="event_push_actions_u_highlight",
|
||||||
|
table="event_push_actions",
|
||||||
|
columns=["user_id", "stream_ordering"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"event_push_actions_highlights_index",
|
||||||
|
index_name="event_push_actions_highlights_index",
|
||||||
|
table="event_push_actions",
|
||||||
|
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
|
||||||
|
where_clause="highlight=1"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._doing_notif_rotation = False
|
||||||
|
self._rotate_notif_loop = self._clock.looping_call(
|
||||||
|
self._rotate_notifs, 30 * 60 * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
|
||||||
|
all_events_and_contexts):
|
||||||
|
"""Handles moving push actions from staging table to main
|
||||||
|
event_push_actions table for all events in `events_and_contexts`.
|
||||||
|
|
||||||
|
Also ensures that all events in `all_events_and_contexts` are removed
|
||||||
|
from the push action staging area.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||||
|
we are persisting
|
||||||
|
all_events_and_contexts (list[(EventBase, EventContext)]): all
|
||||||
|
events that we were going to persist. This includes events
|
||||||
|
we've already persisted, etc, that wouldn't appear in
|
||||||
|
events_and_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
INSERT INTO event_push_actions (
|
||||||
|
room_id, event_id, user_id, actions, stream_ordering,
|
||||||
|
topological_ordering, notif, highlight
|
||||||
|
)
|
||||||
|
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
|
||||||
|
FROM event_push_actions_staging
|
||||||
|
WHERE event_id = ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
if events_and_contexts:
|
||||||
|
txn.executemany(sql, (
|
||||||
|
(
|
||||||
|
event.room_id, event.internal_metadata.stream_ordering,
|
||||||
|
event.depth, event.event_id,
|
||||||
|
)
|
||||||
|
for event, _ in events_and_contexts
|
||||||
|
))
|
||||||
|
|
||||||
|
for event, _ in events_and_contexts:
|
||||||
|
user_ids = self._simple_select_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="event_push_actions_staging",
|
||||||
|
keyvalues={
|
||||||
|
"event_id": event.event_id,
|
||||||
|
},
|
||||||
|
retcol="user_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
for uid in user_ids:
|
||||||
|
txn.call_after(
|
||||||
|
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||||
|
(event.room_id, uid,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now we delete the staging area for *all* events that were being
|
||||||
|
# persisted.
|
||||||
|
txn.executemany(
|
||||||
|
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
|
||||||
|
(
|
||||||
|
(event.event_id,)
|
||||||
|
for event, _ in all_events_and_contexts
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_push_actions_for_user(self, user_id, before=None, limit=50,
|
def get_push_actions_for_user(self, user_id, before=None, limit=50,
|
||||||
only_highlight=False):
|
only_highlight=False):
|
||||||
|
@ -567,69 +793,6 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
|
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
|
||||||
""", (room_id, user_id, stream_ordering))
|
""", (room_id, user_id, stream_ordering))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _find_stream_orderings_for_times(self):
|
|
||||||
yield self.runInteraction(
|
|
||||||
"_find_stream_orderings_for_times",
|
|
||||||
self._find_stream_orderings_for_times_txn
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_stream_orderings_for_times_txn(self, txn):
|
|
||||||
logger.info("Searching for stream ordering 1 month ago")
|
|
||||||
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
|
|
||||||
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Found stream ordering 1 month ago: it's %d",
|
|
||||||
self.stream_ordering_month_ago
|
|
||||||
)
|
|
||||||
logger.info("Searching for stream ordering 1 day ago")
|
|
||||||
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
|
|
||||||
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Found stream ordering 1 day ago: it's %d",
|
|
||||||
self.stream_ordering_day_ago
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
|
|
||||||
"""
|
|
||||||
Find the stream_ordering of the first event that was received after
|
|
||||||
a given timestamp. This is relatively slow as there is no index on
|
|
||||||
received_ts but we can then use this to delete push actions before
|
|
||||||
this.
|
|
||||||
|
|
||||||
received_ts must necessarily be in the same order as stream_ordering
|
|
||||||
and stream_ordering is indexed, so we manually binary search using
|
|
||||||
stream_ordering
|
|
||||||
"""
|
|
||||||
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
|
||||||
max_stream_ordering = txn.fetchone()[0]
|
|
||||||
|
|
||||||
if max_stream_ordering is None:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
range_start = 0
|
|
||||||
range_end = max_stream_ordering
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT received_ts FROM events"
|
|
||||||
" WHERE stream_ordering > ?"
|
|
||||||
" ORDER BY stream_ordering"
|
|
||||||
" LIMIT 1"
|
|
||||||
)
|
|
||||||
|
|
||||||
while range_end - range_start > 1:
|
|
||||||
middle = int((range_end + range_start) / 2)
|
|
||||||
txn.execute(sql, (middle,))
|
|
||||||
middle_ts = txn.fetchone()[0]
|
|
||||||
if ts > middle_ts:
|
|
||||||
range_start = middle
|
|
||||||
else:
|
|
||||||
range_end = middle
|
|
||||||
|
|
||||||
return range_end
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _rotate_notifs(self):
|
def _rotate_notifs(self):
|
||||||
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
|
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
|
||||||
|
@ -755,50 +918,6 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
(rotate_to_stream_ordering,)
|
(rotate_to_stream_ordering,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_push_actions_to_staging(self, event_id, user_id, actions):
|
|
||||||
"""Add the push actions for the user and event to the push
|
|
||||||
action staging area.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_id (str)
|
|
||||||
user_id (str)
|
|
||||||
actions (list[dict|str]): An action can either be a string or
|
|
||||||
dict.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
|
||||||
|
|
||||||
is_highlight = 1 if _action_has_highlight(actions) else 0
|
|
||||||
|
|
||||||
return self._simple_insert(
|
|
||||||
table="event_push_actions_staging",
|
|
||||||
values={
|
|
||||||
"event_id": event_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"actions": _serialize_action(actions, is_highlight),
|
|
||||||
"notif": 1,
|
|
||||||
"highlight": is_highlight,
|
|
||||||
},
|
|
||||||
desc="add_push_actions_to_staging",
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove_push_actions_from_staging(self, event_id):
|
|
||||||
"""Called if we failed to persist the event to ensure that stale push
|
|
||||||
actions don't build up in the DB
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_id (str)
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self._simple_delete(
|
|
||||||
table="event_push_actions_staging",
|
|
||||||
keyvalues={
|
|
||||||
"event_id": event_id,
|
|
||||||
},
|
|
||||||
desc="remove_push_actions_from_staging",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _action_has_highlight(actions):
|
def _action_has_highlight(actions):
|
||||||
for action in actions:
|
for action in actions:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,33 +13,29 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from ._base import SQLBaseStore
|
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from collections import OrderedDict, deque, namedtuple
|
||||||
|
from functools import wraps
|
||||||
|
import logging
|
||||||
|
|
||||||
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
|
import simplejson as json
|
||||||
from synapse.events.utils import prune_event
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
||||||
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
from synapse.util.frozenutils import frozendict_json_encoder
|
||||||
from synapse.util.logcontext import (
|
from synapse.util.logcontext import (
|
||||||
preserve_fn, PreserveLoggingContext, make_deferred_yieldable
|
PreserveLoggingContext, make_deferred_yieldable,
|
||||||
)
|
)
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
|
||||||
from collections import deque, namedtuple, OrderedDict
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
import logging
|
|
||||||
import ujson as json
|
|
||||||
|
|
||||||
# these are only included to make the type annotations work
|
# these are only included to make the type annotations work
|
||||||
from synapse.events import EventBase # noqa: F401
|
from synapse.events import EventBase # noqa: F401
|
||||||
from synapse.events.snapshot import EventContext # noqa: F401
|
from synapse.events.snapshot import EventContext # noqa: F401
|
||||||
|
@ -52,23 +49,25 @@ event_counter = metrics.register_counter(
|
||||||
"persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
|
"persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The number of times we are recalculating the current state
|
||||||
|
state_delta_counter = metrics.register_counter(
|
||||||
|
"state_delta",
|
||||||
|
)
|
||||||
|
# The number of times we are recalculating state when there is only a
|
||||||
|
# single forward extremity
|
||||||
|
state_delta_single_event_counter = metrics.register_counter(
|
||||||
|
"state_delta_single_event",
|
||||||
|
)
|
||||||
|
# The number of times we are reculating state when we could have resonably
|
||||||
|
# calculated the delta when we calculated the state for an event we were
|
||||||
|
# persisting.
|
||||||
|
state_delta_reuse_delta_counter = metrics.register_counter(
|
||||||
|
"state_delta_reuse_delta",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def encode_json(json_object):
|
def encode_json(json_object):
|
||||||
if USE_FROZEN_DICTS:
|
return frozendict_json_encoder.encode(json_object)
|
||||||
# ujson doesn't like frozen_dicts
|
|
||||||
return encode_canonical_json(json_object)
|
|
||||||
else:
|
|
||||||
return json.dumps(json_object, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
# These values are used in the `enqueus_event` and `_do_fetch` methods to
|
|
||||||
# control how we batch/bulk fetch events from the database.
|
|
||||||
# The values are plucked out of thing air to make initial sync run faster
|
|
||||||
# on jki.re
|
|
||||||
# TODO: Make these configurable.
|
|
||||||
EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
|
|
||||||
EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
|
|
||||||
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
|
|
||||||
|
|
||||||
|
|
||||||
class _EventPeristenceQueue(object):
|
class _EventPeristenceQueue(object):
|
||||||
|
@ -199,13 +198,12 @@ def _retry_on_integrity_error(func):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
class EventsStore(SQLBaseStore):
|
class EventsStore(EventsWorkerStore):
|
||||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(EventsStore, self).__init__(db_conn, hs)
|
super(EventsStore, self).__init__(db_conn, hs)
|
||||||
self._clock = hs.get_clock()
|
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
|
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
|
||||||
)
|
)
|
||||||
|
@ -293,10 +291,11 @@ class EventsStore(SQLBaseStore):
|
||||||
def _maybe_start_persisting(self, room_id):
|
def _maybe_start_persisting(self, room_id):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def persisting_queue(item):
|
def persisting_queue(item):
|
||||||
yield self._persist_events(
|
with Measure(self._clock, "persist_events"):
|
||||||
item.events_and_contexts,
|
yield self._persist_events(
|
||||||
backfilled=item.backfilled,
|
item.events_and_contexts,
|
||||||
)
|
backfilled=item.backfilled,
|
||||||
|
)
|
||||||
|
|
||||||
self._event_persist_queue.handle_queue(room_id, persisting_queue)
|
self._event_persist_queue.handle_queue(room_id, persisting_queue)
|
||||||
|
|
||||||
|
@ -378,7 +377,8 @@ class EventsStore(SQLBaseStore):
|
||||||
room_id, ev_ctx_rm, latest_event_ids
|
room_id, ev_ctx_rm, latest_event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_latest_event_ids == set(latest_event_ids):
|
latest_event_ids = set(latest_event_ids)
|
||||||
|
if new_latest_event_ids == latest_event_ids:
|
||||||
# No change in extremities, so no change in state
|
# No change in extremities, so no change in state
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -399,6 +399,26 @@ class EventsStore(SQLBaseStore):
|
||||||
if all_single_prev_not_state:
|
if all_single_prev_not_state:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
state_delta_counter.inc()
|
||||||
|
if len(new_latest_event_ids) == 1:
|
||||||
|
state_delta_single_event_counter.inc()
|
||||||
|
|
||||||
|
# This is a fairly handwavey check to see if we could
|
||||||
|
# have guessed what the delta would have been when
|
||||||
|
# processing one of these events.
|
||||||
|
# What we're interested in is if the latest extremities
|
||||||
|
# were the same when we created the event as they are
|
||||||
|
# now. When this server creates a new event (as opposed
|
||||||
|
# to receiving it over federation) it will use the
|
||||||
|
# forward extremities as the prev_events, so we can
|
||||||
|
# guess this by looking at the prev_events and checking
|
||||||
|
# if they match the current forward extremities.
|
||||||
|
for ev, _ in ev_ctx_rm:
|
||||||
|
prev_event_ids = set(e for e, _ in ev.prev_events)
|
||||||
|
if latest_event_ids == prev_event_ids:
|
||||||
|
state_delta_reuse_delta_counter.inc()
|
||||||
|
break
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Calculating state delta for room %s", room_id,
|
"Calculating state delta for room %s", room_id,
|
||||||
)
|
)
|
||||||
|
@ -609,62 +629,6 @@ class EventsStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue((to_delete, to_insert))
|
defer.returnValue((to_delete, to_insert))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_event(self, event_id, check_redacted=True,
|
|
||||||
get_prev_content=False, allow_rejected=False,
|
|
||||||
allow_none=False):
|
|
||||||
"""Get an event from the database by event_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_id (str): The event_id of the event to fetch
|
|
||||||
check_redacted (bool): If True, check if event has been redacted
|
|
||||||
and redact it.
|
|
||||||
get_prev_content (bool): If True and event is a state event,
|
|
||||||
include the previous states content in the unsigned field.
|
|
||||||
allow_rejected (bool): If True return rejected events.
|
|
||||||
allow_none (bool): If True, return None if no event found, if
|
|
||||||
False throw an exception.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred : A FrozenEvent.
|
|
||||||
"""
|
|
||||||
events = yield self._get_events(
|
|
||||||
[event_id],
|
|
||||||
check_redacted=check_redacted,
|
|
||||||
get_prev_content=get_prev_content,
|
|
||||||
allow_rejected=allow_rejected,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not events and not allow_none:
|
|
||||||
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
|
||||||
|
|
||||||
defer.returnValue(events[0] if events else None)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_events(self, event_ids, check_redacted=True,
|
|
||||||
get_prev_content=False, allow_rejected=False):
|
|
||||||
"""Get events from the database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_ids (list): The event_ids of the events to fetch
|
|
||||||
check_redacted (bool): If True, check if event has been redacted
|
|
||||||
and redact it.
|
|
||||||
get_prev_content (bool): If True and event is a state event,
|
|
||||||
include the previous states content in the unsigned field.
|
|
||||||
allow_rejected (bool): If True return rejected events.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred : Dict from event_id to event.
|
|
||||||
"""
|
|
||||||
events = yield self._get_events(
|
|
||||||
event_ids,
|
|
||||||
check_redacted=check_redacted,
|
|
||||||
get_prev_content=get_prev_content,
|
|
||||||
allow_rejected=allow_rejected,
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue({e.event_id: e for e in events})
|
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
||||||
delete_existing=False, state_delta_for_room={},
|
delete_existing=False, state_delta_for_room={},
|
||||||
|
@ -693,6 +657,8 @@ class EventsStore(SQLBaseStore):
|
||||||
list of the event ids which are the forward extremities.
|
list of the event ids which are the forward extremities.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
all_events_and_contexts = events_and_contexts
|
||||||
|
|
||||||
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
||||||
|
|
||||||
self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
|
self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
|
||||||
|
@ -755,6 +721,7 @@ class EventsStore(SQLBaseStore):
|
||||||
self._update_metadata_tables_txn(
|
self._update_metadata_tables_txn(
|
||||||
txn,
|
txn,
|
||||||
events_and_contexts=events_and_contexts,
|
events_and_contexts=events_and_contexts,
|
||||||
|
all_events_and_contexts=all_events_and_contexts,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -817,7 +784,7 @@ class EventsStore(SQLBaseStore):
|
||||||
|
|
||||||
for member in members_changed:
|
for member in members_changed:
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.get_rooms_for_user, (member,)
|
txn, self.get_rooms_for_user_with_stream_ordering, (member,)
|
||||||
)
|
)
|
||||||
|
|
||||||
for host in set(get_domain_from_id(u) for u in members_changed):
|
for host in set(get_domain_from_id(u) for u in members_changed):
|
||||||
|
@ -1152,26 +1119,33 @@ class EventsStore(SQLBaseStore):
|
||||||
ec for ec in events_and_contexts if ec[0] not in to_remove
|
ec for ec in events_and_contexts if ec[0] not in to_remove
|
||||||
]
|
]
|
||||||
|
|
||||||
def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
|
def _update_metadata_tables_txn(self, txn, events_and_contexts,
|
||||||
|
all_events_and_contexts, backfilled):
|
||||||
"""Update all the miscellaneous tables for new events
|
"""Update all the miscellaneous tables for new events
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||||
we are persisting
|
we are persisting
|
||||||
|
all_events_and_contexts (list[(EventBase, EventContext)]): all
|
||||||
|
events that we were going to persist. This includes events
|
||||||
|
we've already persisted, etc, that wouldn't appear in
|
||||||
|
events_and_context.
|
||||||
backfilled (bool): True if the events were backfilled
|
backfilled (bool): True if the events were backfilled
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Insert all the push actions into the event_push_actions table.
|
||||||
|
self._set_push_actions_for_event_and_users_txn(
|
||||||
|
txn,
|
||||||
|
events_and_contexts=events_and_contexts,
|
||||||
|
all_events_and_contexts=all_events_and_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
if not events_and_contexts:
|
if not events_and_contexts:
|
||||||
# nothing to do here
|
# nothing to do here
|
||||||
return
|
return
|
||||||
|
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
# Insert all the push actions into the event_push_actions table.
|
|
||||||
self._set_push_actions_for_event_and_users_txn(
|
|
||||||
txn, event,
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction and event.redacts is not None:
|
if event.type == EventTypes.Redaction and event.redacts is not None:
|
||||||
# Remove the entries in the event_push_actions table for the
|
# Remove the entries in the event_push_actions table for the
|
||||||
# redacted event.
|
# redacted event.
|
||||||
|
@ -1375,292 +1349,6 @@ class EventsStore(SQLBaseStore):
|
||||||
"have_events", f,
|
"have_events", f,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_events(self, event_ids, check_redacted=True,
|
|
||||||
get_prev_content=False, allow_rejected=False):
|
|
||||||
if not event_ids:
|
|
||||||
defer.returnValue([])
|
|
||||||
|
|
||||||
event_id_list = event_ids
|
|
||||||
event_ids = set(event_ids)
|
|
||||||
|
|
||||||
event_entry_map = self._get_events_from_cache(
|
|
||||||
event_ids,
|
|
||||||
allow_rejected=allow_rejected,
|
|
||||||
)
|
|
||||||
|
|
||||||
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
|
|
||||||
|
|
||||||
if missing_events_ids:
|
|
||||||
missing_events = yield self._enqueue_events(
|
|
||||||
missing_events_ids,
|
|
||||||
check_redacted=check_redacted,
|
|
||||||
allow_rejected=allow_rejected,
|
|
||||||
)
|
|
||||||
|
|
||||||
event_entry_map.update(missing_events)
|
|
||||||
|
|
||||||
events = []
|
|
||||||
for event_id in event_id_list:
|
|
||||||
entry = event_entry_map.get(event_id, None)
|
|
||||||
if not entry:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if allow_rejected or not entry.event.rejected_reason:
|
|
||||||
if check_redacted and entry.redacted_event:
|
|
||||||
event = entry.redacted_event
|
|
||||||
else:
|
|
||||||
event = entry.event
|
|
||||||
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
if get_prev_content:
|
|
||||||
if "replaces_state" in event.unsigned:
|
|
||||||
prev = yield self.get_event(
|
|
||||||
event.unsigned["replaces_state"],
|
|
||||||
get_prev_content=False,
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
if prev:
|
|
||||||
event.unsigned = dict(event.unsigned)
|
|
||||||
event.unsigned["prev_content"] = prev.content
|
|
||||||
event.unsigned["prev_sender"] = prev.sender
|
|
||||||
|
|
||||||
defer.returnValue(events)
|
|
||||||
|
|
||||||
def _invalidate_get_event_cache(self, event_id):
|
|
||||||
self._get_event_cache.invalidate((event_id,))
|
|
||||||
|
|
||||||
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
|
|
||||||
"""Fetch events from the caches
|
|
||||||
|
|
||||||
Args:
|
|
||||||
events (list(str)): list of event_ids to fetch
|
|
||||||
allow_rejected (bool): Whether to teturn events that were rejected
|
|
||||||
update_metrics (bool): Whether to update the cache hit ratio metrics
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict of event_id -> _EventCacheEntry for each event_id in cache. If
|
|
||||||
allow_rejected is `False` then there will still be an entry but it
|
|
||||||
will be `None`
|
|
||||||
"""
|
|
||||||
event_map = {}
|
|
||||||
|
|
||||||
for event_id in events:
|
|
||||||
ret = self._get_event_cache.get(
|
|
||||||
(event_id,), None,
|
|
||||||
update_metrics=update_metrics,
|
|
||||||
)
|
|
||||||
if not ret:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if allow_rejected or not ret.event.rejected_reason:
|
|
||||||
event_map[event_id] = ret
|
|
||||||
else:
|
|
||||||
event_map[event_id] = None
|
|
||||||
|
|
||||||
return event_map
|
|
||||||
|
|
||||||
def _do_fetch(self, conn):
|
|
||||||
"""Takes a database connection and waits for requests for events from
|
|
||||||
the _event_fetch_list queue.
|
|
||||||
"""
|
|
||||||
event_list = []
|
|
||||||
i = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
with self._event_fetch_lock:
|
|
||||||
event_list = self._event_fetch_list
|
|
||||||
self._event_fetch_list = []
|
|
||||||
|
|
||||||
if not event_list:
|
|
||||||
single_threaded = self.database_engine.single_threaded
|
|
||||||
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
|
|
||||||
self._event_fetch_ongoing -= 1
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
i = 0
|
|
||||||
|
|
||||||
event_id_lists = zip(*event_list)[0]
|
|
||||||
event_ids = [
|
|
||||||
item for sublist in event_id_lists for item in sublist
|
|
||||||
]
|
|
||||||
|
|
||||||
rows = self._new_transaction(
|
|
||||||
conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
row_dict = {
|
|
||||||
r["event_id"]: r
|
|
||||||
for r in rows
|
|
||||||
}
|
|
||||||
|
|
||||||
# We only want to resolve deferreds from the main thread
|
|
||||||
def fire(lst, res):
|
|
||||||
for ids, d in lst:
|
|
||||||
if not d.called:
|
|
||||||
try:
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
d.callback([
|
|
||||||
res[i]
|
|
||||||
for i in ids
|
|
||||||
if i in res
|
|
||||||
])
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to callback")
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
reactor.callFromThread(fire, event_list, row_dict)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("do_fetch")
|
|
||||||
|
|
||||||
# We only want to resolve deferreds from the main thread
|
|
||||||
def fire(evs):
|
|
||||||
for _, d in evs:
|
|
||||||
if not d.called:
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
d.errback(e)
|
|
||||||
|
|
||||||
if event_list:
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
reactor.callFromThread(fire, event_list)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
|
||||||
"""Fetches events from the database using the _event_fetch_list. This
|
|
||||||
allows batch and bulk fetching of events - it allows us to fetch events
|
|
||||||
without having to create a new transaction for each request for events.
|
|
||||||
"""
|
|
||||||
if not events:
|
|
||||||
defer.returnValue({})
|
|
||||||
|
|
||||||
events_d = defer.Deferred()
|
|
||||||
with self._event_fetch_lock:
|
|
||||||
self._event_fetch_list.append(
|
|
||||||
(events, events_d)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._event_fetch_lock.notify()
|
|
||||||
|
|
||||||
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
|
|
||||||
self._event_fetch_ongoing += 1
|
|
||||||
should_start = True
|
|
||||||
else:
|
|
||||||
should_start = False
|
|
||||||
|
|
||||||
if should_start:
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
self.runWithConnection(
|
|
||||||
self._do_fetch
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Loading %d events", len(events))
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
rows = yield events_d
|
|
||||||
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
|
|
||||||
|
|
||||||
if not allow_rejected:
|
|
||||||
rows[:] = [r for r in rows if not r["rejects"]]
|
|
||||||
|
|
||||||
res = yield make_deferred_yieldable(defer.gatherResults(
|
|
||||||
[
|
|
||||||
preserve_fn(self._get_event_from_row)(
|
|
||||||
row["internal_metadata"], row["json"], row["redacts"],
|
|
||||||
rejected_reason=row["rejects"],
|
|
||||||
)
|
|
||||||
for row in rows
|
|
||||||
],
|
|
||||||
consumeErrors=True
|
|
||||||
))
|
|
||||||
|
|
||||||
defer.returnValue({
|
|
||||||
e.event.event_id: e
|
|
||||||
for e in res if e
|
|
||||||
})
|
|
||||||
|
|
||||||
def _fetch_event_rows(self, txn, events):
|
|
||||||
rows = []
|
|
||||||
N = 200
|
|
||||||
for i in range(1 + len(events) / N):
|
|
||||||
evs = events[i * N:(i + 1) * N]
|
|
||||||
if not evs:
|
|
||||||
break
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT "
|
|
||||||
" e.event_id as event_id, "
|
|
||||||
" e.internal_metadata,"
|
|
||||||
" e.json,"
|
|
||||||
" r.redacts as redacts,"
|
|
||||||
" rej.event_id as rejects "
|
|
||||||
" FROM event_json as e"
|
|
||||||
" LEFT JOIN rejections as rej USING (event_id)"
|
|
||||||
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
|
|
||||||
" WHERE e.event_id IN (%s)"
|
|
||||||
) % (",".join(["?"] * len(evs)),)
|
|
||||||
|
|
||||||
txn.execute(sql, evs)
|
|
||||||
rows.extend(self.cursor_to_dict(txn))
|
|
||||||
|
|
||||||
return rows
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_event_from_row(self, internal_metadata, js, redacted,
|
|
||||||
rejected_reason=None):
|
|
||||||
with Measure(self._clock, "_get_event_from_row"):
|
|
||||||
d = json.loads(js)
|
|
||||||
internal_metadata = json.loads(internal_metadata)
|
|
||||||
|
|
||||||
if rejected_reason:
|
|
||||||
rejected_reason = yield self._simple_select_one_onecol(
|
|
||||||
table="rejections",
|
|
||||||
keyvalues={"event_id": rejected_reason},
|
|
||||||
retcol="reason",
|
|
||||||
desc="_get_event_from_row_rejected_reason",
|
|
||||||
)
|
|
||||||
|
|
||||||
original_ev = FrozenEvent(
|
|
||||||
d,
|
|
||||||
internal_metadata_dict=internal_metadata,
|
|
||||||
rejected_reason=rejected_reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
redacted_event = None
|
|
||||||
if redacted:
|
|
||||||
redacted_event = prune_event(original_ev)
|
|
||||||
|
|
||||||
redaction_id = yield self._simple_select_one_onecol(
|
|
||||||
table="redactions",
|
|
||||||
keyvalues={"redacts": redacted_event.event_id},
|
|
||||||
retcol="event_id",
|
|
||||||
desc="_get_event_from_row_redactions",
|
|
||||||
)
|
|
||||||
|
|
||||||
redacted_event.unsigned["redacted_by"] = redaction_id
|
|
||||||
# Get the redaction event.
|
|
||||||
|
|
||||||
because = yield self.get_event(
|
|
||||||
redaction_id,
|
|
||||||
check_redacted=False,
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if because:
|
|
||||||
# It's fine to do add the event directly, since get_pdu_json
|
|
||||||
# will serialise this field correctly
|
|
||||||
redacted_event.unsigned["redacted_because"] = because
|
|
||||||
|
|
||||||
cache_entry = _EventCacheEntry(
|
|
||||||
event=original_ev,
|
|
||||||
redacted_event=redacted_event,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
|
|
||||||
|
|
||||||
defer.returnValue(cache_entry)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def count_daily_messages(self):
|
def count_daily_messages(self):
|
||||||
"""
|
"""
|
||||||
|
@ -2375,7 +2063,7 @@ class EventsStore(SQLBaseStore):
|
||||||
to_2, so_2 = yield self._get_event_ordering(event_id2)
|
to_2, so_2 = yield self._get_event_ordering(event_id2)
|
||||||
defer.returnValue((to_1, so_1) > (to_2, so_2))
|
defer.returnValue((to_1, so_1) > (to_2, so_2))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@cachedInlineCallbacks(max_entries=5000)
|
||||||
def _get_event_ordering(self, event_id):
|
def _get_event_ordering(self, event_id):
|
||||||
res = yield self._simple_select_one(
|
res = yield self._simple_select_one(
|
||||||
table="events",
|
table="events",
|
||||||
|
|
395
synapse/storage/events_worker.py
Normal file
395
synapse/storage/events_worker.py
Normal file
|
@ -0,0 +1,395 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
|
from synapse.util.logcontext import (
|
||||||
|
preserve_fn, PreserveLoggingContext, make_deferred_yieldable
|
||||||
|
)
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import simplejson as json
|
||||||
|
|
||||||
|
# these are only included to make the type annotations work
|
||||||
|
from synapse.events import EventBase # noqa: F401
|
||||||
|
from synapse.events.snapshot import EventContext # noqa: F401
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# These values are used in the `enqueus_event` and `_do_fetch` methods to
|
||||||
|
# control how we batch/bulk fetch events from the database.
|
||||||
|
# The values are plucked out of thing air to make initial sync run faster
|
||||||
|
# on jki.re
|
||||||
|
# TODO: Make these configurable.
|
||||||
|
EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
|
||||||
|
EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
|
||||||
|
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
|
||||||
|
|
||||||
|
|
||||||
|
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
|
||||||
|
|
||||||
|
|
||||||
|
class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_event(self, event_id, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False,
|
||||||
|
allow_none=False):
|
||||||
|
"""Get an event from the database by event_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id (str): The event_id of the event to fetch
|
||||||
|
check_redacted (bool): If True, check if event has been redacted
|
||||||
|
and redact it.
|
||||||
|
get_prev_content (bool): If True and event is a state event,
|
||||||
|
include the previous states content in the unsigned field.
|
||||||
|
allow_rejected (bool): If True return rejected events.
|
||||||
|
allow_none (bool): If True, return None if no event found, if
|
||||||
|
False throw an exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred : A FrozenEvent.
|
||||||
|
"""
|
||||||
|
events = yield self._get_events(
|
||||||
|
[event_id],
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not events and not allow_none:
|
||||||
|
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
||||||
|
|
||||||
|
defer.returnValue(events[0] if events else None)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_events(self, event_ids, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
"""Get events from the database
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids (list): The event_ids of the events to fetch
|
||||||
|
check_redacted (bool): If True, check if event has been redacted
|
||||||
|
and redact it.
|
||||||
|
get_prev_content (bool): If True and event is a state event,
|
||||||
|
include the previous states content in the unsigned field.
|
||||||
|
allow_rejected (bool): If True return rejected events.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred : Dict from event_id to event.
|
||||||
|
"""
|
||||||
|
events = yield self._get_events(
|
||||||
|
event_ids,
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({e.event_id: e for e in events})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_events(self, event_ids, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
if not event_ids:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
|
event_id_list = event_ids
|
||||||
|
event_ids = set(event_ids)
|
||||||
|
|
||||||
|
event_entry_map = self._get_events_from_cache(
|
||||||
|
event_ids,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
|
||||||
|
|
||||||
|
if missing_events_ids:
|
||||||
|
missing_events = yield self._enqueue_events(
|
||||||
|
missing_events_ids,
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_entry_map.update(missing_events)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for event_id in event_id_list:
|
||||||
|
entry = event_entry_map.get(event_id, None)
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if allow_rejected or not entry.event.rejected_reason:
|
||||||
|
if check_redacted and entry.redacted_event:
|
||||||
|
event = entry.redacted_event
|
||||||
|
else:
|
||||||
|
event = entry.event
|
||||||
|
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
if get_prev_content:
|
||||||
|
if "replaces_state" in event.unsigned:
|
||||||
|
prev = yield self.get_event(
|
||||||
|
event.unsigned["replaces_state"],
|
||||||
|
get_prev_content=False,
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if prev:
|
||||||
|
event.unsigned = dict(event.unsigned)
|
||||||
|
event.unsigned["prev_content"] = prev.content
|
||||||
|
event.unsigned["prev_sender"] = prev.sender
|
||||||
|
|
||||||
|
defer.returnValue(events)
|
||||||
|
|
||||||
|
def _invalidate_get_event_cache(self, event_id):
|
||||||
|
self._get_event_cache.invalidate((event_id,))
|
||||||
|
|
||||||
|
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
|
||||||
|
"""Fetch events from the caches
|
||||||
|
|
||||||
|
Args:
|
||||||
|
events (list(str)): list of event_ids to fetch
|
||||||
|
allow_rejected (bool): Whether to teturn events that were rejected
|
||||||
|
update_metrics (bool): Whether to update the cache hit ratio metrics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of event_id -> _EventCacheEntry for each event_id in cache. If
|
||||||
|
allow_rejected is `False` then there will still be an entry but it
|
||||||
|
will be `None`
|
||||||
|
"""
|
||||||
|
event_map = {}
|
||||||
|
|
||||||
|
for event_id in events:
|
||||||
|
ret = self._get_event_cache.get(
|
||||||
|
(event_id,), None,
|
||||||
|
update_metrics=update_metrics,
|
||||||
|
)
|
||||||
|
if not ret:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if allow_rejected or not ret.event.rejected_reason:
|
||||||
|
event_map[event_id] = ret
|
||||||
|
else:
|
||||||
|
event_map[event_id] = None
|
||||||
|
|
||||||
|
return event_map
|
||||||
|
|
||||||
|
def _do_fetch(self, conn):
|
||||||
|
"""Takes a database connection and waits for requests for events from
|
||||||
|
the _event_fetch_list queue.
|
||||||
|
"""
|
||||||
|
event_list = []
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with self._event_fetch_lock:
|
||||||
|
event_list = self._event_fetch_list
|
||||||
|
self._event_fetch_list = []
|
||||||
|
|
||||||
|
if not event_list:
|
||||||
|
single_threaded = self.database_engine.single_threaded
|
||||||
|
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
|
||||||
|
self._event_fetch_ongoing -= 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
event_id_lists = zip(*event_list)[0]
|
||||||
|
event_ids = [
|
||||||
|
item for sublist in event_id_lists for item in sublist
|
||||||
|
]
|
||||||
|
|
||||||
|
rows = self._new_transaction(
|
||||||
|
conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
row_dict = {
|
||||||
|
r["event_id"]: r
|
||||||
|
for r in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
# We only want to resolve deferreds from the main thread
|
||||||
|
def fire(lst, res):
|
||||||
|
for ids, d in lst:
|
||||||
|
if not d.called:
|
||||||
|
try:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
d.callback([
|
||||||
|
res[i]
|
||||||
|
for i in ids
|
||||||
|
if i in res
|
||||||
|
])
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to callback")
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
reactor.callFromThread(fire, event_list, row_dict)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("do_fetch")
|
||||||
|
|
||||||
|
# We only want to resolve deferreds from the main thread
|
||||||
|
def fire(evs):
|
||||||
|
for _, d in evs:
|
||||||
|
if not d.called:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
d.errback(e)
|
||||||
|
|
||||||
|
if event_list:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
reactor.callFromThread(fire, event_list)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
||||||
|
"""Fetches events from the database using the _event_fetch_list. This
|
||||||
|
allows batch and bulk fetching of events - it allows us to fetch events
|
||||||
|
without having to create a new transaction for each request for events.
|
||||||
|
"""
|
||||||
|
if not events:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
events_d = defer.Deferred()
|
||||||
|
with self._event_fetch_lock:
|
||||||
|
self._event_fetch_list.append(
|
||||||
|
(events, events_d)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._event_fetch_lock.notify()
|
||||||
|
|
||||||
|
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
|
||||||
|
self._event_fetch_ongoing += 1
|
||||||
|
should_start = True
|
||||||
|
else:
|
||||||
|
should_start = False
|
||||||
|
|
||||||
|
if should_start:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
self.runWithConnection(
|
||||||
|
self._do_fetch
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Loading %d events", len(events))
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
rows = yield events_d
|
||||||
|
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
|
||||||
|
|
||||||
|
if not allow_rejected:
|
||||||
|
rows[:] = [r for r in rows if not r["rejects"]]
|
||||||
|
|
||||||
|
res = yield make_deferred_yieldable(defer.gatherResults(
|
||||||
|
[
|
||||||
|
preserve_fn(self._get_event_from_row)(
|
||||||
|
row["internal_metadata"], row["json"], row["redacts"],
|
||||||
|
rejected_reason=row["rejects"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
))
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
e.event.event_id: e
|
||||||
|
for e in res if e
|
||||||
|
})
|
||||||
|
|
||||||
|
def _fetch_event_rows(self, txn, events):
|
||||||
|
rows = []
|
||||||
|
N = 200
|
||||||
|
for i in range(1 + len(events) / N):
|
||||||
|
evs = events[i * N:(i + 1) * N]
|
||||||
|
if not evs:
|
||||||
|
break
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT "
|
||||||
|
" e.event_id as event_id, "
|
||||||
|
" e.internal_metadata,"
|
||||||
|
" e.json,"
|
||||||
|
" r.redacts as redacts,"
|
||||||
|
" rej.event_id as rejects "
|
||||||
|
" FROM event_json as e"
|
||||||
|
" LEFT JOIN rejections as rej USING (event_id)"
|
||||||
|
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
|
||||||
|
" WHERE e.event_id IN (%s)"
|
||||||
|
) % (",".join(["?"] * len(evs)),)
|
||||||
|
|
||||||
|
txn.execute(sql, evs)
|
||||||
|
rows.extend(self.cursor_to_dict(txn))
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_event_from_row(self, internal_metadata, js, redacted,
|
||||||
|
rejected_reason=None):
|
||||||
|
with Measure(self._clock, "_get_event_from_row"):
|
||||||
|
d = json.loads(js)
|
||||||
|
internal_metadata = json.loads(internal_metadata)
|
||||||
|
|
||||||
|
if rejected_reason:
|
||||||
|
rejected_reason = yield self._simple_select_one_onecol(
|
||||||
|
table="rejections",
|
||||||
|
keyvalues={"event_id": rejected_reason},
|
||||||
|
retcol="reason",
|
||||||
|
desc="_get_event_from_row_rejected_reason",
|
||||||
|
)
|
||||||
|
|
||||||
|
original_ev = FrozenEvent(
|
||||||
|
d,
|
||||||
|
internal_metadata_dict=internal_metadata,
|
||||||
|
rejected_reason=rejected_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
redacted_event = None
|
||||||
|
if redacted:
|
||||||
|
redacted_event = prune_event(original_ev)
|
||||||
|
|
||||||
|
redaction_id = yield self._simple_select_one_onecol(
|
||||||
|
table="redactions",
|
||||||
|
keyvalues={"redacts": redacted_event.event_id},
|
||||||
|
retcol="event_id",
|
||||||
|
desc="_get_event_from_row_redactions",
|
||||||
|
)
|
||||||
|
|
||||||
|
redacted_event.unsigned["redacted_by"] = redaction_id
|
||||||
|
# Get the redaction event.
|
||||||
|
|
||||||
|
because = yield self.get_event(
|
||||||
|
redaction_id,
|
||||||
|
check_redacted=False,
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if because:
|
||||||
|
# It's fine to do add the event directly, since get_pdu_json
|
||||||
|
# will serialise this field correctly
|
||||||
|
redacted_event.unsigned["redacted_because"] = because
|
||||||
|
|
||||||
|
cache_entry = _EventCacheEntry(
|
||||||
|
event=original_ev,
|
||||||
|
redacted_event=redacted_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
|
||||||
|
|
||||||
|
defer.returnValue(cache_entry)
|
|
@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
|
|
||||||
|
|
||||||
# The category ID for the "default" category. We don't store as null in the
|
# The category ID for the "default" category. We don't store as null in the
|
||||||
|
|
|
@ -21,14 +21,7 @@ from synapse.api.errors import StoreError
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
class ProfileStore(SQLBaseStore):
|
class ProfileWorkerStore(SQLBaseStore):
|
||||||
def create_profile(self, user_localpart):
|
|
||||||
return self._simple_insert(
|
|
||||||
table="profiles",
|
|
||||||
values={"user_id": user_localpart},
|
|
||||||
desc="create_profile",
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_profileinfo(self, user_localpart):
|
def get_profileinfo(self, user_localpart):
|
||||||
try:
|
try:
|
||||||
|
@ -61,14 +54,6 @@ class ProfileStore(SQLBaseStore):
|
||||||
desc="get_profile_displayname",
|
desc="get_profile_displayname",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
|
||||||
return self._simple_update_one(
|
|
||||||
table="profiles",
|
|
||||||
keyvalues={"user_id": user_localpart},
|
|
||||||
updatevalues={"displayname": new_displayname},
|
|
||||||
desc="set_profile_displayname",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_profile_avatar_url(self, user_localpart):
|
def get_profile_avatar_url(self, user_localpart):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
|
@ -77,14 +62,6 @@ class ProfileStore(SQLBaseStore):
|
||||||
desc="get_profile_avatar_url",
|
desc="get_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
|
||||||
return self._simple_update_one(
|
|
||||||
table="profiles",
|
|
||||||
keyvalues={"user_id": user_localpart},
|
|
||||||
updatevalues={"avatar_url": new_avatar_url},
|
|
||||||
desc="set_profile_avatar_url",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_from_remote_profile_cache(self, user_id):
|
def get_from_remote_profile_cache(self, user_id):
|
||||||
return self._simple_select_one(
|
return self._simple_select_one(
|
||||||
table="remote_profile_cache",
|
table="remote_profile_cache",
|
||||||
|
@ -94,6 +71,31 @@ class ProfileStore(SQLBaseStore):
|
||||||
desc="get_from_remote_profile_cache",
|
desc="get_from_remote_profile_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileStore(ProfileWorkerStore):
|
||||||
|
def create_profile(self, user_localpart):
|
||||||
|
return self._simple_insert(
|
||||||
|
table="profiles",
|
||||||
|
values={"user_id": user_localpart},
|
||||||
|
desc="create_profile",
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_profile_displayname(self, user_localpart, new_displayname):
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="profiles",
|
||||||
|
keyvalues={"user_id": user_localpart},
|
||||||
|
updatevalues={"displayname": new_displayname},
|
||||||
|
desc="set_profile_displayname",
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="profiles",
|
||||||
|
keyvalues={"user_id": user_localpart},
|
||||||
|
updatevalues={"avatar_url": new_avatar_url},
|
||||||
|
desc="set_profile_avatar_url",
|
||||||
|
)
|
||||||
|
|
||||||
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
|
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||||
"""Ensure we are caching the remote user's profiles.
|
"""Ensure we are caching the remote user's profiles.
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,11 +15,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.storage.appservice import ApplicationServiceWorkerStore
|
||||||
|
from synapse.storage.pusher import PusherWorkerStore
|
||||||
|
from synapse.storage.receipts import ReceiptsWorkerStore
|
||||||
|
from synapse.storage.roommember import RoomMemberWorkerStore
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||||
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.push.baserules import list_with_base_rules
|
from synapse.push.baserules import list_with_base_rules
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
|
||||||
|
@ -48,7 +55,43 @@ def _load_rules(rawrules, enabled_map):
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
class PushRuleStore(SQLBaseStore):
|
class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||||
|
ReceiptsWorkerStore,
|
||||||
|
PusherWorkerStore,
|
||||||
|
RoomMemberWorkerStore,
|
||||||
|
SQLBaseStore):
|
||||||
|
"""This is an abstract base class where subclasses must implement
|
||||||
|
`get_max_push_rules_stream_id` which can be called in the initializer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This ABCMeta metaclass ensures that we cannot be instantiated without
|
||||||
|
# the abstract methods being implemented.
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
push_rules_prefill, push_rules_id = self._get_cache_dict(
|
||||||
|
db_conn, "push_rules_stream",
|
||||||
|
entity_column="user_id",
|
||||||
|
stream_column="stream_id",
|
||||||
|
max_value=self.get_max_push_rules_stream_id(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.push_rules_stream_cache = StreamChangeCache(
|
||||||
|
"PushRulesStreamChangeCache", push_rules_id,
|
||||||
|
prefilled_cache=push_rules_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_max_push_rules_stream_id(self):
|
||||||
|
"""Get the position of the push rules stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=5000)
|
@cachedInlineCallbacks(max_entries=5000)
|
||||||
def get_push_rules_for_user(self, user_id):
|
def get_push_rules_for_user(self, user_id):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
|
@ -89,6 +132,22 @@ class PushRuleStore(SQLBaseStore):
|
||||||
r['rule_id']: False if r['enabled'] == 0 else True for r in results
|
r['rule_id']: False if r['enabled'] == 0 else True for r in results
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def have_push_rules_changed_for_user(self, user_id, last_id):
|
||||||
|
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
|
||||||
|
return defer.succeed(False)
|
||||||
|
else:
|
||||||
|
def have_push_rules_changed_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT COUNT(stream_id) FROM push_rules_stream"
|
||||||
|
" WHERE user_id = ? AND ? < stream_id"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, last_id))
|
||||||
|
count, = txn.fetchone()
|
||||||
|
return bool(count)
|
||||||
|
return self.runInteraction(
|
||||||
|
"have_push_rules_changed", have_push_rules_changed_txn
|
||||||
|
)
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_push_rules_for_user",
|
@cachedList(cached_method_name="get_push_rules_for_user",
|
||||||
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
||||||
def bulk_get_push_rules(self, user_ids):
|
def bulk_get_push_rules(self, user_ids):
|
||||||
|
@ -228,6 +287,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
|
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
|
||||||
|
class PushRuleStore(PushRulesWorkerStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_push_rule(
|
def add_push_rule(
|
||||||
self, user_id, rule_id, priority_class, conditions, actions,
|
self, user_id, rule_id, priority_class, conditions, actions,
|
||||||
|
@ -526,21 +587,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
room stream ordering it corresponds to."""
|
room stream ordering it corresponds to."""
|
||||||
return self._push_rules_stream_id_gen.get_current_token()
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
||||||
def have_push_rules_changed_for_user(self, user_id, last_id):
|
def get_max_push_rules_stream_id(self):
|
||||||
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
|
return self.get_push_rules_stream_token()[0]
|
||||||
return defer.succeed(False)
|
|
||||||
else:
|
|
||||||
def have_push_rules_changed_txn(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT COUNT(stream_id) FROM push_rules_stream"
|
|
||||||
" WHERE user_id = ? AND ? < stream_id"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (user_id, last_id))
|
|
||||||
count, = txn.fetchone()
|
|
||||||
return bool(count)
|
|
||||||
return self.runInteraction(
|
|
||||||
"have_push_rules_changed", have_push_rules_changed_txn
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RuleNotFoundException(Exception):
|
class RuleNotFoundException(Exception):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -27,7 +28,7 @@ import types
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PusherStore(SQLBaseStore):
|
class PusherWorkerStore(SQLBaseStore):
|
||||||
def _decode_pushers_rows(self, rows):
|
def _decode_pushers_rows(self, rows):
|
||||||
for r in rows:
|
for r in rows:
|
||||||
dataJson = r['data']
|
dataJson = r['data']
|
||||||
|
@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
|
||||||
rows = yield self.runInteraction("get_all_pushers", get_pushers)
|
rows = yield self.runInteraction("get_all_pushers", get_pushers)
|
||||||
defer.returnValue(rows)
|
defer.returnValue(rows)
|
||||||
|
|
||||||
def get_pushers_stream_token(self):
|
|
||||||
return self._pushers_id_gen.get_current_token()
|
|
||||||
|
|
||||||
def get_all_updated_pushers(self, last_id, current_id, limit):
|
def get_all_updated_pushers(self, last_id, current_id, limit):
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return defer.succeed(([], []))
|
return defer.succeed(([], []))
|
||||||
|
@ -198,6 +196,11 @@ class PusherStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
|
||||||
|
class PusherStore(PusherWorkerStore):
|
||||||
|
def get_pushers_stream_token(self):
|
||||||
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_pusher(self, user_id, access_token, kind, app_id,
|
def add_pusher(self, user_id, access_token, kind, app_id,
|
||||||
app_display_name, device_display_name,
|
app_display_name, device_display_name,
|
||||||
|
@ -230,14 +233,18 @@ class PusherStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
if newly_inserted:
|
if newly_inserted:
|
||||||
# get_if_user_has_pusher only cares if the user has
|
self.runInteraction(
|
||||||
# at least *one* pusher.
|
"add_pusher",
|
||||||
self.get_if_user_has_pusher.invalidate(user_id,)
|
self._invalidate_cache_and_stream,
|
||||||
|
self.get_if_user_has_pusher, (user_id,)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
||||||
def delete_pusher_txn(txn, stream_id):
|
def delete_pusher_txn(txn, stream_id):
|
||||||
txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_if_user_has_pusher, (user_id,)
|
||||||
|
)
|
||||||
|
|
||||||
self._simple_delete_one_txn(
|
self._simple_delete_one_txn(
|
||||||
txn,
|
txn,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,51 +15,50 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from .util.id_generators import StreamIdGenerator
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ReceiptsStore(SQLBaseStore):
|
class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
|
"""This is an abstract base class where subclasses must implement
|
||||||
|
`get_max_receipt_stream_id` which can be called in the initializer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This ABCMeta metaclass ensures that we cannot be instantiated without
|
||||||
|
# the abstract methods being implemented.
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(ReceiptsStore, self).__init__(db_conn, hs)
|
super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self._receipts_stream_cache = StreamChangeCache(
|
self._receipts_stream_cache = StreamChangeCache(
|
||||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_max_receipt_stream_id(self):
|
||||||
|
"""Get the current max stream ID for receipts stream
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks()
|
||||||
def get_users_with_read_receipts_in_room(self, room_id):
|
def get_users_with_read_receipts_in_room(self, room_id):
|
||||||
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
||||||
defer.returnValue(set(r['user_id'] for r in receipts))
|
defer.returnValue(set(r['user_id'] for r in receipts))
|
||||||
|
|
||||||
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
|
|
||||||
user_id):
|
|
||||||
if receipt_type != "m.read":
|
|
||||||
return
|
|
||||||
|
|
||||||
# Returns an ObservableDeferred
|
|
||||||
res = self.get_users_with_read_receipts_in_room.cache.get(
|
|
||||||
room_id, None, update_metrics=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if res:
|
|
||||||
if isinstance(res, defer.Deferred) and res.called:
|
|
||||||
res = res.result
|
|
||||||
if user_id in res:
|
|
||||||
# We'd only be adding to the set, so no point invalidating if the
|
|
||||||
# user is already there
|
|
||||||
return
|
|
||||||
|
|
||||||
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
|
|
||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
def get_receipts_for_room(self, room_id, receipt_type):
|
def get_receipts_for_room(self, room_id, receipt_type):
|
||||||
return self._simple_select_list(
|
return self._simple_select_list(
|
||||||
|
@ -270,6 +270,59 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
||||||
|
if last_id == current_id:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def get_all_updated_receipts_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
||||||
|
" FROM receipts_linearized"
|
||||||
|
" WHERE ? < stream_id AND stream_id <= ?"
|
||||||
|
" ORDER BY stream_id ASC"
|
||||||
|
)
|
||||||
|
args = [last_id, current_id]
|
||||||
|
if limit is not None:
|
||||||
|
sql += " LIMIT ?"
|
||||||
|
args.append(limit)
|
||||||
|
txn.execute(sql, args)
|
||||||
|
|
||||||
|
return txn.fetchall()
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_all_updated_receipts", get_all_updated_receipts_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
|
||||||
|
user_id):
|
||||||
|
if receipt_type != "m.read":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Returns an ObservableDeferred
|
||||||
|
res = self.get_users_with_read_receipts_in_room.cache.get(
|
||||||
|
room_id, None, update_metrics=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
if isinstance(res, defer.Deferred) and res.called:
|
||||||
|
res = res.result
|
||||||
|
if user_id in res:
|
||||||
|
# We'd only be adding to the set, so no point invalidating if the
|
||||||
|
# user is already there
|
||||||
|
return
|
||||||
|
|
||||||
|
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptsStore(ReceiptsWorkerStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
# We instantiate this first as the ReceiptsWorkerStore constructor
|
||||||
|
# needs to be able to call get_max_receipt_stream_id
|
||||||
|
self._receipts_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "receipts_linearized", "stream_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
super(ReceiptsStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
def get_max_receipt_stream_id(self):
|
def get_max_receipt_stream_id(self):
|
||||||
return self._receipts_id_gen.get_current_token()
|
return self._receipts_id_gen.get_current_token()
|
||||||
|
|
||||||
|
@ -457,25 +510,3 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
"data": json.dumps(data),
|
"data": json.dumps(data),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
|
||||||
if last_id == current_id:
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
def get_all_updated_receipts_txn(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
|
||||||
" FROM receipts_linearized"
|
|
||||||
" WHERE ? < stream_id AND stream_id <= ?"
|
|
||||||
" ORDER BY stream_id ASC"
|
|
||||||
)
|
|
||||||
args = [last_id, current_id]
|
|
||||||
if limit is not None:
|
|
||||||
sql += " LIMIT ?"
|
|
||||||
args.append(limit)
|
|
||||||
txn.execute(sql, args)
|
|
||||||
|
|
||||||
return txn.fetchall()
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_all_updated_receipts", get_all_updated_receipts_txn
|
|
||||||
)
|
|
||||||
|
|
|
@ -19,10 +19,70 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
from synapse.storage import background_updates
|
from synapse.storage import background_updates
|
||||||
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
class RegistrationWorkerStore(SQLBaseStore):
|
||||||
|
@cached()
|
||||||
|
def get_user_by_id(self, user_id):
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="users",
|
||||||
|
keyvalues={
|
||||||
|
"name": user_id,
|
||||||
|
},
|
||||||
|
retcols=["name", "password_hash", "is_guest"],
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_user_by_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def get_user_by_access_token(self, token):
|
||||||
|
"""Get a user from the given access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The access token of a user.
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: None, if the token did not match, otherwise dict
|
||||||
|
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
||||||
|
"""
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_user_by_access_token",
|
||||||
|
self._query_for_auth,
|
||||||
|
token
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def is_server_admin(self, user):
|
||||||
|
res = yield self._simple_select_one_onecol(
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user.to_string()},
|
||||||
|
retcol="admin",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_server_admin",
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
|
def _query_for_auth(self, txn, token):
|
||||||
|
sql = (
|
||||||
|
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
|
||||||
|
" access_tokens.device_id"
|
||||||
|
" FROM users"
|
||||||
|
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
||||||
|
" WHERE token = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (token,))
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
if rows:
|
||||||
|
return rows[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationStore(RegistrationWorkerStore,
|
||||||
|
background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(RegistrationStore, self).__init__(db_conn, hs)
|
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||||
|
@ -187,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
@cached()
|
|
||||||
def get_user_by_id(self, user_id):
|
|
||||||
return self._simple_select_one(
|
|
||||||
table="users",
|
|
||||||
keyvalues={
|
|
||||||
"name": user_id,
|
|
||||||
},
|
|
||||||
retcols=["name", "password_hash", "is_guest"],
|
|
||||||
allow_none=True,
|
|
||||||
desc="get_user_by_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_users_by_id_case_insensitive(self, user_id):
|
def get_users_by_id_case_insensitive(self, user_id):
|
||||||
"""Gets users that match user_id case insensitively.
|
"""Gets users that match user_id case insensitively.
|
||||||
Returns a mapping of user_id -> password_hash.
|
Returns a mapping of user_id -> password_hash.
|
||||||
|
@ -304,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
return self.runInteraction("delete_access_token", f)
|
return self.runInteraction("delete_access_token", f)
|
||||||
|
|
||||||
@cached()
|
|
||||||
def get_user_by_access_token(self, token):
|
|
||||||
"""Get a user from the given access token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token (str): The access token of a user.
|
|
||||||
Returns:
|
|
||||||
defer.Deferred: None, if the token did not match, otherwise dict
|
|
||||||
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
|
||||||
"""
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_user_by_access_token",
|
|
||||||
self._query_for_auth,
|
|
||||||
token
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def is_server_admin(self, user):
|
|
||||||
res = yield self._simple_select_one_onecol(
|
|
||||||
table="users",
|
|
||||||
keyvalues={"name": user.to_string()},
|
|
||||||
retcol="admin",
|
|
||||||
allow_none=True,
|
|
||||||
desc="is_server_admin",
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks()
|
||||||
def is_guest(self, user_id):
|
def is_guest(self, user_id):
|
||||||
res = yield self._simple_select_one_onecol(
|
res = yield self._simple_select_one_onecol(
|
||||||
|
@ -344,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
|
||||||
sql = (
|
|
||||||
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
|
|
||||||
" access_tokens.device_id"
|
|
||||||
" FROM users"
|
|
||||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
|
||||||
" WHERE token = ?"
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, (token,))
|
|
||||||
rows = self.cursor_to_dict(txn)
|
|
||||||
if rows:
|
|
||||||
return rows[0]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||||
yield self._simple_upsert("user_threepids", {
|
yield self._simple_upsert("user_threepids", {
|
||||||
|
@ -456,14 +460,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
"""
|
"""
|
||||||
def _find_next_generated_user_id(txn):
|
def _find_next_generated_user_id(txn):
|
||||||
txn.execute("SELECT name FROM users")
|
txn.execute("SELECT name FROM users")
|
||||||
rows = self.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
regex = re.compile("^@(\d+):")
|
regex = re.compile("^@(\d+):")
|
||||||
|
|
||||||
found = set()
|
found = set()
|
||||||
|
|
||||||
for r in rows:
|
for user_id, in txn:
|
||||||
user_id = r["name"]
|
|
||||||
match = regex.search(user_id)
|
match = regex.search(user_id)
|
||||||
if match:
|
if match:
|
||||||
found.add(int(match.group(1)))
|
found.add(int(match.group(1)))
|
||||||
|
|
|
@ -16,12 +16,13 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.search import SearchStore
|
from synapse.storage.search import SearchStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -38,7 +39,138 @@ RatelimitOverride = collections.namedtuple(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RoomStore(SearchStore):
|
class RoomWorkerStore(SQLBaseStore):
|
||||||
|
def get_public_room_ids(self):
|
||||||
|
return self._simple_select_onecol(
|
||||||
|
table="rooms",
|
||||||
|
keyvalues={
|
||||||
|
"is_public": True,
|
||||||
|
},
|
||||||
|
retcol="room_id",
|
||||||
|
desc="get_public_room_ids",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached(num_args=2, max_entries=100)
|
||||||
|
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
|
||||||
|
"""Get pulbic rooms for a particular list, or across all lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id (int)
|
||||||
|
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
|
||||||
|
means the main list, None means all lsits.
|
||||||
|
"""
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_public_room_ids_at_stream_id",
|
||||||
|
self.get_public_room_ids_at_stream_id_txn,
|
||||||
|
stream_id, network_tuple=network_tuple
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
|
||||||
|
network_tuple):
|
||||||
|
return {
|
||||||
|
rm
|
||||||
|
for rm, vis in self.get_published_at_stream_id_txn(
|
||||||
|
txn, stream_id, network_tuple=network_tuple
|
||||||
|
).items()
|
||||||
|
if vis
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
|
||||||
|
if network_tuple:
|
||||||
|
# We want to get from a particular list. No aggregation required.
|
||||||
|
|
||||||
|
sql = ("""
|
||||||
|
SELECT room_id, visibility FROM public_room_list_stream
|
||||||
|
INNER JOIN (
|
||||||
|
SELECT room_id, max(stream_id) AS stream_id
|
||||||
|
FROM public_room_list_stream
|
||||||
|
WHERE stream_id <= ? %s
|
||||||
|
GROUP BY room_id
|
||||||
|
) grouped USING (room_id, stream_id)
|
||||||
|
""")
|
||||||
|
|
||||||
|
if network_tuple.appservice_id is not None:
|
||||||
|
txn.execute(
|
||||||
|
sql % ("AND appservice_id = ? AND network_id = ?",),
|
||||||
|
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
txn.execute(
|
||||||
|
sql % ("AND appservice_id IS NULL",),
|
||||||
|
(stream_id,)
|
||||||
|
)
|
||||||
|
return dict(txn)
|
||||||
|
else:
|
||||||
|
# We want to get from all lists, so we need to aggregate the results
|
||||||
|
|
||||||
|
logger.info("Executing full list")
|
||||||
|
|
||||||
|
sql = ("""
|
||||||
|
SELECT room_id, visibility
|
||||||
|
FROM public_room_list_stream
|
||||||
|
INNER JOIN (
|
||||||
|
SELECT
|
||||||
|
room_id, max(stream_id) AS stream_id, appservice_id,
|
||||||
|
network_id
|
||||||
|
FROM public_room_list_stream
|
||||||
|
WHERE stream_id <= ?
|
||||||
|
GROUP BY room_id, appservice_id, network_id
|
||||||
|
) grouped USING (room_id, stream_id)
|
||||||
|
""")
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(stream_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
# A room is visible if its visible on any list.
|
||||||
|
for room_id, visibility in txn:
|
||||||
|
results[room_id] = bool(visibility) or results.get(room_id, False)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_public_room_changes(self, prev_stream_id, new_stream_id,
|
||||||
|
network_tuple):
|
||||||
|
def get_public_room_changes_txn(txn):
|
||||||
|
then_rooms = self.get_public_room_ids_at_stream_id_txn(
|
||||||
|
txn, prev_stream_id, network_tuple
|
||||||
|
)
|
||||||
|
|
||||||
|
now_rooms_dict = self.get_published_at_stream_id_txn(
|
||||||
|
txn, new_stream_id, network_tuple
|
||||||
|
)
|
||||||
|
|
||||||
|
now_rooms_visible = set(
|
||||||
|
rm for rm, vis in now_rooms_dict.items() if vis
|
||||||
|
)
|
||||||
|
now_rooms_not_visible = set(
|
||||||
|
rm for rm, vis in now_rooms_dict.items() if not vis
|
||||||
|
)
|
||||||
|
|
||||||
|
newly_visible = now_rooms_visible - then_rooms
|
||||||
|
newly_unpublished = now_rooms_not_visible & then_rooms
|
||||||
|
|
||||||
|
return newly_visible, newly_unpublished
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_public_room_changes", get_public_room_changes_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached(max_entries=10000)
|
||||||
|
def is_room_blocked(self, room_id):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="blocked_rooms",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
retcol="1",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_room_blocked",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RoomStore(RoomWorkerStore, SearchStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def store_room(self, room_id, room_creator_user_id, is_public):
|
def store_room(self, room_id, room_creator_user_id, is_public):
|
||||||
|
@ -225,16 +357,6 @@ class RoomStore(SearchStore):
|
||||||
)
|
)
|
||||||
self.hs.get_notifier().on_new_replication_data()
|
self.hs.get_notifier().on_new_replication_data()
|
||||||
|
|
||||||
def get_public_room_ids(self):
|
|
||||||
return self._simple_select_onecol(
|
|
||||||
table="rooms",
|
|
||||||
keyvalues={
|
|
||||||
"is_public": True,
|
|
||||||
},
|
|
||||||
retcol="room_id",
|
|
||||||
desc="get_public_room_ids",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_room_count(self):
|
def get_room_count(self):
|
||||||
"""Retrieve a list of all rooms
|
"""Retrieve a list of all rooms
|
||||||
"""
|
"""
|
||||||
|
@ -326,113 +448,6 @@ class RoomStore(SearchStore):
|
||||||
def get_current_public_room_stream_id(self):
|
def get_current_public_room_stream_id(self):
|
||||||
return self._public_room_id_gen.get_current_token()
|
return self._public_room_id_gen.get_current_token()
|
||||||
|
|
||||||
@cached(num_args=2, max_entries=100)
|
|
||||||
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
|
|
||||||
"""Get pulbic rooms for a particular list, or across all lists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id (int)
|
|
||||||
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
|
|
||||||
means the main list, None means all lsits.
|
|
||||||
"""
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_public_room_ids_at_stream_id",
|
|
||||||
self.get_public_room_ids_at_stream_id_txn,
|
|
||||||
stream_id, network_tuple=network_tuple
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
|
|
||||||
network_tuple):
|
|
||||||
return {
|
|
||||||
rm
|
|
||||||
for rm, vis in self.get_published_at_stream_id_txn(
|
|
||||||
txn, stream_id, network_tuple=network_tuple
|
|
||||||
).items()
|
|
||||||
if vis
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
|
|
||||||
if network_tuple:
|
|
||||||
# We want to get from a particular list. No aggregation required.
|
|
||||||
|
|
||||||
sql = ("""
|
|
||||||
SELECT room_id, visibility FROM public_room_list_stream
|
|
||||||
INNER JOIN (
|
|
||||||
SELECT room_id, max(stream_id) AS stream_id
|
|
||||||
FROM public_room_list_stream
|
|
||||||
WHERE stream_id <= ? %s
|
|
||||||
GROUP BY room_id
|
|
||||||
) grouped USING (room_id, stream_id)
|
|
||||||
""")
|
|
||||||
|
|
||||||
if network_tuple.appservice_id is not None:
|
|
||||||
txn.execute(
|
|
||||||
sql % ("AND appservice_id = ? AND network_id = ?",),
|
|
||||||
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
txn.execute(
|
|
||||||
sql % ("AND appservice_id IS NULL",),
|
|
||||||
(stream_id,)
|
|
||||||
)
|
|
||||||
return dict(txn)
|
|
||||||
else:
|
|
||||||
# We want to get from all lists, so we need to aggregate the results
|
|
||||||
|
|
||||||
logger.info("Executing full list")
|
|
||||||
|
|
||||||
sql = ("""
|
|
||||||
SELECT room_id, visibility
|
|
||||||
FROM public_room_list_stream
|
|
||||||
INNER JOIN (
|
|
||||||
SELECT
|
|
||||||
room_id, max(stream_id) AS stream_id, appservice_id,
|
|
||||||
network_id
|
|
||||||
FROM public_room_list_stream
|
|
||||||
WHERE stream_id <= ?
|
|
||||||
GROUP BY room_id, appservice_id, network_id
|
|
||||||
) grouped USING (room_id, stream_id)
|
|
||||||
""")
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
sql,
|
|
||||||
(stream_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
# A room is visible if its visible on any list.
|
|
||||||
for room_id, visibility in txn:
|
|
||||||
results[room_id] = bool(visibility) or results.get(room_id, False)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def get_public_room_changes(self, prev_stream_id, new_stream_id,
|
|
||||||
network_tuple):
|
|
||||||
def get_public_room_changes_txn(txn):
|
|
||||||
then_rooms = self.get_public_room_ids_at_stream_id_txn(
|
|
||||||
txn, prev_stream_id, network_tuple
|
|
||||||
)
|
|
||||||
|
|
||||||
now_rooms_dict = self.get_published_at_stream_id_txn(
|
|
||||||
txn, new_stream_id, network_tuple
|
|
||||||
)
|
|
||||||
|
|
||||||
now_rooms_visible = set(
|
|
||||||
rm for rm, vis in now_rooms_dict.items() if vis
|
|
||||||
)
|
|
||||||
now_rooms_not_visible = set(
|
|
||||||
rm for rm, vis in now_rooms_dict.items() if not vis
|
|
||||||
)
|
|
||||||
|
|
||||||
newly_visible = now_rooms_visible - then_rooms
|
|
||||||
newly_unpublished = now_rooms_not_visible & then_rooms
|
|
||||||
|
|
||||||
return newly_visible, newly_unpublished
|
|
||||||
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_public_room_changes", get_public_room_changes_txn
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||||
def get_all_new_public_rooms(txn):
|
def get_all_new_public_rooms(txn):
|
||||||
sql = ("""
|
sql = ("""
|
||||||
|
@ -482,18 +497,6 @@ class RoomStore(SearchStore):
|
||||||
else:
|
else:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
|
||||||
def is_room_blocked(self, room_id):
|
|
||||||
return self._simple_select_one_onecol(
|
|
||||||
table="blocked_rooms",
|
|
||||||
keyvalues={
|
|
||||||
"room_id": room_id,
|
|
||||||
},
|
|
||||||
retcol="1",
|
|
||||||
allow_none=True,
|
|
||||||
desc="is_room_blocked",
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def block_room(self, room_id, user_id):
|
def block_room(self, room_id, user_id):
|
||||||
yield self._simple_insert(
|
yield self._simple_insert(
|
||||||
|
@ -504,7 +507,11 @@ class RoomStore(SearchStore):
|
||||||
},
|
},
|
||||||
desc="block_room",
|
desc="block_room",
|
||||||
)
|
)
|
||||||
self.is_room_blocked.invalidate((room_id,))
|
yield self.runInteraction(
|
||||||
|
"block_room_invalidation",
|
||||||
|
self._invalidate_cache_and_stream,
|
||||||
|
self.is_room_blocked, (room_id,),
|
||||||
|
)
|
||||||
|
|
||||||
def get_media_mxcs_in_room(self, room_id):
|
def get_media_mxcs_in_room(self, room_id):
|
||||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,7 +18,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from synapse.storage.events import EventsWorkerStore
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
@ -27,7 +28,7 @@ from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import simplejson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -37,6 +38,11 @@ RoomsForUser = namedtuple(
|
||||||
("room_id", "sender", "membership", "event_id", "stream_ordering")
|
("room_id", "sender", "membership", "event_id", "stream_ordering")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GetRoomsForUserWithStreamOrdering = namedtuple(
|
||||||
|
"_GetRoomsForUserWithStreamOrdering",
|
||||||
|
("room_id", "stream_ordering",)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# We store this using a namedtuple so that we save about 3x space over using a
|
# We store this using a namedtuple so that we save about 3x space over using a
|
||||||
# dict.
|
# dict.
|
||||||
|
@ -48,97 +54,7 @@ ProfileInfo = namedtuple(
|
||||||
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
|
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberStore(SQLBaseStore):
|
class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
def __init__(self, db_conn, hs):
|
|
||||||
super(RoomMemberStore, self).__init__(db_conn, hs)
|
|
||||||
self.register_background_update_handler(
|
|
||||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
|
||||||
)
|
|
||||||
|
|
||||||
def _store_room_members_txn(self, txn, events, backfilled):
|
|
||||||
"""Store a room member in the database.
|
|
||||||
"""
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="room_memberships",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"user_id": event.state_key,
|
|
||||||
"sender": event.user_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"membership": event.membership,
|
|
||||||
"display_name": event.content.get("displayname", None),
|
|
||||||
"avatar_url": event.content.get("avatar_url", None),
|
|
||||||
}
|
|
||||||
for event in events
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
txn.call_after(
|
|
||||||
self._membership_stream_cache.entity_has_changed,
|
|
||||||
event.state_key, event.internal_metadata.stream_ordering
|
|
||||||
)
|
|
||||||
txn.call_after(
|
|
||||||
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
|
|
||||||
)
|
|
||||||
|
|
||||||
# We update the local_invites table only if the event is "current",
|
|
||||||
# i.e., its something that has just happened.
|
|
||||||
# The only current event that can also be an outlier is if its an
|
|
||||||
# invite that has come in across federation.
|
|
||||||
is_new_state = not backfilled and (
|
|
||||||
not event.internal_metadata.is_outlier()
|
|
||||||
or event.internal_metadata.is_invite_from_remote()
|
|
||||||
)
|
|
||||||
is_mine = self.hs.is_mine_id(event.state_key)
|
|
||||||
if is_new_state and is_mine:
|
|
||||||
if event.membership == Membership.INVITE:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="local_invites",
|
|
||||||
values={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"invitee": event.state_key,
|
|
||||||
"inviter": event.sender,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"stream_id": event.internal_metadata.stream_ordering,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sql = (
|
|
||||||
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
|
|
||||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
|
||||||
" AND replaced_by is NULL"
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, (
|
|
||||||
event.internal_metadata.stream_ordering,
|
|
||||||
event.event_id,
|
|
||||||
event.room_id,
|
|
||||||
event.state_key,
|
|
||||||
))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def locally_reject_invite(self, user_id, room_id):
|
|
||||||
sql = (
|
|
||||||
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
|
|
||||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
|
||||||
" AND replaced_by is NULL"
|
|
||||||
)
|
|
||||||
|
|
||||||
def f(txn, stream_ordering):
|
|
||||||
txn.execute(sql, (
|
|
||||||
stream_ordering,
|
|
||||||
True,
|
|
||||||
room_id,
|
|
||||||
user_id,
|
|
||||||
))
|
|
||||||
|
|
||||||
with self._stream_id_gen.get_next() as stream_ordering:
|
|
||||||
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
|
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
|
||||||
def get_hosts_in_room(self, room_id, cache_context):
|
def get_hosts_in_room(self, room_id, cache_context):
|
||||||
"""Returns the set of all hosts currently in the room
|
"""Returns the set of all hosts currently in the room
|
||||||
|
@ -270,12 +186,32 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=500000, iterable=True)
|
@cachedInlineCallbacks(max_entries=500000, iterable=True)
|
||||||
def get_rooms_for_user(self, user_id):
|
def get_rooms_for_user_with_stream_ordering(self, user_id):
|
||||||
"""Returns a set of room_ids the user is currently joined to
|
"""Returns a set of room_ids the user is currently joined to
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
|
||||||
|
the rooms the user is in currently, along with the stream ordering
|
||||||
|
of the most recent join for that user and room.
|
||||||
"""
|
"""
|
||||||
rooms = yield self.get_rooms_for_user_where_membership_is(
|
rooms = yield self.get_rooms_for_user_where_membership_is(
|
||||||
user_id, membership_list=[Membership.JOIN],
|
user_id, membership_list=[Membership.JOIN],
|
||||||
)
|
)
|
||||||
|
defer.returnValue(frozenset(
|
||||||
|
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
|
||||||
|
for r in rooms
|
||||||
|
))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_rooms_for_user(self, user_id, on_invalidate=None):
|
||||||
|
"""Returns a set of room_ids the user is currently joined to
|
||||||
|
"""
|
||||||
|
rooms = yield self.get_rooms_for_user_with_stream_ordering(
|
||||||
|
user_id, on_invalidate=on_invalidate,
|
||||||
|
)
|
||||||
defer.returnValue(frozenset(r.room_id for r in rooms))
|
defer.returnValue(frozenset(r.room_id for r in rooms))
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
||||||
|
@ -295,89 +231,6 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(user_who_share_room)
|
defer.returnValue(user_who_share_room)
|
||||||
|
|
||||||
def forget(self, user_id, room_id):
|
|
||||||
"""Indicate that user_id wishes to discard history for room_id."""
|
|
||||||
def f(txn):
|
|
||||||
sql = (
|
|
||||||
"UPDATE"
|
|
||||||
" room_memberships"
|
|
||||||
" SET"
|
|
||||||
" forgotten = 1"
|
|
||||||
" WHERE"
|
|
||||||
" user_id = ?"
|
|
||||||
" AND"
|
|
||||||
" room_id = ?"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (user_id, room_id))
|
|
||||||
|
|
||||||
txn.call_after(self.was_forgotten_at.invalidate_all)
|
|
||||||
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
|
|
||||||
self._invalidate_cache_and_stream(
|
|
||||||
txn, self.who_forgot_in_room, (room_id,)
|
|
||||||
)
|
|
||||||
return self.runInteraction("forget_membership", f)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2)
|
|
||||||
def did_forget(self, user_id, room_id):
|
|
||||||
"""Returns whether user_id has elected to discard history for room_id.
|
|
||||||
|
|
||||||
Returns False if they have since re-joined."""
|
|
||||||
def f(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT"
|
|
||||||
" COUNT(*)"
|
|
||||||
" FROM"
|
|
||||||
" room_memberships"
|
|
||||||
" WHERE"
|
|
||||||
" user_id = ?"
|
|
||||||
" AND"
|
|
||||||
" room_id = ?"
|
|
||||||
" AND"
|
|
||||||
" forgotten = 0"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (user_id, room_id))
|
|
||||||
rows = txn.fetchall()
|
|
||||||
return rows[0][0]
|
|
||||||
count = yield self.runInteraction("did_forget_membership", f)
|
|
||||||
defer.returnValue(count == 0)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3)
|
|
||||||
def was_forgotten_at(self, user_id, room_id, event_id):
|
|
||||||
"""Returns whether user_id has elected to discard history for room_id at
|
|
||||||
event_id.
|
|
||||||
|
|
||||||
event_id must be a membership event."""
|
|
||||||
def f(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT"
|
|
||||||
" forgotten"
|
|
||||||
" FROM"
|
|
||||||
" room_memberships"
|
|
||||||
" WHERE"
|
|
||||||
" user_id = ?"
|
|
||||||
" AND"
|
|
||||||
" room_id = ?"
|
|
||||||
" AND"
|
|
||||||
" event_id = ?"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (user_id, room_id, event_id))
|
|
||||||
rows = txn.fetchall()
|
|
||||||
return rows[0][0]
|
|
||||||
forgot = yield self.runInteraction("did_forget_membership_at", f)
|
|
||||||
defer.returnValue(forgot == 1)
|
|
||||||
|
|
||||||
@cached()
|
|
||||||
def who_forgot_in_room(self, room_id):
|
|
||||||
return self._simple_select_list(
|
|
||||||
table="room_memberships",
|
|
||||||
retcols=("user_id", "event_id"),
|
|
||||||
keyvalues={
|
|
||||||
"room_id": room_id,
|
|
||||||
"forgotten": 1,
|
|
||||||
},
|
|
||||||
desc="who_forgot"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_joined_users_from_context(self, event, context):
|
def get_joined_users_from_context(self, event, context):
|
||||||
state_group = context.state_group
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
|
@ -600,6 +453,185 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(joined_hosts)
|
defer.returnValue(joined_hosts)
|
||||||
|
|
||||||
|
@cached(max_entries=10000, iterable=True)
|
||||||
|
def _get_joined_hosts_cache(self, room_id):
|
||||||
|
return _JoinedHostsCache(self, room_id)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def who_forgot_in_room(self, room_id):
|
||||||
|
return self._simple_select_list(
|
||||||
|
table="room_memberships",
|
||||||
|
retcols=("user_id", "event_id"),
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"forgotten": 1,
|
||||||
|
},
|
||||||
|
desc="who_forgot"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RoomMemberStore(RoomMemberWorkerStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(RoomMemberStore, self).__init__(db_conn, hs)
|
||||||
|
self.register_background_update_handler(
|
||||||
|
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
||||||
|
)
|
||||||
|
|
||||||
|
def _store_room_members_txn(self, txn, events, backfilled):
|
||||||
|
"""Store a room member in the database.
|
||||||
|
"""
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="room_memberships",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"user_id": event.state_key,
|
||||||
|
"sender": event.user_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"membership": event.membership,
|
||||||
|
"display_name": event.content.get("displayname", None),
|
||||||
|
"avatar_url": event.content.get("avatar_url", None),
|
||||||
|
}
|
||||||
|
for event in events
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
txn.call_after(
|
||||||
|
self._membership_stream_cache.entity_has_changed,
|
||||||
|
event.state_key, event.internal_metadata.stream_ordering
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We update the local_invites table only if the event is "current",
|
||||||
|
# i.e., its something that has just happened.
|
||||||
|
# The only current event that can also be an outlier is if its an
|
||||||
|
# invite that has come in across federation.
|
||||||
|
is_new_state = not backfilled and (
|
||||||
|
not event.internal_metadata.is_outlier()
|
||||||
|
or event.internal_metadata.is_invite_from_remote()
|
||||||
|
)
|
||||||
|
is_mine = self.hs.is_mine_id(event.state_key)
|
||||||
|
if is_new_state and is_mine:
|
||||||
|
if event.membership == Membership.INVITE:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="local_invites",
|
||||||
|
values={
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"invitee": event.state_key,
|
||||||
|
"inviter": event.sender,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"stream_id": event.internal_metadata.stream_ordering,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql = (
|
||||||
|
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
|
||||||
|
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||||
|
" AND replaced_by is NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (
|
||||||
|
event.internal_metadata.stream_ordering,
|
||||||
|
event.event_id,
|
||||||
|
event.room_id,
|
||||||
|
event.state_key,
|
||||||
|
))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def locally_reject_invite(self, user_id, room_id):
|
||||||
|
sql = (
|
||||||
|
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
|
||||||
|
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||||
|
" AND replaced_by is NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
def f(txn, stream_ordering):
|
||||||
|
txn.execute(sql, (
|
||||||
|
stream_ordering,
|
||||||
|
True,
|
||||||
|
room_id,
|
||||||
|
user_id,
|
||||||
|
))
|
||||||
|
|
||||||
|
with self._stream_id_gen.get_next() as stream_ordering:
|
||||||
|
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
||||||
|
|
||||||
|
def forget(self, user_id, room_id):
|
||||||
|
"""Indicate that user_id wishes to discard history for room_id."""
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"UPDATE"
|
||||||
|
" room_memberships"
|
||||||
|
" SET"
|
||||||
|
" forgotten = 1"
|
||||||
|
" WHERE"
|
||||||
|
" user_id = ?"
|
||||||
|
" AND"
|
||||||
|
" room_id = ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, room_id))
|
||||||
|
|
||||||
|
txn.call_after(self.was_forgotten_at.invalidate_all)
|
||||||
|
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.who_forgot_in_room, (room_id,)
|
||||||
|
)
|
||||||
|
return self.runInteraction("forget_membership", f)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=2)
|
||||||
|
def did_forget(self, user_id, room_id):
|
||||||
|
"""Returns whether user_id has elected to discard history for room_id.
|
||||||
|
|
||||||
|
Returns False if they have since re-joined."""
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT"
|
||||||
|
" COUNT(*)"
|
||||||
|
" FROM"
|
||||||
|
" room_memberships"
|
||||||
|
" WHERE"
|
||||||
|
" user_id = ?"
|
||||||
|
" AND"
|
||||||
|
" room_id = ?"
|
||||||
|
" AND"
|
||||||
|
" forgotten = 0"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, room_id))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
return rows[0][0]
|
||||||
|
count = yield self.runInteraction("did_forget_membership", f)
|
||||||
|
defer.returnValue(count == 0)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=3)
|
||||||
|
def was_forgotten_at(self, user_id, room_id, event_id):
|
||||||
|
"""Returns whether user_id has elected to discard history for room_id at
|
||||||
|
event_id.
|
||||||
|
|
||||||
|
event_id must be a membership event."""
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT"
|
||||||
|
" forgotten"
|
||||||
|
" FROM"
|
||||||
|
" room_memberships"
|
||||||
|
" WHERE"
|
||||||
|
" user_id = ?"
|
||||||
|
" AND"
|
||||||
|
" room_id = ?"
|
||||||
|
" AND"
|
||||||
|
" event_id = ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, room_id, event_id))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
return rows[0][0]
|
||||||
|
forgot = yield self.runInteraction("did_forget_membership_at", f)
|
||||||
|
defer.returnValue(forgot == 1)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _background_add_membership_profile(self, progress, batch_size):
|
def _background_add_membership_profile(self, progress, batch_size):
|
||||||
target_min_stream_id = progress.get(
|
target_min_stream_id = progress.get(
|
||||||
|
@ -675,10 +707,6 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@cached(max_entries=10000, iterable=True)
|
|
||||||
def _get_joined_hosts_cache(self, room_id):
|
|
||||||
return _JoinedHostsCache(self, room_id)
|
|
||||||
|
|
||||||
|
|
||||||
class _JoinedHostsCache(object):
|
class _JoinedHostsCache(object):
|
||||||
"""Cache for joined hosts in a room that is optimised to handle updates
|
"""Cache for joined hosts in a room that is optimised to handle updates
|
||||||
|
|
|
@ -12,9 +12,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
||||||
from synapse.storage.prepare_database import get_statements
|
from synapse.storage.prepare_database import get_statements
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
|
|
||||||
import ujson
|
import simplejson
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs):
|
||||||
"max_stream_id_exclusive": max_stream_id + 1,
|
"max_stream_id_exclusive": max_stream_id + 1,
|
||||||
"rows_inserted": 0,
|
"rows_inserted": 0,
|
||||||
}
|
}
|
||||||
progress_json = ujson.dumps(progress)
|
progress_json = simplejson.dumps(progress)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"INSERT into background_updates (update_name, progress_json)"
|
"INSERT into background_updates (update_name, progress_json)"
|
||||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
||||||
|
|
||||||
from synapse.storage.prepare_database import get_statements
|
from synapse.storage.prepare_database import get_statements
|
||||||
|
|
||||||
import ujson
|
import simplejson
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
|
||||||
"max_stream_id_exclusive": max_stream_id + 1,
|
"max_stream_id_exclusive": max_stream_id + 1,
|
||||||
"rows_inserted": 0,
|
"rows_inserted": 0,
|
||||||
}
|
}
|
||||||
progress_json = ujson.dumps(progress)
|
progress_json = simplejson.dumps(progress)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"INSERT into background_updates (update_name, progress_json)"
|
"INSERT into background_updates (update_name, progress_json)"
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue