mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-17 07:21:37 +01:00
Merge branch 'release-v0.9.1' of github.com:matrix-org/synapse
This commit is contained in:
commit
6d1dea337b
64 changed files with 2529 additions and 1282 deletions
27
CHANGES.rst
27
CHANGES.rst
|
@ -1,3 +1,30 @@
|
||||||
|
Changes in synapse v0.9.1 (2015-05-26)
|
||||||
|
======================================
|
||||||
|
|
||||||
|
General:
|
||||||
|
|
||||||
|
* Add support for backfilling when a client paginates. This allows servers to
|
||||||
|
request history for a room from remote servers when a client tries to
|
||||||
|
paginate history the server does not have - SYN-36
|
||||||
|
* Fix bug where you couldn't disable non-default pushrules - SYN-378
|
||||||
|
* Fix ``register_new_user`` script - SYN-359
|
||||||
|
* Improve performance of fetching events from the database, this improves both
|
||||||
|
initialSync and sending of events.
|
||||||
|
* Improve performance of event streams, allowing synapse to handle more
|
||||||
|
simultaneous connected clients.
|
||||||
|
|
||||||
|
Federation:
|
||||||
|
|
||||||
|
* Fix bug with existing backfill implementation where it returned the wrong
|
||||||
|
selection of events in some circumstances.
|
||||||
|
* Improve performance of joining remote rooms.
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
|
||||||
|
* Add support for changing the bind host of the metrics listener via the
|
||||||
|
``metrics_bind_host`` option.
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.9.0-r5 (2015-05-21)
|
Changes in synapse v0.9.0-r5 (2015-05-21)
|
||||||
=========================================
|
=========================================
|
||||||
|
|
||||||
|
|
|
@ -117,7 +117,7 @@ Installing prerequisites on Mac OS X::
|
||||||
|
|
||||||
To install the synapse homeserver run::
|
To install the synapse homeserver run::
|
||||||
|
|
||||||
$ virtualenv ~/.synapse
|
$ virtualenv -p python2.7 ~/.synapse
|
||||||
$ source ~/.synapse/bin/activate
|
$ source ~/.synapse/bin/activate
|
||||||
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ for port in 8080 8081 8082; do
|
||||||
#rm $DIR/etc/$port.config
|
#rm $DIR/etc/$port.config
|
||||||
python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--generate-config \
|
--generate-config \
|
||||||
|
--enable_registration \
|
||||||
-H "localhost:$https_port" \
|
-H "localhost:$https_port" \
|
||||||
--config-path "$DIR/etc/$port.config" \
|
--config-path "$DIR/etc/$port.config" \
|
||||||
|
|
||||||
|
|
116
scripts-dev/convert_server_keys.py
Normal file
116
scripts-dev/convert_server_keys.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
import psycopg2
|
||||||
|
import yaml
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import hashlib
|
||||||
|
from syutil.base64util import encode_base64
|
||||||
|
from syutil.crypto.signing_key import read_signing_keys
|
||||||
|
from syutil.crypto.jsonsign import sign_json
|
||||||
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
|
|
||||||
|
def select_v1_keys(connection):
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute("SELECT server_name, key_id, verify_key FROM server_signature_keys")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
results = {}
|
||||||
|
for server_name, key_id, verify_key in rows:
|
||||||
|
results.setdefault(server_name, {})[key_id] = encode_base64(verify_key)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def select_v1_certs(connection):
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute("SELECT server_name, tls_certificate FROM server_tls_certificates")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
results = {}
|
||||||
|
for server_name, tls_certificate in rows:
|
||||||
|
results[server_name] = tls_certificate
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def select_v2_json(connection):
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute("SELECT server_name, key_id, key_json FROM server_keys_json")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
results = {}
|
||||||
|
for server_name, key_id, key_json in rows:
|
||||||
|
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8"))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def convert_v1_to_v2(server_name, valid_until, keys, certificate):
|
||||||
|
return {
|
||||||
|
"old_verify_keys": {},
|
||||||
|
"server_name": server_name,
|
||||||
|
"verify_keys": {
|
||||||
|
key_id: {"key": key}
|
||||||
|
for key_id, key in keys.items()
|
||||||
|
},
|
||||||
|
"valid_until_ts": valid_until,
|
||||||
|
"tls_fingerprints": [fingerprint(certificate)],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def fingerprint(certificate):
|
||||||
|
finger = hashlib.sha256(certificate)
|
||||||
|
return {"sha256": encode_base64(finger.digest())}
|
||||||
|
|
||||||
|
|
||||||
|
def rows_v2(server, json):
|
||||||
|
valid_until = json["valid_until_ts"]
|
||||||
|
key_json = encode_canonical_json(json)
|
||||||
|
for key_id in json["verify_keys"]:
|
||||||
|
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = yaml.load(open(sys.argv[1]))
|
||||||
|
valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24
|
||||||
|
|
||||||
|
server_name = config["server_name"]
|
||||||
|
signing_key = read_signing_keys(open(config["signing_key_path"]))[0]
|
||||||
|
|
||||||
|
database = config["database"]
|
||||||
|
assert database["name"] == "psycopg2", "Can only convert for postgresql"
|
||||||
|
args = database["args"]
|
||||||
|
args.pop("cp_max")
|
||||||
|
args.pop("cp_min")
|
||||||
|
connection = psycopg2.connect(**args)
|
||||||
|
keys = select_v1_keys(connection)
|
||||||
|
certificates = select_v1_certs(connection)
|
||||||
|
json = select_v2_json(connection)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for server in keys:
|
||||||
|
if not server in json:
|
||||||
|
v2_json = convert_v1_to_v2(
|
||||||
|
server, valid_until, keys[server], certificates[server]
|
||||||
|
)
|
||||||
|
v2_json = sign_json(v2_json, server_name, signing_key)
|
||||||
|
result[server] = v2_json
|
||||||
|
|
||||||
|
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
|
||||||
|
|
||||||
|
rows = list(
|
||||||
|
row for server, json in result.items()
|
||||||
|
for row in rows_v2(server, json)
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.executemany(
|
||||||
|
"INSERT INTO server_keys_json ("
|
||||||
|
" server_name, key_id, from_server,"
|
||||||
|
" ts_added_ms, ts_valid_until_ms, key_json"
|
||||||
|
") VALUES (%s, %s, %s, %s, %s, %s)",
|
||||||
|
rows
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -33,9 +33,10 @@ def request_registration(user, password, server_location, shared_secret):
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"username": user,
|
"user": user,
|
||||||
"password": password,
|
"password": password,
|
||||||
"mac": mac,
|
"mac": mac,
|
||||||
|
"type": "org.matrix.login.shared_secret",
|
||||||
}
|
}
|
||||||
|
|
||||||
server_location = server_location.rstrip("/")
|
server_location = server_location.rstrip("/")
|
||||||
|
@ -43,7 +44,7 @@ def request_registration(user, password, server_location, shared_secret):
|
||||||
print "Sending registration request..."
|
print "Sending registration request..."
|
||||||
|
|
||||||
req = urllib2.Request(
|
req = urllib2.Request(
|
||||||
"%s/_matrix/client/v2_alpha/register" % (server_location,),
|
"%s/_matrix/client/api/v1/register" % (server_location,),
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
headers={'Content-Type': 'application/json'}
|
headers={'Content-Type': 'application/json'}
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,4 +16,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.9.0-r5"
|
__version__ = "0.9.1"
|
||||||
|
|
|
@ -32,9 +32,9 @@ from synapse.server import HomeServer
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.enterprise import adbapi
|
from twisted.enterprise import adbapi
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource, EncodingResourceWrapper
|
||||||
from twisted.web.static import File
|
from twisted.web.static import File
|
||||||
from twisted.web.server import Site
|
from twisted.web.server import Site, GzipEncoderFactory
|
||||||
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
|
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
|
||||||
from synapse.http.server import JsonResource, RootRedirect
|
from synapse.http.server import JsonResource, RootRedirect
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
|
@ -69,16 +69,26 @@ import subprocess
|
||||||
logger = logging.getLogger("synapse.app.homeserver")
|
logger = logging.getLogger("synapse.app.homeserver")
|
||||||
|
|
||||||
|
|
||||||
|
class GzipFile(File):
|
||||||
|
def getChild(self, path, request):
|
||||||
|
child = File.getChild(self, path, request)
|
||||||
|
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
|
||||||
|
|
||||||
|
|
||||||
|
def gz_wrap(r):
|
||||||
|
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
||||||
|
|
||||||
|
|
||||||
class SynapseHomeServer(HomeServer):
|
class SynapseHomeServer(HomeServer):
|
||||||
|
|
||||||
def build_http_client(self):
|
def build_http_client(self):
|
||||||
return MatrixFederationHttpClient(self)
|
return MatrixFederationHttpClient(self)
|
||||||
|
|
||||||
def build_resource_for_client(self):
|
def build_resource_for_client(self):
|
||||||
return ClientV1RestResource(self)
|
return gz_wrap(ClientV1RestResource(self))
|
||||||
|
|
||||||
def build_resource_for_client_v2_alpha(self):
|
def build_resource_for_client_v2_alpha(self):
|
||||||
return ClientV2AlphaRestResource(self)
|
return gz_wrap(ClientV2AlphaRestResource(self))
|
||||||
|
|
||||||
def build_resource_for_federation(self):
|
def build_resource_for_federation(self):
|
||||||
return JsonResource(self)
|
return JsonResource(self)
|
||||||
|
@ -87,9 +97,16 @@ class SynapseHomeServer(HomeServer):
|
||||||
import syweb
|
import syweb
|
||||||
syweb_path = os.path.dirname(syweb.__file__)
|
syweb_path = os.path.dirname(syweb.__file__)
|
||||||
webclient_path = os.path.join(syweb_path, "webclient")
|
webclient_path = os.path.join(syweb_path, "webclient")
|
||||||
|
# GZip is disabled here due to
|
||||||
|
# https://twistedmatrix.com/trac/ticket/7678
|
||||||
|
# (It can stay enabled for the API resources: they call
|
||||||
|
# write() with the whole body and then finish() straight
|
||||||
|
# after and so do not trigger the bug.
|
||||||
|
# return GzipFile(webclient_path) # TODO configurable?
|
||||||
return File(webclient_path) # TODO configurable?
|
return File(webclient_path) # TODO configurable?
|
||||||
|
|
||||||
def build_resource_for_static_content(self):
|
def build_resource_for_static_content(self):
|
||||||
|
# This is old and should go away: not going to bother adding gzip
|
||||||
return File("static")
|
return File("static")
|
||||||
|
|
||||||
def build_resource_for_content_repo(self):
|
def build_resource_for_content_repo(self):
|
||||||
|
@ -260,9 +277,12 @@ class SynapseHomeServer(HomeServer):
|
||||||
config,
|
config,
|
||||||
metrics_resource,
|
metrics_resource,
|
||||||
),
|
),
|
||||||
interface="127.0.0.1",
|
interface=config.metrics_bind_host,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Metrics now running on %s port %d",
|
||||||
|
config.metrics_bind_host, config.metrics_port,
|
||||||
)
|
)
|
||||||
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
|
|
||||||
|
|
||||||
def run_startup_checks(self, db_conn, database_engine):
|
def run_startup_checks(self, db_conn, database_engine):
|
||||||
all_users_native = are_all_users_on_domain(
|
all_users_native = are_all_users_on_domain(
|
||||||
|
|
|
@ -148,8 +148,8 @@ class ApplicationService(object):
|
||||||
and self.is_interested_in_user(event.state_key)):
|
and self.is_interested_in_user(event.state_key)):
|
||||||
return True
|
return True
|
||||||
# check joined member events
|
# check joined member events
|
||||||
for member in member_list:
|
for user_id in member_list:
|
||||||
if self.is_interested_in_user(member.state_key):
|
if self.is_interested_in_user(user_id):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ class ApplicationService(object):
|
||||||
restrict_to(str): The namespace to restrict regex tests to.
|
restrict_to(str): The namespace to restrict regex tests to.
|
||||||
aliases_for_event(list): A list of all the known room aliases for
|
aliases_for_event(list): A list of all the known room aliases for
|
||||||
this event.
|
this event.
|
||||||
member_list(list): A list of all joined room members in this room.
|
member_list(list): A list of all joined user_ids in this room.
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if this service would like to know about this event.
|
bool: True if this service would like to know about this event.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -20,6 +20,7 @@ class MetricsConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.enable_metrics = config["enable_metrics"]
|
self.enable_metrics = config["enable_metrics"]
|
||||||
self.metrics_port = config.get("metrics_port")
|
self.metrics_port = config.get("metrics_port")
|
||||||
|
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name):
|
def default_config(self, config_dir_path, server_name):
|
||||||
return """\
|
return """\
|
||||||
|
@ -28,6 +29,9 @@ class MetricsConfig(Config):
|
||||||
# Enable collection and rendering of performance metrics
|
# Enable collection and rendering of performance metrics
|
||||||
enable_metrics: False
|
enable_metrics: False
|
||||||
|
|
||||||
# Separate port to accept metrics requests on (on localhost)
|
# Separate port to accept metrics requests on
|
||||||
# metrics_port: 8081
|
# metrics_port: 8081
|
||||||
|
|
||||||
|
# Which host to bind the metric listener to
|
||||||
|
# metrics_bind_host: 127.0.0.1
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.endpoint import matrix_federation_endpoint
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import (
|
||||||
|
preserve_context_over_fn, preserve_context_over_deferred
|
||||||
|
)
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
try:
|
try:
|
||||||
with PreserveLoggingContext():
|
protocol = yield preserve_context_over_fn(
|
||||||
protocol = yield endpoint.connect(factory)
|
endpoint.connect, factory
|
||||||
server_response, server_certificate = yield protocol.remote_key
|
)
|
||||||
defer.returnValue((server_response, server_certificate))
|
server_response, server_certificate = yield preserve_context_over_deferred(
|
||||||
return
|
protocol.remote_key
|
||||||
|
)
|
||||||
|
defer.returnValue((server_response, server_certificate))
|
||||||
|
return
|
||||||
except SynapseKeyClientError as e:
|
except SynapseKeyClientError as e:
|
||||||
logger.exception("Error getting key for %r" % (server_name,))
|
logger.exception("Error getting key for %r" % (server_name,))
|
||||||
if e.status.startswith("4"):
|
if e.status.startswith("4"):
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter
|
from synapse.util.retryutils import get_retry_limiter
|
||||||
|
|
||||||
from synapse.util.async import create_observer
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
|
@ -111,6 +111,10 @@ class Keyring(object):
|
||||||
|
|
||||||
if download is None:
|
if download is None:
|
||||||
download = self._get_server_verify_key_impl(server_name, key_ids)
|
download = self._get_server_verify_key_impl(server_name, key_ids)
|
||||||
|
download = ObservableDeferred(
|
||||||
|
download,
|
||||||
|
consumeErrors=True
|
||||||
|
)
|
||||||
self.key_downloads[server_name] = download
|
self.key_downloads[server_name] = download
|
||||||
|
|
||||||
@download.addBoth
|
@download.addBoth
|
||||||
|
@ -118,30 +122,31 @@ class Keyring(object):
|
||||||
del self.key_downloads[server_name]
|
del self.key_downloads[server_name]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
r = yield create_observer(download)
|
r = yield download.observe()
|
||||||
defer.returnValue(r)
|
defer.returnValue(r)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_server_verify_key_impl(self, server_name, key_ids):
|
def _get_server_verify_key_impl(self, server_name, key_ids):
|
||||||
keys = None
|
keys = None
|
||||||
|
|
||||||
perspective_results = []
|
@defer.inlineCallbacks
|
||||||
for perspective_name, perspective_keys in self.perspective_servers.items():
|
def get_key(perspective_name, perspective_keys):
|
||||||
@defer.inlineCallbacks
|
try:
|
||||||
def get_key():
|
result = yield self.get_server_verify_key_v2_indirect(
|
||||||
try:
|
server_name, key_ids, perspective_name, perspective_keys
|
||||||
result = yield self.get_server_verify_key_v2_indirect(
|
)
|
||||||
server_name, key_ids, perspective_name, perspective_keys
|
defer.returnValue(result)
|
||||||
)
|
except Exception as e:
|
||||||
defer.returnValue(result)
|
logging.info(
|
||||||
except:
|
"Unable to getting key %r for %r from %r: %s %s",
|
||||||
logging.info(
|
key_ids, server_name, perspective_name,
|
||||||
"Unable to getting key %r for %r from %r",
|
type(e).__name__, str(e.message),
|
||||||
key_ids, server_name, perspective_name,
|
)
|
||||||
)
|
|
||||||
perspective_results.append(get_key())
|
|
||||||
|
|
||||||
perspective_results = yield defer.gatherResults(perspective_results)
|
perspective_results = yield defer.gatherResults([
|
||||||
|
get_key(p_name, p_keys)
|
||||||
|
for p_name, p_keys in self.perspective_servers.items()
|
||||||
|
])
|
||||||
|
|
||||||
for results in perspective_results:
|
for results in perspective_results:
|
||||||
if results is not None:
|
if results is not None:
|
||||||
|
@ -154,17 +159,22 @@ class Keyring(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
with limiter:
|
with limiter:
|
||||||
if keys is None:
|
if not keys:
|
||||||
try:
|
try:
|
||||||
keys = yield self.get_server_verify_key_v2_direct(
|
keys = yield self.get_server_verify_key_v2_direct(
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
)
|
)
|
||||||
except:
|
except Exception as e:
|
||||||
pass
|
logging.info(
|
||||||
|
"Unable to getting key %r for %r directly: %s %s",
|
||||||
|
key_ids, server_name,
|
||||||
|
type(e).__name__, str(e.message),
|
||||||
|
)
|
||||||
|
|
||||||
keys = yield self.get_server_verify_key_v1_direct(
|
if not keys:
|
||||||
server_name, key_ids
|
keys = yield self.get_server_verify_key_v1_direct(
|
||||||
)
|
server_name, key_ids
|
||||||
|
)
|
||||||
|
|
||||||
for key_id in key_ids:
|
for key_id in key_ids:
|
||||||
if key_id in keys:
|
if key_id in keys:
|
||||||
|
@ -184,7 +194,7 @@ class Keyring(object):
|
||||||
# TODO(mark): Set the minimum_valid_until_ts to that needed by
|
# TODO(mark): Set the minimum_valid_until_ts to that needed by
|
||||||
# the events being validated or the current time if validating
|
# the events being validated or the current time if validating
|
||||||
# an incoming request.
|
# an incoming request.
|
||||||
responses = yield self.client.post_json(
|
query_response = yield self.client.post_json(
|
||||||
destination=perspective_name,
|
destination=perspective_name,
|
||||||
path=b"/_matrix/key/v2/query",
|
path=b"/_matrix/key/v2/query",
|
||||||
data={
|
data={
|
||||||
|
@ -200,6 +210,8 @@ class Keyring(object):
|
||||||
|
|
||||||
keys = {}
|
keys = {}
|
||||||
|
|
||||||
|
responses = query_response["server_keys"]
|
||||||
|
|
||||||
for response in responses:
|
for response in responses:
|
||||||
if (u"signatures" not in response
|
if (u"signatures" not in response
|
||||||
or perspective_name not in response[u"signatures"]):
|
or perspective_name not in response[u"signatures"]):
|
||||||
|
@ -323,7 +335,7 @@ class Keyring(object):
|
||||||
verify_key.time_added = time_now_ms
|
verify_key.time_added = time_now_ms
|
||||||
old_verify_keys[key_id] = verify_key
|
old_verify_keys[key_id] = verify_key
|
||||||
|
|
||||||
for key_id in response_json["signatures"][server_name]:
|
for key_id in response_json["signatures"].get(server_name, {}):
|
||||||
if key_id not in response_json["verify_keys"]:
|
if key_id not in response_json["verify_keys"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Key response must include verification keys for all"
|
"Key response must include verification keys for all"
|
||||||
|
|
|
@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,6 +80,7 @@ class FederationBase(object):
|
||||||
destinations=[pdu.origin],
|
destinations=[pdu.origin],
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
|
timeout=10000,
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_pdu:
|
if new_pdu:
|
||||||
|
@ -94,7 +97,7 @@ class FederationBase(object):
|
||||||
yield defer.gatherResults(
|
yield defer.gatherResults(
|
||||||
[do(pdu) for pdu in pdus],
|
[do(pdu) for pdu in pdus],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(signed_pdus)
|
defer.returnValue(signed_pdus)
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ from .units import Edu
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException, HttpResponseException, SynapseError,
|
CodeMessageException, HttpResponseException, SynapseError,
|
||||||
)
|
)
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.expiringcache import ExpiringCache
|
from synapse.util.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
|
@ -164,16 +165,17 @@ class FederationClient(FederationBase):
|
||||||
for p in transaction_data["pdus"]
|
for p in transaction_data["pdus"]
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, pdu in enumerate(pdus):
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
pdus[i] = yield self._check_sigs_and_hash(pdu)
|
pdus[:] = yield defer.gatherResults(
|
||||||
|
[self._check_sigs_and_hash(pdu) for pdu in pdus],
|
||||||
# FIXME: We should handle signature failures more gracefully.
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(pdus)
|
defer.returnValue(pdus)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_pdu(self, destinations, event_id, outlier=False):
|
def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
|
||||||
"""Requests the PDU with given origin and ID from the remote home
|
"""Requests the PDU with given origin and ID from the remote home
|
||||||
servers.
|
servers.
|
||||||
|
|
||||||
|
@ -189,6 +191,8 @@ class FederationClient(FederationBase):
|
||||||
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
||||||
it's from an arbitary point in the context as opposed to part
|
it's from an arbitary point in the context as opposed to part
|
||||||
of the current block of PDUs. Defaults to `False`
|
of the current block of PDUs. Defaults to `False`
|
||||||
|
timeout (int): How long to try (in ms) each destination for before
|
||||||
|
moving to the next destination. None indicates no timeout.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Results in the requested PDU.
|
Deferred: Results in the requested PDU.
|
||||||
|
@ -212,7 +216,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
with limiter:
|
with limiter:
|
||||||
transaction_data = yield self.transport_layer.get_event(
|
transaction_data = yield self.transport_layer.get_event(
|
||||||
destination, event_id
|
destination, event_id, timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("transaction_data %r", transaction_data)
|
logger.debug("transaction_data %r", transaction_data)
|
||||||
|
@ -222,7 +226,7 @@ class FederationClient(FederationBase):
|
||||||
for p in transaction_data["pdus"]
|
for p in transaction_data["pdus"]
|
||||||
]
|
]
|
||||||
|
|
||||||
if pdu_list:
|
if pdu_list and pdu_list[0]:
|
||||||
pdu = pdu_list[0]
|
pdu = pdu_list[0]
|
||||||
|
|
||||||
# Check signatures are correct.
|
# Check signatures are correct.
|
||||||
|
@ -255,7 +259,7 @@ class FederationClient(FederationBase):
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._get_pdu_cache is not None:
|
if self._get_pdu_cache is not None and pdu:
|
||||||
self._get_pdu_cache[event_id] = pdu
|
self._get_pdu_cache[event_id] = pdu
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
defer.returnValue(pdu)
|
||||||
|
@ -370,13 +374,17 @@ class FederationClient(FederationBase):
|
||||||
for p in content.get("auth_chain", [])
|
for p in content.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
signed_state = yield self._check_sigs_and_hash_and_fetch(
|
signed_state, signed_auth = yield defer.gatherResults(
|
||||||
destination, state, outlier=True
|
[
|
||||||
)
|
self._check_sigs_and_hash_and_fetch(
|
||||||
|
destination, state, outlier=True
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
),
|
||||||
destination, auth_chain, outlier=True
|
self._check_sigs_and_hash_and_fetch(
|
||||||
)
|
destination, auth_chain, outlier=True
|
||||||
|
)
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
|
@ -518,7 +526,7 @@ class FederationClient(FederationBase):
|
||||||
# Are we missing any?
|
# Are we missing any?
|
||||||
|
|
||||||
seen_events = set(earliest_events_ids)
|
seen_events = set(earliest_events_ids)
|
||||||
seen_events.update(e.event_id for e in signed_events)
|
seen_events.update(e.event_id for e in signed_events if e)
|
||||||
|
|
||||||
missing_events = {}
|
missing_events = {}
|
||||||
for e in itertools.chain(latest_events, signed_events):
|
for e in itertools.chain(latest_events, signed_events):
|
||||||
|
@ -561,7 +569,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
for (result, val), (e_id, _) in zip(res, ordered_missing):
|
for (result, val), (e_id, _) in zip(res, ordered_missing):
|
||||||
if result:
|
if result and val:
|
||||||
signed_events.append(val)
|
signed_events.append(val)
|
||||||
else:
|
else:
|
||||||
failed_to_fetch.add(e_id)
|
failed_to_fetch.add(e_id)
|
||||||
|
|
|
@ -20,7 +20,6 @@ from .federation_base import FederationBase
|
||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -123,29 +122,28 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
results = []
|
||||||
results = []
|
|
||||||
|
|
||||||
for pdu in pdu_list:
|
for pdu in pdu_list:
|
||||||
d = self._handle_new_pdu(transaction.origin, pdu)
|
d = self._handle_new_pdu(transaction.origin, pdu)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield d
|
yield d
|
||||||
results.append({})
|
results.append({})
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
self.send_failure(e, transaction.origin)
|
self.send_failure(e, transaction.origin)
|
||||||
results.append({"error": str(e)})
|
results.append({"error": str(e)})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results.append({"error": str(e)})
|
results.append({"error": str(e)})
|
||||||
logger.exception("Failed to handle PDU")
|
logger.exception("Failed to handle PDU")
|
||||||
|
|
||||||
if hasattr(transaction, "edus"):
|
if hasattr(transaction, "edus"):
|
||||||
for edu in [Edu(**x) for x in transaction.edus]:
|
for edu in [Edu(**x) for x in transaction.edus]:
|
||||||
self.received_edu(
|
self.received_edu(
|
||||||
transaction.origin,
|
transaction.origin,
|
||||||
edu.edu_type,
|
edu.edu_type,
|
||||||
edu.content
|
edu.content
|
||||||
)
|
)
|
||||||
|
|
||||||
for failure in getattr(transaction, "pdu_failures", []):
|
for failure in getattr(transaction, "pdu_failures", []):
|
||||||
logger.info("Got failure %r", failure)
|
logger.info("Got failure %r", failure)
|
||||||
|
|
|
@ -207,13 +207,13 @@ class TransactionQueue(object):
|
||||||
# request at which point pending_pdus_by_dest just keeps growing.
|
# request at which point pending_pdus_by_dest just keeps growing.
|
||||||
# we need application-layer timeouts of some flavour of these
|
# we need application-layer timeouts of some flavour of these
|
||||||
# requests
|
# requests
|
||||||
logger.info(
|
logger.debug(
|
||||||
"TX [%s] Transaction already in progress",
|
"TX [%s] Transaction already in progress",
|
||||||
destination
|
destination
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("TX [%s] _attempt_new_transaction", destination)
|
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||||
|
|
||||||
# list of (pending_pdu, deferred, order)
|
# list of (pending_pdu, deferred, order)
|
||||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||||
|
@ -221,11 +221,11 @@ class TransactionQueue(object):
|
||||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
||||||
|
|
||||||
if pending_pdus:
|
if pending_pdus:
|
||||||
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||||
destination, len(pending_pdus))
|
destination, len(pending_pdus))
|
||||||
|
|
||||||
if not pending_pdus and not pending_edus and not pending_failures:
|
if not pending_pdus and not pending_edus and not pending_failures:
|
||||||
logger.info("TX [%s] Nothing to send", destination)
|
logger.debug("TX [%s] Nothing to send", destination)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Sort based on the order field
|
# Sort based on the order field
|
||||||
|
@ -242,6 +242,8 @@ class TransactionQueue(object):
|
||||||
try:
|
try:
|
||||||
self.pending_transactions[destination] = 1
|
self.pending_transactions[destination] = 1
|
||||||
|
|
||||||
|
txn_id = str(self._next_txn_id)
|
||||||
|
|
||||||
limiter = yield get_retry_limiter(
|
limiter = yield get_retry_limiter(
|
||||||
destination,
|
destination,
|
||||||
self._clock,
|
self._clock,
|
||||||
|
@ -249,9 +251,9 @@ class TransactionQueue(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"TX [%s] Attempting new transaction"
|
"TX [%s] {%s} Attempting new transaction"
|
||||||
" (pdus: %d, edus: %d, failures: %d)",
|
" (pdus: %d, edus: %d, failures: %d)",
|
||||||
destination,
|
destination, txn_id,
|
||||||
len(pending_pdus),
|
len(pending_pdus),
|
||||||
len(pending_edus),
|
len(pending_edus),
|
||||||
len(pending_failures)
|
len(pending_failures)
|
||||||
|
@ -261,7 +263,7 @@ class TransactionQueue(object):
|
||||||
|
|
||||||
transaction = Transaction.create_new(
|
transaction = Transaction.create_new(
|
||||||
origin_server_ts=int(self._clock.time_msec()),
|
origin_server_ts=int(self._clock.time_msec()),
|
||||||
transaction_id=str(self._next_txn_id),
|
transaction_id=txn_id,
|
||||||
origin=self.server_name,
|
origin=self.server_name,
|
||||||
destination=destination,
|
destination=destination,
|
||||||
pdus=pdus,
|
pdus=pdus,
|
||||||
|
@ -275,9 +277,13 @@ class TransactionQueue(object):
|
||||||
|
|
||||||
logger.debug("TX [%s] Persisted transaction", destination)
|
logger.debug("TX [%s] Persisted transaction", destination)
|
||||||
logger.info(
|
logger.info(
|
||||||
"TX [%s] Sending transaction [%s]",
|
"TX [%s] {%s} Sending transaction [%s],"
|
||||||
destination,
|
" (PDUs: %d, EDUs: %d, failures: %d)",
|
||||||
|
destination, txn_id,
|
||||||
transaction.transaction_id,
|
transaction.transaction_id,
|
||||||
|
len(pending_pdus),
|
||||||
|
len(pending_edus),
|
||||||
|
len(pending_failures),
|
||||||
)
|
)
|
||||||
|
|
||||||
with limiter:
|
with limiter:
|
||||||
|
@ -313,7 +319,10 @@ class TransactionQueue(object):
|
||||||
code = e.code
|
code = e.code
|
||||||
response = e.response
|
response = e.response
|
||||||
|
|
||||||
logger.info("TX [%s] got %d response", destination, code)
|
logger.info(
|
||||||
|
"TX [%s] {%s} got %d response",
|
||||||
|
destination, txn_id, code
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("TX [%s] Sent transaction", destination)
|
logger.debug("TX [%s] Sent transaction", destination)
|
||||||
logger.debug("TX [%s] Marking as delivered...", destination)
|
logger.debug("TX [%s] Marking as delivered...", destination)
|
||||||
|
|
|
@ -50,13 +50,15 @@ class TransportLayerClient(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_event(self, destination, event_id):
|
def get_event(self, destination, event_id, timeout=None):
|
||||||
""" Requests the pdu with give id and origin from the given server.
|
""" Requests the pdu with give id and origin from the given server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): The host name of the remote home server we want
|
destination (str): The host name of the remote home server we want
|
||||||
to get the state from.
|
to get the state from.
|
||||||
event_id (str): The id of the event being requested.
|
event_id (str): The id of the event being requested.
|
||||||
|
timeout (int): How long to try (in ms) the destination for before
|
||||||
|
giving up. None indicates no timeout.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Results in a dict received from the remote homeserver.
|
Deferred: Results in a dict received from the remote homeserver.
|
||||||
|
@ -65,7 +67,7 @@ class TransportLayerClient(object):
|
||||||
destination, event_id)
|
destination, event_id)
|
||||||
|
|
||||||
path = PREFIX + "/event/%s/" % (event_id, )
|
path = PREFIX + "/event/%s/" % (event_id, )
|
||||||
return self.client.get_json(destination, path=path)
|
return self.client.get_json(destination, path=path, timeout=timeout)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def backfill(self, destination, room_id, event_tuples, limit):
|
def backfill(self, destination, room_id, event_tuples, limit):
|
||||||
|
|
|
@ -196,6 +196,14 @@ class FederationSendServlet(BaseFederationServlet):
|
||||||
transaction_id, str(transaction_data)
|
transaction_id, str(transaction_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
|
||||||
|
transaction_id, origin,
|
||||||
|
len(transaction_data.get("pdus", [])),
|
||||||
|
len(transaction_data.get("edus", [])),
|
||||||
|
len(transaction_data.get("failures", [])),
|
||||||
|
)
|
||||||
|
|
||||||
# We should ideally be getting this from the security layer.
|
# We should ideally be getting this from the security layer.
|
||||||
# origin = body["origin"]
|
# origin = body["origin"]
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,7 +105,9 @@ class BaseHandler(object):
|
||||||
if not suppress_auth:
|
if not suppress_auth:
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
self.auth.check(event, auth_events=context.current_state)
|
||||||
|
|
||||||
yield self.store.persist_event(event, context=context)
|
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
||||||
|
event, context=context
|
||||||
|
)
|
||||||
|
|
||||||
federation_handler = self.hs.get_handlers().federation_handler
|
federation_handler = self.hs.get_handlers().federation_handler
|
||||||
|
|
||||||
|
@ -137,10 +141,12 @@ class BaseHandler(object):
|
||||||
"Failed to get destination from event %s", s.event_id
|
"Failed to get destination from event %s", s.event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Don't block waiting on waking up all the listeners.
|
with PreserveLoggingContext():
|
||||||
notify_d = self.notifier.on_new_room_event(
|
# Don't block waiting on waking up all the listeners.
|
||||||
event, extra_users=extra_users
|
notify_d = self.notifier.on_new_room_event(
|
||||||
)
|
event, event_stream_id, max_stream_id,
|
||||||
|
extra_users=extra_users
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
@ -147,10 +147,7 @@ class ApplicationServicesHandler(object):
|
||||||
)
|
)
|
||||||
# We need to know the members associated with this event.room_id,
|
# We need to know the members associated with this event.room_id,
|
||||||
# if any.
|
# if any.
|
||||||
member_list = yield self.store.get_room_members(
|
member_list = yield self.store.get_users_in_room(event.room_id)
|
||||||
room_id=event.room_id,
|
|
||||||
membership=Membership.JOIN
|
|
||||||
)
|
|
||||||
|
|
||||||
services = yield self.store.get_app_services()
|
services = yield self.store.get_app_services()
|
||||||
interested_list = [
|
interested_list = [
|
||||||
|
|
|
@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomAlias
|
from synapse.types import RoomAlias
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import string
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -40,6 +41,10 @@ class DirectoryHandler(BaseHandler):
|
||||||
def _create_association(self, room_alias, room_id, servers=None):
|
def _create_association(self, room_alias, room_id, servers=None):
|
||||||
# general association creation for both human users and app services
|
# general association creation for both human users and app services
|
||||||
|
|
||||||
|
for wchar in string.whitespace:
|
||||||
|
if wchar in room_alias.localpart:
|
||||||
|
raise SynapseError(400, "Invalid characters in room alias")
|
||||||
|
|
||||||
if not self.hs.is_mine(room_alias):
|
if not self.hs.is_mine(room_alias):
|
||||||
raise SynapseError(400, "Room alias must be local")
|
raise SynapseError(400, "Room alias must be local")
|
||||||
# TODO(erikj): Change this.
|
# TODO(erikj): Change this.
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler):
|
||||||
# thundering herds on restart.
|
# thundering herds on restart.
|
||||||
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
events, tokens = yield self.notifier.get_events_for(
|
||||||
events, tokens = yield self.notifier.get_events_for(
|
auth_user, room_ids, pagin_config, timeout
|
||||||
auth_user, room_ids, pagin_config, timeout
|
)
|
||||||
)
|
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,11 @@
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, FederationError, StoreError,
|
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
|
@ -29,6 +31,8 @@ from synapse.crypto.event_signing import (
|
||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -156,7 +160,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_event(
|
_, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||||
origin,
|
origin,
|
||||||
event,
|
event,
|
||||||
state=state,
|
state=state,
|
||||||
|
@ -197,9 +201,11 @@ class FederationHandler(BaseHandler):
|
||||||
target_user = UserID.from_string(target_user_id)
|
target_user = UserID.from_string(target_user_id)
|
||||||
extra_users.append(target_user)
|
extra_users.append(target_user)
|
||||||
|
|
||||||
d = self.notifier.on_new_room_event(
|
with PreserveLoggingContext():
|
||||||
event, extra_users=extra_users
|
d = self.notifier.on_new_room_event(
|
||||||
)
|
event, event_stream_id, max_stream_id,
|
||||||
|
extra_users=extra_users
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -218,36 +224,209 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def backfill(self, dest, room_id, limit):
|
def backfill(self, dest, room_id, limit, extremities=[]):
|
||||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||||
"""
|
"""
|
||||||
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
if not extremities:
|
||||||
|
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
||||||
|
|
||||||
pdus = yield self.replication_layer.backfill(
|
events = yield self.replication_layer.backfill(
|
||||||
dest,
|
dest,
|
||||||
room_id,
|
room_id,
|
||||||
limit,
|
limit=limit,
|
||||||
extremities=extremities,
|
extremities=extremities,
|
||||||
)
|
)
|
||||||
|
|
||||||
events = []
|
event_map = {e.event_id: e for e in events}
|
||||||
|
|
||||||
for pdu in pdus:
|
event_ids = set(e.event_id for e in events)
|
||||||
event = pdu
|
|
||||||
|
|
||||||
# FIXME (erikj): Not sure this actually works :/
|
edges = [
|
||||||
context = yield self.state_handler.compute_event_context(event)
|
ev.event_id
|
||||||
|
for ev in events
|
||||||
|
if set(e_id for e_id, _ in ev.prev_events) - event_ids
|
||||||
|
]
|
||||||
|
|
||||||
events.append((event, context))
|
# For each edge get the current state.
|
||||||
|
|
||||||
yield self.store.persist_event(
|
auth_events = {}
|
||||||
event,
|
events_to_state = {}
|
||||||
context=context,
|
for e_id in edges:
|
||||||
backfilled=True
|
state, auth = yield self.replication_layer.get_state_for_room(
|
||||||
|
destination=dest,
|
||||||
|
room_id=room_id,
|
||||||
|
event_id=e_id
|
||||||
|
)
|
||||||
|
auth_events.update({a.event_id: a for a in auth})
|
||||||
|
events_to_state[e_id] = state
|
||||||
|
|
||||||
|
yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
self._handle_new_event(dest, a)
|
||||||
|
for a in auth_events.values()
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
self._handle_new_event(
|
||||||
|
dest, event_map[e_id],
|
||||||
|
state=events_to_state[e_id],
|
||||||
|
backfilled=True,
|
||||||
|
)
|
||||||
|
for e_id in events_to_state
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
events.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
if event in events_to_state:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield self._handle_new_event(
|
||||||
|
dest, event,
|
||||||
|
backfilled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def maybe_backfill(self, room_id, current_depth):
|
||||||
|
"""Checks the database to see if we should backfill before paginating,
|
||||||
|
and if so do.
|
||||||
|
"""
|
||||||
|
extremities = yield self.store.get_oldest_events_with_depth_in_room(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not extremities:
|
||||||
|
logger.debug("Not backfilling as no extremeties found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we reached a point where we should start backfilling.
|
||||||
|
sorted_extremeties_tuple = sorted(
|
||||||
|
extremities.items(),
|
||||||
|
key=lambda e: -int(e[1])
|
||||||
|
)
|
||||||
|
max_depth = sorted_extremeties_tuple[0][1]
|
||||||
|
|
||||||
|
if current_depth > max_depth:
|
||||||
|
logger.debug(
|
||||||
|
"Not backfilling as we don't need to. %d < %d",
|
||||||
|
max_depth, current_depth,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Now we need to decide which hosts to hit first.
|
||||||
|
|
||||||
|
# First we try hosts that are already in the room
|
||||||
|
# TODO: HEURISTIC ALERT.
|
||||||
|
|
||||||
|
curr_state = yield self.state_handler.get_current_state(room_id)
|
||||||
|
|
||||||
|
def get_domains_from_state(state):
|
||||||
|
joined_users = [
|
||||||
|
(state_key, int(event.depth))
|
||||||
|
for (e_type, state_key), event in state.items()
|
||||||
|
if e_type == EventTypes.Member
|
||||||
|
and event.membership == Membership.JOIN
|
||||||
|
]
|
||||||
|
|
||||||
|
joined_domains = {}
|
||||||
|
for u, d in joined_users:
|
||||||
|
try:
|
||||||
|
dom = UserID.from_string(u).domain
|
||||||
|
old_d = joined_domains.get(dom)
|
||||||
|
if old_d:
|
||||||
|
joined_domains[dom] = min(d, old_d)
|
||||||
|
else:
|
||||||
|
joined_domains[dom] = d
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return sorted(joined_domains.items(), key=lambda d: d[1])
|
||||||
|
|
||||||
|
curr_domains = get_domains_from_state(curr_state)
|
||||||
|
|
||||||
|
likely_domains = [
|
||||||
|
domain for domain, depth in curr_domains
|
||||||
|
if domain is not self.server_name
|
||||||
|
]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def try_backfill(domains):
|
||||||
|
# TODO: Should we try multiple of these at a time?
|
||||||
|
for dom in domains:
|
||||||
|
try:
|
||||||
|
events = yield self.backfill(
|
||||||
|
dom, room_id,
|
||||||
|
limit=100,
|
||||||
|
extremities=[e for e in extremities.keys()]
|
||||||
|
)
|
||||||
|
except SynapseError:
|
||||||
|
logger.info(
|
||||||
|
"Failed to backfill from %s because %s",
|
||||||
|
dom, e,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except CodeMessageException as e:
|
||||||
|
if 400 <= e.code < 500:
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Failed to backfill from %s because %s",
|
||||||
|
dom, e,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except NotRetryingDestination as e:
|
||||||
|
logger.info(e.message)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to backfill from %s because %s",
|
||||||
|
dom, e,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if events:
|
||||||
|
defer.returnValue(True)
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
success = yield try_backfill(likely_domains)
|
||||||
|
if success:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
# Huh, well *those* domains didn't work out. Lets try some domains
|
||||||
|
# from the time.
|
||||||
|
|
||||||
|
tried_domains = set(likely_domains)
|
||||||
|
tried_domains.add(self.server_name)
|
||||||
|
|
||||||
|
event_ids = list(extremities.keys())
|
||||||
|
|
||||||
|
states = yield defer.gatherResults([
|
||||||
|
self.state_handler.resolve_state_groups([e])
|
||||||
|
for e in event_ids
|
||||||
|
])
|
||||||
|
states = dict(zip(event_ids, [s[1] for s in states]))
|
||||||
|
|
||||||
|
for e_id, _ in sorted_extremeties_tuple:
|
||||||
|
likely_domains = get_domains_from_state(states[e_id])
|
||||||
|
|
||||||
|
success = yield try_backfill([
|
||||||
|
dom for dom in likely_domains
|
||||||
|
if dom not in tried_domains
|
||||||
|
])
|
||||||
|
if success:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
tried_domains.update(likely_domains)
|
||||||
|
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_invite(self, target_host, event):
|
def send_invite(self, target_host, event):
|
||||||
""" Sends the invite to the remote server for signing.
|
""" Sends the invite to the remote server for signing.
|
||||||
|
@ -376,30 +555,14 @@ class FederationHandler(BaseHandler):
|
||||||
# FIXME
|
# FIXME
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for e in auth_chain:
|
yield self._handle_auth_events(
|
||||||
e.internal_metadata.outlier = True
|
origin, [e for e in auth_chain if e.event_id != event.event_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_state(e):
|
||||||
if e.event_id == event.event_id:
|
if e.event_id == event.event_id:
|
||||||
continue
|
return
|
||||||
|
|
||||||
try:
|
|
||||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
|
||||||
auth = {
|
|
||||||
(e.type, e.state_key): e for e in auth_chain
|
|
||||||
if e.event_id in auth_ids
|
|
||||||
}
|
|
||||||
yield self._handle_new_event(
|
|
||||||
origin, e, auth_events=auth
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to handle auth event %s",
|
|
||||||
e.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
for e in state:
|
|
||||||
if e.event_id == event.event_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
try:
|
||||||
|
@ -417,13 +580,15 @@ class FederationHandler(BaseHandler):
|
||||||
e.event_id,
|
e.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield defer.DeferredList([handle_state(e) for e in state])
|
||||||
|
|
||||||
auth_ids = [e_id for e_id, _ in event.auth_events]
|
auth_ids = [e_id for e_id, _ in event.auth_events]
|
||||||
auth_events = {
|
auth_events = {
|
||||||
(e.type, e.state_key): e for e in auth_chain
|
(e.type, e.state_key): e for e in auth_chain
|
||||||
if e.event_id in auth_ids
|
if e.event_id in auth_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
yield self._handle_new_event(
|
_, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||||
origin,
|
origin,
|
||||||
new_event,
|
new_event,
|
||||||
state=state,
|
state=state,
|
||||||
|
@ -431,9 +596,11 @@ class FederationHandler(BaseHandler):
|
||||||
auth_events=auth_events,
|
auth_events=auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
d = self.notifier.on_new_room_event(
|
with PreserveLoggingContext():
|
||||||
new_event, extra_users=[joinee]
|
d = self.notifier.on_new_room_event(
|
||||||
)
|
new_event, event_stream_id, max_stream_id,
|
||||||
|
extra_users=[joinee]
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -498,7 +665,9 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
event.internal_metadata.outlier = False
|
event.internal_metadata.outlier = False
|
||||||
|
|
||||||
context = yield self._handle_new_event(origin, event)
|
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||||
|
origin, event
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
|
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
|
||||||
|
@ -512,9 +681,10 @@ class FederationHandler(BaseHandler):
|
||||||
target_user = UserID.from_string(target_user_id)
|
target_user = UserID.from_string(target_user_id)
|
||||||
extra_users.append(target_user)
|
extra_users.append(target_user)
|
||||||
|
|
||||||
d = self.notifier.on_new_room_event(
|
with PreserveLoggingContext():
|
||||||
event, extra_users=extra_users
|
d = self.notifier.on_new_room_event(
|
||||||
)
|
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -587,16 +757,18 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
context = yield self.state_handler.compute_event_context(event)
|
context = yield self.state_handler.compute_event_context(event)
|
||||||
|
|
||||||
yield self.store.persist_event(
|
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||||
event,
|
event,
|
||||||
context=context,
|
context=context,
|
||||||
backfilled=False,
|
backfilled=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
target_user = UserID.from_string(event.state_key)
|
target_user = UserID.from_string(event.state_key)
|
||||||
d = self.notifier.on_new_room_event(
|
with PreserveLoggingContext():
|
||||||
event, extra_users=[target_user],
|
d = self.notifier.on_new_room_event(
|
||||||
)
|
event, event_stream_id, max_stream_id,
|
||||||
|
extra_users=[target_user],
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -745,9 +917,12 @@ class FederationHandler(BaseHandler):
|
||||||
# This is a hack to fix some old rooms where the initial join event
|
# This is a hack to fix some old rooms where the initial join event
|
||||||
# didn't reference the create event in its auth events.
|
# didn't reference the create event in its auth events.
|
||||||
if event.type == EventTypes.Member and not event.auth_events:
|
if event.type == EventTypes.Member and not event.auth_events:
|
||||||
if len(event.prev_events) == 1:
|
if len(event.prev_events) == 1 and event.depth < 5:
|
||||||
c = yield self.store.get_event(event.prev_events[0][0])
|
c = yield self.store.get_event(
|
||||||
if c.type == EventTypes.Create:
|
event.prev_events[0][0],
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if c and c.type == EventTypes.Create:
|
||||||
auth_events[(c.type, c.state_key)] = c
|
auth_events[(c.type, c.state_key)] = c
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -773,7 +948,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
yield self.store.persist_event(
|
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||||
event,
|
event,
|
||||||
context=context,
|
context=context,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
|
@ -781,7 +956,7 @@ class FederationHandler(BaseHandler):
|
||||||
current_state=current_state,
|
current_state=current_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(context)
|
defer.returnValue((context, event_stream_id, max_stream_id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
||||||
|
@ -921,7 +1096,7 @@ class FederationHandler(BaseHandler):
|
||||||
if d in have_events and not have_events[d]
|
if d in have_events and not have_events[d]
|
||||||
],
|
],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
if different_events:
|
if different_events:
|
||||||
local_view = dict(auth_events)
|
local_view = dict(auth_events)
|
||||||
|
@ -1166,3 +1341,52 @@ class FederationHandler(BaseHandler):
|
||||||
},
|
},
|
||||||
"missing": [e.event_id for e in missing_locals],
|
"missing": [e.event_id for e in missing_locals],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_auth_events(self, origin, auth_events):
|
||||||
|
auth_ids_to_deferred = {}
|
||||||
|
|
||||||
|
def process_auth_ev(ev):
|
||||||
|
auth_ids = [e_id for e_id, _ in ev.auth_events]
|
||||||
|
|
||||||
|
prev_ds = [
|
||||||
|
auth_ids_to_deferred[i]
|
||||||
|
for i in auth_ids
|
||||||
|
if i in auth_ids_to_deferred
|
||||||
|
]
|
||||||
|
|
||||||
|
d = defer.Deferred()
|
||||||
|
|
||||||
|
auth_ids_to_deferred[ev.event_id] = d
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def f(*_):
|
||||||
|
ev.internal_metadata.outlier = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
auth = {
|
||||||
|
(e.type, e.state_key): e for e in auth_events
|
||||||
|
if e.event_id in auth_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
yield self._handle_new_event(
|
||||||
|
origin, ev, auth_events=auth
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to handle auth event %s",
|
||||||
|
ev.event_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
d.callback(None)
|
||||||
|
|
||||||
|
if prev_ds:
|
||||||
|
dx = defer.DeferredList(prev_ds)
|
||||||
|
dx.addBoth(f)
|
||||||
|
else:
|
||||||
|
f()
|
||||||
|
|
||||||
|
for e in auth_events:
|
||||||
|
process_auth_ev(e)
|
||||||
|
|
||||||
|
yield defer.DeferredList(auth_ids_to_deferred.values())
|
||||||
|
|
|
@ -20,8 +20,9 @@ from synapse.api.errors import RoomError, SynapseError
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, RoomStreamToken
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -89,9 +90,19 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
if not pagin_config.from_token:
|
if not pagin_config.from_token:
|
||||||
pagin_config.from_token = (
|
pagin_config.from_token = (
|
||||||
yield self.hs.get_event_sources().get_current_token()
|
yield self.hs.get_event_sources().get_current_token(
|
||||||
|
direction='b'
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
|
||||||
|
if room_token.topological is None:
|
||||||
|
raise SynapseError(400, "Invalid token")
|
||||||
|
|
||||||
|
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||||
|
room_id, room_token.topological
|
||||||
|
)
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
events, next_key = yield data_source.get_pagination_rows(
|
events, next_key = yield data_source.get_pagination_rows(
|
||||||
|
@ -303,7 +314,7 @@ class MessageHandler(BaseHandler):
|
||||||
event.room_id
|
event.room_id
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||||
|
@ -328,7 +339,7 @@ class MessageHandler(BaseHandler):
|
||||||
yield defer.gatherResults(
|
yield defer.gatherResults(
|
||||||
[handle_room(e) for e in room_list],
|
[handle_room(e) for e in room_list],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"rooms": rooms_ret,
|
"rooms": rooms_ret,
|
||||||
|
|
|
@ -18,8 +18,8 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
from synapse.util.logutils import log_function
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -146,6 +146,10 @@ class PresenceHandler(BaseHandler):
|
||||||
self._user_cachemap = {}
|
self._user_cachemap = {}
|
||||||
self._user_cachemap_latest_serial = 0
|
self._user_cachemap_latest_serial = 0
|
||||||
|
|
||||||
|
# map room_ids to the latest presence serial for a member of that
|
||||||
|
# room
|
||||||
|
self._room_serials = {}
|
||||||
|
|
||||||
metrics.register_callback(
|
metrics.register_callback(
|
||||||
"userCachemap:size",
|
"userCachemap:size",
|
||||||
lambda: len(self._user_cachemap),
|
lambda: len(self._user_cachemap),
|
||||||
|
@ -278,15 +282,14 @@ class PresenceHandler(BaseHandler):
|
||||||
now_online = state["presence"] != PresenceState.OFFLINE
|
now_online = state["presence"] != PresenceState.OFFLINE
|
||||||
was_polling = target_user in self._user_cachemap
|
was_polling = target_user in self._user_cachemap
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
if now_online and not was_polling:
|
||||||
if now_online and not was_polling:
|
self.start_polling_presence(target_user, state=state)
|
||||||
self.start_polling_presence(target_user, state=state)
|
elif not now_online and was_polling:
|
||||||
elif not now_online and was_polling:
|
self.stop_polling_presence(target_user)
|
||||||
self.stop_polling_presence(target_user)
|
|
||||||
|
|
||||||
# TODO(paul): perform a presence push as part of start/stop poll so
|
# TODO(paul): perform a presence push as part of start/stop poll so
|
||||||
# we don't have to do this all the time
|
# we don't have to do this all the time
|
||||||
self.changed_presencelike_data(target_user, state)
|
self.changed_presencelike_data(target_user, state)
|
||||||
|
|
||||||
def bump_presence_active_time(self, user, now=None):
|
def bump_presence_active_time(self, user, now=None):
|
||||||
if now is None:
|
if now is None:
|
||||||
|
@ -298,13 +301,34 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
self.changed_presencelike_data(user, {"last_active": now})
|
self.changed_presencelike_data(user, {"last_active": now})
|
||||||
|
|
||||||
|
def get_joined_rooms_for_user(self, user):
|
||||||
|
"""Get the list of rooms a user is joined to.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The user.
|
||||||
|
Returns:
|
||||||
|
A Deferred of a list of room id strings.
|
||||||
|
"""
|
||||||
|
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||||
|
return rm_handler.get_joined_rooms_for_user(user)
|
||||||
|
|
||||||
|
def get_joined_users_for_room_id(self, room_id):
|
||||||
|
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||||
|
return rm_handler.get_room_members(room_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def changed_presencelike_data(self, user, state):
|
def changed_presencelike_data(self, user, state):
|
||||||
statuscache = self._get_or_make_usercache(user)
|
"""Updates the presence state of a local user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The user being updated.
|
||||||
|
state(dict): The new presence state for the user.
|
||||||
|
Returns:
|
||||||
|
A Deferred
|
||||||
|
"""
|
||||||
self._user_cachemap_latest_serial += 1
|
self._user_cachemap_latest_serial += 1
|
||||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
statuscache = yield self.update_presence_cache(user, state)
|
||||||
|
yield self.push_presence(user, statuscache=statuscache)
|
||||||
return self.push_presence(user, statuscache=statuscache)
|
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def started_user_eventstream(self, user):
|
def started_user_eventstream(self, user):
|
||||||
|
@ -318,14 +342,21 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_joined_room(self, user, room_id):
|
def user_joined_room(self, user, room_id):
|
||||||
if self.hs.is_mine(user):
|
"""Called via the distributor whenever a user joins a room.
|
||||||
statuscache = self._get_or_make_usercache(user)
|
Notifies the new member of the presence of the current members.
|
||||||
|
Notifies the current members of the room of the new member's presence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The user who joined the room.
|
||||||
|
room_id(str): The room id the user joined.
|
||||||
|
"""
|
||||||
|
if self.hs.is_mine(user):
|
||||||
# No actual update but we need to bump the serial anyway for the
|
# No actual update but we need to bump the serial anyway for the
|
||||||
# event source
|
# event source
|
||||||
self._user_cachemap_latest_serial += 1
|
self._user_cachemap_latest_serial += 1
|
||||||
statuscache.update({}, serial=self._user_cachemap_latest_serial)
|
statuscache = yield self.update_presence_cache(
|
||||||
|
user, room_ids=[room_id]
|
||||||
|
)
|
||||||
self.push_update_to_local_and_remote(
|
self.push_update_to_local_and_remote(
|
||||||
observed_user=user,
|
observed_user=user,
|
||||||
room_ids=[room_id],
|
room_ids=[room_id],
|
||||||
|
@ -333,18 +364,22 @@ class PresenceHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
# We also want to tell them about current presence of people.
|
# We also want to tell them about current presence of people.
|
||||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
curr_users = yield self.get_joined_users_for_room_id(room_id)
|
||||||
curr_users = yield rm_handler.get_room_members(room_id)
|
|
||||||
|
|
||||||
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
|
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
|
||||||
|
statuscache = yield self.update_presence_cache(
|
||||||
|
local_user, room_ids=[room_id], add_to_cache=False
|
||||||
|
)
|
||||||
|
|
||||||
self.push_update_to_local_and_remote(
|
self.push_update_to_local_and_remote(
|
||||||
observed_user=local_user,
|
observed_user=local_user,
|
||||||
users_to_push=[user],
|
users_to_push=[user],
|
||||||
statuscache=self._get_or_offline_usercache(local_user),
|
statuscache=statuscache,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_invite(self, observer_user, observed_user):
|
def send_invite(self, observer_user, observed_user):
|
||||||
|
"""Request the presence of a local or remote user for a local user"""
|
||||||
if not self.hs.is_mine(observer_user):
|
if not self.hs.is_mine(observer_user):
|
||||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||||
|
|
||||||
|
@ -379,6 +414,15 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def invite_presence(self, observed_user, observer_user):
|
def invite_presence(self, observed_user, observer_user):
|
||||||
|
"""Handles a m.presence_invite EDU. A remote or local user has
|
||||||
|
requested presence updates for a local user. If the invite is accepted
|
||||||
|
then allow the local or remote user to see the presence of the local
|
||||||
|
user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observed_user(UserID): The local user whose presence is requested.
|
||||||
|
observer_user(UserID): The remote or local user requesting presence.
|
||||||
|
"""
|
||||||
accept = yield self._should_accept_invite(observed_user, observer_user)
|
accept = yield self._should_accept_invite(observed_user, observer_user)
|
||||||
|
|
||||||
if accept:
|
if accept:
|
||||||
|
@ -405,16 +449,34 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def accept_presence(self, observed_user, observer_user):
|
def accept_presence(self, observed_user, observer_user):
|
||||||
|
"""Handles a m.presence_accept EDU. Mark a presence invite from a
|
||||||
|
local or remote user as accepted in a local user's presence list.
|
||||||
|
Starts polling for presence updates from the local or remote user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observed_user(UserID): The user to update in the presence list.
|
||||||
|
observer_user(UserID): The owner of the presence list to update.
|
||||||
|
"""
|
||||||
yield self.store.set_presence_list_accepted(
|
yield self.store.set_presence_list_accepted(
|
||||||
observer_user.localpart, observed_user.to_string()
|
observer_user.localpart, observed_user.to_string()
|
||||||
)
|
)
|
||||||
with PreserveLoggingContext():
|
|
||||||
self.start_polling_presence(
|
self.start_polling_presence(
|
||||||
observer_user, target_user=observed_user
|
observer_user, target_user=observed_user
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def deny_presence(self, observed_user, observer_user):
|
def deny_presence(self, observed_user, observer_user):
|
||||||
|
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
|
||||||
|
local user's presence list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observed_user(UserID): The local or remote user to remove from the
|
||||||
|
list.
|
||||||
|
observer_user(UserID): The local owner of the presence list.
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
yield self.store.del_presence_list(
|
yield self.store.del_presence_list(
|
||||||
observer_user.localpart, observed_user.to_string()
|
observer_user.localpart, observed_user.to_string()
|
||||||
)
|
)
|
||||||
|
@ -423,6 +485,16 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def drop(self, observed_user, observer_user):
|
def drop(self, observed_user, observer_user):
|
||||||
|
"""Remove a local or remote user from a local user's presence list and
|
||||||
|
unsubscribe the local user from updates that user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observed_user(UserId): The local or remote user to remove from the
|
||||||
|
list.
|
||||||
|
observer_user(UserId): The local owner of the presence list.
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
if not self.hs.is_mine(observer_user):
|
if not self.hs.is_mine(observer_user):
|
||||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||||
|
|
||||||
|
@ -430,34 +502,66 @@ class PresenceHandler(BaseHandler):
|
||||||
observer_user.localpart, observed_user.to_string()
|
observer_user.localpart, observed_user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
self.stop_polling_presence(
|
||||||
self.stop_polling_presence(
|
observer_user, target_user=observed_user
|
||||||
observer_user, target_user=observed_user
|
)
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_presence_list(self, observer_user, accepted=None):
|
def get_presence_list(self, observer_user, accepted=None):
|
||||||
|
"""Get the presence list for a local user. The retured list includes
|
||||||
|
the current presence state for each user listed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observer_user(UserID): The local user whose presence list to fetch.
|
||||||
|
accepted(bool or None): If not none then only include users who
|
||||||
|
have or have not accepted the presence invite request.
|
||||||
|
Returns:
|
||||||
|
A Deferred list of presence state events.
|
||||||
|
"""
|
||||||
if not self.hs.is_mine(observer_user):
|
if not self.hs.is_mine(observer_user):
|
||||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||||
|
|
||||||
presence = yield self.store.get_presence_list(
|
presence_list = yield self.store.get_presence_list(
|
||||||
observer_user.localpart, accepted=accepted
|
observer_user.localpart, accepted=accepted
|
||||||
)
|
)
|
||||||
|
|
||||||
for p in presence:
|
results = []
|
||||||
observed_user = UserID.from_string(p.pop("observed_user_id"))
|
for row in presence_list:
|
||||||
p["observed_user"] = observed_user
|
observed_user = UserID.from_string(row["observed_user_id"])
|
||||||
p.update(self._get_or_offline_usercache(observed_user).get_state())
|
result = {
|
||||||
if "last_active" in p:
|
"observed_user": observed_user, "accepted": row["accepted"]
|
||||||
p["last_active_ago"] = int(
|
}
|
||||||
self.clock.time_msec() - p.pop("last_active")
|
result.update(
|
||||||
|
self._get_or_offline_usercache(observed_user).get_state()
|
||||||
|
)
|
||||||
|
if "last_active" in result:
|
||||||
|
result["last_active_ago"] = int(
|
||||||
|
self.clock.time_msec() - result.pop("last_active")
|
||||||
)
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
defer.returnValue(presence)
|
defer.returnValue(results)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def start_polling_presence(self, user, target_user=None, state=None):
|
def start_polling_presence(self, user, target_user=None, state=None):
|
||||||
|
"""Subscribe a local user to presence updates from a local or remote
|
||||||
|
user. If no target_user is supplied then subscribe to all users stored
|
||||||
|
in the presence list for the local user.
|
||||||
|
|
||||||
|
Additonally this pushes the current presence state of this user to all
|
||||||
|
target_users. That state can be provided directly or will be read from
|
||||||
|
the stored state for the local user.
|
||||||
|
|
||||||
|
Also this attempts to notify the local user of the current state of
|
||||||
|
any local target users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The local user that whishes for presence updates.
|
||||||
|
target_user(UserID): The local or remote user whose updates are
|
||||||
|
wanted.
|
||||||
|
state(dict): Optional presence state for the local user.
|
||||||
|
"""
|
||||||
logger.debug("Start polling for presence from %s", user)
|
logger.debug("Start polling for presence from %s", user)
|
||||||
|
|
||||||
if target_user:
|
if target_user:
|
||||||
|
@ -473,8 +577,7 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
# Also include people in all my rooms
|
# Also include people in all my rooms
|
||||||
|
|
||||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
room_ids = yield self.get_joined_rooms_for_user(user)
|
||||||
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
|
|
||||||
|
|
||||||
if state is None:
|
if state is None:
|
||||||
state = yield self.store.get_presence_state(user.localpart)
|
state = yield self.store.get_presence_state(user.localpart)
|
||||||
|
@ -498,9 +601,7 @@ class PresenceHandler(BaseHandler):
|
||||||
# We want to tell the person that just came online
|
# We want to tell the person that just came online
|
||||||
# presence state of people they are interested in?
|
# presence state of people they are interested in?
|
||||||
self.push_update_to_clients(
|
self.push_update_to_clients(
|
||||||
observed_user=target_user,
|
|
||||||
users_to_push=[user],
|
users_to_push=[user],
|
||||||
statuscache=self._get_or_offline_usercache(target_user),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
deferreds = []
|
deferreds = []
|
||||||
|
@ -517,6 +618,12 @@ class PresenceHandler(BaseHandler):
|
||||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
|
|
||||||
def _start_polling_local(self, user, target_user):
|
def _start_polling_local(self, user, target_user):
|
||||||
|
"""Subscribe a local user to presence updates for a local user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserId): The local user that wishes for updates.
|
||||||
|
target_user(UserId): The local users whose updates are wanted.
|
||||||
|
"""
|
||||||
target_localpart = target_user.localpart
|
target_localpart = target_user.localpart
|
||||||
|
|
||||||
if target_localpart not in self._local_pushmap:
|
if target_localpart not in self._local_pushmap:
|
||||||
|
@ -525,6 +632,17 @@ class PresenceHandler(BaseHandler):
|
||||||
self._local_pushmap[target_localpart].add(user)
|
self._local_pushmap[target_localpart].add(user)
|
||||||
|
|
||||||
def _start_polling_remote(self, user, domain, remoteusers):
|
def _start_polling_remote(self, user, domain, remoteusers):
|
||||||
|
"""Subscribe a local user to presence updates for remote users on a
|
||||||
|
given remote domain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The local user that wishes for updates.
|
||||||
|
domain(str): The remote server the local user wants updates from.
|
||||||
|
remoteusers(UserID): The remote users that local user wants to be
|
||||||
|
told about.
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
to_poll = set()
|
to_poll = set()
|
||||||
|
|
||||||
for u in remoteusers:
|
for u in remoteusers:
|
||||||
|
@ -545,6 +663,17 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def stop_polling_presence(self, user, target_user=None):
|
def stop_polling_presence(self, user, target_user=None):
|
||||||
|
"""Unsubscribe a local user from presence updates from a local or
|
||||||
|
remote user. If no target user is supplied then unsubscribe the user
|
||||||
|
from all presence updates that the user had subscribed to.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The local user that no longer wishes for updates.
|
||||||
|
target_user(UserID or None): The user whose updates are no longer
|
||||||
|
wanted.
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
logger.debug("Stop polling for presence from %s", user)
|
logger.debug("Stop polling for presence from %s", user)
|
||||||
|
|
||||||
if not target_user or self.hs.is_mine(target_user):
|
if not target_user or self.hs.is_mine(target_user):
|
||||||
|
@ -573,6 +702,13 @@ class PresenceHandler(BaseHandler):
|
||||||
return defer.DeferredList(deferreds, consumeErrors=True)
|
return defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
|
|
||||||
def _stop_polling_local(self, user, target_user):
|
def _stop_polling_local(self, user, target_user):
|
||||||
|
"""Unsubscribe a local user from presence updates from a local user on
|
||||||
|
this server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The local user that no longer wishes for updates.
|
||||||
|
target_user(UserID): The user whose updates are no longer wanted.
|
||||||
|
"""
|
||||||
for localpart in self._local_pushmap.keys():
|
for localpart in self._local_pushmap.keys():
|
||||||
if target_user and localpart != target_user.localpart:
|
if target_user and localpart != target_user.localpart:
|
||||||
continue
|
continue
|
||||||
|
@ -585,6 +721,17 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _stop_polling_remote(self, user, domain, remoteusers):
|
def _stop_polling_remote(self, user, domain, remoteusers):
|
||||||
|
"""Unsubscribe a local user from presence updates from remote users on
|
||||||
|
a given domain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The local user that no longer wishes for updates.
|
||||||
|
domain(str): The remote server to unsubscribe from.
|
||||||
|
remoteusers([UserID]): The users on that remote server that the
|
||||||
|
local user no longer wishes to be updated about.
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
to_unpoll = set()
|
to_unpoll = set()
|
||||||
|
|
||||||
for u in remoteusers:
|
for u in remoteusers:
|
||||||
|
@ -606,6 +753,19 @@ class PresenceHandler(BaseHandler):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def push_presence(self, user, statuscache):
|
def push_presence(self, user, statuscache):
|
||||||
|
"""
|
||||||
|
Notify local and remote users of a change in presence of a local user.
|
||||||
|
Pushes the update to local clients and remote domains that are directly
|
||||||
|
subscribed to the presence of the local user.
|
||||||
|
Also pushes that update to any local user or remote domain that shares
|
||||||
|
a room with the local user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The local user whose presence was updated.
|
||||||
|
statuscache(UserPresenceCache): Cache of the user's presence state
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
assert(self.hs.is_mine(user))
|
assert(self.hs.is_mine(user))
|
||||||
|
|
||||||
logger.debug("Pushing presence update from %s", user)
|
logger.debug("Pushing presence update from %s", user)
|
||||||
|
@ -617,8 +777,7 @@ class PresenceHandler(BaseHandler):
|
||||||
# and also user is informed of server-forced pushes
|
# and also user is informed of server-forced pushes
|
||||||
localusers.add(user)
|
localusers.add(user)
|
||||||
|
|
||||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
room_ids = yield self.get_joined_rooms_for_user(user)
|
||||||
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
|
|
||||||
|
|
||||||
if not localusers and not room_ids:
|
if not localusers and not room_ids:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
@ -632,45 +791,24 @@ class PresenceHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
yield self.distributor.fire("user_presence_changed", user, statuscache)
|
yield self.distributor.fire("user_presence_changed", user, statuscache)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _push_presence_remote(self, user, destination, state=None):
|
|
||||||
if state is None:
|
|
||||||
state = yield self.store.get_presence_state(user.localpart)
|
|
||||||
del state["mtime"]
|
|
||||||
state["presence"] = state.pop("state")
|
|
||||||
|
|
||||||
if user in self._user_cachemap:
|
|
||||||
state["last_active"] = (
|
|
||||||
self._user_cachemap[user].get_state()["last_active"]
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.distributor.fire(
|
|
||||||
"collect_presencelike_data", user, state
|
|
||||||
)
|
|
||||||
|
|
||||||
if "last_active" in state:
|
|
||||||
state = dict(state)
|
|
||||||
state["last_active_ago"] = int(
|
|
||||||
self.clock.time_msec() - state.pop("last_active")
|
|
||||||
)
|
|
||||||
|
|
||||||
user_state = {
|
|
||||||
"user_id": user.to_string(),
|
|
||||||
}
|
|
||||||
user_state.update(**state)
|
|
||||||
|
|
||||||
yield self.federation.send_edu(
|
|
||||||
destination=destination,
|
|
||||||
edu_type="m.presence",
|
|
||||||
content={
|
|
||||||
"push": [
|
|
||||||
user_state,
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def incoming_presence(self, origin, content):
|
def incoming_presence(self, origin, content):
|
||||||
|
"""Handle an incoming m.presence EDU.
|
||||||
|
For each presence update in the "push" list update our local cache and
|
||||||
|
notify the appropriate local clients. Only clients that share a room
|
||||||
|
or are directly subscribed to the presence for a user should be
|
||||||
|
notified of the update.
|
||||||
|
For each subscription request in the "poll" list start pushing presence
|
||||||
|
updates to the remote server.
|
||||||
|
For unsubscribe request in the "unpoll" list stop pushing presence
|
||||||
|
updates to the remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orgin(str): The source of this m.presence EDU.
|
||||||
|
content(dict): The content of this m.presence EDU.
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
deferreds = []
|
deferreds = []
|
||||||
|
|
||||||
for push in content.get("push", []):
|
for push in content.get("push", []):
|
||||||
|
@ -684,8 +822,7 @@ class PresenceHandler(BaseHandler):
|
||||||
" | %d interested local observers %r", len(observers), observers
|
" | %d interested local observers %r", len(observers), observers
|
||||||
)
|
)
|
||||||
|
|
||||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
room_ids = yield self.get_joined_rooms_for_user(user)
|
||||||
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
|
|
||||||
if room_ids:
|
if room_ids:
|
||||||
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
|
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
|
||||||
|
|
||||||
|
@ -704,20 +841,15 @@ class PresenceHandler(BaseHandler):
|
||||||
self.clock.time_msec() - state.pop("last_active_ago")
|
self.clock.time_msec() - state.pop("last_active_ago")
|
||||||
)
|
)
|
||||||
|
|
||||||
statuscache = self._get_or_make_usercache(user)
|
|
||||||
|
|
||||||
self._user_cachemap_latest_serial += 1
|
self._user_cachemap_latest_serial += 1
|
||||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
yield self.update_presence_cache(user, state, room_ids=room_ids)
|
||||||
|
|
||||||
if not observers and not room_ids:
|
if not observers and not room_ids:
|
||||||
logger.debug(" | no interested observers or room IDs")
|
logger.debug(" | no interested observers or room IDs")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.push_update_to_clients(
|
self.push_update_to_clients(
|
||||||
observed_user=user,
|
users_to_push=observers, room_ids=room_ids
|
||||||
users_to_push=observers,
|
|
||||||
room_ids=room_ids,
|
|
||||||
statuscache=statuscache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
@ -766,13 +898,58 @@ class PresenceHandler(BaseHandler):
|
||||||
if not self._remote_sendmap[user]:
|
if not self._remote_sendmap[user]:
|
||||||
del self._remote_sendmap[user]
|
del self._remote_sendmap[user]
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_presence_cache(self, user, state={}, room_ids=None,
|
||||||
|
add_to_cache=True):
|
||||||
|
"""Update the presence cache for a user with a new state and bump the
|
||||||
|
serial to the latest value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The user being updated
|
||||||
|
state(dict): The presence state being updated
|
||||||
|
room_ids(None or list of str): A list of room_ids to update. If
|
||||||
|
room_ids is None then fetch the list of room_ids the user is
|
||||||
|
joined to.
|
||||||
|
add_to_cache: Whether to add an entry to the presence cache if the
|
||||||
|
user isn't already in the cache.
|
||||||
|
Returns:
|
||||||
|
A Deferred UserPresenceCache for the user being updated.
|
||||||
|
"""
|
||||||
|
if room_ids is None:
|
||||||
|
room_ids = yield self.get_joined_rooms_for_user(user)
|
||||||
|
|
||||||
|
for room_id in room_ids:
|
||||||
|
self._room_serials[room_id] = self._user_cachemap_latest_serial
|
||||||
|
if add_to_cache:
|
||||||
|
statuscache = self._get_or_make_usercache(user)
|
||||||
|
else:
|
||||||
|
statuscache = self._get_or_offline_usercache(user)
|
||||||
|
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||||
|
defer.returnValue(statuscache)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def push_update_to_local_and_remote(self, observed_user, statuscache,
|
def push_update_to_local_and_remote(self, observed_user, statuscache,
|
||||||
users_to_push=[], room_ids=[],
|
users_to_push=[], room_ids=[],
|
||||||
remote_domains=[]):
|
remote_domains=[]):
|
||||||
|
"""Notify local clients and remote servers of a change in the presence
|
||||||
|
of a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observed_user(UserID): The user to push the presence state for.
|
||||||
|
statuscache(UserPresenceCache): The cache for the presence state to
|
||||||
|
push.
|
||||||
|
users_to_push([UserID]): A list of local and remote users to
|
||||||
|
notify.
|
||||||
|
room_ids([str]): Notify the local and remote occupants of these
|
||||||
|
rooms.
|
||||||
|
remote_domains([str]): A list of remote servers to notify in
|
||||||
|
addition to those implied by the users_to_push and the
|
||||||
|
room_ids.
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
|
|
||||||
localusers, remoteusers = partitionbool(
|
localusers, remoteusers = partitionbool(
|
||||||
users_to_push,
|
users_to_push,
|
||||||
|
@ -782,10 +959,7 @@ class PresenceHandler(BaseHandler):
|
||||||
localusers = set(localusers)
|
localusers = set(localusers)
|
||||||
|
|
||||||
self.push_update_to_clients(
|
self.push_update_to_clients(
|
||||||
observed_user=observed_user,
|
users_to_push=localusers, room_ids=room_ids
|
||||||
users_to_push=localusers,
|
|
||||||
room_ids=room_ids,
|
|
||||||
statuscache=statuscache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
remote_domains = set(remote_domains)
|
remote_domains = set(remote_domains)
|
||||||
|
@ -810,11 +984,65 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
defer.returnValue((localusers, remote_domains))
|
defer.returnValue((localusers, remote_domains))
|
||||||
|
|
||||||
def push_update_to_clients(self, observed_user, users_to_push=[],
|
def push_update_to_clients(self, users_to_push=[], room_ids=[]):
|
||||||
room_ids=[], statuscache=None):
|
"""Notify clients of a new presence event.
|
||||||
self.notifier.on_new_user_event(
|
|
||||||
users_to_push,
|
Args:
|
||||||
room_ids,
|
users_to_push([UserID]): List of users to notify.
|
||||||
|
room_ids([str]): List of room_ids to notify.
|
||||||
|
"""
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
self.notifier.on_new_user_event(
|
||||||
|
"presence_key",
|
||||||
|
self._user_cachemap_latest_serial,
|
||||||
|
users_to_push,
|
||||||
|
room_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _push_presence_remote(self, user, destination, state=None):
|
||||||
|
"""Push a user's presence to a remote server. If a presence state event
|
||||||
|
that event is sent. Otherwise a new state event is constructed from the
|
||||||
|
stored presence state.
|
||||||
|
The last_active is replaced with last_active_ago in case the wallclock
|
||||||
|
time on the remote server is different to the time on this server.
|
||||||
|
Sends an EDU to the remote server with the current presence state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user(UserID): The user to push the presence state for.
|
||||||
|
destination(str): The remote server to send state to.
|
||||||
|
state(dict): The state to push, or None to use the current stored
|
||||||
|
state.
|
||||||
|
Returns:
|
||||||
|
A Deferred.
|
||||||
|
"""
|
||||||
|
if state is None:
|
||||||
|
state = yield self.store.get_presence_state(user.localpart)
|
||||||
|
del state["mtime"]
|
||||||
|
state["presence"] = state.pop("state")
|
||||||
|
|
||||||
|
if user in self._user_cachemap:
|
||||||
|
state["last_active"] = (
|
||||||
|
self._user_cachemap[user].get_state()["last_active"]
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.distributor.fire(
|
||||||
|
"collect_presencelike_data", user, state
|
||||||
|
)
|
||||||
|
|
||||||
|
if "last_active" in state:
|
||||||
|
state = dict(state)
|
||||||
|
state["last_active_ago"] = int(
|
||||||
|
self.clock.time_msec() - state.pop("last_active")
|
||||||
|
)
|
||||||
|
|
||||||
|
user_state = {"user_id": user.to_string(), }
|
||||||
|
user_state.update(state)
|
||||||
|
|
||||||
|
yield self.federation.send_edu(
|
||||||
|
destination=destination,
|
||||||
|
edu_type="m.presence",
|
||||||
|
content={"push": [user_state, ], }
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -823,39 +1051,11 @@ class PresenceEventSource(object):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def is_visible(self, observer_user, observed_user):
|
|
||||||
if observer_user == observed_user:
|
|
||||||
defer.returnValue(True)
|
|
||||||
|
|
||||||
presence = self.hs.get_handlers().presence_handler
|
|
||||||
|
|
||||||
if (yield presence.store.user_rooms_intersect(
|
|
||||||
[u.to_string() for u in observer_user, observed_user])):
|
|
||||||
defer.returnValue(True)
|
|
||||||
|
|
||||||
if self.hs.is_mine(observed_user):
|
|
||||||
pushmap = presence._local_pushmap
|
|
||||||
|
|
||||||
defer.returnValue(
|
|
||||||
observed_user.localpart in pushmap and
|
|
||||||
observer_user in pushmap[observed_user.localpart]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
recvmap = presence._remote_recvmap
|
|
||||||
|
|
||||||
defer.returnValue(
|
|
||||||
observed_user in recvmap and
|
|
||||||
observer_user in recvmap[observed_user]
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
def get_new_events_for_user(self, user, from_key, limit):
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
|
|
||||||
observer_user = user
|
|
||||||
|
|
||||||
presence = self.hs.get_handlers().presence_handler
|
presence = self.hs.get_handlers().presence_handler
|
||||||
cachemap = presence._user_cachemap
|
cachemap = presence._user_cachemap
|
||||||
|
|
||||||
|
@ -864,17 +1064,27 @@ class PresenceEventSource(object):
|
||||||
clock = self.clock
|
clock = self.clock
|
||||||
latest_serial = 0
|
latest_serial = 0
|
||||||
|
|
||||||
|
user_ids_to_check = {user}
|
||||||
|
presence_list = yield presence.store.get_presence_list(
|
||||||
|
user.localpart, accepted=True
|
||||||
|
)
|
||||||
|
if presence_list is not None:
|
||||||
|
user_ids_to_check |= set(
|
||||||
|
UserID.from_string(p["observed_user_id"]) for p in presence_list
|
||||||
|
)
|
||||||
|
room_ids = yield presence.get_joined_rooms_for_user(user)
|
||||||
|
for room_id in set(room_ids) & set(presence._room_serials):
|
||||||
|
if presence._room_serials[room_id] > from_key:
|
||||||
|
joined = yield presence.get_joined_users_for_room_id(room_id)
|
||||||
|
user_ids_to_check |= set(joined)
|
||||||
|
|
||||||
updates = []
|
updates = []
|
||||||
# TODO(paul): use a DeferredList ? How to limit concurrency.
|
for observed_user in user_ids_to_check & set(cachemap):
|
||||||
for observed_user in cachemap.keys():
|
|
||||||
cached = cachemap[observed_user]
|
cached = cachemap[observed_user]
|
||||||
|
|
||||||
if cached.serial <= from_key or cached.serial > max_serial:
|
if cached.serial <= from_key or cached.serial > max_serial:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not (yield self.is_visible(observer_user, observed_user)):
|
|
||||||
continue
|
|
||||||
|
|
||||||
latest_serial = max(cached.serial, latest_serial)
|
latest_serial = max(cached.serial, latest_serial)
|
||||||
updates.append(cached.make_event(user=observed_user, clock=clock))
|
updates.append(cached.make_event(user=observed_user, clock=clock))
|
||||||
|
|
||||||
|
@ -911,8 +1121,6 @@ class PresenceEventSource(object):
|
||||||
def get_pagination_rows(self, user, pagination_config, key):
|
def get_pagination_rows(self, user, pagination_config, key):
|
||||||
# TODO (erikj): Does this make sense? Ordering?
|
# TODO (erikj): Does this make sense? Ordering?
|
||||||
|
|
||||||
observer_user = user
|
|
||||||
|
|
||||||
from_key = int(pagination_config.from_key)
|
from_key = int(pagination_config.from_key)
|
||||||
|
|
||||||
if pagination_config.to_key:
|
if pagination_config.to_key:
|
||||||
|
@ -923,14 +1131,26 @@ class PresenceEventSource(object):
|
||||||
presence = self.hs.get_handlers().presence_handler
|
presence = self.hs.get_handlers().presence_handler
|
||||||
cachemap = presence._user_cachemap
|
cachemap = presence._user_cachemap
|
||||||
|
|
||||||
|
user_ids_to_check = {user}
|
||||||
|
presence_list = yield presence.store.get_presence_list(
|
||||||
|
user.localpart, accepted=True
|
||||||
|
)
|
||||||
|
if presence_list is not None:
|
||||||
|
user_ids_to_check |= set(
|
||||||
|
UserID.from_string(p["observed_user_id"]) for p in presence_list
|
||||||
|
)
|
||||||
|
room_ids = yield presence.get_joined_rooms_for_user(user)
|
||||||
|
for room_id in set(room_ids) & set(presence._room_serials):
|
||||||
|
if presence._room_serials[room_id] >= from_key:
|
||||||
|
joined = yield presence.get_joined_users_for_room_id(room_id)
|
||||||
|
user_ids_to_check |= set(joined)
|
||||||
|
|
||||||
updates = []
|
updates = []
|
||||||
# TODO(paul): use a DeferredList ? How to limit concurrency.
|
for observed_user in user_ids_to_check & set(cachemap):
|
||||||
for observed_user in cachemap.keys():
|
|
||||||
if not (to_key < cachemap[observed_user].serial <= from_key):
|
if not (to_key < cachemap[observed_user].serial <= from_key):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (yield self.is_visible(observer_user, observed_user)):
|
updates.append((observed_user, cachemap[observed_user]))
|
||||||
updates.append((observed_user, cachemap[observed_user]))
|
|
||||||
|
|
||||||
# TODO(paul): limit
|
# TODO(paul): limit
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -88,6 +88,9 @@ class ProfileHandler(BaseHandler):
|
||||||
if target_user != auth_user:
|
if target_user != auth_user:
|
||||||
raise AuthError(400, "Cannot set another user's displayname")
|
raise AuthError(400, "Cannot set another user's displayname")
|
||||||
|
|
||||||
|
if new_displayname == '':
|
||||||
|
new_displayname = None
|
||||||
|
|
||||||
yield self.store.set_profile_displayname(
|
yield self.store.set_profile_displayname(
|
||||||
target_user.localpart, new_displayname
|
target_user.localpart, new_displayname
|
||||||
)
|
)
|
||||||
|
@ -154,14 +157,13 @@ class ProfileHandler(BaseHandler):
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
(displayname, avatar_url) = yield defer.gatherResults(
|
||||||
(displayname, avatar_url) = yield defer.gatherResults(
|
[
|
||||||
[
|
self.store.get_profile_displayname(user.localpart),
|
||||||
self.store.get_profile_displayname(user.localpart),
|
self.store.get_profile_avatar_url(user.localpart),
|
||||||
self.store.get_profile_avatar_url(user.localpart),
|
],
|
||||||
],
|
consumeErrors=True
|
||||||
consumeErrors=True
|
).addErrback(unwrapFirstError)
|
||||||
)
|
|
||||||
|
|
||||||
state["displayname"] = displayname
|
state["displayname"] = displayname
|
||||||
state["avatar_url"] = avatar_url
|
state["avatar_url"] = avatar_url
|
||||||
|
|
|
@ -21,11 +21,12 @@ from ._base import BaseHandler
|
||||||
from synapse.types import UserID, RoomAlias, RoomID
|
from synapse.types import UserID, RoomAlias, RoomID
|
||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
from synapse.api.errors import StoreError, SynapseError
|
from synapse.api.errors import StoreError, SynapseError
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils, unwrapFirstError
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import string
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -50,6 +51,10 @@ class RoomCreationHandler(BaseHandler):
|
||||||
self.ratelimit(user_id)
|
self.ratelimit(user_id)
|
||||||
|
|
||||||
if "room_alias_name" in config:
|
if "room_alias_name" in config:
|
||||||
|
for wchar in string.whitespace:
|
||||||
|
if wchar in config["room_alias_name"]:
|
||||||
|
raise SynapseError(400, "Invalid characters in room alias")
|
||||||
|
|
||||||
room_alias = RoomAlias.create(
|
room_alias = RoomAlias.create(
|
||||||
config["room_alias_name"],
|
config["room_alias_name"],
|
||||||
self.hs.hostname,
|
self.hs.hostname,
|
||||||
|
@ -535,7 +540,7 @@ class RoomListHandler(BaseHandler):
|
||||||
for room in chunk
|
for room in chunk
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
for i, room in enumerate(chunk):
|
for i, room in enumerate(chunk):
|
||||||
room["num_joined_members"] = len(results[i])
|
room["num_joined_members"] = len(results[i])
|
||||||
|
@ -575,8 +580,8 @@ class RoomEventSource(object):
|
||||||
|
|
||||||
defer.returnValue((events, end_key))
|
defer.returnValue((events, end_key))
|
||||||
|
|
||||||
def get_current_key(self):
|
def get_current_key(self, direction='f'):
|
||||||
return self.store.get_room_events_max_id()
|
return self.store.get_room_events_max_id(direction)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_pagination_rows(self, user, config, key):
|
def get_pagination_rows(self, user, config, key):
|
||||||
|
|
|
@ -92,7 +92,7 @@ class SyncHandler(BaseHandler):
|
||||||
result = yield self.current_sync_for_user(sync_config, since_token)
|
result = yield self.current_sync_for_user(sync_config, since_token)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
else:
|
else:
|
||||||
def current_sync_callback():
|
def current_sync_callback(before_token, after_token):
|
||||||
return self.current_sync_for_user(sync_config, since_token)
|
return self.current_sync_for_user(sync_config, since_token)
|
||||||
|
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
rm_handler = self.hs.get_handlers().room_member_handler
|
||||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -216,7 +217,10 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
self._latest_room_serial += 1
|
self._latest_room_serial += 1
|
||||||
self._room_serials[room_id] = self._latest_room_serial
|
self._room_serials[room_id] = self._latest_room_serial
|
||||||
|
|
||||||
self.notifier.on_new_user_event(rooms=[room_id])
|
with PreserveLoggingContext():
|
||||||
|
self.notifier.on_new_user_event(
|
||||||
|
"typing_key", self._latest_room_serial, rooms=[room_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TypingNotificationEventSource(object):
|
class TypingNotificationEventSource(object):
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -61,7 +62,10 @@ class SimpleHttpClient(object):
|
||||||
# A small wrapper around self.agent.request() so we can easily attach
|
# A small wrapper around self.agent.request() so we can easily attach
|
||||||
# counters to it
|
# counters to it
|
||||||
outgoing_requests_counter.inc(method)
|
outgoing_requests_counter.inc(method)
|
||||||
d = self.agent.request(method, *args, **kwargs)
|
d = preserve_context_over_fn(
|
||||||
|
self.agent.request,
|
||||||
|
method, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def _cb(response):
|
def _cb(response):
|
||||||
incoming_responses_counter.inc(method, response.code)
|
incoming_responses_counter.inc(method, response.code)
|
||||||
|
|
|
@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.endpoint import matrix_federation_endpoint
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
@ -110,7 +110,8 @@ class MatrixFederationHttpClient(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_request(self, destination, method, path_bytes,
|
def _create_request(self, destination, method, path_bytes,
|
||||||
body_callback, headers_dict={}, param_bytes=b"",
|
body_callback, headers_dict={}, param_bytes=b"",
|
||||||
query_bytes=b"", retry_on_dns_fail=True):
|
query_bytes=b"", retry_on_dns_fail=True,
|
||||||
|
timeout=None):
|
||||||
""" Creates and sends a request to the given url
|
""" Creates and sends a request to the given url
|
||||||
"""
|
"""
|
||||||
headers_dict[b"User-Agent"] = [self.version_string]
|
headers_dict[b"User-Agent"] = [self.version_string]
|
||||||
|
@ -144,22 +145,22 @@ class MatrixFederationHttpClient(object):
|
||||||
producer = body_callback(method, url_bytes, headers_dict)
|
producer = body_callback(method, url_bytes, headers_dict)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with PreserveLoggingContext():
|
request_deferred = preserve_context_over_fn(
|
||||||
request_deferred = self.agent.request(
|
self.agent.request,
|
||||||
destination,
|
destination,
|
||||||
endpoint,
|
endpoint,
|
||||||
method,
|
method,
|
||||||
path_bytes,
|
path_bytes,
|
||||||
param_bytes,
|
param_bytes,
|
||||||
query_bytes,
|
query_bytes,
|
||||||
Headers(headers_dict),
|
Headers(headers_dict),
|
||||||
producer
|
producer
|
||||||
)
|
)
|
||||||
|
|
||||||
response = yield self.clock.time_bound_deferred(
|
response = yield self.clock.time_bound_deferred(
|
||||||
request_deferred,
|
request_deferred,
|
||||||
time_out=60,
|
time_out=timeout/1000. if timeout else 60,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Got response to %s", method)
|
logger.debug("Got response to %s", method)
|
||||||
break
|
break
|
||||||
|
@ -181,7 +182,7 @@ class MatrixFederationHttpClient(object):
|
||||||
_flatten_response_never_received(e),
|
_flatten_response_never_received(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
if retries_left:
|
if retries_left and not timeout:
|
||||||
yield sleep(2 ** (5 - retries_left))
|
yield sleep(2 ** (5 - retries_left))
|
||||||
retries_left -= 1
|
retries_left -= 1
|
||||||
else:
|
else:
|
||||||
|
@ -334,7 +335,8 @@ class MatrixFederationHttpClient(object):
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
|
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
|
||||||
|
timeout=None):
|
||||||
""" GETs some json from the given host homeserver and path
|
""" GETs some json from the given host homeserver and path
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -343,6 +345,9 @@ class MatrixFederationHttpClient(object):
|
||||||
path (str): The HTTP path.
|
path (str): The HTTP path.
|
||||||
args (dict): A dictionary used to create query strings, defaults to
|
args (dict): A dictionary used to create query strings, defaults to
|
||||||
None.
|
None.
|
||||||
|
timeout (int): How long to try (in ms) the destination for before
|
||||||
|
giving up. None indicates no timeout and that the request will
|
||||||
|
be retried.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get *any* HTTP response.
|
Deferred: Succeeds when we get *any* HTTP response.
|
||||||
|
|
||||||
|
@ -370,7 +375,8 @@ class MatrixFederationHttpClient(object):
|
||||||
path.encode("ascii"),
|
path.encode("ascii"),
|
||||||
query_bytes=query_bytes,
|
query_bytes=query_bytes,
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
retry_on_dns_fail=retry_on_dns_fail
|
retry_on_dns_fail=retry_on_dns_fail,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
|
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
|
||||||
)
|
)
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from syutil.jsonutil import (
|
from syutil.jsonutil import (
|
||||||
|
@ -85,7 +85,9 @@ def request_handler(request_handler):
|
||||||
"Received request: %s %s",
|
"Received request: %s %s",
|
||||||
request.method, request.path
|
request.method, request.path
|
||||||
)
|
)
|
||||||
yield request_handler(self, request)
|
d = request_handler(self, request)
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
yield d
|
||||||
code = request.code
|
code = request.code
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
code = e.code
|
code = e.code
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -42,63 +42,78 @@ def count(func, l):
|
||||||
|
|
||||||
class _NotificationListener(object):
|
class _NotificationListener(object):
|
||||||
""" This represents a single client connection to the events stream.
|
""" This represents a single client connection to the events stream.
|
||||||
|
|
||||||
The events stream handler will have yielded to the deferred, so to
|
The events stream handler will have yielded to the deferred, so to
|
||||||
notify the handler it is sufficient to resolve the deferred.
|
notify the handler it is sufficient to resolve the deferred.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deferred):
|
||||||
|
self.deferred = deferred
|
||||||
|
|
||||||
|
def notified(self):
|
||||||
|
return self.deferred.called
|
||||||
|
|
||||||
|
def notify(self, token):
|
||||||
|
""" Inform whoever is listening about the new events.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.deferred.callback(token)
|
||||||
|
except defer.AlreadyCalledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _NotifierUserStream(object):
|
||||||
|
"""This represents a user connected to the event stream.
|
||||||
|
It tracks the most recent stream token for that user.
|
||||||
|
At a given point a user may have a number of streams listening for
|
||||||
|
events.
|
||||||
|
|
||||||
This listener will also keep track of which rooms it is listening in
|
This listener will also keep track of which rooms it is listening in
|
||||||
so that it can remove itself from the indexes in the Notifier class.
|
so that it can remove itself from the indexes in the Notifier class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, user, rooms, from_token, limit, timeout, deferred,
|
def __init__(self, user, rooms, current_token, time_now_ms,
|
||||||
appservice=None):
|
appservice=None):
|
||||||
self.user = user
|
self.user = str(user)
|
||||||
self.appservice = appservice
|
self.appservice = appservice
|
||||||
self.from_token = from_token
|
self.listeners = set()
|
||||||
self.limit = limit
|
self.rooms = set(rooms)
|
||||||
self.timeout = timeout
|
self.current_token = current_token
|
||||||
self.deferred = deferred
|
self.last_notified_ms = time_now_ms
|
||||||
self.rooms = rooms
|
|
||||||
self.timer = None
|
|
||||||
|
|
||||||
def notified(self):
|
def notify(self, stream_key, stream_id, time_now_ms):
|
||||||
return self.deferred.called
|
"""Notify any listeners for this user of a new event from an
|
||||||
|
event source.
|
||||||
|
Args:
|
||||||
|
stream_key(str): The stream the event came from.
|
||||||
|
stream_id(str): The new id for the stream the event came from.
|
||||||
|
time_now_ms(int): The current time in milliseconds.
|
||||||
|
"""
|
||||||
|
self.current_token = self.current_token.copy_and_advance(
|
||||||
|
stream_key, stream_id
|
||||||
|
)
|
||||||
|
if self.listeners:
|
||||||
|
self.last_notified_ms = time_now_ms
|
||||||
|
listeners = self.listeners
|
||||||
|
self.listeners = set()
|
||||||
|
for listener in listeners:
|
||||||
|
listener.notify(self.current_token)
|
||||||
|
|
||||||
def notify(self, notifier, events, start_token, end_token):
|
def remove(self, notifier):
|
||||||
""" Inform whoever is listening about the new events. This will
|
""" Remove this listener from all the indexes in the Notifier
|
||||||
also remove this listener from all the indexes in the Notifier
|
|
||||||
it knows about.
|
it knows about.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = (events, (start_token, end_token))
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.deferred.callback(result)
|
|
||||||
notified_events_counter.inc_by(len(events))
|
|
||||||
except defer.AlreadyCalledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Should the following be done be using intrusively linked lists?
|
|
||||||
# -- erikj
|
|
||||||
|
|
||||||
for room in self.rooms:
|
for room in self.rooms:
|
||||||
lst = notifier.room_to_listeners.get(room, set())
|
lst = notifier.room_to_user_streams.get(room, set())
|
||||||
lst.discard(self)
|
lst.discard(self)
|
||||||
|
|
||||||
notifier.user_to_listeners.get(self.user, set()).discard(self)
|
notifier.user_to_user_stream.pop(self.user)
|
||||||
|
|
||||||
if self.appservice:
|
if self.appservice:
|
||||||
notifier.appservice_to_listeners.get(
|
notifier.appservice_to_user_streams.get(
|
||||||
self.appservice, set()
|
self.appservice, set()
|
||||||
).discard(self)
|
).discard(self)
|
||||||
|
|
||||||
# Cancel the timeout for this notifer if one exists.
|
|
||||||
if self.timer is not None:
|
|
||||||
try:
|
|
||||||
notifier.clock.cancel_call_later(self.timer)
|
|
||||||
except:
|
|
||||||
logger.warn("Failed to cancel notifier timer")
|
|
||||||
|
|
||||||
|
|
||||||
class Notifier(object):
|
class Notifier(object):
|
||||||
""" This class is responsible for notifying any listeners when there are
|
""" This class is responsible for notifying any listeners when there are
|
||||||
|
@ -107,14 +122,18 @@ class Notifier(object):
|
||||||
Primarily used from the /events stream.
|
Primarily used from the /events stream.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
self.room_to_listeners = {}
|
self.user_to_user_stream = {}
|
||||||
self.user_to_listeners = {}
|
self.room_to_user_streams = {}
|
||||||
self.appservice_to_listeners = {}
|
self.appservice_to_user_streams = {}
|
||||||
|
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.pending_new_room_events = []
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@ -122,45 +141,80 @@ class Notifier(object):
|
||||||
"user_joined_room", self._user_joined_room
|
"user_joined_room", self._user_joined_room
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clock.looping_call(
|
||||||
|
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
|
||||||
|
)
|
||||||
|
|
||||||
# This is not a very cheap test to perform, but it's only executed
|
# This is not a very cheap test to perform, but it's only executed
|
||||||
# when rendering the metrics page, which is likely once per minute at
|
# when rendering the metrics page, which is likely once per minute at
|
||||||
# most when scraping it.
|
# most when scraping it.
|
||||||
def count_listeners():
|
def count_listeners():
|
||||||
all_listeners = set()
|
all_user_streams = set()
|
||||||
|
|
||||||
for x in self.room_to_listeners.values():
|
for x in self.room_to_user_streams.values():
|
||||||
all_listeners |= x
|
all_user_streams |= x
|
||||||
for x in self.user_to_listeners.values():
|
for x in self.user_to_user_stream.values():
|
||||||
all_listeners |= x
|
all_user_streams.add(x)
|
||||||
for x in self.appservice_to_listeners.values():
|
for x in self.appservice_to_user_streams.values():
|
||||||
all_listeners |= x
|
all_user_streams |= x
|
||||||
|
|
||||||
return len(all_listeners)
|
return sum(len(stream.listeners) for stream in all_user_streams)
|
||||||
metrics.register_callback("listeners", count_listeners)
|
metrics.register_callback("listeners", count_listeners)
|
||||||
|
|
||||||
metrics.register_callback(
|
metrics.register_callback(
|
||||||
"rooms",
|
"rooms",
|
||||||
lambda: count(bool, self.room_to_listeners.values()),
|
lambda: count(bool, self.room_to_user_streams.values()),
|
||||||
)
|
)
|
||||||
metrics.register_callback(
|
metrics.register_callback(
|
||||||
"users",
|
"users",
|
||||||
lambda: count(bool, self.user_to_listeners.values()),
|
lambda: len(self.user_to_user_stream),
|
||||||
)
|
)
|
||||||
metrics.register_callback(
|
metrics.register_callback(
|
||||||
"appservices",
|
"appservices",
|
||||||
lambda: count(bool, self.appservice_to_listeners.values()),
|
lambda: count(bool, self.appservice_to_user_streams.values()),
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_new_room_event(self, event, extra_users=[]):
|
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
|
||||||
|
extra_users=[]):
|
||||||
""" Used by handlers to inform the notifier something has happened
|
""" Used by handlers to inform the notifier something has happened
|
||||||
in the room, room event wise.
|
in the room, room event wise.
|
||||||
|
|
||||||
This triggers the notifier to wake up any listeners that are
|
This triggers the notifier to wake up any listeners that are
|
||||||
listening to the room, and any listeners for the users in the
|
listening to the room, and any listeners for the users in the
|
||||||
`extra_users` param.
|
`extra_users` param.
|
||||||
|
|
||||||
|
The events can be peristed out of order. The notifier will wait
|
||||||
|
until all previous events have been persisted before notifying
|
||||||
|
the client streams.
|
||||||
"""
|
"""
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
self.pending_new_room_events.append((
|
||||||
|
room_stream_id, event, extra_users
|
||||||
|
))
|
||||||
|
self._notify_pending_new_room_events(max_room_stream_id)
|
||||||
|
|
||||||
|
def _notify_pending_new_room_events(self, max_room_stream_id):
|
||||||
|
"""Notify for the room events that were queued waiting for a previous
|
||||||
|
event to be persisted.
|
||||||
|
Args:
|
||||||
|
max_room_stream_id(int): The highest stream_id below which all
|
||||||
|
events have been persisted.
|
||||||
|
"""
|
||||||
|
pending = self.pending_new_room_events
|
||||||
|
self.pending_new_room_events = []
|
||||||
|
for room_stream_id, event, extra_users in pending:
|
||||||
|
if room_stream_id > max_room_stream_id:
|
||||||
|
self.pending_new_room_events.append((
|
||||||
|
room_stream_id, event, extra_users
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self._on_new_room_event(event, room_stream_id, extra_users)
|
||||||
|
|
||||||
|
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
|
||||||
|
"""Notify any user streams that are interested in this room event"""
|
||||||
# poke any interested application service.
|
# poke any interested application service.
|
||||||
self.hs.get_handlers().appservice_handler.notify_interested_services(
|
self.hs.get_handlers().appservice_handler.notify_interested_services(
|
||||||
event
|
event
|
||||||
|
@ -168,192 +222,129 @@ class Notifier(object):
|
||||||
|
|
||||||
room_id = event.room_id
|
room_id = event.room_id
|
||||||
|
|
||||||
room_source = self.event_sources.sources["room"]
|
room_user_streams = self.room_to_user_streams.get(room_id, set())
|
||||||
|
|
||||||
room_listeners = self.room_to_listeners.get(room_id, set())
|
user_streams = room_user_streams.copy()
|
||||||
|
|
||||||
_discard_if_notified(room_listeners)
|
|
||||||
|
|
||||||
listeners = room_listeners.copy()
|
|
||||||
|
|
||||||
for user in extra_users:
|
for user in extra_users:
|
||||||
user_listeners = self.user_to_listeners.get(user, set())
|
user_stream = self.user_to_user_stream.get(str(user))
|
||||||
|
if user_stream is not None:
|
||||||
|
user_streams.add(user_stream)
|
||||||
|
|
||||||
_discard_if_notified(user_listeners)
|
for appservice in self.appservice_to_user_streams:
|
||||||
|
|
||||||
listeners |= user_listeners
|
|
||||||
|
|
||||||
for appservice in self.appservice_to_listeners:
|
|
||||||
# TODO (kegan): Redundant appservice listener checks?
|
# TODO (kegan): Redundant appservice listener checks?
|
||||||
# App services will already be in the room_to_listeners set, but
|
# App services will already be in the room_to_user_streams set, but
|
||||||
# that isn't enough. They need to be checked here in order to
|
# that isn't enough. They need to be checked here in order to
|
||||||
# receive *invites* for users they are interested in. Does this
|
# receive *invites* for users they are interested in. Does this
|
||||||
# make the room_to_listeners check somewhat obselete?
|
# make the room_to_user_streams check somewhat obselete?
|
||||||
if appservice.is_interested(event):
|
if appservice.is_interested(event):
|
||||||
app_listeners = self.appservice_to_listeners.get(
|
app_user_streams = self.appservice_to_user_streams.get(
|
||||||
appservice, set()
|
appservice, set()
|
||||||
)
|
)
|
||||||
|
user_streams |= app_user_streams
|
||||||
|
|
||||||
_discard_if_notified(app_listeners)
|
logger.debug("on_new_room_event listeners %s", user_streams)
|
||||||
|
|
||||||
listeners |= app_listeners
|
time_now_ms = self.clock.time_msec()
|
||||||
|
for user_stream in user_streams:
|
||||||
logger.debug("on_new_room_event listeners %s", listeners)
|
try:
|
||||||
|
user_stream.notify(
|
||||||
# TODO (erikj): Can we make this more efficient by hitting the
|
"room_key", "s%d" % (room_stream_id,), time_now_ms
|
||||||
# db once?
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def notify(listener):
|
|
||||||
events, end_key = yield room_source.get_new_events_for_user(
|
|
||||||
listener.user,
|
|
||||||
listener.from_token.room_key,
|
|
||||||
listener.limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
if events:
|
|
||||||
end_token = listener.from_token.copy_and_replace(
|
|
||||||
"room_key", end_key
|
|
||||||
)
|
)
|
||||||
|
except:
|
||||||
listener.notify(
|
logger.exception("Failed to notify listener")
|
||||||
self, events, listener.from_token, end_token
|
|
||||||
)
|
|
||||||
|
|
||||||
def eb(failure):
|
|
||||||
logger.exception("Failed to notify listener", failure)
|
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
yield defer.DeferredList(
|
|
||||||
[notify(l).addErrback(eb) for l in listeners],
|
|
||||||
consumeErrors=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_new_user_event(self, users=[], rooms=[]):
|
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
|
||||||
""" Used to inform listeners that something has happend
|
""" Used to inform listeners that something has happend
|
||||||
presence/user event wise.
|
presence/user event wise.
|
||||||
|
|
||||||
Will wake up all listeners for the given users and rooms.
|
Will wake up all listeners for the given users and rooms.
|
||||||
"""
|
"""
|
||||||
# TODO(paul): This is horrible, having to manually list every event
|
yield run_on_reactor()
|
||||||
# source here individually
|
user_streams = set()
|
||||||
presence_source = self.event_sources.sources["presence"]
|
|
||||||
typing_source = self.event_sources.sources["typing"]
|
|
||||||
|
|
||||||
listeners = set()
|
|
||||||
|
|
||||||
for user in users:
|
for user in users:
|
||||||
user_listeners = self.user_to_listeners.get(user, set())
|
user_stream = self.user_to_user_stream.get(str(user))
|
||||||
|
if user_stream is not None:
|
||||||
_discard_if_notified(user_listeners)
|
user_streams.add(user_stream)
|
||||||
|
|
||||||
listeners |= user_listeners
|
|
||||||
|
|
||||||
for room in rooms:
|
for room in rooms:
|
||||||
room_listeners = self.room_to_listeners.get(room, set())
|
user_streams |= self.room_to_user_streams.get(room, set())
|
||||||
|
|
||||||
_discard_if_notified(room_listeners)
|
time_now_ms = self.clock.time_msec()
|
||||||
|
for user_stream in user_streams:
|
||||||
listeners |= room_listeners
|
try:
|
||||||
|
user_stream.notify(stream_key, new_token, time_now_ms)
|
||||||
@defer.inlineCallbacks
|
except:
|
||||||
def notify(listener):
|
logger.exception("Failed to notify listener")
|
||||||
presence_events, presence_end_key = (
|
|
||||||
yield presence_source.get_new_events_for_user(
|
|
||||||
listener.user,
|
|
||||||
listener.from_token.presence_key,
|
|
||||||
listener.limit,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
typing_events, typing_end_key = (
|
|
||||||
yield typing_source.get_new_events_for_user(
|
|
||||||
listener.user,
|
|
||||||
listener.from_token.typing_key,
|
|
||||||
listener.limit,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if presence_events or typing_events:
|
|
||||||
end_token = listener.from_token.copy_and_replace(
|
|
||||||
"presence_key", presence_end_key
|
|
||||||
).copy_and_replace(
|
|
||||||
"typing_key", typing_end_key
|
|
||||||
)
|
|
||||||
|
|
||||||
listener.notify(
|
|
||||||
self,
|
|
||||||
presence_events + typing_events,
|
|
||||||
listener.from_token,
|
|
||||||
end_token
|
|
||||||
)
|
|
||||||
|
|
||||||
def eb(failure):
|
|
||||||
logger.error(
|
|
||||||
"Failed to notify listener",
|
|
||||||
exc_info=(
|
|
||||||
failure.type,
|
|
||||||
failure.value,
|
|
||||||
failure.getTracebackObject())
|
|
||||||
)
|
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
yield defer.DeferredList(
|
|
||||||
[notify(l).addErrback(eb) for l in listeners],
|
|
||||||
consumeErrors=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_events(self, user, rooms, filter, timeout, callback):
|
def wait_for_events(self, user, rooms, timeout, callback,
|
||||||
|
from_token=StreamToken("s0", "0", "0")):
|
||||||
"""Wait until the callback returns a non empty response or the
|
"""Wait until the callback returns a non empty response or the
|
||||||
timeout fires.
|
timeout fires.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
deferred = defer.Deferred()
|
deferred = defer.Deferred()
|
||||||
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
from_token = StreamToken("s0", "0", "0")
|
user = str(user)
|
||||||
|
user_stream = self.user_to_user_stream.get(user)
|
||||||
|
if user_stream is None:
|
||||||
|
appservice = yield self.store.get_app_service_by_user_id(user)
|
||||||
|
current_token = yield self.event_sources.get_current_token()
|
||||||
|
rooms = yield self.store.get_rooms_for_user(user)
|
||||||
|
rooms = [room.room_id for room in rooms]
|
||||||
|
user_stream = _NotifierUserStream(
|
||||||
|
user=user,
|
||||||
|
rooms=rooms,
|
||||||
|
appservice=appservice,
|
||||||
|
current_token=current_token,
|
||||||
|
time_now_ms=time_now_ms,
|
||||||
|
)
|
||||||
|
self._register_with_keys(user_stream)
|
||||||
|
else:
|
||||||
|
current_token = user_stream.current_token
|
||||||
|
|
||||||
listener = [_NotificationListener(
|
listener = [_NotificationListener(deferred)]
|
||||||
user=user,
|
|
||||||
rooms=rooms,
|
|
||||||
from_token=from_token,
|
|
||||||
limit=1,
|
|
||||||
timeout=timeout,
|
|
||||||
deferred=deferred,
|
|
||||||
)]
|
|
||||||
|
|
||||||
if timeout:
|
if timeout and not current_token.is_after(from_token):
|
||||||
self._register_with_keys(listener[0])
|
user_stream.listeners.add(listener[0])
|
||||||
|
|
||||||
|
if current_token.is_after(from_token):
|
||||||
|
result = yield callback(from_token, current_token)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
result = yield callback()
|
|
||||||
timer = [None]
|
timer = [None]
|
||||||
|
|
||||||
|
if result:
|
||||||
|
user_stream.listeners.discard(listener[0])
|
||||||
|
defer.returnValue(result)
|
||||||
|
return
|
||||||
|
|
||||||
if timeout:
|
if timeout:
|
||||||
timed_out = [False]
|
timed_out = [False]
|
||||||
|
|
||||||
def _timeout_listener():
|
def _timeout_listener():
|
||||||
timed_out[0] = True
|
timed_out[0] = True
|
||||||
timer[0] = None
|
timer[0] = None
|
||||||
listener[0].notify(self, [], from_token, from_token)
|
user_stream.listeners.discard(listener[0])
|
||||||
|
listener[0].notify(current_token)
|
||||||
|
|
||||||
# We create multiple notification listeners so we have to manage
|
# We create multiple notification listeners so we have to manage
|
||||||
# canceling the timeout ourselves.
|
# canceling the timeout ourselves.
|
||||||
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
|
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
|
||||||
|
|
||||||
while not result and not timed_out[0]:
|
while not result and not timed_out[0]:
|
||||||
yield deferred
|
new_token = yield deferred
|
||||||
deferred = defer.Deferred()
|
deferred = defer.Deferred()
|
||||||
listener[0] = _NotificationListener(
|
listener[0] = _NotificationListener(deferred)
|
||||||
user=user,
|
user_stream.listeners.add(listener[0])
|
||||||
rooms=rooms,
|
result = yield callback(current_token, new_token)
|
||||||
from_token=from_token,
|
current_token = new_token
|
||||||
limit=1,
|
|
||||||
timeout=timeout,
|
|
||||||
deferred=deferred,
|
|
||||||
)
|
|
||||||
self._register_with_keys(listener[0])
|
|
||||||
result = yield callback()
|
|
||||||
|
|
||||||
if timer[0] is not None:
|
if timer[0] is not None:
|
||||||
try:
|
try:
|
||||||
|
@ -363,125 +354,79 @@ class Notifier(object):
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def get_events_for(self, user, rooms, pagination_config, timeout):
|
def get_events_for(self, user, rooms, pagination_config, timeout):
|
||||||
""" For the given user and rooms, return any new events for them. If
|
""" For the given user and rooms, return any new events for them. If
|
||||||
there are no new events wait for up to `timeout` milliseconds for any
|
there are no new events wait for up to `timeout` milliseconds for any
|
||||||
new events to happen before returning.
|
new events to happen before returning.
|
||||||
"""
|
"""
|
||||||
deferred = defer.Deferred()
|
from_token = pagination_config.from_token
|
||||||
|
|
||||||
self._get_events(
|
|
||||||
deferred, user, rooms, pagination_config.from_token,
|
|
||||||
pagination_config.limit, timeout
|
|
||||||
).addErrback(deferred.errback)
|
|
||||||
|
|
||||||
return deferred
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_events(self, deferred, user, rooms, from_token, limit, timeout):
|
|
||||||
if not from_token:
|
if not from_token:
|
||||||
from_token = yield self.event_sources.get_current_token()
|
from_token = yield self.event_sources.get_current_token()
|
||||||
|
|
||||||
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
|
limit = pagination_config.limit
|
||||||
user.to_string()
|
|
||||||
)
|
|
||||||
|
|
||||||
listener = _NotificationListener(
|
@defer.inlineCallbacks
|
||||||
user,
|
def check_for_updates(before_token, after_token):
|
||||||
rooms,
|
events = []
|
||||||
from_token,
|
end_token = from_token
|
||||||
limit,
|
for name, source in self.event_sources.sources.items():
|
||||||
timeout,
|
keyname = "%s_key" % name
|
||||||
deferred,
|
before_id = getattr(before_token, keyname)
|
||||||
appservice=appservice
|
after_id = getattr(after_token, keyname)
|
||||||
)
|
if before_id == after_id:
|
||||||
|
continue
|
||||||
def _timeout_listener():
|
stuff, new_key = yield source.get_new_events_for_user(
|
||||||
# TODO (erikj): We should probably set to_token to the current
|
user, getattr(from_token, keyname), limit,
|
||||||
# max rather than reusing from_token.
|
|
||||||
# Remove the timer from the listener so we don't try to cancel it.
|
|
||||||
listener.timer = None
|
|
||||||
listener.notify(
|
|
||||||
self,
|
|
||||||
[],
|
|
||||||
listener.from_token,
|
|
||||||
listener.from_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
if timeout:
|
|
||||||
self._register_with_keys(listener)
|
|
||||||
|
|
||||||
yield self._check_for_updates(listener)
|
|
||||||
|
|
||||||
if not timeout:
|
|
||||||
_timeout_listener()
|
|
||||||
else:
|
|
||||||
# Only add the timer if the listener hasn't been notified
|
|
||||||
if not listener.notified():
|
|
||||||
listener.timer = self.clock.call_later(
|
|
||||||
timeout/1000.0, _timeout_listener
|
|
||||||
)
|
)
|
||||||
return
|
events.extend(stuff)
|
||||||
|
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||||
|
|
||||||
|
if events:
|
||||||
|
defer.returnValue((events, (from_token, end_token)))
|
||||||
|
else:
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
result = yield self.wait_for_events(
|
||||||
|
user, rooms, timeout, check_for_updates, from_token=from_token
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
result = ([], (from_token, from_token))
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _register_with_keys(self, listener):
|
def remove_expired_streams(self):
|
||||||
for room in listener.rooms:
|
time_now_ms = self.clock.time_msec()
|
||||||
s = self.room_to_listeners.setdefault(room, set())
|
expired_streams = []
|
||||||
s.add(listener)
|
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
|
||||||
|
for stream in self.user_to_user_stream.values():
|
||||||
|
if stream.listeners:
|
||||||
|
continue
|
||||||
|
if stream.last_notified_ms < expire_before_ts:
|
||||||
|
expired_streams.append(stream)
|
||||||
|
|
||||||
self.user_to_listeners.setdefault(listener.user, set()).add(listener)
|
for expired_stream in expired_streams:
|
||||||
|
expired_stream.remove(self)
|
||||||
|
|
||||||
if listener.appservice:
|
|
||||||
self.appservice_to_listeners.setdefault(
|
|
||||||
listener.appservice, set()
|
|
||||||
).add(listener)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
@log_function
|
@log_function
|
||||||
def _check_for_updates(self, listener):
|
def _register_with_keys(self, user_stream):
|
||||||
# TODO (erikj): We need to think about limits across multiple sources
|
self.user_to_user_stream[user_stream.user] = user_stream
|
||||||
events = []
|
|
||||||
|
|
||||||
from_token = listener.from_token
|
for room in user_stream.rooms:
|
||||||
limit = listener.limit
|
s = self.room_to_user_streams.setdefault(room, set())
|
||||||
|
s.add(user_stream)
|
||||||
|
|
||||||
# TODO (erikj): DeferredList?
|
if user_stream.appservice:
|
||||||
for name, source in self.event_sources.sources.items():
|
self.appservice_to_user_stream.setdefault(
|
||||||
keyname = "%s_key" % name
|
user_stream.appservice, set()
|
||||||
|
).add(user_stream)
|
||||||
stuff, new_key = yield source.get_new_events_for_user(
|
|
||||||
listener.user,
|
|
||||||
getattr(from_token, keyname),
|
|
||||||
limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
events.extend(stuff)
|
|
||||||
|
|
||||||
from_token = from_token.copy_and_replace(keyname, new_key)
|
|
||||||
|
|
||||||
end_token = from_token
|
|
||||||
|
|
||||||
if events:
|
|
||||||
listener.notify(self, events, listener.from_token, end_token)
|
|
||||||
|
|
||||||
defer.returnValue(listener)
|
|
||||||
|
|
||||||
def _user_joined_room(self, user, room_id):
|
def _user_joined_room(self, user, room_id):
|
||||||
new_listeners = self.user_to_listeners.get(user, set())
|
user = str(user)
|
||||||
|
new_user_stream = self.user_to_user_stream.get(user)
|
||||||
listeners = self.room_to_listeners.setdefault(room_id, set())
|
if new_user_stream is not None:
|
||||||
listeners |= new_listeners
|
room_streams = self.room_to_user_streams.setdefault(room_id, set())
|
||||||
|
room_streams.add(new_user_stream)
|
||||||
for l in new_listeners:
|
new_user_stream.rooms.add(room_id)
|
||||||
l.rooms.add(room_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _discard_if_notified(listener_set):
|
|
||||||
"""Remove any 'stale' listeners from the given set.
|
|
||||||
"""
|
|
||||||
to_discard = set()
|
|
||||||
for l in listener_set:
|
|
||||||
if l.notified():
|
|
||||||
to_discard.add(l)
|
|
||||||
|
|
||||||
listener_set -= to_discard
|
|
||||||
|
|
|
@ -74,15 +74,18 @@ class Pusher(object):
|
||||||
|
|
||||||
rawrules = yield self.store.get_push_rules_for_user(self.user_name)
|
rawrules = yield self.store.get_push_rules_for_user(self.user_name)
|
||||||
|
|
||||||
for r in rawrules:
|
rules = []
|
||||||
r['conditions'] = json.loads(r['conditions'])
|
for rawrule in rawrules:
|
||||||
r['actions'] = json.loads(r['actions'])
|
rule = dict(rawrule)
|
||||||
|
rule['conditions'] = json.loads(rawrule['conditions'])
|
||||||
|
rule['actions'] = json.loads(rawrule['actions'])
|
||||||
|
rules.append(rule)
|
||||||
|
|
||||||
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
|
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
|
||||||
|
|
||||||
user = UserID.from_string(self.user_name)
|
user = UserID.from_string(self.user_name)
|
||||||
|
|
||||||
rules = baserules.list_with_base_rules(rawrules, user)
|
rules = baserules.list_with_base_rules(rules, user)
|
||||||
|
|
||||||
room_id = ev['room_id']
|
room_id = ev['room_id']
|
||||||
|
|
||||||
|
|
|
@ -118,11 +118,14 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
user.to_string()
|
user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
for r in rawrules:
|
ruleslist = []
|
||||||
r["conditions"] = json.loads(r["conditions"])
|
for rawrule in rawrules:
|
||||||
r["actions"] = json.loads(r["actions"])
|
rule = dict(rawrule)
|
||||||
|
rule["conditions"] = json.loads(rawrule["conditions"])
|
||||||
|
rule["actions"] = json.loads(rawrule["actions"])
|
||||||
|
ruleslist.append(rule)
|
||||||
|
|
||||||
ruleslist = baserules.list_with_base_rules(rawrules, user)
|
ruleslist = baserules.list_with_base_rules(ruleslist, user)
|
||||||
|
|
||||||
rules = {'global': {}, 'device': {}}
|
rules = {'global': {}, 'device': {}}
|
||||||
|
|
||||||
|
|
|
@ -82,8 +82,10 @@ class RegisterRestServlet(RestServlet):
|
||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
result = None
|
||||||
if service:
|
if service:
|
||||||
is_application_server = True
|
is_application_server = True
|
||||||
|
params = body
|
||||||
elif 'mac' in body:
|
elif 'mac' in body:
|
||||||
# Check registration-specific shared secret auth
|
# Check registration-specific shared secret auth
|
||||||
if 'username' not in body:
|
if 'username' not in body:
|
||||||
|
@ -92,6 +94,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
body['username'], body['mac']
|
body['username'], body['mac']
|
||||||
)
|
)
|
||||||
is_using_shared_secret = True
|
is_using_shared_secret = True
|
||||||
|
params = body
|
||||||
else:
|
else:
|
||||||
authed, result, params = yield self.auth_handler.check_auth(
|
authed, result, params = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
|
@ -118,7 +121,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
password=new_password
|
password=new_password
|
||||||
)
|
)
|
||||||
|
|
||||||
if LoginType.EMAIL_IDENTITY in result:
|
if result and LoginType.EMAIL_IDENTITY in result:
|
||||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||||
|
|
||||||
for reqd in ['medium', 'address', 'validated_at']:
|
for reqd in ['medium', 'address', 'validated_at']:
|
||||||
|
|
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
from synapse.util.async import create_observer
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -83,13 +83,17 @@ class BaseMediaResource(Resource):
|
||||||
download = self.downloads.get(key)
|
download = self.downloads.get(key)
|
||||||
if download is None:
|
if download is None:
|
||||||
download = self._get_remote_media_impl(server_name, media_id)
|
download = self._get_remote_media_impl(server_name, media_id)
|
||||||
|
download = ObservableDeferred(
|
||||||
|
download,
|
||||||
|
consumeErrors=True
|
||||||
|
)
|
||||||
self.downloads[key] = download
|
self.downloads[key] = download
|
||||||
|
|
||||||
@download.addBoth
|
@download.addBoth
|
||||||
def callback(media_info):
|
def callback(media_info):
|
||||||
del self.downloads[key]
|
del self.downloads[key]
|
||||||
return media_info
|
return media_info
|
||||||
return create_observer(download)
|
return download.observe()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_remote_media_impl(self, server_name, media_id):
|
def _get_remote_media_impl(self, server_name, media_id):
|
||||||
|
|
|
@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 18
|
SCHEMA_VERSION = 19
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.events import FrozenEvent
|
|
||||||
from synapse.events.utils import prune_event
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||||
from synapse.util.lrucache import LruCache
|
from synapse.util.lrucache import LruCache
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -27,8 +25,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple, OrderedDict
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import simplejson as json
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
@ -48,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
|
||||||
|
|
||||||
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
|
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
|
||||||
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
|
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
|
||||||
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
|
|
||||||
|
|
||||||
caches_by_name = {}
|
caches_by_name = {}
|
||||||
cache_counter = metrics.register_cache(
|
cache_counter = metrics.register_cache(
|
||||||
|
@ -307,6 +304,12 @@ class SQLBaseStore(object):
|
||||||
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
|
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
|
||||||
max_entries=hs.config.event_cache_size)
|
max_entries=hs.config.event_cache_size)
|
||||||
|
|
||||||
|
self._event_fetch_lock = threading.Condition()
|
||||||
|
self._event_fetch_list = []
|
||||||
|
self._event_fetch_ongoing = 0
|
||||||
|
|
||||||
|
self._pending_ds = []
|
||||||
|
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = hs.database_engine
|
||||||
|
|
||||||
self._stream_id_gen = StreamIdGenerator()
|
self._stream_id_gen = StreamIdGenerator()
|
||||||
|
@ -315,6 +318,7 @@ class SQLBaseStore(object):
|
||||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||||
|
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||||
|
|
||||||
def start_profiling(self):
|
def start_profiling(self):
|
||||||
self._previous_loop_ts = self._clock.time_msec()
|
self._previous_loop_ts = self._clock.time_msec()
|
||||||
|
@ -345,6 +349,75 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
self._clock.looping_call(loop, 10000)
|
self._clock.looping_call(loop, 10000)
|
||||||
|
|
||||||
|
def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
|
||||||
|
start = time.time() * 1000
|
||||||
|
txn_id = self._TXN_ID
|
||||||
|
|
||||||
|
# We don't really need these to be unique, so lets stop it from
|
||||||
|
# growing really large.
|
||||||
|
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
|
||||||
|
|
||||||
|
name = "%s-%x" % (desc, txn_id, )
|
||||||
|
|
||||||
|
transaction_logger.debug("[TXN START] {%s}", name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
i = 0
|
||||||
|
N = 5
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
txn = conn.cursor()
|
||||||
|
txn = LoggingTransaction(
|
||||||
|
txn, name, self.database_engine, after_callbacks
|
||||||
|
)
|
||||||
|
r = func(txn, *args, **kwargs)
|
||||||
|
conn.commit()
|
||||||
|
return r
|
||||||
|
except self.database_engine.module.OperationalError as e:
|
||||||
|
# This can happen if the database disappears mid
|
||||||
|
# transaction.
|
||||||
|
logger.warn(
|
||||||
|
"[TXN OPERROR] {%s} %s %d/%d",
|
||||||
|
name, e, i, N
|
||||||
|
)
|
||||||
|
if i < N:
|
||||||
|
i += 1
|
||||||
|
try:
|
||||||
|
conn.rollback()
|
||||||
|
except self.database_engine.module.Error as e1:
|
||||||
|
logger.warn(
|
||||||
|
"[TXN EROLL] {%s} %s",
|
||||||
|
name, e1,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
except self.database_engine.module.DatabaseError as e:
|
||||||
|
if self.database_engine.is_deadlock(e):
|
||||||
|
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
|
||||||
|
if i < N:
|
||||||
|
i += 1
|
||||||
|
try:
|
||||||
|
conn.rollback()
|
||||||
|
except self.database_engine.module.Error as e1:
|
||||||
|
logger.warn(
|
||||||
|
"[TXN EROLL] {%s} %s",
|
||||||
|
name, e1,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("[TXN FAIL] {%s} %s", name, e)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
end = time.time() * 1000
|
||||||
|
duration = end - start
|
||||||
|
|
||||||
|
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
|
||||||
|
|
||||||
|
self._current_txn_total_time += duration
|
||||||
|
self._txn_perf_counters.update(desc, start, end)
|
||||||
|
sql_txn_timer.inc_by(duration, desc)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def runInteraction(self, desc, func, *args, **kwargs):
|
def runInteraction(self, desc, func, *args, **kwargs):
|
||||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||||
|
@ -356,84 +429,52 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
def inner_func(conn, *args, **kwargs):
|
def inner_func(conn, *args, **kwargs):
|
||||||
with LoggingContext("runInteraction") as context:
|
with LoggingContext("runInteraction") as context:
|
||||||
|
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||||
|
|
||||||
if self.database_engine.is_connection_closed(conn):
|
if self.database_engine.is_connection_closed(conn):
|
||||||
logger.debug("Reconnecting closed database connection")
|
logger.debug("Reconnecting closed database connection")
|
||||||
conn.reconnect()
|
conn.reconnect()
|
||||||
|
|
||||||
current_context.copy_to(context)
|
current_context.copy_to(context)
|
||||||
start = time.time() * 1000
|
return self._new_transaction(
|
||||||
txn_id = self._TXN_ID
|
conn, desc, after_callbacks, func, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# We don't really need these to be unique, so lets stop it from
|
result = yield preserve_context_over_fn(
|
||||||
# growing really large.
|
self._db_pool.runWithConnection,
|
||||||
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
|
inner_func, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
name = "%s-%x" % (desc, txn_id, )
|
|
||||||
|
|
||||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
|
||||||
transaction_logger.debug("[TXN START] {%s}", name)
|
|
||||||
try:
|
|
||||||
i = 0
|
|
||||||
N = 5
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
txn = conn.cursor()
|
|
||||||
txn = LoggingTransaction(
|
|
||||||
txn, name, self.database_engine, after_callbacks
|
|
||||||
)
|
|
||||||
return func(txn, *args, **kwargs)
|
|
||||||
except self.database_engine.module.OperationalError as e:
|
|
||||||
# This can happen if the database disappears mid
|
|
||||||
# transaction.
|
|
||||||
logger.warn(
|
|
||||||
"[TXN OPERROR] {%s} %s %d/%d",
|
|
||||||
name, e, i, N
|
|
||||||
)
|
|
||||||
if i < N:
|
|
||||||
i += 1
|
|
||||||
try:
|
|
||||||
conn.rollback()
|
|
||||||
except self.database_engine.module.Error as e1:
|
|
||||||
logger.warn(
|
|
||||||
"[TXN EROLL] {%s} %s",
|
|
||||||
name, e1,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
except self.database_engine.module.DatabaseError as e:
|
|
||||||
if self.database_engine.is_deadlock(e):
|
|
||||||
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
|
|
||||||
if i < N:
|
|
||||||
i += 1
|
|
||||||
try:
|
|
||||||
conn.rollback()
|
|
||||||
except self.database_engine.module.Error as e1:
|
|
||||||
logger.warn(
|
|
||||||
"[TXN EROLL] {%s} %s",
|
|
||||||
name, e1,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("[TXN FAIL] {%s} %s", name, e)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
end = time.time() * 1000
|
|
||||||
duration = end - start
|
|
||||||
|
|
||||||
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
|
|
||||||
|
|
||||||
self._current_txn_total_time += duration
|
|
||||||
self._txn_perf_counters.update(desc, start, end)
|
|
||||||
sql_txn_timer.inc_by(duration, desc)
|
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
result = yield self._db_pool.runWithConnection(
|
|
||||||
inner_func, *args, **kwargs
|
|
||||||
)
|
|
||||||
for after_callback, after_args in after_callbacks:
|
for after_callback, after_args in after_callbacks:
|
||||||
after_callback(*after_args)
|
after_callback(*after_args)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def runWithConnection(self, func, *args, **kwargs):
|
||||||
|
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||||
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
start_time = time.time() * 1000
|
||||||
|
|
||||||
|
def inner_func(conn, *args, **kwargs):
|
||||||
|
with LoggingContext("runWithConnection") as context:
|
||||||
|
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||||
|
|
||||||
|
if self.database_engine.is_connection_closed(conn):
|
||||||
|
logger.debug("Reconnecting closed database connection")
|
||||||
|
conn.reconnect()
|
||||||
|
|
||||||
|
current_context.copy_to(context)
|
||||||
|
|
||||||
|
return func(conn, *args, **kwargs)
|
||||||
|
|
||||||
|
result = yield preserve_context_over_fn(
|
||||||
|
self._db_pool.runWithConnection,
|
||||||
|
inner_func, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
def cursor_to_dict(self, cursor):
|
def cursor_to_dict(self, cursor):
|
||||||
"""Converts a SQL cursor into an list of dicts.
|
"""Converts a SQL cursor into an list of dicts.
|
||||||
|
|
||||||
|
@ -871,158 +912,6 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
return self.runInteraction("_simple_max_id", func)
|
return self.runInteraction("_simple_max_id", func)
|
||||||
|
|
||||||
def _get_events(self, event_ids, check_redacted=True,
|
|
||||||
get_prev_content=False):
|
|
||||||
return self.runInteraction(
|
|
||||||
"_get_events", self._get_events_txn, event_ids,
|
|
||||||
check_redacted=check_redacted, get_prev_content=get_prev_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_events_txn(self, txn, event_ids, check_redacted=True,
|
|
||||||
get_prev_content=False):
|
|
||||||
if not event_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
events = [
|
|
||||||
self._get_event_txn(
|
|
||||||
txn, event_id,
|
|
||||||
check_redacted=check_redacted,
|
|
||||||
get_prev_content=get_prev_content
|
|
||||||
)
|
|
||||||
for event_id in event_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
return [e for e in events if e]
|
|
||||||
|
|
||||||
def _invalidate_get_event_cache(self, event_id):
|
|
||||||
for check_redacted in (False, True):
|
|
||||||
for get_prev_content in (False, True):
|
|
||||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
|
||||||
get_prev_content)
|
|
||||||
|
|
||||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
|
||||||
get_prev_content=False, allow_rejected=False):
|
|
||||||
|
|
||||||
start_time = time.time() * 1000
|
|
||||||
|
|
||||||
def update_counter(desc, last_time):
|
|
||||||
curr_time = self._get_event_counters.update(desc, last_time)
|
|
||||||
sql_getevents_timer.inc_by(curr_time - last_time, desc)
|
|
||||||
return curr_time
|
|
||||||
|
|
||||||
try:
|
|
||||||
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
|
|
||||||
|
|
||||||
if allow_rejected or not ret.rejected_reason:
|
|
||||||
return ret
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
start_time = update_counter("event_cache", start_time)
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
|
|
||||||
"FROM event_json as e "
|
|
||||||
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
|
|
||||||
"LEFT JOIN rejections as rej on rej.event_id = e.event_id "
|
|
||||||
"WHERE e.event_id = ? "
|
|
||||||
"LIMIT 1 "
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, (event_id,))
|
|
||||||
|
|
||||||
res = txn.fetchone()
|
|
||||||
|
|
||||||
if not res:
|
|
||||||
return None
|
|
||||||
|
|
||||||
internal_metadata, js, redacted, rejected_reason = res
|
|
||||||
|
|
||||||
start_time = update_counter("select_event", start_time)
|
|
||||||
|
|
||||||
result = self._get_event_from_row_txn(
|
|
||||||
txn, internal_metadata, js, redacted,
|
|
||||||
check_redacted=check_redacted,
|
|
||||||
get_prev_content=get_prev_content,
|
|
||||||
rejected_reason=rejected_reason,
|
|
||||||
)
|
|
||||||
self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
|
|
||||||
|
|
||||||
if allow_rejected or not rejected_reason:
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
|
|
||||||
check_redacted=True, get_prev_content=False,
|
|
||||||
rejected_reason=None):
|
|
||||||
|
|
||||||
start_time = time.time() * 1000
|
|
||||||
|
|
||||||
def update_counter(desc, last_time):
|
|
||||||
curr_time = self._get_event_counters.update(desc, last_time)
|
|
||||||
sql_getevents_timer.inc_by(curr_time - last_time, desc)
|
|
||||||
return curr_time
|
|
||||||
|
|
||||||
d = json.loads(js)
|
|
||||||
start_time = update_counter("decode_json", start_time)
|
|
||||||
|
|
||||||
internal_metadata = json.loads(internal_metadata)
|
|
||||||
start_time = update_counter("decode_internal", start_time)
|
|
||||||
|
|
||||||
ev = FrozenEvent(
|
|
||||||
d,
|
|
||||||
internal_metadata_dict=internal_metadata,
|
|
||||||
rejected_reason=rejected_reason,
|
|
||||||
)
|
|
||||||
start_time = update_counter("build_frozen_event", start_time)
|
|
||||||
|
|
||||||
if check_redacted and redacted:
|
|
||||||
ev = prune_event(ev)
|
|
||||||
|
|
||||||
ev.unsigned["redacted_by"] = redacted
|
|
||||||
# Get the redaction event.
|
|
||||||
|
|
||||||
because = self._get_event_txn(
|
|
||||||
txn,
|
|
||||||
redacted,
|
|
||||||
check_redacted=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if because:
|
|
||||||
ev.unsigned["redacted_because"] = because
|
|
||||||
start_time = update_counter("redact_event", start_time)
|
|
||||||
|
|
||||||
if get_prev_content and "replaces_state" in ev.unsigned:
|
|
||||||
prev = self._get_event_txn(
|
|
||||||
txn,
|
|
||||||
ev.unsigned["replaces_state"],
|
|
||||||
get_prev_content=False,
|
|
||||||
)
|
|
||||||
if prev:
|
|
||||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
|
||||||
start_time = update_counter("get_prev_content", start_time)
|
|
||||||
|
|
||||||
return ev
|
|
||||||
|
|
||||||
def _parse_events(self, rows):
|
|
||||||
return self.runInteraction(
|
|
||||||
"_parse_events", self._parse_events_txn, rows
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_events_txn(self, txn, rows):
|
|
||||||
event_ids = [r["event_id"] for r in rows]
|
|
||||||
|
|
||||||
return self._get_events_txn(txn, event_ids)
|
|
||||||
|
|
||||||
def _has_been_redacted_txn(self, txn, event):
|
|
||||||
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
|
|
||||||
txn.execute(sql, (event.event_id,))
|
|
||||||
result = txn.fetchone()
|
|
||||||
return result[0] if result else None
|
|
||||||
|
|
||||||
def get_next_stream_id(self):
|
def get_next_stream_id(self):
|
||||||
with self._next_stream_id_lock:
|
with self._next_stream_id_lock:
|
||||||
i = self._next_stream_id
|
i = self._next_stream_id
|
||||||
|
|
|
@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
|
||||||
|
|
||||||
|
|
||||||
class PostgresEngine(object):
|
class PostgresEngine(object):
|
||||||
|
single_threaded = False
|
||||||
|
|
||||||
def __init__(self, database_module):
|
def __init__(self, database_module):
|
||||||
self.module = database_module
|
self.module = database_module
|
||||||
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
||||||
|
|
|
@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
|
||||||
|
|
||||||
|
|
||||||
class Sqlite3Engine(object):
|
class Sqlite3Engine(object):
|
||||||
|
single_threaded = True
|
||||||
|
|
||||||
def __init__(self, database_module):
|
def __init__(self, database_module):
|
||||||
self.module = database_module
|
self.module = database_module
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cached
|
||||||
from syutil.base64util import encode_base64
|
from syutil.base64util import encode_base64
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from Queue import PriorityQueue, Empty
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -33,16 +36,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_auth_chain(self, event_ids):
|
def get_auth_chain(self, event_ids):
|
||||||
return self.runInteraction(
|
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
|
||||||
"get_auth_chain",
|
|
||||||
self._get_auth_chain_txn,
|
|
||||||
event_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_auth_chain_txn(self, txn, event_ids):
|
|
||||||
results = self._get_auth_chain_ids_txn(txn, event_ids)
|
|
||||||
|
|
||||||
return self._get_events_txn(txn, results)
|
|
||||||
|
|
||||||
def get_auth_chain_ids(self, event_ids):
|
def get_auth_chain_ids(self, event_ids):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
|
@ -79,6 +73,28 @@ class EventFederationStore(SQLBaseStore):
|
||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_oldest_events_with_depth_in_room(self, room_id):
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_oldest_events_with_depth_in_room",
|
||||||
|
self.get_oldest_events_with_depth_in_room_txn,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
|
||||||
|
sql = (
|
||||||
|
"SELECT b.event_id, MAX(e.depth) FROM events as e"
|
||||||
|
" INNER JOIN event_edges as g"
|
||||||
|
" ON g.event_id = e.event_id AND g.room_id = e.room_id"
|
||||||
|
" INNER JOIN event_backward_extremities as b"
|
||||||
|
" ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
|
||||||
|
" WHERE b.room_id = ? AND g.is_state is ?"
|
||||||
|
" GROUP BY b.event_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (room_id, False,))
|
||||||
|
|
||||||
|
return dict(txn.fetchall())
|
||||||
|
|
||||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||||
return self._simple_select_onecol_txn(
|
return self._simple_select_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -247,11 +263,13 @@ class EventFederationStore(SQLBaseStore):
|
||||||
do_insert = depth < min_depth if min_depth else True
|
do_insert = depth < min_depth if min_depth else True
|
||||||
|
|
||||||
if do_insert:
|
if do_insert:
|
||||||
self._simple_insert_txn(
|
self._simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="room_depth",
|
table="room_depth",
|
||||||
values={
|
keyvalues={
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
"min_depth": depth,
|
"min_depth": depth,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -306,31 +324,28 @@ class EventFederationStore(SQLBaseStore):
|
||||||
|
|
||||||
txn.execute(query, (event_id, room_id))
|
txn.execute(query, (event_id, room_id))
|
||||||
|
|
||||||
# Insert all the prev_events as a backwards thing, they'll get
|
query = (
|
||||||
# deleted in a second if they're incorrect anyway.
|
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
||||||
self._simple_insert_many_txn(
|
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||||
txn,
|
" SELECT 1 FROM event_backward_extremities"
|
||||||
table="event_backward_extremities",
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
values=[
|
" )"
|
||||||
{
|
" AND NOT EXISTS ("
|
||||||
"event_id": e_id,
|
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
|
||||||
"room_id": room_id,
|
" AND outlier = ?"
|
||||||
}
|
" )"
|
||||||
for e_id, _ in prev_events
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Also delete from the backwards extremities table all ones that
|
txn.executemany(query, [
|
||||||
# reference events that we have already seen
|
(e_id, room_id, e_id, room_id, e_id, room_id, False)
|
||||||
|
for e_id, _ in prev_events
|
||||||
|
])
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"DELETE FROM event_backward_extremities WHERE EXISTS ("
|
"DELETE FROM event_backward_extremities"
|
||||||
"SELECT 1 FROM events "
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
"WHERE "
|
|
||||||
"event_backward_extremities.event_id = events.event_id "
|
|
||||||
"AND not events.outlier "
|
|
||||||
")"
|
|
||||||
)
|
)
|
||||||
txn.execute(query)
|
txn.execute(query, (event_id, room_id))
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
self.get_latest_event_ids_in_room.invalidate, room_id
|
||||||
|
@ -349,6 +364,10 @@ class EventFederationStore(SQLBaseStore):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_backfill_events",
|
"get_backfill_events",
|
||||||
self._get_backfill_events, room_id, event_list, limit
|
self._get_backfill_events, room_id, event_list, limit
|
||||||
|
).addCallback(
|
||||||
|
self._get_events
|
||||||
|
).addCallback(
|
||||||
|
lambda l: sorted(l, key=lambda e: -e.depth)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
||||||
|
@ -357,54 +376,75 @@ class EventFederationStore(SQLBaseStore):
|
||||||
room_id, repr(event_list), limit
|
room_id, repr(event_list), limit
|
||||||
)
|
)
|
||||||
|
|
||||||
event_results = event_list
|
event_results = set()
|
||||||
|
|
||||||
front = event_list
|
# We want to make sure that we do a breadth-first, "depth" ordered
|
||||||
|
# search.
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"SELECT prev_event_id FROM event_edges "
|
"SELECT depth, prev_event_id FROM event_edges"
|
||||||
"WHERE room_id = ? AND event_id = ? "
|
" INNER JOIN events"
|
||||||
"LIMIT ?"
|
" ON prev_event_id = events.event_id"
|
||||||
|
" AND event_edges.room_id = events.room_id"
|
||||||
|
" WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
|
||||||
|
" AND event_edges.is_state = ?"
|
||||||
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
|
|
||||||
# We iterate through all event_ids in `front` to select their previous
|
queue = PriorityQueue()
|
||||||
# events. These are dumped in `new_front`.
|
|
||||||
# We continue until we reach the limit *or* new_front is empty (i.e.,
|
|
||||||
# we've run out of things to select
|
|
||||||
while front and len(event_results) < limit:
|
|
||||||
|
|
||||||
new_front = []
|
for event_id in event_list:
|
||||||
for event_id in front:
|
depth = self._simple_select_one_onecol_txn(
|
||||||
logger.debug(
|
txn,
|
||||||
"_backfill_interaction: id=%s",
|
table="events",
|
||||||
event_id
|
keyvalues={
|
||||||
)
|
"event_id": event_id,
|
||||||
|
},
|
||||||
|
retcol="depth"
|
||||||
|
)
|
||||||
|
|
||||||
txn.execute(
|
queue.put((-depth, event_id))
|
||||||
query,
|
|
||||||
(room_id, event_id, limit - len(event_results))
|
|
||||||
)
|
|
||||||
|
|
||||||
for row in txn.fetchall():
|
while not queue.empty() and len(event_results) < limit:
|
||||||
logger.debug(
|
try:
|
||||||
"_backfill_interaction: got id=%s",
|
_, event_id = queue.get_nowait()
|
||||||
*row
|
except Empty:
|
||||||
)
|
break
|
||||||
new_front.append(row[0])
|
|
||||||
|
|
||||||
front = new_front
|
if event_id in event_results:
|
||||||
event_results += new_front
|
continue
|
||||||
|
|
||||||
return self._get_events_txn(txn, event_results)
|
event_results.add(event_id)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
query,
|
||||||
|
(room_id, event_id, False, limit - len(event_results))
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in txn.fetchall():
|
||||||
|
if row[1] not in event_results:
|
||||||
|
queue.put((-row[0], row[1]))
|
||||||
|
|
||||||
|
return event_results
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def get_missing_events(self, room_id, earliest_events, latest_events,
|
def get_missing_events(self, room_id, earliest_events, latest_events,
|
||||||
limit, min_depth):
|
limit, min_depth):
|
||||||
return self.runInteraction(
|
ids = yield self.runInteraction(
|
||||||
"get_missing_events",
|
"get_missing_events",
|
||||||
self._get_missing_events,
|
self._get_missing_events,
|
||||||
room_id, earliest_events, latest_events, limit, min_depth
|
room_id, earliest_events, latest_events, limit, min_depth
|
||||||
)
|
)
|
||||||
|
|
||||||
|
events = yield self._get_events(ids)
|
||||||
|
|
||||||
|
events = sorted(
|
||||||
|
[ev for ev in events if ev.depth >= min_depth],
|
||||||
|
key=lambda e: e.depth,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(events[:limit])
|
||||||
|
|
||||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
|
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
|
||||||
limit, min_depth):
|
limit, min_depth):
|
||||||
|
|
||||||
|
@ -436,14 +476,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
front = new_front
|
front = new_front
|
||||||
event_results |= new_front
|
event_results |= new_front
|
||||||
|
|
||||||
events = self._get_events_txn(txn, event_results)
|
return event_results
|
||||||
|
|
||||||
events = sorted(
|
|
||||||
[ev for ev in events if ev.depth >= min_depth],
|
|
||||||
key=lambda e: e.depth,
|
|
||||||
)
|
|
||||||
|
|
||||||
return events[:limit]
|
|
||||||
|
|
||||||
def clean_room_for_join(self, room_id):
|
def clean_room_for_join(self, room_id):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
|
@ -456,3 +489,4 @@ class EventFederationStore(SQLBaseStore):
|
||||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||||
|
|
||||||
txn.execute(query, (room_id,))
|
txn.execute(query, (room_id,))
|
||||||
|
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
|
||||||
|
|
|
@ -15,20 +15,36 @@
|
||||||
|
|
||||||
from _base import SQLBaseStore, _RollbackButIsFineException
|
from _base import SQLBaseStore, _RollbackButIsFineException
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
|
from synapse.util.logcontext import preserve_context_over_deferred
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||||
|
|
||||||
from syutil.base64util import decode_base64
|
from syutil.base64util import decode_base64
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import simplejson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# These values are used in the `enqueus_event` and `_do_fetch` methods to
|
||||||
|
# control how we batch/bulk fetch events from the database.
|
||||||
|
# The values are plucked out of thing air to make initial sync run faster
|
||||||
|
# on jki.re
|
||||||
|
# TODO: Make these configurable.
|
||||||
|
EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
|
||||||
|
EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
|
||||||
|
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
|
||||||
|
|
||||||
|
|
||||||
class EventsStore(SQLBaseStore):
|
class EventsStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -41,20 +57,32 @@ class EventsStore(SQLBaseStore):
|
||||||
self.min_token -= 1
|
self.min_token -= 1
|
||||||
stream_ordering = self.min_token
|
stream_ordering = self.min_token
|
||||||
|
|
||||||
|
if stream_ordering is None:
|
||||||
|
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
|
||||||
|
else:
|
||||||
|
@contextmanager
|
||||||
|
def stream_ordering_manager():
|
||||||
|
yield stream_ordering
|
||||||
|
stream_ordering_manager = stream_ordering_manager()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.runInteraction(
|
with stream_ordering_manager as stream_ordering:
|
||||||
"persist_event",
|
yield self.runInteraction(
|
||||||
self._persist_event_txn,
|
"persist_event",
|
||||||
event=event,
|
self._persist_event_txn,
|
||||||
context=context,
|
event=event,
|
||||||
backfilled=backfilled,
|
context=context,
|
||||||
stream_ordering=stream_ordering,
|
backfilled=backfilled,
|
||||||
is_new_state=is_new_state,
|
stream_ordering=stream_ordering,
|
||||||
current_state=current_state,
|
is_new_state=is_new_state,
|
||||||
)
|
current_state=current_state,
|
||||||
|
)
|
||||||
except _RollbackButIsFineException:
|
except _RollbackButIsFineException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
|
||||||
|
defer.returnValue((stream_ordering, max_persisted_id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_event(self, event_id, check_redacted=True,
|
def get_event(self, event_id, check_redacted=True,
|
||||||
get_prev_content=False, allow_rejected=False,
|
get_prev_content=False, allow_rejected=False,
|
||||||
|
@ -74,18 +102,17 @@ class EventsStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred : A FrozenEvent.
|
Deferred : A FrozenEvent.
|
||||||
"""
|
"""
|
||||||
event = yield self.runInteraction(
|
events = yield self._get_events(
|
||||||
"get_event", self._get_event_txn,
|
[event_id],
|
||||||
event_id,
|
|
||||||
check_redacted=check_redacted,
|
check_redacted=check_redacted,
|
||||||
get_prev_content=get_prev_content,
|
get_prev_content=get_prev_content,
|
||||||
allow_rejected=allow_rejected,
|
allow_rejected=allow_rejected,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not event and not allow_none:
|
if not events and not allow_none:
|
||||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
raise RuntimeError("Could not find event %s" % (event_id,))
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(events[0] if events else None)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||||
|
@ -95,15 +122,6 @@ class EventsStore(SQLBaseStore):
|
||||||
# Remove the any existing cache entries for the event_id
|
# Remove the any existing cache entries for the event_id
|
||||||
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||||
|
|
||||||
if stream_ordering is None:
|
|
||||||
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
|
|
||||||
return self._persist_event_txn(
|
|
||||||
txn, event, context, backfilled,
|
|
||||||
stream_ordering=stream_ordering,
|
|
||||||
is_new_state=is_new_state,
|
|
||||||
current_state=current_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We purposefully do this first since if we include a `current_state`
|
# We purposefully do this first since if we include a `current_state`
|
||||||
# key, we *want* to update the `current_state_events` table
|
# key, we *want* to update the `current_state_events` table
|
||||||
if current_state:
|
if current_state:
|
||||||
|
@ -134,19 +152,17 @@ class EventsStore(SQLBaseStore):
|
||||||
outlier = event.internal_metadata.is_outlier()
|
outlier = event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
if not outlier:
|
if not outlier:
|
||||||
self._store_state_groups_txn(txn, event, context)
|
|
||||||
|
|
||||||
self._update_min_depth_for_room_txn(
|
self._update_min_depth_for_room_txn(
|
||||||
txn,
|
txn,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
event.depth
|
event.depth
|
||||||
)
|
)
|
||||||
|
|
||||||
have_persisted = self._simple_select_one_onecol_txn(
|
have_persisted = self._simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_json",
|
table="events",
|
||||||
keyvalues={"event_id": event.event_id},
|
keyvalues={"event_id": event.event_id},
|
||||||
retcol="event_id",
|
retcols=["event_id", "outlier"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -161,7 +177,9 @@ class EventsStore(SQLBaseStore):
|
||||||
# if we are persisting an event that we had persisted as an outlier,
|
# if we are persisting an event that we had persisted as an outlier,
|
||||||
# but is no longer one.
|
# but is no longer one.
|
||||||
if have_persisted:
|
if have_persisted:
|
||||||
if not outlier:
|
if not outlier and have_persisted["outlier"]:
|
||||||
|
self._store_state_groups_txn(txn, event, context)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"UPDATE event_json SET internal_metadata = ?"
|
"UPDATE event_json SET internal_metadata = ?"
|
||||||
" WHERE event_id = ?"
|
" WHERE event_id = ?"
|
||||||
|
@ -181,6 +199,9 @@ class EventsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not outlier:
|
||||||
|
self._store_state_groups_txn(txn, event, context)
|
||||||
|
|
||||||
self._handle_prev_events(
|
self._handle_prev_events(
|
||||||
txn,
|
txn,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
|
@ -400,3 +421,407 @@ class EventsStore(SQLBaseStore):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"have_events", f,
|
"have_events", f,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_events(self, event_ids, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
if not event_ids:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
|
event_map = self._get_events_from_cache(
|
||||||
|
event_ids,
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
missing_events_ids = [e for e in event_ids if e not in event_map]
|
||||||
|
|
||||||
|
if not missing_events_ids:
|
||||||
|
defer.returnValue([
|
||||||
|
event_map[e_id] for e_id in event_ids
|
||||||
|
if e_id in event_map and event_map[e_id]
|
||||||
|
])
|
||||||
|
|
||||||
|
missing_events = yield self._enqueue_events(
|
||||||
|
missing_events_ids,
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_map.update(missing_events)
|
||||||
|
|
||||||
|
defer.returnValue([
|
||||||
|
event_map[e_id] for e_id in event_ids
|
||||||
|
if e_id in event_map and event_map[e_id]
|
||||||
|
])
|
||||||
|
|
||||||
|
def _get_events_txn(self, txn, event_ids, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
if not event_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
event_map = self._get_events_from_cache(
|
||||||
|
event_ids,
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
missing_events_ids = [e for e in event_ids if e not in event_map]
|
||||||
|
|
||||||
|
if not missing_events_ids:
|
||||||
|
return [
|
||||||
|
event_map[e_id] for e_id in event_ids
|
||||||
|
if e_id in event_map and event_map[e_id]
|
||||||
|
]
|
||||||
|
|
||||||
|
missing_events = self._fetch_events_txn(
|
||||||
|
txn,
|
||||||
|
missing_events_ids,
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_map.update(missing_events)
|
||||||
|
|
||||||
|
return [
|
||||||
|
event_map[e_id] for e_id in event_ids
|
||||||
|
if e_id in event_map and event_map[e_id]
|
||||||
|
]
|
||||||
|
|
||||||
|
def _invalidate_get_event_cache(self, event_id):
|
||||||
|
for check_redacted in (False, True):
|
||||||
|
for get_prev_content in (False, True):
|
||||||
|
self._get_event_cache.invalidate(event_id, check_redacted,
|
||||||
|
get_prev_content)
|
||||||
|
|
||||||
|
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
|
||||||
|
events = self._get_events_txn(
|
||||||
|
txn, [event_id],
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
return events[0] if events else None
|
||||||
|
|
||||||
|
def _get_events_from_cache(self, events, check_redacted, get_prev_content,
|
||||||
|
allow_rejected):
|
||||||
|
event_map = {}
|
||||||
|
|
||||||
|
for event_id in events:
|
||||||
|
try:
|
||||||
|
ret = self._get_event_cache.get(
|
||||||
|
event_id, check_redacted, get_prev_content
|
||||||
|
)
|
||||||
|
|
||||||
|
if allow_rejected or not ret.rejected_reason:
|
||||||
|
event_map[event_id] = ret
|
||||||
|
else:
|
||||||
|
event_map[event_id] = None
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return event_map
|
||||||
|
|
||||||
|
def _do_fetch(self, conn):
|
||||||
|
"""Takes a database connection and waits for requests for events from
|
||||||
|
the _event_fetch_list queue.
|
||||||
|
"""
|
||||||
|
event_list = []
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with self._event_fetch_lock:
|
||||||
|
event_list = self._event_fetch_list
|
||||||
|
self._event_fetch_list = []
|
||||||
|
|
||||||
|
if not event_list:
|
||||||
|
single_threaded = self.database_engine.single_threaded
|
||||||
|
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
|
||||||
|
self._event_fetch_ongoing -= 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
event_id_lists = zip(*event_list)[0]
|
||||||
|
event_ids = [
|
||||||
|
item for sublist in event_id_lists for item in sublist
|
||||||
|
]
|
||||||
|
|
||||||
|
rows = self._new_transaction(
|
||||||
|
conn, "do_fetch", [], self._fetch_event_rows, event_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
row_dict = {
|
||||||
|
r["event_id"]: r
|
||||||
|
for r in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
# We only want to resolve deferreds from the main thread
|
||||||
|
def fire(lst, res):
|
||||||
|
for ids, d in lst:
|
||||||
|
if not d.called:
|
||||||
|
try:
|
||||||
|
d.callback([
|
||||||
|
res[i]
|
||||||
|
for i in ids
|
||||||
|
if i in res
|
||||||
|
])
|
||||||
|
except:
|
||||||
|
logger.exception("Failed to callback")
|
||||||
|
reactor.callFromThread(fire, event_list, row_dict)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("do_fetch")
|
||||||
|
|
||||||
|
# We only want to resolve deferreds from the main thread
|
||||||
|
def fire(evs):
|
||||||
|
for _, d in evs:
|
||||||
|
if not d.called:
|
||||||
|
d.errback(e)
|
||||||
|
|
||||||
|
if event_list:
|
||||||
|
reactor.callFromThread(fire, event_list)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _enqueue_events(self, events, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
"""Fetches events from the database using the _event_fetch_list. This
|
||||||
|
allows batch and bulk fetching of events - it allows us to fetch events
|
||||||
|
without having to create a new transaction for each request for events.
|
||||||
|
"""
|
||||||
|
if not events:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
events_d = defer.Deferred()
|
||||||
|
with self._event_fetch_lock:
|
||||||
|
self._event_fetch_list.append(
|
||||||
|
(events, events_d)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._event_fetch_lock.notify()
|
||||||
|
|
||||||
|
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
|
||||||
|
self._event_fetch_ongoing += 1
|
||||||
|
should_start = True
|
||||||
|
else:
|
||||||
|
should_start = False
|
||||||
|
|
||||||
|
if should_start:
|
||||||
|
self.runWithConnection(
|
||||||
|
self._do_fetch
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = yield preserve_context_over_deferred(events_d)
|
||||||
|
|
||||||
|
if not allow_rejected:
|
||||||
|
rows[:] = [r for r in rows if not r["rejects"]]
|
||||||
|
|
||||||
|
res = yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
self._get_event_from_row(
|
||||||
|
row["internal_metadata"], row["json"], row["redacts"],
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
rejected_reason=row["rejects"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
e.event_id: e
|
||||||
|
for e in res if e
|
||||||
|
})
|
||||||
|
|
||||||
|
def _fetch_event_rows(self, txn, events):
|
||||||
|
rows = []
|
||||||
|
N = 200
|
||||||
|
for i in range(1 + len(events) / N):
|
||||||
|
evs = events[i*N:(i + 1)*N]
|
||||||
|
if not evs:
|
||||||
|
break
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT "
|
||||||
|
" e.event_id as event_id, "
|
||||||
|
" e.internal_metadata,"
|
||||||
|
" e.json,"
|
||||||
|
" r.redacts as redacts,"
|
||||||
|
" rej.event_id as rejects "
|
||||||
|
" FROM event_json as e"
|
||||||
|
" LEFT JOIN rejections as rej USING (event_id)"
|
||||||
|
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
|
||||||
|
" WHERE e.event_id IN (%s)"
|
||||||
|
) % (",".join(["?"]*len(evs)),)
|
||||||
|
|
||||||
|
txn.execute(sql, evs)
|
||||||
|
rows.extend(self.cursor_to_dict(txn))
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def _fetch_events_txn(self, txn, events, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
if not events:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
rows = self._fetch_event_rows(
|
||||||
|
txn, events,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not allow_rejected:
|
||||||
|
rows[:] = [r for r in rows if not r["rejects"]]
|
||||||
|
|
||||||
|
res = [
|
||||||
|
self._get_event_from_row_txn(
|
||||||
|
txn,
|
||||||
|
row["internal_metadata"], row["json"], row["redacts"],
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
rejected_reason=row["rejects"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
r.event_id: r
|
||||||
|
for r in res
|
||||||
|
}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_event_from_row(self, internal_metadata, js, redacted,
|
||||||
|
check_redacted=True, get_prev_content=False,
|
||||||
|
rejected_reason=None):
|
||||||
|
d = json.loads(js)
|
||||||
|
internal_metadata = json.loads(internal_metadata)
|
||||||
|
|
||||||
|
if rejected_reason:
|
||||||
|
rejected_reason = yield self._simple_select_one_onecol(
|
||||||
|
table="rejections",
|
||||||
|
keyvalues={"event_id": rejected_reason},
|
||||||
|
retcol="reason",
|
||||||
|
desc="_get_event_from_row",
|
||||||
|
)
|
||||||
|
|
||||||
|
ev = FrozenEvent(
|
||||||
|
d,
|
||||||
|
internal_metadata_dict=internal_metadata,
|
||||||
|
rejected_reason=rejected_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
if check_redacted and redacted:
|
||||||
|
ev = prune_event(ev)
|
||||||
|
|
||||||
|
redaction_id = yield self._simple_select_one_onecol(
|
||||||
|
table="redactions",
|
||||||
|
keyvalues={"redacts": ev.event_id},
|
||||||
|
retcol="event_id",
|
||||||
|
desc="_get_event_from_row",
|
||||||
|
)
|
||||||
|
|
||||||
|
ev.unsigned["redacted_by"] = redaction_id
|
||||||
|
# Get the redaction event.
|
||||||
|
|
||||||
|
because = yield self.get_event(
|
||||||
|
redaction_id,
|
||||||
|
check_redacted=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if because:
|
||||||
|
ev.unsigned["redacted_because"] = because
|
||||||
|
|
||||||
|
if get_prev_content and "replaces_state" in ev.unsigned:
|
||||||
|
prev = yield self.get_event(
|
||||||
|
ev.unsigned["replaces_state"],
|
||||||
|
get_prev_content=False,
|
||||||
|
)
|
||||||
|
if prev:
|
||||||
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
|
||||||
|
self._get_event_cache.prefill(
|
||||||
|
ev.event_id, check_redacted, get_prev_content, ev
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(ev)
|
||||||
|
|
||||||
|
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
|
||||||
|
check_redacted=True, get_prev_content=False,
|
||||||
|
rejected_reason=None):
|
||||||
|
d = json.loads(js)
|
||||||
|
internal_metadata = json.loads(internal_metadata)
|
||||||
|
|
||||||
|
if rejected_reason:
|
||||||
|
rejected_reason = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="rejections",
|
||||||
|
keyvalues={"event_id": rejected_reason},
|
||||||
|
retcol="reason",
|
||||||
|
)
|
||||||
|
|
||||||
|
ev = FrozenEvent(
|
||||||
|
d,
|
||||||
|
internal_metadata_dict=internal_metadata,
|
||||||
|
rejected_reason=rejected_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
if check_redacted and redacted:
|
||||||
|
ev = prune_event(ev)
|
||||||
|
|
||||||
|
redaction_id = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="redactions",
|
||||||
|
keyvalues={"redacts": ev.event_id},
|
||||||
|
retcol="event_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
ev.unsigned["redacted_by"] = redaction_id
|
||||||
|
# Get the redaction event.
|
||||||
|
|
||||||
|
because = self._get_event_txn(
|
||||||
|
txn,
|
||||||
|
redaction_id,
|
||||||
|
check_redacted=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if because:
|
||||||
|
ev.unsigned["redacted_because"] = because
|
||||||
|
|
||||||
|
if get_prev_content and "replaces_state" in ev.unsigned:
|
||||||
|
prev = self._get_event_txn(
|
||||||
|
txn,
|
||||||
|
ev.unsigned["replaces_state"],
|
||||||
|
get_prev_content=False,
|
||||||
|
)
|
||||||
|
if prev:
|
||||||
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
|
||||||
|
self._get_event_cache.prefill(
|
||||||
|
ev.event_id, check_redacted, get_prev_content, ev
|
||||||
|
)
|
||||||
|
|
||||||
|
return ev
|
||||||
|
|
||||||
|
def _parse_events(self, rows):
|
||||||
|
return self.runInteraction(
|
||||||
|
"_parse_events", self._parse_events_txn, rows
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_events_txn(self, txn, rows):
|
||||||
|
event_ids = [r["event_id"] for r in rows]
|
||||||
|
|
||||||
|
return self._get_events_txn(txn, event_ids)
|
||||||
|
|
||||||
|
def _has_been_redacted_txn(self, txn, event):
|
||||||
|
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
|
||||||
|
txn.execute(sql, (event.event_id,))
|
||||||
|
result = txn.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
|
@ -13,7 +13,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, cached
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
||||||
class PresenceStore(SQLBaseStore):
|
class PresenceStore(SQLBaseStore):
|
||||||
|
@ -87,31 +89,48 @@ class PresenceStore(SQLBaseStore):
|
||||||
desc="add_presence_list_pending",
|
desc="add_presence_list_pending",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||||
return self._simple_update_one(
|
result = yield self._simple_update_one(
|
||||||
table="presence_list",
|
table="presence_list",
|
||||||
keyvalues={"user_id": observer_localpart,
|
keyvalues={"user_id": observer_localpart,
|
||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
updatevalues={"accepted": True},
|
updatevalues={"accepted": True},
|
||||||
desc="set_presence_list_accepted",
|
desc="set_presence_list_accepted",
|
||||||
)
|
)
|
||||||
|
self.get_presence_list_accepted.invalidate(observer_localpart)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
def get_presence_list(self, observer_localpart, accepted=None):
|
def get_presence_list(self, observer_localpart, accepted=None):
|
||||||
keyvalues = {"user_id": observer_localpart}
|
if accepted:
|
||||||
if accepted is not None:
|
return self.get_presence_list_accepted(observer_localpart)
|
||||||
keyvalues["accepted"] = accepted
|
else:
|
||||||
|
keyvalues = {"user_id": observer_localpart}
|
||||||
|
if accepted is not None:
|
||||||
|
keyvalues["accepted"] = accepted
|
||||||
|
|
||||||
|
return self._simple_select_list(
|
||||||
|
table="presence_list",
|
||||||
|
keyvalues=keyvalues,
|
||||||
|
retcols=["observed_user_id", "accepted"],
|
||||||
|
desc="get_presence_list",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def get_presence_list_accepted(self, observer_localpart):
|
||||||
return self._simple_select_list(
|
return self._simple_select_list(
|
||||||
table="presence_list",
|
table="presence_list",
|
||||||
keyvalues=keyvalues,
|
keyvalues={"user_id": observer_localpart, "accepted": True},
|
||||||
retcols=["observed_user_id", "accepted"],
|
retcols=["observed_user_id", "accepted"],
|
||||||
desc="get_presence_list",
|
desc="get_presence_list_accepted",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def del_presence_list(self, observer_localpart, observed_userid):
|
def del_presence_list(self, observer_localpart, observed_userid):
|
||||||
return self._simple_delete_one(
|
yield self._simple_delete_one(
|
||||||
table="presence_list",
|
table="presence_list",
|
||||||
keyvalues={"user_id": observer_localpart,
|
keyvalues={"user_id": observer_localpart,
|
||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
desc="del_presence_list",
|
desc="del_presence_list",
|
||||||
)
|
)
|
||||||
|
self.get_presence_list_accepted.invalidate(observer_localpart)
|
||||||
|
|
|
@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PushRuleStore(SQLBaseStore):
|
class PushRuleStore(SQLBaseStore):
|
||||||
|
@cached()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_push_rules_for_user(self, user_name):
|
def get_push_rules_for_user(self, user_name):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
|
@ -31,6 +32,7 @@ class PushRuleStore(SQLBaseStore):
|
||||||
"user_name": user_name,
|
"user_name": user_name,
|
||||||
},
|
},
|
||||||
retcols=PushRuleTable.fields,
|
retcols=PushRuleTable.fields,
|
||||||
|
desc="get_push_rules_enabled_for_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
rows.sort(
|
rows.sort(
|
||||||
|
@ -150,6 +152,10 @@ class PushRuleStore(SQLBaseStore):
|
||||||
|
|
||||||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self.get_push_rules_for_user.invalidate, user_name
|
||||||
|
)
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||||
)
|
)
|
||||||
|
@ -182,6 +188,9 @@ class PushRuleStore(SQLBaseStore):
|
||||||
new_rule['priority_class'] = priority_class
|
new_rule['priority_class'] = priority_class
|
||||||
new_rule['priority'] = new_prio
|
new_rule['priority'] = new_prio
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self.get_push_rules_for_user.invalidate, user_name
|
||||||
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||||
)
|
)
|
||||||
|
@ -208,17 +217,34 @@ class PushRuleStore(SQLBaseStore):
|
||||||
{'user_name': user_name, 'rule_id': rule_id},
|
{'user_name': user_name, 'rule_id': rule_id},
|
||||||
desc="delete_push_rule",
|
desc="delete_push_rule",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.get_push_rules_for_user.invalidate(user_name)
|
||||||
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||||
yield self._simple_upsert(
|
ret = yield self.runInteraction(
|
||||||
|
"_set_push_rule_enabled_txn",
|
||||||
|
self._set_push_rule_enabled_txn,
|
||||||
|
user_name, rule_id, enabled
|
||||||
|
)
|
||||||
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
|
||||||
|
new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
PushRuleEnableTable.table_name,
|
PushRuleEnableTable.table_name,
|
||||||
{'user_name': user_name, 'rule_id': rule_id},
|
{'user_name': user_name, 'rule_id': rule_id},
|
||||||
{'enabled': 1 if enabled else 0},
|
{'enabled': 1 if enabled else 0},
|
||||||
desc="set_push_rule_enabled",
|
{'id': new_id},
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_push_rules_for_user.invalidate, user_name
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||||
)
|
)
|
||||||
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
|
||||||
|
|
||||||
|
|
||||||
class RuleNotFoundException(Exception):
|
class RuleNotFoundException(Exception):
|
||||||
|
|
|
@ -77,16 +77,16 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Results in a MembershipEvent or None.
|
Deferred: Results in a MembershipEvent or None.
|
||||||
"""
|
"""
|
||||||
def f(txn):
|
return self.runInteraction(
|
||||||
events = self._get_members_events_txn(
|
"get_room_member",
|
||||||
txn,
|
self._get_members_events_txn,
|
||||||
room_id,
|
room_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
).addCallback(
|
||||||
|
self._get_events
|
||||||
return events[0] if events else None
|
).addCallback(
|
||||||
|
lambda events: events[0] if events else None
|
||||||
return self.runInteraction("get_room_member", f)
|
)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_users_in_room(self, room_id):
|
def get_users_in_room(self, room_id):
|
||||||
|
@ -112,15 +112,12 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
list of namedtuples representing the members in this room.
|
list of namedtuples representing the members in this room.
|
||||||
"""
|
"""
|
||||||
|
return self.runInteraction(
|
||||||
def f(txn):
|
"get_room_members",
|
||||||
return self._get_members_events_txn(
|
self._get_members_events_txn,
|
||||||
txn,
|
room_id,
|
||||||
room_id,
|
membership=membership,
|
||||||
membership=membership,
|
).addCallback(self._get_events)
|
||||||
)
|
|
||||||
|
|
||||||
return self.runInteraction("get_room_members", f)
|
|
||||||
|
|
||||||
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
|
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
|
||||||
""" Get all the rooms for this user where the membership for this user
|
""" Get all the rooms for this user where the membership for this user
|
||||||
|
@ -192,14 +189,14 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_members_query", self._get_members_events_txn,
|
"get_members_query", self._get_members_events_txn,
|
||||||
where_clause, where_values
|
where_clause, where_values
|
||||||
)
|
).addCallbacks(self._get_events)
|
||||||
|
|
||||||
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
|
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
|
||||||
rows = self._get_members_rows_txn(
|
rows = self._get_members_rows_txn(
|
||||||
txn,
|
txn,
|
||||||
room_id, membership, user_id,
|
room_id, membership, user_id,
|
||||||
)
|
)
|
||||||
return self._get_events_txn(txn, [r["event_id"] for r in rows])
|
return [r["event_id"] for r in rows]
|
||||||
|
|
||||||
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
||||||
where_clause = "c.room_id = ?"
|
where_clause = "c.room_id = ?"
|
||||||
|
|
19
synapse/storage/schema/delta/19/event_index.sql
Normal file
19
synapse/storage/schema/delta/19/event_index.sql
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
/* Copyright 2015 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
CREATE INDEX events_order_topo_stream_room ON events(
|
||||||
|
topological_ordering, stream_ordering, room_id
|
||||||
|
);
|
|
@ -43,6 +43,7 @@ class StateStore(SQLBaseStore):
|
||||||
* `state_groups_state`: Maps state group to state events.
|
* `state_groups_state`: Maps state group to state events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def get_state_groups(self, event_ids):
|
def get_state_groups(self, event_ids):
|
||||||
""" Get the state groups for the given list of event_ids
|
""" Get the state groups for the given list of event_ids
|
||||||
|
|
||||||
|
@ -71,17 +72,29 @@ class StateStore(SQLBaseStore):
|
||||||
retcol="event_id",
|
retcol="event_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
state = self._get_events_txn(txn, state_ids)
|
res[group] = state_ids
|
||||||
|
|
||||||
res[group] = state
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return self.runInteraction(
|
states = yield self.runInteraction(
|
||||||
"get_state_groups",
|
"get_state_groups",
|
||||||
f,
|
f,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def c(vals):
|
||||||
|
vals[:] = yield self._get_events(vals, get_prev_content=False)
|
||||||
|
|
||||||
|
yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
c(vals)
|
||||||
|
for vals in states.values()
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(states)
|
||||||
|
|
||||||
def _store_state_groups_txn(self, txn, event, context):
|
def _store_state_groups_txn(self, txn, event, context):
|
||||||
if context.current_state is None:
|
if context.current_state is None:
|
||||||
return
|
return
|
||||||
|
@ -152,11 +165,12 @@ class StateStore(SQLBaseStore):
|
||||||
args = (room_id, )
|
args = (room_id, )
|
||||||
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
results = self.cursor_to_dict(txn)
|
results = txn.fetchall()
|
||||||
|
|
||||||
return self._parse_events_txn(txn, results)
|
return [r[0] for r in results]
|
||||||
|
|
||||||
events = yield self.runInteraction("get_current_state", f)
|
event_ids = yield self.runInteraction("get_current_state", f)
|
||||||
|
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
@cached(num_args=3)
|
@cached(num_args=3)
|
||||||
|
|
|
@ -37,11 +37,9 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream"
|
||||||
_TOPOLOGICAL_TOKEN = "topological"
|
_TOPOLOGICAL_TOKEN = "topological"
|
||||||
|
|
||||||
|
|
||||||
class _StreamToken(namedtuple("_StreamToken", "topological stream")):
|
def lower_bound(token):
|
||||||
"""Tokens are positions between events. The token "s1" comes after event 1.
|
if token.topological is None:
|
||||||
|
return "(%d < %s)" % (token.stream, "stream_ordering")
|
||||||
|
else:
|
||||||
|
return "(%d < %s OR (%d = %s AND %d < %s))" % (
|
||||||
|
token.topological, "topological_ordering",
|
||||||
|
token.topological, "topological_ordering",
|
||||||
|
token.stream, "stream_ordering",
|
||||||
|
)
|
||||||
|
|
||||||
s0 s1
|
|
||||||
| |
|
|
||||||
[0] V [1] V [2]
|
|
||||||
|
|
||||||
Tokens can either be a point in the live event stream or a cursor going
|
def upper_bound(token):
|
||||||
through historic events.
|
if token.topological is None:
|
||||||
|
return "(%d >= %s)" % (token.stream, "stream_ordering")
|
||||||
When traversing the live event stream events are ordered by when they
|
else:
|
||||||
arrived at the homeserver.
|
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
|
||||||
|
token.topological, "topological_ordering",
|
||||||
When traversing historic events the events are ordered by their depth in
|
token.topological, "topological_ordering",
|
||||||
the event graph "topological_ordering" and then by when they arrived at the
|
token.stream, "stream_ordering",
|
||||||
homeserver "stream_ordering".
|
)
|
||||||
|
|
||||||
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
|
||||||
event it comes after. Historic tokens start with a "t" followed by the
|
|
||||||
"topological_ordering" id of the event it comes after, follewed by "-",
|
|
||||||
followed by the "stream_ordering" id of the event it comes after.
|
|
||||||
"""
|
|
||||||
__slots__ = []
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse(cls, string):
|
|
||||||
try:
|
|
||||||
if string[0] == 's':
|
|
||||||
return cls(topological=None, stream=int(string[1:]))
|
|
||||||
if string[0] == 't':
|
|
||||||
parts = string[1:].split('-', 1)
|
|
||||||
return cls(topological=int(parts[0]), stream=int(parts[1]))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse_stream_token(cls, string):
|
|
||||||
try:
|
|
||||||
if string[0] == 's':
|
|
||||||
return cls(topological=None, stream=int(string[1:]))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.topological is not None:
|
|
||||||
return "t%d-%d" % (self.topological, self.stream)
|
|
||||||
else:
|
|
||||||
return "s%d" % (self.stream,)
|
|
||||||
|
|
||||||
def lower_bound(self):
|
|
||||||
if self.topological is None:
|
|
||||||
return "(%d < %s)" % (self.stream, "stream_ordering")
|
|
||||||
else:
|
|
||||||
return "(%d < %s OR (%d = %s AND %d < %s))" % (
|
|
||||||
self.topological, "topological_ordering",
|
|
||||||
self.topological, "topological_ordering",
|
|
||||||
self.stream, "stream_ordering",
|
|
||||||
)
|
|
||||||
|
|
||||||
def upper_bound(self):
|
|
||||||
if self.topological is None:
|
|
||||||
return "(%d >= %s)" % (self.stream, "stream_ordering")
|
|
||||||
else:
|
|
||||||
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
|
|
||||||
self.topological, "topological_ordering",
|
|
||||||
self.topological, "topological_ordering",
|
|
||||||
self.stream, "stream_ordering",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StreamStore(SQLBaseStore):
|
class StreamStore(SQLBaseStore):
|
||||||
|
@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore):
|
||||||
limit = MAX_STREAM_SIZE
|
limit = MAX_STREAM_SIZE
|
||||||
|
|
||||||
# From and to keys should be integers from ordering.
|
# From and to keys should be integers from ordering.
|
||||||
from_id = _StreamToken.parse_stream_token(from_key)
|
from_id = RoomStreamToken.parse_stream_token(from_key)
|
||||||
to_id = _StreamToken.parse_stream_token(to_key)
|
to_id = RoomStreamToken.parse_stream_token(to_key)
|
||||||
|
|
||||||
if from_key == to_key:
|
if from_key == to_key:
|
||||||
defer.returnValue(([], to_key))
|
defer.returnValue(([], to_key))
|
||||||
|
@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore):
|
||||||
limit = MAX_STREAM_SIZE
|
limit = MAX_STREAM_SIZE
|
||||||
|
|
||||||
# From and to keys should be integers from ordering.
|
# From and to keys should be integers from ordering.
|
||||||
from_id = _StreamToken.parse_stream_token(from_key)
|
from_id = RoomStreamToken.parse_stream_token(from_key)
|
||||||
to_id = _StreamToken.parse_stream_token(to_key)
|
to_id = RoomStreamToken.parse_stream_token(to_key)
|
||||||
|
|
||||||
if from_key == to_key:
|
if from_key == to_key:
|
||||||
return defer.succeed(([], to_key))
|
return defer.succeed(([], to_key))
|
||||||
|
@ -276,7 +224,7 @@ class StreamStore(SQLBaseStore):
|
||||||
|
|
||||||
return self.runInteraction("get_room_events_stream", f)
|
return self.runInteraction("get_room_events_stream", f)
|
||||||
|
|
||||||
@log_function
|
@defer.inlineCallbacks
|
||||||
def paginate_room_events(self, room_id, from_key, to_key=None,
|
def paginate_room_events(self, room_id, from_key, to_key=None,
|
||||||
direction='b', limit=-1,
|
direction='b', limit=-1,
|
||||||
with_feedback=False):
|
with_feedback=False):
|
||||||
|
@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore):
|
||||||
args = [False, room_id]
|
args = [False, room_id]
|
||||||
if direction == 'b':
|
if direction == 'b':
|
||||||
order = "DESC"
|
order = "DESC"
|
||||||
bounds = _StreamToken.parse(from_key).upper_bound()
|
bounds = upper_bound(RoomStreamToken.parse(from_key))
|
||||||
if to_key:
|
if to_key:
|
||||||
bounds = "%s AND %s" % (
|
bounds = "%s AND %s" % (
|
||||||
bounds, _StreamToken.parse(to_key).lower_bound()
|
bounds, lower_bound(RoomStreamToken.parse(to_key))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
order = "ASC"
|
order = "ASC"
|
||||||
bounds = _StreamToken.parse(from_key).lower_bound()
|
bounds = lower_bound(RoomStreamToken.parse(from_key))
|
||||||
if to_key:
|
if to_key:
|
||||||
bounds = "%s AND %s" % (
|
bounds = "%s AND %s" % (
|
||||||
bounds, _StreamToken.parse(to_key).upper_bound()
|
bounds, upper_bound(RoomStreamToken.parse(to_key))
|
||||||
)
|
)
|
||||||
|
|
||||||
if int(limit) > 0:
|
if int(limit) > 0:
|
||||||
|
@ -333,28 +281,30 @@ class StreamStore(SQLBaseStore):
|
||||||
# when we are going backwards so we subtract one from the
|
# when we are going backwards so we subtract one from the
|
||||||
# stream part.
|
# stream part.
|
||||||
toke -= 1
|
toke -= 1
|
||||||
next_token = str(_StreamToken(topo, toke))
|
next_token = str(RoomStreamToken(topo, toke))
|
||||||
else:
|
else:
|
||||||
# TODO (erikj): We should work out what to do here instead.
|
# TODO (erikj): We should work out what to do here instead.
|
||||||
next_token = to_key if to_key else from_key
|
next_token = to_key if to_key else from_key
|
||||||
|
|
||||||
events = self._get_events_txn(
|
return rows, next_token,
|
||||||
txn,
|
|
||||||
[r["event_id"] for r in rows],
|
|
||||||
get_prev_content=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self._set_before_and_after(events, rows)
|
rows, token = yield self.runInteraction("paginate_room_events", f)
|
||||||
|
|
||||||
return events, next_token,
|
events = yield self._get_events(
|
||||||
|
[r["event_id"] for r in rows],
|
||||||
|
get_prev_content=True
|
||||||
|
)
|
||||||
|
|
||||||
return self.runInteraction("paginate_room_events", f)
|
self._set_before_and_after(events, rows)
|
||||||
|
|
||||||
|
defer.returnValue((events, token))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def get_recent_events_for_room(self, room_id, limit, end_token,
|
def get_recent_events_for_room(self, room_id, limit, end_token,
|
||||||
with_feedback=False, from_token=None):
|
with_feedback=False, from_token=None):
|
||||||
# TODO (erikj): Handle compressed feedback
|
# TODO (erikj): Handle compressed feedback
|
||||||
|
|
||||||
end_token = _StreamToken.parse_stream_token(end_token)
|
end_token = RoomStreamToken.parse_stream_token(end_token)
|
||||||
|
|
||||||
if from_token is None:
|
if from_token is None:
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -365,7 +315,7 @@ class StreamStore(SQLBaseStore):
|
||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from_token = _StreamToken.parse_stream_token(from_token)
|
from_token = RoomStreamToken.parse_stream_token(from_token)
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_ordering, topological_ordering, event_id"
|
"SELECT stream_ordering, topological_ordering, event_id"
|
||||||
" FROM events"
|
" FROM events"
|
||||||
|
@ -395,30 +345,49 @@ class StreamStore(SQLBaseStore):
|
||||||
# stream part.
|
# stream part.
|
||||||
topo = rows[0]["topological_ordering"]
|
topo = rows[0]["topological_ordering"]
|
||||||
toke = rows[0]["stream_ordering"] - 1
|
toke = rows[0]["stream_ordering"] - 1
|
||||||
start_token = str(_StreamToken(topo, toke))
|
start_token = str(RoomStreamToken(topo, toke))
|
||||||
|
|
||||||
token = (start_token, str(end_token))
|
token = (start_token, str(end_token))
|
||||||
else:
|
else:
|
||||||
token = (str(end_token), str(end_token))
|
token = (str(end_token), str(end_token))
|
||||||
|
|
||||||
events = self._get_events_txn(
|
return rows, token
|
||||||
txn,
|
|
||||||
[r["event_id"] for r in rows],
|
|
||||||
get_prev_content=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self._set_before_and_after(events, rows)
|
rows, token = yield self.runInteraction(
|
||||||
|
|
||||||
return events, token
|
|
||||||
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_recent_events_for_room", get_recent_events_for_room_txn
|
"get_recent_events_for_room", get_recent_events_for_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug("stream before")
|
||||||
|
events = yield self._get_events(
|
||||||
|
[r["event_id"] for r in rows],
|
||||||
|
get_prev_content=True
|
||||||
|
)
|
||||||
|
logger.debug("stream after")
|
||||||
|
|
||||||
|
self._set_before_and_after(events, rows)
|
||||||
|
|
||||||
|
defer.returnValue((events, token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_room_events_max_id(self):
|
def get_room_events_max_id(self, direction='f'):
|
||||||
token = yield self._stream_id_gen.get_max_token(self)
|
token = yield self._stream_id_gen.get_max_token(self)
|
||||||
defer.returnValue("s%d" % (token,))
|
if direction != 'b':
|
||||||
|
defer.returnValue("s%d" % (token,))
|
||||||
|
else:
|
||||||
|
topo = yield self.runInteraction(
|
||||||
|
"_get_max_topological_txn", self._get_max_topological_txn
|
||||||
|
)
|
||||||
|
defer.returnValue("t%d-%d" % (topo, token))
|
||||||
|
|
||||||
|
def _get_max_topological_txn(self, txn):
|
||||||
|
txn.execute(
|
||||||
|
"SELECT MAX(topological_ordering) FROM events"
|
||||||
|
" WHERE outlier = ?",
|
||||||
|
(False,)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = txn.fetchall()
|
||||||
|
return rows[0][0] if rows else 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_min_token(self):
|
def _get_min_token(self):
|
||||||
|
@ -439,5 +408,5 @@ class StreamStore(SQLBaseStore):
|
||||||
stream = row["stream_ordering"]
|
stream = row["stream_ordering"]
|
||||||
topo = event.depth
|
topo = event.depth
|
||||||
internal = event.internal_metadata
|
internal = event.internal_metadata
|
||||||
internal.before = str(_StreamToken(topo, stream - 1))
|
internal.before = str(RoomStreamToken(topo, stream - 1))
|
||||||
internal.after = str(_StreamToken(topo, stream))
|
internal.after = str(RoomStreamToken(topo, stream))
|
||||||
|
|
|
@ -78,14 +78,18 @@ class StreamIdGenerator(object):
|
||||||
self._current_max = None
|
self._current_max = None
|
||||||
self._unfinished_ids = deque()
|
self._unfinished_ids = deque()
|
||||||
|
|
||||||
def get_next_txn(self, txn):
|
@defer.inlineCallbacks
|
||||||
|
def get_next(self, store):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
with yield stream_id_gen.get_next as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
if not self._current_max:
|
if not self._current_max:
|
||||||
self._get_or_compute_current_max(txn)
|
yield store.runInteraction(
|
||||||
|
"_compute_current_max",
|
||||||
|
self._get_or_compute_current_max,
|
||||||
|
)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._current_max += 1
|
self._current_max += 1
|
||||||
|
@ -101,7 +105,7 @@ class StreamIdGenerator(object):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.remove(next_id)
|
||||||
|
|
||||||
return manager()
|
defer.returnValue(manager())
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_max_token(self, store):
|
def get_max_token(self, store):
|
||||||
|
|
|
@ -31,7 +31,7 @@ class NullSource(object):
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
def get_new_events_for_user(self, user, from_key, limit):
|
||||||
return defer.succeed(([], from_key))
|
return defer.succeed(([], from_key))
|
||||||
|
|
||||||
def get_current_key(self):
|
def get_current_key(self, direction='f'):
|
||||||
return defer.succeed(0)
|
return defer.succeed(0)
|
||||||
|
|
||||||
def get_pagination_rows(self, user, pagination_config, key):
|
def get_pagination_rows(self, user, pagination_config, key):
|
||||||
|
@ -52,10 +52,10 @@ class EventSources(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_token(self):
|
def get_current_token(self, direction='f'):
|
||||||
token = StreamToken(
|
token = StreamToken(
|
||||||
room_key=(
|
room_key=(
|
||||||
yield self.sources["room"].get_current_key()
|
yield self.sources["room"].get_current_key(direction)
|
||||||
),
|
),
|
||||||
presence_key=(
|
presence_key=(
|
||||||
yield self.sources["presence"].get_current_key()
|
yield self.sources["presence"].get_current_key()
|
||||||
|
|
|
@ -70,6 +70,8 @@ class DomainSpecificString(
|
||||||
"""Return a string encoding the fields of the structure object."""
|
"""Return a string encoding the fields of the structure object."""
|
||||||
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
|
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
|
||||||
|
|
||||||
|
__str__ = to_string
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, localpart, domain,):
|
def create(cls, localpart, domain,):
|
||||||
return cls(localpart=localpart, domain=domain)
|
return cls(localpart=localpart, domain=domain)
|
||||||
|
@ -107,7 +109,6 @@ class StreamToken(
|
||||||
def from_string(cls, string):
|
def from_string(cls, string):
|
||||||
try:
|
try:
|
||||||
keys = string.split(cls._SEPARATOR)
|
keys = string.split(cls._SEPARATOR)
|
||||||
|
|
||||||
return cls(*keys)
|
return cls(*keys)
|
||||||
except:
|
except:
|
||||||
raise SynapseError(400, "Invalid Token")
|
raise SynapseError(400, "Invalid Token")
|
||||||
|
@ -115,10 +116,95 @@ class StreamToken(
|
||||||
def to_string(self):
|
def to_string(self):
|
||||||
return self._SEPARATOR.join([str(k) for k in self])
|
return self._SEPARATOR.join([str(k) for k in self])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def room_stream_id(self):
|
||||||
|
# TODO(markjh): Awful hack to work around hacks in the presence tests
|
||||||
|
# which assume that the keys are integers.
|
||||||
|
if type(self.room_key) is int:
|
||||||
|
return self.room_key
|
||||||
|
else:
|
||||||
|
return int(self.room_key[1:].split("-")[-1])
|
||||||
|
|
||||||
|
def is_after(self, other_token):
|
||||||
|
"""Does this token contain events that the other doesn't?"""
|
||||||
|
return (
|
||||||
|
(other_token.room_stream_id < self.room_stream_id)
|
||||||
|
or (int(other_token.presence_key) < int(self.presence_key))
|
||||||
|
or (int(other_token.typing_key) < int(self.typing_key))
|
||||||
|
)
|
||||||
|
|
||||||
|
def copy_and_advance(self, key, new_value):
|
||||||
|
"""Advance the given key in the token to a new value if and only if the
|
||||||
|
new value is after the old value.
|
||||||
|
"""
|
||||||
|
new_token = self.copy_and_replace(key, new_value)
|
||||||
|
if key == "room_key":
|
||||||
|
new_id = new_token.room_stream_id
|
||||||
|
old_id = self.room_stream_id
|
||||||
|
else:
|
||||||
|
new_id = int(getattr(new_token, key))
|
||||||
|
old_id = int(getattr(self, key))
|
||||||
|
if old_id < new_id:
|
||||||
|
return new_token
|
||||||
|
else:
|
||||||
|
return self
|
||||||
|
|
||||||
def copy_and_replace(self, key, new_value):
|
def copy_and_replace(self, key, new_value):
|
||||||
d = self._asdict()
|
d = self._asdict()
|
||||||
d[key] = new_value
|
d[key] = new_value
|
||||||
return StreamToken(**d)
|
return StreamToken(**d)
|
||||||
|
|
||||||
|
|
||||||
|
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||||
|
"""Tokens are positions between events. The token "s1" comes after event 1.
|
||||||
|
|
||||||
|
s0 s1
|
||||||
|
| |
|
||||||
|
[0] V [1] V [2]
|
||||||
|
|
||||||
|
Tokens can either be a point in the live event stream or a cursor going
|
||||||
|
through historic events.
|
||||||
|
|
||||||
|
When traversing the live event stream events are ordered by when they
|
||||||
|
arrived at the homeserver.
|
||||||
|
|
||||||
|
When traversing historic events the events are ordered by their depth in
|
||||||
|
the event graph "topological_ordering" and then by when they arrived at the
|
||||||
|
homeserver "stream_ordering".
|
||||||
|
|
||||||
|
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
||||||
|
event it comes after. Historic tokens start with a "t" followed by the
|
||||||
|
"topological_ordering" id of the event it comes after, follewed by "-",
|
||||||
|
followed by the "stream_ordering" id of the event it comes after.
|
||||||
|
"""
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, string):
|
||||||
|
try:
|
||||||
|
if string[0] == 's':
|
||||||
|
return cls(topological=None, stream=int(string[1:]))
|
||||||
|
if string[0] == 't':
|
||||||
|
parts = string[1:].split('-', 1)
|
||||||
|
return cls(topological=int(parts[0]), stream=int(parts[1]))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_stream_token(cls, string):
|
||||||
|
try:
|
||||||
|
if string[0] == 's':
|
||||||
|
return cls(topological=None, stream=int(string[1:]))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.topological is not None:
|
||||||
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
|
else:
|
||||||
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
|
|
||||||
from twisted.internet import defer, reactor, task
|
from twisted.internet import defer, reactor, task
|
||||||
|
|
||||||
|
@ -23,6 +23,40 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def unwrapFirstError(failure):
|
||||||
|
# defer.gatherResults and DeferredLists wrap failures.
|
||||||
|
failure.trap(defer.FirstError)
|
||||||
|
return failure.value.subFailure
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_deferred(d):
|
||||||
|
"""Given a deferred that we know has completed, return its value or raise
|
||||||
|
the failure as an exception
|
||||||
|
"""
|
||||||
|
if not d.called:
|
||||||
|
raise RuntimeError("deferred has not finished")
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
def f(r):
|
||||||
|
res.append(r)
|
||||||
|
return r
|
||||||
|
d.addCallback(f)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
return res[0]
|
||||||
|
|
||||||
|
def f(r):
|
||||||
|
res.append(r)
|
||||||
|
return r
|
||||||
|
d.addErrback(f)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
res[0].raiseException()
|
||||||
|
else:
|
||||||
|
raise RuntimeError("deferred did not call callbacks")
|
||||||
|
|
||||||
|
|
||||||
class Clock(object):
|
class Clock(object):
|
||||||
"""A small utility that obtains current time-of-day so that time may be
|
"""A small utility that obtains current time-of-day so that time may be
|
||||||
mocked during unit-tests.
|
mocked during unit-tests.
|
||||||
|
@ -46,13 +80,16 @@ class Clock(object):
|
||||||
def stop_looping_call(self, loop):
|
def stop_looping_call(self, loop):
|
||||||
loop.stop()
|
loop.stop()
|
||||||
|
|
||||||
def call_later(self, delay, callback):
|
def call_later(self, delay, callback, *args, **kwargs):
|
||||||
current_context = LoggingContext.current_context()
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
def wrapped_callback():
|
def wrapped_callback(*args, **kwargs):
|
||||||
LoggingContext.thread_local.current_context = current_context
|
with PreserveLoggingContext():
|
||||||
callback()
|
LoggingContext.thread_local.current_context = current_context
|
||||||
return reactor.callLater(delay, wrapped_callback)
|
callback(*args, **kwargs)
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||||
|
|
||||||
def cancel_call_later(self, timer):
|
def cancel_call_later(self, timer):
|
||||||
timer.cancel()
|
timer.cancel()
|
||||||
|
|
|
@ -16,15 +16,13 @@
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
from .logcontext import PreserveLoggingContext
|
from .logcontext import preserve_context_over_deferred
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def sleep(seconds):
|
def sleep(seconds):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
reactor.callLater(seconds, d.callback, seconds)
|
reactor.callLater(seconds, d.callback, seconds)
|
||||||
with PreserveLoggingContext():
|
return preserve_context_over_deferred(d)
|
||||||
yield d
|
|
||||||
|
|
||||||
|
|
||||||
def run_on_reactor():
|
def run_on_reactor():
|
||||||
|
@ -34,20 +32,56 @@ def run_on_reactor():
|
||||||
return sleep(0)
|
return sleep(0)
|
||||||
|
|
||||||
|
|
||||||
def create_observer(deferred):
|
class ObservableDeferred(object):
|
||||||
"""Creates a deferred that observes the result or failure of the given
|
"""Wraps a deferred object so that we can add observer deferreds. These
|
||||||
deferred *without* affecting the given deferred.
|
observer deferreds do not affect the callback chain of the original
|
||||||
|
deferred.
|
||||||
|
|
||||||
|
If consumeErrors is true errors will be captured from the origin deferred.
|
||||||
"""
|
"""
|
||||||
d = defer.Deferred()
|
|
||||||
|
|
||||||
def callback(r):
|
__slots__ = ["_deferred", "_observers", "_result"]
|
||||||
d.callback(r)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def errback(f):
|
def __init__(self, deferred, consumeErrors=False):
|
||||||
d.errback(f)
|
object.__setattr__(self, "_deferred", deferred)
|
||||||
return f
|
object.__setattr__(self, "_result", None)
|
||||||
|
object.__setattr__(self, "_observers", [])
|
||||||
|
|
||||||
deferred.addCallbacks(callback, errback)
|
def callback(r):
|
||||||
|
self._result = (True, r)
|
||||||
|
while self._observers:
|
||||||
|
try:
|
||||||
|
self._observers.pop().callback(r)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return r
|
||||||
|
|
||||||
return d
|
def errback(f):
|
||||||
|
self._result = (False, f)
|
||||||
|
while self._observers:
|
||||||
|
try:
|
||||||
|
self._observers.pop().errback(f)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if consumeErrors:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return f
|
||||||
|
|
||||||
|
deferred.addCallbacks(callback, errback)
|
||||||
|
|
||||||
|
def observe(self):
|
||||||
|
if not self._result:
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._observers.append(d)
|
||||||
|
return d
|
||||||
|
else:
|
||||||
|
success, res = self._result
|
||||||
|
return defer.succeed(res) if success else defer.fail(res)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._deferred, name)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
setattr(self._deferred, name, value)
|
||||||
|
|
|
@ -13,10 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.util.logcontext import (
|
||||||
|
PreserveLoggingContext, preserve_context_over_deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,7 +97,6 @@ class Signal(object):
|
||||||
Each observer callable may return a Deferred."""
|
Each observer callable may return a Deferred."""
|
||||||
self.observers.append(observer)
|
self.observers.append(observer)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def fire(self, *args, **kwargs):
|
def fire(self, *args, **kwargs):
|
||||||
"""Invokes every callable in the observer list, passing in the args and
|
"""Invokes every callable in the observer list, passing in the args and
|
||||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||||
|
@ -101,24 +104,28 @@ class Signal(object):
|
||||||
|
|
||||||
Returns a Deferred that will complete when all the observers have
|
Returns a Deferred that will complete when all the observers have
|
||||||
completed."""
|
completed."""
|
||||||
with PreserveLoggingContext():
|
|
||||||
deferreds = []
|
|
||||||
for observer in self.observers:
|
|
||||||
d = defer.maybeDeferred(observer, *args, **kwargs)
|
|
||||||
|
|
||||||
def eb(failure):
|
def do(observer):
|
||||||
logger.warning(
|
def eb(failure):
|
||||||
"%s signal observer %s failed: %r",
|
logger.warning(
|
||||||
self.name, observer, failure,
|
"%s signal observer %s failed: %r",
|
||||||
exc_info=(
|
self.name, observer, failure,
|
||||||
failure.type,
|
exc_info=(
|
||||||
failure.value,
|
failure.type,
|
||||||
failure.getTracebackObject()))
|
failure.value,
|
||||||
if not self.suppress_failures:
|
failure.getTracebackObject()))
|
||||||
failure.raiseException()
|
if not self.suppress_failures:
|
||||||
deferreds.append(d.addErrback(eb))
|
return failure
|
||||||
results = []
|
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
|
||||||
for deferred in deferreds:
|
|
||||||
result = yield deferred
|
with PreserveLoggingContext():
|
||||||
results.append(result)
|
deferreds = [
|
||||||
defer.returnValue(results)
|
do(observer)
|
||||||
|
for observer in self.observers
|
||||||
|
]
|
||||||
|
|
||||||
|
d = defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
|
|
||||||
|
d.addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
return preserve_context_over_deferred(d)
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
"""Restores the current logging context"""
|
"""Restores the current logging context"""
|
||||||
LoggingContext.thread_local.current_context = self.current_context
|
LoggingContext.thread_local.current_context = self.current_context
|
||||||
|
|
||||||
|
if self.current_context is not LoggingContext.sentinel:
|
||||||
|
if self.current_context.parent_context is None:
|
||||||
|
logger.warn(
|
||||||
|
"Restoring dead context: %s",
|
||||||
|
self.current_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_context_over_fn(fn, *args, **kwargs):
|
||||||
|
"""Takes a function and invokes it with the given arguments, but removes
|
||||||
|
and restores the current logging context while doing so.
|
||||||
|
|
||||||
|
If the result is a deferred, call preserve_context_over_deferred before
|
||||||
|
returning it.
|
||||||
|
"""
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
res = fn(*args, **kwargs)
|
||||||
|
|
||||||
|
if isinstance(res, defer.Deferred):
|
||||||
|
return preserve_context_over_deferred(res)
|
||||||
|
else:
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_context_over_deferred(deferred):
|
||||||
|
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||||
|
be invoked with the current context.
|
||||||
|
"""
|
||||||
|
d = defer.Deferred()
|
||||||
|
|
||||||
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
def cb(res):
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
LoggingContext.thread_local.current_context = current_context
|
||||||
|
res = d.callback(res)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def eb(failure):
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
LoggingContext.thread_local.current_context = current_context
|
||||||
|
res = d.errback(failure)
|
||||||
|
return res
|
||||||
|
|
||||||
|
if deferred.called:
|
||||||
|
return deferred
|
||||||
|
|
||||||
|
deferred.addCallbacks(cb, eb)
|
||||||
|
return d
|
||||||
|
|
|
@ -217,18 +217,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
_regex("@irc_.*")
|
_regex("@irc_.*")
|
||||||
)
|
)
|
||||||
join_list = [
|
join_list = [
|
||||||
Mock(
|
"@alice:here",
|
||||||
type="m.room.member", room_id="!foo:bar", sender="@alice:here",
|
"@irc_fo:here", # AS user
|
||||||
state_key="@alice:here"
|
"@bob:here",
|
||||||
),
|
|
||||||
Mock(
|
|
||||||
type="m.room.member", room_id="!foo:bar", sender="@irc_fo:here",
|
|
||||||
state_key="@irc_fo:here" # AS user
|
|
||||||
),
|
|
||||||
Mock(
|
|
||||||
type="m.room.member", room_id="!foo:bar", sender="@bob:here",
|
|
||||||
state_key="@bob:here"
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||||
|
|
|
@ -83,7 +83,7 @@ class FederationTestCase(unittest.TestCase):
|
||||||
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
||||||
})
|
})
|
||||||
|
|
||||||
self.datastore.persist_event.return_value = defer.succeed(None)
|
self.datastore.persist_event.return_value = defer.succeed((1,1))
|
||||||
self.datastore.get_room.return_value = defer.succeed(True)
|
self.datastore.get_room.return_value = defer.succeed(True)
|
||||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||||
|
|
||||||
|
@ -126,5 +126,5 @@ class FederationTestCase(unittest.TestCase):
|
||||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||||
|
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
self.notifier.on_new_room_event.assert_called_once_with(
|
||||||
ANY, extra_users=[]
|
ANY, 1, 1, extra_users=[]
|
||||||
)
|
)
|
||||||
|
|
|
@ -233,7 +233,7 @@ class MockedDatastorePresenceTestCase(PresenceTestCase):
|
||||||
if not user_localpart in self.PRESENCE_LIST:
|
if not user_localpart in self.PRESENCE_LIST:
|
||||||
return defer.succeed([])
|
return defer.succeed([])
|
||||||
return defer.succeed([
|
return defer.succeed([
|
||||||
{"observed_user_id": u} for u in
|
{"observed_user_id": u, "accepted": accepted} for u in
|
||||||
self.PRESENCE_LIST[user_localpart]])
|
self.PRESENCE_LIST[user_localpart]])
|
||||||
datastore.get_presence_list = get_presence_list
|
datastore.get_presence_list = get_presence_list
|
||||||
|
|
||||||
|
@ -624,6 +624,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
"""
|
"""
|
||||||
PRESENCE_LIST = {
|
PRESENCE_LIST = {
|
||||||
'apple': [ "@banana:test", "@clementine:test" ],
|
'apple': [ "@banana:test", "@clementine:test" ],
|
||||||
|
'banana': [ "@apple:test" ],
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -733,10 +734,12 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
[
|
[
|
||||||
{"observed_user": self.u_banana,
|
{"observed_user": self.u_banana,
|
||||||
"presence": OFFLINE},
|
"presence": OFFLINE,
|
||||||
|
"accepted": True},
|
||||||
{"observed_user": self.u_clementine,
|
{"observed_user": self.u_clementine,
|
||||||
"presence": OFFLINE},
|
"presence": OFFLINE,
|
||||||
|
"accepted": True},
|
||||||
],
|
],
|
||||||
presence
|
presence
|
||||||
)
|
)
|
||||||
|
@ -757,9 +760,11 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
self.assertEquals([
|
self.assertEquals([
|
||||||
{"observed_user": self.u_banana,
|
{"observed_user": self.u_banana,
|
||||||
"presence": ONLINE,
|
"presence": ONLINE,
|
||||||
"last_active_ago": 2000},
|
"last_active_ago": 2000,
|
||||||
|
"accepted": True},
|
||||||
{"observed_user": self.u_clementine,
|
{"observed_user": self.u_clementine,
|
||||||
"presence": OFFLINE},
|
"presence": OFFLINE,
|
||||||
|
"accepted": True},
|
||||||
], presence)
|
], presence)
|
||||||
|
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||||
|
@ -836,12 +841,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_recv_remote(self):
|
def test_recv_remote(self):
|
||||||
# TODO(paul): Gut-wrenching
|
self.room_members = [self.u_apple, self.u_banana, self.u_potato]
|
||||||
potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
|
|
||||||
set())
|
|
||||||
potato_set.add(self.u_apple)
|
|
||||||
|
|
||||||
self.room_members = [self.u_banana, self.u_potato]
|
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||||
|
|
||||||
|
@ -886,11 +886,8 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_recv_remote_offline(self):
|
def test_recv_remote_offline(self):
|
||||||
""" Various tests relating to SYN-261 """
|
""" Various tests relating to SYN-261 """
|
||||||
potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
|
|
||||||
set())
|
|
||||||
potato_set.add(self.u_apple)
|
|
||||||
|
|
||||||
self.room_members = [self.u_banana, self.u_potato]
|
self.room_members = [self.u_apple, self.u_banana, self.u_potato]
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||||
|
|
||||||
|
@ -1097,12 +1094,8 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
|
||||||
|
|
||||||
# apple should see both banana and clementine currently offline
|
# apple should see both banana and clementine currently offline
|
||||||
self.mock_update_client.assert_has_calls([
|
self.mock_update_client.assert_has_calls([
|
||||||
call(users_to_push=[self.u_apple],
|
call(users_to_push=[self.u_apple]),
|
||||||
observed_user=self.u_banana,
|
call(users_to_push=[self.u_apple]),
|
||||||
statuscache=ANY),
|
|
||||||
call(users_to_push=[self.u_apple],
|
|
||||||
observed_user=self.u_clementine,
|
|
||||||
statuscache=ANY),
|
|
||||||
], any_order=True)
|
], any_order=True)
|
||||||
|
|
||||||
# Gut-wrenching tests
|
# Gut-wrenching tests
|
||||||
|
@ -1121,13 +1114,8 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
|
||||||
|
|
||||||
# apple and banana should now both see each other online
|
# apple and banana should now both see each other online
|
||||||
self.mock_update_client.assert_has_calls([
|
self.mock_update_client.assert_has_calls([
|
||||||
call(users_to_push=set([self.u_apple]),
|
call(users_to_push=set([self.u_apple]), room_ids=[]),
|
||||||
observed_user=self.u_banana,
|
call(users_to_push=[self.u_banana]),
|
||||||
room_ids=[],
|
|
||||||
statuscache=ANY),
|
|
||||||
call(users_to_push=[self.u_banana],
|
|
||||||
observed_user=self.u_apple,
|
|
||||||
statuscache=ANY),
|
|
||||||
], any_order=True)
|
], any_order=True)
|
||||||
|
|
||||||
self.assertTrue("apple" in self.handler._local_pushmap)
|
self.assertTrue("apple" in self.handler._local_pushmap)
|
||||||
|
@ -1143,10 +1131,7 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
|
||||||
|
|
||||||
# banana should now be told apple is offline
|
# banana should now be told apple is offline
|
||||||
self.mock_update_client.assert_has_calls([
|
self.mock_update_client.assert_has_calls([
|
||||||
call(users_to_push=set([self.u_banana, self.u_apple]),
|
call(users_to_push=set([self.u_banana, self.u_apple]), room_ids=[]),
|
||||||
observed_user=self.u_apple,
|
|
||||||
room_ids=[],
|
|
||||||
statuscache=ANY),
|
|
||||||
], any_order=True)
|
], any_order=True)
|
||||||
|
|
||||||
self.assertFalse("banana" in self.handler._local_pushmap)
|
self.assertFalse("banana" in self.handler._local_pushmap)
|
||||||
|
|
|
@ -101,8 +101,8 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
|
||||||
self.datastore.get_profile_avatar_url = get_profile_avatar_url
|
self.datastore.get_profile_avatar_url = get_profile_avatar_url
|
||||||
|
|
||||||
self.presence_list = [
|
self.presence_list = [
|
||||||
{"observed_user_id": "@banana:test"},
|
{"observed_user_id": "@banana:test", "accepted": True},
|
||||||
{"observed_user_id": "@clementine:test"},
|
{"observed_user_id": "@clementine:test", "accepted": True},
|
||||||
]
|
]
|
||||||
def get_presence_list(user_localpart, accepted=None):
|
def get_presence_list(user_localpart, accepted=None):
|
||||||
return defer.succeed(self.presence_list)
|
return defer.succeed(self.presence_list)
|
||||||
|
@ -144,8 +144,8 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_state(self):
|
def test_set_my_state(self):
|
||||||
self.presence_list = [
|
self.presence_list = [
|
||||||
{"observed_user_id": "@banana:test"},
|
{"observed_user_id": "@banana:test", "accepted": True},
|
||||||
{"observed_user_id": "@clementine:test"},
|
{"observed_user_id": "@clementine:test", "accepted": True},
|
||||||
]
|
]
|
||||||
|
|
||||||
mocked_set = self.datastore.set_presence_state
|
mocked_set = self.datastore.set_presence_state
|
||||||
|
@ -167,8 +167,8 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
|
||||||
self.mock_get_joined.side_effect = get_joined
|
self.mock_get_joined.side_effect = get_joined
|
||||||
|
|
||||||
self.presence_list = [
|
self.presence_list = [
|
||||||
{"observed_user_id": "@banana:test"},
|
{"observed_user_id": "@banana:test", "accepted": True},
|
||||||
{"observed_user_id": "@clementine:test"},
|
{"observed_user_id": "@clementine:test", "accepted": True},
|
||||||
]
|
]
|
||||||
|
|
||||||
self.datastore.set_presence_state.return_value = defer.succeed(
|
self.datastore.set_presence_state.return_value = defer.succeed(
|
||||||
|
@ -203,26 +203,20 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
|
||||||
"presence": ONLINE,
|
"presence": ONLINE,
|
||||||
"last_active_ago": 0,
|
"last_active_ago": 0,
|
||||||
"displayname": "Frank",
|
"displayname": "Frank",
|
||||||
"avatar_url": "http://foo"},
|
"avatar_url": "http://foo",
|
||||||
|
"accepted": True},
|
||||||
{"observed_user": self.u_clementine,
|
{"observed_user": self.u_clementine,
|
||||||
"presence": OFFLINE}
|
"presence": OFFLINE,
|
||||||
|
"accepted": True}
|
||||||
], presence)
|
], presence)
|
||||||
|
|
||||||
self.mock_update_client.assert_has_calls([
|
self.mock_update_client.assert_has_calls([
|
||||||
call(users_to_push=set([self.u_apple, self.u_banana, self.u_clementine]),
|
call(
|
||||||
room_ids=[],
|
users_to_push={self.u_apple, self.u_banana, self.u_clementine},
|
||||||
observed_user=self.u_apple,
|
room_ids=[]
|
||||||
statuscache=ANY), # self-reflection
|
),
|
||||||
], any_order=True)
|
], any_order=True)
|
||||||
|
|
||||||
statuscache = self.mock_update_client.call_args[1]["statuscache"]
|
|
||||||
self.assertEquals({
|
|
||||||
"presence": ONLINE,
|
|
||||||
"last_active": 1000000, # MockClock
|
|
||||||
"displayname": "Frank",
|
|
||||||
"avatar_url": "http://foo",
|
|
||||||
}, statuscache.state)
|
|
||||||
|
|
||||||
self.mock_update_client.reset_mock()
|
self.mock_update_client.reset_mock()
|
||||||
|
|
||||||
self.datastore.set_profile_displayname.return_value = defer.succeed(
|
self.datastore.set_profile_displayname.return_value = defer.succeed(
|
||||||
|
@ -232,25 +226,16 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
|
||||||
self.u_apple, "I am an Apple")
|
self.u_apple, "I am an Apple")
|
||||||
|
|
||||||
self.mock_update_client.assert_has_calls([
|
self.mock_update_client.assert_has_calls([
|
||||||
call(users_to_push=set([self.u_apple, self.u_banana, self.u_clementine]),
|
call(
|
||||||
|
users_to_push={self.u_apple, self.u_banana, self.u_clementine},
|
||||||
room_ids=[],
|
room_ids=[],
|
||||||
observed_user=self.u_apple,
|
),
|
||||||
statuscache=ANY), # self-reflection
|
|
||||||
], any_order=True)
|
], any_order=True)
|
||||||
|
|
||||||
statuscache = self.mock_update_client.call_args[1]["statuscache"]
|
|
||||||
self.assertEquals({
|
|
||||||
"presence": ONLINE,
|
|
||||||
"last_active": 1000000, # MockClock
|
|
||||||
"displayname": "I am an Apple",
|
|
||||||
"avatar_url": "http://foo",
|
|
||||||
}, statuscache.state)
|
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_push_remote(self):
|
def test_push_remote(self):
|
||||||
self.presence_list = [
|
self.presence_list = [
|
||||||
{"observed_user_id": "@potato:remote"},
|
{"observed_user_id": "@potato:remote", "accepted": True},
|
||||||
]
|
]
|
||||||
|
|
||||||
self.datastore.set_presence_state.return_value = defer.succeed(
|
self.datastore.set_presence_state.return_value = defer.succeed(
|
||||||
|
@ -314,13 +299,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
|
||||||
self.mock_update_client.assert_called_with(
|
self.mock_update_client.assert_called_with(
|
||||||
users_to_push=set([self.u_apple]),
|
users_to_push=set([self.u_apple]),
|
||||||
room_ids=[],
|
room_ids=[],
|
||||||
observed_user=self.u_potato,
|
)
|
||||||
statuscache=ANY)
|
|
||||||
|
|
||||||
statuscache = self.mock_update_client.call_args[1]["statuscache"]
|
|
||||||
self.assertEquals({"presence": ONLINE,
|
|
||||||
"displayname": "Frank",
|
|
||||||
"avatar_url": "http://foo"}, statuscache.state)
|
|
||||||
|
|
||||||
state = yield self.handlers.presence_handler.get_state(self.u_potato,
|
state = yield self.handlers.presence_handler.get_state(self.u_potato,
|
||||||
self.u_apple)
|
self.u_apple)
|
||||||
|
|
|
@ -87,6 +87,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
|
||||||
|
self.datastore.persist_event.return_value = (1,1)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_invite(self):
|
def test_invite(self):
|
||||||
room_id = "!foo:red"
|
room_id = "!foo:red"
|
||||||
|
@ -160,7 +162,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
event, context=context,
|
event, context=context,
|
||||||
)
|
)
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
self.notifier.on_new_room_event.assert_called_once_with(
|
||||||
event, extra_users=[UserID.from_string(target_user_id)]
|
event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
|
||||||
)
|
)
|
||||||
self.assertFalse(self.datastore.get_room.called)
|
self.assertFalse(self.datastore.get_room.called)
|
||||||
self.assertFalse(self.datastore.store_room.called)
|
self.assertFalse(self.datastore.store_room.called)
|
||||||
|
@ -226,7 +228,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
event, context=context
|
event, context=context
|
||||||
)
|
)
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
self.notifier.on_new_room_event.assert_called_once_with(
|
||||||
event, extra_users=[user]
|
event, 1, 1, extra_users=[user]
|
||||||
)
|
)
|
||||||
|
|
||||||
join_signal_observer.assert_called_with(
|
join_signal_observer.assert_called_with(
|
||||||
|
@ -304,7 +306,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
event, context=context
|
event, context=context
|
||||||
)
|
)
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
self.notifier.on_new_room_event.assert_called_once_with(
|
||||||
event, extra_users=[user]
|
event, 1, 1, extra_users=[user]
|
||||||
)
|
)
|
||||||
|
|
||||||
leave_signal_observer.assert_called_with(
|
leave_signal_observer.assert_called_with(
|
||||||
|
|
|
@ -183,7 +183,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.on_new_user_event.assert_has_calls([
|
self.on_new_user_event.assert_has_calls([
|
||||||
call(rooms=[self.room_id]),
|
call('typing_key', 1, rooms=[self.room_id]),
|
||||||
])
|
])
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
@ -246,7 +246,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.on_new_user_event.assert_has_calls([
|
self.on_new_user_event.assert_has_calls([
|
||||||
call(rooms=[self.room_id]),
|
call('typing_key', 1, rooms=[self.room_id]),
|
||||||
])
|
])
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
@ -300,7 +300,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.on_new_user_event.assert_has_calls([
|
self.on_new_user_event.assert_has_calls([
|
||||||
call(rooms=[self.room_id]),
|
call('typing_key', 1, rooms=[self.room_id]),
|
||||||
])
|
])
|
||||||
|
|
||||||
yield put_json.await_calls()
|
yield put_json.await_calls()
|
||||||
|
@ -332,7 +332,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.on_new_user_event.assert_has_calls([
|
self.on_new_user_event.assert_has_calls([
|
||||||
call(rooms=[self.room_id]),
|
call('typing_key', 1, rooms=[self.room_id]),
|
||||||
])
|
])
|
||||||
self.on_new_user_event.reset_mock()
|
self.on_new_user_event.reset_mock()
|
||||||
|
|
||||||
|
@ -352,7 +352,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
self.clock.advance_time(11)
|
self.clock.advance_time(11)
|
||||||
|
|
||||||
self.on_new_user_event.assert_has_calls([
|
self.on_new_user_event.assert_has_calls([
|
||||||
call(rooms=[self.room_id]),
|
call('typing_key', 2, rooms=[self.room_id]),
|
||||||
])
|
])
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||||
|
@ -378,7 +378,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.on_new_user_event.assert_has_calls([
|
self.on_new_user_event.assert_has_calls([
|
||||||
call(rooms=[self.room_id]),
|
call('typing_key', 3, rooms=[self.room_id]),
|
||||||
])
|
])
|
||||||
self.on_new_user_event.reset_mock()
|
self.on_new_user_event.reset_mock()
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,9 @@ from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.rest.client.v1 import presence
|
from synapse.rest.client.v1 import presence
|
||||||
from synapse.rest.client.v1 import events
|
from synapse.rest.client.v1 import events
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util.async import run_on_reactor
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
|
||||||
OFFLINE = PresenceState.OFFLINE
|
OFFLINE = PresenceState.OFFLINE
|
||||||
|
@ -180,7 +183,7 @@ class PresenceListTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_my_list(self):
|
def test_get_my_list(self):
|
||||||
self.datastore.get_presence_list.return_value = defer.succeed(
|
self.datastore.get_presence_list.return_value = defer.succeed(
|
||||||
[{"observed_user_id": "@banana:test"}],
|
[{"observed_user_id": "@banana:test", "accepted": True}],
|
||||||
)
|
)
|
||||||
|
|
||||||
(code, response) = yield self.mock_resource.trigger("GET",
|
(code, response) = yield self.mock_resource.trigger("GET",
|
||||||
|
@ -188,7 +191,7 @@ class PresenceListTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEquals(200, code)
|
self.assertEquals(200, code)
|
||||||
self.assertEquals([
|
self.assertEquals([
|
||||||
{"user_id": "@banana:test", "presence": OFFLINE},
|
{"user_id": "@banana:test", "presence": OFFLINE, "accepted": True},
|
||||||
], response)
|
], response)
|
||||||
|
|
||||||
self.datastore.get_presence_list.assert_called_with(
|
self.datastore.get_presence_list.assert_called_with(
|
||||||
|
@ -264,11 +267,13 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
||||||
datastore=Mock(spec=[
|
datastore=Mock(spec=[
|
||||||
"set_presence_state",
|
"set_presence_state",
|
||||||
"get_presence_list",
|
"get_presence_list",
|
||||||
|
"get_rooms_for_user",
|
||||||
]),
|
]),
|
||||||
clock=Mock(spec=[
|
clock=Mock(spec=[
|
||||||
"call_later",
|
"call_later",
|
||||||
"cancel_call_later",
|
"cancel_call_later",
|
||||||
"time_msec",
|
"time_msec",
|
||||||
|
"looping_call",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -292,12 +297,21 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
|
hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
|
||||||
|
hs.handlers.room_member_handler.get_room_members = (
|
||||||
|
lambda r: self.room_members if r == "a-room" else []
|
||||||
|
)
|
||||||
|
|
||||||
self.mock_datastore = hs.get_datastore()
|
self.mock_datastore = hs.get_datastore()
|
||||||
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
|
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
|
||||||
self.mock_datastore.get_app_service_by_user_id = Mock(
|
self.mock_datastore.get_app_service_by_user_id = Mock(
|
||||||
return_value=defer.succeed(None)
|
return_value=defer.succeed(None)
|
||||||
)
|
)
|
||||||
|
self.mock_datastore.get_rooms_for_user = (
|
||||||
|
lambda u: [
|
||||||
|
namedtuple("Room", "room_id")(r)
|
||||||
|
for r in get_rooms_for_user(UserID.from_string(u))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def get_profile_displayname(user_id):
|
def get_profile_displayname(user_id):
|
||||||
return defer.succeed("Frank")
|
return defer.succeed("Frank")
|
||||||
|
@ -350,19 +364,19 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
||||||
self.mock_datastore.set_presence_state.return_value = defer.succeed(
|
self.mock_datastore.set_presence_state.return_value = defer.succeed(
|
||||||
{"state": ONLINE}
|
{"state": ONLINE}
|
||||||
)
|
)
|
||||||
self.mock_datastore.get_presence_list.return_value = defer.succeed(
|
self.mock_datastore.get_presence_list.return_value = defer.succeed([])
|
||||||
[]
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.presence.set_state(self.u_banana, self.u_banana,
|
yield self.presence.set_state(self.u_banana, self.u_banana,
|
||||||
state={"presence": ONLINE}
|
state={"presence": ONLINE}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
(code, response) = yield self.mock_resource.trigger("GET",
|
(code, response) = yield self.mock_resource.trigger("GET",
|
||||||
"/events?from=0_1_0&timeout=0", None)
|
"/events?from=s0_1_0&timeout=0", None)
|
||||||
|
|
||||||
self.assertEquals(200, code)
|
self.assertEquals(200, code)
|
||||||
self.assertEquals({"start": "0_1_0", "end": "0_2_0", "chunk": [
|
self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
|
||||||
{"type": "m.presence",
|
{"type": "m.presence",
|
||||||
"content": {
|
"content": {
|
||||||
"user_id": "@banana:test",
|
"user_id": "@banana:test",
|
||||||
|
|
|
@ -33,8 +33,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.db_pool = Mock(spec=["runInteraction"])
|
self.db_pool = Mock(spec=["runInteraction"])
|
||||||
self.mock_txn = Mock()
|
self.mock_txn = Mock()
|
||||||
self.mock_conn = Mock(spec_set=["cursor"])
|
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
|
||||||
self.mock_conn.cursor.return_value = self.mock_txn
|
self.mock_conn.cursor.return_value = self.mock_txn
|
||||||
|
self.mock_conn.rollback.return_value = None
|
||||||
# Our fake runInteraction just runs synchronously inline
|
# Our fake runInteraction just runs synchronously inline
|
||||||
|
|
||||||
def runInteraction(func, *args, **kwargs):
|
def runInteraction(func, *args, **kwargs):
|
||||||
|
|
|
@ -197,6 +197,9 @@ class MockClock(object):
|
||||||
|
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
def looping_call(self, function, interval):
|
||||||
|
pass
|
||||||
|
|
||||||
def cancel_call_later(self, timer):
|
def cancel_call_later(self, timer):
|
||||||
if timer[2]:
|
if timer[2]:
|
||||||
raise Exception("Cannot cancel an expired timer")
|
raise Exception("Cannot cancel an expired timer")
|
||||||
|
@ -355,7 +358,7 @@ class MemoryDataStore(object):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_room_events_max_id(self):
|
def get_room_events_max_id(self):
|
||||||
return 0 # TODO (erikj)
|
return "s0" # TODO (erikj)
|
||||||
|
|
||||||
def get_send_event_level(self, room_id):
|
def get_send_event_level(self, room_id):
|
||||||
return defer.succeed(0)
|
return defer.succeed(0)
|
||||||
|
|
Loading…
Reference in a new issue