0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 15:01:23 +01:00

Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api

This commit is contained in:
David Baker 2016-08-18 17:15:26 +01:00
commit 602c84cd9c
53 changed files with 1546 additions and 641 deletions

View file

@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
tox -e py27

View file

@ -14,6 +14,7 @@ fi
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

209
synapse/app/appservice.py Normal file
View file

@ -0,0 +1,209 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks
def replicate(results):
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
replicate(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``notify_appservices: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.notify_appservices = True
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-appservice",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -0,0 +1,212 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import (
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
)
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse media repository", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-media-repository",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -80,11 +80,6 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__
)
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@ -168,7 +163,6 @@ class PusherServer(HomeServer):
store = self.get_datastore()
replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
@ -220,21 +214,11 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids
)
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
poke_pushers(result)
except:

View file

@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
@ -119,11 +121,13 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state):
def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
@ -194,19 +198,39 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False
yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
self.user_to_current_state[user_id] = UserPresenceState(
state = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object):
@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
result, "typing", "typing_key", room="room_id"
)
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
typing_handler.process_replication(result)
presence_handler.process_replication(result)
yield presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)

View file

@ -14,6 +14,8 @@
# limitations under the License.
from synapse.api.constants import EventTypes
from twisted.internet import defer
import logging
import re
@ -138,65 +140,66 @@ class ApplicationService(object):
return regex_obj["exclusive"]
return False
def _matches_user(self, event, member_list):
if (hasattr(event, "sender") and
self.is_interested_in_user(event.sender)):
return True
@defer.inlineCallbacks
def _matches_user(self, event, store):
if not event:
defer.returnValue(False)
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key
if (hasattr(event, "type") and event.type == EventTypes.Member
and hasattr(event, "state_key")
and self.is_interested_in_user(event.state_key)):
return True
if (event.type == EventTypes.Member and
self.is_interested_in_user(event.state_key)):
defer.returnValue(True)
if not store:
defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id)
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
return True
return False
defer.returnValue(True)
defer.returnValue(False)
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
def _matches_aliases(self, event, alias_list):
@defer.inlineCallbacks
def _matches_aliases(self, event, store):
if not store or not event:
defer.returnValue(False)
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
return True
return False
defer.returnValue(True)
defer.returnValue(False)
def is_interested(self, event, restrict_to=None, aliases_for_event=None,
member_list=None):
@defer.inlineCallbacks
def is_interested(self, event, store=None):
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
store(DataStore)
Returns:
bool: True if this service would like to know about this event.
"""
if aliases_for_event is None:
aliases_for_event = []
if member_list is None:
member_list = []
# Do cheap checks first
if self._matches_room_id(event):
defer.returnValue(True)
if restrict_to and restrict_to not in ApplicationService.NS_LIST:
# this is a programming error, so fail early and raise a general
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if (yield self._matches_aliases(event, store)):
defer.returnValue(True)
if not restrict_to:
return (self._matches_user(event, member_list)
or self._matches_aliases(event, aliases_for_event)
or self._matches_room_id(event))
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
if (yield self._matches_user(event, store)):
defer.returnValue(True)
defer.returnValue(False)
def is_interested_in_user(self, user_id):
return (

View file

@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
from twisted.internet import defer
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
import logging
logger = logging.getLogger(__name__)
@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer
)
self.queuer = _ServiceQueuer(self.txn_ctrl)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
@defer.inlineCallbacks
def start(self):
@ -94,38 +97,36 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run.
"""
def __init__(self, txn_ctrl):
def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
self.pending_requests = {} # dict of {service_id: Deferred}
self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event):
# if this service isn't being sent something
if not self.pending_requests.get(service.id):
self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
self.queued_events.setdefault(service.id, []).append(event)
preserve_fn(self._send_request)(service)
def _send_request(self, service, events):
# send request and add callbacks
d = self.txn_ctrl.send(service, events)
d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
@defer.inlineCallbacks
def _send_request(self, service):
if service.id in self.requests_in_flight:
return
def _on_request_finish(self, service):
self.pending_requests[service.id] = None
# if there are queued events, then send them.
if (service.id in self.queued_events
and len(self.queued_events[service.id]) > 0):
self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
self.requests_in_flight.add(service.id)
try:
while True:
events = self.queued_events.pop(service.id, [])
if not events:
return
def _on_request_fail(self, err):
logger.error("AS request failed: %s", err)
with Measure(self.clock, "servicequeuer.send"):
try:
yield self.txn_ctrl.send(service, events)
except:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
class _TransactionController(object):
@ -155,8 +156,6 @@ class _TransactionController(object):
except Exception as e:
logger.exception(e)
self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):

View file

@ -28,6 +28,7 @@ class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs):
return """\

View file

@ -19,7 +19,6 @@ from .room import (
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler
@ -53,8 +52,6 @@ class Handlers(object):
self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)

View file

@ -16,7 +16,8 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService
from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn
import logging
@ -42,25 +43,53 @@ class ApplicationServicesHandler(object):
self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks
def notify_interested_services(self, event):
def notify_interested_services(self, current_id):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
prolonged length of time.
Args:
event(Event): The event to push out to interested services.
current_id(int): The current maximum ID.
"""
services = yield self.store.get_app_services()
if not services or not self.notify_appservices:
return
self.current_max = max(self.current_max, current_id)
if self.is_processing:
return
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
)
if not events:
break
for event in events:
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
return # no services need notifying
continue # no services need notifying
# Do we know this user exists? If not, poke the user query API for
# all services which match that user regex. This needs to block as these
# user queries need to be made BEFORE pushing the event.
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
@ -71,7 +100,16 @@ class ApplicationServicesHandler(object):
# Fork off pushes to these services
for service in services:
self.scheduler.submit_event_for_as(service, event)
preserve_fn(self.scheduler.submit_event_for_as)(
service, event
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False
@defer.inlineCallbacks
def query_user_exists(self, user_id):
@ -104,11 +142,12 @@ class ApplicationServicesHandler(object):
association can be found.
"""
room_alias_str = room_alias.to_string()
alias_query_services = yield self._get_services_for_event(
event=None,
restrict_to=ApplicationService.NS_ALIASES,
alias_list=[room_alias_str]
services = yield self.store.get_app_services()
alias_query_services = [
s for s in services if (
s.is_interested_in_alias(room_alias_str)
)
]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str
@ -121,34 +160,19 @@ class ApplicationServicesHandler(object):
defer.returnValue(result)
@defer.inlineCallbacks
def _get_services_for_event(self, event, restrict_to="", alias_list=None):
def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event.
Args:
event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns:
list<ApplicationService>: A list of services interested in this
event based on the service regex.
"""
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services()
interested_list = [
s for s in services if (
s.is_interested(event, restrict_to, alias_list, member_list)
yield s.is_interested(event, self.store)
)
]
defer.returnValue(interested_list)

View file

@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base
self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(result)
len(conn.response)
)
defer.returnValue(False)
@ -741,7 +741,7 @@ class AuthHandler(BaseHandler):
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else []
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
@ -750,10 +750,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids
user_id, except_access_token_id
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids
user_id, except_access_token_id
)
@defer.inlineCallbacks

View file

@ -274,7 +274,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]):
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return
@ -284,9 +284,6 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill(
dest,
room_id,
@ -455,6 +452,10 @@ class FederationHandler(BaseHandler):
)
max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",

View file

@ -503,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states)
@defer.inlineCallbacks
def _get_interested_parties(self, states):
def _get_interested_parties(self, states, calculate_remote_hosts=True):
"""Given a list of states return which entities (rooms, users, servers)
are interested in the given states.
@ -526,6 +526,7 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {}
if calculate_remote_hosts:
for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states:
@ -565,6 +566,16 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers.
@ -672,7 +683,7 @@ class PresenceHandler(object):
])
@defer.inlineCallbacks
def set_state(self, target_user, state):
def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
status_msg = state.get("status_msg", None)
@ -689,10 +700,13 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
new_fields = {
"state": presence,
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
"state": presence
}
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec()

View file

@ -141,7 +141,7 @@ class RoomMemberHandler(BaseHandler):
third_party_signed=None,
ratelimit=True,
):
key = (target, room_id,)
key = (room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(

View file

@ -67,10 +67,8 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class.
"""
def __init__(self, user_id, rooms, current_token, time_now_ms,
appservice=None):
def __init__(self, user_id, rooms, current_token, time_now_ms):
self.user_id = user_id
self.appservice = appservice
self.rooms = set(rooms)
self.current_token = current_token
self.last_notified_ms = time_now_ms
@ -107,11 +105,6 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id)
if self.appservice:
notifier.appservice_to_user_streams.get(
self.appservice, set()
).discard(self)
def count_listeners(self):
return len(self.notify_deferred.observers())
@ -142,7 +135,6 @@ class Notifier(object):
def __init__(self, hs):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
@ -168,8 +160,6 @@ class Notifier(object):
all_user_streams |= x
for x in self.user_to_user_stream.values():
all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners)
@ -182,10 +172,6 @@ class Notifier(object):
"users",
lambda: len(self.user_to_user_stream),
)
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
)
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
@ -228,21 +214,7 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
self.appservice_handler.notify_interested_services(event)
app_streams = set()
for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
appservice, set()
)
app_streams |= app_user_streams
self.appservice_handler.notify_interested_services(room_stream_id)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
@ -251,11 +223,9 @@ class Notifier(object):
"room_key", room_stream_id,
users=extra_users,
rooms=[event.room_id],
extra_streams=app_streams,
)
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()):
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms.
@ -294,7 +264,6 @@ class Notifier(object):
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token()
if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id)
@ -302,7 +271,6 @@ class Notifier(object):
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
appservice=appservice,
current_token=current_token,
time_now_ms=self.clock.time_msec(),
)
@ -477,11 +445,6 @@ class Notifier(object):
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
def _user_joined_room(self, user_id, room_id):
new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None:

View file

@ -38,11 +38,12 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "handle_push_actions_for_event"):
with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.current_state
)
with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state
)

View file

@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [
'dont_notify'
]
},
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
]
@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
}
]
},
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{
'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [

View file

@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_token_ids=[]):
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access tokens ids %r",
user_id, except_token_ids
"Removing all pushers for user %s except access tokens id %r",
user_id, except_access_token_id
)
for p in all:
if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']

View file

@ -41,6 +41,7 @@ STREAM_NAMES = (
("push_rules",),
("pushers",),
("state",),
("caches",),
)
@ -70,6 +71,7 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers.
* "caches": Cache invalidations.
The API takes two additional query parameters:
@ -129,6 +131,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@ -140,6 +143,7 @@ class ReplicationResource(Resource):
push_rules_token,
pushers_token,
state_token,
caches_token,
))
@request_handler()
@ -188,6 +192,7 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total)
@ -379,6 +384,20 @@ class ReplicationResource(Resource):
"position", "type", "state_key", "event_id"
))
@defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches
caches = request_streams.get("caches")
if caches is not None:
updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit
)
writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts"
))
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
@ -407,7 +426,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state"
"push_rules", "pushers", "state", "caches",
))):
__slots__ = []

View file

@ -14,15 +14,43 @@
# limitations under the License.
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
import logging
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
def stream_positions(self):
return {}
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)

View file

@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
get_app_services = DataStore.get_app_services.__func__
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__

View file

@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
].orig
]

View file

@ -25,6 +25,9 @@ class SlavedRegistrationStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
].orig
]
_query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View file

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(WhoisRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
)
def __init__(self, hs):
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)

View file

@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore()

View file

@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):

View file

@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs)
self.event_stream_handler = hs.get_event_stream_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(
@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args:
room_id = request.args["room_id"][0]
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
chunk = yield handler.get_stream(
chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(requester.user, event_id)
event = yield self.event_handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:

View file

@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$")
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)

View file

@ -56,6 +56,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request):
flows = []
@ -260,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
@ -329,6 +331,7 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request):

View file

@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)

View file

@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):

View file

@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs):
super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server)
@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server)
@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$"
)
def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)

View file

@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence})
yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,

View file

@ -45,6 +45,7 @@ class DownloadResource(Resource):
@request_handler()
@defer.inlineCallbacks
def _async_render_GET(self, request):
request.setHeader("Content-Security-Policy", "sandbox")
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name)

View file

@ -29,14 +29,13 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
from copy import deepcopy
import os
import re
import fnmatch
import cgi
import ujson as json
import urlparse
import itertools
import logging
logger = logging.getLogger(__name__)
@ -163,7 +162,7 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
if self._is_media(media_info['media_type']):
if _is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info
)
@ -184,11 +183,9 @@ class PreviewUrlResource(Resource):
logger.warn("Couldn't get dims for %s" % url)
# define our OG response for this media
elif self._is_html(media_info['media_type']):
elif _is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import etree
file = open(media_info['filename'])
body = file.read()
file.close()
@ -199,17 +196,35 @@ class PreviewUrlResource(Resource):
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
encoding = match.group(1) if match else "utf-8"
try:
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body, parser)
og = yield self._calc_og(tree, media_info, requester)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = yield self._calc_og(tree, media_info, requester)
og = decode_and_calc_og(body, media_info['uri'], encoding)
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
_rebase_url(og['og:image'], media_info['uri']), requester.user
)
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
else:
logger.warn("Failed to find any OG data in %s", url)
og = {}
@ -232,139 +247,6 @@ class PreviewUrlResource(Resource):
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
@defer.inlineCallbacks
def _calc_og(self, tree, media_info, requester):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
self._rebase_url(og['og:image'], media_info['uri']), requester.user
)
if self._is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
# We clone `tree` as we modify it.
cloned_tree = deepcopy(tree.find("body"))
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
for el in cloned_tree.iter(TAGS_TO_REMOVE):
el.getparent().remove(el)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el.text).strip()
for el in cloned_tree.iter()
if el.text and isinstance(el.tag, basestring) # Removes comments
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
defer.returnValue(og)
def _rebase_url(self, url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
@defer.inlineCallbacks
def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
@ -445,11 +327,165 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None,
})
def _is_media(self, content_type):
def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body, parser)
og = _calc_og(tree, media_uri)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = _calc_og(tree, media_uri)
return og
def _calc_og(tree, media_uri):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = _rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
from lxml import etree
TAGS_TO_REMOVE = (
"header", "nav", "aside", "footer", "script", "style", etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
return og
def _iterate_over_text(tree, *tags_to_ignore):
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
# This is basically a stack that we extend using itertools.chain.
# This will either consist of an element to iterate over *or* a string
# to be returned.
elements = iter([tree])
while True:
el = elements.next()
if isinstance(el, basestring):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
yield el.text
# We add to the stack all the elements children, interspersed with
# each child's tail text (if it exists). The tail text of a node
# is text that comes *after* the node, so we always include it even
# if we ignore the child node.
elements = itertools.chain(
itertools.chain.from_iterable( # Basically a flatmap
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
elements
)
def _rebase_url(url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
def _is_media(content_type):
if content_type.lower().startswith("image/"):
return True
def _is_html(self, content_type):
def _is_html(content_type):
content_type = content_type.lower()
if (
content_type.startswith("text/html") or

View file

@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
@ -94,6 +95,8 @@ class HomeServer(object):
'auth_handler',
'device_handler',
'e2e_keys_handler',
'event_handler',
'event_stream_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@ -214,6 +217,12 @@ class HomeServer(object):
def build_application_service_handler(self):
return ApplicationServicesHandler(self)
def build_event_handler(self):
return EventHandler(self)
def build_event_stream_handler(self):
return EventStreamHandler(self)
def build_event_sources(self):
return EventSources(self)

View file

@ -50,6 +50,7 @@ from .openid import OpenIdStore
from .client_ips import ClientIpStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")],
)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",

View file

@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict
from synapse.storage.engines import PostgresEngine
import synapse.metrics
@ -305,11 +306,12 @@ class SQLBaseStore(object):
func, *args, **kwargs
)
try:
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
finally:
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result)
@ -860,6 +862,62 @@ class SQLBaseStore(object):
return cache, min_val
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
txn,
table="cache_invalidation_stream",
values={
"stream_id": stream_id,
"cache_func": cache_func.__name__,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
}
)
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying

View file

@ -218,13 +218,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
Returns:
AppServiceTransaction: A new transaction.
"""
return self.runInteraction(
"create_appservice_txn",
self._create_appservice_txn,
service, events
)
def _create_appservice_txn(self, txn, service, events):
def _create_appservice_txn(txn):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
# or it may be the highest in the txns list (which are waiting to be/are
@ -252,6 +246,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
service=service, id=new_txn_id, events=events
)
return self.runInteraction(
"create_appservice_txn",
_create_appservice_txn,
)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@ -263,15 +262,9 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves if this transaction was stored
successfully.
"""
return self.runInteraction(
"complete_appservice_txn",
self._complete_appservice_txn,
txn_id, service
)
def _complete_appservice_txn(self, txn, txn_id, service):
txn_id = int(txn_id)
def _complete_appservice_txn(txn):
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
# what was there before. If it isn't, we've got problems (e.g. the AS
# has probably missed some events), so whine loudly but still continue,
@ -298,6 +291,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
dict(txn_id=txn_id, as_id=service.id)
)
return self.runInteraction(
"complete_appservice_txn",
_complete_appservice_txn,
)
@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
@ -309,24 +307,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves to an AppServiceTransaction or
None.
"""
entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn",
self._get_oldest_unsent_txn,
service
)
if not entry:
defer.returnValue(None)
event_ids = json.loads(entry["event_ids"])
events = yield self._get_events(event_ids)
defer.returnValue(AppServiceTransaction(
service=service, id=entry["txn_id"], events=events
))
def _get_oldest_unsent_txn(self, txn, service):
def _get_oldest_unsent_txn(txn):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
@ -342,6 +323,22 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
return entry
entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn",
_get_oldest_unsent_txn,
)
if not entry:
defer.returnValue(None)
event_ids = json.loads(entry["event_ids"])
events = yield self._get_events(event_ids)
defer.returnValue(AppServiceTransaction(
service=service, id=entry["txn_id"], events=events
))
def _get_last_txn(self, txn, service_id):
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
@ -352,3 +349,42 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
return 0
else:
return int(last_txn_id[0]) # select 'last_txn' col
def set_appservice_last_pos(self, pos):
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@defer.inlineCallbacks
def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""
def get_new_events_for_appservice_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e, appservice_stream_position AS a"
" WHERE a.stream_ordering < e.stream_ordering AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (current_id, limit))
rows = txn.fetchall()
upper_bound = current_id
if len(rows) == limit:
upper_bound = rows[-1][0]
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
)
events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events))

View file

@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
Returns:
Deferred
"""
try:
yield self._simple_insert(
def alias_txn(txn):
self._simple_insert_txn(
txn,
"room_aliases",
{
"room_alias": room_alias.to_string(),
"room_id": room_id,
"creator": creator,
},
desc="create_room_alias_association",
)
self._simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[{
"room_alias": room_alias.to_string(),
"server": server,
} for server in servers],
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (room_id,)
)
try:
ret = yield self.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
for server in servers:
# TODO(erikj): Fix this to bulk insert
yield self._simple_insert(
"room_alias_servers",
{
"room_alias": room_alias.to_string(),
"server": server,
},
desc="create_room_alias_association",
)
self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(ret)
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(

View file

@ -600,7 +600,8 @@ class EventsStore(SQLBaseStore):
"rejections",
"redactions",
"room_memberships",
"state_events"
"state_events",
"topics"
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 33
SCHEMA_VERSION = 34
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending",
)
@defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid):
result = yield self._simple_update_one(
def update_presence_list_txn(txn):
result = self._simple_update_one_txn(
txn,
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
keyvalues={
"user_id": observer_localpart,
"observed_user_id": observed_userid
},
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
defer.returnValue(result)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_accepted, (observer_localpart,)
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_observers_accepted, (observed_userid,)
)
return result
return self.runInteraction(
"set_presence_list_accepted", update_presence_list_txn,
)
def get_presence_list(self, observer_localpart, accepted=None):
if accepted:

View file

@ -93,7 +93,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="add_refresh_token_to_user",
)
@defer.inlineCallbacks
def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False):
@ -115,7 +114,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Raises:
StoreError if the user_id could not be registered.
"""
yield self.runInteraction(
return self.runInteraction(
"register",
self._register,
user_id,
@ -127,8 +126,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
create_profile_with_localpart,
admin
)
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
def _register(
self,
@ -210,6 +207,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
(create_profile_with_localpart,)
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
@ -236,22 +238,31 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
yield self._simple_update_one('users', {
def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
txn,
'users', {
'name': user_id
}, {
},
{
'password_hash': password_hash
})
self.get_user_by_id.invalidate((user_id,))
}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
return self.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
@defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[],
def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None,
delete_refresh_tokens=False):
"""
@ -259,7 +270,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args:
user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should
except_token_id (str): list of access_tokens IDs which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
@ -269,53 +280,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Returns:
defer.Deferred:
"""
def f(txn, table, except_tokens, call_after_delete):
sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
def f(txn):
keyvalues = {
"user_id": user_id,
}
if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
keyvalues["device_id"] = device_id
if except_tokens:
sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_tokens]),
if delete_refresh_tokens:
self._simple_delete_txn(
txn,
table="refresh_tokens",
keyvalues=keyvalues,
)
clauses += except_tokens
txn.execute(sql, clauses)
rows = txn.fetchall()
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)
txn.execute(
"DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
"SELECT token FROM access_tokens WHERE %s" % where_clause,
values
)
rows = self.cursor_to_dict(txn)
for row in rows:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],)
)
# delete refresh tokens first, to stop new access tokens being
# allocated while our backs are turned
if delete_refresh_tokens:
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
txn.execute(
"DELETE FROM access_tokens WHERE %s" % where_clause,
values
)
yield self.runInteraction(
"user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token):
@ -328,7 +331,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
},
)
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f)

View file

@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN],
)
@defer.inlineCallbacks
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all()
self.who_forgot_in_room.invalidate_all()
self.did_forget.invalidate((user_id, room_id))
txn.call_after(self.was_forgotten_at.invalidate_all)
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.who_forgot_in_room, (room_id,)
)
return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id):

View file

@ -0,0 +1,23 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS appservice_stream_position(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_ordering BIGINT,
CHECK (Lock='X')
);
INSERT INTO appservice_stream_position (stream_ordering)
SELECT COALESCE(MAX(stream_ordering), 0) FROM events;

View file

@ -0,0 +1,46 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
# This stream is used to notify replication slaves that some caches have
# been invalidated that they cannot infer from the other streams.
CREATE_TABLE = """
CREATE TABLE cache_invalidation_stream (
stream_id BIGINT,
cache_func TEXT,
keys TEXT[],
invalidation_ts BIGINT
);
CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
"""
def run_create(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
return
for statement in get_statements(CREATE_TABLE.splitlines()):
cur.execute(statement)
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -0,0 +1,20 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';

View file

@ -14,6 +14,8 @@
# limitations under the License.
from synapse.appservice import ApplicationService
from twisted.internet import defer
from mock import Mock
from tests import unittest
@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
)
self.store = Mock()
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@someone_else:matrix.org"
self.assertFalse(self.service.is_interested(self.event))
self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(self.service.is_interested(self.event))
self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
self.assertTrue(self.service.is_interested(
self.event,
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
))
self.store.get_aliases_for_room.return_value = [
"#irc_foobar:matrix.org", "#athing:matrix.org"
]
self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested(
self.event, self.store
)))
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
"!irc_foobar:matrix.org"
))
@defer.inlineCallbacks
def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
self.assertFalse(self.service.is_interested(
self.event,
aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
))
self.store.get_aliases_for_room.return_value = [
"#xmpp_foobar:matrix.org", "#athing:matrix.org"
]
self.store.get_users_in_room.return_value = []
self.assertFalse((yield self.service.is_interested(
self.event, self.store
)))
@defer.inlineCallbacks
def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(
self.event,
aliases_for_event=["#irc_barfoo:matrix.org"]
))
def test_restrict_to_rooms(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!flibble_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.event.room_id = "!wibblewoo:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ROOMS
))
def test_restrict_to_aliases(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#xmpp_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ALIASES,
aliases_for_event=["#irc_barfoo:matrix.org"]
))
def test_restrict_to_senders(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#xmpp_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_USERS,
aliases_for_event=["#xmpp_barfoo:matrix.org"]
))
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested(
self.event, self.store
)))
@defer.inlineCallbacks
def test_interested_in_self(self):
# make sure invites get through
self.service.sender = "@appservice:name"
@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
"membership": "invite"
}
self.event.state_key = self.service.sender
self.assertTrue(self.service.is_interested(self.event))
self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
join_list = [
self.store.get_users_in_room.return_value = [
"@alice:here",
"@irc_fo:here", # AS user
"@bob:here",
]
self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(self.service.is_interested(
event=self.event,
member_list=join_list
))
self.assertTrue((yield self.service.is_interested(
event=self.event, store=self.store
)))

View file

@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self):
self.txn_ctrl = Mock()
self.queuer = _ServiceQueuer(self.txn_ctrl)
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately.

View file

@ -15,6 +15,7 @@
from twisted.internet import defer
from .. import unittest
from tests.utils import MockClock
from synapse.handlers.appservice import ApplicationServicesHandler
@ -32,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore = Mock(return_value=self.mock_store)
hs.get_application_service_api = Mock(return_value=self.mock_as_api)
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
@defer.inlineCallbacks
@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
type="m.room.message",
room_id="!foo:bar"
)
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(event)
yield self.handler.notify_interested_services(0)
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id
)
@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
yield self.handler.notify_interested_services(0)
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been."
@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet"
servers = ["aperture"]
interested_service = self._mkservice(is_interested=True)
interested_service = self._mkservice_alias(is_interested_in_alias=True)
services = [
self._mkservice(is_interested=False),
self._mkservice_alias(is_interested_in_alias=False),
interested_service,
self._mkservice(is_interested=False)
self._mkservice_alias(is_interested_in_alias=False)
]
self.mock_store.get_app_services = Mock(return_value=services)
@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
def _mkservice_alias(self, is_interested_in_alias):
service = Mock()
service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service

View file

@ -15,7 +15,9 @@
from . import unittest
from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs
from synapse.rest.media.v1.preview_url_resource import (
summarize_paragraphs, decode_and_calc_og
)
class PreviewTestCase(unittest.TestCase):
@ -137,3 +139,79 @@ class PreviewTestCase(unittest.TestCase):
" of old wooden houses in Northern Norway, the oldest house dating from"
" 1789. The Arctic Cathedral, a modern church…"
)
class PreviewUrlTestCase(unittest.TestCase):
def test_simple(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})
def test_comment(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
<!-- HTML comment -->
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})
def test_comment2(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
<!-- HTML comment -->
Some more text.
<p>Text</p>
More text
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
})
def test_script(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
<script> (function() {})() </script>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})