mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 11:03:51 +01:00
Merge branch 'release-v0.17.1' of github.com:matrix-org/synapse
This commit is contained in:
commit
37638c06c5
98 changed files with 3277 additions and 1430 deletions
47
CHANGES.rst
47
CHANGES.rst
|
@ -1,3 +1,50 @@
|
|||
Changes in synapse v0.17.1 (2016-08-24)
|
||||
=======================================
|
||||
|
||||
Changes:
|
||||
|
||||
* Delete old received_transactions rows (PR #1038)
|
||||
* Pass through user-supplied content in /join/$room_id (PR #1039)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix bug with backfill (PR #1040)
|
||||
|
||||
|
||||
Changes in synapse v0.17.1-rc1 (2016-08-22)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add notification API (PR #1028)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Don't print stack traces when failing to get remote keys (PR #996)
|
||||
* Various federation /event/ perf improvements (PR #998)
|
||||
* Only process one local membership event per room at a time (PR #1005)
|
||||
* Move default display name push rule (PR #1011, #1023)
|
||||
* Fix up preview URL API. Add tests. (PR #1015)
|
||||
* Set ``Content-Security-Policy`` on media repo (PR #1021)
|
||||
* Make notify_interested_services faster (PR #1022)
|
||||
* Add usage stats to prometheus monitoring (PR #1037)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix token login (PR #993)
|
||||
* Fix CAS login (PR #994, #995)
|
||||
* Fix /sync to not clobber status_msg (PR #997)
|
||||
* Fix redacted state events to include prev_content (PR #1003)
|
||||
* Fix some bugs in the auth/ldap handler (PR #1007)
|
||||
* Fix backfill request to limit URI length, so that remotes don't reject the
|
||||
requests due to path length limits (PR #1012)
|
||||
* Fix AS push code to not send duplicate events (PR #1025)
|
||||
|
||||
|
||||
|
||||
Changes in synapse v0.17.0 (2016-08-08)
|
||||
=======================================
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
|
|||
System requirements:
|
||||
- POSIX-compliant system (tested on Linux & OS X)
|
||||
- Python 2.7
|
||||
- At least 512 MB RAM.
|
||||
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
|
||||
|
||||
Synapse is written in python but some of the libraries is uses are written in
|
||||
C. So before we can install synapse itself we need a working C compiler and the
|
||||
|
|
97
docs/workers.rst
Normal file
97
docs/workers.rst
Normal file
|
@ -0,0 +1,97 @@
|
|||
Scaling synapse via workers
|
||||
---------------------------
|
||||
|
||||
Synapse has experimental support for splitting out functionality into
|
||||
multiple separate python processes, helping greatly with scalability. These
|
||||
processes are called 'workers', and are (eventually) intended to scale
|
||||
horizontally independently.
|
||||
|
||||
All processes continue to share the same database instance, and as such, workers
|
||||
only work with postgres based synapse deployments (sharing a single sqlite
|
||||
across multiple processes is a recipe for disaster, plus you should be using
|
||||
postgres anyway if you care about scalability).
|
||||
|
||||
The workers communicate with the master synapse process via a synapse-specific
|
||||
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
|
||||
database replication; feeding a stream of relevant data to the workers so they
|
||||
can be kept in sync with the main synapse process and database state.
|
||||
|
||||
To enable workers, you need to add a replication listener to the master synapse, e.g.::
|
||||
|
||||
listeners:
|
||||
- port: 9092
|
||||
bind_address: '127.0.0.1'
|
||||
type: http
|
||||
tls: false
|
||||
x_forwarded: false
|
||||
resources:
|
||||
- names: [replication]
|
||||
compress: false
|
||||
|
||||
Under **no circumstances** should this replication API listener be exposed to the
|
||||
public internet; it currently implements no authentication whatsoever and is
|
||||
unencrypted HTTP.
|
||||
|
||||
You then create a set of configs for the various worker processes. These should be
|
||||
worker configuration files should be stored in a dedicated subdirectory, to allow
|
||||
synctl to manipulate them.
|
||||
|
||||
The current available worker applications are:
|
||||
* synapse.app.pusher - handles sending push notifications to sygnal and email
|
||||
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
|
||||
* synapse.app.appservice - handles output traffic to Application Services
|
||||
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
|
||||
* synapse.app.media_repository - handles the media repository.
|
||||
|
||||
Each worker configuration file inherits the configuration of the main homeserver
|
||||
configuration file. You can then override configuration specific to that worker,
|
||||
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
|
||||
You should minimise the number of overrides though to maintain a usable config.
|
||||
|
||||
You must specify the type of worker application (worker_app) and the replication
|
||||
endpoint that it's talking to on the main synapse process (worker_replication_url).
|
||||
|
||||
For instance::
|
||||
|
||||
worker_app: synapse.app.synchrotron
|
||||
|
||||
# The replication listener on the synapse to talk to.
|
||||
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
|
||||
|
||||
worker_listeners:
|
||||
- type: http
|
||||
port: 8083
|
||||
resources:
|
||||
- names:
|
||||
- client
|
||||
|
||||
worker_daemonize: True
|
||||
worker_pid_file: /home/matrix/synapse/synchrotron.pid
|
||||
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
|
||||
|
||||
...is a full configuration for a synchrotron worker instance, which will expose a
|
||||
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
|
||||
by the main synapse.
|
||||
|
||||
Obviously you should configure your loadbalancer to route the /sync endpoint to
|
||||
the synchrotron instance(s) in this instance.
|
||||
|
||||
Finally, to actually run your worker-based synapse, you must pass synctl the -a
|
||||
commandline option to tell it to operate on all the worker configurations found
|
||||
in the given directory, e.g.::
|
||||
|
||||
synctl -a $CONFIG/workers start
|
||||
|
||||
Currently one should always restart all workers when restarting or upgrading
|
||||
synapse, unless you explicitly know it's safe not to. For instance, restarting
|
||||
synapse without restarting all the synchrotrons may result in broken typing
|
||||
notifications.
|
||||
|
||||
To manipulate a specific worker, you pass the -w option to synctl::
|
||||
|
||||
synctl -w $CONFIG/workers/synchrotron.yaml restart
|
||||
|
||||
All of the above is highly experimental and subject to change as Synapse evolves,
|
||||
but documenting it here to help folks needing highly scalable Synapses similar
|
||||
to the one running matrix.org!
|
||||
|
|
@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove"
|
|||
tox --notest -e py27
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install lxml
|
||||
|
||||
tox -e py27
|
||||
|
|
|
@ -14,6 +14,7 @@ fi
|
|||
tox -e py27 --notest -v
|
||||
|
||||
TOX_BIN=$TOX_DIR/py27/bin
|
||||
$TOX_BIN/pip install setuptools
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install lxml
|
||||
$TOX_BIN/pip install psycopg2
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.17.0"
|
||||
__version__ = "0.17.1"
|
||||
|
|
|
@ -675,27 +675,18 @@ class Auth(object):
|
|||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
user_id = None
|
||||
guest = False
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
user_id = caveat.caveat_id[len(user_prefix):]
|
||||
user_id = self.get_user_id_from_macaroon(macaroon)
|
||||
user = UserID.from_string(user_id)
|
||||
elif caveat.caveat_id == "guest = true":
|
||||
guest = True
|
||||
|
||||
self.validate_macaroon(
|
||||
macaroon, rights, self.hs.config.expire_access_token,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
guest = False
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id == "guest = true":
|
||||
guest = True
|
||||
|
||||
if guest:
|
||||
ret = {
|
||||
|
@ -743,6 +734,29 @@ class Auth(object):
|
|||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
def get_user_id_from_macaroon(self, macaroon):
|
||||
"""Retrieve the user_id given by the caveats on the macaroon.
|
||||
|
||||
Does *not* validate the macaroon.
|
||||
|
||||
Args:
|
||||
macaroon (pymacaroons.Macaroon): The macaroon to validate
|
||||
|
||||
Returns:
|
||||
(str) user id
|
||||
|
||||
Raises:
|
||||
AuthError if there is no user_id caveat in the macaroon
|
||||
"""
|
||||
user_prefix = "user_id = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
return caveat.caveat_id[len(user_prefix):]
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
|
||||
"""
|
||||
validate that a Macaroon is understood by and was signed by this server.
|
||||
|
@ -754,6 +768,7 @@ class Auth(object):
|
|||
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
||||
This should really always be True, but no clients currently implement
|
||||
token refresh, so we can't enforce expiry yet.
|
||||
user_id (str): The user_id required
|
||||
"""
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
|
|
209
synapse/app/appservice.py
Normal file
209
synapse/app/appservice.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
|
||||
logger = logging.getLogger("synapse.app.appservice")
|
||||
|
||||
|
||||
class AppserviceSlaveStore(
|
||||
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class AppserviceServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse appservice now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
appservice_handler = self.get_application_service_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(results):
|
||||
stream = results.get("events")
|
||||
if stream:
|
||||
max_stream_id = stream["position"]
|
||||
yield appservice_handler.notify_interested_services(max_stream_id)
|
||||
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
yield store.process_replication(result)
|
||||
replicate(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
yield sleep(30)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse appservice", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.worker_app == "synapse.app.appservice"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
if config.notify_appservices:
|
||||
sys.stderr.write(
|
||||
"\nThe appservices must be disabled in the main synapse process"
|
||||
"\nbefore they can be run in a separate worker."
|
||||
"\nPlease add ``notify_appservices: false`` to the main config"
|
||||
"\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.notify_appservices = True
|
||||
|
||||
ps = AppserviceServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ps.setup()
|
||||
ps.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ps.replicate()
|
||||
ps.get_datastore().start_profiling()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-appservice",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
|
@ -51,7 +51,7 @@ from synapse.api.urls import (
|
|||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.metrics import register_memory_metrics
|
||||
from synapse.metrics import register_memory_metrics, get_metrics_for
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
|
@ -385,6 +385,8 @@ def run(hs):
|
|||
|
||||
start_time = hs.get_clock().time()
|
||||
|
||||
stats = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def phone_stats_home():
|
||||
logger.info("Gathering stats for reporting")
|
||||
|
@ -393,7 +395,10 @@ def run(hs):
|
|||
if uptime < 0:
|
||||
uptime = 0
|
||||
|
||||
stats = {}
|
||||
# If the stats directory is empty then this is the first time we've
|
||||
# reported stats.
|
||||
first_time = not stats
|
||||
|
||||
stats["homeserver"] = hs.config.server_name
|
||||
stats["timestamp"] = now
|
||||
stats["uptime_seconds"] = uptime
|
||||
|
@ -406,6 +411,25 @@ def run(hs):
|
|||
daily_messages = yield hs.get_datastore().count_daily_messages()
|
||||
if daily_messages is not None:
|
||||
stats["daily_messages"] = daily_messages
|
||||
else:
|
||||
stats.pop("daily_messages", None)
|
||||
|
||||
if first_time:
|
||||
# Add callbacks to report the synapse stats as metrics whenever
|
||||
# prometheus requests them, typically every 30s.
|
||||
# As some of the stats are expensive to calculate we only update
|
||||
# them when synapse phones home to matrix.org every 24 hours.
|
||||
metrics = get_metrics_for("synapse.usage")
|
||||
metrics.add_callback("timestamp", lambda: stats["timestamp"])
|
||||
metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
|
||||
metrics.add_callback("total_users", lambda: stats["total_users"])
|
||||
metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
|
||||
metrics.add_callback(
|
||||
"daily_active_users", lambda: stats["daily_active_users"]
|
||||
)
|
||||
metrics.add_callback(
|
||||
"daily_messages", lambda: stats.get("daily_messages", 0)
|
||||
)
|
||||
|
||||
logger.info("Reporting stats to matrix.org: %s" % (stats,))
|
||||
try:
|
||||
|
|
212
synapse/app/media_repository.py
Normal file
212
synapse/app/media_repository.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.client_ips import ClientIpStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.media_repository import MediaRepositoryStore
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.api.urls import (
|
||||
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
|
||||
)
|
||||
from synapse.crypto import context_factory
|
||||
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
|
||||
logger = logging.getLogger("synapse.app.media_repository")
|
||||
|
||||
|
||||
class MediaRepositorySlavedStore(
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
BaseSlavedStore,
|
||||
MediaRepositoryStore,
|
||||
ClientIpStore,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class MediaRepositoryServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
elif name == "media":
|
||||
media_repo = MediaRepositoryResource(self)
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse media repository now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
yield store.process_replication(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
yield sleep(5)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse media repository", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.worker_app == "synapse.app.media_repository"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
ss = MediaRepositoryServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.get_handlers()
|
||||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ss.get_datastore().start_profiling()
|
||||
ss.replicate()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-media-repository",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
|
@ -80,11 +80,6 @@ class PusherSlaveStore(
|
|||
DataStore.get_profile_displayname.__func__
|
||||
)
|
||||
|
||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
||||
# in a way that they can be streamed. This means that we don't have a
|
||||
# way to invalidate the forgotten rooms cache correctly.
|
||||
# For now we expire the cache every 10 minutes.
|
||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
||||
who_forgot_in_room = (
|
||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||
)
|
||||
|
@ -168,7 +163,6 @@ class PusherServer(HomeServer):
|
|||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
pusher_pool = self.get_pusherpool()
|
||||
clock = self.get_clock()
|
||||
|
||||
def stop_pusher(user_id, app_id, pushkey):
|
||||
key = "%s:%s" % (app_id, pushkey)
|
||||
|
@ -220,21 +214,11 @@ class PusherServer(HomeServer):
|
|||
min_stream_id, max_stream_id, affected_room_ids
|
||||
)
|
||||
|
||||
def expire_broken_caches():
|
||||
store.who_forgot_in_room.invalidate_all()
|
||||
|
||||
next_expire_broken_caches_ms = 0
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
now_ms = clock.time_msec()
|
||||
if now_ms > next_expire_broken_caches_ms:
|
||||
expire_broken_caches()
|
||||
next_expire_broken_caches_ms = (
|
||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
||||
)
|
||||
yield store.process_replication(result)
|
||||
poke_pushers(result)
|
||||
except:
|
||||
|
|
|
@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
|
|||
from synapse.http.server import JsonResource
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
from synapse.rest.client.v1 import events
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
|
@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
|
|||
BaseSlavedStore,
|
||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
||||
):
|
||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
||||
# in a way that they can be streamed. This means that we don't have a
|
||||
# way to invalidate the forgotten rooms cache correctly.
|
||||
# For now we expire the cache every 10 minutes.
|
||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
||||
who_forgot_in_room = (
|
||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||
)
|
||||
|
@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
|
|||
get_presence_list_accepted = PresenceStore.__dict__[
|
||||
"get_presence_list_accepted"
|
||||
]
|
||||
get_presence_list_observers_accepted = PresenceStore.__dict__[
|
||||
"get_presence_list_observers_accepted"
|
||||
]
|
||||
|
||||
|
||||
UPDATE_SYNCING_USERS_MS = 10 * 1000
|
||||
|
||||
|
||||
class SynchrotronPresence(object):
|
||||
def __init__(self, hs):
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.store = hs.get_datastore()
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {
|
||||
|
@ -119,11 +121,13 @@ class SynchrotronPresence(object):
|
|||
|
||||
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
|
||||
|
||||
def set_state(self, user, state):
|
||||
def set_state(self, user, state, ignore_status_msg=False):
|
||||
# TODO Hows this supposed to work?
|
||||
pass
|
||||
|
||||
get_states = PresenceHandler.get_states.__func__
|
||||
get_state = PresenceHandler.get_state.__func__
|
||||
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
|
||||
current_state_for_users = PresenceHandler.current_state_for_users.__func__
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -194,19 +198,39 @@ class SynchrotronPresence(object):
|
|||
self._need_to_send_sync = False
|
||||
yield self._send_syncing_users_now()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_from_replication(self, states, stream_id):
|
||||
parties = yield self._get_interested_parties(
|
||||
states, calculate_remote_hosts=False
|
||||
)
|
||||
room_ids_to_states, users_to_states, _ = parties
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
|
||||
users=users_to_states.keys()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def process_replication(self, result):
|
||||
stream = result.get("presence", {"rows": []})
|
||||
states = []
|
||||
for row in stream["rows"]:
|
||||
(
|
||||
position, user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||
currently_active
|
||||
) = row
|
||||
self.user_to_current_state[user_id] = UserPresenceState(
|
||||
state = UserPresenceState(
|
||||
user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||
currently_active
|
||||
)
|
||||
self.user_to_current_state[user_id] = state
|
||||
states.append(state)
|
||||
|
||||
if states and "position" in stream:
|
||||
stream_id = int(stream["position"])
|
||||
yield self.notify_from_replication(states, stream_id)
|
||||
|
||||
|
||||
class SynchrotronTyping(object):
|
||||
|
@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
|
|||
elif name == "client":
|
||||
resource = JsonResource(self, canonical_json=False)
|
||||
sync.register_servlets(self, resource)
|
||||
events.register_servlets(self, resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
|
@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
|
|||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
clock = self.get_clock()
|
||||
notifier = self.get_notifier()
|
||||
presence_handler = self.get_presence_handler()
|
||||
typing_handler = self.get_typing_handler()
|
||||
|
||||
def expire_broken_caches():
|
||||
store.who_forgot_in_room.invalidate_all()
|
||||
store.get_presence_list_accepted.invalidate_all()
|
||||
|
||||
def notify_from_stream(
|
||||
result, stream_name, stream_key, room=None, user=None
|
||||
):
|
||||
|
@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
|
|||
result, "typing", "typing_key", room="room_id"
|
||||
)
|
||||
|
||||
next_expire_broken_caches_ms = 0
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args.update(typing_handler.stream_positions())
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
now_ms = clock.time_msec()
|
||||
if now_ms > next_expire_broken_caches_ms:
|
||||
expire_broken_caches()
|
||||
next_expire_broken_caches_ms = (
|
||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
||||
)
|
||||
yield store.process_replication(result)
|
||||
typing_handler.process_replication(result)
|
||||
presence_handler.process_replication(result)
|
||||
yield presence_handler.process_replication(result)
|
||||
notify(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
|
@ -79,13 +81,17 @@ class ApplicationService(object):
|
|||
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
||||
|
||||
def __init__(self, token, url=None, namespaces=None, hs_token=None,
|
||||
sender=None, id=None):
|
||||
sender=None, id=None, protocols=None):
|
||||
self.token = token
|
||||
self.url = url
|
||||
self.hs_token = hs_token
|
||||
self.sender = sender
|
||||
self.namespaces = self._check_namespaces(namespaces)
|
||||
self.id = id
|
||||
if protocols:
|
||||
self.protocols = set(protocols)
|
||||
else:
|
||||
self.protocols = set()
|
||||
|
||||
def _check_namespaces(self, namespaces):
|
||||
# Sanity check that it is of the form:
|
||||
|
@ -138,65 +144,66 @@ class ApplicationService(object):
|
|||
return regex_obj["exclusive"]
|
||||
return False
|
||||
|
||||
def _matches_user(self, event, member_list):
|
||||
if (hasattr(event, "sender") and
|
||||
self.is_interested_in_user(event.sender)):
|
||||
return True
|
||||
@defer.inlineCallbacks
|
||||
def _matches_user(self, event, store):
|
||||
if not event:
|
||||
defer.returnValue(False)
|
||||
|
||||
if self.is_interested_in_user(event.sender):
|
||||
defer.returnValue(True)
|
||||
# also check m.room.member state key
|
||||
if (hasattr(event, "type") and event.type == EventTypes.Member
|
||||
and hasattr(event, "state_key")
|
||||
and self.is_interested_in_user(event.state_key)):
|
||||
return True
|
||||
if (event.type == EventTypes.Member and
|
||||
self.is_interested_in_user(event.state_key)):
|
||||
defer.returnValue(True)
|
||||
|
||||
if not store:
|
||||
defer.returnValue(False)
|
||||
|
||||
member_list = yield store.get_users_in_room(event.room_id)
|
||||
|
||||
# check joined member events
|
||||
for user_id in member_list:
|
||||
if self.is_interested_in_user(user_id):
|
||||
return True
|
||||
return False
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
||||
def _matches_room_id(self, event):
|
||||
if hasattr(event, "room_id"):
|
||||
return self.is_interested_in_room(event.room_id)
|
||||
return False
|
||||
|
||||
def _matches_aliases(self, event, alias_list):
|
||||
@defer.inlineCallbacks
|
||||
def _matches_aliases(self, event, store):
|
||||
if not store or not event:
|
||||
defer.returnValue(False)
|
||||
|
||||
alias_list = yield store.get_aliases_for_room(event.room_id)
|
||||
for alias in alias_list:
|
||||
if self.is_interested_in_alias(alias):
|
||||
return True
|
||||
return False
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
||||
def is_interested(self, event, restrict_to=None, aliases_for_event=None,
|
||||
member_list=None):
|
||||
@defer.inlineCallbacks
|
||||
def is_interested(self, event, store=None):
|
||||
"""Check if this service is interested in this event.
|
||||
|
||||
Args:
|
||||
event(Event): The event to check.
|
||||
restrict_to(str): The namespace to restrict regex tests to.
|
||||
aliases_for_event(list): A list of all the known room aliases for
|
||||
this event.
|
||||
member_list(list): A list of all joined user_ids in this room.
|
||||
store(DataStore)
|
||||
Returns:
|
||||
bool: True if this service would like to know about this event.
|
||||
"""
|
||||
if aliases_for_event is None:
|
||||
aliases_for_event = []
|
||||
if member_list is None:
|
||||
member_list = []
|
||||
# Do cheap checks first
|
||||
if self._matches_room_id(event):
|
||||
defer.returnValue(True)
|
||||
|
||||
if restrict_to and restrict_to not in ApplicationService.NS_LIST:
|
||||
# this is a programming error, so fail early and raise a general
|
||||
# exception
|
||||
raise Exception("Unexpected restrict_to value: %s". restrict_to)
|
||||
if (yield self._matches_aliases(event, store)):
|
||||
defer.returnValue(True)
|
||||
|
||||
if not restrict_to:
|
||||
return (self._matches_user(event, member_list)
|
||||
or self._matches_aliases(event, aliases_for_event)
|
||||
or self._matches_room_id(event))
|
||||
elif restrict_to == ApplicationService.NS_ALIASES:
|
||||
return self._matches_aliases(event, aliases_for_event)
|
||||
elif restrict_to == ApplicationService.NS_ROOMS:
|
||||
return self._matches_room_id(event)
|
||||
elif restrict_to == ApplicationService.NS_USERS:
|
||||
return self._matches_user(event, member_list)
|
||||
if (yield self._matches_user(event, store)):
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
||||
def is_interested_in_user(self, user_id):
|
||||
return (
|
||||
|
@ -216,6 +223,9 @@ class ApplicationService(object):
|
|||
or user_id == self.sender
|
||||
)
|
||||
|
||||
def is_interested_in_protocol(self, protocol):
|
||||
return protocol in self.protocols
|
||||
|
||||
def is_exclusive_alias(self, alias):
|
||||
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from twisted.internet import defer
|
|||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
|
@ -24,6 +25,28 @@ import urllib
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_valid_3pe_result(r, field):
|
||||
if not isinstance(r, dict):
|
||||
return False
|
||||
|
||||
for k in (field, "protocol"):
|
||||
if k not in r:
|
||||
return False
|
||||
if not isinstance(r[k], str):
|
||||
return False
|
||||
|
||||
if "fields" not in r:
|
||||
return False
|
||||
fields = r["fields"]
|
||||
if not isinstance(fields, dict):
|
||||
return False
|
||||
for k in fields.keys():
|
||||
if not isinstance(fields[k], str):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ApplicationServiceApi(SimpleHttpClient):
|
||||
"""This class manages HS -> AS communications, including querying and
|
||||
pushing.
|
||||
|
@ -71,6 +94,43 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
logger.warning("query_alias to %s threw exception %s", uri, ex)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pe(self, service, kind, protocol, fields):
|
||||
if kind == ThirdPartyEntityKind.USER:
|
||||
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
||||
required_field = "userid"
|
||||
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
||||
required_field = "alias"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||
)
|
||||
|
||||
try:
|
||||
response = yield self.get_json(uri, fields)
|
||||
if not isinstance(response, list):
|
||||
logger.warning(
|
||||
"query_3pe to %s returned an invalid response %r",
|
||||
uri, response
|
||||
)
|
||||
defer.returnValue([])
|
||||
|
||||
ret = []
|
||||
for r in response:
|
||||
if _is_valid_3pe_result(r, field=required_field):
|
||||
ret.append(r)
|
||||
else:
|
||||
logger.warning(
|
||||
"query_3pe to %s returned an invalid result %r",
|
||||
uri, r
|
||||
)
|
||||
|
||||
defer.returnValue(ret)
|
||||
except Exception as ex:
|
||||
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
||||
defer.returnValue([])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_bulk(self, service, events, txn_id=None):
|
||||
events = self._serialize(events)
|
||||
|
|
|
@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS
|
|||
This is all tied together by the AppServiceScheduler which DIs the required
|
||||
components.
|
||||
"""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.appservice import ApplicationServiceState
|
||||
from twisted.internet import defer
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
|
|||
self.txn_ctrl = _TransactionController(
|
||||
self.clock, self.store, self.as_api, create_recoverer
|
||||
)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start(self):
|
||||
|
@ -94,38 +97,36 @@ class _ServiceQueuer(object):
|
|||
this schedules any other events in the queue to run.
|
||||
"""
|
||||
|
||||
def __init__(self, txn_ctrl):
|
||||
def __init__(self, txn_ctrl, clock):
|
||||
self.queued_events = {} # dict of {service_id: [events]}
|
||||
self.pending_requests = {} # dict of {service_id: Deferred}
|
||||
self.requests_in_flight = set()
|
||||
self.txn_ctrl = txn_ctrl
|
||||
self.clock = clock
|
||||
|
||||
def enqueue(self, service, event):
|
||||
# if this service isn't being sent something
|
||||
if not self.pending_requests.get(service.id):
|
||||
self._send_request(service, [event])
|
||||
else:
|
||||
# add to queue for this service
|
||||
if service.id not in self.queued_events:
|
||||
self.queued_events[service.id] = []
|
||||
self.queued_events[service.id].append(event)
|
||||
self.queued_events.setdefault(service.id, []).append(event)
|
||||
preserve_fn(self._send_request)(service)
|
||||
|
||||
def _send_request(self, service, events):
|
||||
# send request and add callbacks
|
||||
d = self.txn_ctrl.send(service, events)
|
||||
d.addBoth(self._on_request_finish)
|
||||
d.addErrback(self._on_request_fail)
|
||||
self.pending_requests[service.id] = d
|
||||
@defer.inlineCallbacks
|
||||
def _send_request(self, service):
|
||||
if service.id in self.requests_in_flight:
|
||||
return
|
||||
|
||||
def _on_request_finish(self, service):
|
||||
self.pending_requests[service.id] = None
|
||||
# if there are queued events, then send them.
|
||||
if (service.id in self.queued_events
|
||||
and len(self.queued_events[service.id]) > 0):
|
||||
self._send_request(service, self.queued_events[service.id])
|
||||
self.queued_events[service.id] = []
|
||||
self.requests_in_flight.add(service.id)
|
||||
try:
|
||||
while True:
|
||||
events = self.queued_events.pop(service.id, [])
|
||||
if not events:
|
||||
return
|
||||
|
||||
def _on_request_fail(self, err):
|
||||
logger.error("AS request failed: %s", err)
|
||||
with Measure(self.clock, "servicequeuer.send"):
|
||||
try:
|
||||
yield self.txn_ctrl.send(service, events)
|
||||
except:
|
||||
logger.exception("AS request failed")
|
||||
finally:
|
||||
self.requests_in_flight.discard(service.id)
|
||||
|
||||
|
||||
class _TransactionController(object):
|
||||
|
@ -149,14 +150,12 @@ class _TransactionController(object):
|
|||
if service_is_up:
|
||||
sent = yield txn.send(self.as_api)
|
||||
if sent:
|
||||
txn.complete(self.store)
|
||||
yield txn.complete(self.store)
|
||||
else:
|
||||
self._start_recoverer(service)
|
||||
preserve_fn(self._start_recoverer)(service)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self._start_recoverer(service)
|
||||
# request has finished
|
||||
defer.returnValue(service)
|
||||
preserve_fn(self._start_recoverer)(service)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_recovered(self, recoverer):
|
||||
|
|
|
@ -28,6 +28,7 @@ class AppServiceConfig(Config):
|
|||
|
||||
def read_config(self, config):
|
||||
self.app_service_config_files = config.get("app_service_config_files", [])
|
||||
self.notify_appservices = config.get("notify_appservices", True)
|
||||
|
||||
def default_config(cls, **kwargs):
|
||||
return """\
|
||||
|
@ -122,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||
raise ValueError(
|
||||
"Missing/bad type 'exclusive' key in %s", regex_obj
|
||||
)
|
||||
# protocols check
|
||||
protocols = as_info.get("protocols")
|
||||
if protocols:
|
||||
# Because strings are lists in python
|
||||
if isinstance(protocols, str) or not isinstance(protocols, list):
|
||||
raise KeyError("Optional 'protocols' must be a list if present.")
|
||||
for p in protocols:
|
||||
if not isinstance(p, str):
|
||||
raise KeyError("Bad value for 'protocols' item")
|
||||
return ApplicationService(
|
||||
token=as_info["as_token"],
|
||||
url=as_info["url"],
|
||||
|
@ -129,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||
hs_token=as_info["hs_token"],
|
||||
sender=user_id,
|
||||
id=as_info["id"],
|
||||
protocols=protocols,
|
||||
)
|
||||
|
|
|
@ -22,6 +22,7 @@ from synapse.util.logcontext import (
|
|||
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
|
||||
preserve_fn
|
||||
)
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -61,6 +62,10 @@ Attributes:
|
|||
"""
|
||||
|
||||
|
||||
class KeyLookupError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class Keyring(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -239,6 +244,7 @@ class Keyring(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def do_iterations():
|
||||
with Measure(self.clock, "get_server_verify_keys"):
|
||||
merged_results = {}
|
||||
|
||||
missing_keys = {}
|
||||
|
@ -302,15 +308,15 @@ class Keyring(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_store(self, server_name_and_key_ids):
|
||||
res = yield defer.gatherResults(
|
||||
res = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
self.store.get_server_verify_keys(
|
||||
preserve_fn(self.store.get_server_verify_keys)(
|
||||
server_name, key_ids
|
||||
).addCallback(lambda ks, server: (server, ks), server_name)
|
||||
for server_name, key_ids in server_name_and_key_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(dict(res))
|
||||
|
||||
|
@ -331,13 +337,13 @@ class Keyring(object):
|
|||
)
|
||||
defer.returnValue({})
|
||||
|
||||
results = yield defer.gatherResults(
|
||||
results = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
get_key(p_name, p_keys)
|
||||
preserve_fn(get_key)(p_name, p_keys)
|
||||
for p_name, p_keys in self.perspective_servers.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
union_of_keys = {}
|
||||
for result in results:
|
||||
|
@ -363,7 +369,7 @@ class Keyring(object):
|
|||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Unable to getting key %r for %r directly: %s %s",
|
||||
"Unable to get key %r for %r directly: %s %s",
|
||||
key_ids, server_name,
|
||||
type(e).__name__, str(e.message),
|
||||
)
|
||||
|
@ -377,13 +383,13 @@ class Keyring(object):
|
|||
|
||||
defer.returnValue(keys)
|
||||
|
||||
results = yield defer.gatherResults(
|
||||
results = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
get_key(server_name, key_ids)
|
||||
preserve_fn(get_key)(server_name, key_ids)
|
||||
for server_name, key_ids in server_name_and_key_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
merged = {}
|
||||
for result in results:
|
||||
|
@ -425,7 +431,7 @@ class Keyring(object):
|
|||
for response in responses:
|
||||
if (u"signatures" not in response
|
||||
or perspective_name not in response[u"signatures"]):
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Key response not signed by perspective server"
|
||||
" %r" % (perspective_name,)
|
||||
)
|
||||
|
@ -448,7 +454,7 @@ class Keyring(object):
|
|||
list(response[u"signatures"][perspective_name]),
|
||||
list(perspective_keys)
|
||||
)
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Response not signed with a known key for perspective"
|
||||
" server %r" % (perspective_name,)
|
||||
)
|
||||
|
@ -460,9 +466,9 @@ class Keyring(object):
|
|||
for server_name, response_keys in processed_response.items():
|
||||
keys.setdefault(server_name, {}).update(response_keys)
|
||||
|
||||
yield defer.gatherResults(
|
||||
yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
self.store_keys(
|
||||
preserve_fn(self.store_keys)(
|
||||
server_name=server_name,
|
||||
from_server=perspective_name,
|
||||
verify_keys=response_keys,
|
||||
|
@ -470,7 +476,7 @@ class Keyring(object):
|
|||
for server_name, response_keys in keys.items()
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(keys)
|
||||
|
||||
|
@ -491,10 +497,10 @@ class Keyring(object):
|
|||
|
||||
if (u"signatures" not in response
|
||||
or server_name not in response[u"signatures"]):
|
||||
raise ValueError("Key response not signed by remote server")
|
||||
raise KeyLookupError("Key response not signed by remote server")
|
||||
|
||||
if "tls_fingerprints" not in response:
|
||||
raise ValueError("Key response missing TLS fingerprints")
|
||||
raise KeyLookupError("Key response missing TLS fingerprints")
|
||||
|
||||
certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1, tls_certificate
|
||||
|
@ -508,7 +514,7 @@ class Keyring(object):
|
|||
response_sha256_fingerprints.add(fingerprint[u"sha256"])
|
||||
|
||||
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
|
||||
raise ValueError("TLS certificate not allowed by fingerprints")
|
||||
raise KeyLookupError("TLS certificate not allowed by fingerprints")
|
||||
|
||||
response_keys = yield self.process_v2_response(
|
||||
from_server=server_name,
|
||||
|
@ -518,7 +524,7 @@ class Keyring(object):
|
|||
|
||||
keys.update(response_keys)
|
||||
|
||||
yield defer.gatherResults(
|
||||
yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.store_keys)(
|
||||
server_name=key_server_name,
|
||||
|
@ -528,7 +534,7 @@ class Keyring(object):
|
|||
for key_server_name, verify_keys in keys.items()
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(keys)
|
||||
|
||||
|
@ -560,14 +566,14 @@ class Keyring(object):
|
|||
server_name = response_json["server_name"]
|
||||
if only_from_server:
|
||||
if server_name != from_server:
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Expected a response for server %r not %r" % (
|
||||
from_server, server_name
|
||||
)
|
||||
)
|
||||
for key_id in response_json["signatures"].get(server_name, {}):
|
||||
if key_id not in response_json["verify_keys"]:
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Key response must include verification keys for all"
|
||||
" signatures"
|
||||
)
|
||||
|
@ -594,7 +600,7 @@ class Keyring(object):
|
|||
response_keys.update(verify_keys)
|
||||
response_keys.update(old_verify_keys)
|
||||
|
||||
yield defer.gatherResults(
|
||||
yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.store.store_server_keys_json)(
|
||||
server_name=server_name,
|
||||
|
@ -607,7 +613,7 @@ class Keyring(object):
|
|||
for key_id in updated_key_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
results[server_name] = response_keys
|
||||
|
||||
|
@ -635,15 +641,15 @@ class Keyring(object):
|
|||
|
||||
if ("signatures" not in response
|
||||
or server_name not in response["signatures"]):
|
||||
raise ValueError("Key response not signed by remote server")
|
||||
raise KeyLookupError("Key response not signed by remote server")
|
||||
|
||||
if "tls_certificate" not in response:
|
||||
raise ValueError("Key response missing TLS certificate")
|
||||
raise KeyLookupError("Key response missing TLS certificate")
|
||||
|
||||
tls_certificate_b64 = response["tls_certificate"]
|
||||
|
||||
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
|
||||
raise ValueError("TLS certificate doesn't match")
|
||||
raise KeyLookupError("TLS certificate doesn't match")
|
||||
|
||||
# Cache the result in the datastore.
|
||||
|
||||
|
@ -659,7 +665,7 @@ class Keyring(object):
|
|||
|
||||
for key_id in response["signatures"][server_name]:
|
||||
if key_id not in response["verify_keys"]:
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Key response must include verification keys for all"
|
||||
" signatures"
|
||||
)
|
||||
|
@ -696,7 +702,7 @@ class Keyring(object):
|
|||
A deferred that completes when the keys are stored.
|
||||
"""
|
||||
# TODO(markjh): Store whether the keys have expired.
|
||||
yield defer.gatherResults(
|
||||
yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.store.store_server_verify_key)(
|
||||
server_name, server_name, key.time_added, key
|
||||
|
@ -704,4 +710,4 @@ class Keyring(object):
|
|||
for key_id, key in verify_keys.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
|
|
@ -88,6 +88,8 @@ def prune_event(event):
|
|||
|
||||
if "age_ts" in event.unsigned:
|
||||
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
|
||||
if "replaces_state" in event.unsigned:
|
||||
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
|
||||
|
||||
return type(event)(
|
||||
allowed_fields,
|
||||
|
|
|
@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash
|
|||
from synapse.api.errors import SynapseError
|
||||
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -102,10 +103,10 @@ class FederationBase(object):
|
|||
warn, pdu
|
||||
)
|
||||
|
||||
valid_pdus = yield defer.gatherResults(
|
||||
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
deferreds,
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
if include_none:
|
||||
defer.returnValue(valid_pdus)
|
||||
|
@ -129,7 +130,7 @@ class FederationBase(object):
|
|||
for pdu in pdus
|
||||
]
|
||||
|
||||
deferreds = self.keyring.verify_json_objects_for_server([
|
||||
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
|
||||
(p.origin, p.get_pdu_json())
|
||||
for p in redacted_pdus
|
||||
])
|
||||
|
|
|
@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
|
|||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.events import FrozenEvent
|
||||
import synapse.metrics
|
||||
|
||||
|
@ -51,10 +52,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
|
|||
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
||||
|
||||
|
||||
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
||||
|
||||
|
||||
class FederationClient(FederationBase):
|
||||
def __init__(self, hs):
|
||||
super(FederationClient, self).__init__(hs)
|
||||
|
||||
self.pdu_destination_tried = {}
|
||||
self._clock.looping_call(
|
||||
self._clear_tried_cache, 60 * 1000,
|
||||
)
|
||||
|
||||
def _clear_tried_cache(self):
|
||||
"""Clear pdu_destination_tried cache"""
|
||||
now = self._clock.time_msec()
|
||||
|
||||
old_dict = self.pdu_destination_tried
|
||||
self.pdu_destination_tried = {}
|
||||
|
||||
for event_id, destination_dict in old_dict.items():
|
||||
destination_dict = {
|
||||
dest: time
|
||||
for dest, time in destination_dict.items()
|
||||
if time + PDU_RETRY_TIME_MS > now
|
||||
}
|
||||
if destination_dict:
|
||||
self.pdu_destination_tried[event_id] = destination_dict
|
||||
|
||||
def start_get_pdu_cache(self):
|
||||
self._get_pdu_cache = ExpiringCache(
|
||||
cache_name="get_pdu_cache",
|
||||
|
@ -201,10 +226,10 @@ class FederationClient(FederationBase):
|
|||
]
|
||||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
pdus[:] = yield defer.gatherResults(
|
||||
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
self._check_sigs_and_hashes(pdus),
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
|
@ -240,8 +265,15 @@ class FederationClient(FederationBase):
|
|||
if ev:
|
||||
defer.returnValue(ev)
|
||||
|
||||
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
|
||||
|
||||
pdu = None
|
||||
for destination in destinations:
|
||||
now = self._clock.time_msec()
|
||||
last_attempt = pdu_attempts.get(destination, 0)
|
||||
if last_attempt + PDU_RETRY_TIME_MS > now:
|
||||
continue
|
||||
|
||||
try:
|
||||
limiter = yield get_retry_limiter(
|
||||
destination,
|
||||
|
@ -269,25 +301,19 @@ class FederationClient(FederationBase):
|
|||
|
||||
break
|
||||
|
||||
pdu_attempts[destination] = now
|
||||
|
||||
except SynapseError as e:
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
)
|
||||
continue
|
||||
except CodeMessageException as e:
|
||||
if 400 <= e.code < 500:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
)
|
||||
continue
|
||||
except NotRetryingDestination as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except Exception as e:
|
||||
pdu_attempts[destination] = now
|
||||
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
|
@ -406,7 +432,7 @@ class FederationClient(FederationBase):
|
|||
events and the second is a list of event ids that we failed to fetch.
|
||||
"""
|
||||
if return_local:
|
||||
seen_events = yield self.store.get_events(event_ids)
|
||||
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
|
||||
signed_events = seen_events.values()
|
||||
else:
|
||||
seen_events = yield self.store.have_events(event_ids)
|
||||
|
@ -432,14 +458,16 @@ class FederationClient(FederationBase):
|
|||
batch = set(missing_events[i:i + batch_size])
|
||||
|
||||
deferreds = [
|
||||
self.get_pdu(
|
||||
preserve_fn(self.get_pdu)(
|
||||
destinations=random_server_list(),
|
||||
event_id=e_id,
|
||||
)
|
||||
for e_id in batch
|
||||
]
|
||||
|
||||
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
res = yield preserve_context_over_deferred(
|
||||
defer.DeferredList(deferreds, consumeErrors=True)
|
||||
)
|
||||
for success, result in res:
|
||||
if success:
|
||||
signed_events.append(result)
|
||||
|
@ -828,14 +856,16 @@ class FederationClient(FederationBase):
|
|||
return srvs
|
||||
|
||||
deferreds = [
|
||||
self.get_pdu(
|
||||
preserve_fn(self.get_pdu)(
|
||||
destinations=random_server_list(),
|
||||
event_id=e_id,
|
||||
)
|
||||
for e_id, depth in ordered_missing[:limit - len(signed_events)]
|
||||
]
|
||||
|
||||
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
res = yield preserve_context_over_deferred(
|
||||
defer.DeferredList(deferreds, consumeErrors=True)
|
||||
)
|
||||
for (result, val), (e_id, _) in zip(res, ordered_missing):
|
||||
if result and val:
|
||||
signed_events.append(val)
|
||||
|
|
|
@ -21,11 +21,11 @@ from .units import Transaction
|
|||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.retryutils import (
|
||||
get_retry_limiter, NotRetryingDestination,
|
||||
)
|
||||
from synapse.util.metrics import measure_func
|
||||
import synapse.metrics
|
||||
|
||||
import logging
|
||||
|
@ -51,7 +51,7 @@ class TransactionQueue(object):
|
|||
|
||||
self.transport_layer = transport_layer
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
# Is a mapping from destinations -> deferreds. Used to keep track
|
||||
# of which destinations have transactions in flight and when they are
|
||||
|
@ -82,7 +82,7 @@ class TransactionQueue(object):
|
|||
self.pending_failures_by_dest = {}
|
||||
|
||||
# HACK to get unique tx id
|
||||
self._next_txn_id = int(self._clock.time_msec())
|
||||
self._next_txn_id = int(self.clock.time_msec())
|
||||
|
||||
def can_send_to(self, destination):
|
||||
"""Can we send messages to the given server?
|
||||
|
@ -119,89 +119,46 @@ class TransactionQueue(object):
|
|||
if not destinations:
|
||||
return
|
||||
|
||||
deferreds = []
|
||||
|
||||
for destination in destinations:
|
||||
deferred = defer.Deferred()
|
||||
self.pending_pdus_by_dest.setdefault(destination, []).append(
|
||||
(pdu, deferred, order)
|
||||
(pdu, order)
|
||||
)
|
||||
|
||||
def chain(failure):
|
||||
if not deferred.called:
|
||||
deferred.errback(failure)
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
|
||||
|
||||
deferred.addErrback(log_failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination).addErrback(chain)
|
||||
|
||||
deferreds.append(deferred)
|
||||
|
||||
# NO inlineCallbacks
|
||||
def enqueue_edu(self, edu):
|
||||
destination = edu.destination
|
||||
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
deferred = defer.Deferred()
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(
|
||||
(edu, deferred)
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def chain(failure):
|
||||
if not deferred.called:
|
||||
deferred.errback(failure)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn("Failed to send edu to %s: %s", destination, f.value)
|
||||
|
||||
deferred.addErrback(log_failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination).addErrback(chain)
|
||||
|
||||
return deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def enqueue_failure(self, failure, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
return
|
||||
|
||||
deferred = defer.Deferred()
|
||||
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
self.pending_failures_by_dest.setdefault(
|
||||
destination, []
|
||||
).append(
|
||||
(failure, deferred)
|
||||
).append(failure)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def chain(f):
|
||||
if not deferred.called:
|
||||
deferred.errback(f)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn("Failed to send failure to %s: %s", destination, f.value)
|
||||
|
||||
deferred.addErrback(log_failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination).addErrback(chain)
|
||||
|
||||
yield deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _attempt_new_transaction(self, destination):
|
||||
yield run_on_reactor()
|
||||
|
||||
while True:
|
||||
# list of (pending_pdu, deferred, order)
|
||||
if destination in self.pending_transactions:
|
||||
# XXX: pending_transactions can get stuck on by a never-ending
|
||||
|
@ -226,27 +183,31 @@ class TransactionQueue(object):
|
|||
logger.debug("TX [%s] Nothing to send", destination)
|
||||
return
|
||||
|
||||
yield self._send_new_transaction(
|
||||
destination, pending_pdus, pending_edus, pending_failures
|
||||
)
|
||||
|
||||
@measure_func("_send_new_transaction")
|
||||
@defer.inlineCallbacks
|
||||
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
||||
pending_failures):
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[1])
|
||||
pdus = [x[0] for x in pending_pdus]
|
||||
edus = pending_edus
|
||||
failures = [x.get_dict() for x in pending_failures]
|
||||
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
|
||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[2])
|
||||
|
||||
pdus = [x[0] for x in pending_pdus]
|
||||
edus = [x[0] for x in pending_edus]
|
||||
failures = [x[0].get_dict() for x in pending_failures]
|
||||
deferreds = [
|
||||
x[1]
|
||||
for x in pending_pdus + pending_edus + pending_failures
|
||||
]
|
||||
|
||||
txn_id = str(self._next_txn_id)
|
||||
|
||||
limiter = yield get_retry_limiter(
|
||||
destination,
|
||||
self._clock,
|
||||
self.clock,
|
||||
self.store,
|
||||
)
|
||||
|
||||
|
@ -262,7 +223,7 @@ class TransactionQueue(object):
|
|||
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||
|
||||
transaction = Transaction.create_new(
|
||||
origin_server_ts=int(self._clock.time_msec()),
|
||||
origin_server_ts=int(self.clock.time_msec()),
|
||||
transaction_id=txn_id,
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
|
@ -293,7 +254,7 @@ class TransactionQueue(object):
|
|||
# keys work
|
||||
def json_data_cb():
|
||||
data = transaction.get_dict()
|
||||
now = int(self._clock.time_msec())
|
||||
now = int(self.clock.time_msec())
|
||||
if "pdus" in data:
|
||||
for p in data["pdus"]:
|
||||
if "age_ts" in p:
|
||||
|
@ -333,22 +294,11 @@ class TransactionQueue(object):
|
|||
|
||||
logger.debug("TX [%s] Marked as delivered", destination)
|
||||
|
||||
logger.debug("TX [%s] Yielding to callbacks...", destination)
|
||||
|
||||
for deferred in deferreds:
|
||||
if code == 200:
|
||||
deferred.callback(None)
|
||||
else:
|
||||
deferred.errback(RuntimeError("Got status %d" % code))
|
||||
|
||||
# Ensures we don't continue until all callbacks on that
|
||||
# deferred have fired
|
||||
try:
|
||||
yield deferred
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.debug("TX [%s] Yielded to callbacks", destination)
|
||||
if code != 200:
|
||||
for p in pdus:
|
||||
logger.info(
|
||||
"Failed to send event %s to %s", p.event_id, destination
|
||||
)
|
||||
except NotRetryingDestination:
|
||||
logger.info(
|
||||
"TX [%s] not ready for retry yet - "
|
||||
|
@ -363,6 +313,9 @@ class TransactionQueue(object):
|
|||
destination,
|
||||
e,
|
||||
)
|
||||
|
||||
for p in pdus:
|
||||
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
||||
except Exception as e:
|
||||
# We capture this here as there as nothing actually listens
|
||||
# for this finishing functions deferred.
|
||||
|
@ -372,13 +325,9 @@ class TransactionQueue(object):
|
|||
e,
|
||||
)
|
||||
|
||||
for deferred in deferreds:
|
||||
if not deferred.called:
|
||||
deferred.errback(e)
|
||||
for p in pdus:
|
||||
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
||||
|
||||
finally:
|
||||
# We want to be *very* sure we delete this after we stop processing
|
||||
self.pending_transactions.pop(destination, None)
|
||||
|
||||
# Check to see if there is anything else to send.
|
||||
self._attempt_new_transaction(destination)
|
||||
|
|
|
@ -19,7 +19,6 @@ from .room import (
|
|||
)
|
||||
from .room_member import RoomMemberHandler
|
||||
from .message import MessageHandler
|
||||
from .events import EventStreamHandler, EventHandler
|
||||
from .federation import FederationHandler
|
||||
from .profile import ProfileHandler
|
||||
from .directory import DirectoryHandler
|
||||
|
@ -53,8 +52,6 @@ class Handlers(object):
|
|||
self.message_handler = MessageHandler(hs)
|
||||
self.room_creation_handler = RoomCreationHandler(hs)
|
||||
self.room_member_handler = RoomMemberHandler(hs)
|
||||
self.event_stream_handler = EventStreamHandler(hs)
|
||||
self.event_handler = EventHandler(hs)
|
||||
self.federation_handler = FederationHandler(hs)
|
||||
self.profile_handler = ProfileHandler(hs)
|
||||
self.directory_handler = DirectoryHandler(hs)
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -42,25 +43,53 @@ class ApplicationServicesHandler(object):
|
|||
self.appservice_api = hs.get_application_service_api()
|
||||
self.scheduler = hs.get_application_service_scheduler()
|
||||
self.started_scheduler = False
|
||||
self.clock = hs.get_clock()
|
||||
self.notify_appservices = hs.config.notify_appservices
|
||||
|
||||
self.current_max = 0
|
||||
self.is_processing = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_interested_services(self, event):
|
||||
def notify_interested_services(self, current_id):
|
||||
"""Notifies (pushes) all application services interested in this event.
|
||||
|
||||
Pushing is done asynchronously, so this method won't block for any
|
||||
prolonged length of time.
|
||||
|
||||
Args:
|
||||
event(Event): The event to push out to interested services.
|
||||
current_id(int): The current maximum ID.
|
||||
"""
|
||||
services = yield self.store.get_app_services()
|
||||
if not services or not self.notify_appservices:
|
||||
return
|
||||
|
||||
self.current_max = max(self.current_max, current_id)
|
||||
if self.is_processing:
|
||||
return
|
||||
|
||||
with Measure(self.clock, "notify_interested_services"):
|
||||
self.is_processing = True
|
||||
try:
|
||||
upper_bound = self.current_max
|
||||
limit = 100
|
||||
while True:
|
||||
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
||||
upper_bound, limit
|
||||
)
|
||||
|
||||
if not events:
|
||||
break
|
||||
|
||||
for event in events:
|
||||
# Gather interested services
|
||||
services = yield self._get_services_for_event(event)
|
||||
if len(services) == 0:
|
||||
return # no services need notifying
|
||||
continue # no services need notifying
|
||||
|
||||
# Do we know this user exists? If not, poke the user query API for
|
||||
# all services which match that user regex. This needs to block as these
|
||||
# user queries need to be made BEFORE pushing the event.
|
||||
# Do we know this user exists? If not, poke the user
|
||||
# query API for all services which match that user regex.
|
||||
# This needs to block as these user queries need to be
|
||||
# made BEFORE pushing the event.
|
||||
yield self._check_user_exists(event.sender)
|
||||
if event.type == EventTypes.Member:
|
||||
yield self._check_user_exists(event.state_key)
|
||||
|
@ -71,7 +100,16 @@ class ApplicationServicesHandler(object):
|
|||
|
||||
# Fork off pushes to these services
|
||||
for service in services:
|
||||
self.scheduler.submit_event_for_as(service, event)
|
||||
preserve_fn(self.scheduler.submit_event_for_as)(
|
||||
service, event
|
||||
)
|
||||
|
||||
yield self.store.set_appservice_last_pos(upper_bound)
|
||||
|
||||
if len(events) < limit:
|
||||
break
|
||||
finally:
|
||||
self.is_processing = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_user_exists(self, user_id):
|
||||
|
@ -104,11 +142,12 @@ class ApplicationServicesHandler(object):
|
|||
association can be found.
|
||||
"""
|
||||
room_alias_str = room_alias.to_string()
|
||||
alias_query_services = yield self._get_services_for_event(
|
||||
event=None,
|
||||
restrict_to=ApplicationService.NS_ALIASES,
|
||||
alias_list=[room_alias_str]
|
||||
services = yield self.store.get_app_services()
|
||||
alias_query_services = [
|
||||
s for s in services if (
|
||||
s.is_interested_in_alias(room_alias_str)
|
||||
)
|
||||
]
|
||||
for alias_service in alias_query_services:
|
||||
is_known_alias = yield self.appservice_api.query_alias(
|
||||
alias_service, room_alias_str
|
||||
|
@ -121,34 +160,35 @@ class ApplicationServicesHandler(object):
|
|||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_services_for_event(self, event, restrict_to="", alias_list=None):
|
||||
def query_3pe(self, kind, protocol, fields):
|
||||
services = yield self._get_services_for_3pn(protocol)
|
||||
|
||||
results = yield preserve_context_over_deferred(defer.DeferredList([
|
||||
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
|
||||
for service in services
|
||||
], consumeErrors=True))
|
||||
|
||||
ret = []
|
||||
for (success, result) in results:
|
||||
if success:
|
||||
ret.extend(result)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_services_for_event(self, event):
|
||||
"""Retrieve a list of application services interested in this event.
|
||||
|
||||
Args:
|
||||
event(Event): The event to check. Can be None if alias_list is not.
|
||||
restrict_to(str): The namespace to restrict regex tests to.
|
||||
alias_list: A list of aliases to get services for. If None, this
|
||||
list is obtained from the database.
|
||||
Returns:
|
||||
list<ApplicationService>: A list of services interested in this
|
||||
event based on the service regex.
|
||||
"""
|
||||
member_list = None
|
||||
if hasattr(event, "room_id"):
|
||||
# We need to know the aliases associated with this event.room_id,
|
||||
# if any.
|
||||
if not alias_list:
|
||||
alias_list = yield self.store.get_aliases_for_room(
|
||||
event.room_id
|
||||
)
|
||||
# We need to know the members associated with this event.room_id,
|
||||
# if any.
|
||||
member_list = yield self.store.get_users_in_room(event.room_id)
|
||||
|
||||
services = yield self.store.get_app_services()
|
||||
interested_list = [
|
||||
s for s in services if (
|
||||
s.is_interested(event, restrict_to, alias_list, member_list)
|
||||
yield s.is_interested(event, self.store)
|
||||
)
|
||||
]
|
||||
defer.returnValue(interested_list)
|
||||
|
@ -163,6 +203,14 @@ class ApplicationServicesHandler(object):
|
|||
]
|
||||
defer.returnValue(interested_list)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_services_for_3pn(self, protocol):
|
||||
services = yield self.store.get_app_services()
|
||||
interested_list = [
|
||||
s for s in services if s.is_interested_in_protocol(protocol)
|
||||
]
|
||||
defer.returnValue(interested_list)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_unknown_user(self, user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
|
|
|
@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
|
|||
self.ldap_uri = hs.config.ldap_uri
|
||||
self.ldap_start_tls = hs.config.ldap_start_tls
|
||||
self.ldap_base = hs.config.ldap_base
|
||||
self.ldap_filter = hs.config.ldap_filter
|
||||
self.ldap_attributes = hs.config.ldap_attributes
|
||||
if self.ldap_mode == LDAPMode.SEARCH:
|
||||
self.ldap_bind_dn = hs.config.ldap_bind_dn
|
||||
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||
self.ldap_filter = hs.config.ldap_filter
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
|
|||
else:
|
||||
logger.warn(
|
||||
"ldap registration failed: unexpected (%d!=1) amount of results",
|
||||
len(result)
|
||||
len(conn.response)
|
||||
)
|
||||
defer.returnValue(False)
|
||||
|
||||
|
@ -719,13 +719,14 @@ class AuthHandler(BaseHandler):
|
|||
return macaroon.serialize()
|
||||
|
||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||
auth_api = self.hs.get_auth()
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||
auth_api = self.hs.get_auth()
|
||||
auth_api.validate_macaroon(macaroon, "login", True)
|
||||
return self.get_user_from_macaroon(macaroon)
|
||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
||||
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
|
||||
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
||||
auth_api.validate_macaroon(macaroon, "login", True, user_id)
|
||||
return user_id
|
||||
except Exception:
|
||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||
|
||||
def _generate_base_macaroon(self, user_id):
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
|
@ -736,21 +737,11 @@ class AuthHandler(BaseHandler):
|
|||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
return macaroon
|
||||
|
||||
def get_user_from_macaroon(self, macaroon):
|
||||
user_prefix = "user_id = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
return caveat.caveat_id[len(user_prefix):]
|
||||
raise AuthError(
|
||||
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(self, user_id, newpassword, requester=None):
|
||||
password_hash = self.hash(newpassword)
|
||||
|
||||
except_access_token_ids = [requester.access_token_id] if requester else []
|
||||
except_access_token_id = requester.access_token_id if requester else None
|
||||
|
||||
try:
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
|
@ -759,10 +750,10 @@ class AuthHandler(BaseHandler):
|
|||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||
raise e
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, except_access_token_ids
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
||||
user_id, except_access_token_ids
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -26,7 +26,9 @@ from synapse.api.errors import (
|
|||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
||||
from synapse.util.logcontext import (
|
||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
|
||||
)
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
|
@ -249,7 +251,7 @@ class FederationHandler(BaseHandler):
|
|||
if ev.type != EventTypes.Member:
|
||||
continue
|
||||
try:
|
||||
domain = UserID.from_string(ev.state_key).domain
|
||||
domain = get_domain_from_id(ev.state_key)
|
||||
except:
|
||||
continue
|
||||
|
||||
|
@ -274,7 +276,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit, extremities=[]):
|
||||
def backfill(self, dest, room_id, limit, extremities):
|
||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||
|
||||
This will attempt to get more events from the remote. This may return
|
||||
|
@ -284,9 +286,6 @@ class FederationHandler(BaseHandler):
|
|||
if dest == self.server_name:
|
||||
raise SynapseError(400, "Can't backfill from self.")
|
||||
|
||||
if not extremities:
|
||||
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
||||
|
||||
events = yield self.replication_layer.backfill(
|
||||
dest,
|
||||
room_id,
|
||||
|
@ -364,9 +363,9 @@ class FederationHandler(BaseHandler):
|
|||
missing_auth - failed_to_fetch
|
||||
)
|
||||
|
||||
results = yield defer.gatherResults(
|
||||
results = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
self.replication_layer.get_pdu(
|
||||
preserve_fn(self.replication_layer.get_pdu)(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
|
@ -375,10 +374,10 @@ class FederationHandler(BaseHandler):
|
|||
for event_id in missing_auth - failed_to_fetch
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results})
|
||||
)).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results if a})
|
||||
required_auth.update(
|
||||
a_id for event in results for a_id, _ in event.auth_events
|
||||
a_id for event in results for a_id, _ in event.auth_events if event
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
|
||||
|
@ -455,6 +454,10 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
max_depth = sorted_extremeties_tuple[0][1]
|
||||
|
||||
# We don't want to specify too many extremities as it causes the backfill
|
||||
# request URI to be too long.
|
||||
extremities = dict(sorted_extremeties_tuple[:5])
|
||||
|
||||
if current_depth > max_depth:
|
||||
logger.debug(
|
||||
"Not backfilling as we don't need to. %d < %d",
|
||||
|
@ -551,10 +554,10 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
event_ids = list(extremities.keys())
|
||||
|
||||
states = yield defer.gatherResults([
|
||||
self.state_handler.resolve_state_groups(room_id, [e])
|
||||
states = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
|
||||
for e in event_ids
|
||||
])
|
||||
]))
|
||||
states = dict(zip(event_ids, [s[1] for s in states]))
|
||||
|
||||
for e_id, _ in sorted_extremeties_tuple:
|
||||
|
@ -1093,6 +1096,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if event:
|
||||
if self.hs.is_mine_id(event.event_id):
|
||||
# FIXME: This is a temporary work around where we occasionally
|
||||
# return events slightly differently than when they were
|
||||
# originally signed
|
||||
|
@ -1112,6 +1116,12 @@ class FederationHandler(BaseHandler):
|
|||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
events = yield self._filter_events_for_server(
|
||||
origin, event.room_id, [event]
|
||||
)
|
||||
|
||||
event = events[0]
|
||||
|
||||
defer.returnValue(event)
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
@ -1158,9 +1168,9 @@ class FederationHandler(BaseHandler):
|
|||
a bunch of outliers, but not a chunk of individual events that depend
|
||||
on each other for state calculations.
|
||||
"""
|
||||
contexts = yield defer.gatherResults(
|
||||
contexts = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
self._prep_event(
|
||||
preserve_fn(self._prep_event)(
|
||||
origin,
|
||||
ev_info["event"],
|
||||
state=ev_info.get("state"),
|
||||
|
@ -1168,7 +1178,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
for ev_info in event_infos
|
||||
]
|
||||
)
|
||||
))
|
||||
|
||||
yield self.store.persist_events(
|
||||
[
|
||||
|
@ -1452,9 +1462,9 @@ class FederationHandler(BaseHandler):
|
|||
# Do auth conflict res.
|
||||
logger.info("Different auth: %s", different_auth)
|
||||
|
||||
different_events = yield defer.gatherResults(
|
||||
different_events = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
self.store.get_event(
|
||||
preserve_fn(self.store.get_event)(
|
||||
d,
|
||||
allow_none=True,
|
||||
allow_rejected=False,
|
||||
|
@ -1463,7 +1473,7 @@ class FederationHandler(BaseHandler):
|
|||
if d in have_events and not have_events[d]
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
)).addErrback(unwrapFirstError)
|
||||
|
||||
if different_events:
|
||||
local_view = dict(auth_events)
|
||||
|
|
|
@ -28,7 +28,8 @@ from synapse.types import (
|
|||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
@ -502,15 +503,17 @@ class MessageHandler(BaseHandler):
|
|||
lambda states: states[event.event_id]
|
||||
)
|
||||
|
||||
(messages, token), current_state = yield defer.gatherResults(
|
||||
(messages, token), current_state = yield preserve_context_over_deferred(
|
||||
defer.gatherResults(
|
||||
[
|
||||
self.store.get_recent_events_for_room(
|
||||
preserve_fn(self.store.get_recent_events_for_room)(
|
||||
event.room_id,
|
||||
limit=limit,
|
||||
end_token=room_end_token,
|
||||
),
|
||||
deferred_room_state,
|
||||
]
|
||||
)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
|
@ -719,9 +722,9 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
presence, receipts, (messages, token) = yield defer.gatherResults(
|
||||
[
|
||||
get_presence(),
|
||||
get_receipts(),
|
||||
self.store.get_recent_events_for_room(
|
||||
preserve_fn(get_presence)(),
|
||||
preserve_fn(get_receipts)(),
|
||||
preserve_fn(self.store.get_recent_events_for_room)(
|
||||
room_id,
|
||||
limit=limit,
|
||||
end_token=now_token.room_key,
|
||||
|
@ -755,6 +758,7 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@measure_func("_create_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def _create_new_client_event(self, builder, prev_event_ids=None):
|
||||
if prev_event_ids:
|
||||
|
@ -806,6 +810,7 @@ class MessageHandler(BaseHandler):
|
|||
(event, context,)
|
||||
)
|
||||
|
||||
@measure_func("handle_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_client_event(
|
||||
self,
|
||||
|
@ -934,7 +939,7 @@ class MessageHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def _notify():
|
||||
yield run_on_reactor()
|
||||
self.notifier.on_new_room_event(
|
||||
yield self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
)
|
||||
|
@ -944,6 +949,6 @@ class MessageHandler(BaseHandler):
|
|||
# If invite, remove room_state from unsigned before sending.
|
||||
event.unsigned.pop("invite_room_state", None)
|
||||
|
||||
federation_handler.handle_new_event(
|
||||
preserve_fn(federation_handler.handle_new_event)(
|
||||
event, destinations=destinations,
|
||||
)
|
||||
|
|
|
@ -503,7 +503,7 @@ class PresenceHandler(object):
|
|||
defer.returnValue(states)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_interested_parties(self, states):
|
||||
def _get_interested_parties(self, states, calculate_remote_hosts=True):
|
||||
"""Given a list of states return which entities (rooms, users, servers)
|
||||
are interested in the given states.
|
||||
|
||||
|
@ -526,6 +526,7 @@ class PresenceHandler(object):
|
|||
users_to_states.setdefault(state.user_id, []).append(state)
|
||||
|
||||
hosts_to_states = {}
|
||||
if calculate_remote_hosts:
|
||||
for room_id, states in room_ids_to_states.items():
|
||||
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
||||
if not local_states:
|
||||
|
@ -565,6 +566,16 @@ class PresenceHandler(object):
|
|||
|
||||
self._push_to_remotes(hosts_to_states)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_for_states(self, state, stream_id):
|
||||
parties = yield self._get_interested_parties([state])
|
||||
room_ids_to_states, users_to_states, hosts_to_states = parties
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
|
||||
users=[UserID.from_string(u) for u in users_to_states.keys()]
|
||||
)
|
||||
|
||||
def _push_to_remotes(self, hosts_to_states):
|
||||
"""Sends state updates to remote servers.
|
||||
|
||||
|
@ -672,7 +683,7 @@ class PresenceHandler(object):
|
|||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_state(self, target_user, state):
|
||||
def set_state(self, target_user, state, ignore_status_msg=False):
|
||||
"""Set the presence state of the user.
|
||||
"""
|
||||
status_msg = state.get("status_msg", None)
|
||||
|
@ -689,10 +700,13 @@ class PresenceHandler(object):
|
|||
prev_state = yield self.current_state_for_user(user_id)
|
||||
|
||||
new_fields = {
|
||||
"state": presence,
|
||||
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
|
||||
"state": presence
|
||||
}
|
||||
|
||||
if not ignore_status_msg:
|
||||
msg = status_msg if presence != PresenceState.OFFLINE else None
|
||||
new_fields["status_msg"] = msg
|
||||
|
||||
if presence == PresenceState.ONLINE:
|
||||
new_fields["last_active_ts"] = self.clock.time_msec()
|
||||
|
||||
|
|
|
@ -59,10 +59,13 @@ class RoomMemberHandler(BaseHandler):
|
|||
prev_event_ids,
|
||||
txn_id=None,
|
||||
ratelimit=True,
|
||||
content=None,
|
||||
):
|
||||
if content is None:
|
||||
content = {}
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
|
||||
content = {"membership": membership}
|
||||
content["membership"] = membership
|
||||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
|
@ -140,8 +143,9 @@ class RoomMemberHandler(BaseHandler):
|
|||
remote_room_hosts=None,
|
||||
third_party_signed=None,
|
||||
ratelimit=True,
|
||||
content=None,
|
||||
):
|
||||
key = (target, room_id,)
|
||||
key = (room_id,)
|
||||
|
||||
with (yield self.member_linearizer.queue(key)):
|
||||
result = yield self._update_membership(
|
||||
|
@ -153,6 +157,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
remote_room_hosts=remote_room_hosts,
|
||||
third_party_signed=third_party_signed,
|
||||
ratelimit=ratelimit,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
@ -168,7 +173,11 @@ class RoomMemberHandler(BaseHandler):
|
|||
remote_room_hosts=None,
|
||||
third_party_signed=None,
|
||||
ratelimit=True,
|
||||
content=None,
|
||||
):
|
||||
if content is None:
|
||||
content = {}
|
||||
|
||||
effective_membership_state = action
|
||||
if action in ["kick", "unban"]:
|
||||
effective_membership_state = "leave"
|
||||
|
@ -218,7 +227,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
if inviter and not self.hs.is_mine(inviter):
|
||||
remote_room_hosts.append(inviter.domain)
|
||||
|
||||
content = {"membership": Membership.JOIN}
|
||||
content["membership"] = Membership.JOIN
|
||||
|
||||
profile = self.hs.get_handlers().profile_handler
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
|
@ -272,6 +281,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
txn_id=txn_id,
|
||||
ratelimit=ratelimit,
|
||||
prev_event_ids=latest_event_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.logcontext import (
|
||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||
)
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.types import UserID
|
||||
|
||||
|
@ -169,13 +171,13 @@ class TypingHandler(object):
|
|||
deferreds = []
|
||||
for domain in domains:
|
||||
if domain == self.server_name:
|
||||
self._push_update_local(
|
||||
preserve_fn(self._push_update_local)(
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
typing=typing
|
||||
)
|
||||
else:
|
||||
deferreds.append(self.federation.send_edu(
|
||||
deferreds.append(preserve_fn(self.federation.send_edu)(
|
||||
destination=domain,
|
||||
edu_type="m.typing",
|
||||
content={
|
||||
|
@ -185,7 +187,9 @@ class TypingHandler(object):
|
|||
},
|
||||
))
|
||||
|
||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
yield preserve_context_over_deferred(
|
||||
defer.DeferredList(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _recv_edu(self, origin, content):
|
||||
|
|
|
@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
|
|||
time_out=timeout / 1000. if timeout else 60,
|
||||
)
|
||||
|
||||
response = yield preserve_context_over_fn(
|
||||
send_request,
|
||||
)
|
||||
response = yield preserve_context_over_fn(send_request)
|
||||
|
||||
log_result = "%d %s" % (response.code, response.phrase,)
|
||||
break
|
||||
|
|
|
@ -19,6 +19,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.util.metrics import Measure
|
||||
import synapse.metrics
|
||||
import synapse.events
|
||||
|
||||
|
@ -74,12 +75,12 @@ response_db_txn_duration = metrics.register_distribution(
|
|||
_next_request_id = 0
|
||||
|
||||
|
||||
def request_handler(report_metrics=True):
|
||||
def request_handler(include_metrics=False):
|
||||
"""Decorator for ``wrap_request_handler``"""
|
||||
return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
|
||||
return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
|
||||
|
||||
|
||||
def wrap_request_handler(request_handler, report_metrics):
|
||||
def wrap_request_handler(request_handler, include_metrics=False):
|
||||
"""Wraps a method that acts as a request handler with the necessary logging
|
||||
and exception handling.
|
||||
|
||||
|
@ -103,14 +104,17 @@ def wrap_request_handler(request_handler, report_metrics):
|
|||
_next_request_id += 1
|
||||
|
||||
with LoggingContext(request_id) as request_context:
|
||||
if report_metrics:
|
||||
with Measure(self.clock, "wrapped_request_handler"):
|
||||
request_metrics = RequestMetrics()
|
||||
request_metrics.start(self.clock)
|
||||
request_metrics.start(self.clock, name=self.__class__.__name__)
|
||||
|
||||
request_context.request = request_id
|
||||
with request.processing():
|
||||
try:
|
||||
with PreserveLoggingContext(request_context):
|
||||
if include_metrics:
|
||||
yield request_handler(self, request, request_metrics)
|
||||
else:
|
||||
yield request_handler(self, request)
|
||||
except CodeMessageException as e:
|
||||
code = e.code
|
||||
|
@ -145,12 +149,11 @@ def wrap_request_handler(request_handler, report_metrics):
|
|||
)
|
||||
finally:
|
||||
try:
|
||||
if report_metrics:
|
||||
request_metrics.stop(
|
||||
self.clock, request, self.__class__.__name__
|
||||
self.clock, request
|
||||
)
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warn("Failed to stop metrics: %r", e)
|
||||
return wrapped_request_handler
|
||||
|
||||
|
||||
|
@ -220,9 +223,9 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
# It does its own metric reporting because _async_render dispatches to
|
||||
# a callback and it's the class name of that callback we want to report
|
||||
# against rather than the JsonResource itself.
|
||||
@request_handler(report_metrics=False)
|
||||
@request_handler(include_metrics=True)
|
||||
@defer.inlineCallbacks
|
||||
def _async_render(self, request):
|
||||
def _async_render(self, request, request_metrics):
|
||||
""" This gets called from render() every time someone sends us a request.
|
||||
This checks if anyone has registered a callback for that method and
|
||||
path.
|
||||
|
@ -231,9 +234,6 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
self._send_response(request, 200, {})
|
||||
return
|
||||
|
||||
request_metrics = RequestMetrics()
|
||||
request_metrics.start(self.clock)
|
||||
|
||||
# Loop through all the registered callbacks to check if the method
|
||||
# and path regex match
|
||||
for path_entry in self.path_regexs.get(request.method, []):
|
||||
|
@ -247,12 +247,6 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
|
||||
callback = path_entry.callback
|
||||
|
||||
servlet_instance = getattr(callback, "__self__", None)
|
||||
if servlet_instance is not None:
|
||||
servlet_classname = servlet_instance.__class__.__name__
|
||||
else:
|
||||
servlet_classname = "%r" % callback
|
||||
|
||||
kwargs = intern_dict({
|
||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
||||
for name, value in m.groupdict().items()
|
||||
|
@ -263,10 +257,13 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
code, response = callback_return
|
||||
self._send_response(request, code, response)
|
||||
|
||||
try:
|
||||
request_metrics.stop(self.clock, request, servlet_classname)
|
||||
except:
|
||||
pass
|
||||
servlet_instance = getattr(callback, "__self__", None)
|
||||
if servlet_instance is not None:
|
||||
servlet_classname = servlet_instance.__class__.__name__
|
||||
else:
|
||||
servlet_classname = "%r" % callback
|
||||
|
||||
request_metrics.name = servlet_classname
|
||||
|
||||
return
|
||||
|
||||
|
@ -298,11 +295,12 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
|
||||
|
||||
class RequestMetrics(object):
|
||||
def start(self, clock):
|
||||
def start(self, clock, name):
|
||||
self.start = clock.time_msec()
|
||||
self.start_context = LoggingContext.current_context()
|
||||
self.name = name
|
||||
|
||||
def stop(self, clock, request, servlet_classname):
|
||||
def stop(self, clock, request):
|
||||
context = LoggingContext.current_context()
|
||||
|
||||
tag = ""
|
||||
|
@ -316,26 +314,26 @@ class RequestMetrics(object):
|
|||
)
|
||||
return
|
||||
|
||||
incoming_requests_counter.inc(request.method, servlet_classname, tag)
|
||||
incoming_requests_counter.inc(request.method, self.name, tag)
|
||||
|
||||
response_timer.inc_by(
|
||||
clock.time_msec() - self.start, request.method,
|
||||
servlet_classname, tag
|
||||
self.name, tag
|
||||
)
|
||||
|
||||
ru_utime, ru_stime = context.get_resource_usage()
|
||||
|
||||
response_ru_utime.inc_by(
|
||||
ru_utime, request.method, servlet_classname, tag
|
||||
ru_utime, request.method, self.name, tag
|
||||
)
|
||||
response_ru_stime.inc_by(
|
||||
ru_stime, request.method, servlet_classname, tag
|
||||
ru_stime, request.method, self.name, tag
|
||||
)
|
||||
response_db_txn_count.inc_by(
|
||||
context.db_txn_count, request.method, servlet_classname, tag
|
||||
context.db_txn_count, request.method, self.name, tag
|
||||
)
|
||||
response_db_txn_duration.inc_by(
|
||||
context.db_txn_duration, request.method, servlet_classname, tag
|
||||
context.db_txn_duration, request.method, self.name, tag
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,8 @@ from synapse.api.errors import AuthError
|
|||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.types import StreamToken
|
||||
from synapse.visibility import filter_events_for_client
|
||||
import synapse.metrics
|
||||
|
@ -67,10 +68,8 @@ class _NotifierUserStream(object):
|
|||
so that it can remove itself from the indexes in the Notifier class.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id, rooms, current_token, time_now_ms,
|
||||
appservice=None):
|
||||
def __init__(self, user_id, rooms, current_token, time_now_ms):
|
||||
self.user_id = user_id
|
||||
self.appservice = appservice
|
||||
self.rooms = set(rooms)
|
||||
self.current_token = current_token
|
||||
self.last_notified_ms = time_now_ms
|
||||
|
@ -107,11 +106,6 @@ class _NotifierUserStream(object):
|
|||
|
||||
notifier.user_to_user_stream.pop(self.user_id)
|
||||
|
||||
if self.appservice:
|
||||
notifier.appservice_to_user_streams.get(
|
||||
self.appservice, set()
|
||||
).discard(self)
|
||||
|
||||
def count_listeners(self):
|
||||
return len(self.notify_deferred.observers())
|
||||
|
||||
|
@ -142,7 +136,6 @@ class Notifier(object):
|
|||
def __init__(self, hs):
|
||||
self.user_to_user_stream = {}
|
||||
self.room_to_user_streams = {}
|
||||
self.appservice_to_user_streams = {}
|
||||
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -168,8 +161,6 @@ class Notifier(object):
|
|||
all_user_streams |= x
|
||||
for x in self.user_to_user_stream.values():
|
||||
all_user_streams.add(x)
|
||||
for x in self.appservice_to_user_streams.values():
|
||||
all_user_streams |= x
|
||||
|
||||
return sum(stream.count_listeners() for stream in all_user_streams)
|
||||
metrics.register_callback("listeners", count_listeners)
|
||||
|
@ -182,11 +173,8 @@ class Notifier(object):
|
|||
"users",
|
||||
lambda: len(self.user_to_user_stream),
|
||||
)
|
||||
metrics.register_callback(
|
||||
"appservices",
|
||||
lambda: count(bool, self.appservice_to_user_streams.values()),
|
||||
)
|
||||
|
||||
@preserve_fn
|
||||
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
|
||||
extra_users=[]):
|
||||
""" Used by handlers to inform the notifier something has happened
|
||||
|
@ -208,6 +196,7 @@ class Notifier(object):
|
|||
|
||||
self.notify_replication()
|
||||
|
||||
@preserve_fn
|
||||
def _notify_pending_new_room_events(self, max_room_stream_id):
|
||||
"""Notify for the room events that were queued waiting for a previous
|
||||
event to be persisted.
|
||||
|
@ -225,24 +214,11 @@ class Notifier(object):
|
|||
else:
|
||||
self._on_new_room_event(event, room_stream_id, extra_users)
|
||||
|
||||
@preserve_fn
|
||||
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
|
||||
"""Notify any user streams that are interested in this room event"""
|
||||
# poke any interested application service.
|
||||
self.appservice_handler.notify_interested_services(event)
|
||||
|
||||
app_streams = set()
|
||||
|
||||
for appservice in self.appservice_to_user_streams:
|
||||
# TODO (kegan): Redundant appservice listener checks?
|
||||
# App services will already be in the room_to_user_streams set, but
|
||||
# that isn't enough. They need to be checked here in order to
|
||||
# receive *invites* for users they are interested in. Does this
|
||||
# make the room_to_user_streams check somewhat obselete?
|
||||
if appservice.is_interested(event):
|
||||
app_user_streams = self.appservice_to_user_streams.get(
|
||||
appservice, set()
|
||||
)
|
||||
app_streams |= app_user_streams
|
||||
self.appservice_handler.notify_interested_services(room_stream_id)
|
||||
|
||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||
self._user_joined_room(event.state_key, event.room_id)
|
||||
|
@ -251,16 +227,16 @@ class Notifier(object):
|
|||
"room_key", room_stream_id,
|
||||
users=extra_users,
|
||||
rooms=[event.room_id],
|
||||
extra_streams=app_streams,
|
||||
)
|
||||
|
||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
|
||||
extra_streams=set()):
|
||||
@preserve_fn
|
||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
|
||||
""" Used to inform listeners that something has happend event wise.
|
||||
|
||||
Will wake up all listeners for the given users and rooms.
|
||||
"""
|
||||
with PreserveLoggingContext():
|
||||
with Measure(self.clock, "on_new_event"):
|
||||
user_streams = set()
|
||||
|
||||
for user in users:
|
||||
|
@ -280,6 +256,7 @@ class Notifier(object):
|
|||
|
||||
self.notify_replication()
|
||||
|
||||
@preserve_fn
|
||||
def on_new_replication_data(self):
|
||||
"""Used to inform replication listeners that something has happend
|
||||
without waking up any of the normal user event streams"""
|
||||
|
@ -294,7 +271,6 @@ class Notifier(object):
|
|||
"""
|
||||
user_stream = self.user_to_user_stream.get(user_id)
|
||||
if user_stream is None:
|
||||
appservice = yield self.store.get_app_service_by_user_id(user_id)
|
||||
current_token = yield self.event_sources.get_current_token()
|
||||
if room_ids is None:
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
|
@ -302,7 +278,6 @@ class Notifier(object):
|
|||
user_stream = _NotifierUserStream(
|
||||
user_id=user_id,
|
||||
rooms=room_ids,
|
||||
appservice=appservice,
|
||||
current_token=current_token,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
)
|
||||
|
@ -477,11 +452,6 @@ class Notifier(object):
|
|||
s = self.room_to_user_streams.setdefault(room, set())
|
||||
s.add(user_stream)
|
||||
|
||||
if user_stream.appservice:
|
||||
self.appservice_to_user_stream.setdefault(
|
||||
user_stream.appservice, set()
|
||||
).add(user_stream)
|
||||
|
||||
def _user_joined_room(self, user_id, room_id):
|
||||
new_user_stream = self.user_to_user_stream.get(user_id)
|
||||
if new_user_stream is not None:
|
||||
|
|
|
@ -38,11 +38,12 @@ class ActionGenerator:
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def handle_push_actions_for_event(self, event, context):
|
||||
with Measure(self.clock, "handle_push_actions_for_event"):
|
||||
with Measure(self.clock, "evaluator_for_event"):
|
||||
bulk_evaluator = yield evaluator_for_event(
|
||||
event, self.hs, self.store, context.current_state
|
||||
event, self.hs, self.store, context.state_group, context.current_state
|
||||
)
|
||||
|
||||
with Measure(self.clock, "action_for_event_by_user"):
|
||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||
event, context.current_state
|
||||
)
|
||||
|
|
|
@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [
|
|||
'dont_notify'
|
||||
]
|
||||
},
|
||||
# This was changed from underride to override so it's closer in priority
|
||||
# to the content rules where the user name highlight rule lives. This
|
||||
# way a room rule is lower priority than both but a custom override rule
|
||||
# is higher priority than both.
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.contains_display_name',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'contains_display_name'
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
'notify',
|
||||
{
|
||||
'set_tweak': 'sound',
|
||||
'value': 'default'
|
||||
}, {
|
||||
'set_tweak': 'highlight'
|
||||
}
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/underride/.m.rule.contains_display_name',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'contains_display_name'
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
'notify',
|
||||
{
|
||||
'set_tweak': 'sound',
|
||||
'value': 'default'
|
||||
}, {
|
||||
'set_tweak': 'highlight'
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/underride/.m.rule.room_one_to_one',
|
||||
'conditions': [
|
||||
|
|
|
@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def evaluator_for_event(event, hs, store, current_state):
|
||||
room_id = event.room_id
|
||||
# We also will want to generate notifs for other people in the room so
|
||||
# their unread countss are correct in the event stream, but to avoid
|
||||
# generating them for bot / AS users etc, we only do so for people who've
|
||||
# sent a read receipt into the room.
|
||||
|
||||
local_users_in_room = set(
|
||||
e.state_key for e in current_state.values()
|
||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
||||
and hs.is_mine_id(e.state_key)
|
||||
def evaluator_for_event(event, hs, store, state_group, current_state):
|
||||
rules_by_user = yield store.bulk_get_push_rules_for_room(
|
||||
event.room_id, state_group, current_state
|
||||
)
|
||||
|
||||
# users in the room who have pushers need to get push rules run because
|
||||
# that's how their pushers work
|
||||
if_users_with_pushers = yield store.get_if_users_have_pushers(
|
||||
local_users_in_room
|
||||
)
|
||||
user_ids = set(
|
||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||
)
|
||||
|
||||
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
|
||||
|
||||
# any users with pushers must be ours: they have pushers
|
||||
for uid in users_with_receipts:
|
||||
if uid in local_users_in_room:
|
||||
user_ids.add(uid)
|
||||
|
||||
# if this event is an invite event, we may need to run rules for the user
|
||||
# who's been invited, otherwise they won't get told they've been invited
|
||||
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
|
||||
|
@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
|
|||
if invited_user and hs.is_mine_id(invited_user):
|
||||
has_pusher = yield store.user_has_pusher(invited_user)
|
||||
if has_pusher:
|
||||
user_ids.add(invited_user)
|
||||
|
||||
rules_by_user = yield _get_rules(room_id, user_ids, store)
|
||||
rules_by_user[invited_user] = yield store.get_push_rules_for_user(
|
||||
invited_user
|
||||
)
|
||||
|
||||
defer.returnValue(BulkPushRuleEvaluator(
|
||||
room_id, rules_by_user, user_ids, store
|
||||
event.room_id, rules_by_user, store
|
||||
))
|
||||
|
||||
|
||||
|
@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
|
|||
the same logic to run the actual rules, but could be optimised further
|
||||
(see https://matrix.org/jira/browse/SYN-562)
|
||||
"""
|
||||
def __init__(self, room_id, rules_by_user, users_in_room, store):
|
||||
def __init__(self, room_id, rules_by_user, store):
|
||||
self.room_id = room_id
|
||||
self.rules_by_user = rules_by_user
|
||||
self.users_in_room = users_in_room
|
||||
self.store = store
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -17,14 +17,15 @@ from twisted.internet import defer
|
|||
from synapse.util.presentable_names import (
|
||||
calculate_room_name, name_from_member_event
|
||||
)
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_badge_count(store, user_id):
|
||||
invites, joins = yield defer.gatherResults([
|
||||
store.get_invited_rooms_for_user(user_id),
|
||||
store.get_rooms_for_user(user_id),
|
||||
], consumeErrors=True)
|
||||
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(store.get_invited_rooms_for_user)(user_id),
|
||||
preserve_fn(store.get_rooms_for_user)(user_id),
|
||||
], consumeErrors=True))
|
||||
|
||||
my_receipts_by_room = yield store.get_receipts_for_user(
|
||||
user_id, "m.read",
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
import pusher
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
import logging
|
||||
|
@ -102,14 +102,14 @@ class PusherPool:
|
|||
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pushers_by_user(self, user_id, except_token_ids=[]):
|
||||
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
|
||||
all = yield self.store.get_all_pushers()
|
||||
logger.info(
|
||||
"Removing all pushers for user %s except access tokens ids %r",
|
||||
user_id, except_token_ids
|
||||
"Removing all pushers for user %s except access tokens id %r",
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
for p in all:
|
||||
if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
|
||||
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
|
||||
logger.info(
|
||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||
p['app_id'], p['pushkey'], p['user_name']
|
||||
|
@ -130,10 +130,12 @@ class PusherPool:
|
|||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
deferreds.append(
|
||||
p.on_new_notifications(min_stream_id, max_stream_id)
|
||||
preserve_fn(p.on_new_notifications)(
|
||||
min_stream_id, max_stream_id
|
||||
)
|
||||
)
|
||||
|
||||
yield defer.gatherResults(deferreds)
|
||||
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
|
||||
except:
|
||||
logger.exception("Exception in pusher on_new_notifications")
|
||||
|
||||
|
@ -155,10 +157,10 @@ class PusherPool:
|
|||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
deferreds.append(
|
||||
p.on_new_receipts(min_stream_id, max_stream_id)
|
||||
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
|
||||
)
|
||||
|
||||
yield defer.gatherResults(deferreds)
|
||||
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
|
||||
except:
|
||||
logger.exception("Exception in pusher on_new_receipts")
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ STREAM_NAMES = (
|
|||
("push_rules",),
|
||||
("pushers",),
|
||||
("state",),
|
||||
("caches",),
|
||||
)
|
||||
|
||||
|
||||
|
@ -70,6 +71,7 @@ class ReplicationResource(Resource):
|
|||
* "backfill": Old events that have been backfilled from other servers.
|
||||
* "push_rules": Per user changes to push rules.
|
||||
* "pushers": Per user changes to their pushers.
|
||||
* "caches": Cache invalidations.
|
||||
|
||||
The API takes two additional query parameters:
|
||||
|
||||
|
@ -129,6 +131,7 @@ class ReplicationResource(Resource):
|
|||
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
|
||||
pushers_token = self.store.get_pushers_stream_token()
|
||||
state_token = self.store.get_state_stream_token()
|
||||
caches_token = self.store.get_cache_stream_token()
|
||||
|
||||
defer.returnValue(_ReplicationToken(
|
||||
room_stream_token,
|
||||
|
@ -140,6 +143,7 @@ class ReplicationResource(Resource):
|
|||
push_rules_token,
|
||||
pushers_token,
|
||||
state_token,
|
||||
caches_token,
|
||||
))
|
||||
|
||||
@request_handler()
|
||||
|
@ -188,6 +192,7 @@ class ReplicationResource(Resource):
|
|||
yield self.push_rules(writer, current_token, limit, request_streams)
|
||||
yield self.pushers(writer, current_token, limit, request_streams)
|
||||
yield self.state(writer, current_token, limit, request_streams)
|
||||
yield self.caches(writer, current_token, limit, request_streams)
|
||||
self.streams(writer, current_token, request_streams)
|
||||
|
||||
logger.info("Replicated %d rows", writer.total)
|
||||
|
@ -379,6 +384,20 @@ class ReplicationResource(Resource):
|
|||
"position", "type", "state_key", "event_id"
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def caches(self, writer, current_token, limit, request_streams):
|
||||
current_position = current_token.caches
|
||||
|
||||
caches = request_streams.get("caches")
|
||||
|
||||
if caches is not None:
|
||||
updated_caches = yield self.store.get_all_updated_caches(
|
||||
caches, current_position, limit
|
||||
)
|
||||
writer.write_header_and_rows("caches", updated_caches, (
|
||||
"position", "cache_func", "keys", "invalidation_ts"
|
||||
))
|
||||
|
||||
|
||||
class _Writer(object):
|
||||
"""Writes the streams as a JSON object as the response to the request"""
|
||||
|
@ -407,7 +426,7 @@ class _Writer(object):
|
|||
|
||||
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
||||
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
||||
"push_rules", "pushers", "state"
|
||||
"push_rules", "pushers", "state", "caches",
|
||||
))):
|
||||
__slots__ = []
|
||||
|
||||
|
|
|
@ -14,15 +14,43 @@
|
|||
# limitations under the License.
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseSlavedStore(SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(hs)
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = SlavedIdTracker(
|
||||
db_conn, "cache_invalidation_stream", "stream_id",
|
||||
)
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
||||
def stream_positions(self):
|
||||
return {}
|
||||
pos = {}
|
||||
if self._cache_id_gen:
|
||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||
return pos
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("caches")
|
||||
if stream:
|
||||
for row in stream["rows"]:
|
||||
(
|
||||
position, cache_func, keys, invalidation_ts,
|
||||
) = row
|
||||
|
||||
try:
|
||||
getattr(self, cache_func).invalidate(tuple(keys))
|
||||
except AttributeError:
|
||||
logger.info("Got unexpected cache_func: %r", cache_func)
|
||||
self._cache_id_gen.advance(int(stream["position"]))
|
||||
return defer.succeed(None)
|
||||
|
|
|
@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
|
|||
|
||||
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
||||
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
||||
get_app_services = DataStore.get_app_services.__func__
|
||||
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
|
||||
create_appservice_txn = DataStore.create_appservice_txn.__func__
|
||||
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
|
||||
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
|
||||
_get_last_txn = DataStore._get_last_txn.__func__
|
||||
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
|
||||
get_appservice_state = DataStore.get_appservice_state.__func__
|
||||
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
|
||||
set_appservice_state = DataStore.set_appservice_state.__func__
|
||||
|
|
|
@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
|
|||
class DirectoryStore(BaseSlavedStore):
|
||||
get_aliases_for_room = DirectoryStore.__dict__[
|
||||
"get_aliases_for_room"
|
||||
].orig
|
||||
]
|
||||
|
|
|
@ -25,6 +25,9 @@ class SlavedRegistrationStore(BaseSlavedStore):
|
|||
# TODO: use the cached version and invalidate deleted tokens
|
||||
get_user_by_access_token = RegistrationStore.__dict__[
|
||||
"get_user_by_access_token"
|
||||
].orig
|
||||
]
|
||||
|
||||
_query_for_auth = DataStore._query_for_auth.__func__
|
||||
get_user_by_id = RegistrationStore.__dict__[
|
||||
"get_user_by_id"
|
||||
]
|
||||
|
|
|
@ -46,7 +46,9 @@ from synapse.rest.client.v2_alpha import (
|
|||
account_data,
|
||||
report_event,
|
||||
openid,
|
||||
notifications,
|
||||
devices,
|
||||
thirdparty,
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
|
@ -91,4 +93,6 @@ class ClientRestResource(JsonResource):
|
|||
account_data.register_servlets(hs, client_resource)
|
||||
report_event.register_servlets(hs, client_resource)
|
||||
openid.register_servlets(hs, client_resource)
|
||||
notifications.register_servlets(hs, client_resource)
|
||||
devices.register_servlets(hs, client_resource)
|
||||
thirdparty.register_servlets(hs, client_resource)
|
||||
|
|
|
@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
|
|||
class WhoisRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(WhoisRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
|
|||
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PurgeHistoryRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
|
|
@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
|
|||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self.hs = hs
|
||||
self.handlers = hs.get_handlers()
|
||||
self.builder_factory = hs.get_event_builder_factory()
|
||||
self.auth = hs.get_v1auth()
|
||||
self.txns = HttpTransactionStore()
|
||||
|
|
|
@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
|
|||
class ClientDirectoryServer(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ClientDirectoryServer, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_alias):
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
|
@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(ClientDirectoryListServer, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
|
|
|
@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
|
||||
DEFAULT_LONGPOLL_TIME_MS = 30000
|
||||
|
||||
def __init__(self, hs):
|
||||
super(EventStreamRestServlet, self).__init__(hs)
|
||||
self.event_stream_handler = hs.get_event_stream_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
|
@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
if "room_id" in request.args:
|
||||
room_id = request.args["room_id"][0]
|
||||
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
|
@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
|
||||
as_client_event = "raw" not in request.args
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
chunk = yield self.event_stream_handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
|
@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(EventRestServlet, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.event_handler
|
||||
event = yield handler.get_event(requester.user, event_id)
|
||||
event = yield self.event_handler.get_event(requester.user, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
|
|
|
@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
|
|||
class InitialSyncRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/initialSync$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(InitialSyncRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
|
|
@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
self.jwt_secret = hs.config.jwt_secret
|
||||
self.jwt_algorithm = hs.config.jwt_algorithm
|
||||
self.cas_enabled = hs.config.cas_enabled
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.servername = hs.config.server_name
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.device_handler = self.hs.get_device_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def on_GET(self, request):
|
||||
flows = []
|
||||
|
@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
LoginRestServlet.JWT_TYPE):
|
||||
result = yield self.do_jwt_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
# TODO Delete this after all CAS clients switch to token login instead
|
||||
elif self.cas_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.CAS_TYPE):
|
||||
uri = "%s/proxyValidate" % (self.cas_server_url,)
|
||||
args = {
|
||||
"ticket": login_submission["ticket"],
|
||||
"service": login_submission["service"]
|
||||
}
|
||||
body = yield self.http_client.get_raw(uri, args)
|
||||
result = yield self.do_cas_login(body)
|
||||
defer.returnValue(result)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
result = yield self.do_token_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
|
@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
# TODO Delete this after all CAS clients switch to token login instead
|
||||
@defer.inlineCallbacks
|
||||
def do_cas_login(self, cas_response_body):
|
||||
user, attributes = self.parse_cas_response(cas_response_body)
|
||||
|
||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if required_attribute not in attributes:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
# Also need to check value
|
||||
if required_value is not None:
|
||||
actual_value = attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if required_value != actual_value:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if registered_user_id:
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
registered_user_id
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": registered_user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
else:
|
||||
user_id, access_token = (
|
||||
yield self.handlers.registration_handler.register(localpart=user)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_jwt_login(self, login_submission):
|
||||
token = login_submission.get("token", None)
|
||||
|
@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
# TODO Delete this after all CAS clients switch to token login instead
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not root[0].tag.endswith("authenticationSuccess"):
|
||||
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
attributes = {}
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in attribute tags
|
||||
# to the full URL of the namespace.
|
||||
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
|
||||
# We don't care about namespace here and it will always be encased in
|
||||
# curly braces, so we remove them.
|
||||
if "}" in attribute.tag:
|
||||
attributes[attribute.tag.split("}")[1]] = attribute.text
|
||||
else:
|
||||
attributes[attribute.tag] = attribute.text
|
||||
if user is None or attributes is None:
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return (user, attributes)
|
||||
|
||||
def _register_device(self, user_id, login_submission):
|
||||
"""Register a device for a user.
|
||||
|
||||
|
@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(SAML2RestServlet, self).__init__(hs)
|
||||
self.sp_config = hs.config.saml2_config_path
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((200, {"status": "not_authenticated"}))
|
||||
|
||||
|
||||
# TODO Delete this after all CAS clients switch to token login instead
|
||||
class CasRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/login/cas", releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(CasRestServlet, self).__init__(hs)
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
|
||||
def on_GET(self, request):
|
||||
return (200, {"serverUrl": self.cas_server_url})
|
||||
|
||||
|
||||
class CasRedirectServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
|
||||
|
||||
|
@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_service_url = hs.config.cas_service_url
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
return urlparse.urlunparse(url_parts)
|
||||
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
user = None
|
||||
attributes = None
|
||||
try:
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not root[0].tag.endswith("authenticationSuccess"):
|
||||
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = (root[0].tag.endswith("authenticationSuccess"))
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
attributes = {}
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in attribute tags
|
||||
# to the full URL of the namespace.
|
||||
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
|
||||
# We don't care about namespace here and it will always be encased in
|
||||
# curly braces, so we remove them.
|
||||
if "}" in attribute.tag:
|
||||
attributes[attribute.tag.split("}")[1]] = attribute.text
|
||||
else:
|
||||
attributes[attribute.tag] = attribute.text
|
||||
if user is None or attributes is None:
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return (user, attributes)
|
||||
# ElementTree library expands the namespace in
|
||||
# attribute tags to the full URL of the namespace.
|
||||
# We don't care about namespace here and it will always
|
||||
# be encased in curly braces, so we remove them.
|
||||
tag = attribute.tag
|
||||
if "}" in tag:
|
||||
tag = tag.split("}")[1]
|
||||
attributes[tag] = attribute.text
|
||||
if user is None:
|
||||
raise Exception("CAS response does not contain user")
|
||||
if attributes is None:
|
||||
raise Exception("CAS response does not contain attributes")
|
||||
except Exception:
|
||||
logger.error("Error parsing CAS response", exc_info=1)
|
||||
raise LoginError(401, "Invalid CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(401, "Unsuccessful CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
return user, attributes
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
@ -512,5 +426,3 @@ def register_servlets(hs, http_server):
|
|||
if hs.config.cas_enabled:
|
||||
CasRedirectServlet(hs).register(http_server)
|
||||
CasTicketServlet(hs).register(http_server)
|
||||
CasRestServlet(hs).register(http_server)
|
||||
# TODO PasswordResetRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
|
|||
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
|||
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
|||
class ProfileRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
|
|
@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
self.sessions = {}
|
||||
self.enable_registration = hs.config.enable_registration
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def on_GET(self, request):
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
|
@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
super(CreateUserRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
|
|
@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
|
|||
class RoomCreateRestServlet(ClientV1RestServlet):
|
||||
# No PATTERN; we have custom dispatch rules here
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomCreateRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = "/createRoom"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
|||
|
||||
# TODO: Needs unit testing for generic events
|
||||
class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(RoomStateEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /room/$roomid/state/$eventtype
|
||||
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
|
||||
|
@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
# TODO: Needs unit testing for generic events + feedback
|
||||
class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomSendEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
|
||||
|
@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
# TODO: Needs unit testing for room ID + alias joins
|
||||
class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(JoinRoomAliasServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /join/$room_identifier[/$txn_id]
|
||||
|
@ -253,6 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
|||
action="join",
|
||||
txn_id=txn_id,
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
content=content,
|
||||
third_party_signed=content.get("third_party_signed", None),
|
||||
)
|
||||
|
||||
|
@ -296,6 +312,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
|||
class RoomMemberListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMemberListRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
# TODO support Pagination stream API (limit/tokens)
|
||||
|
@ -322,6 +342,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
|||
class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMessageListRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
@ -351,6 +375,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||
class RoomStateRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomStateRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
@ -368,6 +396,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
|
|||
class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomInitialSyncRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
@ -388,6 +420,7 @@ class RoomEventContext(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(RoomEventContext, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
|
@ -424,6 +457,10 @@ class RoomEventContext(ClientV1RestServlet):
|
|||
|
||||
|
||||
class RoomForgetRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(RoomForgetRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
@ -462,6 +499,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
|
|||
# TODO: Needs unit testing
|
||||
class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMembershipRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/[invite|join|leave]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
||||
|
@ -542,6 +583,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
|
||||
|
||||
class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(RoomRedactEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
@ -624,6 +669,10 @@ class SearchRestServlet(ClientV1RestServlet):
|
|||
"/search$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(SearchRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
|
99
synapse/rest/client/v2_alpha/notifications.py
Normal file
99
synapse/rest/client/v2_alpha/notifications.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_string, parse_integer
|
||||
)
|
||||
from synapse.events.utils import (
|
||||
serialize_event, format_event_for_client_v2_without_room_id,
|
||||
)
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationsServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/notifications$", releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(NotificationsServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
from_token = parse_string(request, "from", required=False)
|
||||
limit = parse_integer(request, "limit", default=50)
|
||||
|
||||
limit = min(limit, 500)
|
||||
|
||||
push_actions = yield self.store.get_push_actions_for_user(
|
||||
user_id, from_token, limit
|
||||
)
|
||||
|
||||
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
|
||||
user_id, 'm.read'
|
||||
)
|
||||
|
||||
notif_event_ids = [pa["event_id"] for pa in push_actions]
|
||||
notif_events = yield self.store.get_events(notif_event_ids)
|
||||
|
||||
returned_push_actions = []
|
||||
|
||||
next_token = None
|
||||
|
||||
for pa in push_actions:
|
||||
returned_pa = {
|
||||
"room_id": pa["room_id"],
|
||||
"profile_tag": pa["profile_tag"],
|
||||
"actions": pa["actions"],
|
||||
"ts": pa["received_ts"],
|
||||
"event": serialize_event(
|
||||
notif_events[pa["event_id"]],
|
||||
self.clock.time_msec(),
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
),
|
||||
}
|
||||
|
||||
if pa["room_id"] not in receipts_by_room:
|
||||
returned_pa["read"] = False
|
||||
else:
|
||||
receipt = receipts_by_room[pa["room_id"]]
|
||||
|
||||
returned_pa["read"] = (
|
||||
receipt["topological_ordering"], receipt["stream_ordering"]
|
||||
) >= (
|
||||
pa["topological_ordering"], pa["stream_ordering"]
|
||||
)
|
||||
returned_push_actions.append(returned_pa)
|
||||
next_token = pa["stream_ordering"]
|
||||
|
||||
defer.returnValue((200, {
|
||||
"notifications": returned_push_actions,
|
||||
"next_token": next_token,
|
||||
}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
NotificationsServlet(hs).register(http_server)
|
|
@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet):
|
|||
# register the user's device
|
||||
device_id = params.get("device_id")
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
device_id = self.device_handler.check_device_registered(
|
||||
return self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
return device_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_guest_registration(self):
|
||||
|
|
|
@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
|
|||
affect_presence = set_presence != PresenceState.OFFLINE
|
||||
|
||||
if affect_presence:
|
||||
yield self.presence_handler.set_state(user, {"presence": set_presence})
|
||||
yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
|
||||
|
||||
context = yield self.presence_handler.user_syncing(
|
||||
user.to_string(), affect_presence=affect_presence,
|
||||
|
|
78
synapse/rest/client/v2_alpha/thirdparty.py
Normal file
78
synapse/rest/client/v2_alpha/thirdparty.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThirdPartyUserServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
|
||||
releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThirdPartyUserServlet, self).__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, protocol):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
fields = request.args
|
||||
del fields["access_token"]
|
||||
|
||||
results = yield self.appservice_handler.query_3pe(
|
||||
ThirdPartyEntityKind.USER, protocol, fields
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
||||
|
||||
class ThirdPartyLocationServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
|
||||
releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThirdPartyLocationServlet, self).__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, protocol):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
fields = request.args
|
||||
del fields["access_token"]
|
||||
|
||||
results = yield self.appservice_handler.query_3pe(
|
||||
ThirdPartyEntityKind.LOCATION, protocol, fields
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ThirdPartyUserServlet(hs).register(http_server)
|
||||
ThirdPartyLocationServlet(hs).register(http_server)
|
|
@ -15,6 +15,7 @@
|
|||
from synapse.http.server import request_handler, respond_with_json_bytes
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.crypto.keyring import KeyLookupError
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
@ -210,9 +211,10 @@ class RemoteKey(Resource):
|
|||
yield self.keyring.get_server_verify_key_v2_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
except KeyLookupError as e:
|
||||
logger.info("Failed to fetch key: %s", e)
|
||||
except:
|
||||
logger.exception("Failed to get key for %r", server_name)
|
||||
pass
|
||||
yield self.query_keys(
|
||||
request, query, query_remote_on_cache_miss=False
|
||||
)
|
||||
|
|
|
@ -45,6 +45,7 @@ class DownloadResource(Resource):
|
|||
@request_handler()
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
request.setHeader("Content-Security-Policy", "sandbox")
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_file(request, media_id, name)
|
||||
|
|
|
@ -29,14 +29,13 @@ from synapse.http.server import (
|
|||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
import cgi
|
||||
import ujson as json
|
||||
import urlparse
|
||||
import itertools
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -163,7 +162,7 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
logger.debug("got media_info of '%s'" % media_info)
|
||||
|
||||
if self._is_media(media_info['media_type']):
|
||||
if _is_media(media_info['media_type']):
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
media_info['filesystem_id'], media_info
|
||||
)
|
||||
|
@ -184,11 +183,9 @@ class PreviewUrlResource(Resource):
|
|||
logger.warn("Couldn't get dims for %s" % url)
|
||||
|
||||
# define our OG response for this media
|
||||
elif self._is_html(media_info['media_type']):
|
||||
elif _is_html(media_info['media_type']):
|
||||
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
|
||||
|
||||
from lxml import etree
|
||||
|
||||
file = open(media_info['filename'])
|
||||
body = file.read()
|
||||
file.close()
|
||||
|
@ -199,17 +196,35 @@ class PreviewUrlResource(Resource):
|
|||
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
|
||||
encoding = match.group(1) if match else "utf-8"
|
||||
|
||||
try:
|
||||
parser = etree.HTMLParser(recover=True, encoding=encoding)
|
||||
tree = etree.fromstring(body, parser)
|
||||
og = yield self._calc_og(tree, media_info, requester)
|
||||
except UnicodeDecodeError:
|
||||
# blindly try decoding the body as utf-8, which seems to fix
|
||||
# the charset mismatches on https://google.com
|
||||
parser = etree.HTMLParser(recover=True, encoding=encoding)
|
||||
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
|
||||
og = yield self._calc_og(tree, media_info, requester)
|
||||
og = decode_and_calc_og(body, media_info['uri'], encoding)
|
||||
|
||||
# pre-cache the image for posterity
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
image_info = yield self._download_url(
|
||||
_rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||
)
|
||||
|
||||
if _is_media(image_info['media_type']):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
image_info['filesystem_id'], image_info
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
og["og:image:height"] = dims['height']
|
||||
else:
|
||||
logger.warn("Couldn't get dims for %s" % og["og:image"])
|
||||
|
||||
og["og:image"] = "mxc://%s/%s" % (
|
||||
self.server_name, image_info['filesystem_id']
|
||||
)
|
||||
og["og:image:type"] = image_info['media_type']
|
||||
og["matrix:image:size"] = image_info['media_length']
|
||||
else:
|
||||
del og["og:image"]
|
||||
else:
|
||||
logger.warn("Failed to find any OG data in %s", url)
|
||||
og = {}
|
||||
|
@ -232,139 +247,6 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calc_og(self, tree, media_info, requester):
|
||||
# suck our tree into lxml and define our OG response.
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
# (although the client could choose to do this by asking for previews of those
|
||||
# URLs to avoid DoSing the server)
|
||||
|
||||
# "og:type" : "video",
|
||||
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
|
||||
# "og:site_name" : "YouTube",
|
||||
# "og:video:type" : "application/x-shockwave-flash",
|
||||
# "og:description" : "Fun stuff happening here",
|
||||
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
|
||||
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
|
||||
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
|
||||
# "og:video:width" : "1280"
|
||||
# "og:video:height" : "720",
|
||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||
|
||||
og = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if 'content' in tag.attrib:
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
# "article:publisher" : "https://www.facebook.com/thethudonline" />
|
||||
# "article:author" content="https://www.facebook.com/thethudonline" />
|
||||
# "article:tag" content="baby" />
|
||||
# "article:section" content="Breaking News" />
|
||||
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||
|
||||
if 'og:title' not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
og['og:title'] = title[0].text.strip() if title else None
|
||||
|
||||
if 'og:image' not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
)
|
||||
if meta_image:
|
||||
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
|
||||
else:
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(images, key=lambda i: (
|
||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||
))
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
if images:
|
||||
og['og:image'] = images[0].attrib['src']
|
||||
|
||||
# pre-cache the image for posterity
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
image_info = yield self._download_url(
|
||||
self._rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||
)
|
||||
|
||||
if self._is_media(image_info['media_type']):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
image_info['filesystem_id'], image_info
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
og["og:image:height"] = dims['height']
|
||||
else:
|
||||
logger.warn("Couldn't get dims for %s" % og["og:image"])
|
||||
|
||||
og["og:image"] = "mxc://%s/%s" % (
|
||||
self.server_name, image_info['filesystem_id']
|
||||
)
|
||||
og["og:image:type"] = image_info['media_type']
|
||||
og["matrix:image:size"] = image_info['media_length']
|
||||
else:
|
||||
del og["og:image"]
|
||||
|
||||
if 'og:description' not in og:
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta"
|
||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||
"/@content")
|
||||
if meta_description:
|
||||
og['og:description'] = meta_description[0]
|
||||
else:
|
||||
# grab any text nodes which are inside the <body/> tag...
|
||||
# unless they are within an HTML5 semantic markup tag...
|
||||
# <header/>, <nav/>, <aside/>, <footer/>
|
||||
# ...or if they are within a <script/> or <style/> tag.
|
||||
# This is a very very very coarse approximation to a plain text
|
||||
# render of the page.
|
||||
|
||||
# We don't just use XPATH here as that is slow on some machines.
|
||||
|
||||
# We clone `tree` as we modify it.
|
||||
cloned_tree = deepcopy(tree.find("body"))
|
||||
|
||||
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
|
||||
for el in cloned_tree.iter(TAGS_TO_REMOVE):
|
||||
el.getparent().remove(el)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r'\s+', '\n', el.text).strip()
|
||||
for el in cloned_tree.iter()
|
||||
if el.text and isinstance(el.tag, basestring) # Removes comments
|
||||
)
|
||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
defer.returnValue(og)
|
||||
|
||||
def _rebase_url(self, url, base):
|
||||
base = list(urlparse.urlparse(base))
|
||||
url = list(urlparse.urlparse(url))
|
||||
if not url[0]: # fix up schema
|
||||
url[0] = base[0] or "http"
|
||||
if not url[1]: # fix up hostname
|
||||
url[1] = base[1]
|
||||
if not url[2].startswith('/'):
|
||||
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
|
||||
return urlparse.urlunparse(url)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _download_url(self, url, user):
|
||||
# TODO: we should probably honour robots.txt... except in practice
|
||||
|
@ -445,11 +327,165 @@ class PreviewUrlResource(Resource):
|
|||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||
})
|
||||
|
||||
def _is_media(self, content_type):
|
||||
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||
from lxml import etree
|
||||
|
||||
try:
|
||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||
tree = etree.fromstring(body, parser)
|
||||
og = _calc_og(tree, media_uri)
|
||||
except UnicodeDecodeError:
|
||||
# blindly try decoding the body as utf-8, which seems to fix
|
||||
# the charset mismatches on https://google.com
|
||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
|
||||
og = _calc_og(tree, media_uri)
|
||||
|
||||
return og
|
||||
|
||||
|
||||
def _calc_og(tree, media_uri):
|
||||
# suck our tree into lxml and define our OG response.
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
# (although the client could choose to do this by asking for previews of those
|
||||
# URLs to avoid DoSing the server)
|
||||
|
||||
# "og:type" : "video",
|
||||
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
|
||||
# "og:site_name" : "YouTube",
|
||||
# "og:video:type" : "application/x-shockwave-flash",
|
||||
# "og:description" : "Fun stuff happening here",
|
||||
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
|
||||
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
|
||||
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
|
||||
# "og:video:width" : "1280"
|
||||
# "og:video:height" : "720",
|
||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||
|
||||
og = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if 'content' in tag.attrib:
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
# "article:publisher" : "https://www.facebook.com/thethudonline" />
|
||||
# "article:author" content="https://www.facebook.com/thethudonline" />
|
||||
# "article:tag" content="baby" />
|
||||
# "article:section" content="Breaking News" />
|
||||
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||
|
||||
if 'og:title' not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
og['og:title'] = title[0].text.strip() if title else None
|
||||
|
||||
if 'og:image' not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
)
|
||||
if meta_image:
|
||||
og['og:image'] = _rebase_url(meta_image[0], media_uri)
|
||||
else:
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(images, key=lambda i: (
|
||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||
))
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
if images:
|
||||
og['og:image'] = images[0].attrib['src']
|
||||
|
||||
if 'og:description' not in og:
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta"
|
||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||
"/@content")
|
||||
if meta_description:
|
||||
og['og:description'] = meta_description[0]
|
||||
else:
|
||||
# grab any text nodes which are inside the <body/> tag...
|
||||
# unless they are within an HTML5 semantic markup tag...
|
||||
# <header/>, <nav/>, <aside/>, <footer/>
|
||||
# ...or if they are within a <script/> or <style/> tag.
|
||||
# This is a very very very coarse approximation to a plain text
|
||||
# render of the page.
|
||||
|
||||
# We don't just use XPATH here as that is slow on some machines.
|
||||
|
||||
from lxml import etree
|
||||
|
||||
TAGS_TO_REMOVE = (
|
||||
"header", "nav", "aside", "footer", "script", "style", etree.Comment
|
||||
)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r'\s+', '\n', el).strip()
|
||||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||
)
|
||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
return og
|
||||
|
||||
|
||||
def _iterate_over_text(tree, *tags_to_ignore):
|
||||
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||
skipping text nodes inside certain tags.
|
||||
"""
|
||||
# This is basically a stack that we extend using itertools.chain.
|
||||
# This will either consist of an element to iterate over *or* a string
|
||||
# to be returned.
|
||||
elements = iter([tree])
|
||||
while True:
|
||||
el = elements.next()
|
||||
if isinstance(el, basestring):
|
||||
yield el
|
||||
elif el is not None and el.tag not in tags_to_ignore:
|
||||
# el.text is the text before the first child, so we can immediately
|
||||
# return it if the text exists.
|
||||
if el.text:
|
||||
yield el.text
|
||||
|
||||
# We add to the stack all the elements children, interspersed with
|
||||
# each child's tail text (if it exists). The tail text of a node
|
||||
# is text that comes *after* the node, so we always include it even
|
||||
# if we ignore the child node.
|
||||
elements = itertools.chain(
|
||||
itertools.chain.from_iterable( # Basically a flatmap
|
||||
[child, child.tail] if child.tail else [child]
|
||||
for child in el.iterchildren()
|
||||
),
|
||||
elements
|
||||
)
|
||||
|
||||
|
||||
def _rebase_url(url, base):
|
||||
base = list(urlparse.urlparse(base))
|
||||
url = list(urlparse.urlparse(url))
|
||||
if not url[0]: # fix up schema
|
||||
url[0] = base[0] or "http"
|
||||
if not url[1]: # fix up hostname
|
||||
url[1] = base[1]
|
||||
if not url[2].startswith('/'):
|
||||
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
|
||||
return urlparse.urlunparse(url)
|
||||
|
||||
|
||||
def _is_media(content_type):
|
||||
if content_type.lower().startswith("image/"):
|
||||
return True
|
||||
|
||||
def _is_html(self, content_type):
|
||||
|
||||
def _is_html(content_type):
|
||||
content_type = content_type.lower()
|
||||
if (
|
||||
content_type.startswith("text/html") or
|
||||
|
|
|
@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
|
|||
from synapse.handlers.room import RoomListHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.notifier import Notifier
|
||||
|
@ -94,6 +95,8 @@ class HomeServer(object):
|
|||
'auth_handler',
|
||||
'device_handler',
|
||||
'e2e_keys_handler',
|
||||
'event_handler',
|
||||
'event_stream_handler',
|
||||
'application_service_api',
|
||||
'application_service_scheduler',
|
||||
'application_service_handler',
|
||||
|
@ -214,6 +217,12 @@ class HomeServer(object):
|
|||
def build_application_service_handler(self):
|
||||
return ApplicationServicesHandler(self)
|
||||
|
||||
def build_event_handler(self):
|
||||
return EventHandler(self)
|
||||
|
||||
def build_event_stream_handler(self):
|
||||
return EventStreamHandler(self)
|
||||
|
||||
def build_event_sources(self):
|
||||
return EventSources(self)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import synapse.api.auth
|
||||
import synapse.handlers
|
||||
import synapse.handlers.auth
|
||||
import synapse.handlers.device
|
||||
|
@ -6,6 +7,9 @@ import synapse.storage
|
|||
import synapse.state
|
||||
|
||||
class HomeServer(object):
|
||||
def get_auth(self) -> synapse.api.auth.Auth:
|
||||
pass
|
||||
|
||||
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
|
||||
pass
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ from .openid import OpenIdStore
|
|||
from .client_ips import ClientIpStore
|
||||
|
||||
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
|
||||
from .engines import PostgresEngine
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = StreamIdGenerator(
|
||||
db_conn, "cache_invalidation_stream", "stream_id",
|
||||
)
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
event_cache_prefill, min_event_val = self._get_cache_dict(
|
||||
db_conn, "events",
|
||||
|
|
|
@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
|||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||
from synapse.util.caches.descriptors import Cache
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
import synapse.metrics
|
||||
|
||||
|
||||
|
@ -165,7 +166,7 @@ class SQLBaseStore(object):
|
|||
self._txn_perf_counters = PerformanceCounters()
|
||||
self._get_event_counters = PerformanceCounters()
|
||||
|
||||
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
|
||||
self._get_event_cache = Cache("*getEvent*", keylen=3,
|
||||
max_entries=hs.config.event_cache_size)
|
||||
|
||||
self._state_group_cache = DictionaryCache(
|
||||
|
@ -305,11 +306,12 @@ class SQLBaseStore(object):
|
|||
func, *args, **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
|
||||
finally:
|
||||
for after_callback, after_args in after_callbacks:
|
||||
after_callback(*after_args)
|
||||
defer.returnValue(result)
|
||||
|
@ -860,6 +862,62 @@ class SQLBaseStore(object):
|
|||
|
||||
return cache, min_val
|
||||
|
||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
|
||||
This should only be used to invalidate caches where slaves won't
|
||||
otherwise know from other replication streams that the cache should
|
||||
be invalidated.
|
||||
"""
|
||||
txn.call_after(cache_func.invalidate, keys)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# get_next() returns a context manager which is designed to wrap
|
||||
# the transaction. However, we want to only get an ID when we want
|
||||
# to use it, here, so we need to call __enter__ manually, and have
|
||||
# __exit__ called after the transaction finishes.
|
||||
ctx = self._cache_id_gen.get_next()
|
||||
stream_id = ctx.__enter__()
|
||||
txn.call_after(ctx.__exit__, None, None, None)
|
||||
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="cache_invalidation_stream",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"cache_func": cache_func.__name__,
|
||||
"keys": list(keys),
|
||||
"invalidation_ts": self.clock.time_msec(),
|
||||
}
|
||||
)
|
||||
|
||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_updated_caches_txn(txn):
|
||||
# We purposefully don't bound by the current token, as we want to
|
||||
# send across cache invalidations as quickly as possible. Cache
|
||||
# invalidations are idempotent, so duplicates are fine.
|
||||
sql = (
|
||||
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
||||
" FROM cache_invalidation_stream"
|
||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, limit,))
|
||||
return txn.fetchall()
|
||||
return self.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
class _RollbackButIsFineException(Exception):
|
||||
""" This exception is used to rollback a transaction without implying
|
||||
|
|
|
@ -218,13 +218,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||
Returns:
|
||||
AppServiceTransaction: A new transaction.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"create_appservice_txn",
|
||||
self._create_appservice_txn,
|
||||
service, events
|
||||
)
|
||||
|
||||
def _create_appservice_txn(self, txn, service, events):
|
||||
def _create_appservice_txn(txn):
|
||||
# work out new txn id (highest txn id for this service += 1)
|
||||
# The highest id may be the last one sent (in which case it is last_txn)
|
||||
# or it may be the highest in the txns list (which are waiting to be/are
|
||||
|
@ -252,6 +246,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||
service=service, id=new_txn_id, events=events
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"create_appservice_txn",
|
||||
_create_appservice_txn,
|
||||
)
|
||||
|
||||
def complete_appservice_txn(self, txn_id, service):
|
||||
"""Completes an application service transaction.
|
||||
|
||||
|
@ -263,15 +262,9 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||
A Deferred which resolves if this transaction was stored
|
||||
successfully.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"complete_appservice_txn",
|
||||
self._complete_appservice_txn,
|
||||
txn_id, service
|
||||
)
|
||||
|
||||
def _complete_appservice_txn(self, txn, txn_id, service):
|
||||
txn_id = int(txn_id)
|
||||
|
||||
def _complete_appservice_txn(txn):
|
||||
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
|
||||
# what was there before. If it isn't, we've got problems (e.g. the AS
|
||||
# has probably missed some events), so whine loudly but still continue,
|
||||
|
@ -298,6 +291,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||
dict(txn_id=txn_id, as_id=service.id)
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"complete_appservice_txn",
|
||||
_complete_appservice_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_oldest_unsent_txn(self, service):
|
||||
"""Get the oldest transaction which has not been sent for this
|
||||
|
@ -309,24 +307,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||
A Deferred which resolves to an AppServiceTransaction or
|
||||
None.
|
||||
"""
|
||||
entry = yield self.runInteraction(
|
||||
"get_oldest_unsent_appservice_txn",
|
||||
self._get_oldest_unsent_txn,
|
||||
service
|
||||
)
|
||||
|
||||
if not entry:
|
||||
defer.returnValue(None)
|
||||
|
||||
event_ids = json.loads(entry["event_ids"])
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
||||
defer.returnValue(AppServiceTransaction(
|
||||
service=service, id=entry["txn_id"], events=events
|
||||
))
|
||||
|
||||
def _get_oldest_unsent_txn(self, txn, service):
|
||||
def _get_oldest_unsent_txn(txn):
|
||||
# Monotonically increasing txn ids, so just select the smallest
|
||||
# one in the txns table (we delete them when they are sent)
|
||||
txn.execute(
|
||||
|
@ -342,6 +323,22 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||
|
||||
return entry
|
||||
|
||||
entry = yield self.runInteraction(
|
||||
"get_oldest_unsent_appservice_txn",
|
||||
_get_oldest_unsent_txn,
|
||||
)
|
||||
|
||||
if not entry:
|
||||
defer.returnValue(None)
|
||||
|
||||
event_ids = json.loads(entry["event_ids"])
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
||||
defer.returnValue(AppServiceTransaction(
|
||||
service=service, id=entry["txn_id"], events=events
|
||||
))
|
||||
|
||||
def _get_last_txn(self, txn, service_id):
|
||||
txn.execute(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||
|
@ -352,3 +349,45 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||
return 0
|
||||
else:
|
||||
return int(last_txn_id[0]) # select 'last_txn' col
|
||||
|
||||
def set_appservice_last_pos(self, pos):
|
||||
def set_appservice_last_pos_txn(txn):
|
||||
txn.execute(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
)
|
||||
return self.runInteraction(
|
||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_appservice(self, current_id, limit):
|
||||
"""Get all new evnets"""
|
||||
|
||||
def get_new_events_for_appservice_txn(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id"
|
||||
" FROM events AS e"
|
||||
" WHERE"
|
||||
" (SELECT stream_ordering FROM appservice_stream_position)"
|
||||
" < e.stream_ordering"
|
||||
" AND e.stream_ordering <= ?"
|
||||
" ORDER BY e.stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (current_id, limit))
|
||||
rows = txn.fetchall()
|
||||
|
||||
upper_bound = current_id
|
||||
if len(rows) == limit:
|
||||
upper_bound = rows[-1][0]
|
||||
|
||||
return upper_bound, [row[1] for row in rows]
|
||||
|
||||
upper_bound, event_ids = yield self.runInteraction(
|
||||
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
|
||||
)
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
||||
defer.returnValue((upper_bound, events))
|
||||
|
|
|
@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
|
|||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
try:
|
||||
yield self._simple_insert(
|
||||
def alias_txn(txn):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_aliases",
|
||||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"room_id": room_id,
|
||||
"creator": creator,
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="room_alias_servers",
|
||||
values=[{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
} for server in servers],
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_aliases_for_room, (room_id,)
|
||||
)
|
||||
|
||||
try:
|
||||
ret = yield self.runInteraction(
|
||||
"create_room_alias_association", alias_txn
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise SynapseError(
|
||||
409, "Room alias %s already exists" % room_alias.to_string()
|
||||
)
|
||||
|
||||
for server in servers:
|
||||
# TODO(erikj): Fix this to bulk insert
|
||||
yield self._simple_insert(
|
||||
"room_alias_servers",
|
||||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
self.get_aliases_for_room.invalidate((room_id,))
|
||||
defer.returnValue(ret)
|
||||
|
||||
def get_room_alias_creator(self, room_alias):
|
||||
return self._simple_select_one_onecol(
|
||||
|
|
|
@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
)
|
||||
self._simple_insert_many_txn(txn, "event_push_actions", values)
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
|
||||
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
||||
def get_unread_event_push_actions_by_room_for_user(
|
||||
self, room_id, user_id, last_read_event_id
|
||||
):
|
||||
|
@ -337,6 +337,36 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
# Now return the first `limit`
|
||||
defer.returnValue(notifs[:limit])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_push_actions_for_user(self, user_id, before=None, limit=50):
|
||||
def f(txn):
|
||||
before_clause = ""
|
||||
if before:
|
||||
before_clause = "AND stream_ordering < ?"
|
||||
args = [user_id, before, limit]
|
||||
else:
|
||||
args = [user_id, limit]
|
||||
sql = (
|
||||
"SELECT epa.event_id, epa.room_id,"
|
||||
" epa.stream_ordering, epa.topological_ordering,"
|
||||
" epa.actions, epa.profile_tag, e.received_ts"
|
||||
" FROM event_push_actions epa, events e"
|
||||
" WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id"
|
||||
" AND epa.user_id = ? %s"
|
||||
" ORDER BY epa.stream_ordering DESC"
|
||||
" LIMIT ?"
|
||||
% (before_clause,)
|
||||
)
|
||||
txn.execute(sql, args)
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
push_actions = yield self.runInteraction(
|
||||
"get_push_actions_for_user", f
|
||||
)
|
||||
for pa in push_actions:
|
||||
pa["actions"] = json.loads(pa["actions"])
|
||||
defer.returnValue(push_actions)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
def f(txn):
|
||||
|
|
|
@ -20,8 +20,11 @@ from synapse.events import FrozenEvent, USE_FROZEN_DICTS
|
|||
from synapse.events.utils import prune_event
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
|
||||
from synapse.util.logcontext import (
|
||||
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
|
||||
)
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
|
@ -201,7 +204,7 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
deferreds = []
|
||||
for room_id, evs_ctxs in partitioned.items():
|
||||
d = self._event_persist_queue.add_to_queue(
|
||||
d = preserve_fn(self._event_persist_queue.add_to_queue)(
|
||||
room_id, evs_ctxs,
|
||||
backfilled=backfilled,
|
||||
current_state=None,
|
||||
|
@ -211,7 +214,9 @@ class EventsStore(SQLBaseStore):
|
|||
for room_id in partitioned.keys():
|
||||
self._maybe_start_persisting(room_id)
|
||||
|
||||
return defer.gatherResults(deferreds, consumeErrors=True)
|
||||
return preserve_context_over_deferred(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -224,7 +229,7 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
self._maybe_start_persisting(event.room_id)
|
||||
|
||||
yield deferred
|
||||
yield preserve_context_over_deferred(deferred)
|
||||
|
||||
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
||||
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
|
||||
|
@ -600,7 +605,8 @@ class EventsStore(SQLBaseStore):
|
|||
"rejections",
|
||||
"redactions",
|
||||
"room_memberships",
|
||||
"state_events"
|
||||
"state_events",
|
||||
"topics"
|
||||
):
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||
|
@ -1086,7 +1092,7 @@ class EventsStore(SQLBaseStore):
|
|||
if not allow_rejected:
|
||||
rows[:] = [r for r in rows if not r["rejects"]]
|
||||
|
||||
res = yield defer.gatherResults(
|
||||
res = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self._get_event_from_row)(
|
||||
row["internal_metadata"], row["json"], row["redacts"],
|
||||
|
@ -1095,7 +1101,7 @@ class EventsStore(SQLBaseStore):
|
|||
for row in rows
|
||||
],
|
||||
consumeErrors=True
|
||||
)
|
||||
))
|
||||
|
||||
defer.returnValue({
|
||||
e.event.event_id: e
|
||||
|
@ -1131,6 +1137,7 @@ class EventsStore(SQLBaseStore):
|
|||
@defer.inlineCallbacks
|
||||
def _get_event_from_row(self, internal_metadata, js, redacted,
|
||||
rejected_reason=None):
|
||||
with Measure(self._clock, "_get_event_from_row"):
|
||||
d = json.loads(js)
|
||||
internal_metadata = json.loads(internal_metadata)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 33
|
||||
SCHEMA_VERSION = 34
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
|
|
@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
|
|||
desc="add_presence_list_pending",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||
result = yield self._simple_update_one(
|
||||
def update_presence_list_txn(txn):
|
||||
result = self._simple_update_one_txn(
|
||||
txn,
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
keyvalues={
|
||||
"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid
|
||||
},
|
||||
updatevalues={"accepted": True},
|
||||
desc="set_presence_list_accepted",
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
|
||||
defer.returnValue(result)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_presence_list_accepted, (observer_localpart,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_presence_list_observers_accepted, (observed_userid,)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return self.runInteraction(
|
||||
"set_presence_list_accepted", update_presence_list_txn,
|
||||
)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
if accepted:
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||
from synapse.push.baserules import list_with_base_rules
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
|
|||
|
||||
|
||||
class PushRuleStore(SQLBaseStore):
|
||||
@cachedInlineCallbacks(lru=True)
|
||||
@cachedInlineCallbacks()
|
||||
def get_push_rules_for_user(self, user_id):
|
||||
rows = yield self._simple_select_list(
|
||||
table="push_rules",
|
||||
|
@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue(rules)
|
||||
|
||||
@cachedInlineCallbacks(lru=True)
|
||||
@cachedInlineCallbacks()
|
||||
def get_push_rules_enabled_for_user(self, user_id):
|
||||
results = yield self._simple_select_list(
|
||||
table="push_rules_enable",
|
||||
|
@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue(results)
|
||||
|
||||
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
# of None don't hit previous cached calls with a None state_group.
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
|
||||
cache_context):
|
||||
# We don't use `state_group`, its there so that we can cache based
|
||||
# on it. However, its important that its never None, since two current_state's
|
||||
# with a state_group of None are likely to be different.
|
||||
# See bulk_get_push_rules_for_room for how we work around this.
|
||||
assert state_group is not None
|
||||
|
||||
# We also will want to generate notifs for other people in the room so
|
||||
# their unread countss are correct in the event stream, but to avoid
|
||||
# generating them for bot / AS users etc, we only do so for people who've
|
||||
# sent a read receipt into the room.
|
||||
local_users_in_room = set(
|
||||
e.state_key for e in current_state.values()
|
||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
||||
and self.hs.is_mine_id(e.state_key)
|
||||
)
|
||||
|
||||
# users in the room who have pushers need to get push rules run because
|
||||
# that's how their pushers work
|
||||
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
||||
local_users_in_room, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
user_ids = set(
|
||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||
)
|
||||
|
||||
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
|
||||
# any users with pushers must be ours: they have pushers
|
||||
for uid in users_with_receipts:
|
||||
if uid in local_users_in_room:
|
||||
user_ids.add(uid)
|
||||
|
||||
rules_by_user = yield self.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
|
||||
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
||||
|
||||
defer.returnValue(rules_by_user)
|
||||
|
||||
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
|
||||
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
||||
def bulk_get_push_rules_enabled(self, user_ids):
|
||||
|
|
|
@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
|
|||
"get_all_updated_pushers", get_all_updated_pushers_txn
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
|
||||
@cachedInlineCallbacks(num_args=1, max_entries=15000)
|
||||
def get_if_user_has_pusher(self, user_id):
|
||||
result = yield self._simple_select_many_batch(
|
||||
table='pushers',
|
||||
|
|
|
@ -94,6 +94,31 @@ class ReceiptsStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT rl.room_id, rl.event_id,"
|
||||
" e.topological_ordering, e.stream_ordering"
|
||||
" FROM receipts_linearized AS rl"
|
||||
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||
" WHERE rl.room_id = e.room_id"
|
||||
" AND rl.event_id = e.event_id"
|
||||
" AND user_id = ?"
|
||||
)
|
||||
txn.execute(sql, (user_id,))
|
||||
return txn.fetchall()
|
||||
rows = yield self.runInteraction(
|
||||
"get_receipts_for_user_with_orderings", f
|
||||
)
|
||||
defer.returnValue({
|
||||
row[0]: {
|
||||
"event_id": row[1],
|
||||
"topological_ordering": row[2],
|
||||
"stream_ordering": row[3],
|
||||
} for row in rows
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
"""Get receipts for multiple rooms for sending to clients.
|
||||
|
@ -120,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue([ev for res in results.values() for ev in res])
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
|
||||
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
|
||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
"""Get receipts for a single room for sending to clients.
|
||||
|
||||
|
|
|
@ -93,7 +93,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
desc="add_refresh_token_to_user",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, user_id, token=None, password_hash=None,
|
||||
was_guest=False, make_guest=False, appservice_id=None,
|
||||
create_profile_with_localpart=None, admin=False):
|
||||
|
@ -115,7 +114,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
Raises:
|
||||
StoreError if the user_id could not be registered.
|
||||
"""
|
||||
yield self.runInteraction(
|
||||
return self.runInteraction(
|
||||
"register",
|
||||
self._register,
|
||||
user_id,
|
||||
|
@ -127,8 +126,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
create_profile_with_localpart,
|
||||
admin
|
||||
)
|
||||
self.get_user_by_id.invalidate((user_id,))
|
||||
self.is_guest.invalidate((user_id,))
|
||||
|
||||
def _register(
|
||||
self,
|
||||
|
@ -210,6 +207,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
(create_profile_with_localpart,)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user_id,)
|
||||
)
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
@cached()
|
||||
def get_user_by_id(self, user_id):
|
||||
return self._simple_select_one(
|
||||
|
@ -236,22 +238,31 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
|
||||
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_set_password_hash(self, user_id, password_hash):
|
||||
"""
|
||||
NB. This does *not* evict any cache because the one use for this
|
||||
removes most of the entries subsequently anyway so it would be
|
||||
pointless. Use flush_user separately.
|
||||
"""
|
||||
yield self._simple_update_one('users', {
|
||||
def user_set_password_hash_txn(txn):
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
'users', {
|
||||
'name': user_id
|
||||
}, {
|
||||
},
|
||||
{
|
||||
'password_hash': password_hash
|
||||
})
|
||||
self.get_user_by_id.invalidate((user_id,))
|
||||
}
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user_id,)
|
||||
)
|
||||
return self.runInteraction(
|
||||
"user_set_password_hash", user_set_password_hash_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_delete_access_tokens(self, user_id, except_token_ids=[],
|
||||
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
||||
device_id=None,
|
||||
delete_refresh_tokens=False):
|
||||
"""
|
||||
|
@ -259,7 +270,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
|
||||
Args:
|
||||
user_id (str): ID of user the tokens belong to
|
||||
except_token_ids (list[str]): list of access_tokens which should
|
||||
except_token_id (str): list of access_tokens IDs which should
|
||||
*not* be deleted
|
||||
device_id (str|None): ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
|
@ -269,53 +280,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
def f(txn, table, except_tokens, call_after_delete):
|
||||
sql = "SELECT token FROM %s WHERE user_id = ?" % table
|
||||
clauses = [user_id]
|
||||
|
||||
def f(txn):
|
||||
keyvalues = {
|
||||
"user_id": user_id,
|
||||
}
|
||||
if device_id is not None:
|
||||
sql += " AND device_id = ?"
|
||||
clauses.append(device_id)
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
if except_tokens:
|
||||
sql += " AND id NOT IN (%s)" % (
|
||||
",".join(["?" for _ in except_tokens]),
|
||||
if delete_refresh_tokens:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="refresh_tokens",
|
||||
keyvalues=keyvalues,
|
||||
)
|
||||
clauses += except_tokens
|
||||
|
||||
txn.execute(sql, clauses)
|
||||
|
||||
rows = txn.fetchall()
|
||||
|
||||
n = 100
|
||||
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
|
||||
for chunk in chunks:
|
||||
if call_after_delete:
|
||||
for row in chunk:
|
||||
txn.call_after(call_after_delete, (row[0],))
|
||||
items = keyvalues.items()
|
||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||
values = [v for _, v in items]
|
||||
if except_token_id:
|
||||
where_clause += " AND id != ?"
|
||||
values.append(except_token_id)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM %s WHERE token in (%s)" % (
|
||||
table,
|
||||
",".join(["?" for _ in chunk]),
|
||||
), [r[0] for r in chunk]
|
||||
"SELECT token FROM access_tokens WHERE %s" % where_clause,
|
||||
values
|
||||
)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
for row in rows:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (row["token"],)
|
||||
)
|
||||
|
||||
# delete refresh tokens first, to stop new access tokens being
|
||||
# allocated while our backs are turned
|
||||
if delete_refresh_tokens:
|
||||
yield self.runInteraction(
|
||||
"user_delete_access_tokens", f,
|
||||
table="refresh_tokens",
|
||||
except_tokens=[],
|
||||
call_after_delete=None,
|
||||
txn.execute(
|
||||
"DELETE FROM access_tokens WHERE %s" % where_clause,
|
||||
values
|
||||
)
|
||||
|
||||
yield self.runInteraction(
|
||||
"user_delete_access_tokens", f,
|
||||
table="access_tokens",
|
||||
except_tokens=except_token_ids,
|
||||
call_after_delete=self.get_user_by_access_token.invalidate,
|
||||
)
|
||||
|
||||
def delete_access_token(self, access_token):
|
||||
|
@ -328,7 +331,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
},
|
||||
)
|
||||
|
||||
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (access_token,)
|
||||
)
|
||||
|
||||
return self.runInteraction("delete_access_token", f)
|
||||
|
||||
|
|
|
@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
|
|||
user_id, membership_list=[Membership.JOIN],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def forget(self, user_id, room_id):
|
||||
"""Indicate that user_id wishes to discard history for room_id."""
|
||||
def f(txn):
|
||||
|
@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
|
|||
" room_id = ?"
|
||||
)
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
yield self.runInteraction("forget_membership", f)
|
||||
self.was_forgotten_at.invalidate_all()
|
||||
self.who_forgot_in_room.invalidate_all()
|
||||
self.did_forget.invalidate((user_id, room_id))
|
||||
|
||||
txn.call_after(self.was_forgotten_at.invalidate_all)
|
||||
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.who_forgot_in_room, (room_id,)
|
||||
)
|
||||
return self.runInteraction("forget_membership", f)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2)
|
||||
def did_forget(self, user_id, room_id):
|
||||
|
|
23
synapse/storage/schema/delta/34/appservice_stream.sql
Normal file
23
synapse/storage/schema/delta/34/appservice_stream.sql
Normal file
|
@ -0,0 +1,23 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS appservice_stream_position(
|
||||
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
|
||||
stream_ordering BIGINT,
|
||||
CHECK (Lock='X')
|
||||
);
|
||||
|
||||
INSERT INTO appservice_stream_position (stream_ordering)
|
||||
SELECT COALESCE(MAX(stream_ordering), 0) FROM events;
|
46
synapse/storage/schema/delta/34/cache_stream.py
Normal file
46
synapse/storage/schema/delta/34/cache_stream.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.prepare_database import get_statements
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# This stream is used to notify replication slaves that some caches have
|
||||
# been invalidated that they cannot infer from the other streams.
|
||||
CREATE_TABLE = """
|
||||
CREATE TABLE cache_invalidation_stream (
|
||||
stream_id BIGINT,
|
||||
cache_func TEXT,
|
||||
keys TEXT[],
|
||||
invalidation_ts BIGINT
|
||||
);
|
||||
|
||||
CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
|
||||
"""
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
if not isinstance(database_engine, PostgresEngine):
|
||||
return
|
||||
|
||||
for statement in get_statements(CREATE_TABLE.splitlines()):
|
||||
cur.execute(statement)
|
||||
|
||||
|
||||
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||
pass
|
20
synapse/storage/schema/delta/34/push_display_name_rename.sql
Normal file
20
synapse/storage/schema/delta/34/push_display_name_rename.sql
Normal file
|
@ -0,0 +1,20 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
|
||||
UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
|
||||
|
||||
DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
|
||||
UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
|
32
synapse/storage/schema/delta/34/received_txn_purge.py
Normal file
32
synapse/storage/schema/delta/34/received_txn_purge.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
if isinstance(database_engine, PostgresEngine):
|
||||
cur.execute("TRUNCATE received_transactions")
|
||||
else:
|
||||
cur.execute("DELETE FROM received_transactions")
|
||||
|
||||
cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)")
|
||||
|
||||
|
||||
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||
pass
|
|
@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
|
|||
class SignatureStore(SQLBaseStore):
|
||||
"""Persistence for event signatures and hashes"""
|
||||
|
||||
@cached(lru=True)
|
||||
@cached()
|
||||
def get_event_reference_hash(self, event_id):
|
||||
return self._get_event_reference_hashes_txn(event_id)
|
||||
|
||||
|
|
|
@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
|
|||
return [r[0] for r in results]
|
||||
return self.runInteraction("get_current_state_for_key", f)
|
||||
|
||||
@cached(num_args=2, lru=True, max_entries=1000)
|
||||
@cached(num_args=2, max_entries=1000)
|
||||
def _get_state_group_from_group(self, group, types):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
|
|||
state_map = yield self.get_state_for_events([event_id], types)
|
||||
defer.returnValue(state_map[event_id])
|
||||
|
||||
@cached(num_args=2, lru=True, max_entries=10000)
|
||||
@cached(num_args=2, max_entries=10000)
|
||||
def _get_state_group_for_event(self, room_id, event_id):
|
||||
return self._simple_select_one_onecol(
|
||||
table="event_to_state_groups",
|
||||
|
|
|
@ -39,7 +39,7 @@ from ._base import SQLBaseStore
|
|||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
|
||||
import logging
|
||||
|
@ -234,12 +234,12 @@ class StreamStore(SQLBaseStore):
|
|||
results = {}
|
||||
room_ids = list(room_ids)
|
||||
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
|
||||
res = yield defer.gatherResults([
|
||||
res = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(self.get_room_events_stream_for_room)(
|
||||
room_id, from_key, to_key, limit, order=order,
|
||||
)
|
||||
for room_id in rm_ids
|
||||
])
|
||||
]))
|
||||
results.update(dict(zip(rm_ids, res)))
|
||||
|
||||
defer.returnValue(results)
|
||||
|
|
|
@ -62,10 +62,9 @@ class TransactionStore(SQLBaseStore):
|
|||
self.last_transaction = {}
|
||||
|
||||
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
|
||||
hs.get_clock().looping_call(
|
||||
self._persist_in_mem_txns,
|
||||
1000,
|
||||
)
|
||||
self._clock.looping_call(self._persist_in_mem_txns, 1000)
|
||||
|
||||
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
|
||||
|
||||
def get_received_txn_response(self, transaction_id, origin):
|
||||
"""For an incoming transaction from a given origin, check if we have
|
||||
|
@ -127,6 +126,7 @@ class TransactionStore(SQLBaseStore):
|
|||
"origin": origin,
|
||||
"response_code": code,
|
||||
"response_json": buffer(encode_canonical_json(response_dict)),
|
||||
"ts": self._clock.time_msec(),
|
||||
},
|
||||
or_ignore=True,
|
||||
desc="set_received_txn_response",
|
||||
|
@ -383,3 +383,12 @@ class TransactionStore(SQLBaseStore):
|
|||
yield self.runInteraction("_persist_in_mem_txns", f)
|
||||
except:
|
||||
logger.exception("Failed to persist transactions!")
|
||||
|
||||
def _cleanup_transactions(self):
|
||||
now = self._clock.time_msec()
|
||||
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
||||
|
||||
def _cleanup_transactions_txn(txn):
|
||||
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
||||
|
||||
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
|
||||
|
|
|
@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
return "t%d-%d" % (self.topological, self.stream)
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
|
||||
# Some arbitrary constants used for internal API enumerations. Don't rely on
|
||||
# exact values; always pass or compare symbolically
|
||||
class ThirdPartyEntityKind(object):
|
||||
USER = 'user'
|
||||
LOCATION = 'location'
|
||||
|
|
|
@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit):
|
|||
except StopIteration:
|
||||
pass
|
||||
|
||||
return defer.gatherResults([
|
||||
return preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(_concurrently_execute_inner)()
|
||||
for _ in xrange(limit)
|
||||
], consumeErrors=True).addErrback(unwrapFirstError)
|
||||
], consumeErrors=True)).addErrback(unwrapFirstError)
|
||||
|
||||
|
||||
class Linearizer(object):
|
||||
|
@ -181,7 +181,8 @@ class Linearizer(object):
|
|||
self.key_to_defer[key] = new_defer
|
||||
|
||||
if current_defer:
|
||||
yield preserve_context_over_deferred(current_defer)
|
||||
with PreserveLoggingContext():
|
||||
yield current_defer
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
@ -264,7 +265,7 @@ class ReadWriteLock(object):
|
|||
curr_readers.clear()
|
||||
self.key_to_current_writer[key] = new_defer
|
||||
|
||||
yield defer.gatherResults(to_wait_on)
|
||||
yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
|
|
@ -25,8 +25,7 @@ from synapse.util.logcontext import (
|
|||
from . import DEBUG_CACHES, register_cache
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections import namedtuple
|
||||
|
||||
import os
|
||||
import functools
|
||||
|
@ -54,16 +53,11 @@ class Cache(object):
|
|||
"metrics",
|
||||
)
|
||||
|
||||
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
|
||||
if lru:
|
||||
def __init__(self, name, max_entries=1000, keylen=1, tree=False):
|
||||
cache_type = TreeCache if tree else dict
|
||||
self.cache = LruCache(
|
||||
max_size=max_entries, keylen=keylen, cache_type=cache_type
|
||||
)
|
||||
self.max_entries = None
|
||||
else:
|
||||
self.cache = OrderedDict()
|
||||
self.max_entries = max_entries
|
||||
|
||||
self.name = name
|
||||
self.keylen = keylen
|
||||
|
@ -81,8 +75,8 @@ class Cache(object):
|
|||
"Cache objects can only be accessed from the main thread"
|
||||
)
|
||||
|
||||
def get(self, key, default=_CacheSentinel):
|
||||
val = self.cache.get(key, _CacheSentinel)
|
||||
def get(self, key, default=_CacheSentinel, callback=None):
|
||||
val = self.cache.get(key, _CacheSentinel, callback=callback)
|
||||
if val is not _CacheSentinel:
|
||||
self.metrics.inc_hits()
|
||||
return val
|
||||
|
@ -94,19 +88,15 @@ class Cache(object):
|
|||
else:
|
||||
return default
|
||||
|
||||
def update(self, sequence, key, value):
|
||||
def update(self, sequence, key, value, callback=None):
|
||||
self.check_thread()
|
||||
if self.sequence == sequence:
|
||||
# Only update the cache if the caches sequence number matches the
|
||||
# number that the cache had before the SELECT was started (SYN-369)
|
||||
self.prefill(key, value)
|
||||
self.prefill(key, value, callback=callback)
|
||||
|
||||
def prefill(self, key, value):
|
||||
if self.max_entries is not None:
|
||||
while len(self.cache) >= self.max_entries:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
self.cache[key] = value
|
||||
def prefill(self, key, value, callback=None):
|
||||
self.cache.set(key, value, callback=callback)
|
||||
|
||||
def invalidate(self, key):
|
||||
self.check_thread()
|
||||
|
@ -151,9 +141,21 @@ class CacheDescriptor(object):
|
|||
The wrapped function has another additional callable, called "prefill",
|
||||
which can be used to insert values into the cache specifically, without
|
||||
calling the calculation function.
|
||||
|
||||
Cached functions can be "chained" (i.e. a cached function can call other cached
|
||||
functions and get appropriately invalidated when they called caches are
|
||||
invalidated) by adding a special "cache_context" argument to the function
|
||||
and passing that as a kwarg to all caches called. For example::
|
||||
|
||||
@cachedInlineCallbacks(cache_context=True)
|
||||
def foo(self, key, cache_context):
|
||||
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
|
||||
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||
defer.returnValue(r1 + r2)
|
||||
|
||||
"""
|
||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
|
||||
inlineCallbacks=False):
|
||||
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
|
||||
inlineCallbacks=False, cache_context=False):
|
||||
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
|
||||
|
||||
self.orig = orig
|
||||
|
@ -165,15 +167,33 @@ class CacheDescriptor(object):
|
|||
|
||||
self.max_entries = max_entries
|
||||
self.num_args = num_args
|
||||
self.lru = lru
|
||||
self.tree = tree
|
||||
|
||||
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
||||
all_args = inspect.getargspec(orig)
|
||||
self.arg_names = all_args.args[1:num_args + 1]
|
||||
|
||||
if "cache_context" in all_args.args:
|
||||
if not cache_context:
|
||||
raise ValueError(
|
||||
"Cannot have a 'cache_context' arg without setting"
|
||||
" cache_context=True"
|
||||
)
|
||||
try:
|
||||
self.arg_names.remove("cache_context")
|
||||
except ValueError:
|
||||
pass
|
||||
elif cache_context:
|
||||
raise ValueError(
|
||||
"Cannot have cache_context=True without having an arg"
|
||||
" named `cache_context`"
|
||||
)
|
||||
|
||||
self.add_cache_context = cache_context
|
||||
|
||||
if len(self.arg_names) < self.num_args:
|
||||
raise Exception(
|
||||
"Not enough explicit positional arguments to key off of for %r."
|
||||
" (@cached cannot key off of *args or **kwars)"
|
||||
" (@cached cannot key off of *args or **kwargs)"
|
||||
% (orig.__name__,)
|
||||
)
|
||||
|
||||
|
@ -182,16 +202,29 @@ class CacheDescriptor(object):
|
|||
name=self.orig.__name__,
|
||||
max_entries=self.max_entries,
|
||||
keylen=self.num_args,
|
||||
lru=self.lru,
|
||||
tree=self.tree,
|
||||
)
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
||||
# Add temp cache_context so inspect.getcallargs doesn't explode
|
||||
if self.add_cache_context:
|
||||
kwargs["cache_context"] = None
|
||||
|
||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||
|
||||
# Add our own `cache_context` to argument list if the wrapped function
|
||||
# has asked for one
|
||||
if self.add_cache_context:
|
||||
kwargs["cache_context"] = _CacheContext(cache, cache_key)
|
||||
|
||||
try:
|
||||
cached_result_d = cache.get(cache_key)
|
||||
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
|
||||
|
||||
observer = cached_result_d.observe()
|
||||
if DEBUG_CACHES:
|
||||
|
@ -228,7 +261,7 @@ class CacheDescriptor(object):
|
|||
ret.addErrback(onErr)
|
||||
|
||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||
cache.update(sequence, cache_key, ret)
|
||||
cache.update(sequence, cache_key, ret, callback=invalidate_callback)
|
||||
|
||||
return preserve_context_over_deferred(ret.observe())
|
||||
|
||||
|
@ -297,6 +330,10 @@ class CacheListDescriptor(object):
|
|||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||
list_args = arg_dict[self.list_name]
|
||||
|
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
|
|||
key[self.list_pos] = arg
|
||||
|
||||
try:
|
||||
res = cache.get(tuple(key))
|
||||
res = cache.get(tuple(key), callback=invalidate_callback)
|
||||
if not res.has_succeeded():
|
||||
res = res.observe()
|
||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
||||
|
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
|
|||
|
||||
key = list(keyargs)
|
||||
key[self.list_pos] = arg
|
||||
cache.update(sequence, tuple(key), observer)
|
||||
cache.update(
|
||||
sequence, tuple(key), observer,
|
||||
callback=invalidate_callback
|
||||
)
|
||||
|
||||
def invalidate(f, key):
|
||||
cache.invalidate(key)
|
||||
|
@ -376,24 +416,29 @@ class CacheListDescriptor(object):
|
|||
return wrapped
|
||||
|
||||
|
||||
def cached(max_entries=1000, num_args=1, lru=True, tree=False):
|
||||
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
|
||||
def invalidate(self):
|
||||
self.cache.invalidate(self.key)
|
||||
|
||||
|
||||
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
lru=lru,
|
||||
tree=tree,
|
||||
cache_context=cache_context,
|
||||
)
|
||||
|
||||
|
||||
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
|
||||
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
lru=lru,
|
||||
tree=tree,
|
||||
inlineCallbacks=True,
|
||||
cache_context=cache_context,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
|
|||
|
||||
|
||||
class _Node(object):
|
||||
__slots__ = ["prev_node", "next_node", "key", "value"]
|
||||
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
|
||||
|
||||
def __init__(self, prev_node, next_node, key, value):
|
||||
def __init__(self, prev_node, next_node, key, value, callbacks=set()):
|
||||
self.prev_node = prev_node
|
||||
self.next_node = next_node
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.callbacks = callbacks
|
||||
|
||||
|
||||
class LruCache(object):
|
||||
|
@ -44,6 +45,9 @@ class LruCache(object):
|
|||
Least-recently-used cache.
|
||||
Supports del_multi only if cache_type=TreeCache
|
||||
If cache_type=TreeCache, all keys must be tuples.
|
||||
|
||||
Can also set callbacks on objects when getting/setting which are fired
|
||||
when that key gets invalidated/evicted.
|
||||
"""
|
||||
def __init__(self, max_size, keylen=1, cache_type=dict):
|
||||
cache = cache_type()
|
||||
|
@ -62,10 +66,10 @@ class LruCache(object):
|
|||
|
||||
return inner
|
||||
|
||||
def add_node(key, value):
|
||||
def add_node(key, value, callbacks=set()):
|
||||
prev_node = list_root
|
||||
next_node = prev_node.next_node
|
||||
node = _Node(prev_node, next_node, key, value)
|
||||
node = _Node(prev_node, next_node, key, value, callbacks)
|
||||
prev_node.next_node = node
|
||||
next_node.prev_node = node
|
||||
cache[key] = node
|
||||
|
@ -88,23 +92,41 @@ class LruCache(object):
|
|||
prev_node.next_node = next_node
|
||||
next_node.prev_node = prev_node
|
||||
|
||||
for cb in node.callbacks:
|
||||
cb()
|
||||
node.callbacks.clear()
|
||||
|
||||
@synchronized
|
||||
def cache_get(key, default=None):
|
||||
def cache_get(key, default=None, callback=None):
|
||||
node = cache.get(key, None)
|
||||
if node is not None:
|
||||
move_node_to_front(node)
|
||||
if callback:
|
||||
node.callbacks.add(callback)
|
||||
return node.value
|
||||
else:
|
||||
return default
|
||||
|
||||
@synchronized
|
||||
def cache_set(key, value):
|
||||
def cache_set(key, value, callback=None):
|
||||
node = cache.get(key, None)
|
||||
if node is not None:
|
||||
if value != node.value:
|
||||
for cb in node.callbacks:
|
||||
cb()
|
||||
node.callbacks.clear()
|
||||
|
||||
if callback:
|
||||
node.callbacks.add(callback)
|
||||
|
||||
move_node_to_front(node)
|
||||
node.value = value
|
||||
else:
|
||||
add_node(key, value)
|
||||
if callback:
|
||||
callbacks = set([callback])
|
||||
else:
|
||||
callbacks = set()
|
||||
add_node(key, value, callbacks)
|
||||
if len(cache) > max_size:
|
||||
todelete = list_root.prev_node
|
||||
delete_node(todelete)
|
||||
|
@ -148,6 +170,9 @@ class LruCache(object):
|
|||
def cache_clear():
|
||||
list_root.next_node = list_root
|
||||
list_root.prev_node = list_root
|
||||
for node in cache.values():
|
||||
for cb in node.callbacks:
|
||||
cb()
|
||||
cache.clear()
|
||||
|
||||
@synchronized
|
||||
|
|
|
@ -64,6 +64,9 @@ class TreeCache(object):
|
|||
self.size -= cnt
|
||||
return popped
|
||||
|
||||
def values(self):
|
||||
return [e.value for e in self.root.values()]
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
|
|
@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
|
|||
return res
|
||||
|
||||
|
||||
def preserve_context_over_deferred(deferred):
|
||||
def preserve_context_over_deferred(deferred, context=None):
|
||||
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||
be invoked with the current context.
|
||||
"""
|
||||
current_context = LoggingContext.current_context()
|
||||
d = _PreservingContextDeferred(current_context)
|
||||
if context is None:
|
||||
context = LoggingContext.current_context()
|
||||
d = _PreservingContextDeferred(context)
|
||||
deferred.chainDeferred(d)
|
||||
return d
|
||||
|
||||
|
@ -316,8 +317,13 @@ def preserve_fn(f):
|
|||
|
||||
def g(*args, **kwargs):
|
||||
with PreserveLoggingContext(current):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
res = f(*args, **kwargs)
|
||||
if isinstance(res, defer.Deferred):
|
||||
return preserve_context_over_deferred(
|
||||
res, context=LoggingContext.sentinel
|
||||
)
|
||||
else:
|
||||
return res
|
||||
return g
|
||||
|
||||
|
||||
|
|
|
@ -13,10 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
import synapse.metrics
|
||||
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
|
||||
|
@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution(
|
|||
)
|
||||
|
||||
|
||||
def measure_func(name):
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
@defer.inlineCallbacks
|
||||
def measured_func(self, *args, **kwargs):
|
||||
with Measure(self.clock, name):
|
||||
r = yield func(self, *args, **kwargs)
|
||||
defer.returnValue(r)
|
||||
return measured_func
|
||||
return wrapper
|
||||
|
||||
|
||||
class Measure(object):
|
||||
__slots__ = [
|
||||
"clock", "name", "start_context", "start", "new_context", "ru_utime",
|
||||
|
@ -64,7 +78,6 @@ class Measure(object):
|
|||
self.start = self.clock.time_msec()
|
||||
self.start_context = LoggingContext.current_context()
|
||||
if not self.start_context:
|
||||
logger.warn("Entered Measure without log context: %s", self.name)
|
||||
self.start_context = LoggingContext("Measure")
|
||||
self.start_context.__enter__()
|
||||
self.created_context = True
|
||||
|
@ -74,7 +87,7 @@ class Measure(object):
|
|||
self.db_txn_duration = self.start_context.db_txn_duration
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is not None or not self.start_context:
|
||||
if isinstance(exc_type, Exception) or not self.start_context:
|
||||
return
|
||||
|
||||
duration = self.clock.time_msec() - self.start
|
||||
|
@ -85,7 +98,7 @@ class Measure(object):
|
|||
if context != self.start_context:
|
||||
logger.warn(
|
||||
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
|
||||
context, self.start_context, self.name
|
||||
self.start_context, context, self.name
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -55,12 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
|||
given events
|
||||
events ([synapse.events.EventBase]): list of events to filter
|
||||
"""
|
||||
forgotten = yield defer.gatherResults([
|
||||
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(store.who_forgot_in_room)(
|
||||
room_id,
|
||||
)
|
||||
for room_id in frozenset(e.room_id for e in events)
|
||||
], consumeErrors=True)
|
||||
], consumeErrors=True))
|
||||
|
||||
# Set of membership event_ids that have been forgotten
|
||||
event_id_forgotten = frozenset(
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
from synapse.appservice import ApplicationService
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock
|
||||
from tests import unittest
|
||||
|
||||
|
@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
||||
)
|
||||
|
||||
self.store = Mock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_user_id_prefix_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(self.event))
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_user_id_prefix_no_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@someone_else:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(self.event))
|
||||
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_member_is_checked(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
|
@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
self.event.sender = "@someone_else:matrix.org"
|
||||
self.event.type = "m.room.member"
|
||||
self.event.state_key = "@irc_foobar:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(self.event))
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_id_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||
)
|
||||
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(self.event))
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_id_no_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||
)
|
||||
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(self.event))
|
||||
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_alias_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.assertTrue(self.service.is_interested(
|
||||
self.event,
|
||||
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
|
||||
))
|
||||
self.store.get_aliases_for_room.return_value = [
|
||||
"#irc_foobar:matrix.org", "#athing:matrix.org"
|
||||
]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
|
||||
def test_non_exclusive_alias(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
|
@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
"!irc_foobar:matrix.org"
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_alias_no_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.assertFalse(self.service.is_interested(
|
||||
self.event,
|
||||
aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
|
||||
))
|
||||
self.store.get_aliases_for_room.return_value = [
|
||||
"#xmpp_foobar:matrix.org", "#athing:matrix.org"
|
||||
]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertFalse((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_multiple_matches(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
|
@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(
|
||||
self.event,
|
||||
aliases_for_event=["#irc_barfoo:matrix.org"]
|
||||
))
|
||||
|
||||
def test_restrict_to_rooms(self):
|
||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||
_regex("!flibble_.*:matrix.org")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.event.room_id = "!wibblewoo:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(
|
||||
self.event,
|
||||
restrict_to=ApplicationService.NS_ROOMS
|
||||
))
|
||||
|
||||
def test_restrict_to_aliases(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#xmpp_.*:matrix.org")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(
|
||||
self.event,
|
||||
restrict_to=ApplicationService.NS_ALIASES,
|
||||
aliases_for_event=["#irc_barfoo:matrix.org"]
|
||||
))
|
||||
|
||||
def test_restrict_to_senders(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#xmpp_.*:matrix.org")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(
|
||||
self.event,
|
||||
restrict_to=ApplicationService.NS_USERS,
|
||||
aliases_for_event=["#xmpp_barfoo:matrix.org"]
|
||||
))
|
||||
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_interested_in_self(self):
|
||||
# make sure invites get through
|
||||
self.service.sender = "@appservice:name"
|
||||
|
@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
"membership": "invite"
|
||||
}
|
||||
self.event.state_key = self.service.sender
|
||||
self.assertTrue(self.service.is_interested(self.event))
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_member_list_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
join_list = [
|
||||
self.store.get_users_in_room.return_value = [
|
||||
"@alice:here",
|
||||
"@irc_fo:here", # AS user
|
||||
"@bob:here",
|
||||
]
|
||||
self.store.get_aliases_for_room.return_value = []
|
||||
|
||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(
|
||||
event=self.event,
|
||||
member_list=join_list
|
||||
))
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
event=self.event, store=self.store
|
||||
)))
|
||||
|
|
|
@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
self.txn_ctrl = Mock()
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
|
||||
|
||||
def test_send_single_event_no_queue(self):
|
||||
# Expect the event to be sent immediately.
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
from .. import unittest
|
||||
from tests.utils import MockClock
|
||||
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
|
||||
|
@ -32,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
hs.get_datastore = Mock(return_value=self.mock_store)
|
||||
hs.get_application_service_api = Mock(return_value=self.mock_as_api)
|
||||
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
|
||||
hs.get_clock.return_value = MockClock()
|
||||
self.handler = ApplicationServicesHandler(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
self.mock_as_api.push = Mock()
|
||||
yield self.handler.notify_interested_services(event)
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
|
||||
interested_service, event
|
||||
)
|
||||
|
@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
yield self.handler.notify_interested_services(event)
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.mock_as_api.query_user.assert_called_once_with(
|
||||
services[0], user_id
|
||||
)
|
||||
|
@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
yield self.handler.notify_interested_services(event)
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.assertFalse(
|
||||
self.mock_as_api.query_user.called,
|
||||
"query_user called when it shouldn't have been."
|
||||
|
@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
|
||||
room_id = "!alpha:bet"
|
||||
servers = ["aperture"]
|
||||
interested_service = self._mkservice(is_interested=True)
|
||||
interested_service = self._mkservice_alias(is_interested_in_alias=True)
|
||||
services = [
|
||||
self._mkservice(is_interested=False),
|
||||
self._mkservice_alias(is_interested_in_alias=False),
|
||||
interested_service,
|
||||
self._mkservice(is_interested=False)
|
||||
self._mkservice_alias(is_interested_in_alias=False)
|
||||
]
|
||||
|
||||
self.mock_store.get_app_services = Mock(return_value=services)
|
||||
|
@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
service.token = "mock_service_token"
|
||||
service.url = "mock_service_url"
|
||||
return service
|
||||
|
||||
def _mkservice_alias(self, is_interested_in_alias):
|
||||
service = Mock()
|
||||
service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
|
||||
service.token = "mock_service_token"
|
||||
service.url = "mock_service_url"
|
||||
return service
|
||||
|
|
|
@ -14,11 +14,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import pymacaroons
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
import synapse.api.errors
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class AuthHandlers(object):
|
||||
|
@ -31,11 +33,12 @@ class AuthTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(handlers=None)
|
||||
self.hs.handlers = AuthHandlers(self.hs)
|
||||
self.auth_handler = self.hs.handlers.auth_handler
|
||||
|
||||
def test_token_is_a_macaroon(self):
|
||||
self.hs.config.macaroon_secret_key = "this key is a huge secret"
|
||||
|
||||
token = self.hs.handlers.auth_handler.generate_access_token("some_user")
|
||||
token = self.auth_handler.generate_access_token("some_user")
|
||||
# Check that we can parse the thing with pymacaroons
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
# The most basic of sanity checks
|
||||
|
@ -46,7 +49,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.hs.config.macaroon_secret_key = "this key is a massive secret"
|
||||
self.hs.clock.now = 5000
|
||||
|
||||
token = self.hs.handlers.auth_handler.generate_access_token("a_user")
|
||||
token = self.auth_handler.generate_access_token("a_user")
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
|
||||
def verify_gen(caveat):
|
||||
|
@ -67,3 +70,46 @@ class AuthTestCase(unittest.TestCase):
|
|||
v.satisfy_general(verify_type)
|
||||
v.satisfy_general(verify_expiry)
|
||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||
|
||||
def test_short_term_login_token_gives_user_id(self):
|
||||
self.hs.clock.now = 1000
|
||||
|
||||
token = self.auth_handler.generate_short_term_login_token(
|
||||
"a_user", 5000
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
"a_user",
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
token
|
||||
)
|
||||
)
|
||||
|
||||
# when we advance the clock, the token should be rejected
|
||||
self.hs.clock.now = 6000
|
||||
with self.assertRaises(synapse.api.errors.AuthError):
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
token
|
||||
)
|
||||
|
||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||
token = self.auth_handler.generate_short_term_login_token(
|
||||
"a_user", 5000
|
||||
)
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
|
||||
self.assertEqual(
|
||||
"a_user",
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
macaroon.serialize()
|
||||
)
|
||||
)
|
||||
|
||||
# add another "user_id" caveat, which might allow us to override the
|
||||
# user_id.
|
||||
macaroon.add_first_party_caveat("user_id = b_user")
|
||||
|
||||
with self.assertRaises(synapse.api.errors.AuthError):
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
macaroon.serialize()
|
||||
)
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
|
||||
from synapse.util.caches.descriptors import Cache, cached
|
||||
|
@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
|
|||
cache.get(3)
|
||||
|
||||
def test_eviction_lru(self):
|
||||
cache = Cache("test", max_entries=2, lru=True)
|
||||
cache = Cache("test", max_entries=2)
|
||||
|
||||
cache.prefill(1, "one")
|
||||
cache.prefill(2, "two")
|
||||
|
@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEquals(a.func("foo").result, d.result)
|
||||
self.assertEquals(callcount[0], 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invalidate_context(self):
|
||||
callcount = [0]
|
||||
callcount2 = [0]
|
||||
|
||||
class A(object):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
@cached(cache_context=True)
|
||||
def func2(self, key, cache_context):
|
||||
callcount2[0] += 1
|
||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||
|
||||
a = A()
|
||||
yield a.func2("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 1)
|
||||
self.assertEquals(callcount2[0], 1)
|
||||
|
||||
a.func.invalidate(("foo",))
|
||||
yield a.func("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 2)
|
||||
self.assertEquals(callcount2[0], 1)
|
||||
|
||||
yield a.func2("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 2)
|
||||
self.assertEquals(callcount2[0], 2)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_eviction_context(self):
|
||||
callcount = [0]
|
||||
callcount2 = [0]
|
||||
|
||||
class A(object):
|
||||
@cached(max_entries=2)
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
@cached(cache_context=True)
|
||||
def func2(self, key, cache_context):
|
||||
callcount2[0] += 1
|
||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||
|
||||
a = A()
|
||||
yield a.func2("foo")
|
||||
yield a.func2("foo2")
|
||||
|
||||
self.assertEquals(callcount[0], 2)
|
||||
self.assertEquals(callcount2[0], 2)
|
||||
|
||||
yield a.func("foo3")
|
||||
|
||||
self.assertEquals(callcount[0], 3)
|
||||
self.assertEquals(callcount2[0], 2)
|
||||
|
||||
yield a.func2("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 4)
|
||||
self.assertEquals(callcount2[0], 3)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_double_get(self):
|
||||
callcount = [0]
|
||||
callcount2 = [0]
|
||||
|
||||
class A(object):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
@cached(cache_context=True)
|
||||
def func2(self, key, cache_context):
|
||||
callcount2[0] += 1
|
||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||
|
||||
a = A()
|
||||
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
|
||||
|
||||
yield a.func2("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 1)
|
||||
self.assertEquals(callcount2[0], 1)
|
||||
|
||||
a.func2.invalidate(("foo",))
|
||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
|
||||
|
||||
yield a.func2("foo")
|
||||
a.func2.invalidate(("foo",))
|
||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
|
||||
|
||||
self.assertEquals(callcount[0], 1)
|
||||
self.assertEquals(callcount2[0], 2)
|
||||
|
||||
a.func.invalidate(("foo",))
|
||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
|
||||
yield a.func("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 2)
|
||||
self.assertEquals(callcount2[0], 2)
|
||||
|
||||
yield a.func2("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 2)
|
||||
self.assertEquals(callcount2[0], 3)
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
|
||||
from . import unittest
|
||||
|
||||
from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs
|
||||
from synapse.rest.media.v1.preview_url_resource import (
|
||||
summarize_paragraphs, decode_and_calc_og
|
||||
)
|
||||
|
||||
|
||||
class PreviewTestCase(unittest.TestCase):
|
||||
|
@ -137,3 +139,79 @@ class PreviewTestCase(unittest.TestCase):
|
|||
" of old wooden houses in Northern Norway, the oldest house dating from"
|
||||
" 1789. The Arctic Cathedral, a modern church…"
|
||||
)
|
||||
|
||||
|
||||
class PreviewUrlTestCase(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
Some text.
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {
|
||||
"og:title": "Foo",
|
||||
"og:description": "Some text."
|
||||
})
|
||||
|
||||
def test_comment(self):
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
<!-- HTML comment -->
|
||||
Some text.
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {
|
||||
"og:title": "Foo",
|
||||
"og:description": "Some text."
|
||||
})
|
||||
|
||||
def test_comment2(self):
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
Some text.
|
||||
<!-- HTML comment -->
|
||||
Some more text.
|
||||
<p>Text</p>
|
||||
More text
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {
|
||||
"og:title": "Foo",
|
||||
"og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
|
||||
})
|
||||
|
||||
def test_script(self):
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
<script> (function() {})() </script>
|
||||
Some text.
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {
|
||||
"og:title": "Foo",
|
||||
"og:description": "Some text."
|
||||
})
|
||||
|
|
|
@ -19,6 +19,8 @@ from .. import unittest
|
|||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.treecache import TreeCache
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class LruCacheTestCase(unittest.TestCase):
|
||||
|
||||
|
@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
|
|||
self.assertEquals(cache.get("key"), 1)
|
||||
self.assertEquals(cache.setdefault("key", 2), 1)
|
||||
self.assertEquals(cache.get("key"), 1)
|
||||
cache["key"] = 2 # Make sure overriding works.
|
||||
self.assertEquals(cache.get("key"), 2)
|
||||
|
||||
def test_pop(self):
|
||||
cache = LruCache(1)
|
||||
|
@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
|
|||
cache["key"] = 1
|
||||
cache.clear()
|
||||
self.assertEquals(len(cache), 0)
|
||||
|
||||
|
||||
class LruCacheCallbacksTestCase(unittest.TestCase):
|
||||
def test_get(self):
|
||||
m = Mock()
|
||||
cache = LruCache(1)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.get("key", callback=m)
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.get("key", "value")
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.set("key", "value2")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
def test_multi_get(self):
|
||||
m = Mock()
|
||||
cache = LruCache(1)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.get("key", callback=m)
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.get("key", callback=m)
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.set("key", "value2")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
def test_set(self):
|
||||
m = Mock()
|
||||
cache = LruCache(1)
|
||||
|
||||
cache.set("key", "value", m)
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.set("key", "value2")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
def test_pop(self):
|
||||
m = Mock()
|
||||
cache = LruCache(1)
|
||||
|
||||
cache.set("key", "value", m)
|
||||
self.assertFalse(m.called)
|
||||
|
||||
cache.pop("key")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
cache.pop("key")
|
||||
self.assertEquals(m.call_count, 1)
|
||||
|
||||
def test_del_multi(self):
|
||||
m1 = Mock()
|
||||
m2 = Mock()
|
||||
m3 = Mock()
|
||||
m4 = Mock()
|
||||
cache = LruCache(4, 2, cache_type=TreeCache)
|
||||
|
||||
cache.set(("a", "1"), "value", m1)
|
||||
cache.set(("a", "2"), "value", m2)
|
||||
cache.set(("b", "1"), "value", m3)
|
||||
cache.set(("b", "2"), "value", m4)
|
||||
|
||||
self.assertEquals(m1.call_count, 0)
|
||||
self.assertEquals(m2.call_count, 0)
|
||||
self.assertEquals(m3.call_count, 0)
|
||||
self.assertEquals(m4.call_count, 0)
|
||||
|
||||
cache.del_multi(("a",))
|
||||
|
||||
self.assertEquals(m1.call_count, 1)
|
||||
self.assertEquals(m2.call_count, 1)
|
||||
self.assertEquals(m3.call_count, 0)
|
||||
self.assertEquals(m4.call_count, 0)
|
||||
|
||||
def test_clear(self):
|
||||
m1 = Mock()
|
||||
m2 = Mock()
|
||||
cache = LruCache(5)
|
||||
|
||||
cache.set("key1", "value", m1)
|
||||
cache.set("key2", "value", m2)
|
||||
|
||||
self.assertEquals(m1.call_count, 0)
|
||||
self.assertEquals(m2.call_count, 0)
|
||||
|
||||
cache.clear()
|
||||
|
||||
self.assertEquals(m1.call_count, 1)
|
||||
self.assertEquals(m2.call_count, 1)
|
||||
|
||||
def test_eviction(self):
|
||||
m1 = Mock(name="m1")
|
||||
m2 = Mock(name="m2")
|
||||
m3 = Mock(name="m3")
|
||||
cache = LruCache(2)
|
||||
|
||||
cache.set("key1", "value", m1)
|
||||
cache.set("key2", "value", m2)
|
||||
|
||||
self.assertEquals(m1.call_count, 0)
|
||||
self.assertEquals(m2.call_count, 0)
|
||||
self.assertEquals(m3.call_count, 0)
|
||||
|
||||
cache.set("key3", "value", m3)
|
||||
|
||||
self.assertEquals(m1.call_count, 1)
|
||||
self.assertEquals(m2.call_count, 0)
|
||||
self.assertEquals(m3.call_count, 0)
|
||||
|
||||
cache.set("key3", "value")
|
||||
|
||||
self.assertEquals(m1.call_count, 1)
|
||||
self.assertEquals(m2.call_count, 0)
|
||||
self.assertEquals(m3.call_count, 0)
|
||||
|
||||
cache.get("key2")
|
||||
|
||||
self.assertEquals(m1.call_count, 1)
|
||||
self.assertEquals(m2.call_count, 0)
|
||||
self.assertEquals(m3.call_count, 0)
|
||||
|
||||
cache.set("key1", "value", m1)
|
||||
|
||||
self.assertEquals(m1.call_count, 1)
|
||||
self.assertEquals(m2.call_count, 0)
|
||||
self.assertEquals(m3.call_count, 1)
|
||||
|
|
Loading…
Reference in a new issue