0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 01:51:55 +01:00

Review comments

This commit is contained in:
Kegan Dougal 2016-11-11 17:47:03 +00:00
parent f6c48802f5
commit 8ecaff51a1
5 changed files with 119 additions and 158 deletions

View file

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
logger = logging.getLogger(__name__)
def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
return request.path + "/" + token
class HttpTransactionCache(object):
def __init__(self):
self.transactions = {
# $txn_key: ObservableDeferred<(res_code, res_json_body)>
}
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
See:
fetch_or_execute
"""
return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args:
txn_key (str): A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
fn (function): A function which returns a tuple of
(response_code, response_dict)d
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
synapse.util.async.ObservableDeferred which resolves to a tuple
of (response_code, response_dict).
"""
try:
return self.transactions[txn_key]
except KeyError:
pass # execute the function instead.
deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
self.transactions[txn_key] = observable
return observable

View file

@ -18,7 +18,8 @@
from synapse.http.servlet import RestServlet
from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionCache
from synapse.rest.client.transactions import HttpTransactionCache
import re
import logging

View file

@ -22,7 +22,6 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
from synapse.util.async import ObservableDeferred
from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
@ -56,17 +55,11 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request
)
res = yield observable.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(self.on_POST(request))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
@defer.inlineCallbacks
def on_POST(self, request):
@ -217,19 +210,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, event_type, txn_id)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
)
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
res = yield observable.observe()
defer.returnValue(res)
# TODO: Needs unit testing for room ID + alias joins
@ -288,17 +273,11 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
)
res = yield observable.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(self.on_POST(request, room_identifier, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
# TODO: Needs unit testing
@ -542,17 +521,11 @@ class RoomForgetRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
)
res = yield observable.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(self.on_POST(request, room_id, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
# TODO: Needs unit testing
@ -626,19 +599,11 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, membership_action, txn_id)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
)
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
res = yield observable.observe()
defer.returnValue(res)
class RoomRedactEventRestServlet(ClientV1RestServlet):
@ -672,19 +637,11 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, event_id, txn_id)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
)
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
res = yield observable.observe()
defer.returnValue(res)
class RoomTypingRestServlet(ClientV1RestServlet):

View file

@ -1,75 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
logger = logging.getLogger(__name__)
class HttpTransactionCache(object):
def __init__(self):
# { key : (txn_id, res_observ_defer) }
self.transactions = {}
def _get_response(self, key, txn_id):
try:
(last_txn_id, res_observ_defer) = self.transactions[key]
if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", txn_id)
return res_observ_defer
except KeyError:
pass
return None
def _store_response(self, key, txn_id, res_observ_defer):
self.transactions[key] = (txn_id, res_observ_defer)
def store_client_transaction(self, request, txn_id, res_observ_defer):
"""Stores the request/Promise<response> pair of an HTTP transaction.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
res_observ_defer (Promise<tuple>): A tuple of (response code, response dict)
txn_id (str): The transaction ID for this request.
"""
self._store_response(self._get_key(request), txn_id, res_observ_defer)
def get_client_transaction(self, request, txn_id):
"""Retrieves a stored response if there was one.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
txn_id (str): The transaction ID for this request.
Returns:
Promise: Resolves to the response tuple.
Raises:
KeyError if the transaction was not found.
"""
res_observ_defer = self._get_response(self._get_key(request), txn_id)
if res_observ_defer is None:
raise KeyError("Transaction not found.")
return res_observ_defer
def _get_key(self, request):
token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token

View file

@ -19,8 +19,7 @@ from twisted.internet import defer
from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionCache
from synapse.util.async import ObservableDeferred
from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_v2_patterns
@ -46,16 +45,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(self._put(request, message_type, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
res = yield res_deferred.observe()
observable = self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
)
res = yield observable.observe()
defer.returnValue(res)
@defer.inlineCallbacks