0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-22 20:20:02 +01:00

Convert Transaction and Edu object to attrs (#10542)

Instead of wrapping the JSON into an object, this creates concrete
instances for Transaction and Edu. This allows for improved type
hints and simplified code.
This commit is contained in:
Patrick Cloke 2021-08-06 09:39:59 -04:00 committed by GitHub
parent 60f0534b6e
commit 1de26b3467
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 74 additions and 193 deletions

1
changelog.d/10542.misc Normal file
View file

@ -0,0 +1 @@
Convert `Transaction` and `Edu` objects to attrs.

View file

@ -195,13 +195,17 @@ class FederationServer(FederationBase):
origin, room_id, versions, limit origin, room_id, versions, limit
) )
res = self._transaction_from_pdus(pdus).get_dict() res = self._transaction_dict_from_pdus(pdus)
return 200, res return 200, res
async def on_incoming_transaction( async def on_incoming_transaction(
self, origin: str, transaction_data: JsonDict self,
) -> Tuple[int, Dict[str, Any]]: origin: str,
transaction_id: str,
destination: str,
transaction_data: JsonDict,
) -> Tuple[int, JsonDict]:
# If we receive a transaction we should make sure that kick off handling # If we receive a transaction we should make sure that kick off handling
# any old events in the staging area. # any old events in the staging area.
if not self._started_handling_of_staged_events: if not self._started_handling_of_staged_events:
@ -212,8 +216,14 @@ class FederationServer(FederationBase):
# accurate as possible. # accurate as possible.
request_time = self._clock.time_msec() request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data) transaction = Transaction(
transaction_id = transaction.transaction_id # type: ignore transaction_id=transaction_id,
destination=destination,
origin=origin,
origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore
pdus=transaction_data.get("pdus"), # type: ignore
edus=transaction_data.get("edus"),
)
if not transaction_id: if not transaction_id:
raise Exception("Transaction missing transaction_id") raise Exception("Transaction missing transaction_id")
@ -221,9 +231,7 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id) logger.debug("[%s] Got transaction", transaction_id)
# Reject malformed transactions early: reject if too many PDUs/EDUs # Reject malformed transactions early: reject if too many PDUs/EDUs
if len(transaction.pdus) > 50 or ( # type: ignore if len(transaction.pdus) > 50 or len(transaction.edus) > 100:
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
):
logger.info("Transaction PDU or EDU count too large. Returning 400") logger.info("Transaction PDU or EDU count too large. Returning 400")
return 400, {} return 400, {}
@ -263,7 +271,7 @@ class FederationServer(FederationBase):
# CRITICAL SECTION: the first thing we must do (before awaiting) is # CRITICAL SECTION: the first thing we must do (before awaiting) is
# add an entry to _active_transactions. # add an entry to _active_transactions.
assert origin not in self._active_transactions assert origin not in self._active_transactions
self._active_transactions[origin] = transaction.transaction_id # type: ignore self._active_transactions[origin] = transaction.transaction_id
try: try:
result = await self._handle_incoming_transaction( result = await self._handle_incoming_transaction(
@ -291,11 +299,11 @@ class FederationServer(FederationBase):
if response: if response:
logger.debug( logger.debug(
"[%s] We've already responded to this request", "[%s] We've already responded to this request",
transaction.transaction_id, # type: ignore transaction.transaction_id,
) )
return response return response
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore logger.debug("[%s] Transaction is new", transaction.transaction_id)
# We process PDUs and EDUs in parallel. This is important as we don't # We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients # want to block things like to device messages from reaching clients
@ -334,7 +342,7 @@ class FederationServer(FederationBase):
report back to the sending server. report back to the sending server.
""" """
received_pdus_counter.inc(len(transaction.pdus)) # type: ignore received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
@ -342,7 +350,7 @@ class FederationServer(FederationBase):
newest_pdu_ts = 0 newest_pdu_ts = 0
for p in transaction.pdus: # type: ignore for p in transaction.pdus:
# FIXME (richardv): I don't think this works: # FIXME (richardv): I don't think this works:
# https://github.com/matrix-org/synapse/issues/8429 # https://github.com/matrix-org/synapse/issues/8429
if "unsigned" in p: if "unsigned" in p:
@ -436,10 +444,10 @@ class FederationServer(FederationBase):
return pdu_results return pdu_results
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None:
"""Process the EDUs in a received transaction.""" """Process the EDUs in a received transaction."""
async def _process_edu(edu_dict): async def _process_edu(edu_dict: JsonDict) -> None:
received_edus_counter.inc() received_edus_counter.inc()
edu = Edu( edu = Edu(
@ -452,7 +460,7 @@ class FederationServer(FederationBase):
await concurrently_execute( await concurrently_execute(
_process_edu, _process_edu,
getattr(transaction, "edus", []), transaction.edus,
TRANSACTION_CONCURRENCY_LIMIT, TRANSACTION_CONCURRENCY_LIMIT,
) )
@ -538,7 +546,7 @@ class FederationServer(FederationBase):
pdu = await self.handler.get_persisted_pdu(origin, event_id) pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu: if pdu:
return 200, self._transaction_from_pdus([pdu]).get_dict() return 200, self._transaction_dict_from_pdus([pdu])
else: else:
return 404, "" return 404, ""
@ -879,18 +887,20 @@ class FederationServer(FederationBase):
ts_now_ms = self._clock.time_msec() ts_now_ms = self._clock.time_msec()
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
"""Returns a new Transaction containing the given PDUs suitable for """Returns a new Transaction containing the given PDUs suitable for
transmission. transmission.
""" """
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list] pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction( return Transaction(
# Just need a dummy transaction ID and destination since it won't be used.
transaction_id="",
origin=self.server_name, origin=self.server_name,
pdus=pdus, pdus=pdus,
origin_server_ts=int(time_now), origin_server_ts=int(time_now),
destination=None, destination="",
) ).get_dict()
async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
"""Process a PDU received in a federation /send/ transaction. """Process a PDU received in a federation /send/ transaction.

View file

@ -45,7 +45,7 @@ class TransactionActions:
`None` if we have not previously responded to this transaction or a `None` if we have not previously responded to this transaction or a
2-tuple of `(int, dict)` representing the response code and response body. 2-tuple of `(int, dict)` representing the response code and response body.
""" """
transaction_id = transaction.transaction_id # type: ignore transaction_id = transaction.transaction_id
if not transaction_id: if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id") raise RuntimeError("Cannot persist a transaction with no transaction_id")
@ -56,7 +56,7 @@ class TransactionActions:
self, origin: str, transaction: Transaction, code: int, response: JsonDict self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None: ) -> None:
"""Persist how we responded to a transaction.""" """Persist how we responded to a transaction."""
transaction_id = transaction.transaction_id # type: ignore transaction_id = transaction.transaction_id
if not transaction_id: if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id") raise RuntimeError("Cannot persist a transaction with no transaction_id")

View file

@ -27,6 +27,7 @@ from synapse.logging.opentracing import (
tags, tags,
whitelisted_homeserver, whitelisted_homeserver,
) )
from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
@ -104,13 +105,13 @@ class TransactionManager:
len(edus), len(edus),
) )
transaction = Transaction.create_new( transaction = Transaction(
origin_server_ts=int(self.clock.time_msec()), origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id, transaction_id=txn_id,
origin=self._server_name, origin=self._server_name,
destination=destination, destination=destination,
pdus=pdus, pdus=[p.get_pdu_json() for p in pdus],
edus=edus, edus=[edu.get_dict() for edu in edus],
) )
self._next_txn_id += 1 self._next_txn_id += 1
@ -131,7 +132,7 @@ class TransactionManager:
# FIXME (richardv): I also believe it no longer works. We (now?) store # FIXME (richardv): I also believe it no longer works. We (now?) store
# "age_ts" in "unsigned" rather than at the top level. See # "age_ts" in "unsigned" rather than at the top level. See
# https://github.com/matrix-org/synapse/issues/8429. # https://github.com/matrix-org/synapse/issues/8429.
def json_data_cb(): def json_data_cb() -> JsonDict:
data = transaction.get_dict() data = transaction.get_dict()
now = int(self.clock.time_msec()) now = int(self.clock.time_msec())
if "pdus" in data: if "pdus" in data:

View file

@ -143,7 +143,7 @@ class TransportLayerClient:
"""Sends the given Transaction to its destination """Sends the given Transaction to its destination
Args: Args:
transaction (Transaction) transaction
Returns: Returns:
Succeeds when we get a 2xx HTTP response. The result Succeeds when we get a 2xx HTTP response. The result

View file

@ -450,21 +450,12 @@ class FederationSendServlet(BaseFederationServerServlet):
len(transaction_data.get("edus", [])), len(transaction_data.get("edus", [])),
) )
# We should ideally be getting this from the security layer.
# origin = body["origin"]
# Add some extra data to the transaction dict that isn't included
# in the request body.
transaction_data.update(
transaction_id=transaction_id, destination=self.server_name
)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return 400, {"error": "Invalid transaction"} return 400, {"error": "Invalid transaction"}
code, response = await self.handler.on_incoming_transaction( code, response = await self.handler.on_incoming_transaction(
origin, transaction_data origin, transaction_id, self.server_name, transaction_data
) )
return code, response return code, response

View file

@ -17,18 +17,17 @@ server protocol.
""" """
import logging import logging
from typing import Optional from typing import List, Optional
import attr import attr
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class Edu(JsonEncodedObject): class Edu:
"""An Edu represents a piece of data sent from one homeserver to another. """An Edu represents a piece of data sent from one homeserver to another.
In comparison to Pdus, Edus are not persisted for a long time on disk, are In comparison to Pdus, Edus are not persisted for a long time on disk, are
@ -36,10 +35,10 @@ class Edu(JsonEncodedObject):
internal ID or previous references graph. internal ID or previous references graph.
""" """
edu_type = attr.ib(type=str) edu_type: str
content = attr.ib(type=dict) content: dict
origin = attr.ib(type=str) origin: str
destination = attr.ib(type=str) destination: str
def get_dict(self) -> JsonDict: def get_dict(self) -> JsonDict:
return { return {
@ -55,14 +54,21 @@ class Edu(JsonEncodedObject):
"destination": self.destination, "destination": self.destination,
} }
def get_context(self): def get_context(self) -> str:
return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
def strip_context(self): def strip_context(self) -> None:
getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
class Transaction(JsonEncodedObject): def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]:
if edus is None:
return []
return edus
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Transaction:
"""A transaction is a list of Pdus and Edus to be sent to a remote home """A transaction is a list of Pdus and Edus to be sent to a remote home
server with some extra metadata. server with some extra metadata.
@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject):
""" """
valid_keys = [ # Required keys.
"transaction_id", transaction_id: str
"origin", origin: str
"destination", destination: str
"origin_server_ts", origin_server_ts: int
"previous_ids", pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
"pdus", edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
"edus",
]
internal_keys = ["transaction_id", "destination"] def get_dict(self) -> JsonDict:
"""A JSON-ready dictionary of valid keys which aren't internal."""
required_keys = [ result = {
"transaction_id", "origin": self.origin,
"origin", "origin_server_ts": self.origin_server_ts,
"destination", "pdus": self.pdus,
"origin_server_ts", }
"pdus", if self.edus:
] result["edus"] = self.edus
return result
def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
"""If we include a list of pdus then we decode then as PDU's
automatically.
"""
# If there's no EDUs then remove the arg
if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"]
super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
@staticmethod
def create_new(pdus, **kwargs):
"""Used to create a new transaction. Will auto fill out
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:
raise KeyError("Require 'origin_server_ts' to construct a Transaction")
if "transaction_id" not in kwargs:
raise KeyError("Require 'transaction_id' to construct a Transaction")
kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
return Transaction(**kwargs)

View file

@ -1,102 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class JsonEncodedObject:
"""A common base class for defining protocol units that are represented
as JSON.
Attributes:
unrecognized_keys (dict): A dict containing all the key/value pairs we
don't recognize.
"""
valid_keys = [] # keys we will store
"""A list of strings that represent keys we know about
and can handle. If we have values for these keys they will be
included in the `dictionary` instance variable.
"""
internal_keys = [] # keys to ignore while building dict
"""A list of strings that should *not* be encoded into JSON.
"""
required_keys = []
"""A list of strings that we require to exist. If they are not given upon
construction it raises an exception.
"""
def __init__(self, **kwargs):
"""Takes the dict of `kwargs` and loads all keys that are *valid*
(i.e., are included in the `valid_keys` list) into the dictionary`
instance variable.
Any keys that aren't recognized are added to the `unrecognized_keys`
attribute.
Args:
**kwargs: Attributes associated with this protocol unit.
"""
for required_key in self.required_keys:
if required_key not in kwargs:
raise RuntimeError("Key %s is required" % required_key)
self.unrecognized_keys = {} # Keys we were given not listed as valid
for k, v in kwargs.items():
if k in self.valid_keys or k in self.internal_keys:
self.__dict__[k] = v
else:
self.unrecognized_keys[k] = v
def get_dict(self):
"""Converts this protocol unit into a :py:class:`dict`, ready to be
encoded as JSON.
The keys it encodes are: `valid_keys` - `internal_keys`
Returns
dict
"""
d = {
k: _encode(v)
for (k, v) in self.__dict__.items()
if k in self.valid_keys and k not in self.internal_keys
}
d.update(self.unrecognized_keys)
return d
def get_internal_dict(self):
d = {
k: _encode(v, internal=True)
for (k, v) in self.__dict__.items()
if k in self.valid_keys
}
d.update(self.unrecognized_keys)
return d
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
def _encode(obj, internal=False):
if type(obj) is list:
return [_encode(o, internal=internal) for o in obj]
if isinstance(obj, JsonEncodedObject):
if internal:
return obj.get_internal_dict()
else:
return obj.get_dict()
return obj