0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 21:33:53 +01:00

Merge branch 'release-v0.17.1' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-08-24 14:39:35 +01:00
commit 37638c06c5
98 changed files with 3277 additions and 1430 deletions

View file

@ -1,3 +1,50 @@
Changes in synapse v0.17.1 (2016-08-24)
=======================================
Changes:
* Delete old received_transactions rows (PR #1038)
* Pass through user-supplied content in /join/$room_id (PR #1039)
Bug fixes:
* Fix bug with backfill (PR #1040)
Changes in synapse v0.17.1-rc1 (2016-08-22)
===========================================
Features:
* Add notification API (PR #1028)
Changes:
* Don't print stack traces when failing to get remote keys (PR #996)
* Various federation /event/ perf improvements (PR #998)
* Only process one local membership event per room at a time (PR #1005)
* Move default display name push rule (PR #1011, #1023)
* Fix up preview URL API. Add tests. (PR #1015)
* Set ``Content-Security-Policy`` on media repo (PR #1021)
* Make notify_interested_services faster (PR #1022)
* Add usage stats to prometheus monitoring (PR #1037)
Bug fixes:
* Fix token login (PR #993)
* Fix CAS login (PR #994, #995)
* Fix /sync to not clobber status_msg (PR #997)
* Fix redacted state events to include prev_content (PR #1003)
* Fix some bugs in the auth/ldap handler (PR #1007)
* Fix backfill request to limit URI length, so that remotes don't reject the
requests due to path length limits (PR #1012)
* Fix AS push code to not send duplicate events (PR #1025)
Changes in synapse v0.17.0 (2016-08-08)
=======================================

View file

@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- Python 2.7
- At least 512 MB RAM.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the

97
docs/workers.rst Normal file
View file

@ -0,0 +1,97 @@
Scaling synapse via workers
---------------------------
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
postgres anyway if you care about scalability).
The workers communicate with the master synapse process via a synapse-specific
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
- port: 9092
bind_address: '127.0.0.1'
type: http
tls: false
x_forwarded: false
resources:
- names: [replication]
compress: false
Under **no circumstances** should this replication API listener be exposed to the
public internet; it currently implements no authentication whatsoever and is
unencrypted HTTP.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_url).
For instance::
worker_app: synapse.app.synchrotron
# The replication listener on the synapse to talk to.
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
worker_listeners:
- type: http
port: 8083
resources:
- names:
- client
worker_daemonize: True
worker_pid_file: /home/matrix/synapse/synchrotron.pid
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
in the given directory, e.g.::
synctl -a $CONFIG/workers start
Currently one should always restart all workers when restarting or upgrading
synapse, unless you explicitly know it's safe not to. For instance, restarting
synapse without restarting all the synchrotrons may result in broken typing
notifications.
To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!

View file

@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
tox -e py27

View file

@ -14,6 +14,7 @@ fi
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.17.0"
__version__ = "0.17.1"

View file

@ -675,27 +675,18 @@ class Auth(object):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
user_prefix = "user_id = "
user = None
user_id = None
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user_id = caveat.caveat_id[len(user_prefix):]
user = UserID.from_string(user_id)
elif caveat.caveat_id == "guest = true":
guest = True
user_id = self.get_user_id_from_macaroon(macaroon)
user = UserID.from_string(user_id)
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
)
if user is None:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id == "guest = true":
guest = True
if guest:
ret = {
@ -743,6 +734,29 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
AuthError if there is no user_id caveat in the macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@ -754,6 +768,7 @@ class Auth(object):
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required
"""
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")

209
synapse/app/appservice.py Normal file
View file

@ -0,0 +1,209 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks
def replicate(results):
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
replicate(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``notify_appservices: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.notify_appservices = True
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-appservice",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -51,7 +51,7 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics
from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
@ -385,6 +385,8 @@ def run(hs):
start_time = hs.get_clock().time()
stats = {}
@defer.inlineCallbacks
def phone_stats_home():
logger.info("Gathering stats for reporting")
@ -393,7 +395,10 @@ def run(hs):
if uptime < 0:
uptime = 0
stats = {}
# If the stats directory is empty then this is the first time we've
# reported stats.
first_time = not stats
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
@ -406,6 +411,25 @@ def run(hs):
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
else:
stats.pop("daily_messages", None)
if first_time:
# Add callbacks to report the synapse stats as metrics whenever
# prometheus requests them, typically every 30s.
# As some of the stats are expensive to calculate we only update
# them when synapse phones home to matrix.org every 24 hours.
metrics = get_metrics_for("synapse.usage")
metrics.add_callback("timestamp", lambda: stats["timestamp"])
metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
metrics.add_callback("total_users", lambda: stats["total_users"])
metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
metrics.add_callback(
"daily_active_users", lambda: stats["daily_active_users"]
)
metrics.add_callback(
"daily_messages", lambda: stats.get("daily_messages", 0)
)
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:

View file

@ -0,0 +1,212 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import (
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
)
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse media repository", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-media-repository",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -80,11 +80,6 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__
)
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@ -168,7 +163,6 @@ class PusherServer(HomeServer):
store = self.get_datastore()
replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
@ -220,21 +214,11 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids
)
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
poke_pushers(result)
except:

View file

@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
@ -119,11 +121,13 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state):
def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
@ -194,19 +198,39 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False
yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
self.user_to_current_state[user_id] = UserPresenceState(
state = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object):
@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
result, "typing", "typing_key", room="room_id"
)
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
typing_handler.process_replication(result)
presence_handler.process_replication(result)
yield presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)

View file

@ -14,6 +14,8 @@
# limitations under the License.
from synapse.api.constants import EventTypes
from twisted.internet import defer
import logging
import re
@ -79,13 +81,17 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None):
sender=None, id=None, protocols=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.namespaces = self._check_namespaces(namespaces)
self.id = id
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
@ -138,65 +144,66 @@ class ApplicationService(object):
return regex_obj["exclusive"]
return False
def _matches_user(self, event, member_list):
if (hasattr(event, "sender") and
self.is_interested_in_user(event.sender)):
return True
@defer.inlineCallbacks
def _matches_user(self, event, store):
if not event:
defer.returnValue(False)
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key
if (hasattr(event, "type") and event.type == EventTypes.Member
and hasattr(event, "state_key")
and self.is_interested_in_user(event.state_key)):
return True
if (event.type == EventTypes.Member and
self.is_interested_in_user(event.state_key)):
defer.returnValue(True)
if not store:
defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id)
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
return True
return False
defer.returnValue(True)
defer.returnValue(False)
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
def _matches_aliases(self, event, alias_list):
@defer.inlineCallbacks
def _matches_aliases(self, event, store):
if not store or not event:
defer.returnValue(False)
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
return True
return False
defer.returnValue(True)
defer.returnValue(False)
def is_interested(self, event, restrict_to=None, aliases_for_event=None,
member_list=None):
@defer.inlineCallbacks
def is_interested(self, event, store=None):
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
store(DataStore)
Returns:
bool: True if this service would like to know about this event.
"""
if aliases_for_event is None:
aliases_for_event = []
if member_list is None:
member_list = []
# Do cheap checks first
if self._matches_room_id(event):
defer.returnValue(True)
if restrict_to and restrict_to not in ApplicationService.NS_LIST:
# this is a programming error, so fail early and raise a general
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if (yield self._matches_aliases(event, store)):
defer.returnValue(True)
if not restrict_to:
return (self._matches_user(event, member_list)
or self._matches_aliases(event, aliases_for_event)
or self._matches_room_id(event))
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
if (yield self._matches_user(event, store)):
defer.returnValue(True)
defer.returnValue(False)
def is_interested_in_user(self, user_id):
return (
@ -216,6 +223,9 @@ class ApplicationService(object):
or user_id == self.sender
)
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind
import logging
import urllib
@ -24,6 +25,28 @@ import urllib
logger = logging.getLogger(__name__)
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
@ -71,6 +94,43 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events)

View file

@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
from twisted.internet import defer
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
import logging
logger = logging.getLogger(__name__)
@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer
)
self.queuer = _ServiceQueuer(self.txn_ctrl)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
@defer.inlineCallbacks
def start(self):
@ -94,38 +97,36 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run.
"""
def __init__(self, txn_ctrl):
def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
self.pending_requests = {} # dict of {service_id: Deferred}
self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event):
# if this service isn't being sent something
if not self.pending_requests.get(service.id):
self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
self.queued_events.setdefault(service.id, []).append(event)
preserve_fn(self._send_request)(service)
def _send_request(self, service, events):
# send request and add callbacks
d = self.txn_ctrl.send(service, events)
d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
@defer.inlineCallbacks
def _send_request(self, service):
if service.id in self.requests_in_flight:
return
def _on_request_finish(self, service):
self.pending_requests[service.id] = None
# if there are queued events, then send them.
if (service.id in self.queued_events
and len(self.queued_events[service.id]) > 0):
self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
self.requests_in_flight.add(service.id)
try:
while True:
events = self.queued_events.pop(service.id, [])
if not events:
return
def _on_request_fail(self, err):
logger.error("AS request failed: %s", err)
with Measure(self.clock, "servicequeuer.send"):
try:
yield self.txn_ctrl.send(service, events)
except:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
class _TransactionController(object):
@ -149,14 +150,12 @@ class _TransactionController(object):
if service_is_up:
sent = yield txn.send(self.as_api)
if sent:
txn.complete(self.store)
yield txn.complete(self.store)
else:
self._start_recoverer(service)
preserve_fn(self._start_recoverer)(service)
except Exception as e:
logger.exception(e)
self._start_recoverer(service)
# request has finished
defer.returnValue(service)
preserve_fn(self._start_recoverer)(service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):

View file

@ -28,6 +28,7 @@ class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs):
return """\
@ -122,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
@ -129,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
protocols=protocols,
)

View file

@ -22,6 +22,7 @@ from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
from synapse.util.metrics import Measure
from twisted.internet import defer
@ -61,6 +62,10 @@ Attributes:
"""
class KeyLookupError(ValueError):
pass
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@ -239,59 +244,60 @@ class Keyring(object):
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
with Measure(self.clock, "get_server_verify_keys"):
merged_results = {}
missing_keys = {}
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which verify requests we have keys
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
server_name = verify_request.server_name
result_keys = merged_results[server_name]
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext():
verify_request.deferred.callback((
server_name,
key_id,
result_keys[key_id],
))
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
# We now need to figure out which verify requests we have keys
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
server_name = verify_request.server_name
result_keys = merged_results[server_name]
if not missing_keys:
break
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext():
verify_request.deferred.callback((
server_name,
key_id,
result_keys[key_id],
))
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_keys:
break
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for verify_request in verify_requests:
@ -302,15 +308,15 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
res = yield preserve_context_over_deferred(defer.gatherResults(
[
self.store.get_server_verify_keys(
preserve_fn(self.store.get_server_verify_keys)(
server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
defer.returnValue(dict(res))
@ -331,13 +337,13 @@ class Keyring(object):
)
defer.returnValue({})
results = yield defer.gatherResults(
results = yield preserve_context_over_deferred(defer.gatherResults(
[
get_key(p_name, p_keys)
preserve_fn(get_key)(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
union_of_keys = {}
for result in results:
@ -363,7 +369,7 @@ class Keyring(object):
)
except Exception as e:
logger.info(
"Unable to getting key %r for %r directly: %s %s",
"Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
@ -377,13 +383,13 @@ class Keyring(object):
defer.returnValue(keys)
results = yield defer.gatherResults(
results = yield preserve_context_over_deferred(defer.gatherResults(
[
get_key(server_name, key_ids)
preserve_fn(get_key)(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
merged = {}
for result in results:
@ -425,7 +431,7 @@ class Keyring(object):
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
raise ValueError(
raise KeyLookupError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
@ -448,7 +454,7 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
raise ValueError(
raise KeyLookupError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
@ -460,9 +466,9 @@ class Keyring(object):
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
yield defer.gatherResults(
yield preserve_context_over_deferred(defer.gatherResults(
[
self.store_keys(
preserve_fn(self.store_keys)(
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@ -470,7 +476,7 @@ class Keyring(object):
for server_name, response_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
defer.returnValue(keys)
@ -491,10 +497,10 @@ class Keyring(object):
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
raise ValueError("Key response not signed by remote server")
raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
raise ValueError("Key response missing TLS fingerprints")
raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
@ -508,7 +514,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise ValueError("TLS certificate not allowed by fingerprints")
raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
@ -518,7 +524,7 @@ class Keyring(object):
keys.update(response_keys)
yield defer.gatherResults(
yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store_keys)(
server_name=key_server_name,
@ -528,7 +534,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
defer.returnValue(keys)
@ -560,14 +566,14 @@ class Keyring(object):
server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise ValueError(
raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
@ -594,7 +600,7 @@ class Keyring(object):
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
yield defer.gatherResults(
yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
@ -607,7 +613,7 @@ class Keyring(object):
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
results[server_name] = response_keys
@ -635,15 +641,15 @@ class Keyring(object):
if ("signatures" not in response
or server_name not in response["signatures"]):
raise ValueError("Key response not signed by remote server")
raise KeyLookupError("Key response not signed by remote server")
if "tls_certificate" not in response:
raise ValueError("Key response missing TLS certificate")
raise KeyLookupError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match")
raise KeyLookupError("TLS certificate doesn't match")
# Cache the result in the datastore.
@ -659,7 +665,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
raise ValueError(
raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
@ -696,7 +702,7 @@ class Keyring(object):
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults(
yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
@ -704,4 +710,4 @@ class Keyring(object):
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)

View file

@ -88,6 +88,8 @@ def prune_event(event):
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)(
allowed_fields,

View file

@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging
@ -102,10 +103,10 @@ class FederationBase(object):
warn, pdu
)
valid_pdus = yield defer.gatherResults(
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
deferreds,
consumeErrors=True
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@ -129,7 +130,7 @@ class FederationBase(object):
for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])

View file

@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
import synapse.metrics
@ -51,10 +52,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(
self._clear_tried_cache, 60 * 1000,
)
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
old_dict = self.pdu_destination_tried
self.pdu_destination_tried = {}
for event_id, destination_dict in old_dict.items():
destination_dict = {
dest: time
for dest, time in destination_dict.items()
if time + PDU_RETRY_TIME_MS > now
}
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@ -201,10 +226,10 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
defer.returnValue(pdus)
@ -240,8 +265,15 @@ class FederationClient(FederationBase):
if ev:
defer.returnValue(ev)
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None
for destination in destinations:
now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0)
if last_attempt + PDU_RETRY_TIME_MS > now:
continue
try:
limiter = yield get_retry_limiter(
destination,
@ -269,25 +301,19 @@ class FederationClient(FederationBase):
break
pdu_attempts[destination] = now
except SynapseError as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
pdu_attempts[destination] = now
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
@ -406,7 +432,7 @@ class FederationClient(FederationBase):
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
seen_events = yield self.store.get_events(event_ids)
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
@ -432,14 +458,16 @@ class FederationClient(FederationBase):
batch = set(missing_events[i:i + batch_size])
deferreds = [
self.get_pdu(
preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
res = yield defer.DeferredList(deferreds, consumeErrors=True)
res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success:
signed_events.append(result)
@ -828,14 +856,16 @@ class FederationClient(FederationBase):
return srvs
deferreds = [
self.get_pdu(
preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
res = yield defer.DeferredList(deferreds, consumeErrors=True)
res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
signed_events.append(val)

View file

@ -21,11 +21,11 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func
import synapse.metrics
import logging
@ -51,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer
self._clock = hs.get_clock()
self.clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
@ -82,7 +82,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
self._next_txn_id = int(self.clock.time_msec())
def can_send_to(self, destination):
"""Can we send messages to the given server?
@ -119,266 +119,215 @@ class TransactionQueue(object):
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order)
(pdu, order)
)
def chain(failure):
if not deferred.called:
deferred.errback(failure)
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
def log_failure(f):
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if not self.can_send_to(destination):
return
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
deferred = defer.Deferred()
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
).append(
(failure, deferred)
).append(failure)
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
def chain(f):
if not deferred.called:
deferred.errback(f)
def log_failure(f):
logger.warn("Failed to send failure to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
yield run_on_reactor()
while True:
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.debug(
"TX [%s] Transaction already in progress",
destination
)
return
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.debug(
"TX [%s] Transaction already in progress",
destination
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures
)
return
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination)
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
txn_id = str(self._next_txn_id)
try:
self.pending_transactions[destination] = 1
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
logger.debug("TX [%s] _attempt_new_transaction", destination)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
txn_id = str(self._next_txn_id)
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
with limiter:
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Persisting transaction...", destination)
logger.debug("TX [%s] Marked as delivered", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
logger.debug("TX [%s] Yielding to callbacks...", destination)
self._next_txn_id += 1
for deferred in deferreds:
if code == 200:
deferred.callback(None)
else:
deferred.errback(RuntimeError("Got status %d" % code))
yield self.transaction_actions.prepare_to_send(transaction)
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
logger.debug("TX [%s] Yielded to callbacks", destination)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
with limiter:
# Actually send the transaction
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
if code != 200:
for p in pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)

View file

@ -19,7 +19,6 @@ from .room import (
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler
@ -53,8 +52,6 @@ class Handlers(object):
self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)

View file

@ -16,7 +16,8 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService
from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging
@ -42,36 +43,73 @@ class ApplicationServicesHandler(object):
self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks
def notify_interested_services(self, event):
def notify_interested_services(self, current_id):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
prolonged length of time.
Args:
event(Event): The event to push out to interested services.
current_id(int): The current maximum ID.
"""
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
return # no services need notifying
services = yield self.store.get_app_services()
if not services or not self.notify_appservices:
return
# Do we know this user exists? If not, poke the user query API for
# all services which match that user regex. This needs to block as these
# user queries need to be made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
self.current_max = max(self.current_max, current_id)
if self.is_processing:
return
if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
)
# Fork off pushes to these services
for service in services:
self.scheduler.submit_event_for_as(service, event)
if not events:
break
for event in events:
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
continue # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
# Fork off pushes to these services
for service in services:
preserve_fn(self.scheduler.submit_event_for_as)(
service, event
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False
@defer.inlineCallbacks
def query_user_exists(self, user_id):
@ -104,11 +142,12 @@ class ApplicationServicesHandler(object):
association can be found.
"""
room_alias_str = room_alias.to_string()
alias_query_services = yield self._get_services_for_event(
event=None,
restrict_to=ApplicationService.NS_ALIASES,
alias_list=[room_alias_str]
)
services = yield self.store.get_app_services()
alias_query_services = [
s for s in services if (
s.is_interested_in_alias(room_alias_str)
)
]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str
@ -121,34 +160,35 @@ class ApplicationServicesHandler(object):
defer.returnValue(result)
@defer.inlineCallbacks
def _get_services_for_event(self, event, restrict_to="", alias_list=None):
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield preserve_context_over_deferred(defer.DeferredList([
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services
], consumeErrors=True))
ret = []
for (success, result) in results:
if success:
ret.extend(result)
defer.returnValue(ret)
@defer.inlineCallbacks
def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event.
Args:
event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns:
list<ApplicationService>: A list of services interested in this
event based on the service regex.
"""
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services()
interested_list = [
s for s in services if (
s.is_interested(event, restrict_to, alias_list, member_list)
yield s.is_interested(event, self.store)
)
]
defer.returnValue(interested_list)
@ -163,6 +203,14 @@ class ApplicationServicesHandler(object):
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):

View file

@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base
self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(result)
len(conn.response)
)
defer.returnValue(False)
@ -719,13 +719,14 @@ class AuthHandler(BaseHandler):
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", True)
return self.get_user_from_macaroon(macaroon)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True, user_id)
return user_id
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
@ -736,21 +737,11 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else []
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
@ -759,10 +750,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids
user_id, except_access_token_id
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids
user_id, except_access_token_id
)
@defer.inlineCallbacks

View file

@ -26,7 +26,9 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@ -249,7 +251,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member:
continue
try:
domain = UserID.from_string(ev.state_key).domain
domain = get_domain_from_id(ev.state_key)
except:
continue
@ -274,7 +276,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]):
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return
@ -284,9 +286,6 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill(
dest,
room_id,
@ -364,9 +363,9 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch
)
results = yield defer.gatherResults(
results = yield preserve_context_over_deferred(defer.gatherResults(
[
self.replication_layer.get_pdu(
preserve_fn(self.replication_layer.get_pdu)(
[dest],
event_id,
outlier=True,
@ -375,10 +374,10 @@ class FederationHandler(BaseHandler):
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
)).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id for event in results for a_id, _ in event.auth_events
a_id for event in results for a_id, _ in event.auth_events if event
)
missing_auth = required_auth - set(auth_events)
@ -455,6 +454,10 @@ class FederationHandler(BaseHandler):
)
max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",
@ -551,10 +554,10 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
states = yield defer.gatherResults([
self.state_handler.resolve_state_groups(room_id, [e])
states = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids
])
]))
states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple:
@ -1093,16 +1096,17 @@ class FederationHandler(BaseHandler):
)
if event:
# FIXME: This is a temporary work around where we occasionally
# return events slightly differently than when they were
# originally signed
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
if self.hs.is_mine_id(event.event_id):
# FIXME: This is a temporary work around where we occasionally
# return events slightly differently than when they were
# originally signed
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
)
if do_auth:
in_room = yield self.auth.check_host_in_room(
@ -1112,6 +1116,12 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server(
origin, event.room_id, [event]
)
event = events[0]
defer.returnValue(event)
else:
defer.returnValue(None)
@ -1158,9 +1168,9 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
contexts = yield defer.gatherResults(
contexts = yield preserve_context_over_deferred(defer.gatherResults(
[
self._prep_event(
preserve_fn(self._prep_event)(
origin,
ev_info["event"],
state=ev_info.get("state"),
@ -1168,7 +1178,7 @@ class FederationHandler(BaseHandler):
)
for ev_info in event_infos
]
)
))
yield self.store.persist_events(
[
@ -1452,9 +1462,9 @@ class FederationHandler(BaseHandler):
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
different_events = yield defer.gatherResults(
different_events = yield preserve_context_over_deferred(defer.gatherResults(
[
self.store.get_event(
preserve_fn(self.store.get_event)(
d,
allow_none=True,
allow_rejected=False,
@ -1463,7 +1473,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d]
],
consumeErrors=True
).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)

View file

@ -28,7 +28,8 @@ from synapse.types import (
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@ -502,15 +503,17 @@ class MessageHandler(BaseHandler):
lambda states: states[event.event_id]
)
(messages, token), current_state = yield defer.gatherResults(
[
self.store.get_recent_events_for_room(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
(messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
@ -719,9 +722,9 @@ class MessageHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults(
[
get_presence(),
get_receipts(),
self.store.get_recent_events_for_room(
preserve_fn(get_presence)(),
preserve_fn(get_receipts)(),
preserve_fn(self.store.get_recent_events_for_room)(
room_id,
limit=limit,
end_token=now_token.room_key,
@ -755,6 +758,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret)
@measure_func("_create_new_client_event")
@defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
@ -806,6 +810,7 @@ class MessageHandler(BaseHandler):
(event, context,)
)
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
self,
@ -934,7 +939,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
self.notifier.on_new_room_event(
yield self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
@ -944,6 +949,6 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event(
preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations,
)

View file

@ -503,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states)
@defer.inlineCallbacks
def _get_interested_parties(self, states):
def _get_interested_parties(self, states, calculate_remote_hosts=True):
"""Given a list of states return which entities (rooms, users, servers)
are interested in the given states.
@ -526,14 +526,15 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {}
for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states:
continue
if calculate_remote_hosts:
for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states:
continue
hosts = yield self.store.get_joined_hosts_for_room(room_id)
for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)
hosts = yield self.store.get_joined_hosts_for_room(room_id)
for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
@ -565,6 +566,16 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers.
@ -672,7 +683,7 @@ class PresenceHandler(object):
])
@defer.inlineCallbacks
def set_state(self, target_user, state):
def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
status_msg = state.get("status_msg", None)
@ -689,10 +700,13 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
new_fields = {
"state": presence,
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
"state": presence
}
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec()

View file

@ -59,10 +59,13 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids,
txn_id=None,
ratelimit=True,
content=None,
):
if content is None:
content = {}
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": membership}
content["membership"] = membership
if requester.is_guest:
content["kind"] = "guest"
@ -140,8 +143,9 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
content=None,
):
key = (target, room_id,)
key = (room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
@ -153,6 +157,7 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
)
defer.returnValue(result)
@ -168,7 +173,11 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
content=None,
):
if content is None:
content = {}
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@ -218,7 +227,7 @@ class RoomMemberHandler(BaseHandler):
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content = {"membership": Membership.JOIN}
content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
@ -272,6 +281,7 @@ class RoomMemberHandler(BaseHandler):
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
content=content,
)
@defer.inlineCallbacks

View file

@ -464,10 +464,10 @@ class SyncHandler(object):
else:
state = {}
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
@ -485,9 +485,9 @@ class SyncHandler(object):
)
defer.returnValue(notifs)
# There is no new information in this period, so your notification
# count is whatever it was last time.
defer.returnValue(None)
# There is no new information in this period, so your notification
# count is whatever it was last time.
defer.returnValue(None)
@defer.inlineCallbacks
def generate_sync_result(self, sync_config, since_token=None, full_state=False):

View file

@ -16,7 +16,9 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.metrics import Measure
from synapse.types import UserID
@ -169,13 +171,13 @@ class TypingHandler(object):
deferreds = []
for domain in domains:
if domain == self.server_name:
self._push_update_local(
preserve_fn(self._push_update_local)(
room_id=room_id,
user_id=user_id,
typing=typing
)
else:
deferreds.append(self.federation.send_edu(
deferreds.append(preserve_fn(self.federation.send_edu)(
destination=domain,
edu_type="m.typing",
content={
@ -185,7 +187,9 @@ class TypingHandler(object):
},
))
yield defer.DeferredList(deferreds, consumeErrors=True)
yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks
def _recv_edu(self, origin, content):

View file

@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60,
)
response = yield preserve_context_over_fn(
send_request,
)
response = yield preserve_context_over_fn(send_request)
log_result = "%d %s" % (response.code, response.phrase,)
break

View file

@ -19,6 +19,7 @@ from synapse.api.errors import (
)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict
from synapse.util.metrics import Measure
import synapse.metrics
import synapse.events
@ -74,12 +75,12 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0
def request_handler(report_metrics=True):
def request_handler(include_metrics=False):
"""Decorator for ``wrap_request_handler``"""
return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
def wrap_request_handler(request_handler, report_metrics):
def wrap_request_handler(request_handler, include_metrics=False):
"""Wraps a method that acts as a request handler with the necessary logging
and exception handling.
@ -103,54 +104,56 @@ def wrap_request_handler(request_handler, report_metrics):
_next_request_id += 1
with LoggingContext(request_id) as request_context:
if report_metrics:
with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
request_metrics.start(self.clock, name=self.__class__.__name__)
request_context.request = request_id
with request.processing():
try:
with PreserveLoggingContext(request_context):
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True
)
finally:
request_context.request = request_id
with request.processing():
try:
if report_metrics:
request_metrics.stop(
self.clock, request, self.__class__.__name__
with PreserveLoggingContext(request_context):
if include_metrics:
yield request_handler(self, request, request_metrics)
else:
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
pass
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True
)
finally:
try:
request_metrics.stop(
self.clock, request
)
except Exception as e:
logger.warn("Failed to stop metrics: %r", e)
return wrapped_request_handler
@ -220,9 +223,9 @@ class JsonResource(HttpServer, resource.Resource):
# It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself.
@request_handler(report_metrics=False)
@request_handler(include_metrics=True)
@defer.inlineCallbacks
def _async_render(self, request):
def _async_render(self, request, request_metrics):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
@ -231,9 +234,6 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, 200, {})
return
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
@ -247,12 +247,6 @@ class JsonResource(HttpServer, resource.Resource):
callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items()
@ -263,10 +257,13 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return
self._send_response(request, code, response)
try:
request_metrics.stop(self.clock, request, servlet_classname)
except:
pass
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
return
@ -298,11 +295,12 @@ class JsonResource(HttpServer, resource.Resource):
class RequestMetrics(object):
def start(self, clock):
def start(self, clock, name):
self.start = clock.time_msec()
self.start_context = LoggingContext.current_context()
self.name = name
def stop(self, clock, request, servlet_classname):
def stop(self, clock, request):
context = LoggingContext.current_context()
tag = ""
@ -316,26 +314,26 @@ class RequestMetrics(object):
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
incoming_requests_counter.inc(request.method, self.name, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
servlet_classname, tag
self.name, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
ru_utime, request.method, servlet_classname, tag
ru_utime, request.method, self.name, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
ru_stime, request.method, self.name, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname, tag
context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname, tag
context.db_txn_duration, request.method, self.name, tag
)

View file

@ -19,7 +19,8 @@ from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.metrics import Measure
from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
import synapse.metrics
@ -67,10 +68,8 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class.
"""
def __init__(self, user_id, rooms, current_token, time_now_ms,
appservice=None):
def __init__(self, user_id, rooms, current_token, time_now_ms):
self.user_id = user_id
self.appservice = appservice
self.rooms = set(rooms)
self.current_token = current_token
self.last_notified_ms = time_now_ms
@ -107,11 +106,6 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id)
if self.appservice:
notifier.appservice_to_user_streams.get(
self.appservice, set()
).discard(self)
def count_listeners(self):
return len(self.notify_deferred.observers())
@ -142,7 +136,6 @@ class Notifier(object):
def __init__(self, hs):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
@ -168,8 +161,6 @@ class Notifier(object):
all_user_streams |= x
for x in self.user_to_user_stream.values():
all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners)
@ -182,11 +173,8 @@ class Notifier(object):
"users",
lambda: len(self.user_to_user_stream),
)
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
)
@preserve_fn
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened
@ -208,6 +196,7 @@ class Notifier(object):
self.notify_replication()
@preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
@ -225,24 +214,11 @@ class Notifier(object):
else:
self._on_new_room_event(event, room_stream_id, extra_users)
@preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
self.appservice_handler.notify_interested_services(event)
app_streams = set()
for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
appservice, set()
)
app_streams |= app_user_streams
self.appservice_handler.notify_interested_services(room_stream_id)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
@ -251,35 +227,36 @@ class Notifier(object):
"room_key", room_stream_id,
users=extra_users,
rooms=[event.room_id],
extra_streams=app_streams,
)
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()):
@preserve_fn
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms.
"""
with PreserveLoggingContext():
user_streams = set()
with Measure(self.clock, "on_new_event"):
user_streams = set()
for user in users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
for user in users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set())
for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set())
time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
try:
user_stream.notify(stream_key, new_token, time_now_ms)
except:
logger.exception("Failed to notify listener")
time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
try:
user_stream.notify(stream_key, new_token, time_now_ms)
except:
logger.exception("Failed to notify listener")
self.notify_replication()
self.notify_replication()
@preserve_fn
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
@ -294,7 +271,6 @@ class Notifier(object):
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token()
if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id)
@ -302,7 +278,6 @@ class Notifier(object):
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
appservice=appservice,
current_token=current_token,
time_now_ms=self.clock.time_msec(),
)
@ -477,11 +452,6 @@ class Notifier(object):
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
def _user_joined_room(self, user_id, room_id):
new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None:

View file

@ -38,15 +38,16 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "handle_push_actions_for_event"):
with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.current_state
event, self.hs, self.store, context.state_group, context.current_state
)
with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state
)
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items()
]
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items()
]

View file

@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [
'dont_notify'
]
},
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
]
@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
}
]
},
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{
'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [

View file

@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks
def evaluator_for_event(event, hs, store, current_state):
room_id = event.room_id
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
def evaluator_for_event(event, hs, store, state_group, current_state):
rules_by_user = yield store.bulk_get_push_rules_for_room(
event.room_id, state_group, current_state
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher:
user_ids.add(invited_user)
rules_by_user = yield _get_rules(room_id, user_ids, store)
rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user
)
defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, user_ids, store
event.room_id, rules_by_user, store
))
@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562)
"""
def __init__(self, room_id, rules_by_user, users_in_room, store):
def __init__(self, room_id, rules_by_user, store):
self.room_id = room_id
self.rules_by_user = rules_by_user
self.users_in_room = users_in_room
self.store = store
@defer.inlineCallbacks

View file

@ -17,14 +17,15 @@ from twisted.internet import defer
from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event
)
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites, joins = yield defer.gatherResults([
store.get_invited_rooms_for_user(user_id),
store.get_rooms_for_user(user_id),
], consumeErrors=True)
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.get_invited_rooms_for_user)(user_id),
preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True))
my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",

View file

@ -17,7 +17,7 @@
from twisted.internet import defer
import pusher
from synapse.util.logcontext import preserve_fn
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor
import logging
@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_token_ids=[]):
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access tokens ids %r",
user_id, except_token_ids
"Removing all pushers for user %s except access tokens id %r",
user_id, except_access_token_id
)
for p in all:
if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
@ -130,10 +130,12 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
p.on_new_notifications(min_stream_id, max_stream_id)
preserve_fn(p.on_new_notifications)(
min_stream_id, max_stream_id
)
)
yield defer.gatherResults(deferreds)
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except:
logger.exception("Exception in pusher on_new_notifications")
@ -155,10 +157,10 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
p.on_new_receipts(min_stream_id, max_stream_id)
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
)
yield defer.gatherResults(deferreds)
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except:
logger.exception("Exception in pusher on_new_receipts")

View file

@ -41,6 +41,7 @@ STREAM_NAMES = (
("push_rules",),
("pushers",),
("state",),
("caches",),
)
@ -70,6 +71,7 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers.
* "caches": Cache invalidations.
The API takes two additional query parameters:
@ -129,6 +131,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@ -140,6 +143,7 @@ class ReplicationResource(Resource):
push_rules_token,
pushers_token,
state_token,
caches_token,
))
@request_handler()
@ -188,6 +192,7 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total)
@ -379,6 +384,20 @@ class ReplicationResource(Resource):
"position", "type", "state_key", "event_id"
))
@defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches
caches = request_streams.get("caches")
if caches is not None:
updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit
)
writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts"
))
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
@ -407,7 +426,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state"
"push_rules", "pushers", "state", "caches",
))):
__slots__ = []

View file

@ -14,15 +14,43 @@
# limitations under the License.
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
import logging
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
def stream_positions(self):
return {}
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)

View file

@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
get_app_services = DataStore.get_app_services.__func__
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__

View file

@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
].orig
]

View file

@ -25,6 +25,9 @@ class SlavedRegistrationStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
].orig
]
_query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View file

@ -46,7 +46,9 @@ from synapse.rest.client.v2_alpha import (
account_data,
report_event,
openid,
notifications,
devices,
thirdparty,
)
from synapse.http.server import JsonResource
@ -91,4 +93,6 @@ class ClientRestResource(JsonResource):
account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)

View file

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(WhoisRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
)
def __init__(self, hs):
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)

View file

@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore()

View file

@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):

View file

@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs)
self.event_stream_handler = hs.get_event_stream_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(
@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args:
room_id = request.args["room_id"][0]
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
chunk = yield handler.get_stream(
chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(requester.user, event_id)
event = yield self.event_handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:

View file

@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$")
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)

View file

@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request):
flows = []
@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield self.http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id
)
)
result = {
"user_id": registered_user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
token = login_submission.get("token", None)
@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas", releases=())
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request):
@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet):
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
user = None
attributes = None
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = (root[0].tag.endswith("authenticationSuccess"))
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
if attributes is None:
raise Exception("CAS response does not contain attributes")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(401, "Unsuccessful CAS response",
errcode=Codes.UNAUTHORIZED)
return user, attributes
def register_servlets(hs, http_server):
@ -512,5 +426,3 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

View file

@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)

View file

@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):

View file

@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@ -253,6 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
action="join",
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
content=content,
third_party_signed=content.get("third_party_signed", None),
)
@ -296,6 +312,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
@ -322,6 +342,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -351,6 +375,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -368,6 +396,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -388,6 +420,7 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs):
super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@ -424,6 +457,10 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server)
@ -462,6 +499,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@ -542,6 +583,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server)
@ -624,6 +669,10 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$"
)
def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)

View file

@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import (
RestServlet, parse_string, parse_integer
)
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id,
)
from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$", releases=())
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False)
limit = parse_integer(request, "limit", default=50)
limit = min(limit, 500)
push_actions = yield self.store.get_push_actions_for_user(
user_id, from_token, limit
)
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
user_id, 'm.read'
)
notif_event_ids = [pa["event_id"] for pa in push_actions]
notif_events = yield self.store.get_events(notif_event_ids)
returned_push_actions = []
next_token = None
for pa in push_actions:
returned_pa = {
"room_id": pa["room_id"],
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
"event": serialize_event(
notif_events[pa["event_id"]],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
),
}
if pa["room_id"] not in receipts_by_room:
returned_pa["read"] = False
else:
receipt = receipts_by_room[pa["room_id"]]
returned_pa["read"] = (
receipt["topological_ordering"], receipt["stream_ordering"]
) >= (
pa["topological_ordering"], pa["stream_ordering"]
)
returned_push_actions.append(returned_pa)
next_token = pa["stream_ordering"]
defer.returnValue((200, {
"notifications": returned_push_actions,
"next_token": next_token,
}))
def register_servlets(hs, http_server):
NotificationsServlet(hs).register(http_server)

View file

@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet):
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id = self.device_handler.check_device_registered(
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
return device_id
@defer.inlineCallbacks
def _do_guest_registration(self):

View file

@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence})
yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,

View file

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
defer.returnValue((200, results))
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
defer.returnValue((200, results))
def register_servlets(hs, http_server):
ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server)

View file

@ -15,6 +15,7 @@
from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes
from synapse.crypto.keyring import KeyLookupError
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@ -210,9 +211,10 @@ class RemoteKey(Resource):
yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids
)
except KeyLookupError as e:
logger.info("Failed to fetch key: %s", e)
except:
logger.exception("Failed to get key for %r", server_name)
pass
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)

View file

@ -45,6 +45,7 @@ class DownloadResource(Resource):
@request_handler()
@defer.inlineCallbacks
def _async_render_GET(self, request):
request.setHeader("Content-Security-Policy", "sandbox")
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name)

View file

@ -29,14 +29,13 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
from copy import deepcopy
import os
import re
import fnmatch
import cgi
import ujson as json
import urlparse
import itertools
import logging
logger = logging.getLogger(__name__)
@ -163,7 +162,7 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
if self._is_media(media_info['media_type']):
if _is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info
)
@ -184,11 +183,9 @@ class PreviewUrlResource(Resource):
logger.warn("Couldn't get dims for %s" % url)
# define our OG response for this media
elif self._is_html(media_info['media_type']):
elif _is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import etree
file = open(media_info['filename'])
body = file.read()
file.close()
@ -199,17 +196,35 @@ class PreviewUrlResource(Resource):
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
encoding = match.group(1) if match else "utf-8"
try:
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body, parser)
og = yield self._calc_og(tree, media_info, requester)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = yield self._calc_og(tree, media_info, requester)
og = decode_and_calc_og(body, media_info['uri'], encoding)
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
_rebase_url(og['og:image'], media_info['uri']), requester.user
)
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
else:
logger.warn("Failed to find any OG data in %s", url)
og = {}
@ -232,139 +247,6 @@ class PreviewUrlResource(Resource):
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
@defer.inlineCallbacks
def _calc_og(self, tree, media_info, requester):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
self._rebase_url(og['og:image'], media_info['uri']), requester.user
)
if self._is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
# We clone `tree` as we modify it.
cloned_tree = deepcopy(tree.find("body"))
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
for el in cloned_tree.iter(TAGS_TO_REMOVE):
el.getparent().remove(el)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el.text).strip()
for el in cloned_tree.iter()
if el.text and isinstance(el.tag, basestring) # Removes comments
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
defer.returnValue(og)
def _rebase_url(self, url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
@defer.inlineCallbacks
def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
@ -445,17 +327,171 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None,
})
def _is_media(self, content_type):
if content_type.lower().startswith("image/"):
return True
def _is_html(self, content_type):
content_type = content_type.lower()
if (
content_type.startswith("text/html") or
content_type.startswith("application/xhtml")
):
return True
def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body, parser)
og = _calc_og(tree, media_uri)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = _calc_og(tree, media_uri)
return og
def _calc_og(tree, media_uri):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = _rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
from lxml import etree
TAGS_TO_REMOVE = (
"header", "nav", "aside", "footer", "script", "style", etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
return og
def _iterate_over_text(tree, *tags_to_ignore):
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
# This is basically a stack that we extend using itertools.chain.
# This will either consist of an element to iterate over *or* a string
# to be returned.
elements = iter([tree])
while True:
el = elements.next()
if isinstance(el, basestring):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
yield el.text
# We add to the stack all the elements children, interspersed with
# each child's tail text (if it exists). The tail text of a node
# is text that comes *after* the node, so we always include it even
# if we ignore the child node.
elements = itertools.chain(
itertools.chain.from_iterable( # Basically a flatmap
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
elements
)
def _rebase_url(url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
def _is_media(content_type):
if content_type.lower().startswith("image/"):
return True
def _is_html(content_type):
content_type = content_type.lower()
if (
content_type.startswith("text/html") or
content_type.startswith("application/xhtml")
):
return True
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):

View file

@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
@ -94,6 +95,8 @@ class HomeServer(object):
'auth_handler',
'device_handler',
'e2e_keys_handler',
'event_handler',
'event_stream_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@ -214,6 +217,12 @@ class HomeServer(object):
def build_application_service_handler(self):
return ApplicationServicesHandler(self)
def build_event_handler(self):
return EventHandler(self)
def build_event_stream_handler(self):
return EventStreamHandler(self)
def build_event_sources(self):
return EventSources(self)

View file

@ -1,3 +1,4 @@
import synapse.api.auth
import synapse.handlers
import synapse.handlers.auth
import synapse.handlers.device
@ -6,6 +7,9 @@ import synapse.storage
import synapse.state
class HomeServer(object):
def get_auth(self) -> synapse.api.auth.Auth:
pass
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass

View file

@ -50,6 +50,7 @@ from .openid import OpenIdStore
from .client_ips import ClientIpStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")],
)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",

View file

@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict
from synapse.storage.engines import PostgresEngine
import synapse.metrics
@ -165,7 +166,7 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
@ -305,13 +306,14 @@ class SQLBaseStore(object):
func, *args, **kwargs
)
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
try:
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
finally:
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result)
@defer.inlineCallbacks
@ -860,6 +862,62 @@ class SQLBaseStore(object):
return cache, min_val
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
txn,
table="cache_invalidation_stream",
values={
"stream_id": stream_id,
"cache_func": cache_func.__name__,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
}
)
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying

View file

@ -218,38 +218,37 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
Returns:
AppServiceTransaction: A new transaction.
"""
def _create_appservice_txn(txn):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
# or it may be the highest in the txns list (which are waiting to be/are
# being sent)
last_txn_id = self._get_last_txn(txn, service.id)
txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,)
)
highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None:
highest_txn_id = 0
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
event_ids = json.dumps([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids)
)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events
)
return self.runInteraction(
"create_appservice_txn",
self._create_appservice_txn,
service, events
)
def _create_appservice_txn(self, txn, service, events):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
# or it may be the highest in the txns list (which are waiting to be/are
# being sent)
last_txn_id = self._get_last_txn(txn, service.id)
txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,)
)
highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None:
highest_txn_id = 0
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
event_ids = json.dumps([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids)
)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events
_create_appservice_txn,
)
def complete_appservice_txn(self, txn_id, service):
@ -263,39 +262,38 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves if this transaction was stored
successfully.
"""
return self.runInteraction(
"complete_appservice_txn",
self._complete_appservice_txn,
txn_id, service
)
def _complete_appservice_txn(self, txn, txn_id, service):
txn_id = int(txn_id)
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
# what was there before. If it isn't, we've got problems (e.g. the AS
# has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id)
if (last_txn_id + 1) != txn_id:
logger.error(
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
"sent it to the AS out of order. FIX ME. last_txn=%s "
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
service.id
def _complete_appservice_txn(txn):
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
# what was there before. If it isn't, we've got problems (e.g. the AS
# has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id)
if (last_txn_id + 1) != txn_id:
logger.error(
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
"sent it to the AS out of order. FIX ME. last_txn=%s "
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
service.id
)
# Set current txn_id for AS to 'txn_id'
self._simple_upsert_txn(
txn, "application_services_state", dict(as_id=service.id),
dict(last_txn=txn_id)
)
# Set current txn_id for AS to 'txn_id'
self._simple_upsert_txn(
txn, "application_services_state", dict(as_id=service.id),
dict(last_txn=txn_id)
)
# Delete txn
self._simple_delete_txn(
txn, "application_services_txns",
dict(txn_id=txn_id, as_id=service.id)
)
# Delete txn
self._simple_delete_txn(
txn, "application_services_txns",
dict(txn_id=txn_id, as_id=service.id)
return self.runInteraction(
"complete_appservice_txn",
_complete_appservice_txn,
)
@defer.inlineCallbacks
@ -309,10 +307,25 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves to an AppServiceTransaction or
None.
"""
def _get_oldest_unsent_txn(txn):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,)
)
rows = self.cursor_to_dict(txn)
if not rows:
return None
entry = rows[0]
return entry
entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn",
self._get_oldest_unsent_txn,
service
_get_oldest_unsent_txn,
)
if not entry:
@ -326,22 +339,6 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
service=service, id=entry["txn_id"], events=events
))
def _get_oldest_unsent_txn(self, txn, service):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,)
)
rows = self.cursor_to_dict(txn)
if not rows:
return None
entry = rows[0]
return entry
def _get_last_txn(self, txn, service_id):
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
@ -352,3 +349,45 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
return 0
else:
return int(last_txn_id[0]) # select 'last_txn' col
def set_appservice_last_pos(self, pos):
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@defer.inlineCallbacks
def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""
def get_new_events_for_appservice_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
" WHERE"
" (SELECT stream_ordering FROM appservice_stream_position)"
" < e.stream_ordering"
" AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (current_id, limit))
rows = txn.fetchall()
upper_bound = current_id
if len(rows) == limit:
upper_bound = rows[-1][0]
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
)
events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events))

View file

@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
Returns:
Deferred
"""
try:
yield self._simple_insert(
def alias_txn(txn):
self._simple_insert_txn(
txn,
"room_aliases",
{
"room_alias": room_alias.to_string(),
"room_id": room_id,
"creator": creator,
},
desc="create_room_alias_association",
)
self._simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[{
"room_alias": room_alias.to_string(),
"server": server,
} for server in servers],
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (room_id,)
)
try:
ret = yield self.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
for server in servers:
# TODO(erikj): Fix this to bulk insert
yield self._simple_insert(
"room_alias_servers",
{
"room_alias": room_alias.to_string(),
"server": server,
},
desc="create_room_alias_association",
)
self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(ret)
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(

View file

@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
)
self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
@ -337,6 +337,36 @@ class EventPushActionsStore(SQLBaseStore):
# Now return the first `limit`
defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50):
def f(txn):
before_clause = ""
if before:
before_clause = "AND stream_ordering < ?"
args = [user_id, before, limit]
else:
args = [user_id, limit]
sql = (
"SELECT epa.event_id, epa.room_id,"
" epa.stream_ordering, epa.topological_ordering,"
" epa.actions, epa.profile_tag, e.received_ts"
" FROM event_push_actions epa, events e"
" WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
" ORDER BY epa.stream_ordering DESC"
" LIMIT ?"
% (before_clause,)
)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
push_actions = yield self.runInteraction(
"get_push_actions_for_user", f
)
for pa in push_actions:
pa["actions"] = json.loads(pa["actions"])
defer.returnValue(push_actions)
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):

View file

@ -20,8 +20,11 @@ from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logcontext import (
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
)
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
@ -201,7 +204,7 @@ class EventsStore(SQLBaseStore):
deferreds = []
for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue(
d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs,
backfilled=backfilled,
current_state=None,
@ -211,7 +214,9 @@ class EventsStore(SQLBaseStore):
for room_id in partitioned.keys():
self._maybe_start_persisting(room_id)
return defer.gatherResults(deferreds, consumeErrors=True)
return preserve_context_over_deferred(
defer.gatherResults(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks
@log_function
@ -224,7 +229,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id)
yield deferred
yield preserve_context_over_deferred(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@ -600,7 +605,8 @@ class EventsStore(SQLBaseStore):
"rejections",
"redactions",
"room_memberships",
"state_events"
"state_events",
"topics"
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
@ -1086,7 +1092,7 @@ class EventsStore(SQLBaseStore):
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults(
res = yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
@ -1095,7 +1101,7 @@ class EventsStore(SQLBaseStore):
for row in rows
],
consumeErrors=True
)
))
defer.returnValue({
e.event.event_id: e
@ -1131,54 +1137,55 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
desc="_get_event_from_row_rejected_reason",
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
desc="_get_event_from_row_rejected_reason",
)
original_ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
original_ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
redacted_event = None
if redacted:
redacted_event = prune_event(original_ev)
redacted_event = None
if redacted:
redacted_event = prune_event(original_ev)
redaction_id = yield self._simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": redacted_event.event_id},
retcol="event_id",
desc="_get_event_from_row_redactions",
)
redaction_id = yield self._simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": redacted_event.event_id},
retcol="event_id",
desc="_get_event_from_row_redactions",
redacted_event.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False,
allow_none=True,
)
if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
)
redacted_event.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False,
allow_none=True,
)
if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 33
SCHEMA_VERSION = 34
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending",
)
@defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid):
result = yield self._simple_update_one(
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
def update_presence_list_txn(txn):
result = self._simple_update_one_txn(
txn,
table="presence_list",
keyvalues={
"user_id": observer_localpart,
"observed_user_id": observed_userid
},
updatevalues={"accepted": True},
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_accepted, (observer_localpart,)
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_observers_accepted, (observed_userid,)
)
return result
return self.runInteraction(
"set_presence_list_accepted", update_presence_list_txn,
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
if accepted:

View file

@ -16,6 +16,7 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer
import logging
@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True)
@cachedInlineCallbacks()
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules)
@cachedInlineCallbacks(lru=True)
@cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results)
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
cache_context):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and self.hs.is_mine_id(e.state_key)
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate,
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids):

View file

@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
@cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch(
table='pushers',

View file

@ -94,6 +94,31 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
" FROM receipts_linearized AS rl"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE rl.room_id = e.room_id"
" AND rl.event_id = e.event_id"
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
return txn.fetchall()
rows = yield self.runInteraction(
"get_receipts_for_user_with_orderings", f
)
defer.returnValue({
row[0]: {
"event_id": row[1],
"topological_ordering": row[2],
"stream_ordering": row[3],
} for row in rows
})
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients.
@ -120,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.

View file

@ -93,7 +93,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="add_refresh_token_to_user",
)
@defer.inlineCallbacks
def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False):
@ -115,7 +114,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Raises:
StoreError if the user_id could not be registered.
"""
yield self.runInteraction(
return self.runInteraction(
"register",
self._register,
user_id,
@ -127,8 +126,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
create_profile_with_localpart,
admin
)
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
def _register(
self,
@ -210,6 +207,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
(create_profile_with_localpart,)
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
@ -236,22 +238,31 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
yield self._simple_update_one('users', {
'name': user_id
}, {
'password_hash': password_hash
})
self.get_user_by_id.invalidate((user_id,))
def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
txn,
'users', {
'name': user_id
},
{
'password_hash': password_hash
}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
return self.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
@defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[],
def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None,
delete_refresh_tokens=False):
"""
@ -259,7 +270,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args:
user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should
except_token_id (str): list of access_tokens IDs which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
@ -269,53 +280,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Returns:
defer.Deferred:
"""
def f(txn, table, except_tokens, call_after_delete):
sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
def f(txn):
keyvalues = {
"user_id": user_id,
}
if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
keyvalues["device_id"] = device_id
if except_tokens:
sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_tokens]),
)
clauses += except_tokens
txn.execute(sql, clauses)
rows = txn.fetchall()
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
txn.execute(
"DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
if delete_refresh_tokens:
self._simple_delete_txn(
txn,
table="refresh_tokens",
keyvalues=keyvalues,
)
# delete refresh tokens first, to stop new access tokens being
# allocated while our backs are turned
if delete_refresh_tokens:
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)
txn.execute(
"SELECT token FROM access_tokens WHERE %s" % where_clause,
values
)
rows = self.cursor_to_dict(txn)
for row in rows:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],)
)
txn.execute(
"DELETE FROM access_tokens WHERE %s" % where_clause,
values
)
yield self.runInteraction(
"user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token):
@ -328,7 +331,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
},
)
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f)

View file

@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN],
)
@defer.inlineCallbacks
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all()
self.who_forgot_in_room.invalidate_all()
self.did_forget.invalidate((user_id, room_id))
txn.call_after(self.was_forgotten_at.invalidate_all)
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.who_forgot_in_room, (room_id,)
)
return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id):

View file

@ -0,0 +1,23 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS appservice_stream_position(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_ordering BIGINT,
CHECK (Lock='X')
);
INSERT INTO appservice_stream_position (stream_ordering)
SELECT COALESCE(MAX(stream_ordering), 0) FROM events;

View file

@ -0,0 +1,46 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
# This stream is used to notify replication slaves that some caches have
# been invalidated that they cannot infer from the other streams.
CREATE_TABLE = """
CREATE TABLE cache_invalidation_stream (
stream_id BIGINT,
cache_func TEXT,
keys TEXT[],
invalidation_ts BIGINT
);
CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
"""
def run_create(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
return
for statement in get_statements(CREATE_TABLE.splitlines()):
cur.execute(statement)
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -0,0 +1,20 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';

View file

@ -0,0 +1,32 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
cur.execute("TRUNCATE received_transactions")
else:
cur.execute("DELETE FROM received_transactions")
cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)")
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
@cached(lru=True)
@cached()
def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id)

View file

@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, lru=True, max_entries=1000)
@cached(num_args=2, max_entries=1000)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000)
@cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",

View file

@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
@ -234,12 +234,12 @@ class StreamStore(SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([
res = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit, order=order,
)
for room_id in rm_ids
])
]))
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)

View file

@ -62,10 +62,9 @@ class TransactionStore(SQLBaseStore):
self.last_transaction = {}
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
hs.get_clock().looping_call(
self._persist_in_mem_txns,
1000,
)
self._clock.looping_call(self._persist_in_mem_txns, 1000)
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
@ -127,6 +126,7 @@ class TransactionStore(SQLBaseStore):
"origin": origin,
"response_code": code,
"response_json": buffer(encode_canonical_json(response_dict)),
"ts": self._clock.time_msec(),
},
or_ignore=True,
desc="set_received_txn_response",
@ -383,3 +383,12 @@ class TransactionStore(SQLBaseStore):
yield self.runInteraction("_persist_in_mem_txns", f)
except:
logger.exception("Failed to persist transactions!")
def _cleanup_transactions(self):
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)

View file

@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'

View file

@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit):
except StopIteration:
pass
return defer.gatherResults([
return preserve_context_over_deferred(defer.gatherResults([
preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit)
], consumeErrors=True).addErrback(unwrapFirstError)
], consumeErrors=True)).addErrback(unwrapFirstError)
class Linearizer(object):
@ -181,7 +181,8 @@ class Linearizer(object):
self.key_to_defer[key] = new_defer
if current_defer:
yield preserve_context_over_deferred(current_defer)
with PreserveLoggingContext():
yield current_defer
@contextmanager
def _ctx_manager():
@ -264,7 +265,7 @@ class ReadWriteLock(object):
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
yield defer.gatherResults(to_wait_on)
yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():

View file

@ -25,8 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
from collections import OrderedDict
from collections import namedtuple
import os
import functools
@ -54,16 +53,11 @@ class Cache(object):
"metrics",
)
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru:
cache_type = TreeCache if tree else dict
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type
)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
def __init__(self, name, max_entries=1000, keylen=1, tree=False):
cache_type = TreeCache if tree else dict
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type
)
self.name = name
self.keylen = keylen
@ -81,8 +75,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@ -94,19 +88,15 @@ class Cache(object):
else:
return default
def update(self, sequence, key, value):
def update(self, sequence, key, value, callback=None):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value)
self.prefill(key, value, callback=callback)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback)
def invalidate(self, key):
self.check_thread()
@ -151,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
@cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
inlineCallbacks=False):
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@ -165,15 +167,33 @@ class CacheDescriptor(object):
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in all_args.args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
try:
self.arg_names.remove("cache_context")
except ValueError:
pass
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
" (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
@ -182,16 +202,29 @@ class CacheDescriptor(object):
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
tree=self.tree,
)
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
# Add temp cache_context so inspect.getcallargs doesn't explode
if self.add_cache_context:
kwargs["cache_context"] = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext(cache, cache_key)
try:
cached_result_d = cache.get(cache_key)
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@ -228,7 +261,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret)
cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@ -297,6 +330,10 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
res = cache.get(tuple(key))
res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
cache.update(sequence, tuple(key), observer)
cache.update(
sequence, tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
@ -376,24 +416,29 @@ class CacheListDescriptor(object):
return wrapped
def cached(max_entries=1000, num_args=1, lru=True, tree=False):
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
def invalidate(self):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
tree=tree,
cache_context=cache_context,
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
tree=tree,
inlineCallbacks=True,
cache_context=cache_context,
)

View file

@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object):
__slots__ = ["prev_node", "next_node", "key", "value"]
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value):
def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
self.callbacks = callbacks
class LruCache(object):
@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
@ -62,10 +66,10 @@ class LruCache(object):
return inner
def add_node(key, value):
def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value)
node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized
def cache_get(key, default=None):
def cache_get(key, default=None, callback=None):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
if callback:
node.callbacks.add(callback)
return node.value
else:
return default
@synchronized
def cache_set(key, value):
def cache_set(key, value, callback=None):
node = cache.get(key, None)
if node is not None:
if value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
move_node_to_front(node)
node.value = value
else:
add_node(key, value)
if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear():
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
for cb in node.callbacks:
cb()
cache.clear()
@synchronized

View file

@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt
return popped
def values(self):
return [e.value for e in self.root.values()]
def __len__(self):
return self.size

View file

@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
return res
def preserve_context_over_deferred(deferred):
def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
current_context = LoggingContext.current_context()
d = _PreservingContextDeferred(current_context)
if context is None:
context = LoggingContext.current_context()
d = _PreservingContextDeferred(context)
deferred.chainDeferred(d)
return d
@ -316,8 +317,13 @@ def preserve_fn(f):
def g(*args, **kwargs):
with PreserveLoggingContext(current):
return f(*args, **kwargs)
res = f(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(
res, context=LoggingContext.sentinel
)
else:
return res
return g

View file

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.util.logcontext import LoggingContext
import synapse.metrics
from functools import wraps
import logging
@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution(
)
def measure_func(name):
def wrapper(func):
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, name):
r = yield func(self, *args, **kwargs)
defer.returnValue(r)
return measured_func
return wrapper
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
@ -64,7 +78,6 @@ class Measure(object):
self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context()
if not self.start_context:
logger.warn("Entered Measure without log context: %s", self.name)
self.start_context = LoggingContext("Measure")
self.start_context.__enter__()
self.created_context = True
@ -74,7 +87,7 @@ class Measure(object):
self.db_txn_duration = self.start_context.db_txn_duration
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None or not self.start_context:
if isinstance(exc_type, Exception) or not self.start_context:
return
duration = self.clock.time_msec() - self.start
@ -85,7 +98,7 @@ class Measure(object):
if context != self.start_context:
logger.warn(
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
context, self.start_context, self.name
self.start_context, context, self.name
)
return

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes
from synapse.util.logcontext import preserve_fn
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging
@ -55,12 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
given events
events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield defer.gatherResults([
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
], consumeErrors=True))
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(

View file

@ -14,6 +14,8 @@
# limitations under the License.
from synapse.appservice import ApplicationService
from twisted.internet import defer
from mock import Mock
from tests import unittest
@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
)
self.store = Mock()
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@someone_else:matrix.org"
self.assertFalse(self.service.is_interested(self.event))
self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(self.service.is_interested(self.event))
self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
self.assertTrue(self.service.is_interested(
self.event,
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
))
self.store.get_aliases_for_room.return_value = [
"#irc_foobar:matrix.org", "#athing:matrix.org"
]
self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested(
self.event, self.store
)))
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
"!irc_foobar:matrix.org"
))
@defer.inlineCallbacks
def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
self.assertFalse(self.service.is_interested(
self.event,
aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
))
self.store.get_aliases_for_room.return_value = [
"#xmpp_foobar:matrix.org", "#athing:matrix.org"
]
self.store.get_users_in_room.return_value = []
self.assertFalse((yield self.service.is_interested(
self.event, self.store
)))
@defer.inlineCallbacks
def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(
self.event,
aliases_for_event=["#irc_barfoo:matrix.org"]
))
def test_restrict_to_rooms(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!flibble_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.event.room_id = "!wibblewoo:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ROOMS
))
def test_restrict_to_aliases(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#xmpp_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ALIASES,
aliases_for_event=["#irc_barfoo:matrix.org"]
))
def test_restrict_to_senders(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#xmpp_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_USERS,
aliases_for_event=["#xmpp_barfoo:matrix.org"]
))
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested(
self.event, self.store
)))
@defer.inlineCallbacks
def test_interested_in_self(self):
# make sure invites get through
self.service.sender = "@appservice:name"
@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
"membership": "invite"
}
self.event.state_key = self.service.sender
self.assertTrue(self.service.is_interested(self.event))
self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
join_list = [
self.store.get_users_in_room.return_value = [
"@alice:here",
"@irc_fo:here", # AS user
"@bob:here",
]
self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(self.service.is_interested(
event=self.event,
member_list=join_list
))
self.assertTrue((yield self.service.is_interested(
event=self.event, store=self.store
)))

View file

@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self):
self.txn_ctrl = Mock()
self.queuer = _ServiceQueuer(self.txn_ctrl)
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately.

View file

@ -15,6 +15,7 @@
from twisted.internet import defer
from .. import unittest
from tests.utils import MockClock
from synapse.handlers.appservice import ApplicationServicesHandler
@ -32,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore = Mock(return_value=self.mock_store)
hs.get_application_service_api = Mock(return_value=self.mock_as_api)
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
@defer.inlineCallbacks
@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
type="m.room.message",
room_id="!foo:bar"
)
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(event)
yield self.handler.notify_interested_services(0)
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id
)
@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
yield self.handler.notify_interested_services(0)
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been."
@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet"
servers = ["aperture"]
interested_service = self._mkservice(is_interested=True)
interested_service = self._mkservice_alias(is_interested_in_alias=True)
services = [
self._mkservice(is_interested=False),
self._mkservice_alias(is_interested_in_alias=False),
interested_service,
self._mkservice(is_interested=False)
self._mkservice_alias(is_interested_in_alias=False)
]
self.mock_store.get_app_services = Mock(return_value=services)
@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
def _mkservice_alias(self, is_interested_in_alias):
service = Mock()
service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service

View file

@ -14,11 +14,13 @@
# limitations under the License.
import pymacaroons
from twisted.internet import defer
import synapse
import synapse.api.errors
from synapse.handlers.auth import AuthHandler
from tests import unittest
from tests.utils import setup_test_homeserver
from twisted.internet import defer
class AuthHandlers(object):
@ -31,11 +33,12 @@ class AuthTestCase(unittest.TestCase):
def setUp(self):
self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler
def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret"
token = self.hs.handlers.auth_handler.generate_access_token("some_user")
token = self.auth_handler.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks
@ -46,7 +49,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000
token = self.hs.handlers.auth_handler.generate_access_token("a_user")
token = self.auth_handler.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat):
@ -67,3 +70,46 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_type)
v.satisfy_general(verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000
token = self.auth_handler.generate_short_term_login_token(
"a_user", 5000
)
self.assertEqual(
"a_user",
self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.auth_handler.generate_short_term_login_token(
"a_user", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token)
self.assertEqual(
"a_user",
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)

View file

@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
from mock import Mock
from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
def test_eviction_lru(self):
cache = Cache("test", max_entries=2, lru=True)
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)

View file

@ -15,7 +15,9 @@
from . import unittest
from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs
from synapse.rest.media.v1.preview_url_resource import (
summarize_paragraphs, decode_and_calc_og
)
class PreviewTestCase(unittest.TestCase):
@ -137,3 +139,79 @@ class PreviewTestCase(unittest.TestCase):
" of old wooden houses in Northern Norway, the oldest house dating from"
" 1789. The Arctic Cathedral, a modern church…"
)
class PreviewUrlTestCase(unittest.TestCase):
def test_simple(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})
def test_comment(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
<!-- HTML comment -->
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})
def test_comment2(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
<!-- HTML comment -->
Some more text.
<p>Text</p>
More text
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
})
def test_script(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
<script> (function() {})() </script>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})

View file

@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from mock import Mock
class LruCacheTestCase(unittest.TestCase):
@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1)
cache["key"] = 2 # Make sure overriding works.
self.assertEquals(cache.get("key"), 2)
def test_pop(self):
cache = LruCache(1)
@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)
class LruCacheCallbacksTestCase(unittest.TestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_multi_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_set(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.set("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_pop(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.pop("key")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
cache.pop("key")
self.assertEquals(m.call_count, 1)
def test_del_multi(self):
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
cache.del_multi(("a",))
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
def test_clear(self):
m1 = Mock()
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
cache.clear()
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
def test_eviction(self):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.get("key2")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)