0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 17:03:50 +01:00

Merge pull request #3638 from matrix-org/rav/refactor_federation_client_exception_handling

Factor out exception handling in federation_client
This commit is contained in:
Richard van der Hoff 2018-08-02 17:37:46 +01:00 committed by GitHub
commit bdae8f2e68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 149 additions and 129 deletions

1
changelog.d/3638.misc Normal file
View file

@ -0,0 +1 @@
Factor out exception handling in federation_client

View file

@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
PDU_RETRY_TIME_MS = 1 * 60 * 1000 PDU_RETRY_TIME_MS = 1 * 60 * 1000
class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response
we couldn't parse
"""
pass
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationClient, self).__init__(hs) super(FederationClient, self).__init__(hs)
@ -458,6 +465,61 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def _try_destination_list(self, description, destinations, callback):
"""Try an operation on a series of servers, until it succeeds
Args:
description (unicode): description of the operation we're doing, for logging
destinations (Iterable[unicode]): list of server_names to try
callback (callable): Function to run for each server. Passed a single
argument: the server_name to try. May return a deferred.
If the callback raises a CodeMessageException with a 300/400 code,
attempts to perform the operation stop immediately and the exception is
reraised.
Otherwise, if the callback raises an Exception the error is logged and the
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
Returns:
The [Deferred] result of callback, if it succeeds
Raises:
CodeMessageException if the chosen remote server returns a 300/400 code.
RuntimeError if no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue
try:
res = yield callback(destination)
defer.returnValue(res)
except InvalidResponseError as e:
logger.warn(
"Failed to %s via %s: %s",
description, destination, e,
)
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.warn(
"Failed to %s via %s: %i %s",
description, destination, e.code, e.message,
)
except Exception:
logger.warn(
"Failed to %s via %s",
description, destination, exc_info=1,
)
raise RuntimeError("Failed to %s via any server", description)
def make_membership_event(self, destinations, room_id, user_id, membership, def make_membership_event(self, destinations, room_id, user_id, membership,
content={},): content={},):
""" """
@ -492,50 +554,35 @@ class FederationClient(FederationBase):
"make_membership_event called with membership='%s', must be one of %s" % "make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships)) (membership, ",".join(valid_memberships))
) )
for destination in destinations:
if destination == self.server_name:
continue
try: @defer.inlineCallbacks
ret = yield self.transport_layer.make_membership_event( def send_request(destination):
destination, room_id, user_id, membership ret = yield self.transport_layer.make_membership_event(
) destination, room_id, user_id, membership
)
pdu_dict = ret["event"] pdu_dict = ret["event"]
logger.debug("Got response to make_%s: %s", membership, pdu_dict) logger.debug("Got response to make_%s: %s", membership, pdu_dict)
pdu_dict["content"].update(content) pdu_dict["content"].update(content)
# The protoevent received over the JSON wire may not have all # The protoevent received over the JSON wire may not have all
# the required fields. Lets just gloss over that because # the required fields. Lets just gloss over that because
# there's some we never care about # there's some we never care about
if "prev_state" not in pdu_dict: if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = [] pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict) ev = builder.EventBuilder(pdu_dict)
defer.returnValue( defer.returnValue(
(destination, ev) (destination, ev)
) )
break
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
raise RuntimeError("Failed to send to any server.") return self._try_destination_list(
"make_" + membership, destinations, send_request,
)
@defer.inlineCallbacks
def send_join(self, destinations, pdu): def send_join(self, destinations, pdu):
"""Sends a join event to one of a list of homeservers. """Sends a join event to one of a list of homeservers.
@ -558,87 +605,70 @@ class FederationClient(FederationBase):
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
""" """
for destination in destinations: @defer.inlineCallbacks
if destination == self.server_name: def send_request(destination):
continue time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
try: logger.debug("Got content: %s", content)
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content) state = [
event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
state = [ auth_chain = [
event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in content.get("state", []) for p in content.get("auth_chain", [])
] ]
auth_chain = [ pdus = {
event_from_pdu_json(p, outlier=True) p.event_id: p
for p in content.get("auth_chain", []) for p in itertools.chain(state, auth_chain)
] }
pdus = { valid_pdus = yield self._check_sigs_and_hash_and_fetch(
p.event_id: p destination, list(pdus.values()),
for p in itertools.chain(state, auth_chain) outlier=True,
} )
valid_pdus = yield self._check_sigs_and_hash_and_fetch( valid_pdus_map = {
destination, list(pdus.values()), p.event_id: p
outlier=True, for p in valid_pdus
) }
valid_pdus_map = { # NB: We *need* to copy to ensure that we don't have multiple
p.event_id: p # references being passed on, as that causes... issues.
for p in valid_pdus signed_state = [
} copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple signed_auth = [
# references being passed on, as that causes... issues. valid_pdus_map[p.event_id]
signed_state = [ for p in auth_chain
copy.copy(valid_pdus_map[p.event_id]) if p.event_id in valid_pdus_map
for p in state ]
if p.event_id in valid_pdus_map
]
signed_auth = [ # NB: We *need* to copy to ensure that we don't have multiple
valid_pdus_map[p.event_id] # references being passed on, as that causes... issues.
for p in auth_chain for s in signed_state:
if p.event_id in valid_pdus_map s.internal_metadata = copy.deepcopy(s.internal_metadata)
]
# NB: We *need* to copy to ensure that we don't have multiple auth_chain.sort(key=lambda e: e.depth)
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth) defer.returnValue({
"state": signed_state,
defer.returnValue({ "auth_chain": signed_auth,
"state": signed_state, "origin": destination,
"auth_chain": signed_auth, })
"origin": destination, return self._try_destination_list("send_join", destinations, send_request)
})
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu): def send_invite(self, destination, room_id, event_id, pdu):
@ -663,7 +693,6 @@ class FederationClient(FederationBase):
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks
def send_leave(self, destinations, pdu): def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers. """Sends a leave event to one of a list of homeservers.
@ -681,34 +710,24 @@ class FederationClient(FederationBase):
Deferred: resolves to None. Deferred: resolves to None.
Fails with a ``CodeMessageException`` if the chosen remote server Fails with a ``CodeMessageException`` if the chosen remote server
returns a non-200 code. returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
""" """
for destination in destinations: @defer.inlineCallbacks
if destination == self.server_name: def send_request(destination):
continue time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
try: logger.debug("Got content: %s", content)
time_now = self._clock.time_msec() defer.returnValue(None)
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content) return self._try_destination_list("send_leave", destinations, send_request)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
def get_public_rooms(self, destination, limit=None, since_token=None, def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None, include_all_networks=False, search_filter=None, include_all_networks=False,