Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api

This commit is contained in:
David Baker 2016-08-18 17:15:26 +01:00
commit 602c84cd9c
53 changed files with 1546 additions and 641 deletions

View file

@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27 tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
tox -e py27 tox -e py27

View file

@ -14,6 +14,7 @@ fi
tox -e py27 --notest -v tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml $TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2 $TOX_BIN/pip install psycopg2

209
synapse/app/appservice.py Normal file
View file

@ -0,0 +1,209 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks
def replicate(results):
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
replicate(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``notify_appservices: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.notify_appservices = True
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-appservice",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -0,0 +1,212 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import (
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
)
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse media repository", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-media-repository",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -80,11 +80,6 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__ DataStore.get_profile_displayname.__func__
) )
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@ -168,7 +163,6 @@ class PusherServer(HomeServer):
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool() pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey): def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)
@ -220,21 +214,11 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids min_stream_id, max_stream_id, affected_room_ids
) )
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
poke_pushers(result) poke_pushers(result)
except: except:

View file

@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[ get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted" "get_presence_list_accepted"
] ]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000 UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object): class SynchrotronPresence(object):
def __init__(self, hs): def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info() active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = { self.user_to_current_state = {
@ -119,11 +121,13 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state): def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work? # TODO Hows this supposed to work?
pass pass
get_states = PresenceHandler.get_states.__func__ get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks @defer.inlineCallbacks
@ -194,19 +198,39 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False self._need_to_send_sync = False
yield self._send_syncing_users_now() yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result): def process_replication(self, result):
stream = result.get("presence", {"rows": []}) stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]: for row in stream["rows"]:
( (
position, user_id, state, last_active_ts, position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) = row ) = row
self.user_to_current_state[user_id] = UserPresenceState( state = UserPresenceState(
user_id, state, last_active_ts, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) )
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object): class SynchrotronTyping(object):
@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource) sync.register_servlets(self, resource)
events.register_servlets(self, resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource, "/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client() http_client = self.get_simple_http_client()
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier() notifier = self.get_notifier()
presence_handler = self.get_presence_handler() presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler() typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream( def notify_from_stream(
result, stream_name, stream_key, room=None, user=None result, stream_name, stream_key, room=None, user=None
): ):
@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
result, "typing", "typing_key", room="room_id" result, "typing", "typing_key", room="room_id"
) )
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args.update(typing_handler.stream_positions()) args.update(typing_handler.stream_positions())
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
presence_handler.process_replication(result) yield presence_handler.process_replication(result)
notify(result) notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from twisted.internet import defer
import logging import logging
import re import re
@ -138,65 +140,66 @@ class ApplicationService(object):
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
def _matches_user(self, event, member_list): @defer.inlineCallbacks
if (hasattr(event, "sender") and def _matches_user(self, event, store):
self.is_interested_in_user(event.sender)): if not event:
return True defer.returnValue(False)
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key # also check m.room.member state key
if (hasattr(event, "type") and event.type == EventTypes.Member if (event.type == EventTypes.Member and
and hasattr(event, "state_key") self.is_interested_in_user(event.state_key)):
and self.is_interested_in_user(event.state_key)): defer.returnValue(True)
return True
if not store:
defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id)
# check joined member events # check joined member events
for user_id in member_list: for user_id in member_list:
if self.is_interested_in_user(user_id): if self.is_interested_in_user(user_id):
return True defer.returnValue(True)
return False defer.returnValue(False)
def _matches_room_id(self, event): def _matches_room_id(self, event):
if hasattr(event, "room_id"): if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id) return self.is_interested_in_room(event.room_id)
return False return False
def _matches_aliases(self, event, alias_list): @defer.inlineCallbacks
def _matches_aliases(self, event, store):
if not store or not event:
defer.returnValue(False)
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list: for alias in alias_list:
if self.is_interested_in_alias(alias): if self.is_interested_in_alias(alias):
return True defer.returnValue(True)
return False defer.returnValue(False)
def is_interested(self, event, restrict_to=None, aliases_for_event=None, @defer.inlineCallbacks
member_list=None): def is_interested(self, event, store=None):
"""Check if this service is interested in this event. """Check if this service is interested in this event.
Args: Args:
event(Event): The event to check. event(Event): The event to check.
restrict_to(str): The namespace to restrict regex tests to. store(DataStore)
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
Returns: Returns:
bool: True if this service would like to know about this event. bool: True if this service would like to know about this event.
""" """
if aliases_for_event is None: # Do cheap checks first
aliases_for_event = [] if self._matches_room_id(event):
if member_list is None: defer.returnValue(True)
member_list = []
if restrict_to and restrict_to not in ApplicationService.NS_LIST: if (yield self._matches_aliases(event, store)):
# this is a programming error, so fail early and raise a general defer.returnValue(True)
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if not restrict_to: if (yield self._matches_user(event, store)):
return (self._matches_user(event, member_list) defer.returnValue(True)
or self._matches_aliases(event, aliases_for_event)
or self._matches_room_id(event)) defer.returnValue(False)
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id): def is_interested_in_user(self, user_id):
return ( return (

View file

@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from twisted.internet import defer from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController( self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer self.clock, self.store, self.as_api, create_recoverer
) )
self.queuer = _ServiceQueuer(self.txn_ctrl) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
@ -94,38 +97,36 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run. this schedules any other events in the queue to run.
""" """
def __init__(self, txn_ctrl): def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]} self.queued_events = {} # dict of {service_id: [events]}
self.pending_requests = {} # dict of {service_id: Deferred} self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event): def enqueue(self, service, event):
# if this service isn't being sent something # if this service isn't being sent something
if not self.pending_requests.get(service.id): self.queued_events.setdefault(service.id, []).append(event)
self._send_request(service, [event]) preserve_fn(self._send_request)(service)
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
def _send_request(self, service, events): @defer.inlineCallbacks
# send request and add callbacks def _send_request(self, service):
d = self.txn_ctrl.send(service, events) if service.id in self.requests_in_flight:
d.addBoth(self._on_request_finish) return
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
def _on_request_finish(self, service): self.requests_in_flight.add(service.id)
self.pending_requests[service.id] = None try:
# if there are queued events, then send them. while True:
if (service.id in self.queued_events events = self.queued_events.pop(service.id, [])
and len(self.queued_events[service.id]) > 0): if not events:
self._send_request(service, self.queued_events[service.id]) return
self.queued_events[service.id] = []
def _on_request_fail(self, err): with Measure(self.clock, "servicequeuer.send"):
logger.error("AS request failed: %s", err) try:
yield self.txn_ctrl.send(service, events)
except:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
class _TransactionController(object): class _TransactionController(object):
@ -155,8 +156,6 @@ class _TransactionController(object):
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
self._start_recoverer(service) self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_recovered(self, recoverer): def on_recovered(self, recoverer):

View file

@ -28,6 +28,7 @@ class AppServiceConfig(Config):
def read_config(self, config): def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs): def default_config(cls, **kwargs):
return """\ return """\

View file

@ -19,7 +19,6 @@ from .room import (
) )
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@ -53,8 +52,6 @@ class Handlers(object):
self.message_handler = MessageHandler(hs) self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)

View file

@ -16,7 +16,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn
import logging import logging
@ -42,25 +43,53 @@ class ApplicationServicesHandler(object):
self.appservice_api = hs.get_application_service_api() self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler() self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_interested_services(self, event): def notify_interested_services(self, current_id):
"""Notifies (pushes) all application services interested in this event. """Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any Pushing is done asynchronously, so this method won't block for any
prolonged length of time. prolonged length of time.
Args: Args:
event(Event): The event to push out to interested services. current_id(int): The current maximum ID.
""" """
services = yield self.store.get_app_services()
if not services or not self.notify_appservices:
return
self.current_max = max(self.current_max, current_id)
if self.is_processing:
return
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
)
if not events:
break
for event in events:
# Gather interested services # Gather interested services
services = yield self._get_services_for_event(event) services = yield self._get_services_for_event(event)
if len(services) == 0: if len(services) == 0:
return # no services need notifying continue # no services need notifying
# Do we know this user exists? If not, poke the user query API for # Do we know this user exists? If not, poke the user
# all services which match that user regex. This needs to block as these # query API for all services which match that user regex.
# user queries need to be made BEFORE pushing the event. # This needs to block as these user queries need to be
# made BEFORE pushing the event.
yield self._check_user_exists(event.sender) yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key) yield self._check_user_exists(event.state_key)
@ -71,7 +100,16 @@ class ApplicationServicesHandler(object):
# Fork off pushes to these services # Fork off pushes to these services
for service in services: for service in services:
self.scheduler.submit_event_for_as(service, event) preserve_fn(self.scheduler.submit_event_for_as)(
service, event
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user_exists(self, user_id): def query_user_exists(self, user_id):
@ -104,11 +142,12 @@ class ApplicationServicesHandler(object):
association can be found. association can be found.
""" """
room_alias_str = room_alias.to_string() room_alias_str = room_alias.to_string()
alias_query_services = yield self._get_services_for_event( services = yield self.store.get_app_services()
event=None, alias_query_services = [
restrict_to=ApplicationService.NS_ALIASES, s for s in services if (
alias_list=[room_alias_str] s.is_interested_in_alias(room_alias_str)
) )
]
for alias_service in alias_query_services: for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias( is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str alias_service, room_alias_str
@ -121,34 +160,19 @@ class ApplicationServicesHandler(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_services_for_event(self, event, restrict_to="", alias_list=None): def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.
Args: Args:
event(Event): The event to check. Can be None if alias_list is not. event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns: Returns:
list<ApplicationService>: A list of services interested in this list<ApplicationService>: A list of services interested in this
event based on the service regex. event based on the service regex.
""" """
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
s.is_interested(event, restrict_to, alias_list, member_list) yield s.is_interested(event, self.store)
) )
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)

View file

@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
self.ldap_uri = hs.config.ldap_uri self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base self.ldap_base = hs.config.ldap_base
self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH: if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
else: else:
logger.warn( logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results", "ldap registration failed: unexpected (%d!=1) amount of results",
len(result) len(conn.response)
) )
defer.returnValue(False) defer.returnValue(False)
@ -741,7 +741,7 @@ class AuthHandler(BaseHandler):
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else [] except_access_token_id = requester.access_token_id if requester else None
try: try:
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
@ -750,10 +750,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids user_id, except_access_token_id
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids user_id, except_access_token_id
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -274,7 +274,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]): def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return This will attempt to get more events from the remote. This may return
@ -284,9 +284,6 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill( events = yield self.replication_layer.backfill(
dest, dest,
room_id, room_id,
@ -455,6 +452,10 @@ class FederationHandler(BaseHandler):
) )
max_depth = sorted_extremeties_tuple[0][1] max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth: if current_depth > max_depth:
logger.debug( logger.debug(
"Not backfilling as we don't need to. %d < %d", "Not backfilling as we don't need to. %d < %d",

View file

@ -503,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_interested_parties(self, states): def _get_interested_parties(self, states, calculate_remote_hosts=True):
"""Given a list of states return which entities (rooms, users, servers) """Given a list of states return which entities (rooms, users, servers)
are interested in the given states. are interested in the given states.
@ -526,6 +526,7 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state) users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {} hosts_to_states = {}
if calculate_remote_hosts:
for room_id, states in room_ids_to_states.items(): for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states) local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states: if not local_states:
@ -565,6 +566,16 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states) self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states): def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers. """Sends state updates to remote servers.
@ -672,7 +683,7 @@ class PresenceHandler(object):
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_state(self, target_user, state): def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user. """Set the presence state of the user.
""" """
status_msg = state.get("status_msg", None) status_msg = state.get("status_msg", None)
@ -689,10 +700,13 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id) prev_state = yield self.current_state_for_user(user_id)
new_fields = { new_fields = {
"state": presence, "state": presence
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
} }
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
if presence == PresenceState.ONLINE: if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec() new_fields["last_active_ts"] = self.clock.time_msec()

View file

@ -141,7 +141,7 @@ class RoomMemberHandler(BaseHandler):
third_party_signed=None, third_party_signed=None,
ratelimit=True, ratelimit=True,
): ):
key = (target, room_id,) key = (room_id,)
with (yield self.member_linearizer.queue(key)): with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership( result = yield self._update_membership(

View file

@ -67,10 +67,8 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user_id, rooms, current_token, time_now_ms, def __init__(self, user_id, rooms, current_token, time_now_ms):
appservice=None):
self.user_id = user_id self.user_id = user_id
self.appservice = appservice
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
@ -107,11 +105,6 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id) notifier.user_to_user_stream.pop(self.user_id)
if self.appservice:
notifier.appservice_to_user_streams.get(
self.appservice, set()
).discard(self)
def count_listeners(self): def count_listeners(self):
return len(self.notify_deferred.observers()) return len(self.notify_deferred.observers())
@ -142,7 +135,6 @@ class Notifier(object):
def __init__(self, hs): def __init__(self, hs):
self.user_to_user_stream = {} self.user_to_user_stream = {}
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -168,8 +160,6 @@ class Notifier(object):
all_user_streams |= x all_user_streams |= x
for x in self.user_to_user_stream.values(): for x in self.user_to_user_stream.values():
all_user_streams.add(x) all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
return sum(stream.count_listeners() for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
@ -182,10 +172,6 @@ class Notifier(object):
"users", "users",
lambda: len(self.user_to_user_stream), lambda: len(self.user_to_user_stream),
) )
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
)
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
@ -228,21 +214,7 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
self.appservice_handler.notify_interested_services(event) self.appservice_handler.notify_interested_services(room_stream_id)
app_streams = set()
for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
appservice, set()
)
app_streams |= app_user_streams
if event.type == EventTypes.Member and event.membership == Membership.JOIN: if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id) self._user_joined_room(event.state_key, event.room_id)
@ -251,11 +223,9 @@ class Notifier(object):
"room_key", room_stream_id, "room_key", room_stream_id,
users=extra_users, users=extra_users,
rooms=[event.room_id], rooms=[event.room_id],
extra_streams=app_streams,
) )
def on_new_event(self, stream_key, new_token, users=[], rooms=[], def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
extra_streams=set()):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
@ -294,7 +264,6 @@ class Notifier(object):
""" """
user_stream = self.user_to_user_stream.get(user_id) user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id) rooms = yield self.store.get_rooms_for_user(user_id)
@ -302,7 +271,6 @@ class Notifier(object):
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user_id=user_id, user_id=user_id,
rooms=room_ids, rooms=room_ids,
appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
) )
@ -477,11 +445,6 @@ class Notifier(object):
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream) s.add(user_stream)
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
def _user_joined_room(self, user_id, room_id): def _user_joined_room(self, user_id, room_id):
new_user_stream = self.user_to_user_stream.get(user_id) new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None: if new_user_stream is not None:

View file

@ -38,11 +38,12 @@ class ActionGenerator:
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "handle_push_actions_for_event"): with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.current_state event, self.hs, self.store, context.current_state
) )
with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state event, context.current_state
) )

View file

@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [
'dont_notify' 'dont_notify'
] ]
}, },
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
] ]
@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{ {
'rule_id': 'global/underride/.m.rule.room_one_to_one', 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [

View file

@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_token_ids=[]): def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access tokens ids %r", "Removing all pushers for user %s except access tokens id %r",
user_id, except_token_ids user_id, except_access_token_id
) )
for p in all: for p in all:
if p['user_name'] == user_id and p['access_token'] not in except_token_ids: if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']

View file

@ -41,6 +41,7 @@ STREAM_NAMES = (
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",), ("state",),
("caches",),
) )
@ -70,6 +71,7 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers. * "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules. * "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers. * "pushers": Per user changes to their pushers.
* "caches": Cache invalidations.
The API takes two additional query parameters: The API takes two additional query parameters:
@ -129,6 +131,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token() state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -140,6 +143,7 @@ class ReplicationResource(Resource):
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token, state_token,
caches_token,
)) ))
@request_handler() @request_handler()
@ -188,6 +192,7 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
@ -379,6 +384,20 @@ class ReplicationResource(Resource):
"position", "type", "state_key", "event_id" "position", "type", "state_key", "event_id"
)) ))
@defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches
caches = request_streams.get("caches")
if caches is not None:
updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit
)
writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -407,7 +426,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state" "push_rules", "pushers", "state", "caches",
))): ))):
__slots__ = [] __slots__ = []

View file

@ -14,15 +14,43 @@
# limitations under the License. # limitations under the License.
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
import logging
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore): class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs) super(BaseSlavedStore, self).__init__(hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
def stream_positions(self): def stream_positions(self):
return {} pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def process_replication(self, result): def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None) return defer.succeed(None)

View file

@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_app_service_by_token = DataStore.get_app_service_by_token.__func__ get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
get_app_services = DataStore.get_app_services.__func__
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__

View file

@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore): class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[ get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room" "get_aliases_for_room"
].orig ]

View file

@ -25,6 +25,9 @@ class SlavedRegistrationStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens # TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[ get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token" "get_user_by_access_token"
].orig ]
_query_for_auth = DataStore._query_for_auth.__func__ _query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View file

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(WhoisRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
) )
def __init__(self, hs):
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id): def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View file

@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()

View file

@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs) super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):

View file

@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs)
self.event_stream_handler = hs.get_event_stream_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req( requester = yield self.auth.get_user_by_req(
@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args: if "room_id" in request.args:
room_id = request.args["room_id"][0] room_id = request.args["room_id"][0]
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args: if "timeout" in request.args:
@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
chunk = yield handler.get_stream( chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(), requester.user.to_string(),
pagin_config, pagin_config,
timeout=timeout, timeout=timeout,
@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler event = yield self.event_handler.get_event(requester.user, event_id)
event = yield handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View file

@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$") PATTERNS = client_path_patterns("/initialSync$")
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View file

@ -56,6 +56,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler() self.device_handler = self.hs.get_device_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -260,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs) super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path self.sp_config = hs.config.saml2_config_path
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -329,6 +331,7 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View file

@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)

View file

@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
self.sessions = {} self.sessions = {}
self.enable_registration = hs.config.enable_registration self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs) super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):

View file

@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet): class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet): class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet): class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContext, self).__init__(hs) super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$" "/search$"
) )
def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View file

@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
affect_presence = set_presence != PresenceState.OFFLINE affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence: if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence}) yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
context = yield self.presence_handler.user_syncing( context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence, user.to_string(), affect_presence=affect_presence,

View file

@ -45,6 +45,7 @@ class DownloadResource(Resource):
@request_handler() @request_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
request.setHeader("Content-Security-Policy", "sandbox")
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name) yield self._respond_local_file(request, media_id, name)

View file

@ -29,14 +29,13 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from copy import deepcopy
import os import os
import re import re
import fnmatch import fnmatch
import cgi import cgi
import ujson as json import ujson as json
import urlparse import urlparse
import itertools
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -163,7 +162,7 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info) logger.debug("got media_info of '%s'" % media_info)
if self._is_media(media_info['media_type']): if _is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails( dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info media_info['filesystem_id'], media_info
) )
@ -184,11 +183,9 @@ class PreviewUrlResource(Resource):
logger.warn("Couldn't get dims for %s" % url) logger.warn("Couldn't get dims for %s" % url)
# define our OG response for this media # define our OG response for this media
elif self._is_html(media_info['media_type']): elif _is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM # TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import etree
file = open(media_info['filename']) file = open(media_info['filename'])
body = file.read() body = file.read()
file.close() file.close()
@ -199,17 +196,35 @@ class PreviewUrlResource(Resource):
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
encoding = match.group(1) if match else "utf-8" encoding = match.group(1) if match else "utf-8"
try: og = decode_and_calc_og(body, media_info['uri'], encoding)
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body, parser)
og = yield self._calc_og(tree, media_info, requester)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = yield self._calc_og(tree, media_info, requester)
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
_rebase_url(og['og:image'], media_info['uri']), requester.user
)
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
else: else:
logger.warn("Failed to find any OG data in %s", url) logger.warn("Failed to find any OG data in %s", url)
og = {} og = {}
@ -232,139 +247,6 @@ class PreviewUrlResource(Resource):
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
@defer.inlineCallbacks
def _calc_og(self, tree, media_info, requester):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
self._rebase_url(og['og:image'], media_info['uri']), requester.user
)
if self._is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
# We clone `tree` as we modify it.
cloned_tree = deepcopy(tree.find("body"))
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
for el in cloned_tree.iter(TAGS_TO_REMOVE):
el.getparent().remove(el)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el.text).strip()
for el in cloned_tree.iter()
if el.text and isinstance(el.tag, basestring) # Removes comments
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
defer.returnValue(og)
def _rebase_url(self, url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_url(self, url, user): def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice # TODO: we should probably honour robots.txt... except in practice
@ -445,11 +327,165 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None, "etag": headers["ETag"][0] if "ETag" in headers else None,
}) })
def _is_media(self, content_type):
def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body, parser)
og = _calc_og(tree, media_uri)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = _calc_og(tree, media_uri)
return og
def _calc_og(tree, media_uri):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = _rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
from lxml import etree
TAGS_TO_REMOVE = (
"header", "nav", "aside", "footer", "script", "style", etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
return og
def _iterate_over_text(tree, *tags_to_ignore):
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
# This is basically a stack that we extend using itertools.chain.
# This will either consist of an element to iterate over *or* a string
# to be returned.
elements = iter([tree])
while True:
el = elements.next()
if isinstance(el, basestring):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
yield el.text
# We add to the stack all the elements children, interspersed with
# each child's tail text (if it exists). The tail text of a node
# is text that comes *after* the node, so we always include it even
# if we ignore the child node.
elements = itertools.chain(
itertools.chain.from_iterable( # Basically a flatmap
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
elements
)
def _rebase_url(url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
def _is_media(content_type):
if content_type.lower().startswith("image/"): if content_type.lower().startswith("image/"):
return True return True
def _is_html(self, content_type):
def _is_html(content_type):
content_type = content_type.lower() content_type = content_type.lower()
if ( if (
content_type.startswith("text/html") or content_type.startswith("text/html") or

View file

@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -94,6 +95,8 @@ class HomeServer(object):
'auth_handler', 'auth_handler',
'device_handler', 'device_handler',
'e2e_keys_handler', 'e2e_keys_handler',
'event_handler',
'event_stream_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -214,6 +217,12 @@ class HomeServer(object):
def build_application_service_handler(self): def build_application_service_handler(self):
return ApplicationServicesHandler(self) return ApplicationServicesHandler(self)
def build_event_handler(self):
return EventHandler(self)
def build_event_stream_handler(self):
return EventStreamHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)

View file

@ -50,6 +50,7 @@ from .openid import OpenIdStore
from .client_ips import ClientIpStore from .client_ips import ClientIpStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",

View file

@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.storage.engines import PostgresEngine
import synapse.metrics import synapse.metrics
@ -305,11 +306,12 @@ class SQLBaseStore(object):
func, *args, **kwargs func, *args, **kwargs
) )
try:
with PreserveLoggingContext(): with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection( result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
finally:
for after_callback, after_args in after_callbacks: for after_callback, after_args in after_callbacks:
after_callback(*after_args) after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
@ -860,6 +862,62 @@ class SQLBaseStore(object):
return cache, min_val return cache, min_val
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
txn,
table="cache_invalidation_stream",
values={
"stream_id": stream_id,
"cache_func": cache_func.__name__,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
}
)
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying

View file

@ -218,13 +218,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
Returns: Returns:
AppServiceTransaction: A new transaction. AppServiceTransaction: A new transaction.
""" """
return self.runInteraction( def _create_appservice_txn(txn):
"create_appservice_txn",
self._create_appservice_txn,
service, events
)
def _create_appservice_txn(self, txn, service, events):
# work out new txn id (highest txn id for this service += 1) # work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn) # The highest id may be the last one sent (in which case it is last_txn)
# or it may be the highest in the txns list (which are waiting to be/are # or it may be the highest in the txns list (which are waiting to be/are
@ -252,6 +246,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
service=service, id=new_txn_id, events=events service=service, id=new_txn_id, events=events
) )
return self.runInteraction(
"create_appservice_txn",
_create_appservice_txn,
)
def complete_appservice_txn(self, txn_id, service): def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction. """Completes an application service transaction.
@ -263,15 +262,9 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves if this transaction was stored A Deferred which resolves if this transaction was stored
successfully. successfully.
""" """
return self.runInteraction(
"complete_appservice_txn",
self._complete_appservice_txn,
txn_id, service
)
def _complete_appservice_txn(self, txn, txn_id, service):
txn_id = int(txn_id) txn_id = int(txn_id)
def _complete_appservice_txn(txn):
# Debugging query: Make sure the txn being completed is EXACTLY +1 from # Debugging query: Make sure the txn being completed is EXACTLY +1 from
# what was there before. If it isn't, we've got problems (e.g. the AS # what was there before. If it isn't, we've got problems (e.g. the AS
# has probably missed some events), so whine loudly but still continue, # has probably missed some events), so whine loudly but still continue,
@ -298,6 +291,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
dict(txn_id=txn_id, as_id=service.id) dict(txn_id=txn_id, as_id=service.id)
) )
return self.runInteraction(
"complete_appservice_txn",
_complete_appservice_txn,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_oldest_unsent_txn(self, service): def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this """Get the oldest transaction which has not been sent for this
@ -309,24 +307,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves to an AppServiceTransaction or A Deferred which resolves to an AppServiceTransaction or
None. None.
""" """
entry = yield self.runInteraction( def _get_oldest_unsent_txn(txn):
"get_oldest_unsent_appservice_txn",
self._get_oldest_unsent_txn,
service
)
if not entry:
defer.returnValue(None)
event_ids = json.loads(entry["event_ids"])
events = yield self._get_events(event_ids)
defer.returnValue(AppServiceTransaction(
service=service, id=entry["txn_id"], events=events
))
def _get_oldest_unsent_txn(self, txn, service):
# Monotonically increasing txn ids, so just select the smallest # Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent) # one in the txns table (we delete them when they are sent)
txn.execute( txn.execute(
@ -342,6 +323,22 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
return entry return entry
entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn",
_get_oldest_unsent_txn,
)
if not entry:
defer.returnValue(None)
event_ids = json.loads(entry["event_ids"])
events = yield self._get_events(event_ids)
defer.returnValue(AppServiceTransaction(
service=service, id=entry["txn_id"], events=events
))
def _get_last_txn(self, txn, service_id): def _get_last_txn(self, txn, service_id):
txn.execute( txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?", "SELECT last_txn FROM application_services_state WHERE as_id=?",
@ -352,3 +349,42 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
return 0 return 0
else: else:
return int(last_txn_id[0]) # select 'last_txn' col return int(last_txn_id[0]) # select 'last_txn' col
def set_appservice_last_pos(self, pos):
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@defer.inlineCallbacks
def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""
def get_new_events_for_appservice_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e, appservice_stream_position AS a"
" WHERE a.stream_ordering < e.stream_ordering AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (current_id, limit))
rows = txn.fetchall()
upper_bound = current_id
if len(rows) == limit:
upper_bound = rows[-1][0]
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
)
events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events))

View file

@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
Returns: Returns:
Deferred Deferred
""" """
try: def alias_txn(txn):
yield self._simple_insert( self._simple_insert_txn(
txn,
"room_aliases", "room_aliases",
{ {
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
"room_id": room_id, "room_id": room_id,
"creator": creator, "creator": creator,
}, },
desc="create_room_alias_association", )
self._simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[{
"room_alias": room_alias.to_string(),
"server": server,
} for server in servers],
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (room_id,)
)
try:
ret = yield self.runInteraction(
"create_room_alias_association", alias_txn
) )
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise SynapseError( raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string() 409, "Room alias %s already exists" % room_alias.to_string()
) )
defer.returnValue(ret)
for server in servers:
# TODO(erikj): Fix this to bulk insert
yield self._simple_insert(
"room_alias_servers",
{
"room_alias": room_alias.to_string(),
"server": server,
},
desc="create_room_alias_association",
)
self.get_aliases_for_room.invalidate((room_id,))
def get_room_alias_creator(self, room_alias): def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(

View file

@ -600,7 +600,8 @@ class EventsStore(SQLBaseStore):
"rejections", "rejections",
"redactions", "redactions",
"room_memberships", "room_memberships",
"state_events" "state_events",
"topics"
): ):
txn.executemany( txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,), "DELETE FROM %s WHERE event_id = ?" % (table,),

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 33 SCHEMA_VERSION = 34
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending", desc="add_presence_list_pending",
) )
@defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid): def set_presence_list_accepted(self, observer_localpart, observed_userid):
result = yield self._simple_update_one( def update_presence_list_txn(txn):
result = self._simple_update_one_txn(
txn,
table="presence_list", table="presence_list",
keyvalues={"user_id": observer_localpart, keyvalues={
"observed_user_id": observed_userid}, "user_id": observer_localpart,
"observed_user_id": observed_userid
},
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,)) self._invalidate_cache_and_stream(
defer.returnValue(result) txn, self.get_presence_list_accepted, (observer_localpart,)
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_observers_accepted, (observed_userid,)
)
return result
return self.runInteraction(
"set_presence_list_accepted", update_presence_list_txn,
)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
if accepted: if accepted:

View file

@ -93,7 +93,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="add_refresh_token_to_user", desc="add_refresh_token_to_user",
) )
@defer.inlineCallbacks
def register(self, user_id, token=None, password_hash=None, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False): create_profile_with_localpart=None, admin=False):
@ -115,7 +114,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
""" """
yield self.runInteraction( return self.runInteraction(
"register", "register",
self._register, self._register,
user_id, user_id,
@ -127,8 +126,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
create_profile_with_localpart, create_profile_with_localpart,
admin admin
) )
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
def _register( def _register(
self, self,
@ -210,6 +207,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
(create_profile_with_localpart,) (create_profile_with_localpart,)
) )
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached() @cached()
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
return self._simple_select_one( return self._simple_select_one(
@ -236,22 +238,31 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("get_users_by_id_case_insensitive", f) return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash(self, user_id, password_hash):
""" """
NB. This does *not* evict any cache because the one use for this NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately. pointless. Use flush_user separately.
""" """
yield self._simple_update_one('users', { def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
txn,
'users', {
'name': user_id 'name': user_id
}, { },
{
'password_hash': password_hash 'password_hash': password_hash
}) }
self.get_user_by_id.invalidate((user_id,)) )
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
return self.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[], def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None, device_id=None,
delete_refresh_tokens=False): delete_refresh_tokens=False):
""" """
@ -259,7 +270,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args: Args:
user_id (str): ID of user the tokens belong to user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should except_token_id (str): list of access_tokens IDs which should
*not* be deleted *not* be deleted
device_id (str|None): ID of device the tokens are associated with. device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
@ -269,53 +280,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Returns: Returns:
defer.Deferred: defer.Deferred:
""" """
def f(txn, table, except_tokens, call_after_delete): def f(txn):
sql = "SELECT token FROM %s WHERE user_id = ?" % table keyvalues = {
clauses = [user_id] "user_id": user_id,
}
if device_id is not None: if device_id is not None:
sql += " AND device_id = ?" keyvalues["device_id"] = device_id
clauses.append(device_id)
if except_tokens: if delete_refresh_tokens:
sql += " AND id NOT IN (%s)" % ( self._simple_delete_txn(
",".join(["?" for _ in except_tokens]), txn,
table="refresh_tokens",
keyvalues=keyvalues,
) )
clauses += except_tokens
txn.execute(sql, clauses) items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
rows = txn.fetchall() values = [v for _, v in items]
if except_token_id:
n = 100 where_clause += " AND id != ?"
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] values.append(except_token_id)
for chunk in chunks:
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
txn.execute( txn.execute(
"DELETE FROM %s WHERE token in (%s)" % ( "SELECT token FROM access_tokens WHERE %s" % where_clause,
table, values
",".join(["?" for _ in chunk]), )
), [r[0] for r in chunk] rows = self.cursor_to_dict(txn)
for row in rows:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],)
) )
# delete refresh tokens first, to stop new access tokens being txn.execute(
# allocated while our backs are turned "DELETE FROM access_tokens WHERE %s" % where_clause,
if delete_refresh_tokens: values
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
) )
yield self.runInteraction( yield self.runInteraction(
"user_delete_access_tokens", f, "user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
) )
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
@ -328,7 +331,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
}, },
) )
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f) return self.runInteraction("delete_access_token", f)

View file

@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
@defer.inlineCallbacks
def forget(self, user_id, room_id): def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""
def f(txn): def f(txn):
@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
" room_id = ?" " room_id = ?"
) )
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all() txn.call_after(self.was_forgotten_at.invalidate_all)
self.who_forgot_in_room.invalidate_all() txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self.did_forget.invalidate((user_id, room_id)) self._invalidate_cache_and_stream(
txn, self.who_forgot_in_room, (room_id,)
)
return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id): def did_forget(self, user_id, room_id):

View file

@ -0,0 +1,23 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS appservice_stream_position(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_ordering BIGINT,
CHECK (Lock='X')
);
INSERT INTO appservice_stream_position (stream_ordering)
SELECT COALESCE(MAX(stream_ordering), 0) FROM events;

View file

@ -0,0 +1,46 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
# This stream is used to notify replication slaves that some caches have
# been invalidated that they cannot infer from the other streams.
CREATE_TABLE = """
CREATE TABLE cache_invalidation_stream (
stream_id BIGINT,
cache_func TEXT,
keys TEXT[],
invalidation_ts BIGINT
);
CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
"""
def run_create(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
return
for statement in get_statements(CREATE_TABLE.splitlines()):
cur.execute(statement)
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -0,0 +1,20 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
type="m.something", room_id="!foo:bar", sender="@someone:somewhere" type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
) )
self.store = Mock()
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")
) )
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse(self.service.is_interested(self.event)) self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")
@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org" self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_id_match(self): def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_id_no_match(self): def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(self.service.is_interested(self.event)) self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_alias_match(self): def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.assertTrue(self.service.is_interested( self.store.get_aliases_for_room.return_value = [
self.event, "#irc_foobar:matrix.org", "#athing:matrix.org"
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"] ]
)) self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested(
self.event, self.store
)))
def test_non_exclusive_alias(self): def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
"!irc_foobar:matrix.org" "!irc_foobar:matrix.org"
)) ))
@defer.inlineCallbacks
def test_regex_alias_no_match(self): def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.assertFalse(self.service.is_interested( self.store.get_aliases_for_room.return_value = [
self.event, "#xmpp_foobar:matrix.org", "#athing:matrix.org"
aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ]
)) self.store.get_users_in_room.return_value = []
self.assertFalse((yield self.service.is_interested(
self.event, self.store
)))
@defer.inlineCallbacks
def test_regex_multiple_matches(self): def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*") _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested( self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
self.event, self.store.get_users_in_room.return_value = []
aliases_for_event=["#irc_barfoo:matrix.org"] self.assertTrue((yield self.service.is_interested(
)) self.event, self.store
)))
def test_restrict_to_rooms(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!flibble_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.event.room_id = "!wibblewoo:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ROOMS
))
def test_restrict_to_aliases(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#xmpp_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ALIASES,
aliases_for_event=["#irc_barfoo:matrix.org"]
))
def test_restrict_to_senders(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#xmpp_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_USERS,
aliases_for_event=["#xmpp_barfoo:matrix.org"]
))
@defer.inlineCallbacks
def test_interested_in_self(self): def test_interested_in_self(self):
# make sure invites get through # make sure invites get through
self.service.sender = "@appservice:name" self.service.sender = "@appservice:name"
@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
"membership": "invite" "membership": "invite"
} }
self.event.state_key = self.service.sender self.event.state_key = self.service.sender
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")
) )
join_list = [ self.store.get_users_in_room.return_value = [
"@alice:here", "@alice:here",
"@irc_fo:here", # AS user "@irc_fo:here", # AS user
"@bob:here", "@bob:here",
] ]
self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(self.service.is_interested( self.assertTrue((yield self.service.is_interested(
event=self.event, event=self.event, store=self.store
member_list=join_list )))
))

View file

@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.txn_ctrl = Mock() self.txn_ctrl = Mock()
self.queuer = _ServiceQueuer(self.txn_ctrl) self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
def test_send_single_event_no_queue(self): def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately. # Expect the event to be sent immediately.

View file

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from .. import unittest from .. import unittest
from tests.utils import MockClock
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
@ -32,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore = Mock(return_value=self.mock_store) hs.get_datastore = Mock(return_value=self.mock_store)
hs.get_application_service_api = Mock(return_value=self.mock_as_api) hs.get_application_service_api = Mock(return_value=self.mock_as_api)
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs) self.handler = ApplicationServicesHandler(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
type="m.room.message", type="m.room.message",
room_id="!foo:bar" room_id="!foo:bar"
) )
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(event) yield self.handler.notify_interested_services(0)
self.mock_scheduler.submit_event_for_as.assert_called_once_with( self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event interested_service, event
) )
@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event) self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with( self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id services[0], user_id
) )
@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event) self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
yield self.handler.notify_interested_services(0)
self.assertFalse( self.assertFalse(
self.mock_as_api.query_user.called, self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been." "query_user called when it shouldn't have been."
@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet" room_id = "!alpha:bet"
servers = ["aperture"] servers = ["aperture"]
interested_service = self._mkservice(is_interested=True) interested_service = self._mkservice_alias(is_interested_in_alias=True)
services = [ services = [
self._mkservice(is_interested=False), self._mkservice_alias(is_interested_in_alias=False),
interested_service, interested_service,
self._mkservice(is_interested=False) self._mkservice_alias(is_interested_in_alias=False)
] ]
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token" service.token = "mock_service_token"
service.url = "mock_service_url" service.url = "mock_service_url"
return service return service
def _mkservice_alias(self, is_interested_in_alias):
service = Mock()
service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service

View file

@ -15,7 +15,9 @@
from . import unittest from . import unittest
from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs from synapse.rest.media.v1.preview_url_resource import (
summarize_paragraphs, decode_and_calc_og
)
class PreviewTestCase(unittest.TestCase): class PreviewTestCase(unittest.TestCase):
@ -137,3 +139,79 @@ class PreviewTestCase(unittest.TestCase):
" of old wooden houses in Northern Norway, the oldest house dating from" " of old wooden houses in Northern Norway, the oldest house dating from"
" 1789. The Arctic Cathedral, a modern church…" " 1789. The Arctic Cathedral, a modern church…"
) )
class PreviewUrlTestCase(unittest.TestCase):
def test_simple(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})
def test_comment(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
<!-- HTML comment -->
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})
def test_comment2(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
<!-- HTML comment -->
Some more text.
<p>Text</p>
More text
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
})
def test_script(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
<script> (function() {})() </script>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})