0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-17 07:21:37 +01:00

Merge branch 'release-v0.9.1' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2015-05-26 16:03:32 +01:00
commit 6d1dea337b
64 changed files with 2529 additions and 1282 deletions

View file

@ -1,3 +1,30 @@
Changes in synapse v0.9.1 (2015-05-26)
======================================
General:
* Add support for backfilling when a client paginates. This allows servers to
request history for a room from remote servers when a client tries to
paginate history the server does not have - SYN-36
* Fix bug where you couldn't disable non-default pushrules - SYN-378
* Fix ``register_new_user`` script - SYN-359
* Improve performance of fetching events from the database, this improves both
initialSync and sending of events.
* Improve performance of event streams, allowing synapse to handle more
simultaneous connected clients.
Federation:
* Fix bug with existing backfill implementation where it returned the wrong
selection of events in some circumstances.
* Improve performance of joining remote rooms.
Configuration:
* Add support for changing the bind host of the metrics listener via the
``metrics_bind_host`` option.
Changes in synapse v0.9.0-r5 (2015-05-21) Changes in synapse v0.9.0-r5 (2015-05-21)
========================================= =========================================

View file

@ -117,7 +117,7 @@ Installing prerequisites on Mac OS X::
To install the synapse homeserver run:: To install the synapse homeserver run::
$ virtualenv ~/.synapse $ virtualenv -p python2.7 ~/.synapse
$ source ~/.synapse/bin/activate $ source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master $ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master

View file

@ -31,6 +31,7 @@ for port in 8080 8081 8082; do
#rm $DIR/etc/$port.config #rm $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--generate-config \ --generate-config \
--enable_registration \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \

View file

@ -0,0 +1,116 @@
import psycopg2
import yaml
import sys
import json
import time
import hashlib
from syutil.base64util import encode_base64
from syutil.crypto.signing_key import read_signing_keys
from syutil.crypto.jsonsign import sign_json
from syutil.jsonutil import encode_canonical_json
def select_v1_keys(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, key_id, verify_key FROM server_signature_keys")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, key_id, verify_key in rows:
results.setdefault(server_name, {})[key_id] = encode_base64(verify_key)
return results
def select_v1_certs(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, tls_certificate FROM server_tls_certificates")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, tls_certificate in rows:
results[server_name] = tls_certificate
return results
def select_v2_json(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, key_id, key_json FROM server_keys_json")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, key_id, key_json in rows:
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8"))
return results
def convert_v1_to_v2(server_name, valid_until, keys, certificate):
return {
"old_verify_keys": {},
"server_name": server_name,
"verify_keys": {
key_id: {"key": key}
for key_id, key in keys.items()
},
"valid_until_ts": valid_until,
"tls_fingerprints": [fingerprint(certificate)],
}
def fingerprint(certificate):
finger = hashlib.sha256(certificate)
return {"sha256": encode_base64(finger.digest())}
def rows_v2(server, json):
valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json))
def main():
config = yaml.load(open(sys.argv[1]))
valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24
server_name = config["server_name"]
signing_key = read_signing_keys(open(config["signing_key_path"]))[0]
database = config["database"]
assert database["name"] == "psycopg2", "Can only convert for postgresql"
args = database["args"]
args.pop("cp_max")
args.pop("cp_min")
connection = psycopg2.connect(**args)
keys = select_v1_keys(connection)
certificates = select_v1_certs(connection)
json = select_v2_json(connection)
result = {}
for server in keys:
if not server in json:
v2_json = convert_v1_to_v2(
server, valid_until, keys[server], certificates[server]
)
v2_json = sign_json(v2_json, server_name, signing_key)
result[server] = v2_json
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
rows = list(
row for server, json in result.items()
for row in rows_v2(server, json)
)
cursor = connection.cursor()
cursor.executemany(
"INSERT INTO server_keys_json ("
" server_name, key_id, from_server,"
" ts_added_ms, ts_valid_until_ms, key_json"
") VALUES (%s, %s, %s, %s, %s, %s)",
rows
)
connection.commit()
if __name__ == '__main__':
main()

View file

@ -33,9 +33,10 @@ def request_registration(user, password, server_location, shared_secret):
).hexdigest() ).hexdigest()
data = { data = {
"username": user, "user": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret",
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@ -43,7 +44,7 @@ def request_registration(user, password, server_location, shared_secret):
print "Sending registration request..." print "Sending registration request..."
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/v2_alpha/register" % (server_location,), "%s/_matrix/client/api/v1/register" % (server_location,),
data=json.dumps(data), data=json.dumps(data),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'}
) )

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.9.0-r5" __version__ = "0.9.1"

View file

@ -32,9 +32,9 @@ from synapse.server import HomeServer
from twisted.internet import reactor from twisted.internet import reactor
from twisted.application import service from twisted.application import service
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.web.resource import Resource from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site, GzipEncoderFactory
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
@ -69,16 +69,26 @@ import subprocess
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
class GzipFile(File):
def getChild(self, path, request):
child = File.getChild(self, path, request)
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def build_http_client(self): def build_http_client(self):
return MatrixFederationHttpClient(self) return MatrixFederationHttpClient(self)
def build_resource_for_client(self): def build_resource_for_client(self):
return ClientV1RestResource(self) return gz_wrap(ClientV1RestResource(self))
def build_resource_for_client_v2_alpha(self): def build_resource_for_client_v2_alpha(self):
return ClientV2AlphaRestResource(self) return gz_wrap(ClientV2AlphaRestResource(self))
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
@ -87,9 +97,16 @@ class SynapseHomeServer(HomeServer):
import syweb import syweb
syweb_path = os.path.dirname(syweb.__file__) syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient") webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self): def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File("static") return File("static")
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
@ -260,9 +277,12 @@ class SynapseHomeServer(HomeServer):
config, config,
metrics_resource, metrics_resource,
), ),
interface="127.0.0.1", interface=config.metrics_bind_host,
)
logger.info(
"Metrics now running on %s port %d",
config.metrics_bind_host, config.metrics_port,
) )
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn, database_engine): def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain( all_users_native = are_all_users_on_domain(

View file

@ -148,8 +148,8 @@ class ApplicationService(object):
and self.is_interested_in_user(event.state_key)): and self.is_interested_in_user(event.state_key)):
return True return True
# check joined member events # check joined member events
for member in member_list: for user_id in member_list:
if self.is_interested_in_user(member.state_key): if self.is_interested_in_user(user_id):
return True return True
return False return False
@ -173,7 +173,7 @@ class ApplicationService(object):
restrict_to(str): The namespace to restrict regex tests to. restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for aliases_for_event(list): A list of all the known room aliases for
this event. this event.
member_list(list): A list of all joined room members in this room. member_list(list): A list of all joined user_ids in this room.
Returns: Returns:
bool: True if this service would like to know about this event. bool: True if this service would like to know about this event.
""" """

View file

@ -20,6 +20,7 @@ class MetricsConfig(Config):
def read_config(self, config): def read_config(self, config):
self.enable_metrics = config["enable_metrics"] self.enable_metrics = config["enable_metrics"]
self.metrics_port = config.get("metrics_port") self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name):
return """\ return """\
@ -28,6 +29,9 @@ class MetricsConfig(Config):
# Enable collection and rendering of performance metrics # Enable collection and rendering of performance metrics
enable_metrics: False enable_metrics: False
# Separate port to accept metrics requests on (on localhost) # Separate port to accept metrics requests on
# metrics_port: 8081 # metrics_port: 8081
# Which host to bind the metric listener to
# metrics_bind_host: 127.0.0.1
""" """

View file

@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import (
preserve_context_over_fn, preserve_context_over_deferred
)
import simplejson as json import simplejson as json
import logging import logging
@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5): for i in range(5):
try: try:
with PreserveLoggingContext(): protocol = yield preserve_context_over_fn(
protocol = yield endpoint.connect(factory) endpoint.connect, factory
server_response, server_certificate = yield protocol.remote_key )
defer.returnValue((server_response, server_certificate)) server_response, server_certificate = yield preserve_context_over_deferred(
return protocol.remote_key
)
defer.returnValue((server_response, server_certificate))
return
except SynapseKeyClientError as e: except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,)) logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"): if e.status.startswith("4"):

View file

@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util.async import create_observer from synapse.util.async import ObservableDeferred
from OpenSSL import crypto from OpenSSL import crypto
@ -111,6 +111,10 @@ class Keyring(object):
if download is None: if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids) download = self._get_server_verify_key_impl(server_name, key_ids)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.key_downloads[server_name] = download self.key_downloads[server_name] = download
@download.addBoth @download.addBoth
@ -118,30 +122,31 @@ class Keyring(object):
del self.key_downloads[server_name] del self.key_downloads[server_name]
return ret return ret
r = yield create_observer(download) r = yield download.observe()
defer.returnValue(r) defer.returnValue(r)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids): def _get_server_verify_key_impl(self, server_name, key_ids):
keys = None keys = None
perspective_results = [] @defer.inlineCallbacks
for perspective_name, perspective_keys in self.perspective_servers.items(): def get_key(perspective_name, perspective_keys):
@defer.inlineCallbacks try:
def get_key(): result = yield self.get_server_verify_key_v2_indirect(
try: server_name, key_ids, perspective_name, perspective_keys
result = yield self.get_server_verify_key_v2_indirect( )
server_name, key_ids, perspective_name, perspective_keys defer.returnValue(result)
) except Exception as e:
defer.returnValue(result) logging.info(
except: "Unable to getting key %r for %r from %r: %s %s",
logging.info( key_ids, server_name, perspective_name,
"Unable to getting key %r for %r from %r", type(e).__name__, str(e.message),
key_ids, server_name, perspective_name, )
)
perspective_results.append(get_key())
perspective_results = yield defer.gatherResults(perspective_results) perspective_results = yield defer.gatherResults([
get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
])
for results in perspective_results: for results in perspective_results:
if results is not None: if results is not None:
@ -154,17 +159,22 @@ class Keyring(object):
) )
with limiter: with limiter:
if keys is None: if not keys:
try: try:
keys = yield self.get_server_verify_key_v2_direct( keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids server_name, key_ids
) )
except: except Exception as e:
pass logging.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
keys = yield self.get_server_verify_key_v1_direct( if not keys:
server_name, key_ids keys = yield self.get_server_verify_key_v1_direct(
) server_name, key_ids
)
for key_id in key_ids: for key_id in key_ids:
if key_id in keys: if key_id in keys:
@ -184,7 +194,7 @@ class Keyring(object):
# TODO(mark): Set the minimum_valid_until_ts to that needed by # TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating # the events being validated or the current time if validating
# an incoming request. # an incoming request.
responses = yield self.client.post_json( query_response = yield self.client.post_json(
destination=perspective_name, destination=perspective_name,
path=b"/_matrix/key/v2/query", path=b"/_matrix/key/v2/query",
data={ data={
@ -200,6 +210,8 @@ class Keyring(object):
keys = {} keys = {}
responses = query_response["server_keys"]
for response in responses: for response in responses:
if (u"signatures" not in response if (u"signatures" not in response
or perspective_name not in response[u"signatures"]): or perspective_name not in response[u"signatures"]):
@ -323,7 +335,7 @@ class Keyring(object):
verify_key.time_added = time_now_ms verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key old_verify_keys[key_id] = verify_key
for key_id in response_json["signatures"][server_name]: for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"

View file

@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
import logging import logging
@ -78,6 +80,7 @@ class FederationBase(object):
destinations=[pdu.origin], destinations=[pdu.origin],
event_id=pdu.event_id, event_id=pdu.event_id,
outlier=outlier, outlier=outlier,
timeout=10000,
) )
if new_pdu: if new_pdu:
@ -94,7 +97,7 @@ class FederationBase(object):
yield defer.gatherResults( yield defer.gatherResults(
[do(pdu) for pdu in pdus], [do(pdu) for pdu in pdus],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus) defer.returnValue(signed_pdus)

View file

@ -22,6 +22,7 @@ from .units import Edu
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError
from synapse.util.expiringcache import ExpiringCache from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -164,16 +165,17 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
for i, pdu in enumerate(pdus): # FIXME: We should handle signature failures more gracefully.
pdus[i] = yield self._check_sigs_and_hash(pdu) pdus[:] = yield defer.gatherResults(
[self._check_sigs_and_hash(pdu) for pdu in pdus],
# FIXME: We should handle signature failures more gracefully. consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(pdus) defer.returnValue(pdus)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_pdu(self, destinations, event_id, outlier=False): def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
"""Requests the PDU with given origin and ID from the remote home """Requests the PDU with given origin and ID from the remote home
servers. servers.
@ -189,6 +191,8 @@ class FederationClient(FederationBase):
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False` of the current block of PDUs. Defaults to `False`
timeout (int): How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
Returns: Returns:
Deferred: Results in the requested PDU. Deferred: Results in the requested PDU.
@ -212,7 +216,7 @@ class FederationClient(FederationBase):
with limiter: with limiter:
transaction_data = yield self.transport_layer.get_event( transaction_data = yield self.transport_layer.get_event(
destination, event_id destination, event_id, timeout=timeout,
) )
logger.debug("transaction_data %r", transaction_data) logger.debug("transaction_data %r", transaction_data)
@ -222,7 +226,7 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
if pdu_list: if pdu_list and pdu_list[0]:
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
@ -255,7 +259,7 @@ class FederationClient(FederationBase):
) )
continue continue
if self._get_pdu_cache is not None: if self._get_pdu_cache is not None and pdu:
self._get_pdu_cache[event_id] = pdu self._get_pdu_cache[event_id] = pdu
defer.returnValue(pdu) defer.returnValue(pdu)
@ -370,13 +374,17 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
signed_state = yield self._check_sigs_and_hash_and_fetch( signed_state, signed_auth = yield defer.gatherResults(
destination, state, outlier=True [
) self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
signed_auth = yield self._check_sigs_and_hash_and_fetch( ),
destination, auth_chain, outlier=True self._check_sigs_and_hash_and_fetch(
) destination, auth_chain, outlier=True
)
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
@ -518,7 +526,7 @@ class FederationClient(FederationBase):
# Are we missing any? # Are we missing any?
seen_events = set(earliest_events_ids) seen_events = set(earliest_events_ids)
seen_events.update(e.event_id for e in signed_events) seen_events.update(e.event_id for e in signed_events if e)
missing_events = {} missing_events = {}
for e in itertools.chain(latest_events, signed_events): for e in itertools.chain(latest_events, signed_events):
@ -561,7 +569,7 @@ class FederationClient(FederationBase):
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing): for (result, val), (e_id, _) in zip(res, ordered_missing):
if result: if result and val:
signed_events.append(val) signed_events.append(val)
else: else:
failed_to_fetch.add(e_id) failed_to_fetch.add(e_id)

View file

@ -20,7 +20,6 @@ from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -123,29 +122,28 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext(): results = []
results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu) d = self._handle_new_pdu(transaction.origin, pdu)
try: try:
yield d yield d
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)
results.append({"error": str(e)}) results.append({"error": str(e)})
except Exception as e: except Exception as e:
results.append({"error": str(e)}) results.append({"error": str(e)})
logger.exception("Failed to handle PDU") logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu( self.received_edu(
transaction.origin, transaction.origin,
edu.edu_type, edu.edu_type,
edu.content edu.content
) )
for failure in getattr(transaction, "pdu_failures", []): for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure) logger.info("Got failure %r", failure)

View file

@ -207,13 +207,13 @@ class TransactionQueue(object):
# request at which point pending_pdus_by_dest just keeps growing. # request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these # we need application-layer timeouts of some flavour of these
# requests # requests
logger.info( logger.debug(
"TX [%s] Transaction already in progress", "TX [%s] Transaction already in progress",
destination destination
) )
return return
logger.info("TX [%s] _attempt_new_transaction", destination) logger.debug("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@ -221,11 +221,11 @@ class TransactionQueue(object):
pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus: if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus)) destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures: if not pending_pdus and not pending_edus and not pending_failures:
logger.info("TX [%s] Nothing to send", destination) logger.debug("TX [%s] Nothing to send", destination)
return return
# Sort based on the order field # Sort based on the order field
@ -242,6 +242,8 @@ class TransactionQueue(object):
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
self._clock, self._clock,
@ -249,9 +251,9 @@ class TransactionQueue(object):
) )
logger.debug( logger.debug(
"TX [%s] Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d, failures: %d)",
destination, destination, txn_id,
len(pending_pdus), len(pending_pdus),
len(pending_edus), len(pending_edus),
len(pending_failures) len(pending_failures)
@ -261,7 +263,7 @@ class TransactionQueue(object):
transaction = Transaction.create_new( transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()), origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id), transaction_id=txn_id,
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
pdus=pdus, pdus=pdus,
@ -275,9 +277,13 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination) logger.debug("TX [%s] Persisted transaction", destination)
logger.info( logger.info(
"TX [%s] Sending transaction [%s]", "TX [%s] {%s} Sending transaction [%s],"
destination, " (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
) )
with limiter: with limiter:
@ -313,7 +319,10 @@ class TransactionQueue(object):
code = e.code code = e.code
response = e.response response = e.response
logger.info("TX [%s] got %d response", destination, code) logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.debug("TX [%s] Sent transaction", destination) logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination) logger.debug("TX [%s] Marking as delivered...", destination)

View file

@ -50,13 +50,15 @@ class TransportLayerClient(object):
) )
@log_function @log_function
def get_event(self, destination, event_id): def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server. """ Requests the pdu with give id and origin from the given server.
Args: Args:
destination (str): The host name of the remote home server we want destination (str): The host name of the remote home server we want
to get the state from. to get the state from.
event_id (str): The id of the event being requested. event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Deferred: Results in a dict received from the remote homeserver.
@ -65,7 +67,7 @@ class TransportLayerClient(object):
destination, event_id) destination, event_id)
path = PREFIX + "/event/%s/" % (event_id, ) path = PREFIX + "/event/%s/" % (event_id, )
return self.client.get_json(destination, path=path) return self.client.get_json(destination, path=path, timeout=timeout)
@log_function @log_function
def backfill(self, destination, room_id, event_tuples, limit): def backfill(self, destination, room_id, event_tuples, limit):

View file

@ -196,6 +196,14 @@ class FederationSendServlet(BaseFederationServlet):
transaction_id, str(transaction_data) transaction_id, str(transaction_data)
) )
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer. # We should ideally be getting this from the security layer.
# origin = body["origin"] # origin = body["origin"]

View file

@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID from synapse.types import UserID
from synapse.util.logcontext import PreserveLoggingContext
import logging import logging
@ -103,7 +105,9 @@ class BaseHandler(object):
if not suppress_auth: if not suppress_auth:
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.store.persist_event(event, context=context) (event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
@ -137,10 +141,12 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
# Don't block waiting on waking up all the listeners. with PreserveLoggingContext():
notify_d = self.notifier.on_new_room_event( # Don't block waiting on waking up all the listeners.
event, extra_users=extra_users notify_d = self.notifier.on_new_room_event(
) event, event_stream_id, max_stream_id,
extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(

View file

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID from synapse.types import UserID
@ -147,10 +147,7 @@ class ApplicationServicesHandler(object):
) )
# We need to know the members associated with this event.room_id, # We need to know the members associated with this event.room_id,
# if any. # if any.
member_list = yield self.store.get_room_members( member_list = yield self.store.get_users_in_room(event.room_id)
room_id=event.room_id,
membership=Membership.JOIN
)
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_list = [ interested_list = [

View file

@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes
from synapse.types import RoomAlias from synapse.types import RoomAlias
import logging import logging
import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,6 +41,10 @@ class DirectoryHandler(BaseHandler):
def _create_association(self, room_alias, room_id, servers=None): def _create_association(self, room_alias, room_id, servers=None):
# general association creation for both human users and app services # general association creation for both human users and app services
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this. # TODO(erikj): Change this.

View file

@ -15,7 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
with PreserveLoggingContext(): events, tokens = yield self.notifier.get_events_for(
events, tokens = yield self.notifier.get_events_for( auth_user, room_ids, pagin_config, timeout
auth_user, room_ids, pagin_config, timeout )
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -18,9 +18,11 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, StoreError, AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@ -29,6 +31,8 @@ from synapse.crypto.event_signing import (
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.util.retryutils import NotRetryingDestination
from twisted.internet import defer from twisted.internet import defer
import itertools import itertools
@ -156,7 +160,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
yield self._handle_new_event( _, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, origin,
event, event,
state=state, state=state,
@ -197,9 +201,11 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=extra_users d = self.notifier.on_new_room_event(
) event, event_stream_id, max_stream_id,
extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -218,36 +224,209 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit): def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
""" """
extremities = yield self.store.get_oldest_events_in_room(room_id) if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
pdus = yield self.replication_layer.backfill( events = yield self.replication_layer.backfill(
dest, dest,
room_id, room_id,
limit, limit=limit,
extremities=extremities, extremities=extremities,
) )
events = [] event_map = {e.event_id: e for e in events}
for pdu in pdus: event_ids = set(e.event_id for e in events)
event = pdu
# FIXME (erikj): Not sure this actually works :/ edges = [
context = yield self.state_handler.compute_event_context(event) ev.event_id
for ev in events
if set(e_id for e_id, _ in ev.prev_events) - event_ids
]
events.append((event, context)) # For each edge get the current state.
yield self.store.persist_event( auth_events = {}
event, events_to_state = {}
context=context, for e_id in edges:
backfilled=True state, auth = yield self.replication_layer.get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
events_to_state[e_id] = state
yield defer.gatherResults(
[
self._handle_new_event(dest, a)
for a in auth_events.values()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
yield defer.gatherResults(
[
self._handle_new_event(
dest, event_map[e_id],
state=events_to_state[e_id],
backfilled=True,
)
for e_id in events_to_state
],
consumeErrors=True
).addErrback(unwrapFirstError)
events.sort(key=lambda e: e.depth)
for event in events:
if event in events_to_state:
continue
yield self._handle_new_event(
dest, event,
backfilled=True,
) )
defer.returnValue(events) defer.returnValue(events)
@defer.inlineCallbacks
def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
extremities = yield self.store.get_oldest_events_with_depth_in_room(
room_id
)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
return
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(
extremities.items(),
key=lambda e: -int(e[1])
)
max_depth = sorted_extremeties_tuple[0][1]
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",
max_depth, current_depth,
)
return
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
joined_domains = {}
for u, d in joined_users:
try:
dom = UserID.from_string(u).domain
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
except:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains
if domain is not self.server_name
]
@defer.inlineCallbacks
def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
events = yield self.backfill(
dom, room_id,
limit=100,
extremities=[e for e in extremities.keys()]
)
except SynapseError:
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
dom, e,
)
continue
if events:
defer.returnValue(True)
defer.returnValue(False)
success = yield try_backfill(likely_domains)
if success:
defer.returnValue(True)
# Huh, well *those* domains didn't work out. Lets try some domains
# from the time.
tried_domains = set(likely_domains)
tried_domains.add(self.server_name)
event_ids = list(extremities.keys())
states = yield defer.gatherResults([
self.state_handler.resolve_state_groups([e])
for e in event_ids
])
states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
dom for dom in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
tried_domains.update(likely_domains)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, target_host, event): def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing. """ Sends the invite to the remote server for signing.
@ -376,30 +555,14 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
for e in auth_chain: yield self._handle_auth_events(
e.internal_metadata.outlier = True origin, [e for e in auth_chain if e.event_id != event.event_id]
)
@defer.inlineCallbacks
def handle_state(e):
if e.event_id == event.event_id: if e.event_id == event.event_id:
continue return
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
e.event_id,
)
for e in state:
if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
@ -417,13 +580,15 @@ class FederationHandler(BaseHandler):
e.event_id, e.event_id,
) )
yield defer.DeferredList([handle_state(e) for e in state])
auth_ids = [e_id for e_id, _ in event.auth_events] auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = { auth_events = {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event( _, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, origin,
new_event, new_event,
state=state, state=state,
@ -431,9 +596,11 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
new_event, extra_users=[joinee] d = self.notifier.on_new_room_event(
) new_event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -498,7 +665,9 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
context = yield self._handle_new_event(origin, event) context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
logger.debug( logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s", "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@ -512,9 +681,10 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=extra_users d = self.notifier.on_new_room_event(
) event, event_stream_id, max_stream_id, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -587,16 +757,18 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
backfilled=False, backfilled=False,
) )
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=[target_user], d = self.notifier.on_new_room_event(
) event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -745,9 +917,12 @@ class FederationHandler(BaseHandler):
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events: if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1: if len(event.prev_events) == 1 and event.depth < 5:
c = yield self.store.get_event(event.prev_events[0][0]) c = yield self.store.get_event(
if c.type == EventTypes.Create: event.prev_events[0][0],
allow_none=True,
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c auth_events[(c.type, c.state_key)] = c
try: try:
@ -773,7 +948,7 @@ class FederationHandler(BaseHandler):
) )
raise raise
yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
backfilled=backfilled, backfilled=backfilled,
@ -781,7 +956,7 @@ class FederationHandler(BaseHandler):
current_state=current_state, current_state=current_state,
) )
defer.returnValue(context) defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
@ -921,7 +1096,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)
@ -1166,3 +1341,52 @@ class FederationHandler(BaseHandler):
}, },
"missing": [e.event_id for e in missing_locals], "missing": [e.event_id for e in missing_locals],
}) })
@defer.inlineCallbacks
def _handle_auth_events(self, origin, auth_events):
auth_ids_to_deferred = {}
def process_auth_ev(ev):
auth_ids = [e_id for e_id, _ in ev.auth_events]
prev_ds = [
auth_ids_to_deferred[i]
for i in auth_ids
if i in auth_ids_to_deferred
]
d = defer.Deferred()
auth_ids_to_deferred[ev.event_id] = d
@defer.inlineCallbacks
def f(*_):
ev.internal_metadata.outlier = True
try:
auth = {
(e.type, e.state_key): e for e in auth_events
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, ev, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
ev.event_id,
)
d.callback(None)
if prev_ds:
dx = defer.DeferredList(prev_ds)
dx.addBoth(f)
else:
f()
for e in auth_events:
process_auth_ev(e)
yield defer.DeferredList(auth_ids_to_deferred.values())

View file

@ -20,8 +20,9 @@ from synapse.api.errors import RoomError, SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID, RoomStreamToken
from ._base import BaseHandler from ._base import BaseHandler
@ -89,9 +90,19 @@ class MessageHandler(BaseHandler):
if not pagin_config.from_token: if not pagin_config.from_token:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token() yield self.hs.get_event_sources().get_current_token(
direction='b'
)
) )
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological
)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(
@ -303,7 +314,7 @@ class MessageHandler(BaseHandler):
event.room_id event.room_id
), ),
] ]
) ).addErrback(unwrapFirstError)
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
@ -328,7 +339,7 @@ class MessageHandler(BaseHandler):
yield defer.gatherResults( yield defer.gatherResults(
[handle_room(e) for e in room_list], [handle_room(e) for e in room_list],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,

View file

@ -18,8 +18,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
import synapse.metrics import synapse.metrics
@ -146,6 +146,10 @@ class PresenceHandler(BaseHandler):
self._user_cachemap = {} self._user_cachemap = {}
self._user_cachemap_latest_serial = 0 self._user_cachemap_latest_serial = 0
# map room_ids to the latest presence serial for a member of that
# room
self._room_serials = {}
metrics.register_callback( metrics.register_callback(
"userCachemap:size", "userCachemap:size",
lambda: len(self._user_cachemap), lambda: len(self._user_cachemap),
@ -278,15 +282,14 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap was_polling = target_user in self._user_cachemap
with PreserveLoggingContext(): if now_online and not was_polling:
if now_online and not was_polling: self.start_polling_presence(target_user, state=state)
self.start_polling_presence(target_user, state=state) elif not now_online and was_polling:
elif not now_online and was_polling: self.stop_polling_presence(target_user)
self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so # TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time # we don't have to do this all the time
self.changed_presencelike_data(target_user, state) self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None): def bump_presence_active_time(self, user, now=None):
if now is None: if now is None:
@ -298,13 +301,34 @@ class PresenceHandler(BaseHandler):
self.changed_presencelike_data(user, {"last_active": now}) self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to.
Args:
user(UserID): The user.
Returns:
A Deferred of a list of room id strings.
"""
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_joined_rooms_for_user(user)
def get_joined_users_for_room_id(self, room_id):
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_room_members(room_id)
@defer.inlineCallbacks
def changed_presencelike_data(self, user, state): def changed_presencelike_data(self, user, state):
statuscache = self._get_or_make_usercache(user) """Updates the presence state of a local user.
Args:
user(UserID): The user being updated.
state(dict): The new presence state for the user.
Returns:
A Deferred
"""
self._user_cachemap_latest_serial += 1 self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial) statuscache = yield self.update_presence_cache(user, state)
yield self.push_presence(user, statuscache=statuscache)
return self.push_presence(user, statuscache=statuscache)
@log_function @log_function
def started_user_eventstream(self, user): def started_user_eventstream(self, user):
@ -318,14 +342,21 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_joined_room(self, user, room_id): def user_joined_room(self, user, room_id):
if self.hs.is_mine(user): """Called via the distributor whenever a user joins a room.
statuscache = self._get_or_make_usercache(user) Notifies the new member of the presence of the current members.
Notifies the current members of the room of the new member's presence.
Args:
user(UserID): The user who joined the room.
room_id(str): The room id the user joined.
"""
if self.hs.is_mine(user):
# No actual update but we need to bump the serial anyway for the # No actual update but we need to bump the serial anyway for the
# event source # event source
self._user_cachemap_latest_serial += 1 self._user_cachemap_latest_serial += 1
statuscache.update({}, serial=self._user_cachemap_latest_serial) statuscache = yield self.update_presence_cache(
user, room_ids=[room_id]
)
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
observed_user=user, observed_user=user,
room_ids=[room_id], room_ids=[room_id],
@ -333,18 +364,22 @@ class PresenceHandler(BaseHandler):
) )
# We also want to tell them about current presence of people. # We also want to tell them about current presence of people.
rm_handler = self.homeserver.get_handlers().room_member_handler curr_users = yield self.get_joined_users_for_room_id(room_id)
curr_users = yield rm_handler.get_room_members(room_id)
for local_user in [c for c in curr_users if self.hs.is_mine(c)]: for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
statuscache = yield self.update_presence_cache(
local_user, room_ids=[room_id], add_to_cache=False
)
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
observed_user=local_user, observed_user=local_user,
users_to_push=[user], users_to_push=[user],
statuscache=self._get_or_offline_usercache(local_user), statuscache=statuscache,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, observer_user, observed_user): def send_invite(self, observer_user, observed_user):
"""Request the presence of a local or remote user for a local user"""
if not self.hs.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
@ -379,6 +414,15 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def invite_presence(self, observed_user, observer_user): def invite_presence(self, observed_user, observer_user):
"""Handles a m.presence_invite EDU. A remote or local user has
requested presence updates for a local user. If the invite is accepted
then allow the local or remote user to see the presence of the local
user.
Args:
observed_user(UserID): The local user whose presence is requested.
observer_user(UserID): The remote or local user requesting presence.
"""
accept = yield self._should_accept_invite(observed_user, observer_user) accept = yield self._should_accept_invite(observed_user, observer_user)
if accept: if accept:
@ -405,16 +449,34 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def accept_presence(self, observed_user, observer_user): def accept_presence(self, observed_user, observer_user):
"""Handles a m.presence_accept EDU. Mark a presence invite from a
local or remote user as accepted in a local user's presence list.
Starts polling for presence updates from the local or remote user.
Args:
observed_user(UserID): The user to update in the presence list.
observer_user(UserID): The owner of the presence list to update.
"""
yield self.store.set_presence_list_accepted( yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
with PreserveLoggingContext():
self.start_polling_presence( self.start_polling_presence(
observer_user, target_user=observed_user observer_user, target_user=observed_user
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user): def deny_presence(self, observed_user, observer_user):
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
local user's presence list.
Args:
observed_user(UserID): The local or remote user to remove from the
list.
observer_user(UserID): The local owner of the presence list.
Returns:
A Deferred.
"""
yield self.store.del_presence_list( yield self.store.del_presence_list(
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
@ -423,6 +485,16 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def drop(self, observed_user, observer_user): def drop(self, observed_user, observer_user):
"""Remove a local or remote user from a local user's presence list and
unsubscribe the local user from updates that user.
Args:
observed_user(UserId): The local or remote user to remove from the
list.
observer_user(UserId): The local owner of the presence list.
Returns:
A Deferred.
"""
if not self.hs.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
@ -430,34 +502,66 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
with PreserveLoggingContext(): self.stop_polling_presence(
self.stop_polling_presence( observer_user, target_user=observed_user
observer_user, target_user=observed_user )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None): def get_presence_list(self, observer_user, accepted=None):
"""Get the presence list for a local user. The retured list includes
the current presence state for each user listed.
Args:
observer_user(UserID): The local user whose presence list to fetch.
accepted(bool or None): If not none then only include users who
have or have not accepted the presence invite request.
Returns:
A Deferred list of presence state events.
"""
if not self.hs.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
presence = yield self.store.get_presence_list( presence_list = yield self.store.get_presence_list(
observer_user.localpart, accepted=accepted observer_user.localpart, accepted=accepted
) )
for p in presence: results = []
observed_user = UserID.from_string(p.pop("observed_user_id")) for row in presence_list:
p["observed_user"] = observed_user observed_user = UserID.from_string(row["observed_user_id"])
p.update(self._get_or_offline_usercache(observed_user).get_state()) result = {
if "last_active" in p: "observed_user": observed_user, "accepted": row["accepted"]
p["last_active_ago"] = int( }
self.clock.time_msec() - p.pop("last_active") result.update(
self._get_or_offline_usercache(observed_user).get_state()
)
if "last_active" in result:
result["last_active_ago"] = int(
self.clock.time_msec() - result.pop("last_active")
) )
results.append(result)
defer.returnValue(presence) defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def start_polling_presence(self, user, target_user=None, state=None): def start_polling_presence(self, user, target_user=None, state=None):
"""Subscribe a local user to presence updates from a local or remote
user. If no target_user is supplied then subscribe to all users stored
in the presence list for the local user.
Additonally this pushes the current presence state of this user to all
target_users. That state can be provided directly or will be read from
the stored state for the local user.
Also this attempts to notify the local user of the current state of
any local target users.
Args:
user(UserID): The local user that whishes for presence updates.
target_user(UserID): The local or remote user whose updates are
wanted.
state(dict): Optional presence state for the local user.
"""
logger.debug("Start polling for presence from %s", user) logger.debug("Start polling for presence from %s", user)
if target_user: if target_user:
@ -473,8 +577,7 @@ class PresenceHandler(BaseHandler):
# Also include people in all my rooms # Also include people in all my rooms
rm_handler = self.homeserver.get_handlers().room_member_handler room_ids = yield self.get_joined_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if state is None: if state is None:
state = yield self.store.get_presence_state(user.localpart) state = yield self.store.get_presence_state(user.localpart)
@ -498,9 +601,7 @@ class PresenceHandler(BaseHandler):
# We want to tell the person that just came online # We want to tell the person that just came online
# presence state of people they are interested in? # presence state of people they are interested in?
self.push_update_to_clients( self.push_update_to_clients(
observed_user=target_user,
users_to_push=[user], users_to_push=[user],
statuscache=self._get_or_offline_usercache(target_user),
) )
deferreds = [] deferreds = []
@ -517,6 +618,12 @@ class PresenceHandler(BaseHandler):
yield defer.DeferredList(deferreds, consumeErrors=True) yield defer.DeferredList(deferreds, consumeErrors=True)
def _start_polling_local(self, user, target_user): def _start_polling_local(self, user, target_user):
"""Subscribe a local user to presence updates for a local user
Args:
user(UserId): The local user that wishes for updates.
target_user(UserId): The local users whose updates are wanted.
"""
target_localpart = target_user.localpart target_localpart = target_user.localpart
if target_localpart not in self._local_pushmap: if target_localpart not in self._local_pushmap:
@ -525,6 +632,17 @@ class PresenceHandler(BaseHandler):
self._local_pushmap[target_localpart].add(user) self._local_pushmap[target_localpart].add(user)
def _start_polling_remote(self, user, domain, remoteusers): def _start_polling_remote(self, user, domain, remoteusers):
"""Subscribe a local user to presence updates for remote users on a
given remote domain.
Args:
user(UserID): The local user that wishes for updates.
domain(str): The remote server the local user wants updates from.
remoteusers(UserID): The remote users that local user wants to be
told about.
Returns:
A Deferred.
"""
to_poll = set() to_poll = set()
for u in remoteusers: for u in remoteusers:
@ -545,6 +663,17 @@ class PresenceHandler(BaseHandler):
@log_function @log_function
def stop_polling_presence(self, user, target_user=None): def stop_polling_presence(self, user, target_user=None):
"""Unsubscribe a local user from presence updates from a local or
remote user. If no target user is supplied then unsubscribe the user
from all presence updates that the user had subscribed to.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID or None): The user whose updates are no longer
wanted.
Returns:
A Deferred.
"""
logger.debug("Stop polling for presence from %s", user) logger.debug("Stop polling for presence from %s", user)
if not target_user or self.hs.is_mine(target_user): if not target_user or self.hs.is_mine(target_user):
@ -573,6 +702,13 @@ class PresenceHandler(BaseHandler):
return defer.DeferredList(deferreds, consumeErrors=True) return defer.DeferredList(deferreds, consumeErrors=True)
def _stop_polling_local(self, user, target_user): def _stop_polling_local(self, user, target_user):
"""Unsubscribe a local user from presence updates from a local user on
this server.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID): The user whose updates are no longer wanted.
"""
for localpart in self._local_pushmap.keys(): for localpart in self._local_pushmap.keys():
if target_user and localpart != target_user.localpart: if target_user and localpart != target_user.localpart:
continue continue
@ -585,6 +721,17 @@ class PresenceHandler(BaseHandler):
@log_function @log_function
def _stop_polling_remote(self, user, domain, remoteusers): def _stop_polling_remote(self, user, domain, remoteusers):
"""Unsubscribe a local user from presence updates from remote users on
a given domain.
Args:
user(UserID): The local user that no longer wishes for updates.
domain(str): The remote server to unsubscribe from.
remoteusers([UserID]): The users on that remote server that the
local user no longer wishes to be updated about.
Returns:
A Deferred.
"""
to_unpoll = set() to_unpoll = set()
for u in remoteusers: for u in remoteusers:
@ -606,6 +753,19 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def push_presence(self, user, statuscache): def push_presence(self, user, statuscache):
"""
Notify local and remote users of a change in presence of a local user.
Pushes the update to local clients and remote domains that are directly
subscribed to the presence of the local user.
Also pushes that update to any local user or remote domain that shares
a room with the local user.
Args:
user(UserID): The local user whose presence was updated.
statuscache(UserPresenceCache): Cache of the user's presence state
Returns:
A Deferred.
"""
assert(self.hs.is_mine(user)) assert(self.hs.is_mine(user))
logger.debug("Pushing presence update from %s", user) logger.debug("Pushing presence update from %s", user)
@ -617,8 +777,7 @@ class PresenceHandler(BaseHandler):
# and also user is informed of server-forced pushes # and also user is informed of server-forced pushes
localusers.add(user) localusers.add(user)
rm_handler = self.homeserver.get_handlers().room_member_handler room_ids = yield self.get_joined_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if not localusers and not room_ids: if not localusers and not room_ids:
defer.returnValue(None) defer.returnValue(None)
@ -632,45 +791,24 @@ class PresenceHandler(BaseHandler):
) )
yield self.distributor.fire("user_presence_changed", user, statuscache) yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {
"user_id": user.to_string(),
}
user_state.update(**state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={
"push": [
user_state,
],
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def incoming_presence(self, origin, content): def incoming_presence(self, origin, content):
"""Handle an incoming m.presence EDU.
For each presence update in the "push" list update our local cache and
notify the appropriate local clients. Only clients that share a room
or are directly subscribed to the presence for a user should be
notified of the update.
For each subscription request in the "poll" list start pushing presence
updates to the remote server.
For unsubscribe request in the "unpoll" list stop pushing presence
updates to the remote server.
Args:
orgin(str): The source of this m.presence EDU.
content(dict): The content of this m.presence EDU.
Returns:
A Deferred.
"""
deferreds = [] deferreds = []
for push in content.get("push", []): for push in content.get("push", []):
@ -684,8 +822,7 @@ class PresenceHandler(BaseHandler):
" | %d interested local observers %r", len(observers), observers " | %d interested local observers %r", len(observers), observers
) )
rm_handler = self.homeserver.get_handlers().room_member_handler room_ids = yield self.get_joined_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if room_ids: if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
@ -704,20 +841,15 @@ class PresenceHandler(BaseHandler):
self.clock.time_msec() - state.pop("last_active_ago") self.clock.time_msec() - state.pop("last_active_ago")
) )
statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1 self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial) yield self.update_presence_cache(user, state, room_ids=room_ids)
if not observers and not room_ids: if not observers and not room_ids:
logger.debug(" | no interested observers or room IDs") logger.debug(" | no interested observers or room IDs")
continue continue
self.push_update_to_clients( self.push_update_to_clients(
observed_user=user, users_to_push=observers, room_ids=room_ids
users_to_push=observers,
room_ids=room_ids,
statuscache=statuscache,
) )
user_id = user.to_string() user_id = user.to_string()
@ -766,13 +898,58 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]: if not self._remote_sendmap[user]:
del self._remote_sendmap[user] del self._remote_sendmap[user]
with PreserveLoggingContext(): yield defer.DeferredList(deferreds, consumeErrors=True)
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def update_presence_cache(self, user, state={}, room_ids=None,
add_to_cache=True):
"""Update the presence cache for a user with a new state and bump the
serial to the latest value.
Args:
user(UserID): The user being updated
state(dict): The presence state being updated
room_ids(None or list of str): A list of room_ids to update. If
room_ids is None then fetch the list of room_ids the user is
joined to.
add_to_cache: Whether to add an entry to the presence cache if the
user isn't already in the cache.
Returns:
A Deferred UserPresenceCache for the user being updated.
"""
if room_ids is None:
room_ids = yield self.get_joined_rooms_for_user(user)
for room_id in room_ids:
self._room_serials[room_id] = self._user_cachemap_latest_serial
if add_to_cache:
statuscache = self._get_or_make_usercache(user)
else:
statuscache = self._get_or_offline_usercache(user)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
defer.returnValue(statuscache)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache, def push_update_to_local_and_remote(self, observed_user, statuscache,
users_to_push=[], room_ids=[], users_to_push=[], room_ids=[],
remote_domains=[]): remote_domains=[]):
"""Notify local clients and remote servers of a change in the presence
of a user.
Args:
observed_user(UserID): The user to push the presence state for.
statuscache(UserPresenceCache): The cache for the presence state to
push.
users_to_push([UserID]): A list of local and remote users to
notify.
room_ids([str]): Notify the local and remote occupants of these
rooms.
remote_domains([str]): A list of remote servers to notify in
addition to those implied by the users_to_push and the
room_ids.
Returns:
A Deferred.
"""
localusers, remoteusers = partitionbool( localusers, remoteusers = partitionbool(
users_to_push, users_to_push,
@ -782,10 +959,7 @@ class PresenceHandler(BaseHandler):
localusers = set(localusers) localusers = set(localusers)
self.push_update_to_clients( self.push_update_to_clients(
observed_user=observed_user, users_to_push=localusers, room_ids=room_ids
users_to_push=localusers,
room_ids=room_ids,
statuscache=statuscache,
) )
remote_domains = set(remote_domains) remote_domains = set(remote_domains)
@ -810,11 +984,65 @@ class PresenceHandler(BaseHandler):
defer.returnValue((localusers, remote_domains)) defer.returnValue((localusers, remote_domains))
def push_update_to_clients(self, observed_user, users_to_push=[], def push_update_to_clients(self, users_to_push=[], room_ids=[]):
room_ids=[], statuscache=None): """Notify clients of a new presence event.
self.notifier.on_new_user_event(
users_to_push, Args:
room_ids, users_to_push([UserID]): List of users to notify.
room_ids([str]): List of room_ids to notify.
"""
with PreserveLoggingContext():
self.notifier.on_new_user_event(
"presence_key",
self._user_cachemap_latest_serial,
users_to_push,
room_ids,
)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
"""Push a user's presence to a remote server. If a presence state event
that event is sent. Otherwise a new state event is constructed from the
stored presence state.
The last_active is replaced with last_active_ago in case the wallclock
time on the remote server is different to the time on this server.
Sends an EDU to the remote server with the current presence state.
Args:
user(UserID): The user to push the presence state for.
destination(str): The remote server to send state to.
state(dict): The state to push, or None to use the current stored
state.
Returns:
A Deferred.
"""
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {"user_id": user.to_string(), }
user_state.update(state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={"push": [user_state, ], }
) )
@ -823,39 +1051,11 @@ class PresenceEventSource(object):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
@defer.inlineCallbacks
def is_visible(self, observer_user, observed_user):
if observer_user == observed_user:
defer.returnValue(True)
presence = self.hs.get_handlers().presence_handler
if (yield presence.store.user_rooms_intersect(
[u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if self.hs.is_mine(observed_user):
pushmap = presence._local_pushmap
defer.returnValue(
observed_user.localpart in pushmap and
observer_user in pushmap[observed_user.localpart]
)
else:
recvmap = presence._remote_recvmap
defer.returnValue(
observed_user in recvmap and
observer_user in recvmap[observed_user]
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
observer_user = user
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
@ -864,17 +1064,27 @@ class PresenceEventSource(object):
clock = self.clock clock = self.clock
latest_serial = 0 latest_serial = 0
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = [] updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency. for observed_user in user_ids_to_check & set(cachemap):
for observed_user in cachemap.keys():
cached = cachemap[observed_user] cached = cachemap[observed_user]
if cached.serial <= from_key or cached.serial > max_serial: if cached.serial <= from_key or cached.serial > max_serial:
continue continue
if not (yield self.is_visible(observer_user, observed_user)):
continue
latest_serial = max(cached.serial, latest_serial) latest_serial = max(cached.serial, latest_serial)
updates.append(cached.make_event(user=observed_user, clock=clock)) updates.append(cached.make_event(user=observed_user, clock=clock))
@ -911,8 +1121,6 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key): def get_pagination_rows(self, user, pagination_config, key):
# TODO (erikj): Does this make sense? Ordering? # TODO (erikj): Does this make sense? Ordering?
observer_user = user
from_key = int(pagination_config.from_key) from_key = int(pagination_config.from_key)
if pagination_config.to_key: if pagination_config.to_key:
@ -923,14 +1131,26 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] >= from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = [] updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency. for observed_user in user_ids_to_check & set(cachemap):
for observed_user in cachemap.keys():
if not (to_key < cachemap[observed_user].serial <= from_key): if not (to_key < cachemap[observed_user].serial <= from_key):
continue continue
if (yield self.is_visible(observer_user, observed_user)): updates.append((observed_user, cachemap[observed_user]))
updates.append((observed_user, cachemap[observed_user]))
# TODO(paul): limit # TODO(paul): limit

View file

@ -17,8 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID
from synapse.util import unwrapFirstError
from ._base import BaseHandler from ._base import BaseHandler
@ -88,6 +88,9 @@ class ProfileHandler(BaseHandler):
if target_user != auth_user: if target_user != auth_user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '':
new_displayname = None
yield self.store.set_profile_displayname( yield self.store.set_profile_displayname(
target_user.localpart, new_displayname target_user.localpart, new_displayname
) )
@ -154,14 +157,13 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
defer.returnValue(None) defer.returnValue(None)
with PreserveLoggingContext(): (displayname, avatar_url) = yield defer.gatherResults(
(displayname, avatar_url) = yield defer.gatherResults( [
[ self.store.get_profile_displayname(user.localpart),
self.store.get_profile_displayname(user.localpart), self.store.get_profile_avatar_url(user.localpart),
self.store.get_profile_avatar_url(user.localpart), ],
], consumeErrors=True
consumeErrors=True ).addErrback(unwrapFirstError)
)
state["displayname"] = displayname state["displayname"] = displayname
state["avatar_url"] = avatar_url state["avatar_url"] = avatar_url

View file

@ -21,11 +21,12 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
import logging import logging
import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,6 +51,10 @@ class RoomCreationHandler(BaseHandler):
self.ratelimit(user_id) self.ratelimit(user_id)
if "room_alias_name" in config: if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias.create( room_alias = RoomAlias.create(
config["room_alias_name"], config["room_alias_name"],
self.hs.hostname, self.hs.hostname,
@ -535,7 +540,7 @@ class RoomListHandler(BaseHandler):
for room in chunk for room in chunk
], ],
consumeErrors=True, consumeErrors=True,
) ).addErrback(unwrapFirstError)
for i, room in enumerate(chunk): for i, room in enumerate(chunk):
room["num_joined_members"] = len(results[i]) room["num_joined_members"] = len(results[i])
@ -575,8 +580,8 @@ class RoomEventSource(object):
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))
def get_current_key(self): def get_current_key(self, direction='f'):
return self.store.get_room_events_max_id() return self.store.get_room_events_max_id(direction)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pagination_rows(self, user, config, key): def get_pagination_rows(self, user, config, key):

View file

@ -92,7 +92,7 @@ class SyncHandler(BaseHandler):
result = yield self.current_sync_for_user(sync_config, since_token) result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result) defer.returnValue(result)
else: else:
def current_sync_callback(): def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID
import logging import logging
@ -216,7 +217,10 @@ class TypingNotificationHandler(BaseHandler):
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
self.notifier.on_new_user_event(rooms=[room_id]) with PreserveLoggingContext():
self.notifier.on_new_user_event(
"typing_key", self._latest_room_serial, rooms=[room_id]
)
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.util.logcontext import preserve_context_over_fn
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
import synapse.metrics import synapse.metrics
@ -61,7 +62,10 @@ class SimpleHttpClient(object):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
# counters to it # counters to it
outgoing_requests_counter.inc(method) outgoing_requests_counter.inc(method)
d = self.agent.request(method, *args, **kwargs) d = preserve_context_over_fn(
self.agent.request,
method, *args, **kwargs
)
def _cb(response): def _cb(response):
incoming_responses_counter.inc(method, response.code) incoming_responses_counter.inc(method, response.code)

View file

@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics import synapse.metrics
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -110,7 +110,8 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True): query_bytes=b"", retry_on_dns_fail=True,
timeout=None):
""" Creates and sends a request to the given url """ Creates and sends a request to the given url
""" """
headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"User-Agent"] = [self.version_string]
@ -144,22 +145,22 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, url_bytes, headers_dict) producer = body_callback(method, url_bytes, headers_dict)
try: try:
with PreserveLoggingContext(): request_deferred = preserve_context_over_fn(
request_deferred = self.agent.request( self.agent.request,
destination, destination,
endpoint, endpoint,
method, method,
path_bytes, path_bytes,
param_bytes, param_bytes,
query_bytes, query_bytes,
Headers(headers_dict), Headers(headers_dict),
producer producer
) )
response = yield self.clock.time_bound_deferred( response = yield self.clock.time_bound_deferred(
request_deferred, request_deferred,
time_out=60, time_out=timeout/1000. if timeout else 60,
) )
logger.debug("Got response to %s", method) logger.debug("Got response to %s", method)
break break
@ -181,7 +182,7 @@ class MatrixFederationHttpClient(object):
_flatten_response_never_received(e), _flatten_response_never_received(e),
) )
if retries_left: if retries_left and not timeout:
yield sleep(2 ** (5 - retries_left)) yield sleep(2 ** (5 - retries_left))
retries_left -= 1 retries_left -= 1
else: else:
@ -334,7 +335,8 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True): def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
timeout=None):
""" GETs some json from the given host homeserver and path """ GETs some json from the given host homeserver and path
Args: Args:
@ -343,6 +345,9 @@ class MatrixFederationHttpClient(object):
path (str): The HTTP path. path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to args (dict): A dictionary used to create query strings, defaults to
None. None.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
be retried.
Returns: Returns:
Deferred: Succeeds when we get *any* HTTP response. Deferred: Succeeds when we get *any* HTTP response.
@ -370,7 +375,8 @@ class MatrixFederationHttpClient(object):
path.encode("ascii"), path.encode("ascii"),
query_bytes=query_bytes, query_bytes=query_bytes,
body_callback=body_callback, body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:

View file

@ -17,7 +17,7 @@
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
) )
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
import synapse.metrics import synapse.metrics
from syutil.jsonutil import ( from syutil.jsonutil import (
@ -85,7 +85,9 @@ def request_handler(request_handler):
"Received request: %s %s", "Received request: %s %s",
request.method, request.path request.method, request.path
) )
yield request_handler(self, request) d = request_handler(self, request)
with PreserveLoggingContext():
yield d
code = request.code code = request.code
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.async import run_on_reactor
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -42,63 +42,78 @@ def count(func, l):
class _NotificationListener(object): class _NotificationListener(object):
""" This represents a single client connection to the events stream. """ This represents a single client connection to the events stream.
The events stream handler will have yielded to the deferred, so to The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred. notify the handler it is sufficient to resolve the deferred.
"""
def __init__(self, deferred):
self.deferred = deferred
def notified(self):
return self.deferred.called
def notify(self, token):
""" Inform whoever is listening about the new events.
"""
try:
self.deferred.callback(token)
except defer.AlreadyCalledError:
pass
class _NotifierUserStream(object):
"""This represents a user connected to the event stream.
It tracks the most recent stream token for that user.
At a given point a user may have a number of streams listening for
events.
This listener will also keep track of which rooms it is listening in This listener will also keep track of which rooms it is listening in
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, from_token, limit, timeout, deferred, def __init__(self, user, rooms, current_token, time_now_ms,
appservice=None): appservice=None):
self.user = user self.user = str(user)
self.appservice = appservice self.appservice = appservice
self.from_token = from_token self.listeners = set()
self.limit = limit self.rooms = set(rooms)
self.timeout = timeout self.current_token = current_token
self.deferred = deferred self.last_notified_ms = time_now_ms
self.rooms = rooms
self.timer = None
def notified(self): def notify(self, stream_key, stream_id, time_now_ms):
return self.deferred.called """Notify any listeners for this user of a new event from an
event source.
Args:
stream_key(str): The stream the event came from.
stream_id(str): The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds.
"""
self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id
)
if self.listeners:
self.last_notified_ms = time_now_ms
listeners = self.listeners
self.listeners = set()
for listener in listeners:
listener.notify(self.current_token)
def notify(self, notifier, events, start_token, end_token): def remove(self, notifier):
""" Inform whoever is listening about the new events. This will """ Remove this listener from all the indexes in the Notifier
also remove this listener from all the indexes in the Notifier
it knows about. it knows about.
""" """
result = (events, (start_token, end_token))
try:
self.deferred.callback(result)
notified_events_counter.inc_by(len(events))
except defer.AlreadyCalledError:
pass
# Should the following be done be using intrusively linked lists?
# -- erikj
for room in self.rooms: for room in self.rooms:
lst = notifier.room_to_listeners.get(room, set()) lst = notifier.room_to_user_streams.get(room, set())
lst.discard(self) lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self) notifier.user_to_user_stream.pop(self.user)
if self.appservice: if self.appservice:
notifier.appservice_to_listeners.get( notifier.appservice_to_user_streams.get(
self.appservice, set() self.appservice, set()
).discard(self) ).discard(self)
# Cancel the timeout for this notifer if one exists.
if self.timer is not None:
try:
notifier.clock.cancel_call_later(self.timer)
except:
logger.warn("Failed to cancel notifier timer")
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
@ -107,14 +122,18 @@ class Notifier(object):
Primarily used from the /events stream. Primarily used from the /events stream.
""" """
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.room_to_listeners = {} self.user_to_user_stream = {}
self.user_to_listeners = {} self.room_to_user_streams = {}
self.appservice_to_listeners = {} self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -122,45 +141,80 @@ class Notifier(object):
"user_joined_room", self._user_joined_room "user_joined_room", self._user_joined_room
) )
self.clock.looping_call(
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
)
# This is not a very cheap test to perform, but it's only executed # This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at # when rendering the metrics page, which is likely once per minute at
# most when scraping it. # most when scraping it.
def count_listeners(): def count_listeners():
all_listeners = set() all_user_streams = set()
for x in self.room_to_listeners.values(): for x in self.room_to_user_streams.values():
all_listeners |= x all_user_streams |= x
for x in self.user_to_listeners.values(): for x in self.user_to_user_stream.values():
all_listeners |= x all_user_streams.add(x)
for x in self.appservice_to_listeners.values(): for x in self.appservice_to_user_streams.values():
all_listeners |= x all_user_streams |= x
return len(all_listeners) return sum(len(stream.listeners) for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
metrics.register_callback( metrics.register_callback(
"rooms", "rooms",
lambda: count(bool, self.room_to_listeners.values()), lambda: count(bool, self.room_to_user_streams.values()),
) )
metrics.register_callback( metrics.register_callback(
"users", "users",
lambda: count(bool, self.user_to_listeners.values()), lambda: len(self.user_to_user_stream),
) )
metrics.register_callback( metrics.register_callback(
"appservices", "appservices",
lambda: count(bool, self.appservice_to_listeners.values()), lambda: count(bool, self.appservice_to_user_streams.values()),
) )
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_room_event(self, event, extra_users=[]): def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
in the room, room event wise. in the room, room event wise.
This triggers the notifier to wake up any listeners that are This triggers the notifier to wake up any listeners that are
listening to the room, and any listeners for the users in the listening to the room, and any listeners for the users in the
`extra_users` param. `extra_users` param.
The events can be peristed out of order. The notifier will wait
until all previous events have been persisted before notifying
the client streams.
""" """
yield run_on_reactor()
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
self._notify_pending_new_room_events(max_room_stream_id)
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
max_room_stream_id(int): The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
self.pending_new_room_events = []
for room_stream_id, event, extra_users in pending:
if room_stream_id > max_room_stream_id:
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
else:
self._on_new_room_event(event, room_stream_id, extra_users)
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services( self.hs.get_handlers().appservice_handler.notify_interested_services(
event event
@ -168,192 +222,129 @@ class Notifier(object):
room_id = event.room_id room_id = event.room_id
room_source = self.event_sources.sources["room"] room_user_streams = self.room_to_user_streams.get(room_id, set())
room_listeners = self.room_to_listeners.get(room_id, set()) user_streams = room_user_streams.copy()
_discard_if_notified(room_listeners)
listeners = room_listeners.copy()
for user in extra_users: for user in extra_users:
user_listeners = self.user_to_listeners.get(user, set()) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
_discard_if_notified(user_listeners) for appservice in self.appservice_to_user_streams:
listeners |= user_listeners
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks? # TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_listeners set, but # App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to # that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this # receive *invites* for users they are interested in. Does this
# make the room_to_listeners check somewhat obselete? # make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event): if appservice.is_interested(event):
app_listeners = self.appservice_to_listeners.get( app_user_streams = self.appservice_to_user_streams.get(
appservice, set() appservice, set()
) )
user_streams |= app_user_streams
_discard_if_notified(app_listeners) logger.debug("on_new_room_event listeners %s", user_streams)
listeners |= app_listeners time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
logger.debug("on_new_room_event listeners %s", listeners) try:
user_stream.notify(
# TODO (erikj): Can we make this more efficient by hitting the "room_key", "s%d" % (room_stream_id,), time_now_ms
# db once?
@defer.inlineCallbacks
def notify(listener):
events, end_key = yield room_source.get_new_events_for_user(
listener.user,
listener.from_token.room_key,
listener.limit,
)
if events:
end_token = listener.from_token.copy_and_replace(
"room_key", end_key
) )
except:
listener.notify( logger.exception("Failed to notify listener")
self, events, listener.from_token, end_token
)
def eb(failure):
logger.exception("Failed to notify listener", failure)
with PreserveLoggingContext():
yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_new_user_event(self, users=[], rooms=[]): def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend """ Used to inform listeners that something has happend
presence/user event wise. presence/user event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
# TODO(paul): This is horrible, having to manually list every event yield run_on_reactor()
# source here individually user_streams = set()
presence_source = self.event_sources.sources["presence"]
typing_source = self.event_sources.sources["typing"]
listeners = set()
for user in users: for user in users:
user_listeners = self.user_to_listeners.get(user, set()) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
_discard_if_notified(user_listeners) user_streams.add(user_stream)
listeners |= user_listeners
for room in rooms: for room in rooms:
room_listeners = self.room_to_listeners.get(room, set()) user_streams |= self.room_to_user_streams.get(room, set())
_discard_if_notified(room_listeners) time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
listeners |= room_listeners try:
user_stream.notify(stream_key, new_token, time_now_ms)
@defer.inlineCallbacks except:
def notify(listener): logger.exception("Failed to notify listener")
presence_events, presence_end_key = (
yield presence_source.get_new_events_for_user(
listener.user,
listener.from_token.presence_key,
listener.limit,
)
)
typing_events, typing_end_key = (
yield typing_source.get_new_events_for_user(
listener.user,
listener.from_token.typing_key,
listener.limit,
)
)
if presence_events or typing_events:
end_token = listener.from_token.copy_and_replace(
"presence_key", presence_end_key
).copy_and_replace(
"typing_key", typing_end_key
)
listener.notify(
self,
presence_events + typing_events,
listener.from_token,
end_token
)
def eb(failure):
logger.error(
"Failed to notify listener",
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject())
)
with PreserveLoggingContext():
yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, filter, timeout, callback): def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
deferred = defer.Deferred() deferred = defer.Deferred()
time_now_ms = self.clock.time_msec()
from_token = StreamToken("s0", "0", "0") user = str(user)
user_stream = self.user_to_user_stream.get(user)
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(user)
rooms = [room.room_id for room in rooms]
user_stream = _NotifierUserStream(
user=user,
rooms=rooms,
appservice=appservice,
current_token=current_token,
time_now_ms=time_now_ms,
)
self._register_with_keys(user_stream)
else:
current_token = user_stream.current_token
listener = [_NotificationListener( listener = [_NotificationListener(deferred)]
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)]
if timeout: if timeout and not current_token.is_after(from_token):
self._register_with_keys(listener[0]) user_stream.listeners.add(listener[0])
if current_token.is_after(from_token):
result = yield callback(from_token, current_token)
else:
result = None
result = yield callback()
timer = [None] timer = [None]
if result:
user_stream.listeners.discard(listener[0])
defer.returnValue(result)
return
if timeout: if timeout:
timed_out = [False] timed_out = [False]
def _timeout_listener(): def _timeout_listener():
timed_out[0] = True timed_out[0] = True
timer[0] = None timer[0] = None
listener[0].notify(self, [], from_token, from_token) user_stream.listeners.discard(listener[0])
listener[0].notify(current_token)
# We create multiple notification listeners so we have to manage # We create multiple notification listeners so we have to manage
# canceling the timeout ourselves. # canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener) timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]: while not result and not timed_out[0]:
yield deferred new_token = yield deferred
deferred = defer.Deferred() deferred = defer.Deferred()
listener[0] = _NotificationListener( listener[0] = _NotificationListener(deferred)
user=user, user_stream.listeners.add(listener[0])
rooms=rooms, result = yield callback(current_token, new_token)
from_token=from_token, current_token = new_token
limit=1,
timeout=timeout,
deferred=deferred,
)
self._register_with_keys(listener[0])
result = yield callback()
if timer[0] is not None: if timer[0] is not None:
try: try:
@ -363,125 +354,79 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def get_events_for(self, user, rooms, pagination_config, timeout): def get_events_for(self, user, rooms, pagination_config, timeout):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
""" """
deferred = defer.Deferred() from_token = pagination_config.from_token
self._get_events(
deferred, user, rooms, pagination_config.from_token,
pagination_config.limit, timeout
).addErrback(deferred.errback)
return deferred
@defer.inlineCallbacks
def _get_events(self, deferred, user, rooms, from_token, limit, timeout):
if not from_token: if not from_token:
from_token = yield self.event_sources.get_current_token() from_token = yield self.event_sources.get_current_token()
appservice = yield self.hs.get_datastore().get_app_service_by_user_id( limit = pagination_config.limit
user.to_string()
)
listener = _NotificationListener( @defer.inlineCallbacks
user, def check_for_updates(before_token, after_token):
rooms, events = []
from_token, end_token = from_token
limit, for name, source in self.event_sources.sources.items():
timeout, keyname = "%s_key" % name
deferred, before_id = getattr(before_token, keyname)
appservice=appservice after_id = getattr(after_token, keyname)
) if before_id == after_id:
continue
def _timeout_listener(): stuff, new_key = yield source.get_new_events_for_user(
# TODO (erikj): We should probably set to_token to the current user, getattr(from_token, keyname), limit,
# max rather than reusing from_token.
# Remove the timer from the listener so we don't try to cancel it.
listener.timer = None
listener.notify(
self,
[],
listener.from_token,
listener.from_token,
)
if timeout:
self._register_with_keys(listener)
yield self._check_for_updates(listener)
if not timeout:
_timeout_listener()
else:
# Only add the timer if the listener hasn't been notified
if not listener.notified():
listener.timer = self.clock.call_later(
timeout/1000.0, _timeout_listener
) )
return events.extend(stuff)
end_token = end_token.copy_and_replace(keyname, new_key)
if events:
defer.returnValue((events, (from_token, end_token)))
else:
defer.returnValue(None)
result = yield self.wait_for_events(
user, rooms, timeout, check_for_updates, from_token=from_token
)
if result is None:
result = ([], (from_token, from_token))
defer.returnValue(result)
@log_function @log_function
def _register_with_keys(self, listener): def remove_expired_streams(self):
for room in listener.rooms: time_now_ms = self.clock.time_msec()
s = self.room_to_listeners.setdefault(room, set()) expired_streams = []
s.add(listener) expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values():
if stream.listeners:
continue
if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream)
self.user_to_listeners.setdefault(listener.user, set()).add(listener) for expired_stream in expired_streams:
expired_stream.remove(self)
if listener.appservice:
self.appservice_to_listeners.setdefault(
listener.appservice, set()
).add(listener)
@defer.inlineCallbacks
@log_function @log_function
def _check_for_updates(self, listener): def _register_with_keys(self, user_stream):
# TODO (erikj): We need to think about limits across multiple sources self.user_to_user_stream[user_stream.user] = user_stream
events = []
from_token = listener.from_token for room in user_stream.rooms:
limit = listener.limit s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
# TODO (erikj): DeferredList? if user_stream.appservice:
for name, source in self.event_sources.sources.items(): self.appservice_to_user_stream.setdefault(
keyname = "%s_key" % name user_stream.appservice, set()
).add(user_stream)
stuff, new_key = yield source.get_new_events_for_user(
listener.user,
getattr(from_token, keyname),
limit,
)
events.extend(stuff)
from_token = from_token.copy_and_replace(keyname, new_key)
end_token = from_token
if events:
listener.notify(self, events, listener.from_token, end_token)
defer.returnValue(listener)
def _user_joined_room(self, user, room_id): def _user_joined_room(self, user, room_id):
new_listeners = self.user_to_listeners.get(user, set()) user = str(user)
new_user_stream = self.user_to_user_stream.get(user)
listeners = self.room_to_listeners.setdefault(room_id, set()) if new_user_stream is not None:
listeners |= new_listeners room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream)
for l in new_listeners: new_user_stream.rooms.add(room_id)
l.rooms.add(room_id)
def _discard_if_notified(listener_set):
"""Remove any 'stale' listeners from the given set.
"""
to_discard = set()
for l in listener_set:
if l.notified():
to_discard.add(l)
listener_set -= to_discard

View file

@ -74,15 +74,18 @@ class Pusher(object):
rawrules = yield self.store.get_push_rules_for_user(self.user_name) rawrules = yield self.store.get_push_rules_for_user(self.user_name)
for r in rawrules: rules = []
r['conditions'] = json.loads(r['conditions']) for rawrule in rawrules:
r['actions'] = json.loads(r['actions']) rule = dict(rawrule)
rule['conditions'] = json.loads(rawrule['conditions'])
rule['actions'] = json.loads(rawrule['actions'])
rules.append(rule)
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name) enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
user = UserID.from_string(self.user_name) user = UserID.from_string(self.user_name)
rules = baserules.list_with_base_rules(rawrules, user) rules = baserules.list_with_base_rules(rules, user)
room_id = ev['room_id'] room_id = ev['room_id']

View file

@ -118,11 +118,14 @@ class PushRuleRestServlet(ClientV1RestServlet):
user.to_string() user.to_string()
) )
for r in rawrules: ruleslist = []
r["conditions"] = json.loads(r["conditions"]) for rawrule in rawrules:
r["actions"] = json.loads(r["actions"]) rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
ruleslist = baserules.list_with_base_rules(rawrules, user) ruleslist = baserules.list_with_base_rules(ruleslist, user)
rules = {'global': {}, 'device': {}} rules = {'global': {}, 'device': {}}

View file

@ -82,8 +82,10 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
result = None
if service: if service:
is_application_server = True is_application_server = True
params = body
elif 'mac' in body: elif 'mac' in body:
# Check registration-specific shared secret auth # Check registration-specific shared secret auth
if 'username' not in body: if 'username' not in body:
@ -92,6 +94,7 @@ class RegisterRestServlet(RestServlet):
body['username'], body['mac'] body['username'], body['mac']
) )
is_using_shared_secret = True is_using_shared_secret = True
params = body
else: else:
authed, result, params = yield self.auth_handler.check_auth( authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
@ -118,7 +121,7 @@ class RegisterRestServlet(RestServlet):
password=new_password password=new_password
) )
if LoginType.EMAIL_IDENTITY in result: if result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY] threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']: for reqd in ['medium', 'address', 'validated_at']:

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.util.async import create_observer from synapse.util.async import ObservableDeferred
import os import os
@ -83,13 +83,17 @@ class BaseMediaResource(Resource):
download = self.downloads.get(key) download = self.downloads.get(key)
if download is None: if download is None:
download = self._get_remote_media_impl(server_name, media_id) download = self._get_remote_media_impl(server_name, media_id)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download self.downloads[key] = download
@download.addBoth @download.addBoth
def callback(media_info): def callback(media_info):
del self.downloads[key] del self.downloads[key]
return media_info return media_info
return create_observer(download) return download.observe()
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id): def _get_remote_media_impl(self, server_name, media_id):

View file

@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 18 SCHEMA_VERSION = 19
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -15,10 +15,8 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
import synapse.metrics import synapse.metrics
@ -27,8 +25,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import functools import functools
import simplejson as json
import sys import sys
import time import time
import threading import threading
@ -48,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
caches_by_name = {} caches_by_name = {}
cache_counter = metrics.register_cache( cache_counter = metrics.register_cache(
@ -307,6 +304,12 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
self._pending_ds = []
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator() self._stream_id_gen = StreamIdGenerator()
@ -315,6 +318,7 @@ class SQLBaseStore(object):
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()
@ -345,6 +349,75 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
r = func(txn, *args, **kwargs)
conn.commit()
return r
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
@defer.inlineCallbacks @defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool.""" """Wraps the .runInteraction() method on the underlying db_pool."""
@ -356,84 +429,52 @@ class SQLBaseStore(object):
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn): if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection") logger.debug("Reconnecting closed database connection")
conn.reconnect() conn.reconnect()
current_context.copy_to(context) current_context.copy_to(context)
start = time.time() * 1000 return self._new_transaction(
txn_id = self._TXN_ID conn, desc, after_callbacks, func, *args, **kwargs
)
# We don't really need these to be unique, so lets stop it from result = yield preserve_context_over_fn(
# growing really large. self._db_pool.runWithConnection,
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) inner_func, *args, **kwargs
)
name = "%s-%x" % (desc, txn_id, )
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
for after_callback, after_args in after_callbacks: for after_callback, after_args in after_callbacks:
after_callback(*after_args) after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
defer.returnValue(result)
def cursor_to_dict(self, cursor): def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.
@ -871,158 +912,6 @@ class SQLBaseStore(object):
return self.runInteraction("_simple_max_id", func) return self.runInteraction("_simple_max_id", func)
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False):
return self.runInteraction(
"_get_events", self._get_events_txn, event_ids,
check_redacted=check_redacted, get_prev_content=get_prev_content,
)
def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False):
if not event_ids:
return []
events = [
self._get_event_txn(
txn, event_id,
check_redacted=check_redacted,
get_prev_content=get_prev_content
)
for event_id in event_ids
]
return [e for e in events if e]
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
start_time = time.time() * 1000
def update_counter(desc, last_time):
curr_time = self._get_event_counters.update(desc, last_time)
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
try:
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
if allow_rejected or not ret.rejected_reason:
return ret
else:
return None
except KeyError:
pass
finally:
start_time = update_counter("event_cache", start_time)
sql = (
"SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
"FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
"LEFT JOIN rejections as rej on rej.event_id = e.event_id "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
txn.execute(sql, (event_id,))
res = txn.fetchone()
if not res:
return None
internal_metadata, js, redacted, rejected_reason = res
start_time = update_counter("select_event", start_time)
result = self._get_event_from_row_txn(
txn, internal_metadata, js, redacted,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=rejected_reason,
)
self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
if allow_rejected or not rejected_reason:
return result
else:
return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
start_time = time.time() * 1000
def update_counter(desc, last_time):
curr_time = self._get_event_counters.update(desc, last_time)
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
d = json.loads(js)
start_time = update_counter("decode_json", start_time)
internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
start_time = update_counter("build_frozen_event", start_time)
if check_redacted and redacted:
ev = prune_event(ev)
ev.unsigned["redacted_by"] = redacted
# Get the redaction event.
because = self._get_event_txn(
txn,
redacted,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
start_time = update_counter("redact_event", start_time)
if get_prev_content and "replaces_state" in ev.unsigned:
prev = self._get_event_txn(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
start_time = update_counter("get_prev_content", start_time)
return ev
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
result = txn.fetchone()
return result[0] if result else None
def get_next_stream_id(self): def get_next_stream_id(self):
with self._next_stream_id_lock: with self._next_stream_id_lock:
i = self._next_stream_id i = self._next_stream_id

View file

@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object): class PostgresEngine(object):
single_threaded = False
def __init__(self, database_module): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)

View file

@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True
def __init__(self, database_module): def __init__(self, database_module):
self.module = database_module self.module = database_module

View file

@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cached
from syutil.base64util import encode_base64 from syutil.base64util import encode_base64
import logging import logging
from Queue import PriorityQueue, Empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,16 +36,7 @@ class EventFederationStore(SQLBaseStore):
""" """
def get_auth_chain(self, event_ids): def get_auth_chain(self, event_ids):
return self.runInteraction( return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
"get_auth_chain",
self._get_auth_chain_txn,
event_ids
)
def _get_auth_chain_txn(self, txn, event_ids):
results = self._get_auth_chain_ids_txn(txn, event_ids)
return self._get_events_txn(txn, results)
def get_auth_chain_ids(self, event_ids): def get_auth_chain_ids(self, event_ids):
return self.runInteraction( return self.runInteraction(
@ -79,6 +73,28 @@ class EventFederationStore(SQLBaseStore):
room_id, room_id,
) )
def get_oldest_events_with_depth_in_room(self, room_id):
return self.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
)
def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
sql = (
"SELECT b.event_id, MAX(e.depth) FROM events as e"
" INNER JOIN event_edges as g"
" ON g.event_id = e.event_id AND g.room_id = e.room_id"
" INNER JOIN event_backward_extremities as b"
" ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
" WHERE b.room_id = ? AND g.is_state is ?"
" GROUP BY b.event_id"
)
txn.execute(sql, (room_id, False,))
return dict(txn.fetchall())
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
txn, txn,
@ -247,11 +263,13 @@ class EventFederationStore(SQLBaseStore):
do_insert = depth < min_depth if min_depth else True do_insert = depth < min_depth if min_depth else True
if do_insert: if do_insert:
self._simple_insert_txn( self._simple_upsert_txn(
txn, txn,
table="room_depth", table="room_depth",
values={ keyvalues={
"room_id": room_id, "room_id": room_id,
},
values={
"min_depth": depth, "min_depth": depth,
}, },
) )
@ -306,31 +324,28 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (event_id, room_id)) txn.execute(query, (event_id, room_id))
# Insert all the prev_events as a backwards thing, they'll get query = (
# deleted in a second if they're incorrect anyway. "INSERT INTO event_backward_extremities (event_id, room_id)"
self._simple_insert_many_txn( " SELECT ?, ? WHERE NOT EXISTS ("
txn, " SELECT 1 FROM event_backward_extremities"
table="event_backward_extremities", " WHERE event_id = ? AND room_id = ?"
values=[ " )"
{ " AND NOT EXISTS ("
"event_id": e_id, " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
"room_id": room_id, " AND outlier = ?"
} " )"
for e_id, _ in prev_events
],
) )
# Also delete from the backwards extremities table all ones that txn.executemany(query, [
# reference events that we have already seen (e_id, room_id, e_id, room_id, e_id, room_id, False)
for e_id, _ in prev_events
])
query = ( query = (
"DELETE FROM event_backward_extremities WHERE EXISTS (" "DELETE FROM event_backward_extremities"
"SELECT 1 FROM events " " WHERE event_id = ? AND room_id = ?"
"WHERE "
"event_backward_extremities.event_id = events.event_id "
"AND not events.outlier "
")"
) )
txn.execute(query) txn.execute(query, (event_id, room_id))
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, room_id
@ -349,6 +364,10 @@ class EventFederationStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_backfill_events", "get_backfill_events",
self._get_backfill_events, room_id, event_list, limit self._get_backfill_events, room_id, event_list, limit
).addCallback(
self._get_events
).addCallback(
lambda l: sorted(l, key=lambda e: -e.depth)
) )
def _get_backfill_events(self, txn, room_id, event_list, limit): def _get_backfill_events(self, txn, room_id, event_list, limit):
@ -357,54 +376,75 @@ class EventFederationStore(SQLBaseStore):
room_id, repr(event_list), limit room_id, repr(event_list), limit
) )
event_results = event_list event_results = set()
front = event_list # We want to make sure that we do a breadth-first, "depth" ordered
# search.
query = ( query = (
"SELECT prev_event_id FROM event_edges " "SELECT depth, prev_event_id FROM event_edges"
"WHERE room_id = ? AND event_id = ? " " INNER JOIN events"
"LIMIT ?" " ON prev_event_id = events.event_id"
" AND event_edges.room_id = events.room_id"
" WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
" AND event_edges.is_state = ?"
" LIMIT ?"
) )
# We iterate through all event_ids in `front` to select their previous queue = PriorityQueue()
# events. These are dumped in `new_front`.
# We continue until we reach the limit *or* new_front is empty (i.e.,
# we've run out of things to select
while front and len(event_results) < limit:
new_front = [] for event_id in event_list:
for event_id in front: depth = self._simple_select_one_onecol_txn(
logger.debug( txn,
"_backfill_interaction: id=%s", table="events",
event_id keyvalues={
) "event_id": event_id,
},
retcol="depth"
)
txn.execute( queue.put((-depth, event_id))
query,
(room_id, event_id, limit - len(event_results))
)
for row in txn.fetchall(): while not queue.empty() and len(event_results) < limit:
logger.debug( try:
"_backfill_interaction: got id=%s", _, event_id = queue.get_nowait()
*row except Empty:
) break
new_front.append(row[0])
front = new_front if event_id in event_results:
event_results += new_front continue
return self._get_events_txn(txn, event_results) event_results.add(event_id)
txn.execute(
query,
(room_id, event_id, False, limit - len(event_results))
)
for row in txn.fetchall():
if row[1] not in event_results:
queue.put((-row[0], row[1]))
return event_results
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, def get_missing_events(self, room_id, earliest_events, latest_events,
limit, min_depth): limit, min_depth):
return self.runInteraction( ids = yield self.runInteraction(
"get_missing_events", "get_missing_events",
self._get_missing_events, self._get_missing_events,
room_id, earliest_events, latest_events, limit, min_depth room_id, earliest_events, latest_events, limit, min_depth
) )
events = yield self._get_events(ids)
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
defer.returnValue(events[:limit])
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
limit, min_depth): limit, min_depth):
@ -436,14 +476,7 @@ class EventFederationStore(SQLBaseStore):
front = new_front front = new_front
event_results |= new_front event_results |= new_front
events = self._get_events_txn(txn, event_results) return event_results
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
return events[:limit]
def clean_room_for_join(self, room_id): def clean_room_for_join(self, room_id):
return self.runInteraction( return self.runInteraction(
@ -456,3 +489,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?" query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)

View file

@ -15,20 +15,36 @@
from _base import SQLBaseStore, _RollbackButIsFineException from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer from twisted.internet import defer, reactor
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64 from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from contextlib import contextmanager
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# These values are used in the `enqueus_event` and `_do_fetch` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
# TODO: Make these configurable.
EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -41,20 +57,32 @@ class EventsStore(SQLBaseStore):
self.min_token -= 1 self.min_token -= 1
stream_ordering = self.min_token stream_ordering = self.min_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
else:
@contextmanager
def stream_ordering_manager():
yield stream_ordering
stream_ordering_manager = stream_ordering_manager()
try: try:
yield self.runInteraction( with stream_ordering_manager as stream_ordering:
"persist_event", yield self.runInteraction(
self._persist_event_txn, "persist_event",
event=event, self._persist_event_txn,
context=context, event=event,
backfilled=backfilled, context=context,
stream_ordering=stream_ordering, backfilled=backfilled,
is_new_state=is_new_state, stream_ordering=stream_ordering,
current_state=current_state, is_new_state=is_new_state,
) current_state=current_state,
)
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False, get_prev_content=False, allow_rejected=False,
@ -74,18 +102,17 @@ class EventsStore(SQLBaseStore):
Returns: Returns:
Deferred : A FrozenEvent. Deferred : A FrozenEvent.
""" """
event = yield self.runInteraction( events = yield self._get_events(
"get_event", self._get_event_txn, [event_id],
event_id,
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
allow_rejected=allow_rejected, allow_rejected=allow_rejected,
) )
if not event and not allow_none: if not events and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,)) raise RuntimeError("Could not find event %s" % (event_id,))
defer.returnValue(event) defer.returnValue(events[0] if events else None)
@log_function @log_function
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context, backfilled,
@ -95,15 +122,6 @@ class EventsStore(SQLBaseStore):
# Remove the any existing cache entries for the event_id # Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id) txn.call_after(self._invalidate_get_event_cache, event.event_id)
if stream_ordering is None:
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
return self._persist_event_txn(
txn, event, context, backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
@ -134,19 +152,17 @@ class EventsStore(SQLBaseStore):
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
if not outlier: if not outlier:
self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn( self._update_min_depth_for_room_txn(
txn, txn,
event.room_id, event.room_id,
event.depth event.depth
) )
have_persisted = self._simple_select_one_onecol_txn( have_persisted = self._simple_select_one_txn(
txn, txn,
table="event_json", table="events",
keyvalues={"event_id": event.event_id}, keyvalues={"event_id": event.event_id},
retcol="event_id", retcols=["event_id", "outlier"],
allow_none=True, allow_none=True,
) )
@ -161,7 +177,9 @@ class EventsStore(SQLBaseStore):
# if we are persisting an event that we had persisted as an outlier, # if we are persisting an event that we had persisted as an outlier,
# but is no longer one. # but is no longer one.
if have_persisted: if have_persisted:
if not outlier: if not outlier and have_persisted["outlier"]:
self._store_state_groups_txn(txn, event, context)
sql = ( sql = (
"UPDATE event_json SET internal_metadata = ?" "UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?" " WHERE event_id = ?"
@ -181,6 +199,9 @@ class EventsStore(SQLBaseStore):
) )
return return
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._handle_prev_events( self._handle_prev_events(
txn, txn,
outlier=outlier, outlier=outlier,
@ -400,3 +421,407 @@ class EventsStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"have_events", f, "have_events", f,
) )
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not event_ids:
defer.returnValue([])
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids:
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not event_ids:
return []
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids:
return [
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
]
missing_events = self._fetch_events_txn(
txn,
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
return [
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
]
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
events = self._get_events_txn(
txn, [event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
return events[0] if events else None
def _get_events_from_cache(self, events, check_redacted, get_prev_content,
allow_rejected):
event_map = {}
for event_id in events:
try:
ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content
)
if allow_rejected or not ret.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
except KeyError:
pass
return event_map
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
event_list = []
i = 0
while True:
try:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
single_threaded = self.database_engine.single_threaded
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
self._event_fetch_ongoing -= 1
return
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
event_id_lists = zip(*event_list)[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction(
conn, "do_fetch", [], self._fetch_event_rows, event_ids
)
row_dict = {
r["event_id"]: r
for r in rows
}
# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
if not d.called:
try:
d.callback([
res[i]
for i in ids
if i in res
])
except:
logger.exception("Failed to callback")
reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
for _, d in evs:
if not d.called:
d.errback(e)
if event_list:
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
"""
if not events:
defer.returnValue({})
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
should_start = True
else:
should_start = False
if should_start:
self.runWithConnection(
self._do_fetch
)
rows = yield preserve_context_over_deferred(events_d)
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults(
[
self._get_event_from_row(
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row["rejects"],
)
for row in rows
],
consumeErrors=True
)
defer.returnValue({
e.event_id: e
for e in res if e
})
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N]
if not evs:
break
sql = (
"SELECT "
" e.event_id as event_id, "
" e.internal_metadata,"
" e.json,"
" r.redacts as redacts,"
" rej.event_id as rejects "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
return rows
def _fetch_events_txn(self, txn, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not events:
return {}
rows = self._fetch_event_rows(
txn, events,
)
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = [
self._get_event_from_row_txn(
txn,
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row["rejects"],
)
for row in rows
]
return {
r.event_id: r
for r in res
}
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
desc="_get_event_from_row",
)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
if check_redacted and redacted:
ev = prune_event(ev)
redaction_id = yield self._simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
desc="_get_event_from_row",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
prev = yield self.get_event(
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
)
defer.returnValue(ev)
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = self._simple_select_one_onecol_txn(
txn,
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
if check_redacted and redacted:
ev = prune_event(ev)
redaction_id = self._simple_select_one_onecol_txn(
txn,
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = self._get_event_txn(
txn,
redaction_id,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
prev = self._get_event_txn(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
)
return ev
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
result = txn.fetchone()
return result[0] if result else None

View file

@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore, cached
from twisted.internet import defer
class PresenceStore(SQLBaseStore): class PresenceStore(SQLBaseStore):
@ -87,31 +89,48 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending", desc="add_presence_list_pending",
) )
@defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid): def set_presence_list_accepted(self, observer_localpart, observed_userid):
return self._simple_update_one( result = yield self._simple_update_one(
table="presence_list", table="presence_list",
keyvalues={"user_id": observer_localpart, keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted", desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate(observer_localpart)
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
keyvalues = {"user_id": observer_localpart} if accepted:
if accepted is not None: return self.get_presence_list_accepted(observer_localpart)
keyvalues["accepted"] = accepted else:
keyvalues = {"user_id": observer_localpart}
if accepted is not None:
keyvalues["accepted"] = accepted
return self._simple_select_list(
table="presence_list",
keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"],
desc="get_presence_list",
)
@cached()
def get_presence_list_accepted(self, observer_localpart):
return self._simple_select_list( return self._simple_select_list(
table="presence_list", table="presence_list",
keyvalues=keyvalues, keyvalues={"user_id": observer_localpart, "accepted": True},
retcols=["observed_user_id", "accepted"], retcols=["observed_user_id", "accepted"],
desc="get_presence_list", desc="get_presence_list_accepted",
) )
@defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid): def del_presence_list(self, observer_localpart, observed_userid):
return self._simple_delete_one( yield self._simple_delete_one(
table="presence_list", table="presence_list",
keyvalues={"user_id": observer_localpart, keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
desc="del_presence_list", desc="del_presence_list",
) )
self.get_presence_list_accepted.invalidate(observer_localpart)

View file

@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cached()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_rules_for_user(self, user_name): def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
@ -31,6 +32,7 @@ class PushRuleStore(SQLBaseStore):
"user_name": user_name, "user_name": user_name,
}, },
retcols=PushRuleTable.fields, retcols=PushRuleTable.fields,
desc="get_push_rules_enabled_for_user",
) )
rows.sort( rows.sort(
@ -150,6 +152,10 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
)
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, user_name
) )
@ -182,6 +188,9 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
)
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, user_name
) )
@ -208,17 +217,34 @@ class PushRuleStore(SQLBaseStore):
{'user_name': user_name, 'rule_id': rule_id}, {'user_name': user_name, 'rule_id': rule_id},
desc="delete_push_rule", desc="delete_push_rule",
) )
self.get_push_rules_for_user.invalidate(user_name)
self.get_push_rules_enabled_for_user.invalidate(user_name) self.get_push_rules_enabled_for_user.invalidate(user_name)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled): def set_push_rule_enabled(self, user_name, rule_id, enabled):
yield self._simple_upsert( ret = yield self.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
user_name, rule_id, enabled
)
defer.returnValue(ret)
def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
self._simple_upsert_txn(
txn,
PushRuleEnableTable.table_name, PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id}, {'user_name': user_name, 'rule_id': rule_id},
{'enabled': 1 if enabled else 0}, {'enabled': 1 if enabled else 0},
desc="set_push_rule_enabled", {'id': new_id},
)
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
) )
self.get_push_rules_enabled_for_user.invalidate(user_name)
class RuleNotFoundException(Exception): class RuleNotFoundException(Exception):

View file

@ -77,16 +77,16 @@ class RoomMemberStore(SQLBaseStore):
Returns: Returns:
Deferred: Results in a MembershipEvent or None. Deferred: Results in a MembershipEvent or None.
""" """
def f(txn): return self.runInteraction(
events = self._get_members_events_txn( "get_room_member",
txn, self._get_members_events_txn,
room_id, room_id,
user_id=user_id, user_id=user_id,
) ).addCallback(
self._get_events
return events[0] if events else None ).addCallback(
lambda events: events[0] if events else None
return self.runInteraction("get_room_member", f) )
@cached() @cached()
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
@ -112,15 +112,12 @@ class RoomMemberStore(SQLBaseStore):
Returns: Returns:
list of namedtuples representing the members in this room. list of namedtuples representing the members in this room.
""" """
return self.runInteraction(
def f(txn): "get_room_members",
return self._get_members_events_txn( self._get_members_events_txn,
txn, room_id,
room_id, membership=membership,
membership=membership, ).addCallback(self._get_events)
)
return self.runInteraction("get_room_members", f)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
@ -192,14 +189,14 @@ class RoomMemberStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_members_query", self._get_members_events_txn, "get_members_query", self._get_members_events_txn,
where_clause, where_values where_clause, where_values
) ).addCallbacks(self._get_events)
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn( rows = self._get_members_rows_txn(
txn, txn,
room_id, membership, user_id, room_id, membership, user_id,
) )
return self._get_events_txn(txn, [r["event_id"] for r in rows]) return [r["event_id"] for r in rows]
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?" where_clause = "c.room_id = ?"

View file

@ -0,0 +1,19 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX events_order_topo_stream_room ON events(
topological_ordering, stream_ordering, room_id
);

View file

@ -43,6 +43,7 @@ class StateStore(SQLBaseStore):
* `state_groups_state`: Maps state group to state events. * `state_groups_state`: Maps state group to state events.
""" """
@defer.inlineCallbacks
def get_state_groups(self, event_ids): def get_state_groups(self, event_ids):
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids
@ -71,17 +72,29 @@ class StateStore(SQLBaseStore):
retcol="event_id", retcol="event_id",
) )
state = self._get_events_txn(txn, state_ids) res[group] = state_ids
res[group] = state
return res return res
return self.runInteraction( states = yield self.runInteraction(
"get_state_groups", "get_state_groups",
f, f,
) )
@defer.inlineCallbacks
def c(vals):
vals[:] = yield self._get_events(vals, get_prev_content=False)
yield defer.gatherResults(
[
c(vals)
for vals in states.values()
],
consumeErrors=True,
)
defer.returnValue(states)
def _store_state_groups_txn(self, txn, event, context): def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None: if context.current_state is None:
return return
@ -152,11 +165,12 @@ class StateStore(SQLBaseStore):
args = (room_id, ) args = (room_id, )
txn.execute(sql, args) txn.execute(sql, args)
results = self.cursor_to_dict(txn) results = txn.fetchall()
return self._parse_events_txn(txn, results) return [r[0] for r in results]
events = yield self.runInteraction("get_current_state", f) event_ids = yield self.runInteraction("get_current_state", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@cached(num_args=3) @cached(num_args=3)

View file

@ -37,11 +37,9 @@ from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from collections import namedtuple
import logging import logging
@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological" _TOPOLOGICAL_TOKEN = "topological"
class _StreamToken(namedtuple("_StreamToken", "topological stream")): def lower_bound(token):
"""Tokens are positions between events. The token "s1" comes after event 1. if token.topological is None:
return "(%d < %s)" % (token.stream, "stream_ordering")
else:
return "(%d < %s OR (%d = %s AND %d < %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, "stream_ordering",
)
s0 s1
| |
[0] V [1] V [2]
Tokens can either be a point in the live event stream or a cursor going def upper_bound(token):
through historic events. if token.topological is None:
return "(%d >= %s)" % (token.stream, "stream_ordering")
When traversing the live event stream events are ordered by when they else:
arrived at the homeserver. return "(%d > %s OR (%d = %s AND %d >= %s))" % (
token.topological, "topological_ordering",
When traversing historic events the events are ordered by their depth in token.topological, "topological_ordering",
the event graph "topological_ordering" and then by when they arrived at the token.stream, "stream_ordering",
homeserver "stream_ordering". )
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
@classmethod
def parse(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
if string[0] == 't':
parts = string[1:].split('-', 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@classmethod
def parse_stream_token(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def __str__(self):
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
def lower_bound(self):
if self.topological is None:
return "(%d < %s)" % (self.stream, "stream_ordering")
else:
return "(%d < %s OR (%d = %s AND %d < %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
)
def upper_bound(self):
if self.topological is None:
return "(%d >= %s)" % (self.stream, "stream_ordering")
else:
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
)
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore):
limit = MAX_STREAM_SIZE limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering. # From and to keys should be integers from ordering.
from_id = _StreamToken.parse_stream_token(from_key) from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key) to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key: if from_key == to_key:
defer.returnValue(([], to_key)) defer.returnValue(([], to_key))
@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore):
limit = MAX_STREAM_SIZE limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering. # From and to keys should be integers from ordering.
from_id = _StreamToken.parse_stream_token(from_key) from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key) to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key: if from_key == to_key:
return defer.succeed(([], to_key)) return defer.succeed(([], to_key))
@ -276,7 +224,7 @@ class StreamStore(SQLBaseStore):
return self.runInteraction("get_room_events_stream", f) return self.runInteraction("get_room_events_stream", f)
@log_function @defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None, def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1, direction='b', limit=-1,
with_feedback=False): with_feedback=False):
@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore):
args = [False, room_id] args = [False, room_id]
if direction == 'b': if direction == 'b':
order = "DESC" order = "DESC"
bounds = _StreamToken.parse(from_key).upper_bound() bounds = upper_bound(RoomStreamToken.parse(from_key))
if to_key: if to_key:
bounds = "%s AND %s" % ( bounds = "%s AND %s" % (
bounds, _StreamToken.parse(to_key).lower_bound() bounds, lower_bound(RoomStreamToken.parse(to_key))
) )
else: else:
order = "ASC" order = "ASC"
bounds = _StreamToken.parse(from_key).lower_bound() bounds = lower_bound(RoomStreamToken.parse(from_key))
if to_key: if to_key:
bounds = "%s AND %s" % ( bounds = "%s AND %s" % (
bounds, _StreamToken.parse(to_key).upper_bound() bounds, upper_bound(RoomStreamToken.parse(to_key))
) )
if int(limit) > 0: if int(limit) > 0:
@ -333,28 +281,30 @@ class StreamStore(SQLBaseStore):
# when we are going backwards so we subtract one from the # when we are going backwards so we subtract one from the
# stream part. # stream part.
toke -= 1 toke -= 1
next_token = str(_StreamToken(topo, toke)) next_token = str(RoomStreamToken(topo, toke))
else: else:
# TODO (erikj): We should work out what to do here instead. # TODO (erikj): We should work out what to do here instead.
next_token = to_key if to_key else from_key next_token = to_key if to_key else from_key
events = self._get_events_txn( return rows, next_token,
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(events, rows) rows, token = yield self.runInteraction("paginate_room_events", f)
return events, next_token, events = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
return self.runInteraction("paginate_room_events", f) self._set_before_and_after(events, rows)
defer.returnValue((events, token))
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token, def get_recent_events_for_room(self, room_id, limit, end_token,
with_feedback=False, from_token=None): with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback # TODO (erikj): Handle compressed feedback
end_token = _StreamToken.parse_stream_token(end_token) end_token = RoomStreamToken.parse_stream_token(end_token)
if from_token is None: if from_token is None:
sql = ( sql = (
@ -365,7 +315,7 @@ class StreamStore(SQLBaseStore):
" LIMIT ?" " LIMIT ?"
) )
else: else:
from_token = _StreamToken.parse_stream_token(from_token) from_token = RoomStreamToken.parse_stream_token(from_token)
sql = ( sql = (
"SELECT stream_ordering, topological_ordering, event_id" "SELECT stream_ordering, topological_ordering, event_id"
" FROM events" " FROM events"
@ -395,30 +345,49 @@ class StreamStore(SQLBaseStore):
# stream part. # stream part.
topo = rows[0]["topological_ordering"] topo = rows[0]["topological_ordering"]
toke = rows[0]["stream_ordering"] - 1 toke = rows[0]["stream_ordering"] - 1
start_token = str(_StreamToken(topo, toke)) start_token = str(RoomStreamToken(topo, toke))
token = (start_token, str(end_token)) token = (start_token, str(end_token))
else: else:
token = (str(end_token), str(end_token)) token = (str(end_token), str(end_token))
events = self._get_events_txn( return rows, token
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(events, rows) rows, token = yield self.runInteraction(
return events, token
return self.runInteraction(
"get_recent_events_for_room", get_recent_events_for_room_txn "get_recent_events_for_room", get_recent_events_for_room_txn
) )
logger.debug("stream before")
events = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
logger.debug("stream after")
self._set_before_and_after(events, rows)
defer.returnValue((events, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self): def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token(self) token = yield self._stream_id_gen.get_max_token(self)
defer.returnValue("s%d" % (token,)) if direction != 'b':
defer.returnValue("s%d" % (token,))
else:
topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn
)
defer.returnValue("t%d-%d" % (topo, token))
def _get_max_topological_txn(self, txn):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
" WHERE outlier = ?",
(False,)
)
rows = txn.fetchall()
return rows[0][0] if rows else 0
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_min_token(self): def _get_min_token(self):
@ -439,5 +408,5 @@ class StreamStore(SQLBaseStore):
stream = row["stream_ordering"] stream = row["stream_ordering"]
topo = event.depth topo = event.depth
internal = event.internal_metadata internal = event.internal_metadata
internal.before = str(_StreamToken(topo, stream - 1)) internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(_StreamToken(topo, stream)) internal.after = str(RoomStreamToken(topo, stream))

View file

@ -78,14 +78,18 @@ class StreamIdGenerator(object):
self._current_max = None self._current_max = None
self._unfinished_ids = deque() self._unfinished_ids = deque()
def get_next_txn(self, txn): @defer.inlineCallbacks
def get_next(self, store):
""" """
Usage: Usage:
with stream_id_gen.get_next_txn(txn) as stream_id: with yield stream_id_gen.get_next as stream_id:
# ... persist event ... # ... persist event ...
""" """
if not self._current_max: if not self._current_max:
self._get_or_compute_current_max(txn) yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock: with self._lock:
self._current_max += 1 self._current_max += 1
@ -101,7 +105,7 @@ class StreamIdGenerator(object):
with self._lock: with self._lock:
self._unfinished_ids.remove(next_id) self._unfinished_ids.remove(next_id)
return manager() defer.returnValue(manager())
@defer.inlineCallbacks @defer.inlineCallbacks
def get_max_token(self, store): def get_max_token(self, store):

View file

@ -31,7 +31,7 @@ class NullSource(object):
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
return defer.succeed(([], from_key)) return defer.succeed(([], from_key))
def get_current_key(self): def get_current_key(self, direction='f'):
return defer.succeed(0) return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key): def get_pagination_rows(self, user, pagination_config, key):
@ -52,10 +52,10 @@ class EventSources(object):
} }
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_token(self): def get_current_token(self, direction='f'):
token = StreamToken( token = StreamToken(
room_key=( room_key=(
yield self.sources["room"].get_current_key() yield self.sources["room"].get_current_key(direction)
), ),
presence_key=( presence_key=(
yield self.sources["presence"].get_current_key() yield self.sources["presence"].get_current_key()

View file

@ -70,6 +70,8 @@ class DomainSpecificString(
"""Return a string encoding the fields of the structure object.""" """Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
__str__ = to_string
@classmethod @classmethod
def create(cls, localpart, domain,): def create(cls, localpart, domain,):
return cls(localpart=localpart, domain=domain) return cls(localpart=localpart, domain=domain)
@ -107,7 +109,6 @@ class StreamToken(
def from_string(cls, string): def from_string(cls, string):
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
return cls(*keys) return cls(*keys)
except: except:
raise SynapseError(400, "Invalid Token") raise SynapseError(400, "Invalid Token")
@ -115,10 +116,95 @@ class StreamToken(
def to_string(self): def to_string(self):
return self._SEPARATOR.join([str(k) for k in self]) return self._SEPARATOR.join([str(k) for k in self])
@property
def room_stream_id(self):
# TODO(markjh): Awful hack to work around hacks in the presence tests
# which assume that the keys are integers.
if type(self.room_key) is int:
return self.room_key
else:
return int(self.room_key[1:].split("-")[-1])
def is_after(self, other_token):
"""Does this token contain events that the other doesn't?"""
return (
(other_token.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key))
)
def copy_and_advance(self, key, new_value):
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
"""
new_token = self.copy_and_replace(key, new_value)
if key == "room_key":
new_id = new_token.room_stream_id
old_id = self.room_stream_id
else:
new_id = int(getattr(new_token, key))
old_id = int(getattr(self, key))
if old_id < new_id:
return new_token
else:
return self
def copy_and_replace(self, key, new_value): def copy_and_replace(self, key, new_value):
d = self._asdict() d = self._asdict()
d[key] = new_value d[key] = new_value
return StreamToken(**d) return StreamToken(**d)
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
| |
[0] V [1] V [2]
Tokens can either be a point in the live event stream or a cursor going
through historic events.
When traversing the live event stream events are ordered by when they
arrived at the homeserver.
When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
@classmethod
def parse(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
if string[0] == 't':
parts = string[1:].split('-', 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@classmethod
def parse_stream_token(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def __str__(self):
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from twisted.internet import defer, reactor, task from twisted.internet import defer, reactor, task
@ -23,6 +23,40 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
return failure.value.subFailure
def unwrap_deferred(d):
"""Given a deferred that we know has completed, return its value or raise
the failure as an exception
"""
if not d.called:
raise RuntimeError("deferred has not finished")
res = []
def f(r):
res.append(r)
return r
d.addCallback(f)
if res:
return res[0]
def f(r):
res.append(r)
return r
d.addErrback(f)
if res:
res[0].raiseException()
else:
raise RuntimeError("deferred did not call callbacks")
class Clock(object): class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.
@ -46,13 +80,16 @@ class Clock(object):
def stop_looping_call(self, loop): def stop_looping_call(self, loop):
loop.stop() loop.stop()
def call_later(self, delay, callback): def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(): def wrapped_callback(*args, **kwargs):
LoggingContext.thread_local.current_context = current_context with PreserveLoggingContext():
callback() LoggingContext.thread_local.current_context = current_context
return reactor.callLater(delay, wrapped_callback) callback(*args, **kwargs)
with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer): def cancel_call_later(self, timer):
timer.cancel() timer.cancel()

View file

@ -16,15 +16,13 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from .logcontext import PreserveLoggingContext from .logcontext import preserve_context_over_deferred
@defer.inlineCallbacks
def sleep(seconds): def sleep(seconds):
d = defer.Deferred() d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds) reactor.callLater(seconds, d.callback, seconds)
with PreserveLoggingContext(): return preserve_context_over_deferred(d)
yield d
def run_on_reactor(): def run_on_reactor():
@ -34,20 +32,56 @@ def run_on_reactor():
return sleep(0) return sleep(0)
def create_observer(deferred): class ObservableDeferred(object):
"""Creates a deferred that observes the result or failure of the given """Wraps a deferred object so that we can add observer deferreds. These
deferred *without* affecting the given deferred. observer deferreds do not affect the callback chain of the original
deferred.
If consumeErrors is true errors will be captured from the origin deferred.
""" """
d = defer.Deferred()
def callback(r): __slots__ = ["_deferred", "_observers", "_result"]
d.callback(r)
return r
def errback(f): def __init__(self, deferred, consumeErrors=False):
d.errback(f) object.__setattr__(self, "_deferred", deferred)
return f object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", [])
deferred.addCallbacks(callback, errback) def callback(r):
self._result = (True, r)
while self._observers:
try:
self._observers.pop().callback(r)
except:
pass
return r
return d def errback(f):
self._result = (False, f)
while self._observers:
try:
self._observers.pop().errback(f)
except:
pass
if consumeErrors:
return None
else:
return f
deferred.addCallbacks(callback, errback)
def observe(self):
if not self._result:
d = defer.Deferred()
self._observers.append(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
def __getattr__(self, name):
return getattr(self._deferred, name)
def __setattr__(self, name, value):
setattr(self._deferred, name, value)

View file

@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred,
)
from synapse.util import unwrapFirstError
import logging import logging
@ -93,7 +97,6 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -101,24 +104,28 @@ class Signal(object):
Returns a Deferred that will complete when all the observers have Returns a Deferred that will complete when all the observers have
completed.""" completed."""
with PreserveLoggingContext():
deferreds = []
for observer in self.observers:
d = defer.maybeDeferred(observer, *args, **kwargs)
def eb(failure): def do(observer):
logger.warning( def eb(failure):
"%s signal observer %s failed: %r", logger.warning(
self.name, observer, failure, "%s signal observer %s failed: %r",
exc_info=( self.name, observer, failure,
failure.type, exc_info=(
failure.value, failure.type,
failure.getTracebackObject())) failure.value,
if not self.suppress_failures: failure.getTracebackObject()))
failure.raiseException() if not self.suppress_failures:
deferreds.append(d.addErrback(eb)) return failure
results = [] return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
for deferred in deferreds:
result = yield deferred with PreserveLoggingContext():
results.append(result) deferreds = [
defer.returnValue(results) do(observer)
for observer in self.observers
]
d = defer.gatherResults(deferreds, consumeErrors=True)
d.addErrback(unwrapFirstError)
return preserve_context_over_deferred(d)

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import threading import threading
import logging import logging
@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
"""Restores the current logging context""" """Restores the current logging context"""
LoggingContext.thread_local.current_context = self.current_context LoggingContext.thread_local.current_context = self.current_context
if self.current_context is not LoggingContext.sentinel:
if self.current_context.parent_context is None:
logger.warn(
"Restoring dead context: %s",
self.current_context,
)
def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so.
If the result is a deferred, call preserve_context_over_deferred before
returning it.
"""
with PreserveLoggingContext():
res = fn(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(res)
else:
return res
def preserve_context_over_deferred(deferred):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
d = defer.Deferred()
current_context = LoggingContext.current_context()
def cb(res):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.callback(res)
return res
def eb(failure):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.errback(failure)
return res
if deferred.called:
return deferred
deferred.addCallbacks(cb, eb)
return d

View file

@ -217,18 +217,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*") _regex("@irc_.*")
) )
join_list = [ join_list = [
Mock( "@alice:here",
type="m.room.member", room_id="!foo:bar", sender="@alice:here", "@irc_fo:here", # AS user
state_key="@alice:here" "@bob:here",
),
Mock(
type="m.room.member", room_id="!foo:bar", sender="@irc_fo:here",
state_key="@irc_fo:here" # AS user
),
Mock(
type="m.room.member", room_id="!foo:bar", sender="@bob:here",
state_key="@bob:here"
)
] ]
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"

View file

@ -83,7 +83,7 @@ class FederationTestCase(unittest.TestCase):
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"}, "hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
}) })
self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.persist_event.return_value = defer.succeed((1,1))
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
@ -126,5 +126,5 @@ class FederationTestCase(unittest.TestCase):
self.auth.check.assert_called_once_with(ANY, auth_events={}) self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
ANY, extra_users=[] ANY, 1, 1, extra_users=[]
) )

View file

@ -233,7 +233,7 @@ class MockedDatastorePresenceTestCase(PresenceTestCase):
if not user_localpart in self.PRESENCE_LIST: if not user_localpart in self.PRESENCE_LIST:
return defer.succeed([]) return defer.succeed([])
return defer.succeed([ return defer.succeed([
{"observed_user_id": u} for u in {"observed_user_id": u, "accepted": accepted} for u in
self.PRESENCE_LIST[user_localpart]]) self.PRESENCE_LIST[user_localpart]])
datastore.get_presence_list = get_presence_list datastore.get_presence_list = get_presence_list
@ -624,6 +624,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
""" """
PRESENCE_LIST = { PRESENCE_LIST = {
'apple': [ "@banana:test", "@clementine:test" ], 'apple': [ "@banana:test", "@clementine:test" ],
'banana': [ "@apple:test" ],
} }
@defer.inlineCallbacks @defer.inlineCallbacks
@ -733,10 +734,12 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals( self.assertEquals(
[ [
{"observed_user": self.u_banana, {"observed_user": self.u_banana,
"presence": OFFLINE}, "presence": OFFLINE,
"accepted": True},
{"observed_user": self.u_clementine, {"observed_user": self.u_clementine,
"presence": OFFLINE}, "presence": OFFLINE,
"accepted": True},
], ],
presence presence
) )
@ -757,9 +760,11 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals([ self.assertEquals([
{"observed_user": self.u_banana, {"observed_user": self.u_banana,
"presence": ONLINE, "presence": ONLINE,
"last_active_ago": 2000}, "last_active_ago": 2000,
"accepted": True},
{"observed_user": self.u_clementine, {"observed_user": self.u_clementine,
"presence": OFFLINE}, "presence": OFFLINE,
"accepted": True},
], presence) ], presence)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events_for_user(
@ -836,12 +841,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_recv_remote(self): def test_recv_remote(self):
# TODO(paul): Gut-wrenching self.room_members = [self.u_apple, self.u_banana, self.u_potato]
potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
set())
potato_set.add(self.u_apple)
self.room_members = [self.u_banana, self.u_potato]
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)
@ -886,11 +886,8 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_recv_remote_offline(self): def test_recv_remote_offline(self):
""" Various tests relating to SYN-261 """ """ Various tests relating to SYN-261 """
potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
set())
potato_set.add(self.u_apple)
self.room_members = [self.u_banana, self.u_potato] self.room_members = [self.u_apple, self.u_banana, self.u_potato]
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)
@ -1097,12 +1094,8 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
# apple should see both banana and clementine currently offline # apple should see both banana and clementine currently offline
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=[self.u_apple], call(users_to_push=[self.u_apple]),
observed_user=self.u_banana, call(users_to_push=[self.u_apple]),
statuscache=ANY),
call(users_to_push=[self.u_apple],
observed_user=self.u_clementine,
statuscache=ANY),
], any_order=True) ], any_order=True)
# Gut-wrenching tests # Gut-wrenching tests
@ -1121,13 +1114,8 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
# apple and banana should now both see each other online # apple and banana should now both see each other online
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=set([self.u_apple]), call(users_to_push=set([self.u_apple]), room_ids=[]),
observed_user=self.u_banana, call(users_to_push=[self.u_banana]),
room_ids=[],
statuscache=ANY),
call(users_to_push=[self.u_banana],
observed_user=self.u_apple,
statuscache=ANY),
], any_order=True) ], any_order=True)
self.assertTrue("apple" in self.handler._local_pushmap) self.assertTrue("apple" in self.handler._local_pushmap)
@ -1143,10 +1131,7 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
# banana should now be told apple is offline # banana should now be told apple is offline
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=set([self.u_banana, self.u_apple]), call(users_to_push=set([self.u_banana, self.u_apple]), room_ids=[]),
observed_user=self.u_apple,
room_ids=[],
statuscache=ANY),
], any_order=True) ], any_order=True)
self.assertFalse("banana" in self.handler._local_pushmap) self.assertFalse("banana" in self.handler._local_pushmap)

View file

@ -101,8 +101,8 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.datastore.get_profile_avatar_url = get_profile_avatar_url self.datastore.get_profile_avatar_url = get_profile_avatar_url
self.presence_list = [ self.presence_list = [
{"observed_user_id": "@banana:test"}, {"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test"}, {"observed_user_id": "@clementine:test", "accepted": True},
] ]
def get_presence_list(user_localpart, accepted=None): def get_presence_list(user_localpart, accepted=None):
return defer.succeed(self.presence_list) return defer.succeed(self.presence_list)
@ -144,8 +144,8 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_state(self): def test_set_my_state(self):
self.presence_list = [ self.presence_list = [
{"observed_user_id": "@banana:test"}, {"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test"}, {"observed_user_id": "@clementine:test", "accepted": True},
] ]
mocked_set = self.datastore.set_presence_state mocked_set = self.datastore.set_presence_state
@ -167,8 +167,8 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.mock_get_joined.side_effect = get_joined self.mock_get_joined.side_effect = get_joined
self.presence_list = [ self.presence_list = [
{"observed_user_id": "@banana:test"}, {"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test"}, {"observed_user_id": "@clementine:test", "accepted": True},
] ]
self.datastore.set_presence_state.return_value = defer.succeed( self.datastore.set_presence_state.return_value = defer.succeed(
@ -203,26 +203,20 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
"presence": ONLINE, "presence": ONLINE,
"last_active_ago": 0, "last_active_ago": 0,
"displayname": "Frank", "displayname": "Frank",
"avatar_url": "http://foo"}, "avatar_url": "http://foo",
"accepted": True},
{"observed_user": self.u_clementine, {"observed_user": self.u_clementine,
"presence": OFFLINE} "presence": OFFLINE,
"accepted": True}
], presence) ], presence)
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=set([self.u_apple, self.u_banana, self.u_clementine]), call(
room_ids=[], users_to_push={self.u_apple, self.u_banana, self.u_clementine},
observed_user=self.u_apple, room_ids=[]
statuscache=ANY), # self-reflection ),
], any_order=True) ], any_order=True)
statuscache = self.mock_update_client.call_args[1]["statuscache"]
self.assertEquals({
"presence": ONLINE,
"last_active": 1000000, # MockClock
"displayname": "Frank",
"avatar_url": "http://foo",
}, statuscache.state)
self.mock_update_client.reset_mock() self.mock_update_client.reset_mock()
self.datastore.set_profile_displayname.return_value = defer.succeed( self.datastore.set_profile_displayname.return_value = defer.succeed(
@ -232,25 +226,16 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.u_apple, "I am an Apple") self.u_apple, "I am an Apple")
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=set([self.u_apple, self.u_banana, self.u_clementine]), call(
users_to_push={self.u_apple, self.u_banana, self.u_clementine},
room_ids=[], room_ids=[],
observed_user=self.u_apple, ),
statuscache=ANY), # self-reflection
], any_order=True) ], any_order=True)
statuscache = self.mock_update_client.call_args[1]["statuscache"]
self.assertEquals({
"presence": ONLINE,
"last_active": 1000000, # MockClock
"displayname": "I am an Apple",
"avatar_url": "http://foo",
}, statuscache.state)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_remote(self): def test_push_remote(self):
self.presence_list = [ self.presence_list = [
{"observed_user_id": "@potato:remote"}, {"observed_user_id": "@potato:remote", "accepted": True},
] ]
self.datastore.set_presence_state.return_value = defer.succeed( self.datastore.set_presence_state.return_value = defer.succeed(
@ -314,13 +299,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.mock_update_client.assert_called_with( self.mock_update_client.assert_called_with(
users_to_push=set([self.u_apple]), users_to_push=set([self.u_apple]),
room_ids=[], room_ids=[],
observed_user=self.u_potato, )
statuscache=ANY)
statuscache = self.mock_update_client.call_args[1]["statuscache"]
self.assertEquals({"presence": ONLINE,
"displayname": "Frank",
"avatar_url": "http://foo"}, statuscache.state)
state = yield self.handlers.presence_handler.get_state(self.u_potato, state = yield self.handlers.presence_handler.get_state(self.u_potato,
self.u_apple) self.u_apple)

View file

@ -87,6 +87,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite(self): def test_invite(self):
room_id = "!foo:red" room_id = "!foo:red"
@ -160,7 +162,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event, context=context, event, context=context,
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[UserID.from_string(target_user_id)] event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
) )
self.assertFalse(self.datastore.get_room.called) self.assertFalse(self.datastore.get_room.called)
self.assertFalse(self.datastore.store_room.called) self.assertFalse(self.datastore.store_room.called)
@ -226,7 +228,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event, context=context event, context=context
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user] event, 1, 1, extra_users=[user]
) )
join_signal_observer.assert_called_with( join_signal_observer.assert_called_with(
@ -304,7 +306,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event, context=context event, context=context
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user] event, 1, 1, extra_users=[user]
) )
leave_signal_observer.assert_called_with( leave_signal_observer.assert_called_with(

View file

@ -183,7 +183,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -246,7 +246,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -300,7 +300,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
yield put_json.await_calls() yield put_json.await_calls()
@ -332,7 +332,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
self.on_new_user_event.reset_mock() self.on_new_user_event.reset_mock()
@ -352,7 +352,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.clock.advance_time(11) self.clock.advance_time(11)
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 2, rooms=[self.room_id]),
]) ])
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
@ -378,7 +378,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
self.on_new_user_event.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]), call('typing_key', 3, rooms=[self.room_id]),
]) ])
self.on_new_user_event.reset_mock() self.on_new_user_event.reset_mock()

View file

@ -27,6 +27,9 @@ from synapse.handlers.presence import PresenceHandler
from synapse.rest.client.v1 import presence from synapse.rest.client.v1 import presence
from synapse.rest.client.v1 import events from synapse.rest.client.v1 import events
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor
from collections import namedtuple
OFFLINE = PresenceState.OFFLINE OFFLINE = PresenceState.OFFLINE
@ -180,7 +183,7 @@ class PresenceListTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_list(self): def test_get_my_list(self):
self.datastore.get_presence_list.return_value = defer.succeed( self.datastore.get_presence_list.return_value = defer.succeed(
[{"observed_user_id": "@banana:test"}], [{"observed_user_id": "@banana:test", "accepted": True}],
) )
(code, response) = yield self.mock_resource.trigger("GET", (code, response) = yield self.mock_resource.trigger("GET",
@ -188,7 +191,7 @@ class PresenceListTestCase(unittest.TestCase):
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals([ self.assertEquals([
{"user_id": "@banana:test", "presence": OFFLINE}, {"user_id": "@banana:test", "presence": OFFLINE, "accepted": True},
], response) ], response)
self.datastore.get_presence_list.assert_called_with( self.datastore.get_presence_list.assert_called_with(
@ -264,11 +267,13 @@ class PresenceEventStreamTestCase(unittest.TestCase):
datastore=Mock(spec=[ datastore=Mock(spec=[
"set_presence_state", "set_presence_state",
"get_presence_list", "get_presence_list",
"get_rooms_for_user",
]), ]),
clock=Mock(spec=[ clock=Mock(spec=[
"call_later", "call_later",
"cancel_call_later", "cancel_call_later",
"time_msec", "time_msec",
"looping_call",
]), ]),
) )
@ -292,12 +297,21 @@ class PresenceEventStreamTestCase(unittest.TestCase):
else: else:
return [] return []
hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
hs.handlers.room_member_handler.get_room_members = (
lambda r: self.room_members if r == "a-room" else []
)
self.mock_datastore = hs.get_datastore() self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None) self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
self.mock_datastore.get_app_service_by_user_id = Mock( self.mock_datastore.get_app_service_by_user_id = Mock(
return_value=defer.succeed(None) return_value=defer.succeed(None)
) )
self.mock_datastore.get_rooms_for_user = (
lambda u: [
namedtuple("Room", "room_id")(r)
for r in get_rooms_for_user(UserID.from_string(u))
]
)
def get_profile_displayname(user_id): def get_profile_displayname(user_id):
return defer.succeed("Frank") return defer.succeed("Frank")
@ -350,19 +364,19 @@ class PresenceEventStreamTestCase(unittest.TestCase):
self.mock_datastore.set_presence_state.return_value = defer.succeed( self.mock_datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE} {"state": ONLINE}
) )
self.mock_datastore.get_presence_list.return_value = defer.succeed( self.mock_datastore.get_presence_list.return_value = defer.succeed([])
[]
)
yield self.presence.set_state(self.u_banana, self.u_banana, yield self.presence.set_state(self.u_banana, self.u_banana,
state={"presence": ONLINE} state={"presence": ONLINE}
) )
yield run_on_reactor()
(code, response) = yield self.mock_resource.trigger("GET", (code, response) = yield self.mock_resource.trigger("GET",
"/events?from=0_1_0&timeout=0", None) "/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"start": "0_1_0", "end": "0_2_0", "chunk": [ self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
{"type": "m.presence", {"type": "m.presence",
"content": { "content": {
"user_id": "@banana:test", "user_id": "@banana:test",

View file

@ -33,8 +33,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db_pool = Mock(spec=["runInteraction"]) self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock() self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor"]) self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
self.mock_conn.cursor.return_value = self.mock_txn self.mock_conn.cursor.return_value = self.mock_txn
self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline # Our fake runInteraction just runs synchronously inline
def runInteraction(func, *args, **kwargs): def runInteraction(func, *args, **kwargs):

View file

@ -197,6 +197,9 @@ class MockClock(object):
return t return t
def looping_call(self, function, interval):
pass
def cancel_call_later(self, timer): def cancel_call_later(self, timer):
if timer[2]: if timer[2]:
raise Exception("Cannot cancel an expired timer") raise Exception("Cannot cancel an expired timer")
@ -355,7 +358,7 @@ class MemoryDataStore(object):
return [] return []
def get_room_events_max_id(self): def get_room_events_max_id(self):
return 0 # TODO (erikj) return "s0" # TODO (erikj)
def get_send_event_level(self, room_id): def get_send_event_level(self, room_id):
return defer.succeed(0) return defer.succeed(0)