0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 05:53:45 +01:00

Merge pull request #3638 from matrix-org/rav/refactor_federation_client_exception_handling

Factor out exception handling in federation_client
This commit is contained in:
Richard van der Hoff 2018-08-02 17:37:46 +01:00 committed by GitHub
commit bdae8f2e68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 149 additions and 129 deletions

1
changelog.d/3638.misc Normal file
View file

@ -0,0 +1 @@
Factor out exception handling in federation_client

View file

@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
PDU_RETRY_TIME_MS = 1 * 60 * 1000
class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response
we couldn't parse
"""
pass
class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
@ -458,6 +465,61 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth)
@defer.inlineCallbacks
def _try_destination_list(self, description, destinations, callback):
"""Try an operation on a series of servers, until it succeeds
Args:
description (unicode): description of the operation we're doing, for logging
destinations (Iterable[unicode]): list of server_names to try
callback (callable): Function to run for each server. Passed a single
argument: the server_name to try. May return a deferred.
If the callback raises a CodeMessageException with a 300/400 code,
attempts to perform the operation stop immediately and the exception is
reraised.
Otherwise, if the callback raises an Exception the error is logged and the
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
Returns:
The [Deferred] result of callback, if it succeeds
Raises:
CodeMessageException if the chosen remote server returns a 300/400 code.
RuntimeError if no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue
try:
res = yield callback(destination)
defer.returnValue(res)
except InvalidResponseError as e:
logger.warn(
"Failed to %s via %s: %s",
description, destination, e,
)
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.warn(
"Failed to %s via %s: %i %s",
description, destination, e.code, e.message,
)
except Exception:
logger.warn(
"Failed to %s via %s",
description, destination, exc_info=1,
)
raise RuntimeError("Failed to %s via any server", description)
def make_membership_event(self, destinations, room_id, user_id, membership,
content={},):
"""
@ -492,50 +554,35 @@ class FederationClient(FederationBase):
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
for destination in destinations:
if destination == self.server_name:
continue
try:
ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership
)
@defer.inlineCallbacks
def send_request(destination):
ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership
)
pdu_dict = ret["event"]
pdu_dict = ret["event"]
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
pdu_dict["content"].update(content)
pdu_dict["content"].update(content)
# The protoevent received over the JSON wire may not have all
# the required fields. Lets just gloss over that because
# there's some we never care about
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
# The protoevent received over the JSON wire may not have all
# the required fields. Lets just gloss over that because
# there's some we never care about
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
ev = builder.EventBuilder(pdu_dict)
defer.returnValue(
(destination, ev)
)
break
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
defer.returnValue(
(destination, ev)
)
raise RuntimeError("Failed to send to any server.")
return self._try_destination_list(
"make_" + membership, destinations, send_request,
)
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
"""Sends a join event to one of a list of homeservers.
@ -558,87 +605,70 @@ class FederationClient(FederationBase):
Fails with a ``RuntimeError`` if no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
logger.debug("Got content: %s", content)
state = [
event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
state = [
event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
auth_chain = [
event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
pdus = {
p.event_id: p
for p in itertools.chain(state, auth_chain)
}
pdus = {
p.event_id: p
for p in itertools.chain(state, auth_chain)
}
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, list(pdus.values()),
outlier=True,
)
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, list(pdus.values()),
outlier=True,
)
valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}
valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth)
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
})
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
defer.returnValue({
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
})
return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
@ -663,7 +693,6 @@ class FederationClient(FederationBase):
defer.returnValue(pdu)
@defer.inlineCallbacks
def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers.
@ -681,34 +710,24 @@ class FederationClient(FederationBase):
Deferred: resolves to None.
Fails with a ``CodeMessageException`` if the chosen remote server
returns a non-200 code.
returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
defer.returnValue(None)
logger.debug("Got content: %s", content)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
return self._try_destination_list("send_leave", destinations, send_request)
def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None, include_all_networks=False,