mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 07:33:47 +01:00
Merge remote-tracking branch 'origin/develop' into paul/thirdpartylookup
This commit is contained in:
commit
d5bf7a4a99
47 changed files with 1203 additions and 546 deletions
|
@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove"
|
|||
tox --notest -e py27
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install lxml
|
||||
|
||||
tox -e py27
|
||||
|
|
212
synapse/app/media_repository.py
Normal file
212
synapse/app/media_repository.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.client_ips import ClientIpStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.media_repository import MediaRepositoryStore
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.api.urls import (
|
||||
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
|
||||
)
|
||||
from synapse.crypto import context_factory
|
||||
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
|
||||
logger = logging.getLogger("synapse.app.media_repository")
|
||||
|
||||
|
||||
class MediaRepositorySlavedStore(
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
BaseSlavedStore,
|
||||
MediaRepositoryStore,
|
||||
ClientIpStore,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class MediaRepositoryServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
elif name == "media":
|
||||
media_repo = MediaRepositoryResource(self)
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse media repository now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
yield store.process_replication(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
yield sleep(5)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse media repository", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.worker_app == "synapse.app.media_repository"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
ss = MediaRepositoryServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.get_handlers()
|
||||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ss.get_datastore().start_profiling()
|
||||
ss.replicate()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-media-repository",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
|
@ -80,11 +80,6 @@ class PusherSlaveStore(
|
|||
DataStore.get_profile_displayname.__func__
|
||||
)
|
||||
|
||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
||||
# in a way that they can be streamed. This means that we don't have a
|
||||
# way to invalidate the forgotten rooms cache correctly.
|
||||
# For now we expire the cache every 10 minutes.
|
||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
||||
who_forgot_in_room = (
|
||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||
)
|
||||
|
@ -168,7 +163,6 @@ class PusherServer(HomeServer):
|
|||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
pusher_pool = self.get_pusherpool()
|
||||
clock = self.get_clock()
|
||||
|
||||
def stop_pusher(user_id, app_id, pushkey):
|
||||
key = "%s:%s" % (app_id, pushkey)
|
||||
|
@ -220,21 +214,11 @@ class PusherServer(HomeServer):
|
|||
min_stream_id, max_stream_id, affected_room_ids
|
||||
)
|
||||
|
||||
def expire_broken_caches():
|
||||
store.who_forgot_in_room.invalidate_all()
|
||||
|
||||
next_expire_broken_caches_ms = 0
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
now_ms = clock.time_msec()
|
||||
if now_ms > next_expire_broken_caches_ms:
|
||||
expire_broken_caches()
|
||||
next_expire_broken_caches_ms = (
|
||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
||||
)
|
||||
yield store.process_replication(result)
|
||||
poke_pushers(result)
|
||||
except:
|
||||
|
|
|
@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
|
|||
from synapse.http.server import JsonResource
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
from synapse.rest.client.v1 import events
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
|
@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
|
|||
BaseSlavedStore,
|
||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
||||
):
|
||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
||||
# in a way that they can be streamed. This means that we don't have a
|
||||
# way to invalidate the forgotten rooms cache correctly.
|
||||
# For now we expire the cache every 10 minutes.
|
||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
||||
who_forgot_in_room = (
|
||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||
)
|
||||
|
@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
|
|||
get_presence_list_accepted = PresenceStore.__dict__[
|
||||
"get_presence_list_accepted"
|
||||
]
|
||||
get_presence_list_observers_accepted = PresenceStore.__dict__[
|
||||
"get_presence_list_observers_accepted"
|
||||
]
|
||||
|
||||
|
||||
UPDATE_SYNCING_USERS_MS = 10 * 1000
|
||||
|
||||
|
||||
class SynchrotronPresence(object):
|
||||
def __init__(self, hs):
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.store = hs.get_datastore()
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {
|
||||
|
@ -124,6 +126,8 @@ class SynchrotronPresence(object):
|
|||
pass
|
||||
|
||||
get_states = PresenceHandler.get_states.__func__
|
||||
get_state = PresenceHandler.get_state.__func__
|
||||
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
|
||||
current_state_for_users = PresenceHandler.current_state_for_users.__func__
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -194,19 +198,39 @@ class SynchrotronPresence(object):
|
|||
self._need_to_send_sync = False
|
||||
yield self._send_syncing_users_now()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_from_replication(self, states, stream_id):
|
||||
parties = yield self._get_interested_parties(
|
||||
states, calculate_remote_hosts=False
|
||||
)
|
||||
room_ids_to_states, users_to_states, _ = parties
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
|
||||
users=users_to_states.keys()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def process_replication(self, result):
|
||||
stream = result.get("presence", {"rows": []})
|
||||
states = []
|
||||
for row in stream["rows"]:
|
||||
(
|
||||
position, user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||
currently_active
|
||||
) = row
|
||||
self.user_to_current_state[user_id] = UserPresenceState(
|
||||
state = UserPresenceState(
|
||||
user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||
currently_active
|
||||
)
|
||||
self.user_to_current_state[user_id] = state
|
||||
states.append(state)
|
||||
|
||||
if states and "position" in stream:
|
||||
stream_id = int(stream["position"])
|
||||
yield self.notify_from_replication(states, stream_id)
|
||||
|
||||
|
||||
class SynchrotronTyping(object):
|
||||
|
@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
|
|||
elif name == "client":
|
||||
resource = JsonResource(self, canonical_json=False)
|
||||
sync.register_servlets(self, resource)
|
||||
events.register_servlets(self, resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
|
@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
|
|||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
clock = self.get_clock()
|
||||
notifier = self.get_notifier()
|
||||
presence_handler = self.get_presence_handler()
|
||||
typing_handler = self.get_typing_handler()
|
||||
|
||||
def expire_broken_caches():
|
||||
store.who_forgot_in_room.invalidate_all()
|
||||
store.get_presence_list_accepted.invalidate_all()
|
||||
|
||||
def notify_from_stream(
|
||||
result, stream_name, stream_key, room=None, user=None
|
||||
):
|
||||
|
@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
|
|||
result, "typing", "typing_key", room="room_id"
|
||||
)
|
||||
|
||||
next_expire_broken_caches_ms = 0
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args.update(typing_handler.stream_positions())
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
now_ms = clock.time_msec()
|
||||
if now_ms > next_expire_broken_caches_ms:
|
||||
expire_broken_caches()
|
||||
next_expire_broken_caches_ms = (
|
||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
||||
)
|
||||
yield store.process_replication(result)
|
||||
typing_handler.process_replication(result)
|
||||
presence_handler.process_replication(result)
|
||||
yield presence_handler.process_replication(result)
|
||||
notify(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
|
@ -138,65 +140,66 @@ class ApplicationService(object):
|
|||
return regex_obj["exclusive"]
|
||||
return False
|
||||
|
||||
def _matches_user(self, event, member_list):
|
||||
if (hasattr(event, "sender") and
|
||||
self.is_interested_in_user(event.sender)):
|
||||
return True
|
||||
@defer.inlineCallbacks
|
||||
def _matches_user(self, event, store):
|
||||
if not event:
|
||||
defer.returnValue(False)
|
||||
|
||||
if self.is_interested_in_user(event.sender):
|
||||
defer.returnValue(True)
|
||||
# also check m.room.member state key
|
||||
if (hasattr(event, "type") and event.type == EventTypes.Member
|
||||
and hasattr(event, "state_key")
|
||||
and self.is_interested_in_user(event.state_key)):
|
||||
return True
|
||||
if (event.type == EventTypes.Member and
|
||||
self.is_interested_in_user(event.state_key)):
|
||||
defer.returnValue(True)
|
||||
|
||||
if not store:
|
||||
defer.returnValue(False)
|
||||
|
||||
member_list = yield store.get_users_in_room(event.room_id)
|
||||
|
||||
# check joined member events
|
||||
for user_id in member_list:
|
||||
if self.is_interested_in_user(user_id):
|
||||
return True
|
||||
return False
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
||||
def _matches_room_id(self, event):
|
||||
if hasattr(event, "room_id"):
|
||||
return self.is_interested_in_room(event.room_id)
|
||||
return False
|
||||
|
||||
def _matches_aliases(self, event, alias_list):
|
||||
@defer.inlineCallbacks
|
||||
def _matches_aliases(self, event, store):
|
||||
if not store or not event:
|
||||
defer.returnValue(False)
|
||||
|
||||
alias_list = yield store.get_aliases_for_room(event.room_id)
|
||||
for alias in alias_list:
|
||||
if self.is_interested_in_alias(alias):
|
||||
return True
|
||||
return False
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
||||
def is_interested(self, event, restrict_to=None, aliases_for_event=None,
|
||||
member_list=None):
|
||||
@defer.inlineCallbacks
|
||||
def is_interested(self, event, store=None):
|
||||
"""Check if this service is interested in this event.
|
||||
|
||||
Args:
|
||||
event(Event): The event to check.
|
||||
restrict_to(str): The namespace to restrict regex tests to.
|
||||
aliases_for_event(list): A list of all the known room aliases for
|
||||
this event.
|
||||
member_list(list): A list of all joined user_ids in this room.
|
||||
store(DataStore)
|
||||
Returns:
|
||||
bool: True if this service would like to know about this event.
|
||||
"""
|
||||
if aliases_for_event is None:
|
||||
aliases_for_event = []
|
||||
if member_list is None:
|
||||
member_list = []
|
||||
# Do cheap checks first
|
||||
if self._matches_room_id(event):
|
||||
defer.returnValue(True)
|
||||
|
||||
if restrict_to and restrict_to not in ApplicationService.NS_LIST:
|
||||
# this is a programming error, so fail early and raise a general
|
||||
# exception
|
||||
raise Exception("Unexpected restrict_to value: %s". restrict_to)
|
||||
if (yield self._matches_aliases(event, store)):
|
||||
defer.returnValue(True)
|
||||
|
||||
if not restrict_to:
|
||||
return (self._matches_user(event, member_list)
|
||||
or self._matches_aliases(event, aliases_for_event)
|
||||
or self._matches_room_id(event))
|
||||
elif restrict_to == ApplicationService.NS_ALIASES:
|
||||
return self._matches_aliases(event, aliases_for_event)
|
||||
elif restrict_to == ApplicationService.NS_ROOMS:
|
||||
return self._matches_room_id(event)
|
||||
elif restrict_to == ApplicationService.NS_USERS:
|
||||
return self._matches_user(event, member_list)
|
||||
if (yield self._matches_user(event, store)):
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
||||
def is_interested_in_user(self, user_id):
|
||||
return (
|
||||
|
|
|
@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS
|
|||
This is all tied together by the AppServiceScheduler which DIs the required
|
||||
components.
|
||||
"""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.appservice import ApplicationServiceState
|
||||
from twisted.internet import defer
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
|
|||
self.txn_ctrl = _TransactionController(
|
||||
self.clock, self.store, self.as_api, create_recoverer
|
||||
)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start(self):
|
||||
|
@ -94,38 +97,36 @@ class _ServiceQueuer(object):
|
|||
this schedules any other events in the queue to run.
|
||||
"""
|
||||
|
||||
def __init__(self, txn_ctrl):
|
||||
def __init__(self, txn_ctrl, clock):
|
||||
self.queued_events = {} # dict of {service_id: [events]}
|
||||
self.pending_requests = {} # dict of {service_id: Deferred}
|
||||
self.requests_in_flight = set()
|
||||
self.txn_ctrl = txn_ctrl
|
||||
self.clock = clock
|
||||
|
||||
def enqueue(self, service, event):
|
||||
# if this service isn't being sent something
|
||||
if not self.pending_requests.get(service.id):
|
||||
self._send_request(service, [event])
|
||||
else:
|
||||
# add to queue for this service
|
||||
if service.id not in self.queued_events:
|
||||
self.queued_events[service.id] = []
|
||||
self.queued_events[service.id].append(event)
|
||||
self.queued_events.setdefault(service.id, []).append(event)
|
||||
preserve_fn(self._send_request)(service)
|
||||
|
||||
def _send_request(self, service, events):
|
||||
# send request and add callbacks
|
||||
d = self.txn_ctrl.send(service, events)
|
||||
d.addBoth(self._on_request_finish)
|
||||
d.addErrback(self._on_request_fail)
|
||||
self.pending_requests[service.id] = d
|
||||
@defer.inlineCallbacks
|
||||
def _send_request(self, service):
|
||||
if service.id in self.requests_in_flight:
|
||||
return
|
||||
|
||||
def _on_request_finish(self, service):
|
||||
self.pending_requests[service.id] = None
|
||||
# if there are queued events, then send them.
|
||||
if (service.id in self.queued_events
|
||||
and len(self.queued_events[service.id]) > 0):
|
||||
self._send_request(service, self.queued_events[service.id])
|
||||
self.queued_events[service.id] = []
|
||||
self.requests_in_flight.add(service.id)
|
||||
try:
|
||||
while True:
|
||||
events = self.queued_events.pop(service.id, [])
|
||||
if not events:
|
||||
return
|
||||
|
||||
def _on_request_fail(self, err):
|
||||
logger.error("AS request failed: %s", err)
|
||||
with Measure(self.clock, "servicequeuer.send"):
|
||||
try:
|
||||
yield self.txn_ctrl.send(service, events)
|
||||
except:
|
||||
logger.exception("AS request failed")
|
||||
finally:
|
||||
self.requests_in_flight.discard(service.id)
|
||||
|
||||
|
||||
class _TransactionController(object):
|
||||
|
@ -155,8 +156,6 @@ class _TransactionController(object):
|
|||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self._start_recoverer(service)
|
||||
# request has finished
|
||||
defer.returnValue(service)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_recovered(self, recoverer):
|
||||
|
|
|
@ -19,7 +19,6 @@ from .room import (
|
|||
)
|
||||
from .room_member import RoomMemberHandler
|
||||
from .message import MessageHandler
|
||||
from .events import EventStreamHandler, EventHandler
|
||||
from .federation import FederationHandler
|
||||
from .profile import ProfileHandler
|
||||
from .directory import DirectoryHandler
|
||||
|
@ -53,8 +52,6 @@ class Handlers(object):
|
|||
self.message_handler = MessageHandler(hs)
|
||||
self.room_creation_handler = RoomCreationHandler(hs)
|
||||
self.room_member_handler = RoomMemberHandler(hs)
|
||||
self.event_stream_handler = EventStreamHandler(hs)
|
||||
self.event_handler = EventHandler(hs)
|
||||
self.federation_handler = FederationHandler(hs)
|
||||
self.profile_handler = ProfileHandler(hs)
|
||||
self.directory_handler = DirectoryHandler(hs)
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -42,36 +43,60 @@ class ApplicationServicesHandler(object):
|
|||
self.appservice_api = hs.get_application_service_api()
|
||||
self.scheduler = hs.get_application_service_scheduler()
|
||||
self.started_scheduler = False
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_interested_services(self, event):
|
||||
def notify_interested_services(self, current_id):
|
||||
"""Notifies (pushes) all application services interested in this event.
|
||||
|
||||
Pushing is done asynchronously, so this method won't block for any
|
||||
prolonged length of time.
|
||||
|
||||
Args:
|
||||
event(Event): The event to push out to interested services.
|
||||
current_id(int): The current maximum ID.
|
||||
"""
|
||||
# Gather interested services
|
||||
services = yield self._get_services_for_event(event)
|
||||
if len(services) == 0:
|
||||
return # no services need notifying
|
||||
services = yield self.store.get_app_services()
|
||||
if not services:
|
||||
return
|
||||
|
||||
# Do we know this user exists? If not, poke the user query API for
|
||||
# all services which match that user regex. This needs to block as these
|
||||
# user queries need to be made BEFORE pushing the event.
|
||||
yield self._check_user_exists(event.sender)
|
||||
if event.type == EventTypes.Member:
|
||||
yield self._check_user_exists(event.state_key)
|
||||
with Measure(self.clock, "notify_interested_services"):
|
||||
upper_bound = current_id
|
||||
limit = 100
|
||||
while True:
|
||||
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
||||
upper_bound, limit
|
||||
)
|
||||
|
||||
if not self.started_scheduler:
|
||||
self.scheduler.start().addErrback(log_failure)
|
||||
self.started_scheduler = True
|
||||
logger.info("Current_id: %r, upper_bound: %r", current_id, upper_bound)
|
||||
|
||||
# Fork off pushes to these services
|
||||
for service in services:
|
||||
self.scheduler.submit_event_for_as(service, event)
|
||||
if not events:
|
||||
break
|
||||
|
||||
for event in events:
|
||||
# Gather interested services
|
||||
services = yield self._get_services_for_event(event)
|
||||
if len(services) == 0:
|
||||
continue # no services need notifying
|
||||
|
||||
# Do we know this user exists? If not, poke the user query API for
|
||||
# all services which match that user regex. This needs to block as
|
||||
# these user queries need to be made BEFORE pushing the event.
|
||||
yield self._check_user_exists(event.sender)
|
||||
if event.type == EventTypes.Member:
|
||||
yield self._check_user_exists(event.state_key)
|
||||
|
||||
if not self.started_scheduler:
|
||||
self.scheduler.start().addErrback(log_failure)
|
||||
self.started_scheduler = True
|
||||
|
||||
# Fork off pushes to these services
|
||||
for service in services:
|
||||
preserve_fn(self.scheduler.submit_event_for_as)(service, event)
|
||||
|
||||
yield self.store.set_appservice_last_pos(upper_bound)
|
||||
|
||||
if len(events) < limit:
|
||||
break
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_user_exists(self, user_id):
|
||||
|
@ -104,11 +129,12 @@ class ApplicationServicesHandler(object):
|
|||
association can be found.
|
||||
"""
|
||||
room_alias_str = room_alias.to_string()
|
||||
alias_query_services = yield self._get_services_for_event(
|
||||
event=None,
|
||||
restrict_to=ApplicationService.NS_ALIASES,
|
||||
alias_list=[room_alias_str]
|
||||
)
|
||||
services = yield self.store.get_app_services()
|
||||
alias_query_services = [
|
||||
s for s in services if (
|
||||
s.is_interested_in_alias(room_alias_str)
|
||||
)
|
||||
]
|
||||
for alias_service in alias_query_services:
|
||||
is_known_alias = yield self.appservice_api.query_alias(
|
||||
alias_service, room_alias_str
|
||||
|
@ -136,34 +162,19 @@ class ApplicationServicesHandler(object):
|
|||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_services_for_event(self, event, restrict_to="", alias_list=None):
|
||||
def _get_services_for_event(self, event):
|
||||
"""Retrieve a list of application services interested in this event.
|
||||
|
||||
Args:
|
||||
event(Event): The event to check. Can be None if alias_list is not.
|
||||
restrict_to(str): The namespace to restrict regex tests to.
|
||||
alias_list: A list of aliases to get services for. If None, this
|
||||
list is obtained from the database.
|
||||
Returns:
|
||||
list<ApplicationService>: A list of services interested in this
|
||||
event based on the service regex.
|
||||
"""
|
||||
member_list = None
|
||||
if hasattr(event, "room_id"):
|
||||
# We need to know the aliases associated with this event.room_id,
|
||||
# if any.
|
||||
if not alias_list:
|
||||
alias_list = yield self.store.get_aliases_for_room(
|
||||
event.room_id
|
||||
)
|
||||
# We need to know the members associated with this event.room_id,
|
||||
# if any.
|
||||
member_list = yield self.store.get_users_in_room(event.room_id)
|
||||
|
||||
services = yield self.store.get_app_services()
|
||||
interested_list = [
|
||||
s for s in services if (
|
||||
s.is_interested(event, restrict_to, alias_list, member_list)
|
||||
yield s.is_interested(event, self.store)
|
||||
)
|
||||
]
|
||||
defer.returnValue(interested_list)
|
||||
|
|
|
@ -741,7 +741,7 @@ class AuthHandler(BaseHandler):
|
|||
def set_password(self, user_id, newpassword, requester=None):
|
||||
password_hash = self.hash(newpassword)
|
||||
|
||||
except_access_token_ids = [requester.access_token_id] if requester else []
|
||||
except_access_token_id = requester.access_token_id if requester else None
|
||||
|
||||
try:
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
|
@ -750,10 +750,10 @@ class AuthHandler(BaseHandler):
|
|||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||
raise e
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, except_access_token_ids
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
||||
user_id, except_access_token_ids
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -274,7 +274,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit, extremities=[]):
|
||||
def backfill(self, dest, room_id, limit, extremities):
|
||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||
|
||||
This will attempt to get more events from the remote. This may return
|
||||
|
@ -284,9 +284,6 @@ class FederationHandler(BaseHandler):
|
|||
if dest == self.server_name:
|
||||
raise SynapseError(400, "Can't backfill from self.")
|
||||
|
||||
if not extremities:
|
||||
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
||||
|
||||
events = yield self.replication_layer.backfill(
|
||||
dest,
|
||||
room_id,
|
||||
|
@ -455,6 +452,10 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
max_depth = sorted_extremeties_tuple[0][1]
|
||||
|
||||
# We don't want to specify too many extremities as it causes the backfill
|
||||
# request URI to be too long.
|
||||
extremities = dict(sorted_extremeties_tuple[:5])
|
||||
|
||||
if current_depth > max_depth:
|
||||
logger.debug(
|
||||
"Not backfilling as we don't need to. %d < %d",
|
||||
|
|
|
@ -503,7 +503,7 @@ class PresenceHandler(object):
|
|||
defer.returnValue(states)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_interested_parties(self, states):
|
||||
def _get_interested_parties(self, states, calculate_remote_hosts=True):
|
||||
"""Given a list of states return which entities (rooms, users, servers)
|
||||
are interested in the given states.
|
||||
|
||||
|
@ -526,14 +526,15 @@ class PresenceHandler(object):
|
|||
users_to_states.setdefault(state.user_id, []).append(state)
|
||||
|
||||
hosts_to_states = {}
|
||||
for room_id, states in room_ids_to_states.items():
|
||||
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
||||
if not local_states:
|
||||
continue
|
||||
if calculate_remote_hosts:
|
||||
for room_id, states in room_ids_to_states.items():
|
||||
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
||||
if not local_states:
|
||||
continue
|
||||
|
||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
for host in hosts:
|
||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
for host in hosts:
|
||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||
|
||||
for user_id, states in users_to_states.items():
|
||||
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
||||
|
@ -565,6 +566,16 @@ class PresenceHandler(object):
|
|||
|
||||
self._push_to_remotes(hosts_to_states)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_for_states(self, state, stream_id):
|
||||
parties = yield self._get_interested_parties([state])
|
||||
room_ids_to_states, users_to_states, hosts_to_states = parties
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
|
||||
users=[UserID.from_string(u) for u in users_to_states.keys()]
|
||||
)
|
||||
|
||||
def _push_to_remotes(self, hosts_to_states):
|
||||
"""Sends state updates to remote servers.
|
||||
|
||||
|
|
|
@ -67,10 +67,8 @@ class _NotifierUserStream(object):
|
|||
so that it can remove itself from the indexes in the Notifier class.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id, rooms, current_token, time_now_ms,
|
||||
appservice=None):
|
||||
def __init__(self, user_id, rooms, current_token, time_now_ms):
|
||||
self.user_id = user_id
|
||||
self.appservice = appservice
|
||||
self.rooms = set(rooms)
|
||||
self.current_token = current_token
|
||||
self.last_notified_ms = time_now_ms
|
||||
|
@ -107,11 +105,6 @@ class _NotifierUserStream(object):
|
|||
|
||||
notifier.user_to_user_stream.pop(self.user_id)
|
||||
|
||||
if self.appservice:
|
||||
notifier.appservice_to_user_streams.get(
|
||||
self.appservice, set()
|
||||
).discard(self)
|
||||
|
||||
def count_listeners(self):
|
||||
return len(self.notify_deferred.observers())
|
||||
|
||||
|
@ -142,7 +135,6 @@ class Notifier(object):
|
|||
def __init__(self, hs):
|
||||
self.user_to_user_stream = {}
|
||||
self.room_to_user_streams = {}
|
||||
self.appservice_to_user_streams = {}
|
||||
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -168,8 +160,6 @@ class Notifier(object):
|
|||
all_user_streams |= x
|
||||
for x in self.user_to_user_stream.values():
|
||||
all_user_streams.add(x)
|
||||
for x in self.appservice_to_user_streams.values():
|
||||
all_user_streams |= x
|
||||
|
||||
return sum(stream.count_listeners() for stream in all_user_streams)
|
||||
metrics.register_callback("listeners", count_listeners)
|
||||
|
@ -182,10 +172,6 @@ class Notifier(object):
|
|||
"users",
|
||||
lambda: len(self.user_to_user_stream),
|
||||
)
|
||||
metrics.register_callback(
|
||||
"appservices",
|
||||
lambda: count(bool, self.appservice_to_user_streams.values()),
|
||||
)
|
||||
|
||||
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
|
||||
extra_users=[]):
|
||||
|
@ -228,21 +214,7 @@ class Notifier(object):
|
|||
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
|
||||
"""Notify any user streams that are interested in this room event"""
|
||||
# poke any interested application service.
|
||||
self.appservice_handler.notify_interested_services(event)
|
||||
|
||||
app_streams = set()
|
||||
|
||||
for appservice in self.appservice_to_user_streams:
|
||||
# TODO (kegan): Redundant appservice listener checks?
|
||||
# App services will already be in the room_to_user_streams set, but
|
||||
# that isn't enough. They need to be checked here in order to
|
||||
# receive *invites* for users they are interested in. Does this
|
||||
# make the room_to_user_streams check somewhat obselete?
|
||||
if appservice.is_interested(event):
|
||||
app_user_streams = self.appservice_to_user_streams.get(
|
||||
appservice, set()
|
||||
)
|
||||
app_streams |= app_user_streams
|
||||
self.appservice_handler.notify_interested_services(room_stream_id)
|
||||
|
||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||
self._user_joined_room(event.state_key, event.room_id)
|
||||
|
@ -251,11 +223,9 @@ class Notifier(object):
|
|||
"room_key", room_stream_id,
|
||||
users=extra_users,
|
||||
rooms=[event.room_id],
|
||||
extra_streams=app_streams,
|
||||
)
|
||||
|
||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
|
||||
extra_streams=set()):
|
||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
|
||||
""" Used to inform listeners that something has happend event wise.
|
||||
|
||||
Will wake up all listeners for the given users and rooms.
|
||||
|
@ -294,7 +264,6 @@ class Notifier(object):
|
|||
"""
|
||||
user_stream = self.user_to_user_stream.get(user_id)
|
||||
if user_stream is None:
|
||||
appservice = yield self.store.get_app_service_by_user_id(user_id)
|
||||
current_token = yield self.event_sources.get_current_token()
|
||||
if room_ids is None:
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
|
@ -302,7 +271,6 @@ class Notifier(object):
|
|||
user_stream = _NotifierUserStream(
|
||||
user_id=user_id,
|
||||
rooms=room_ids,
|
||||
appservice=appservice,
|
||||
current_token=current_token,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
)
|
||||
|
@ -477,11 +445,6 @@ class Notifier(object):
|
|||
s = self.room_to_user_streams.setdefault(room, set())
|
||||
s.add(user_stream)
|
||||
|
||||
if user_stream.appservice:
|
||||
self.appservice_to_user_stream.setdefault(
|
||||
user_stream.appservice, set()
|
||||
).add(user_stream)
|
||||
|
||||
def _user_joined_room(self, user_id, room_id):
|
||||
new_user_stream = self.user_to_user_stream.get(user_id)
|
||||
if new_user_stream is not None:
|
||||
|
|
|
@ -38,15 +38,16 @@ class ActionGenerator:
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def handle_push_actions_for_event(self, event, context):
|
||||
with Measure(self.clock, "handle_push_actions_for_event"):
|
||||
with Measure(self.clock, "evaluator_for_event"):
|
||||
bulk_evaluator = yield evaluator_for_event(
|
||||
event, self.hs, self.store, context.current_state
|
||||
)
|
||||
|
||||
with Measure(self.clock, "action_for_event_by_user"):
|
||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||
event, context.current_state
|
||||
)
|
||||
|
||||
context.push_actions = [
|
||||
(uid, actions) for uid, actions in actions_by_user.items()
|
||||
]
|
||||
context.push_actions = [
|
||||
(uid, actions) for uid, actions in actions_by_user.items()
|
||||
]
|
||||
|
|
|
@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [
|
|||
'dont_notify'
|
||||
]
|
||||
},
|
||||
# This was changed from underride to override so it's closer in priority
|
||||
# to the content rules where the user name highlight rule lives. This
|
||||
# way a room rule is lower priority than both but a custom override rule
|
||||
# is higher priority than both.
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.contains_display_name',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'contains_display_name'
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
'notify',
|
||||
{
|
||||
'set_tweak': 'sound',
|
||||
'value': 'default'
|
||||
}, {
|
||||
'set_tweak': 'highlight'
|
||||
}
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/underride/.m.rule.contains_display_name',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'contains_display_name'
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
'notify',
|
||||
{
|
||||
'set_tweak': 'sound',
|
||||
'value': 'default'
|
||||
}, {
|
||||
'set_tweak': 'highlight'
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/underride/.m.rule.room_one_to_one',
|
||||
'conditions': [
|
||||
|
|
|
@ -102,14 +102,14 @@ class PusherPool:
|
|||
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pushers_by_user(self, user_id, except_token_ids=[]):
|
||||
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
|
||||
all = yield self.store.get_all_pushers()
|
||||
logger.info(
|
||||
"Removing all pushers for user %s except access tokens ids %r",
|
||||
user_id, except_token_ids
|
||||
"Removing all pushers for user %s except access tokens id %r",
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
for p in all:
|
||||
if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
|
||||
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
|
||||
logger.info(
|
||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||
p['app_id'], p['pushkey'], p['user_name']
|
||||
|
|
|
@ -41,6 +41,7 @@ STREAM_NAMES = (
|
|||
("push_rules",),
|
||||
("pushers",),
|
||||
("state",),
|
||||
("caches",),
|
||||
)
|
||||
|
||||
|
||||
|
@ -70,6 +71,7 @@ class ReplicationResource(Resource):
|
|||
* "backfill": Old events that have been backfilled from other servers.
|
||||
* "push_rules": Per user changes to push rules.
|
||||
* "pushers": Per user changes to their pushers.
|
||||
* "caches": Cache invalidations.
|
||||
|
||||
The API takes two additional query parameters:
|
||||
|
||||
|
@ -129,6 +131,7 @@ class ReplicationResource(Resource):
|
|||
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
|
||||
pushers_token = self.store.get_pushers_stream_token()
|
||||
state_token = self.store.get_state_stream_token()
|
||||
caches_token = self.store.get_cache_stream_token()
|
||||
|
||||
defer.returnValue(_ReplicationToken(
|
||||
room_stream_token,
|
||||
|
@ -140,6 +143,7 @@ class ReplicationResource(Resource):
|
|||
push_rules_token,
|
||||
pushers_token,
|
||||
state_token,
|
||||
caches_token,
|
||||
))
|
||||
|
||||
@request_handler()
|
||||
|
@ -188,6 +192,7 @@ class ReplicationResource(Resource):
|
|||
yield self.push_rules(writer, current_token, limit, request_streams)
|
||||
yield self.pushers(writer, current_token, limit, request_streams)
|
||||
yield self.state(writer, current_token, limit, request_streams)
|
||||
yield self.caches(writer, current_token, limit, request_streams)
|
||||
self.streams(writer, current_token, request_streams)
|
||||
|
||||
logger.info("Replicated %d rows", writer.total)
|
||||
|
@ -379,6 +384,20 @@ class ReplicationResource(Resource):
|
|||
"position", "type", "state_key", "event_id"
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def caches(self, writer, current_token, limit, request_streams):
|
||||
current_position = current_token.caches
|
||||
|
||||
caches = request_streams.get("caches")
|
||||
|
||||
if caches is not None:
|
||||
updated_caches = yield self.store.get_all_updated_caches(
|
||||
caches, current_position, limit
|
||||
)
|
||||
writer.write_header_and_rows("caches", updated_caches, (
|
||||
"position", "cache_func", "keys", "invalidation_ts"
|
||||
))
|
||||
|
||||
|
||||
class _Writer(object):
|
||||
"""Writes the streams as a JSON object as the response to the request"""
|
||||
|
@ -407,7 +426,7 @@ class _Writer(object):
|
|||
|
||||
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
||||
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
||||
"push_rules", "pushers", "state"
|
||||
"push_rules", "pushers", "state", "caches",
|
||||
))):
|
||||
__slots__ = []
|
||||
|
||||
|
|
|
@ -14,15 +14,43 @@
|
|||
# limitations under the License.
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseSlavedStore(SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(hs)
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = SlavedIdTracker(
|
||||
db_conn, "cache_invalidation_stream", "stream_id",
|
||||
)
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
||||
def stream_positions(self):
|
||||
return {}
|
||||
pos = {}
|
||||
if self._cache_id_gen:
|
||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||
return pos
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("caches")
|
||||
if stream:
|
||||
for row in stream["rows"]:
|
||||
(
|
||||
position, cache_func, keys, invalidation_ts,
|
||||
) = row
|
||||
|
||||
try:
|
||||
getattr(self, cache_func).invalidate(tuple(keys))
|
||||
except AttributeError:
|
||||
logger.info("Got unexpected cache_func: %r", cache_func)
|
||||
self._cache_id_gen.advance(int(stream["position"]))
|
||||
return defer.succeed(None)
|
||||
|
|
|
@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
|
|||
class DirectoryStore(BaseSlavedStore):
|
||||
get_aliases_for_room = DirectoryStore.__dict__[
|
||||
"get_aliases_for_room"
|
||||
].orig
|
||||
]
|
||||
|
|
|
@ -25,6 +25,6 @@ class SlavedRegistrationStore(BaseSlavedStore):
|
|||
# TODO: use the cached version and invalidate deleted tokens
|
||||
get_user_by_access_token = RegistrationStore.__dict__[
|
||||
"get_user_by_access_token"
|
||||
].orig
|
||||
]
|
||||
|
||||
_query_for_auth = DataStore._query_for_auth.__func__
|
||||
|
|
|
@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
|
|||
class WhoisRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(WhoisRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
|
|||
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PurgeHistoryRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
|
|
@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
|
|||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self.hs = hs
|
||||
self.handlers = hs.get_handlers()
|
||||
self.builder_factory = hs.get_event_builder_factory()
|
||||
self.auth = hs.get_v1auth()
|
||||
self.txns = HttpTransactionStore()
|
||||
|
|
|
@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
|
|||
class ClientDirectoryServer(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ClientDirectoryServer, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_alias):
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
|
@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(ClientDirectoryListServer, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
|
|
|
@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
|
||||
DEFAULT_LONGPOLL_TIME_MS = 30000
|
||||
|
||||
def __init__(self, hs):
|
||||
super(EventStreamRestServlet, self).__init__(hs)
|
||||
self.event_stream_handler = hs.get_event_stream_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
|
@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
if "room_id" in request.args:
|
||||
room_id = request.args["room_id"][0]
|
||||
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
|
@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
|
||||
as_client_event = "raw" not in request.args
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
chunk = yield self.event_stream_handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
|
@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(EventRestServlet, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.event_handler
|
||||
event = yield handler.get_event(requester.user, event_id)
|
||||
event = yield self.event_handler.get_event(requester.user, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
|
|
|
@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
|
|||
class InitialSyncRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/initialSync$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(InitialSyncRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
|
|
@ -56,6 +56,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
self.cas_enabled = hs.config.cas_enabled
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.device_handler = self.hs.get_device_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def on_GET(self, request):
|
||||
flows = []
|
||||
|
@ -260,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(SAML2RestServlet, self).__init__(hs)
|
||||
self.sp_config = hs.config.saml2_config_path
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -329,6 +331,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
self.cas_service_url = hs.config.cas_service_url
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
|
|
@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
|
|||
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
|||
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
|||
class ProfileRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
|
|
@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
self.sessions = {}
|
||||
self.enable_registration = hs.config.enable_registration
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def on_GET(self, request):
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
|
@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
super(CreateUserRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
|
|
@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
|
|||
class RoomCreateRestServlet(ClientV1RestServlet):
|
||||
# No PATTERN; we have custom dispatch rules here
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomCreateRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = "/createRoom"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
|||
|
||||
# TODO: Needs unit testing for generic events
|
||||
class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(RoomStateEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /room/$roomid/state/$eventtype
|
||||
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
|
||||
|
@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
# TODO: Needs unit testing for generic events + feedback
|
||||
class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomSendEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
|
||||
|
@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
# TODO: Needs unit testing for room ID + alias joins
|
||||
class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(JoinRoomAliasServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /join/$room_identifier[/$txn_id]
|
||||
|
@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
|||
class RoomMemberListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMemberListRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
# TODO support Pagination stream API (limit/tokens)
|
||||
|
@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
|||
class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMessageListRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||
class RoomStateRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomStateRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
|
|||
class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomInitialSyncRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(RoomEventContext, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
|
@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet):
|
|||
|
||||
|
||||
class RoomForgetRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(RoomForgetRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
|
|||
# TODO: Needs unit testing
|
||||
class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMembershipRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/[invite|join|leave]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
||||
|
@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
|
||||
|
||||
class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(RoomRedactEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet):
|
|||
"/search$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(SearchRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
|
|
@ -45,6 +45,7 @@ class DownloadResource(Resource):
|
|||
@request_handler()
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
request.setHeader("Content-Security-Policy", "sandbox")
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_file(request, media_id, name)
|
||||
|
|
|
@ -29,14 +29,13 @@ from synapse.http.server import (
|
|||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
import cgi
|
||||
import ujson as json
|
||||
import urlparse
|
||||
import itertools
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -163,7 +162,7 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
logger.debug("got media_info of '%s'" % media_info)
|
||||
|
||||
if self._is_media(media_info['media_type']):
|
||||
if _is_media(media_info['media_type']):
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
media_info['filesystem_id'], media_info
|
||||
)
|
||||
|
@ -184,11 +183,9 @@ class PreviewUrlResource(Resource):
|
|||
logger.warn("Couldn't get dims for %s" % url)
|
||||
|
||||
# define our OG response for this media
|
||||
elif self._is_html(media_info['media_type']):
|
||||
elif _is_html(media_info['media_type']):
|
||||
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
|
||||
|
||||
from lxml import etree
|
||||
|
||||
file = open(media_info['filename'])
|
||||
body = file.read()
|
||||
file.close()
|
||||
|
@ -199,17 +196,35 @@ class PreviewUrlResource(Resource):
|
|||
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
|
||||
encoding = match.group(1) if match else "utf-8"
|
||||
|
||||
try:
|
||||
parser = etree.HTMLParser(recover=True, encoding=encoding)
|
||||
tree = etree.fromstring(body, parser)
|
||||
og = yield self._calc_og(tree, media_info, requester)
|
||||
except UnicodeDecodeError:
|
||||
# blindly try decoding the body as utf-8, which seems to fix
|
||||
# the charset mismatches on https://google.com
|
||||
parser = etree.HTMLParser(recover=True, encoding=encoding)
|
||||
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
|
||||
og = yield self._calc_og(tree, media_info, requester)
|
||||
og = decode_and_calc_og(body, media_info['uri'], encoding)
|
||||
|
||||
# pre-cache the image for posterity
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
image_info = yield self._download_url(
|
||||
_rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||
)
|
||||
|
||||
if _is_media(image_info['media_type']):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
image_info['filesystem_id'], image_info
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
og["og:image:height"] = dims['height']
|
||||
else:
|
||||
logger.warn("Couldn't get dims for %s" % og["og:image"])
|
||||
|
||||
og["og:image"] = "mxc://%s/%s" % (
|
||||
self.server_name, image_info['filesystem_id']
|
||||
)
|
||||
og["og:image:type"] = image_info['media_type']
|
||||
og["matrix:image:size"] = image_info['media_length']
|
||||
else:
|
||||
del og["og:image"]
|
||||
else:
|
||||
logger.warn("Failed to find any OG data in %s", url)
|
||||
og = {}
|
||||
|
@ -232,139 +247,6 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calc_og(self, tree, media_info, requester):
|
||||
# suck our tree into lxml and define our OG response.
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
# (although the client could choose to do this by asking for previews of those
|
||||
# URLs to avoid DoSing the server)
|
||||
|
||||
# "og:type" : "video",
|
||||
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
|
||||
# "og:site_name" : "YouTube",
|
||||
# "og:video:type" : "application/x-shockwave-flash",
|
||||
# "og:description" : "Fun stuff happening here",
|
||||
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
|
||||
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
|
||||
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
|
||||
# "og:video:width" : "1280"
|
||||
# "og:video:height" : "720",
|
||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||
|
||||
og = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if 'content' in tag.attrib:
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
# "article:publisher" : "https://www.facebook.com/thethudonline" />
|
||||
# "article:author" content="https://www.facebook.com/thethudonline" />
|
||||
# "article:tag" content="baby" />
|
||||
# "article:section" content="Breaking News" />
|
||||
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||
|
||||
if 'og:title' not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
og['og:title'] = title[0].text.strip() if title else None
|
||||
|
||||
if 'og:image' not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
)
|
||||
if meta_image:
|
||||
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
|
||||
else:
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(images, key=lambda i: (
|
||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||
))
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
if images:
|
||||
og['og:image'] = images[0].attrib['src']
|
||||
|
||||
# pre-cache the image for posterity
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
image_info = yield self._download_url(
|
||||
self._rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||
)
|
||||
|
||||
if self._is_media(image_info['media_type']):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
image_info['filesystem_id'], image_info
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
og["og:image:height"] = dims['height']
|
||||
else:
|
||||
logger.warn("Couldn't get dims for %s" % og["og:image"])
|
||||
|
||||
og["og:image"] = "mxc://%s/%s" % (
|
||||
self.server_name, image_info['filesystem_id']
|
||||
)
|
||||
og["og:image:type"] = image_info['media_type']
|
||||
og["matrix:image:size"] = image_info['media_length']
|
||||
else:
|
||||
del og["og:image"]
|
||||
|
||||
if 'og:description' not in og:
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta"
|
||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||
"/@content")
|
||||
if meta_description:
|
||||
og['og:description'] = meta_description[0]
|
||||
else:
|
||||
# grab any text nodes which are inside the <body/> tag...
|
||||
# unless they are within an HTML5 semantic markup tag...
|
||||
# <header/>, <nav/>, <aside/>, <footer/>
|
||||
# ...or if they are within a <script/> or <style/> tag.
|
||||
# This is a very very very coarse approximation to a plain text
|
||||
# render of the page.
|
||||
|
||||
# We don't just use XPATH here as that is slow on some machines.
|
||||
|
||||
# We clone `tree` as we modify it.
|
||||
cloned_tree = deepcopy(tree.find("body"))
|
||||
|
||||
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
|
||||
for el in cloned_tree.iter(TAGS_TO_REMOVE):
|
||||
el.getparent().remove(el)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r'\s+', '\n', el.text).strip()
|
||||
for el in cloned_tree.iter()
|
||||
if el.text and isinstance(el.tag, basestring) # Removes comments
|
||||
)
|
||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
defer.returnValue(og)
|
||||
|
||||
def _rebase_url(self, url, base):
|
||||
base = list(urlparse.urlparse(base))
|
||||
url = list(urlparse.urlparse(url))
|
||||
if not url[0]: # fix up schema
|
||||
url[0] = base[0] or "http"
|
||||
if not url[1]: # fix up hostname
|
||||
url[1] = base[1]
|
||||
if not url[2].startswith('/'):
|
||||
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
|
||||
return urlparse.urlunparse(url)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _download_url(self, url, user):
|
||||
# TODO: we should probably honour robots.txt... except in practice
|
||||
|
@ -445,17 +327,171 @@ class PreviewUrlResource(Resource):
|
|||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||
})
|
||||
|
||||
def _is_media(self, content_type):
|
||||
if content_type.lower().startswith("image/"):
|
||||
return True
|
||||
|
||||
def _is_html(self, content_type):
|
||||
content_type = content_type.lower()
|
||||
if (
|
||||
content_type.startswith("text/html") or
|
||||
content_type.startswith("application/xhtml")
|
||||
):
|
||||
return True
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||
from lxml import etree
|
||||
|
||||
try:
|
||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||
tree = etree.fromstring(body, parser)
|
||||
og = _calc_og(tree, media_uri)
|
||||
except UnicodeDecodeError:
|
||||
# blindly try decoding the body as utf-8, which seems to fix
|
||||
# the charset mismatches on https://google.com
|
||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
|
||||
og = _calc_og(tree, media_uri)
|
||||
|
||||
return og
|
||||
|
||||
|
||||
def _calc_og(tree, media_uri):
|
||||
# suck our tree into lxml and define our OG response.
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
# (although the client could choose to do this by asking for previews of those
|
||||
# URLs to avoid DoSing the server)
|
||||
|
||||
# "og:type" : "video",
|
||||
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
|
||||
# "og:site_name" : "YouTube",
|
||||
# "og:video:type" : "application/x-shockwave-flash",
|
||||
# "og:description" : "Fun stuff happening here",
|
||||
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
|
||||
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
|
||||
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
|
||||
# "og:video:width" : "1280"
|
||||
# "og:video:height" : "720",
|
||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||
|
||||
og = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if 'content' in tag.attrib:
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
# "article:publisher" : "https://www.facebook.com/thethudonline" />
|
||||
# "article:author" content="https://www.facebook.com/thethudonline" />
|
||||
# "article:tag" content="baby" />
|
||||
# "article:section" content="Breaking News" />
|
||||
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||
|
||||
if 'og:title' not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
og['og:title'] = title[0].text.strip() if title else None
|
||||
|
||||
if 'og:image' not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
)
|
||||
if meta_image:
|
||||
og['og:image'] = _rebase_url(meta_image[0], media_uri)
|
||||
else:
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(images, key=lambda i: (
|
||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||
))
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
if images:
|
||||
og['og:image'] = images[0].attrib['src']
|
||||
|
||||
if 'og:description' not in og:
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta"
|
||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||
"/@content")
|
||||
if meta_description:
|
||||
og['og:description'] = meta_description[0]
|
||||
else:
|
||||
# grab any text nodes which are inside the <body/> tag...
|
||||
# unless they are within an HTML5 semantic markup tag...
|
||||
# <header/>, <nav/>, <aside/>, <footer/>
|
||||
# ...or if they are within a <script/> or <style/> tag.
|
||||
# This is a very very very coarse approximation to a plain text
|
||||
# render of the page.
|
||||
|
||||
# We don't just use XPATH here as that is slow on some machines.
|
||||
|
||||
from lxml import etree
|
||||
|
||||
TAGS_TO_REMOVE = (
|
||||
"header", "nav", "aside", "footer", "script", "style", etree.Comment
|
||||
)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r'\s+', '\n', el).strip()
|
||||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||
)
|
||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
return og
|
||||
|
||||
|
||||
def _iterate_over_text(tree, *tags_to_ignore):
|
||||
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||
skipping text nodes inside certain tags.
|
||||
"""
|
||||
# This is basically a stack that we extend using itertools.chain.
|
||||
# This will either consist of an element to iterate over *or* a string
|
||||
# to be returned.
|
||||
elements = iter([tree])
|
||||
while True:
|
||||
el = elements.next()
|
||||
if isinstance(el, basestring):
|
||||
yield el
|
||||
elif el is not None and el.tag not in tags_to_ignore:
|
||||
# el.text is the text before the first child, so we can immediately
|
||||
# return it if the text exists.
|
||||
if el.text:
|
||||
yield el.text
|
||||
|
||||
# We add to the stack all the elements children, interspersed with
|
||||
# each child's tail text (if it exists). The tail text of a node
|
||||
# is text that comes *after* the node, so we always include it even
|
||||
# if we ignore the child node.
|
||||
elements = itertools.chain(
|
||||
itertools.chain.from_iterable( # Basically a flatmap
|
||||
[child, child.tail] if child.tail else [child]
|
||||
for child in el.iterchildren()
|
||||
),
|
||||
elements
|
||||
)
|
||||
|
||||
|
||||
def _rebase_url(url, base):
|
||||
base = list(urlparse.urlparse(base))
|
||||
url = list(urlparse.urlparse(url))
|
||||
if not url[0]: # fix up schema
|
||||
url[0] = base[0] or "http"
|
||||
if not url[1]: # fix up hostname
|
||||
url[1] = base[1]
|
||||
if not url[2].startswith('/'):
|
||||
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
|
||||
return urlparse.urlunparse(url)
|
||||
|
||||
|
||||
def _is_media(content_type):
|
||||
if content_type.lower().startswith("image/"):
|
||||
return True
|
||||
|
||||
|
||||
def _is_html(content_type):
|
||||
content_type = content_type.lower()
|
||||
if (
|
||||
content_type.startswith("text/html") or
|
||||
content_type.startswith("application/xhtml")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
||||
|
|
|
@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
|
|||
from synapse.handlers.room import RoomListHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.notifier import Notifier
|
||||
|
@ -94,6 +95,8 @@ class HomeServer(object):
|
|||
'auth_handler',
|
||||
'device_handler',
|
||||
'e2e_keys_handler',
|
||||
'event_handler',
|
||||
'event_stream_handler',
|
||||
'application_service_api',
|
||||
'application_service_scheduler',
|
||||
'application_service_handler',
|
||||
|
@ -214,6 +217,12 @@ class HomeServer(object):
|
|||
def build_application_service_handler(self):
|
||||
return ApplicationServicesHandler(self)
|
||||
|
||||
def build_event_handler(self):
|
||||
return EventHandler(self)
|
||||
|
||||
def build_event_stream_handler(self):
|
||||
return EventStreamHandler(self)
|
||||
|
||||
def build_event_sources(self):
|
||||
return EventSources(self)
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ from .openid import OpenIdStore
|
|||
from .client_ips import ClientIpStore
|
||||
|
||||
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
|
||||
from .engines import PostgresEngine
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = StreamIdGenerator(
|
||||
db_conn, "cache_invalidation_stream", "stream_id",
|
||||
)
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
event_cache_prefill, min_event_val = self._get_cache_dict(
|
||||
db_conn, "events",
|
||||
|
|
|
@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
|||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||
from synapse.util.caches.descriptors import Cache
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
import synapse.metrics
|
||||
|
||||
|
||||
|
@ -305,13 +306,14 @@ class SQLBaseStore(object):
|
|||
func, *args, **kwargs
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
|
||||
for after_callback, after_args in after_callbacks:
|
||||
after_callback(*after_args)
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
finally:
|
||||
for after_callback, after_args in after_callbacks:
|
||||
after_callback(*after_args)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -860,6 +862,62 @@ class SQLBaseStore(object):
|
|||
|
||||
return cache, min_val
|
||||
|
||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
|
||||
This should only be used to invalidate caches where slaves won't
|
||||
otherwise know from other replication streams that the cache should
|
||||
be invalidated.
|
||||
"""
|
||||
txn.call_after(cache_func.invalidate, keys)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# get_next() returns a context manager which is designed to wrap
|
||||
# the transaction. However, we want to only get an ID when we want
|
||||
# to use it, here, so we need to call __enter__ manually, and have
|
||||
# __exit__ called after the transaction finishes.
|
||||
ctx = self._cache_id_gen.get_next()
|
||||
stream_id = ctx.__enter__()
|
||||
txn.call_after(ctx.__exit__, None, None, None)
|
||||
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="cache_invalidation_stream",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"cache_func": cache_func.__name__,
|
||||
"keys": list(keys),
|
||||
"invalidation_ts": self.clock.time_msec(),
|
||||
}
|
||||
)
|
||||
|
||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_updated_caches_txn(txn):
|
||||
# We purposefully don't bound by the current token, as we want to
|
||||
# send across cache invalidations as quickly as possible. Cache
|
||||
# invalidations are idempotent, so duplicates are fine.
|
||||
sql = (
|
||||
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
||||
" FROM cache_invalidation_stream"
|
||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, limit,))
|
||||
return txn.fetchall()
|
||||
return self.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
class _RollbackButIsFineException(Exception):
|
||||
""" This exception is used to rollback a transaction without implying
|
||||
|
|
|
@ -352,3 +352,42 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||
return 0
|
||||
else:
|
||||
return int(last_txn_id[0]) # select 'last_txn' col
|
||||
|
||||
def set_appservice_last_pos(self, pos):
|
||||
def set_appservice_last_pos_txn(txn):
|
||||
txn.execute(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
)
|
||||
return self.runInteraction(
|
||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_appservice(self, current_id, limit):
|
||||
"""Get all new evnets"""
|
||||
|
||||
def get_new_events_for_appservice_txn(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id"
|
||||
" FROM events AS e, appservice_stream_position AS a"
|
||||
" WHERE a.stream_ordering < e.stream_ordering AND e.stream_ordering <= ?"
|
||||
" ORDER BY e.stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (current_id, limit))
|
||||
rows = txn.fetchall()
|
||||
|
||||
upper_bound = current_id
|
||||
if len(rows) == limit:
|
||||
upper_bound = rows[-1][0]
|
||||
|
||||
return upper_bound, [row[1] for row in rows]
|
||||
|
||||
upper_bound, event_ids = yield self.runInteraction(
|
||||
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
|
||||
)
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
||||
defer.returnValue((upper_bound, events))
|
||||
|
|
|
@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
|
|||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
try:
|
||||
yield self._simple_insert(
|
||||
def alias_txn(txn):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_aliases",
|
||||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"room_id": room_id,
|
||||
"creator": creator,
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="room_alias_servers",
|
||||
values=[{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
} for server in servers],
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_aliases_for_room, (room_id,)
|
||||
)
|
||||
|
||||
try:
|
||||
ret = yield self.runInteraction(
|
||||
"create_room_alias_association", alias_txn
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise SynapseError(
|
||||
409, "Room alias %s already exists" % room_alias.to_string()
|
||||
)
|
||||
|
||||
for server in servers:
|
||||
# TODO(erikj): Fix this to bulk insert
|
||||
yield self._simple_insert(
|
||||
"room_alias_servers",
|
||||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
self.get_aliases_for_room.invalidate((room_id,))
|
||||
defer.returnValue(ret)
|
||||
|
||||
def get_room_alias_creator(self, room_alias):
|
||||
return self._simple_select_one_onecol(
|
||||
|
|
|
@ -600,7 +600,8 @@ class EventsStore(SQLBaseStore):
|
|||
"rejections",
|
||||
"redactions",
|
||||
"room_memberships",
|
||||
"state_events"
|
||||
"state_events",
|
||||
"topics"
|
||||
):
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 33
|
||||
SCHEMA_VERSION = 34
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
|
|
@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
|
|||
desc="add_presence_list_pending",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||
result = yield self._simple_update_one(
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
updatevalues={"accepted": True},
|
||||
desc="set_presence_list_accepted",
|
||||
def update_presence_list_txn(txn):
|
||||
result = self._simple_update_one_txn(
|
||||
txn,
|
||||
table="presence_list",
|
||||
keyvalues={
|
||||
"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid
|
||||
},
|
||||
updatevalues={"accepted": True},
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_presence_list_accepted, (observer_localpart,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_presence_list_observers_accepted, (observed_userid,)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return self.runInteraction(
|
||||
"set_presence_list_accepted", update_presence_list_txn,
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
|
||||
defer.returnValue(result)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
if accepted:
|
||||
|
|
|
@ -251,7 +251,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
self.get_user_by_id.invalidate((user_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_delete_access_tokens(self, user_id, except_token_ids=[],
|
||||
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
||||
device_id=None,
|
||||
delete_refresh_tokens=False):
|
||||
"""
|
||||
|
@ -259,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
|
||||
Args:
|
||||
user_id (str): ID of user the tokens belong to
|
||||
except_token_ids (list[str]): list of access_tokens which should
|
||||
except_token_id (str): list of access_tokens IDs which should
|
||||
*not* be deleted
|
||||
device_id (str|None): ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
|
@ -269,53 +269,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
def f(txn, table, except_tokens, call_after_delete):
|
||||
sql = "SELECT token FROM %s WHERE user_id = ?" % table
|
||||
clauses = [user_id]
|
||||
|
||||
def f(txn):
|
||||
keyvalues = {
|
||||
"user_id": user_id,
|
||||
}
|
||||
if device_id is not None:
|
||||
sql += " AND device_id = ?"
|
||||
clauses.append(device_id)
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
if except_tokens:
|
||||
sql += " AND id NOT IN (%s)" % (
|
||||
",".join(["?" for _ in except_tokens]),
|
||||
)
|
||||
clauses += except_tokens
|
||||
|
||||
txn.execute(sql, clauses)
|
||||
|
||||
rows = txn.fetchall()
|
||||
|
||||
n = 100
|
||||
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
|
||||
for chunk in chunks:
|
||||
if call_after_delete:
|
||||
for row in chunk:
|
||||
txn.call_after(call_after_delete, (row[0],))
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM %s WHERE token in (%s)" % (
|
||||
table,
|
||||
",".join(["?" for _ in chunk]),
|
||||
), [r[0] for r in chunk]
|
||||
if delete_refresh_tokens:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="refresh_tokens",
|
||||
keyvalues=keyvalues,
|
||||
)
|
||||
|
||||
# delete refresh tokens first, to stop new access tokens being
|
||||
# allocated while our backs are turned
|
||||
if delete_refresh_tokens:
|
||||
yield self.runInteraction(
|
||||
"user_delete_access_tokens", f,
|
||||
table="refresh_tokens",
|
||||
except_tokens=[],
|
||||
call_after_delete=None,
|
||||
items = keyvalues.items()
|
||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||
values = [v for _, v in items]
|
||||
if except_token_id:
|
||||
where_clause += " AND id != ?"
|
||||
values.append(except_token_id)
|
||||
|
||||
txn.execute(
|
||||
"SELECT token FROM access_tokens WHERE %s" % where_clause,
|
||||
values
|
||||
)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
for row in rows:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (row["token"],)
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM access_tokens WHERE %s" % where_clause,
|
||||
values
|
||||
)
|
||||
|
||||
yield self.runInteraction(
|
||||
"user_delete_access_tokens", f,
|
||||
table="access_tokens",
|
||||
except_tokens=except_token_ids,
|
||||
call_after_delete=self.get_user_by_access_token.invalidate,
|
||||
)
|
||||
|
||||
def delete_access_token(self, access_token):
|
||||
|
@ -328,7 +320,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
},
|
||||
)
|
||||
|
||||
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (access_token,)
|
||||
)
|
||||
|
||||
return self.runInteraction("delete_access_token", f)
|
||||
|
||||
|
|
|
@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
|
|||
user_id, membership_list=[Membership.JOIN],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def forget(self, user_id, room_id):
|
||||
"""Indicate that user_id wishes to discard history for room_id."""
|
||||
def f(txn):
|
||||
|
@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
|
|||
" room_id = ?"
|
||||
)
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
yield self.runInteraction("forget_membership", f)
|
||||
self.was_forgotten_at.invalidate_all()
|
||||
self.who_forgot_in_room.invalidate_all()
|
||||
self.did_forget.invalidate((user_id, room_id))
|
||||
|
||||
txn.call_after(self.was_forgotten_at.invalidate_all)
|
||||
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.who_forgot_in_room, (room_id,)
|
||||
)
|
||||
return self.runInteraction("forget_membership", f)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2)
|
||||
def did_forget(self, user_id, room_id):
|
||||
|
|
23
synapse/storage/schema/delta/34/appservice_stream.sql
Normal file
23
synapse/storage/schema/delta/34/appservice_stream.sql
Normal file
|
@ -0,0 +1,23 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS appservice_stream_position(
|
||||
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
|
||||
stream_ordering BIGINT,
|
||||
CHECK (Lock='X')
|
||||
);
|
||||
|
||||
INSERT INTO appservice_stream_position (stream_ordering)
|
||||
SELECT COALESCE(MAX(stream_ordering), 0) FROM events;
|
46
synapse/storage/schema/delta/34/cache_stream.py
Normal file
46
synapse/storage/schema/delta/34/cache_stream.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.prepare_database import get_statements
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# This stream is used to notify replication slaves that some caches have
|
||||
# been invalidated that they cannot infer from the other streams.
|
||||
CREATE_TABLE = """
|
||||
CREATE TABLE cache_invalidation_stream (
|
||||
stream_id BIGINT,
|
||||
cache_func TEXT,
|
||||
keys TEXT[],
|
||||
invalidation_ts BIGINT
|
||||
);
|
||||
|
||||
CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
|
||||
"""
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
if not isinstance(database_engine, PostgresEngine):
|
||||
return
|
||||
|
||||
for statement in get_statements(CREATE_TABLE.splitlines()):
|
||||
cur.execute(statement)
|
||||
|
||||
|
||||
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||
pass
|
20
synapse/storage/schema/delta/34/push_display_name_rename.sql
Normal file
20
synapse/storage/schema/delta/34/push_display_name_rename.sql
Normal file
|
@ -0,0 +1,20 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
|
||||
UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
|
||||
|
||||
DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
|
||||
UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
from synapse.appservice import ApplicationService
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock
|
||||
from tests import unittest
|
||||
|
||||
|
@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
||||
)
|
||||
|
||||
self.store = Mock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_user_id_prefix_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(self.event))
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_user_id_prefix_no_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@someone_else:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(self.event))
|
||||
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_member_is_checked(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
|
@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
self.event.sender = "@someone_else:matrix.org"
|
||||
self.event.type = "m.room.member"
|
||||
self.event.state_key = "@irc_foobar:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(self.event))
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_id_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||
)
|
||||
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(self.event))
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_id_no_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||
)
|
||||
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(self.event))
|
||||
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_alias_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.assertTrue(self.service.is_interested(
|
||||
self.event,
|
||||
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
|
||||
))
|
||||
self.store.get_aliases_for_room.return_value = [
|
||||
"#irc_foobar:matrix.org", "#athing:matrix.org"
|
||||
]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
|
||||
def test_non_exclusive_alias(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
|
@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
"!irc_foobar:matrix.org"
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_alias_no_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.assertFalse(self.service.is_interested(
|
||||
self.event,
|
||||
aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
|
||||
))
|
||||
self.store.get_aliases_for_room.return_value = [
|
||||
"#xmpp_foobar:matrix.org", "#athing:matrix.org"
|
||||
]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertFalse((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_multiple_matches(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
|
@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(
|
||||
self.event,
|
||||
aliases_for_event=["#irc_barfoo:matrix.org"]
|
||||
))
|
||||
|
||||
def test_restrict_to_rooms(self):
|
||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||
_regex("!flibble_.*:matrix.org")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.event.room_id = "!wibblewoo:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(
|
||||
self.event,
|
||||
restrict_to=ApplicationService.NS_ROOMS
|
||||
))
|
||||
|
||||
def test_restrict_to_aliases(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#xmpp_.*:matrix.org")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(
|
||||
self.event,
|
||||
restrict_to=ApplicationService.NS_ALIASES,
|
||||
aliases_for_event=["#irc_barfoo:matrix.org"]
|
||||
))
|
||||
|
||||
def test_restrict_to_senders(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#xmpp_.*:matrix.org")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||
self.assertFalse(self.service.is_interested(
|
||||
self.event,
|
||||
restrict_to=ApplicationService.NS_USERS,
|
||||
aliases_for_event=["#xmpp_barfoo:matrix.org"]
|
||||
))
|
||||
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_interested_in_self(self):
|
||||
# make sure invites get through
|
||||
self.service.sender = "@appservice:name"
|
||||
|
@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
"membership": "invite"
|
||||
}
|
||||
self.event.state_key = self.service.sender
|
||||
self.assertTrue(self.service.is_interested(self.event))
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_member_list_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
join_list = [
|
||||
self.store.get_users_in_room.return_value = [
|
||||
"@alice:here",
|
||||
"@irc_fo:here", # AS user
|
||||
"@bob:here",
|
||||
]
|
||||
self.store.get_aliases_for_room.return_value = []
|
||||
|
||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||
self.assertTrue(self.service.is_interested(
|
||||
event=self.event,
|
||||
member_list=join_list
|
||||
))
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
event=self.event, store=self.store
|
||||
)))
|
||||
|
|
|
@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
self.txn_ctrl = Mock()
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
|
||||
|
||||
def test_send_single_event_no_queue(self):
|
||||
# Expect the event to be sent immediately.
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
from .. import unittest
|
||||
from tests.utils import MockClock
|
||||
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
|
||||
|
@ -32,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
hs.get_datastore = Mock(return_value=self.mock_store)
|
||||
hs.get_application_service_api = Mock(return_value=self.mock_as_api)
|
||||
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
|
||||
hs.get_clock.return_value = MockClock()
|
||||
self.handler = ApplicationServicesHandler(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
self.mock_as_api.push = Mock()
|
||||
yield self.handler.notify_interested_services(event)
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
|
||||
interested_service, event
|
||||
)
|
||||
|
@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
yield self.handler.notify_interested_services(event)
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.mock_as_api.query_user.assert_called_once_with(
|
||||
services[0], user_id
|
||||
)
|
||||
|
@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
yield self.handler.notify_interested_services(event)
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.assertFalse(
|
||||
self.mock_as_api.query_user.called,
|
||||
"query_user called when it shouldn't have been."
|
||||
|
@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
|
||||
room_id = "!alpha:bet"
|
||||
servers = ["aperture"]
|
||||
interested_service = self._mkservice(is_interested=True)
|
||||
interested_service = self._mkservice_alias(is_interested_in_alias=True)
|
||||
services = [
|
||||
self._mkservice(is_interested=False),
|
||||
self._mkservice_alias(is_interested_in_alias=False),
|
||||
interested_service,
|
||||
self._mkservice(is_interested=False)
|
||||
self._mkservice_alias(is_interested_in_alias=False)
|
||||
]
|
||||
|
||||
self.mock_store.get_app_services = Mock(return_value=services)
|
||||
|
@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
service.token = "mock_service_token"
|
||||
service.url = "mock_service_url"
|
||||
return service
|
||||
|
||||
def _mkservice_alias(self, is_interested_in_alias):
|
||||
service = Mock()
|
||||
service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
|
||||
service.token = "mock_service_token"
|
||||
service.url = "mock_service_url"
|
||||
return service
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
|
||||
from . import unittest
|
||||
|
||||
from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs
|
||||
from synapse.rest.media.v1.preview_url_resource import (
|
||||
summarize_paragraphs, decode_and_calc_og
|
||||
)
|
||||
|
||||
|
||||
class PreviewTestCase(unittest.TestCase):
|
||||
|
@ -137,3 +139,79 @@ class PreviewTestCase(unittest.TestCase):
|
|||
" of old wooden houses in Northern Norway, the oldest house dating from"
|
||||
" 1789. The Arctic Cathedral, a modern church…"
|
||||
)
|
||||
|
||||
|
||||
class PreviewUrlTestCase(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
Some text.
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {
|
||||
"og:title": "Foo",
|
||||
"og:description": "Some text."
|
||||
})
|
||||
|
||||
def test_comment(self):
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
<!-- HTML comment -->
|
||||
Some text.
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {
|
||||
"og:title": "Foo",
|
||||
"og:description": "Some text."
|
||||
})
|
||||
|
||||
def test_comment2(self):
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
Some text.
|
||||
<!-- HTML comment -->
|
||||
Some more text.
|
||||
<p>Text</p>
|
||||
More text
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {
|
||||
"og:title": "Foo",
|
||||
"og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
|
||||
})
|
||||
|
||||
def test_script(self):
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
<script> (function() {})() </script>
|
||||
Some text.
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
|
||||
self.assertEquals(og, {
|
||||
"og:title": "Foo",
|
||||
"og:description": "Some text."
|
||||
})
|
||||
|
|
Loading…
Reference in a new issue