From a24bc5b2dc3a5d81cdfbe7be367dbb461d85b999 Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 23 May 2016 18:33:51 +0100 Subject: [PATCH 001/116] Add GET /notifications API --- synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/notifications.py | 100 ++++++++++++++++++ synapse/storage/event_push_actions.py | 28 +++++ synapse/storage/receipts.py | 25 +++++ 4 files changed, 155 insertions(+) create mode 100644 synapse/rest/client/v2_alpha/notifications.py diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 8b223e032..c729dee47 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import ( account_data, report_event, openid, + notifications, ) from synapse.http.server import JsonResource @@ -90,3 +91,4 @@ class ClientRestResource(JsonResource): account_data.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) + notifications.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py new file mode 100644 index 000000000..505e99839 --- /dev/null +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import ( + RestServlet, parse_string, parse_integer +) +from synapse.events.utils import ( + serialize_event, format_event_for_client_v2_without_room_id, +) + +from ._base import client_v2_patterns + +import logging + +logger = logging.getLogger(__name__) + + +class NotificationsServlet(RestServlet): + PATTERNS = client_v2_patterns("/notifications$", releases=()) + + def __init__(self, hs): + super(NotificationsServlet, self).__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + from_token = parse_string(request, "from", required=False) + limit = parse_integer(request, "limit", default=50) + + limit = min(limit, 500) + + push_actions = yield self.store.get_push_actions_for_user( + user_id, from_token, limit + ) + + receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + user_id, 'm.read' + ) + + notif_event_ids = [pa["event_id"] for pa in push_actions] + notif_events = yield self.store.get_events(notif_event_ids) + + returned_push_actions = [] + + next_token = None + + for pa in push_actions: + returned_pa = { + "room_id": pa["room_id"], + "profile_tag": pa["profile_tag"], + "actions": pa["actions"], + "event": serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ), + } + + if pa["room_id"] not in receipts_by_room: + returned_pa["read"] = False + else: + receipt = receipts_by_room[pa["room_id"]] + + returned_pa["read"] = ( + pa["topological_ordering"] > receipt["topological_ordering"] + or ( + pa["topological_ordering"] == receipt["topological_ordering"] + and pa["stream_ordering"] > receipt["stream_ordering"] + ) + ) + returned_push_actions.append(returned_pa) + next_token = pa["stream_ordering"] + + defer.returnValue((200, { + "notifications": returned_push_actions, + "next_token": next_token, + })) + + +def register_servlets(hs, http_server): + NotificationsServlet(hs).register(http_server) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 4dae51a17..a9cb042b5 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -191,6 +191,34 @@ class EventPushActionsStore(SQLBaseStore): } for row in after_read_receipt + no_read_receipt ]) + @defer.inlineCallbacks + def get_push_actions_for_user(self, user_id, before=None, limit=50): + def f(txn): + before_clause = "" + if before: + before_clause = "AND stream_ordering < ?" + args = [user_id, before, limit] + else: + args = [user_id, limit] + sql = ( + "SELECT event_id, room_id, stream_ordering, topological_ordering," + " actions, profile_tag" + " FROM event_push_actions" + " WHERE user_id = ? %s" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + % (before_clause,) + ) + txn.execute(sql, args) + return self.cursor_to_dict(txn) + + push_actions = yield self.runInteraction( + "get_push_actions_for_user", f + ) + for pa in push_actions: + pa["actions"] = json.loads(pa["actions"]) + defer.returnValue(push_actions) + @defer.inlineCallbacks def get_time_of_last_push_action_before(self, stream_ordering): def f(txn): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f1774f0e4..d147a6060 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -74,6 +74,31 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) + @defer.inlineCallbacks + def get_receipts_for_user_with_orderings(self, user_id, receipt_type): + def f(txn): + sql = ( + "SELECT rl.room_id, rl.event_id," + " e.topological_ordering, e.stream_ordering" + " FROM receipts_linearized rl," + " events e" + " WHERE rl.room_id = e.room_id" + " AND rl.event_id = e.event_id" + " AND user_id = ?" + ) + txn.execute(sql, (user_id,)) + return txn.fetchall() + rows = yield self.runInteraction( + "get_receipts_for_user_with_orderings", f + ) + defer.returnValue({ + row[0]: { + "event_id": row[1], + "topological_ordering": row[2], + "stream_ordering": row[3], + } for row in rows + }) + @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): """Get receipts for multiple rooms for sending to clients. From b791a530da1d89e36297cb626950cb42a7ea9226 Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 23 May 2016 18:48:02 +0100 Subject: [PATCH 002/116] Actually make the 'read' flag correct --- synapse/rest/client/v2_alpha/notifications.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 505e99839..960025696 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -81,10 +81,9 @@ class NotificationsServlet(RestServlet): receipt = receipts_by_room[pa["room_id"]] returned_pa["read"] = ( - pa["topological_ordering"] > receipt["topological_ordering"] - or ( - pa["topological_ordering"] == receipt["topological_ordering"] - and pa["stream_ordering"] > receipt["stream_ordering"] + receipt["topological_ordering"] >= pa["topological_ordering"] or ( + receipt["topological_ordering"] == pa["topological_ordering"] and + receipt["stream_ordering"] >= pa["stream_ordering"] ) ) returned_push_actions.append(returned_pa) From 37b7e846200f00a36c6084d426ab73ee5d0e0218 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 24 May 2016 11:33:32 +0100 Subject: [PATCH 003/116] Include the ts the notif was received at --- synapse/rest/client/v2_alpha/notifications.py | 1 + synapse/storage/event_push_actions.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 960025696..4d84230e6 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -68,6 +68,7 @@ class NotificationsServlet(RestServlet): "room_id": pa["room_id"], "profile_tag": pa["profile_tag"], "actions": pa["actions"], + "ts": pa["received_ts"], "event": serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index a9cb042b5..5123072c4 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -201,11 +201,13 @@ class EventPushActionsStore(SQLBaseStore): else: args = [user_id, limit] sql = ( - "SELECT event_id, room_id, stream_ordering, topological_ordering," - " actions, profile_tag" - " FROM event_push_actions" - " WHERE user_id = ? %s" - " ORDER BY stream_ordering DESC" + "SELECT epa.event_id, epa.room_id," + " epa.stream_ordering, epa.topological_ordering," + " epa.actions, epa.profile_tag, e.received_ts" + " FROM event_push_actions epa, events e" + " WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id" + " AND epa.user_id = ? %s" + " ORDER BY epa.stream_ordering DESC" " LIMIT ?" % (before_clause,) ) From 6fe6a6f0299c97086a552eda75570eaa66ff2598 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 16:34:07 +0100 Subject: [PATCH 004/116] Fix login with m.login.token login with token (as used by CAS auth) was broken by 067596d, such that it always returned a 401. --- synapse/api/auth.py | 45 ++++++++++++++++++++----------- synapse/handlers/auth.py | 17 +++--------- synapse/server.pyi | 4 +++ tests/handlers/test_auth.py | 53 ++++++++++++++++++++++++++++++++++--- 4 files changed, 87 insertions(+), 32 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 59db76deb..0db26fcfd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -675,27 +675,18 @@ class Auth(object): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - user_prefix = "user_id = " - user = None - user_id = None - guest = False - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - user_id = caveat.caveat_id[len(user_prefix):] - user = UserID.from_string(user_id) - elif caveat.caveat_id == "guest = true": - guest = True + user_id = self.get_user_id_from_macaroon(macaroon) + user = UserID.from_string(user_id) self.validate_macaroon( macaroon, rights, self.hs.config.expire_access_token, user_id=user_id, ) - if user is None: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) + guest = False + for caveat in macaroon.caveats: + if caveat.caveat_id == "guest = true": + guest = True if guest: ret = { @@ -743,6 +734,29 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) + def get_user_id_from_macaroon(self, macaroon): + """Retrieve the user_id given by the caveats on the macaroon. + + Does *not* validate the macaroon. + + Args: + macaroon (pymacaroons.Macaroon): The macaroon to validate + + Returns: + (str) user id + + Raises: + AuthError if there is no user_id caveat in the macaroon + """ + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + return caveat.caveat_id[len(user_prefix):] + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): """ validate that a Macaroon is understood by and was signed by this server. @@ -754,6 +768,7 @@ class Auth(object): verify_expiry(bool): Whether to verify whether the macaroon has expired. This should really always be True, but no clients currently implement token refresh, so we can't enforce expiry yet. + user_id (str): The user_id required """ v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2e138f328..1d3641b7a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -720,10 +720,11 @@ class AuthHandler(BaseHandler): def validate_short_term_login_token_and_get_user_id(self, login_token): try: - macaroon = pymacaroons.Macaroon.deserialize(login_token) auth_api = self.hs.get_auth() - auth_api.validate_macaroon(macaroon, "login", True) - return self.get_user_from_macaroon(macaroon) + macaroon = pymacaroons.Macaroon.deserialize(login_token) + user_id = auth_api.get_user_id_from_macaroon(macaroon) + auth_api.validate_macaroon(macaroon, "login", True, user_id) + return user_id except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) @@ -736,16 +737,6 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - def get_user_from_macaroon(self, macaroon): - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] - raise AuthError( - self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", - errcode=Codes.UNKNOWN_TOKEN - ) - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) diff --git a/synapse/server.pyi b/synapse/server.pyi index c0aa868c4..9570df553 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,4 @@ +import synapse.api.auth import synapse.handlers import synapse.handlers.auth import synapse.handlers.device @@ -6,6 +7,9 @@ import synapse.storage import synapse.state class HomeServer(object): + def get_auth(self) -> synapse.api.auth.Auth: + pass + def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: pass diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 21077cbe9..366cf97ae 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -14,12 +14,13 @@ # limitations under the License. import pymacaroons +from twisted.internet import defer +import synapse +import synapse.api.errors from synapse.handlers.auth import AuthHandler from tests import unittest from tests.utils import setup_test_homeserver -from twisted.internet import defer - class AuthHandlers(object): def __init__(self, hs): @@ -31,11 +32,12 @@ class AuthTestCase(unittest.TestCase): def setUp(self): self.hs = yield setup_test_homeserver(handlers=None) self.hs.handlers = AuthHandlers(self.hs) + self.auth_handler = self.hs.handlers.auth_handler def test_token_is_a_macaroon(self): self.hs.config.macaroon_secret_key = "this key is a huge secret" - token = self.hs.handlers.auth_handler.generate_access_token("some_user") + token = self.auth_handler.generate_access_token("some_user") # Check that we can parse the thing with pymacaroons macaroon = pymacaroons.Macaroon.deserialize(token) # The most basic of sanity checks @@ -46,7 +48,7 @@ class AuthTestCase(unittest.TestCase): self.hs.config.macaroon_secret_key = "this key is a massive secret" self.hs.clock.now = 5000 - token = self.hs.handlers.auth_handler.generate_access_token("a_user") + token = self.auth_handler.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) def verify_gen(caveat): @@ -67,3 +69,46 @@ class AuthTestCase(unittest.TestCase): v.satisfy_general(verify_type) v.satisfy_general(verify_expiry) v.verify(macaroon, self.hs.config.macaroon_secret_key) + + def test_short_term_login_token_gives_user_id(self): + self.hs.clock.now = 1000 + + token = self.auth_handler.generate_short_term_login_token( + "a_user", 5000 + ) + + self.assertEqual( + "a_user", + self.auth_handler.validate_short_term_login_token_and_get_user_id( + token + ) + ) + + # when we advance the clock, the token should be rejected + self.hs.clock.now = 6000 + with self.assertRaises(synapse.api.errors.AuthError): + self.auth_handler.validate_short_term_login_token_and_get_user_id( + token + ) + + def test_short_term_login_token_cannot_replace_user_id(self): + token = self.auth_handler.generate_short_term_login_token( + "a_user", 5000 + ) + macaroon = pymacaroons.Macaroon.deserialize(token) + + self.assertEqual( + "a_user", + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) + ) + + # add another "user_id" caveat, which might allow us to override the + # user_id. + macaroon.add_first_party_caveat("user_id = b_user") + + with self.assertRaises(synapse.api.errors.AuthError): + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) From 0682ca04b3ac0a3e148633d020b3248dbe98f13d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 17:01:30 +0100 Subject: [PATCH 005/116] Fix CAS login Attempting to log in with CAS was giving a 500 error. --- synapse/rest/client/v1/login.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 92fcae674..d8c76a346 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -427,6 +427,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): From 65666fedd5f60ec65fd86d9bbdff40fa67469025 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 17:17:25 +0100 Subject: [PATCH 006/116] Clean up CAS login code Remove some apparently unused code. Clean up parse_cas_response, mostly to catch the exception if the CAS response isn't valid XML. --- synapse/rest/client/v1/login.py | 158 +++++++------------------------- 1 file changed, 33 insertions(+), 125 deletions(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 92fcae674..fef7910c4 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -54,10 +54,6 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url - self.cas_required_attributes = hs.config.cas_required_attributes - self.servername = hs.config.server_name - self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() self.device_handler = self.hs.get_device_handler() @@ -110,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) defer.returnValue(result) - # TODO Delete this after all CAS clients switch to token login instead - elif self.cas_enabled and (login_submission["type"] == - LoginRestServlet.CAS_TYPE): - uri = "%s/proxyValidate" % (self.cas_server_url,) - args = { - "ticket": login_submission["ticket"], - "service": login_submission["service"] - } - body = yield self.http_client.get_raw(uri, args) - result = yield self.do_cas_login(body) - defer.returnValue(result) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) defer.returnValue(result) @@ -191,51 +176,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - @defer.inlineCallbacks - def do_cas_login(self, cas_response_body): - user, attributes = self.parse_cas_response(cas_response_body) - - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.auth_handler - registered_user_id = yield auth_handler.check_user_exists(user_id) - if registered_user_id: - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - registered_user_id - ) - ) - result = { - "user_id": registered_user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - @defer.inlineCallbacks def do_jwt_login(self, login_submission): token = login_submission.get("token", None) @@ -293,33 +233,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) - def _register_device(self, user_id, login_submission): """Register a device for a user. @@ -384,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) -# TODO Delete this after all CAS clients switch to token login instead -class CasRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas", releases=()) - - def __init__(self, hs): - super(CasRestServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - - def on_GET(self, request): - return (200, {"serverUrl": self.cas_server_url}) - - class CasRedirectServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) @@ -479,30 +380,39 @@ class CasTicketServlet(ClientV1RestServlet): return urlparse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + user = None + attributes = None + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = (root[0].tag.endswith("authenticationSuccess")) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + if attributes is None: + raise Exception("CAS response does not contain attributes") + except Exception: + logger.error("Error parsing CAS response", exc_info=1) + raise LoginError(401, "Invalid CAS response", + errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError(401, "Unsuccessful CAS response", + errcode=Codes.UNAUTHORIZED) + return user, attributes def register_servlets(hs, http_server): @@ -512,5 +422,3 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - CasRestServlet(hs).register(http_server) - # TODO PasswordResetRestServlet(hs).register(http_server) From a8bcc7274d42e27c9d2966e908396cda91c379f4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 17:20:38 +0100 Subject: [PATCH 007/116] PEP8 --- tests/handlers/test_auth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 366cf97ae..4a8cd19ac 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -22,6 +22,7 @@ from synapse.handlers.auth import AuthHandler from tests import unittest from tests.utils import setup_test_homeserver + class AuthHandlers(object): def __init__(self, hs): self.auth_handler = AuthHandler(hs) From 79ebfbe7c62400a2b63d67fb65b1abce29d8bf38 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 9 Aug 2016 16:29:28 +0100 Subject: [PATCH 008/116] /login: Respond with a 403 when we get an invalid m.login.token --- synapse/handlers/auth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1d3641b7a..82998a81c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -719,14 +719,14 @@ class AuthHandler(BaseHandler): return macaroon.serialize() def validate_short_term_login_token_and_get_user_id(self, login_token): + auth_api = self.hs.get_auth() try: - auth_api = self.hs.get_auth() macaroon = pymacaroons.Macaroon.deserialize(login_token) user_id = auth_api.get_user_id_from_macaroon(macaroon) auth_api.validate_macaroon(macaroon, "login", True, user_id) return user_id - except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): - raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) + except Exception: + raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( From fa1ce4d8adb307f9b325b4e8c125f78a673c4987 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 10:44:37 +0100 Subject: [PATCH 009/116] Don't print stack traces when failing to get remote keys --- synapse/crypto/keyring.py | 28 ++++++++++++---------- synapse/rest/key/v2/remote_key_resource.py | 4 +++- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 5012c10ee..7cd11cfae 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -61,6 +61,10 @@ Attributes: """ +class KeyLookupError(ValueError): + pass + + class Keyring(object): def __init__(self, hs): self.store = hs.get_datastore() @@ -363,7 +367,7 @@ class Keyring(object): ) except Exception as e: logger.info( - "Unable to getting key %r for %r directly: %s %s", + "Unable to get key %r for %r directly: %s %s", key_ids, server_name, type(e).__name__, str(e.message), ) @@ -425,7 +429,7 @@ class Keyring(object): for response in responses: if (u"signatures" not in response or perspective_name not in response[u"signatures"]): - raise ValueError( + raise KeyLookupError( "Key response not signed by perspective server" " %r" % (perspective_name,) ) @@ -448,7 +452,7 @@ class Keyring(object): list(response[u"signatures"][perspective_name]), list(perspective_keys) ) - raise ValueError( + raise KeyLookupError( "Response not signed with a known key for perspective" " server %r" % (perspective_name,) ) @@ -491,10 +495,10 @@ class Keyring(object): if (u"signatures" not in response or server_name not in response[u"signatures"]): - raise ValueError("Key response not signed by remote server") + raise KeyLookupError("Key response not signed by remote server") if "tls_fingerprints" not in response: - raise ValueError("Key response missing TLS fingerprints") + raise KeyLookupError("Key response missing TLS fingerprints") certificate_bytes = crypto.dump_certificate( crypto.FILETYPE_ASN1, tls_certificate @@ -508,7 +512,7 @@ class Keyring(object): response_sha256_fingerprints.add(fingerprint[u"sha256"]) if sha256_fingerprint_b64 not in response_sha256_fingerprints: - raise ValueError("TLS certificate not allowed by fingerprints") + raise KeyLookupError("TLS certificate not allowed by fingerprints") response_keys = yield self.process_v2_response( from_server=server_name, @@ -560,14 +564,14 @@ class Keyring(object): server_name = response_json["server_name"] if only_from_server: if server_name != from_server: - raise ValueError( + raise KeyLookupError( "Expected a response for server %r not %r" % ( from_server, server_name ) ) for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: - raise ValueError( + raise KeyLookupError( "Key response must include verification keys for all" " signatures" ) @@ -635,15 +639,15 @@ class Keyring(object): if ("signatures" not in response or server_name not in response["signatures"]): - raise ValueError("Key response not signed by remote server") + raise KeyLookupError("Key response not signed by remote server") if "tls_certificate" not in response: - raise ValueError("Key response missing TLS certificate") + raise KeyLookupError("Key response missing TLS certificate") tls_certificate_b64 = response["tls_certificate"] if encode_base64(x509_certificate_bytes) != tls_certificate_b64: - raise ValueError("TLS certificate doesn't match") + raise KeyLookupError("TLS certificate doesn't match") # Cache the result in the datastore. @@ -659,7 +663,7 @@ class Keyring(object): for key_id in response["signatures"][server_name]: if key_id not in response["verify_keys"]: - raise ValueError( + raise KeyLookupError( "Key response must include verification keys for all" " signatures" ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7209d5a37..9fe201365 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,6 +15,7 @@ from synapse.http.server import request_handler, respond_with_json_bytes from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.api.errors import SynapseError, Codes +from synapse.crypto.keyring import KeyLookupError from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -210,9 +211,10 @@ class RemoteKey(Resource): yield self.keyring.get_server_verify_key_v2_direct( server_name, key_ids ) + except KeyLookupError as e: + logger.info("Failed to fetch key: %s", e) except: logger.exception("Failed to get key for %r", server_name) - pass yield self.query_keys( request, query, query_remote_on_cache_miss=False ) From 3bc9629be5510a1f90af97568eb871f0ae22e5f8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 10:56:38 +0100 Subject: [PATCH 010/116] Measure federation send transaction resources --- synapse/federation/transaction_queue.py | 12 +++++++----- synapse/util/metrics.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 5787f854d..bbeec61f7 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -26,6 +26,7 @@ from synapse.util.logcontext import PreserveLoggingContext from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) +from synapse.util.metrics import measure_func import synapse.metrics import logging @@ -51,7 +52,7 @@ class TransactionQueue(object): self.transport_layer = transport_layer - self._clock = hs.get_clock() + self.clock = hs.get_clock() # Is a mapping from destinations -> deferreds. Used to keep track # of which destinations have transactions in flight and when they are @@ -82,7 +83,7 @@ class TransactionQueue(object): self.pending_failures_by_dest = {} # HACK to get unique tx id - self._next_txn_id = int(self._clock.time_msec()) + self._next_txn_id = int(self.clock.time_msec()) def can_send_to(self, destination): """Can we send messages to the given server? @@ -197,6 +198,7 @@ class TransactionQueue(object): yield deferred + @measure_func("attempt_new_transaction") @defer.inlineCallbacks @log_function def _attempt_new_transaction(self, destination): @@ -246,7 +248,7 @@ class TransactionQueue(object): limiter = yield get_retry_limiter( destination, - self._clock, + self.clock, self.store, ) @@ -262,7 +264,7 @@ class TransactionQueue(object): logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( - origin_server_ts=int(self._clock.time_msec()), + origin_server_ts=int(self.clock.time_msec()), transaction_id=txn_id, origin=self.server_name, destination=destination, @@ -293,7 +295,7 @@ class TransactionQueue(object): # keys work def json_data_cb(): data = transaction.get_dict() - now = int(self._clock.time_msec()) + now = int(self.clock.time_msec()) if "pdus" in data: for p in data["pdus"]: if "age_ts" in p: diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 0b944d3e6..4d7fee886 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer from synapse.util.logcontext import LoggingContext import synapse.metrics +from functools import wraps import logging @@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution( ) +def measure_func(name): + def wrapper(func): + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, name): + r = yield func(self, *args, **kwargs) + defer.returnValue(r) + return measured_func + return wrapper + + class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", From f91df1f761b1e9e4da184560b0e7d9557129d064 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 11:31:46 +0100 Subject: [PATCH 011/116] Store if we fail to fetch an event from a destination --- synapse/federation/federation_client.py | 37 ++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index da95c2ad6..baa672c4a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -51,10 +51,34 @@ sent_edus_counter = metrics.register_counter("sent_edus") sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) +PDU_RETRY_TIME_MS = 1 * 60 * 1000 + + class FederationClient(FederationBase): def __init__(self, hs): super(FederationClient, self).__init__(hs) + self.pdu_destination_tried = {} + self._clock.looping_call( + self._clear_tried_cache, 60 * 1000, + ) + + def _clear_tried_cache(self): + """Clear pdu_destination_tried cache""" + now = self._clock.time_msec() + + old_dict = self.pdu_destination_tried + self.pdu_destination_tried = {} + + for event_id, destination_dict in old_dict.items(): + destination_dict = { + dest: time + for dest, time in destination_dict.items() + if time + PDU_RETRY_TIME_MS > now + } + if destination_dict: + self.pdu_destination_tried[event_id] = destination_dict + def start_get_pdu_cache(self): self._get_pdu_cache = ExpiringCache( cache_name="get_pdu_cache", @@ -240,8 +264,15 @@ class FederationClient(FederationBase): if ev: defer.returnValue(ev) + pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + pdu = None for destination in destinations: + now = self._clock.time_msec() + last_attempt = pdu_attempts.get(destination, 0) + if last_attempt + PDU_RETRY_TIME_MS > now: + continue + try: limiter = yield get_retry_limiter( destination, @@ -276,9 +307,11 @@ class FederationClient(FederationBase): ) continue except CodeMessageException as e: - if 400 <= e.code < 500: + if 400 <= e.code < 500 and e.code != 404: raise + pdu_attempts[destination] = now + logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, @@ -288,6 +321,8 @@ class FederationClient(FederationBase): logger.info(e.message) continue except Exception as e: + pdu_attempts[destination] = now + logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, From 2510db3e760320b8eeffd9b9a0a0a193ce49f5ba Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Wed, 10 Aug 2016 12:57:30 +0100 Subject: [PATCH 012/116] Don't change status_msg on /sync --- synapse/handlers/presence.py | 9 ++++++--- synapse/rest/client/v2_alpha/sync.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6b70fa381..2293b5fdf 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -672,7 +672,7 @@ class PresenceHandler(object): ]) @defer.inlineCallbacks - def set_state(self, target_user, state): + def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -689,10 +689,13 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) new_fields = { - "state": presence, - "status_msg": status_msg if presence != PresenceState.OFFLINE else None + "state": presence } + if not ignore_status_msg: + msg = status_msg if presence != PresenceState.OFFLINE else None + new_fields["status_msg"] = msg + if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 43d8e0bf3..b11acdbea 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet): affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state(user, {"presence": set_presence}) + yield self.presence_handler.set_state(user, {"presence": set_presence}, True) context = yield self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence, From 11fdfaf03b67810f3d289241f772d3177e3c6b7e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 11:55:15 +0100 Subject: [PATCH 013/116] Only resign our own events --- synapse/handlers/federation.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 618cb5362..55d11122b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1093,16 +1093,17 @@ class FederationHandler(BaseHandler): ) if event: - # FIXME: This is a temporary work around where we occasionally - # return events slightly differently than when they were - # originally signed - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + if self.hs.is_mine_id(event.event_id): + # FIXME: This is a temporary work around where we occasionally + # return events slightly differently than when they were + # originally signed + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) if do_auth: in_room = yield self.auth.check_host_in_room( From 7f41bcbeec64da874c10a70f8a0b8e3f9f0626d8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 13:22:20 +0100 Subject: [PATCH 014/116] Correctly auth /event/ requests --- synapse/handlers/federation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 55d11122b..2f8959db9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -249,7 +249,7 @@ class FederationHandler(BaseHandler): if ev.type != EventTypes.Member: continue try: - domain = UserID.from_string(ev.state_key).domain + domain = get_domain_from_id(ev.state_key) except: continue @@ -1106,13 +1106,14 @@ class FederationHandler(BaseHandler): ) if do_auth: - in_room = yield self.auth.check_host_in_room( - event.room_id, - origin + events = yield self._filter_events_for_server( + origin, event.room_id, [event] ) - if not in_room: + if not events: raise AuthError(403, "Host not in room.") + event = events[0] + defer.returnValue(event) else: defer.returnValue(None) From ea8c4094dbaa9cec30c543a03f451d2555d1d23d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 13:26:13 +0100 Subject: [PATCH 015/116] Also pull out rejected events --- synapse/federation/federation_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index baa672c4a..56115a87d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -441,7 +441,7 @@ class FederationClient(FederationBase): events and the second is a list of event ids that we failed to fetch. """ if return_local: - seen_events = yield self.store.get_events(event_ids) + seen_events = yield self.store.get_events(event_ids, allow_rejected=True) signed_events = seen_events.values() else: seen_events = yield self.store.have_events(event_ids) From 739ea29d1ecbdf414db1f5062c8a2aeaa519f4ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 13:32:23 +0100 Subject: [PATCH 016/116] Also check if server is in the room --- synapse/handlers/federation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2f8959db9..ff6bb475b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1106,11 +1106,16 @@ class FederationHandler(BaseHandler): ) if do_auth: + in_room = yield self.auth.check_host_in_room( + event.room_id, + origin + ) + if not in_room: + raise AuthError(403, "Host not in room.") + events = yield self._filter_events_for_server( origin, event.room_id, [event] ) - if not events: - raise AuthError(403, "Host not in room.") event = events[0] From 487bc49bf8dcaadd6abd9cee1ef762f1bf0d35a7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 13:39:12 +0100 Subject: [PATCH 017/116] Don't stop on 4xx series errors --- synapse/federation/federation_client.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 56115a87d..9ba315171 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -300,23 +300,13 @@ class FederationClient(FederationBase): break + pdu_attempts[destination] = now + except SynapseError as e: logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, ) - continue - except CodeMessageException as e: - if 400 <= e.code < 500 and e.code != 404: - raise - - pdu_attempts[destination] = now - - logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, - ) - continue except NotRetryingDestination as e: logger.info(e.message) continue From ca8abfbf306ac1ecfe6927e99f0ce1f9eb5c9971 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 14:21:10 +0100 Subject: [PATCH 018/116] Clean up TransactionQueue --- synapse/federation/transaction_queue.py | 381 ++++++++++-------------- synapse/http/matrixfederationclient.py | 4 +- synapse/util/logcontext.py | 1 - synapse/util/metrics.py | 3 +- 4 files changed, 165 insertions(+), 224 deletions(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index bbeec61f7..e3212b482 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -21,8 +21,7 @@ from .units import Transaction from synapse.api.errors import HttpResponseException from synapse.util.async import run_on_reactor -from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_context_over_fn from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) @@ -120,267 +119,213 @@ class TransactionQueue(object): if not destinations: return - deferreds = [] - for destination in destinations: - deferred = defer.Deferred() self.pending_pdus_by_dest.setdefault(destination, []).append( - (pdu, deferred, order) + (pdu, order) ) - def chain(failure): - if not deferred.called: - deferred.errback(failure) + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) - def log_failure(f): - logger.warn("Failed to send pdu to %s: %s", destination, f.value) - - deferred.addErrback(log_failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) - - deferreds.append(deferred) - - # NO inlineCallbacks def enqueue_edu(self, edu): destination = edu.destination if not self.can_send_to(destination): return - deferred = defer.Deferred() - self.pending_edus_by_dest.setdefault(destination, []).append( - (edu, deferred) + self.pending_edus_by_dest.setdefault(destination, []).append(edu) + + preserve_context_over_fn( + self._attempt_new_transaction, destination ) - def chain(failure): - if not deferred.called: - deferred.errback(failure) - - def log_failure(f): - logger.warn("Failed to send edu to %s: %s", destination, f.value) - - deferred.addErrback(log_failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) - - return deferred - - @defer.inlineCallbacks def enqueue_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": return - deferred = defer.Deferred() - if not self.can_send_to(destination): return self.pending_failures_by_dest.setdefault( destination, [] - ).append( - (failure, deferred) + ).append(failure) + + preserve_context_over_fn( + self._attempt_new_transaction, destination ) - def chain(f): - if not deferred.called: - deferred.errback(f) - - def log_failure(f): - logger.warn("Failed to send failure to %s: %s", destination, f.value) - - deferred.addErrback(log_failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) - - yield deferred - - @measure_func("attempt_new_transaction") @defer.inlineCallbacks - @log_function def _attempt_new_transaction(self, destination): yield run_on_reactor() + while True: + # list of (pending_pdu, deferred, order) + if destination in self.pending_transactions: + # XXX: pending_transactions can get stuck on by a never-ending + # request at which point pending_pdus_by_dest just keeps growing. + # we need application-layer timeouts of some flavour of these + # requests + logger.debug( + "TX [%s] Transaction already in progress", + destination + ) + return - # list of (pending_pdu, deferred, order) - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. - # we need application-layer timeouts of some flavour of these - # requests - logger.debug( - "TX [%s] Transaction already in progress", - destination + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_failures = self.pending_failures_by_dest.pop(destination, []) + + if pending_pdus: + logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) + + if not pending_pdus and not pending_edus and not pending_failures: + logger.debug("TX [%s] Nothing to send", destination) + return + + yield self._send_new_transaction( + destination, pending_pdus, pending_edus, pending_failures ) - return - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_failures = self.pending_failures_by_dest.pop(destination, []) - - if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) - - if not pending_pdus and not pending_edus and not pending_failures: - logger.debug("TX [%s] Nothing to send", destination) - return - - try: - self.pending_transactions[destination] = 1 - - logger.debug("TX [%s] _attempt_new_transaction", destination) + @measure_func("_send_new_transaction") + @defer.inlineCallbacks + def _send_new_transaction(self, destination, pending_pdus, pending_edus, + pending_failures): # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) - + pending_pdus.sort(key=lambda t: t[1]) pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] + edus = pending_edus + failures = [x.get_dict() for x in pending_failures] - txn_id = str(self._next_txn_id) + try: + self.pending_transactions[destination] = 1 - limiter = yield get_retry_limiter( - destination, - self.clock, - self.store, - ) + logger.debug("TX [%s] _attempt_new_transaction", destination) - logger.debug( - "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", - destination, txn_id, - len(pending_pdus), - len(pending_edus), - len(pending_failures) - ) + txn_id = str(self._next_txn_id) - logger.debug("TX [%s] Persisting transaction...", destination) - - transaction = Transaction.create_new( - origin_server_ts=int(self.clock.time_msec()), - transaction_id=txn_id, - origin=self.server_name, - destination=destination, - pdus=pdus, - edus=edus, - pdu_failures=failures, - ) - - self._next_txn_id += 1 - - yield self.transaction_actions.prepare_to_send(transaction) - - logger.debug("TX [%s] Persisted transaction", destination) - logger.info( - "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d, failures: %d)", - destination, txn_id, - transaction.transaction_id, - len(pending_pdus), - len(pending_edus), - len(pending_failures), - ) - - with limiter: - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self.clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - try: - response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - code = 200 - - if response: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warn( - "Transaction returned error for %s: %s", - e_id, r, - ) - except HttpResponseException as e: - code = e.code - response = e.response - - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code + limiter = yield get_retry_limiter( + destination, + self.clock, + self.store, ) - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) + logger.debug( + "TX [%s] {%s} Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, txn_id, + len(pending_pdus), + len(pending_edus), + len(pending_failures) + ) - yield self.transaction_actions.delivered( - transaction, code, response - ) + logger.debug("TX [%s] Persisting transaction...", destination) - logger.debug("TX [%s] Marked as delivered", destination) + transaction = Transaction.create_new( + origin_server_ts=int(self.clock.time_msec()), + transaction_id=txn_id, + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + pdu_failures=failures, + ) - logger.debug("TX [%s] Yielding to callbacks...", destination) + self._next_txn_id += 1 - for deferred in deferreds: - if code == 200: - deferred.callback(None) - else: - deferred.errback(RuntimeError("Got status %d" % code)) + yield self.transaction_actions.prepare_to_send(transaction) - # Ensures we don't continue until all callbacks on that - # deferred have fired - try: - yield deferred - except: - pass + logger.debug("TX [%s] Persisted transaction", destination) + logger.info( + "TX [%s] {%s} Sending transaction [%s]," + " (PDUs: %d, EDUs: %d, failures: %d)", + destination, txn_id, + transaction.transaction_id, + len(pending_pdus), + len(pending_edus), + len(pending_failures), + ) - logger.debug("TX [%s] Yielded to callbacks", destination) - except NotRetryingDestination: - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) - except RuntimeError as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) - except Exception as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) + with limiter: + # Actually send the transaction - for deferred in deferreds: - if not deferred.called: - deferred.errback(e) + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self.clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data - finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) + try: + response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + code = 200 - # Check to see if there is anything else to send. - self._attempt_new_transaction(destination) + if response: + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: + logger.warn( + "Transaction returned error for %s: %s", + e_id, r, + ) + except HttpResponseException as e: + code = e.code + response = e.response + + logger.info( + "TX [%s] {%s} got %d response", + destination, txn_id, code + ) + + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) + + yield self.transaction_actions.delivered( + transaction, code, response + ) + + logger.debug("TX [%s] Marked as delivered", destination) + + if code != 200: + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) + except NotRetryingDestination: + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", + destination, + ) + except RuntimeError as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) + + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) + except Exception as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) + + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) + + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c3589534f..f93093dd8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object): time_out=timeout / 1000. if timeout else 60, ) - response = yield preserve_context_over_fn( - send_request, - ) + response = yield preserve_context_over_fn(send_request) log_result = "%d %s" % (response.code, response.phrase,) break diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 5316259d1..7a87045f8 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -317,7 +317,6 @@ def preserve_fn(f): def g(*args, **kwargs): with PreserveLoggingContext(current): return f(*args, **kwargs) - return g diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4d7fee886..76f301f54 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -78,7 +78,6 @@ class Measure(object): self.start = self.clock.time_msec() self.start_context = LoggingContext.current_context() if not self.start_context: - logger.warn("Entered Measure without log context: %s", self.name) self.start_context = LoggingContext("Measure") self.start_context.__enter__() self.created_context = True @@ -99,7 +98,7 @@ class Measure(object): if context != self.start_context: logger.warn( "Context has unexpectedly changed from '%s' to '%s'. (%r)", - context, self.start_context, self.name + self.start_context, context, self.name ) return From c315922b5f21957ddabd4245d67b7177386250aa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 16:34:10 +0100 Subject: [PATCH 019/116] PEP8 --- synapse/federation/transaction_queue.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index e3212b482..cb2ef0210 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -296,7 +296,9 @@ class TransactionQueue(object): if code != 200: for p in pdus: - logger.info("Failed to send event %s to %s", p.event_id, destination) + logger.info( + "Failed to send event %s to %s", p.event_id, destination + ) except NotRetryingDestination: logger.info( "TX [%s] not ready for retry yet - " From 1d40373c9d68bd948a9829283dc16903d462ae62 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 11 Aug 2016 10:24:41 +0100 Subject: [PATCH 020/116] Include prev_content in redacted state events --- synapse/events/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index aab18d7f7..0e9fd902a 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -88,6 +88,8 @@ def prune_event(event): if "age_ts" in event.unsigned: allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] + if "replaces_state" in event.unsigned: + allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"] return type(event)( allowed_fields, From 5b5148b7ec2a2bdfe5c3045ade80b29ee3183abd Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Thu, 11 Aug 2016 11:48:30 +0100 Subject: [PATCH 021/116] Synced up synchrotron set_state with PresenceHandler set_state --- synapse/app/synchrotron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 215ccfd52..48bc97636 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -119,7 +119,7 @@ class SynchrotronPresence(object): reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) - def set_state(self, user, state): + def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? pass From 448ac6cf0d357de33fb653f76e48f260ed6f66ae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Aug 2016 09:32:19 +0100 Subject: [PATCH 022/116] Only process one local membership event per room at a time --- synapse/handlers/room_member.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8cec8fc4e..4709112a0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -141,7 +141,7 @@ class RoomMemberHandler(BaseHandler): third_party_signed=None, ratelimit=True, ): - key = (target, room_id,) + key = (room_id,) with (yield self.member_linearizer.queue(key)): result = yield self._update_membership( From 866a5320de439ab2019251aa8f8697c74aeeef8c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Aug 2016 10:03:19 +0100 Subject: [PATCH 023/116] Dont invoke get_handlers fromClientV1RestServlet hs.get_handlers() can not be invoked from split out processes. Moving the invocations down a level means that we can slowly split out individual servlets. --- synapse/rest/client/v1/admin.py | 8 +++++ synapse/rest/client/v1/base.py | 1 - synapse/rest/client/v1/directory.py | 5 +++ synapse/rest/client/v1/events.py | 4 +++ synapse/rest/client/v1/initial_sync.py | 4 +++ synapse/rest/client/v1/login.py | 3 ++ synapse/rest/client/v1/profile.py | 12 +++++++ synapse/rest/client/v1/register.py | 2 ++ synapse/rest/client/v1/room.py | 48 ++++++++++++++++++++++++++ 9 files changed, 86 insertions(+), 1 deletion(-) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index b0cb31a44..af21661d7 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P[^/]*)") + def __init__(self, hs): + super(WhoisRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): "/admin/purge_history/(?P[^/]*)/(?P[^/]*)" ) + def __init__(self, hs): + super(PurgeHistoryRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 96b49b01f..c2a844786 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet): hs (synapse.server.HomeServer): """ self.hs = hs - self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_v1auth() self.txns = HttpTransactionStore() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8ac09419d..09d083159 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -36,6 +36,10 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(ClientV1RestServlet): PATTERNS = client_path_patterns("/directory/room/(?P[^/]*)$") + def __init__(self, hs): + super(ClientDirectoryServer, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) @@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet): def __init__(self, hs): super(ClientDirectoryListServer, self).__init__(hs) self.store = hs.get_datastore() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id): diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 498bb9e18..998b115bb 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 + def __init__(self, hs): + super(EventStreamRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 36c352056..113a49e53 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns class InitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/initialSync$") + def __init__(self, hs): + super(InitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index b31e27f7b..6c0eec8fb 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -56,6 +56,7 @@ class LoginRestServlet(ClientV1RestServlet): self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.device_handler = self.hs.get_device_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): flows = [] @@ -260,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet): def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) self.sp_config = hs.config.saml2_config_path + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): @@ -329,6 +331,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 65c4e2ebe..355e82474 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)/displayname") + def __init__(self, hs): + super(ProfileDisplaynameRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)/avatar_url") + def __init__(self, hs): + super(ProfileAvatarURLRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)") + def __init__(self, hs): + super(ProfileRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 2383b9df8..71d58c8e8 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet): self.sessions = {} self.enable_registration = hs.config.enable_registration self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet): super(CreateUserRestServlet, self).__init__(hs) self.store = hs.get_datastore() self.direct_user_creation_max_duration = hs.config.user_creation_max_duration + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 866a1e912..89c389511 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -35,6 +35,10 @@ logger = logging.getLogger(__name__) class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here + def __init__(self, hs): + super(RoomCreateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) @@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomStateEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" @@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomSendEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = ("/rooms/(?P[^/]*)/send/(?P[^/]*)") @@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ClientV1RestServlet): + def __init__(self, hs): + super(JoinRoomAliasServlet, self).__init__(hs) + self.handlers = hs.get_handlers() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/members$") + def __init__(self, hs): + super(RoomMemberListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/messages$") + def __init__(self, hs): + super(RoomMessageListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/state$") + def __init__(self, hs): + super(RoomStateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/initialSync$") + def __init__(self, hs): + super(RoomInitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet): def __init__(self, hs): super(RoomEventContext, self).__init__(hs) self.clock = hs.get_clock() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomForgetRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/forget") register_txn_path(self, PATTERNS, http_server) @@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomMembershipRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P[^/]*)/" @@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomRedactEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") register_txn_path(self, PATTERNS, http_server) @@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet): "/search$" ) + def __init__(self, hs): + super(SearchRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) From 4e1cebd56f9688d49a51929264c095356005f9a3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Aug 2016 15:31:44 +0100 Subject: [PATCH 024/116] Make synchrotron accept /events --- synapse/app/synchrotron.py | 36 ++++++++++++++++++++++++++++++-- synapse/handlers/__init__.py | 3 --- synapse/handlers/presence.py | 27 +++++++++++++++++------- synapse/rest/client/v1/events.py | 9 ++++---- synapse/server.py | 9 ++++++++ 5 files changed, 66 insertions(+), 18 deletions(-) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 48bc97636..3dca1c37a 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite from synapse.http.server import JsonResource from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v1 import events from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore @@ -89,17 +90,23 @@ class SynchrotronSlavedStore( get_presence_list_accepted = PresenceStore.__dict__[ "get_presence_list_accepted" ] + get_presence_list_observers_accepted = PresenceStore.__dict__[ + "get_presence_list_observers_accepted" + ] + UPDATE_SYNCING_USERS_MS = 10 * 1000 class SynchrotronPresence(object): def __init__(self, hs): + self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() self.store = hs.get_datastore() self.user_to_num_current_syncs = {} self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.clock = hs.get_clock() + self.notifier = hs.get_notifier() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = { @@ -124,6 +131,8 @@ class SynchrotronPresence(object): pass get_states = PresenceHandler.get_states.__func__ + get_state = PresenceHandler.get_state.__func__ + _get_interested_parties = PresenceHandler._get_interested_parties.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__ @defer.inlineCallbacks @@ -194,19 +203,39 @@ class SynchrotronPresence(object): self._need_to_send_sync = False yield self._send_syncing_users_now() + @defer.inlineCallbacks + def notify_from_replication(self, states, stream_id): + parties = yield self._get_interested_parties( + states, calculate_remote_hosts=False + ) + room_ids_to_states, users_to_states, _ = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=users_to_states.keys() + ) + + @defer.inlineCallbacks def process_replication(self, result): stream = result.get("presence", {"rows": []}) + states = [] for row in stream["rows"]: ( position, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active ) = row - self.user_to_current_state[user_id] = UserPresenceState( + state = UserPresenceState( user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active ) + self.user_to_current_state[user_id] = state + states.append(state) + + if states and "position" in stream: + stream_id = int(stream["position"]) + yield self.notify_from_replication(states, stream_id) class SynchrotronTyping(object): @@ -266,10 +295,12 @@ class SynchrotronServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) sync.register_servlets(self, resource) + events.register_servlets(self, resource) resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, }) root_resource = create_resource_tree(resources, Resource()) @@ -315,6 +346,7 @@ class SynchrotronServer(HomeServer): def expire_broken_caches(): store.who_forgot_in_room.invalidate_all() store.get_presence_list_accepted.invalidate_all() + store.get_presence_list_observers_accepted.invalidate_all() def notify_from_stream( result, stream_name, stream_key, room=None, user=None @@ -392,7 +424,7 @@ class SynchrotronServer(HomeServer): ) yield store.process_replication(result) typing_handler.process_replication(result) - presence_handler.process_replication(result) + yield presence_handler.process_replication(result) notify(result) except: logger.exception("Error replicating from %r", replication_url) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 1a50a2ec9..63d05f253 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -19,7 +19,6 @@ from .room import ( ) from .room_member import RoomMemberHandler from .message import MessageHandler -from .events import EventStreamHandler, EventHandler from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler @@ -53,8 +52,6 @@ class Handlers(object): self.message_handler = MessageHandler(hs) self.room_creation_handler = RoomCreationHandler(hs) self.room_member_handler = RoomMemberHandler(hs) - self.event_stream_handler = EventStreamHandler(hs) - self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2293b5fdf..6a1fe76c8 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -503,7 +503,7 @@ class PresenceHandler(object): defer.returnValue(states) @defer.inlineCallbacks - def _get_interested_parties(self, states): + def _get_interested_parties(self, states, calculate_remote_hosts=True): """Given a list of states return which entities (rooms, users, servers) are interested in the given states. @@ -526,14 +526,15 @@ class PresenceHandler(object): users_to_states.setdefault(state.user_id, []).append(state) hosts_to_states = {} - for room_id, states in room_ids_to_states.items(): - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - if not local_states: - continue + if calculate_remote_hosts: + for room_id, states in room_ids_to_states.items(): + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) + if not local_states: + continue - hosts = yield self.store.get_joined_hosts_for_room(room_id) - for host in hosts: - hosts_to_states.setdefault(host, []).extend(local_states) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(local_states) for user_id, states in users_to_states.items(): local_states = filter(lambda s: self.is_mine_id(s.user_id), states) @@ -565,6 +566,16 @@ class PresenceHandler(object): self._push_to_remotes(hosts_to_states) + @defer.inlineCallbacks + def notify_for_states(self, state, stream_id): + parties = yield self._get_interested_parties([state]) + room_ids_to_states, users_to_states, hosts_to_states = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] + ) + def _push_to_remotes(self, hosts_to_states): """Sends state updates to remote servers. diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 998b115bb..701b6f549 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): def __init__(self, hs): super(EventStreamRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.event_stream_handler = hs.get_event_stream_handler() @defer.inlineCallbacks def on_GET(self, request): @@ -50,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet): if "room_id" in request.args: room_id = request.args["room_id"][0] - handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if "timeout" in request.args: @@ -61,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args - chunk = yield handler.get_stream( + chunk = yield self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -84,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(EventRestServlet, self).__init__(hs) self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() @defer.inlineCallbacks def on_GET(self, request, event_id): requester = yield self.auth.get_user_by_req(request) - handler = self.handlers.event_handler - event = yield handler.get_event(requester.user, event_id) + event = yield self.event_handler.get_event(requester.user, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/server.py b/synapse/server.py index 6bb498830..af3246504 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler from synapse.handlers.room import RoomListHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler +from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier @@ -94,6 +95,8 @@ class HomeServer(object): 'auth_handler', 'device_handler', 'e2e_keys_handler', + 'event_handler', + 'event_stream_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -214,6 +217,12 @@ class HomeServer(object): def build_application_service_handler(self): return ApplicationServicesHandler(self) + def build_event_handler(self): + return EventHandler(self) + + def build_event_stream_handler(self): + return EventStreamHandler(self) + def build_event_sources(self): return EventSources(self) From e380538b594d4fa55e98818df0061efbdf3cb35b Mon Sep 17 00:00:00 2001 From: Daniel Ehlers Date: Sun, 14 Aug 2016 02:30:54 +0200 Subject: [PATCH 025/116] Fix AttributeError when bind_dn is not defined. In case one does not define bind_dn in ldap configuration, filter attribute is not declared. Since auth code only uses ldap_filter attribute when according LDAP mode is selected, it is safe to only declare the attribute in that case. Signed-off-by: Daniel Ehlers --- synapse/handlers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 82998a81c..c9307bd44 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -70,11 +70,11 @@ class AuthHandler(BaseHandler): self.ldap_uri = hs.config.ldap_uri self.ldap_start_tls = hs.config.ldap_start_tls self.ldap_base = hs.config.ldap_base - self.ldap_filter = hs.config.ldap_filter self.ldap_attributes = hs.config.ldap_attributes if self.ldap_mode == LDAPMode.SEARCH: self.ldap_bind_dn = hs.config.ldap_bind_dn self.ldap_bind_password = hs.config.ldap_bind_password + self.ldap_filter = hs.config.ldap_filter self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() From dfaf0fee317d54131f96e6f7b43ecbcab8b0a687 Mon Sep 17 00:00:00 2001 From: Daniel Ehlers Date: Sun, 14 Aug 2016 13:32:24 +0200 Subject: [PATCH 026/116] Log the value which is observed in the first place. The name 'result' is of bool type and has no len property, resulting in a TypeError. Futhermore in the flow control conn.response is observed and hence should be reported. Signed-off-by: Daniel Ehlers --- synapse/handlers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c9307bd44..a582d6334 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -660,7 +660,7 @@ class AuthHandler(BaseHandler): else: logger.warn( "ldap registration failed: unexpected (%d!=1) amount of results", - len(result) + len(conn.response) ) defer.returnValue(False) From 8a57cc3123d36606fa2cf014b541e39c68cff72c Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sun, 14 Aug 2016 11:50:22 -0700 Subject: [PATCH 027/116] Add missing database corruption recovery case Signed-off-by: Benjamin Saunders --- synapse/storage/events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index d2feee8db..ad026b5e0 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -600,7 +600,8 @@ class EventsStore(SQLBaseStore): "rejections", "redactions", "room_memberships", - "state_events" + "state_events", + "topics" ): txn.executemany( "DELETE FROM %s WHERE event_id = ?" % (table,), From 99bbd90b0da01690bd1193898138c64008ed71ad Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 09:45:44 +0100 Subject: [PATCH 028/116] Always run txn.after_callbacks --- synapse/storage/_base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0117fdc63..e516b6de3 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -305,13 +305,14 @@ class SQLBaseStore(object): func, *args, **kwargs ) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) - - for after_callback, after_args in after_callbacks: - after_callback(*after_args) + try: + with PreserveLoggingContext(): + result = yield self._db_pool.runWithConnection( + inner_func, *args, **kwargs + ) + finally: + for after_callback, after_args in after_callbacks: + after_callback(*after_args) defer.returnValue(result) @defer.inlineCallbacks From 4d70d1f80ea688304abdcbbf3ee01f6ab932abc7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 10:21:25 +0100 Subject: [PATCH 029/116] Add some invalidations to a cache_stream --- synapse/storage/__init__.py | 3 ++ synapse/storage/_base.py | 18 ++++++++ synapse/storage/directory.py | 37 +++++++++------- synapse/storage/prepare_database.py | 2 +- synapse/storage/presence.py | 32 +++++++++----- synapse/storage/roommember.py | 12 ++--- .../storage/schema/delta/34/cache_stream.py | 44 +++++++++++++++++++ 7 files changed, 117 insertions(+), 31 deletions(-) create mode 100644 synapse/storage/schema/delta/34/cache_stream.py diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 73fb334dd..a0c029a2f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -122,6 +122,9 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], ) + self._cache_id_gen = StreamIdGenerator( + db_conn, "cache_stream", "stream_id", + ) events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e516b6de3..02d9098dd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -861,6 +861,24 @@ class SQLBaseStore(object): return cache, min_val + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + txn.call_after(cache_func.invalidate, keys) + + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_after(ctx.__exit__, None, None, None) + + self._simple_insert_txn( + txn, + table="cache_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_func.__name__, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + } + ) + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index ef231a04d..9caaf81f2 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore): Returns: Deferred """ - try: - yield self._simple_insert( + def alias_txn(txn): + self._simple_insert_txn( + txn, "room_aliases", { "room_alias": room_alias.to_string(), "room_id": room_id, "creator": creator, }, - desc="create_room_alias_association", + ) + + self._simple_insert_many_txn( + txn, + table="room_alias_servers", + values=[{ + "room_alias": room_alias.to_string(), + "server": server, + } for server in servers], + ) + + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) + ) + + try: + ret = yield self.runInteraction( + "create_room_alias_association", alias_txn ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() ) - - for server in servers: - # TODO(erikj): Fix this to bulk insert - yield self._simple_insert( - "room_alias_servers", - { - "room_alias": room_alias.to_string(), - "server": server, - }, - desc="create_room_alias_association", - ) - self.get_aliases_for_room.invalidate((room_id,)) + defer.returnValue(ret) def get_room_alias_creator(self, room_alias): return self._simple_select_one_onecol( diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 8801669a6..b94ce7bea 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 33 +SCHEMA_VERSION = 34 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index d03f7c541..21d069664 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore): desc="add_presence_list_pending", ) - @defer.inlineCallbacks def set_presence_list_accepted(self, observer_localpart, observed_userid): - result = yield self._simple_update_one( - table="presence_list", - keyvalues={"user_id": observer_localpart, - "observed_user_id": observed_userid}, - updatevalues={"accepted": True}, - desc="set_presence_list_accepted", + def update_presence_list_txn(txn): + result = self._simple_update_one_txn( + txn, + table="presence_list", + keyvalues={ + "user_id": observer_localpart, + "observed_user_id": observed_userid + }, + updatevalues={"accepted": True}, + ) + + self._invalidate_cache_and_stream( + txn, self.get_presence_list_accepted, (observer_localpart,) + ) + self._invalidate_cache_and_stream( + txn, self.get_presence_list_observers_accepted, (observed_userid,) + ) + + return result + + return self.runInteraction( + "set_presence_list_accepted", update_presence_list_txn, ) - self.get_presence_list_accepted.invalidate((observer_localpart,)) - self.get_presence_list_observers_accepted.invalidate((observed_userid,)) - defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): if accepted: diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8bd693be7..a422ddf63 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore): user_id, membership_list=[Membership.JOIN], ) - @defer.inlineCallbacks def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore): " room_id = ?" ) txn.execute(sql, (user_id, room_id)) - yield self.runInteraction("forget_membership", f) - self.was_forgotten_at.invalidate_all() - self.who_forgot_in_room.invalidate_all() - self.did_forget.invalidate((user_id, room_id)) + + txn.call_after(self.was_forgotten_at.invalidate_all) + txn.call_after(self.did_forget.invalidate, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.who_forgot_in_room, (room_id,) + ) + return self.runInteraction("forget_membership", f) @cachedInlineCallbacks(num_args=2) def did_forget(self, user_id, room_id): diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py new file mode 100644 index 000000000..4c350bfb1 --- /dev/null +++ b/synapse/storage/schema/delta/34/cache_stream.py @@ -0,0 +1,44 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine + +import logging + +logger = logging.getLogger(__name__) + + +CREATE_TABLE = """ +CREATE TABLE cache_stream ( + stream_id BIGINT, + cache_func TEXT, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE INDEX cache_stream_id ON cache_stream(stream_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + for statement in get_statements(CREATE_TABLE.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass From 64e7e1185392972fd85718bfa55248b041f56b82 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 11:16:45 +0100 Subject: [PATCH 030/116] Implement cache replication stream --- synapse/app/synchrotron.py | 13 ------ synapse/replication/resource.py | 21 +++++++++- synapse/replication/slave/storage/_base.py | 30 +++++++++++++- synapse/storage/__init__.py | 11 +++-- synapse/storage/_base.py | 47 ++++++++++++++++------ 5 files changed, 92 insertions(+), 30 deletions(-) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 3dca1c37a..207a75d89 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -338,16 +338,10 @@ class SynchrotronServer(HomeServer): http_client = self.get_simple_http_client() store = self.get_datastore() replication_url = self.config.worker_replication_url - clock = self.get_clock() notifier = self.get_notifier() presence_handler = self.get_presence_handler() typing_handler = self.get_typing_handler() - def expire_broken_caches(): - store.who_forgot_in_room.invalidate_all() - store.get_presence_list_accepted.invalidate_all() - store.get_presence_list_observers_accepted.invalidate_all() - def notify_from_stream( result, stream_name, stream_key, room=None, user=None ): @@ -409,19 +403,12 @@ class SynchrotronServer(HomeServer): result, "typing", "typing_key", room="room_id" ) - next_expire_broken_caches_ms = 0 while True: try: args = store.stream_positions() args.update(typing_handler.stream_positions()) args["timeout"] = 30000 result = yield http_client.get_json(replication_url, args=args) - now_ms = clock.time_msec() - if now_ms > next_expire_broken_caches_ms: - expire_broken_caches() - next_expire_broken_caches_ms = ( - now_ms + store.BROKEN_CACHE_EXPIRY_MS - ) yield store.process_replication(result) typing_handler.process_replication(result) yield presence_handler.process_replication(result) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c2d487ff..84993b33b 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -41,6 +41,7 @@ STREAM_NAMES = ( ("push_rules",), ("pushers",), ("state",), + ("caches",), ) @@ -70,6 +71,7 @@ class ReplicationResource(Resource): * "backfill": Old events that have been backfilled from other servers. * "push_rules": Per user changes to push rules. * "pushers": Per user changes to their pushers. + * "caches": Cache invalidations. The API takes two additional query parameters: @@ -129,6 +131,7 @@ class ReplicationResource(Resource): push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() state_token = self.store.get_state_stream_token() + caches_token = self.store.get_cache_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -140,6 +143,7 @@ class ReplicationResource(Resource): push_rules_token, pushers_token, state_token, + caches_token, )) @request_handler() @@ -188,6 +192,7 @@ class ReplicationResource(Resource): yield self.push_rules(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams) + yield self.caches(writer, current_token, limit, request_streams) self.streams(writer, current_token, request_streams) logger.info("Replicated %d rows", writer.total) @@ -379,6 +384,20 @@ class ReplicationResource(Resource): "position", "type", "state_key", "event_id" )) + @defer.inlineCallbacks + def caches(self, writer, current_token, limit, request_streams): + current_position = current_token.caches + + caches = request_streams.get("caches") + + if caches is not None: + updated_caches = yield self.store.get_all_updated_caches( + caches, current_position, limit + ) + writer.write_header_and_rows("caches", updated_caches, ( + "position", "cache_func", "keys", "invalidation_ts" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -407,7 +426,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers", "state" + "push_rules", "pushers", "state", "caches", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 46e43ce1c..24c9946d6 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,15 +14,43 @@ # limitations under the License. from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine from twisted.internet import defer +from ._slaved_id_tracker import SlavedIdTracker + +import logging + +logger = logging.getLogger(__name__) + class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): super(BaseSlavedStore, self).__init__(hs) + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = SlavedIdTracker( + db_conn, "cache_stream", "stream_id", + ) + else: + self._cache_id_gen = None def stream_positions(self): - return {} + pos = {} + if self._cache_id_gen: + pos["caches"] = self._cache_id_gen.get_current_token() + return pos def process_replication(self, result): + stream = result.get("caches") + if stream: + for row in stream["rows"]: + ( + position, cache_func, keys, invalidation_ts, + ) = row + + try: + getattr(self, cache_func).invalidate(tuple(keys)) + except AttributeError: + logger.warn("Got unexpected cache_func: %r", cache_func) + self._cache_id_gen.advance(int(stream["position"])) return defer.succeed(None) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a0c029a2f..8af492b69 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -50,6 +50,7 @@ from .openid import OpenIdStore from .client_ips import ClientIpStore from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator +from .engines import PostgresEngine from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -122,9 +123,13 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], ) - self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_stream", "stream_id", - ) + + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = StreamIdGenerator( + db_conn, "cache_stream", "stream_id", + ) + else: + self._cache_id_gen = None events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 02d9098dd..e3edc2cde 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache from synapse.util.caches import intern_dict +from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -864,21 +865,43 @@ class SQLBaseStore(object): def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_after(ctx.__exit__, None, None, None) + if isinstance(self.database_engine, PostgresEngine): + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_after(ctx.__exit__, None, None, None) - self._simple_insert_txn( - txn, - table="cache_stream", - values={ - "stream_id": stream_id, - "cache_func": cache_func.__name__, - "keys": list(keys), - "invalidation_ts": self.clock.time_msec(), - } + self._simple_insert_txn( + txn, + table="cache_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_func.__name__, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + } + ) + + def get_all_updated_caches(self, last_id, current_id, limit): + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts FROM cache_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit,)) + return txn.fetchall() + return self.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn ) + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying From 0be963472bc4ecaafa42872c9fffef5b8f3426f3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 11:24:12 +0100 Subject: [PATCH 031/116] Use cached version of get_aliases_for_room --- synapse/replication/slave/storage/directory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 5fbe3a303..7301d885f 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore class DirectoryStore(BaseSlavedStore): get_aliases_for_room = DirectoryStore.__dict__[ "get_aliases_for_room" - ].orig + ] From 784a2d4f2c3c90804bf81d85bf23a671d204dc94 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 11:25:48 +0100 Subject: [PATCH 032/116] Remove broken cache stuff --- synapse/app/pusher.py | 16 ---------------- synapse/app/synchrotron.py | 5 ----- 2 files changed, 21 deletions(-) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index c8dde0fcb..8d755a4b3 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -80,11 +80,6 @@ class PusherSlaveStore( DataStore.get_profile_displayname.__func__ ) - # XXX: This is a bit broken because we don't persist forgotten rooms - # in a way that they can be streamed. This means that we don't have a - # way to invalidate the forgotten rooms cache correctly. - # For now we expire the cache every 10 minutes. - BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000 who_forgot_in_room = ( RoomMemberStore.__dict__["who_forgot_in_room"] ) @@ -168,7 +163,6 @@ class PusherServer(HomeServer): store = self.get_datastore() replication_url = self.config.worker_replication_url pusher_pool = self.get_pusherpool() - clock = self.get_clock() def stop_pusher(user_id, app_id, pushkey): key = "%s:%s" % (app_id, pushkey) @@ -220,21 +214,11 @@ class PusherServer(HomeServer): min_stream_id, max_stream_id, affected_room_ids ) - def expire_broken_caches(): - store.who_forgot_in_room.invalidate_all() - - next_expire_broken_caches_ms = 0 while True: try: args = store.stream_positions() args["timeout"] = 30000 result = yield http_client.get_json(replication_url, args=args) - now_ms = clock.time_msec() - if now_ms > next_expire_broken_caches_ms: - expire_broken_caches() - next_expire_broken_caches_ms = ( - now_ms + store.BROKEN_CACHE_EXPIRY_MS - ) yield store.process_replication(result) poke_pushers(result) except: diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 207a75d89..e3173533e 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -75,11 +75,6 @@ class SynchrotronSlavedStore( BaseSlavedStore, ClientIpStore, # After BaseSlavedStore because the constructor is different ): - # XXX: This is a bit broken because we don't persist forgotten rooms - # in a way that they can be streamed. This means that we don't have a - # way to invalidate the forgotten rooms cache correctly. - # For now we expire the cache every 10 minutes. - BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000 who_forgot_in_room = ( RoomMemberStore.__dict__["who_forgot_in_room"] ) From d9664344ecf481a73b2189f5bc65f4f784222969 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 11:45:57 +0100 Subject: [PATCH 033/116] Rename table. Add docs. --- synapse/replication/slave/storage/_base.py | 2 +- synapse/storage/__init__.py | 2 +- synapse/storage/_base.py | 12 ++++++++++-- synapse/storage/schema/delta/34/cache_stream.py | 6 ++++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 24c9946d6..d839d169a 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -29,7 +29,7 @@ class BaseSlavedStore(SQLBaseStore): super(BaseSlavedStore, self).__init__(hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_stream", "stream_id", + db_conn, "cache_invalidation_stream", "stream_id", ) else: self._cache_id_gen = None diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 8af492b69..7efc5bfee 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -126,7 +126,7 @@ class DataStore(RoomMemberStore, RoomStore, if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_stream", "stream_id", + db_conn, "cache_invalidation_stream", "stream_id", ) else: self._cache_id_gen = None diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e3edc2cde..c55776994 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -863,6 +863,13 @@ class SQLBaseStore(object): return cache, min_val def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ txn.call_after(cache_func.invalidate, keys) if isinstance(self.database_engine, PostgresEngine): @@ -872,7 +879,7 @@ class SQLBaseStore(object): self._simple_insert_txn( txn, - table="cache_stream", + table="cache_invalidation_stream", values={ "stream_id": stream_id, "cache_func": cache_func.__name__, @@ -887,7 +894,8 @@ class SQLBaseStore(object): # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts FROM cache_stream" + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, limit,)) diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py index 4c350bfb1..3b63a1562 100644 --- a/synapse/storage/schema/delta/34/cache_stream.py +++ b/synapse/storage/schema/delta/34/cache_stream.py @@ -20,15 +20,17 @@ import logging logger = logging.getLogger(__name__) +# This stream is used to notify replication slaves that some caches have +# been invalidated that they cannot infer from the other streams. CREATE_TABLE = """ -CREATE TABLE cache_stream ( +CREATE TABLE cache_invalidation_stream ( stream_id BIGINT, cache_func TEXT, keys TEXT[], invalidation_ts BIGINT ); -CREATE INDEX cache_stream_id ON cache_stream(stream_id); +CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id); """ From 89e786bd85331b73204294e6066dc538f1c9869c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 13:45:26 +0100 Subject: [PATCH 034/116] Doc get_next() context manager usage --- synapse/storage/_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c55776994..b0923a9ca 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -873,6 +873,10 @@ class SQLBaseStore(object): txn.call_after(cache_func.invalidate, keys) if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. ctx = self._cache_id_gen.get_next() stream_id = ctx.__enter__() txn.call_after(ctx.__exit__, None, None, None) From dc3a00f24f301ab08750fdd8ca6ae040ba290e1e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 17:04:39 +0100 Subject: [PATCH 035/116] Refactor user_delete_access_tokens. Invalidate get_user_by_access_token to slaves. --- synapse/handlers/auth.py | 6 +-- synapse/push/pusherpool.py | 8 ++-- synapse/storage/registration.py | 78 +++++++++++++++------------------ 3 files changed, 43 insertions(+), 49 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a582d6334..6986930c0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -741,7 +741,7 @@ class AuthHandler(BaseHandler): def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) - except_access_token_ids = [requester.access_token_id] if requester else [] + except_access_token_id = requester.access_token_id if requester else None try: yield self.store.user_set_password_hash(user_id, password_hash) @@ -750,10 +750,10 @@ class AuthHandler(BaseHandler): raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise e yield self.store.user_delete_access_tokens( - user_id, except_access_token_ids + user_id, except_access_token_id ) yield self.hs.get_pusherpool().remove_pushers_by_user( - user_id, except_access_token_ids + user_id, except_access_token_id ) @defer.inlineCallbacks diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 5853ec36a..54c0f1b84 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -102,14 +102,14 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id, except_token_ids=[]): + def remove_pushers_by_user(self, user_id, except_access_token_id=None): all = yield self.store.get_all_pushers() logger.info( - "Removing all pushers for user %s except access tokens ids %r", - user_id, except_token_ids + "Removing all pushers for user %s except access tokens id %r", + user_id, except_access_token_id ) for p in all: - if p['user_name'] == user_id and p['access_token'] not in except_token_ids: + if p['user_name'] == user_id and p['access_token'] != except_access_token_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 7e7d32eb6..19cb3b31c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -251,7 +251,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[], + def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None, delete_refresh_tokens=False): """ @@ -259,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): Args: user_id (str): ID of user the tokens belong to - except_token_ids (list[str]): list of access_tokens which should + except_token_id (str): list of access_tokens IDs which should *not* be deleted device_id (str|None): ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will @@ -269,53 +269,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): Returns: defer.Deferred: """ - def f(txn, table, except_tokens, call_after_delete): - sql = "SELECT token FROM %s WHERE user_id = ?" % table - clauses = [user_id] - + def f(txn): + keyvalues = { + "user_id": user_id, + } if device_id is not None: - sql += " AND device_id = ?" - clauses.append(device_id) + keyvalues["device_id"] = device_id - if except_tokens: - sql += " AND id NOT IN (%s)" % ( - ",".join(["?" for _ in except_tokens]), - ) - clauses += except_tokens - - txn.execute(sql, clauses) - - rows = txn.fetchall() - - n = 100 - chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] - for chunk in chunks: - if call_after_delete: - for row in chunk: - txn.call_after(call_after_delete, (row[0],)) - - txn.execute( - "DELETE FROM %s WHERE token in (%s)" % ( - table, - ",".join(["?" for _ in chunk]), - ), [r[0] for r in chunk] + if delete_refresh_tokens: + self._simple_delete_txn( + txn, + table="refresh_tokens", + keyvalues=keyvalues, ) - # delete refresh tokens first, to stop new access tokens being - # allocated while our backs are turned - if delete_refresh_tokens: - yield self.runInteraction( - "user_delete_access_tokens", f, - table="refresh_tokens", - except_tokens=[], - call_after_delete=None, + items = keyvalues.items() + where_clause = " AND ".join(k + " = ?" for k, _ in items) + values = [v for _, v in items] + if except_token_id: + where_clause += " AND id != ?" + values.append(except_token_id) + + txn.execute( + "SELECT token FROM access_tokens WHERE %s" % where_clause, + values + ) + rows = self.cursor_to_dict(txn) + + for row in rows: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (row["token"],) + ) + + txn.execute( + "DELETE FROM access_tokens WHERE %s" % where_clause, + values ) yield self.runInteraction( "user_delete_access_tokens", f, - table="access_tokens", - except_tokens=except_token_ids, - call_after_delete=self.get_user_by_access_token.invalidate, ) def delete_access_token(self, access_token): @@ -328,7 +320,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): }, ) - txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (access_token,) + ) return self.runInteraction("delete_access_token", f) From 1c7c317df1124ff262d40b40966956a952d1e03b Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 15 Aug 2016 18:34:53 +0100 Subject: [PATCH 036/116] Move display name rule As per https://github.com/matrix-org/matrix-doc/pull/373 and comment --- synapse/push/baserules.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 024c14904..9d60f7696 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [ 'dont_notify' ] }, + # This was changed from underride to override so it's closer in priority + # to the content rules where the user name highlight rule lives. This + # way a room rule is lower priority than both but a custom override rule + # is higher priority than both. + { + 'rule_id': 'global/underride/.m.rule.contains_display_name', + 'conditions': [ + { + 'kind': 'contains_display_name' + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + }, { + 'set_tweak': 'highlight' + } + ] + }, ] @@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, - { - 'rule_id': 'global/underride/.m.rule.contains_display_name', - 'conditions': [ - { - 'kind': 'contains_display_name' - } - ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight' - } - ] - }, { 'rule_id': 'global/underride/.m.rule.room_one_to_one', 'conditions': [ From a2427981b798eb157d48b4e695a0d025357a2b01 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Aug 2016 18:10:54 +0100 Subject: [PATCH 037/116] Use cached get_user_by_access_token in slaves --- synapse/replication/slave/storage/_base.py | 2 +- synapse/replication/slave/storage/registration.py | 2 +- synapse/storage/_base.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d839d169a..f19540d6b 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -51,6 +51,6 @@ class BaseSlavedStore(SQLBaseStore): try: getattr(self, cache_func).invalidate(tuple(keys)) except AttributeError: - logger.warn("Got unexpected cache_func: %r", cache_func) + logger.info("Got unexpected cache_func: %r", cache_func) self._cache_id_gen.advance(int(stream["position"])) return defer.succeed(None) diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index 307833f9e..38b78b97f 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -25,6 +25,6 @@ class SlavedRegistrationStore(BaseSlavedStore): # TODO: use the cached version and invalidate deleted tokens get_user_by_access_token = RegistrationStore.__dict__[ "get_user_by_access_token" - ].orig + ] _query_for_auth = DataStore._query_for_auth.__func__ diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b0923a9ca..0a2e78fd8 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -880,6 +880,7 @@ class SQLBaseStore(object): ctx = self._cache_id_gen.get_next() stream_id = ctx.__enter__() txn.call_after(ctx.__exit__, None, None, None) + txn.call_after(self.hs.get_notifier().on_new_replication_data) self._simple_insert_txn( txn, From 2ee1bd124c2d9ac7f11d0b32c9e40a842db14fe9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 16 Aug 2016 11:34:36 +0100 Subject: [PATCH 038/116] Limit number of extremeties in backfill request This works around a bug where if we make a backfill request with too many extremeties it causes the request URI to be too long. --- synapse/handlers/federation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ff6bb475b..328f8f484 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -274,7 +274,7 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities=[]): + def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. This may return @@ -284,9 +284,6 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - if not extremities: - extremities = yield self.store.get_oldest_events_in_room(room_id) - events = yield self.replication_layer.backfill( dest, room_id, @@ -455,6 +452,10 @@ class FederationHandler(BaseHandler): ) max_depth = sorted_extremeties_tuple[0][1] + # We don't want to specify too many extremities as it causes the backfill + # request URI to be too long. + extremities = dict(sorted_extremeties_tuple[:5]) + if current_depth > max_depth: logger.debug( "Not backfilling as we don't need to. %d < %d", From 48b5829aea007bce620ad3db7bfd1dc25cbf9837 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 16 Aug 2016 14:53:18 +0100 Subject: [PATCH 039/116] Fix up preview URL API. Add tests. This includes: - Splitting out methods of a class into stand alone functions, to make them easier to test. - Adding unit tests to split out functions, testing HTML -> preview. - Handle the fact that elements in lxml may have tail text. --- synapse/rest/media/v1/preview_url_resource.py | 353 ++++++++++-------- tests/test_preview.py | 80 +++- 2 files changed, 275 insertions(+), 158 deletions(-) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index bdd0e60c5..beafc8e73 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -163,7 +163,7 @@ class PreviewUrlResource(Resource): logger.debug("got media_info of '%s'" % media_info) - if self._is_media(media_info['media_type']): + if _is_media(media_info['media_type']): dims = yield self.media_repo._generate_local_thumbnails( media_info['filesystem_id'], media_info ) @@ -184,11 +184,9 @@ class PreviewUrlResource(Resource): logger.warn("Couldn't get dims for %s" % url) # define our OG response for this media - elif self._is_html(media_info['media_type']): + elif _is_html(media_info['media_type']): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - from lxml import etree - file = open(media_info['filename']) body = file.read() file.close() @@ -199,17 +197,35 @@ class PreviewUrlResource(Resource): match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) encoding = match.group(1) if match else "utf-8" - try: - parser = etree.HTMLParser(recover=True, encoding=encoding) - tree = etree.fromstring(body, parser) - og = yield self._calc_og(tree, media_info, requester) - except UnicodeDecodeError: - # blindly try decoding the body as utf-8, which seems to fix - # the charset mismatches on https://google.com - parser = etree.HTMLParser(recover=True, encoding=encoding) - tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) - og = yield self._calc_og(tree, media_info, requester) + og = decode_and_calc_og(body, media_info['uri'], encoding) + # pre-cache the image for posterity + # FIXME: it might be cleaner to use the same flow as the main /preview_url + # request itself and benefit from the same caching etc. But for now we + # just rely on the caching on the master request to speed things up. + if 'og:image' in og and og['og:image']: + image_info = yield self._download_url( + _rebase_url(og['og:image'], media_info['uri']), requester.user + ) + + if _is_media(image_info['media_type']): + # TODO: make sure we don't choke on white-on-transparent images + dims = yield self.media_repo._generate_local_thumbnails( + image_info['filesystem_id'], image_info + ) + if dims: + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + logger.warn("Couldn't get dims for %s" % og["og:image"]) + + og["og:image"] = "mxc://%s/%s" % ( + self.server_name, image_info['filesystem_id'] + ) + og["og:image:type"] = image_info['media_type'] + og["matrix:image:size"] = image_info['media_length'] + else: + del og["og:image"] else: logger.warn("Failed to find any OG data in %s", url) og = {} @@ -232,139 +248,6 @@ class PreviewUrlResource(Resource): respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) - @defer.inlineCallbacks - def _calc_og(self, tree, media_info, requester): - # suck our tree into lxml and define our OG response. - - # if we see any image URLs in the OG response, then spider them - # (although the client could choose to do this by asking for previews of those - # URLs to avoid DoSing the server) - - # "og:type" : "video", - # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", - # "og:site_name" : "YouTube", - # "og:video:type" : "application/x-shockwave-flash", - # "og:description" : "Fun stuff happening here", - # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", - # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", - # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - # "og:video:width" : "1280" - # "og:video:height" : "720", - # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - - og = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if 'content' in tag.attrib: - og[tag.attrib['property']] = tag.attrib['content'] - - # TODO: grab article: meta tags too, e.g.: - - # "article:publisher" : "https://www.facebook.com/thethudonline" /> - # "article:author" content="https://www.facebook.com/thethudonline" /> - # "article:tag" content="baby" /> - # "article:section" content="Breaking News" /> - # "article:published_time" content="2016-03-31T19:58:24+00:00" /> - # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> - - if 'og:title' not in og: - # do some basic spidering of the HTML - title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - og['og:title'] = title[0].text.strip() if title else None - - if 'og:image' not in og: - # TODO: extract a favicon failing all else - meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" - ) - if meta_image: - og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) - else: - # TODO: consider inlined CSS styles as well as width & height attribs - images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted(images, key=lambda i: ( - -1 * float(i.attrib['width']) * float(i.attrib['height']) - )) - if not images: - images = tree.xpath("//img[@src]") - if images: - og['og:image'] = images[0].attrib['src'] - - # pre-cache the image for posterity - # FIXME: it might be cleaner to use the same flow as the main /preview_url - # request itself and benefit from the same caching etc. But for now we - # just rely on the caching on the master request to speed things up. - if 'og:image' in og and og['og:image']: - image_info = yield self._download_url( - self._rebase_url(og['og:image'], media_info['uri']), requester.user - ) - - if self._is_media(image_info['media_type']): - # TODO: make sure we don't choke on white-on-transparent images - dims = yield self.media_repo._generate_local_thumbnails( - image_info['filesystem_id'], image_info - ) - if dims: - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] - else: - logger.warn("Couldn't get dims for %s" % og["og:image"]) - - og["og:image"] = "mxc://%s/%s" % ( - self.server_name, image_info['filesystem_id'] - ) - og["og:image:type"] = image_info['media_type'] - og["matrix:image:size"] = image_info['media_length'] - else: - del og["og:image"] - - if 'og:description' not in og: - meta_description = tree.xpath( - "//*/meta" - "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content") - if meta_description: - og['og:description'] = meta_description[0] - else: - # grab any text nodes which are inside the tag... - # unless they are within an HTML5 semantic markup tag... - #
,