mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 17:03:50 +01:00
Merge pull request #3638 from matrix-org/rav/refactor_federation_client_exception_handling
Factor out exception handling in federation_client
This commit is contained in:
commit
bdae8f2e68
2 changed files with 149 additions and 129 deletions
1
changelog.d/3638.misc
Normal file
1
changelog.d/3638.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Factor out exception handling in federation_client
|
|
@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
|
||||||
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidResponseError(RuntimeError):
|
||||||
|
"""Helper for _try_destination_list: indicates that the server returned a response
|
||||||
|
we couldn't parse
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FederationClient(FederationBase):
|
class FederationClient(FederationBase):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(FederationClient, self).__init__(hs)
|
super(FederationClient, self).__init__(hs)
|
||||||
|
@ -458,6 +465,61 @@ class FederationClient(FederationBase):
|
||||||
defer.returnValue(signed_auth)
|
defer.returnValue(signed_auth)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
def _try_destination_list(self, description, destinations, callback):
|
||||||
|
"""Try an operation on a series of servers, until it succeeds
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description (unicode): description of the operation we're doing, for logging
|
||||||
|
|
||||||
|
destinations (Iterable[unicode]): list of server_names to try
|
||||||
|
|
||||||
|
callback (callable): Function to run for each server. Passed a single
|
||||||
|
argument: the server_name to try. May return a deferred.
|
||||||
|
|
||||||
|
If the callback raises a CodeMessageException with a 300/400 code,
|
||||||
|
attempts to perform the operation stop immediately and the exception is
|
||||||
|
reraised.
|
||||||
|
|
||||||
|
Otherwise, if the callback raises an Exception the error is logged and the
|
||||||
|
next server tried. Normally the stacktrace is logged but this is
|
||||||
|
suppressed if the exception is an InvalidResponseError.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The [Deferred] result of callback, if it succeeds
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CodeMessageException if the chosen remote server returns a 300/400 code.
|
||||||
|
|
||||||
|
RuntimeError if no servers were reachable.
|
||||||
|
"""
|
||||||
|
for destination in destinations:
|
||||||
|
if destination == self.server_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = yield callback(destination)
|
||||||
|
defer.returnValue(res)
|
||||||
|
except InvalidResponseError as e:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to %s via %s: %s",
|
||||||
|
description, destination, e,
|
||||||
|
)
|
||||||
|
except CodeMessageException as e:
|
||||||
|
if not 500 <= e.code < 600:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to %s via %s: %i %s",
|
||||||
|
description, destination, e.code, e.message,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to %s via %s",
|
||||||
|
description, destination, exc_info=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError("Failed to %s via any server", description)
|
||||||
|
|
||||||
def make_membership_event(self, destinations, room_id, user_id, membership,
|
def make_membership_event(self, destinations, room_id, user_id, membership,
|
||||||
content={},):
|
content={},):
|
||||||
"""
|
"""
|
||||||
|
@ -492,50 +554,35 @@ class FederationClient(FederationBase):
|
||||||
"make_membership_event called with membership='%s', must be one of %s" %
|
"make_membership_event called with membership='%s', must be one of %s" %
|
||||||
(membership, ",".join(valid_memberships))
|
(membership, ",".join(valid_memberships))
|
||||||
)
|
)
|
||||||
for destination in destinations:
|
|
||||||
if destination == self.server_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
@defer.inlineCallbacks
|
||||||
ret = yield self.transport_layer.make_membership_event(
|
def send_request(destination):
|
||||||
destination, room_id, user_id, membership
|
ret = yield self.transport_layer.make_membership_event(
|
||||||
)
|
destination, room_id, user_id, membership
|
||||||
|
)
|
||||||
|
|
||||||
pdu_dict = ret["event"]
|
pdu_dict = ret["event"]
|
||||||
|
|
||||||
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
|
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
|
||||||
|
|
||||||
pdu_dict["content"].update(content)
|
pdu_dict["content"].update(content)
|
||||||
|
|
||||||
# The protoevent received over the JSON wire may not have all
|
# The protoevent received over the JSON wire may not have all
|
||||||
# the required fields. Lets just gloss over that because
|
# the required fields. Lets just gloss over that because
|
||||||
# there's some we never care about
|
# there's some we never care about
|
||||||
if "prev_state" not in pdu_dict:
|
if "prev_state" not in pdu_dict:
|
||||||
pdu_dict["prev_state"] = []
|
pdu_dict["prev_state"] = []
|
||||||
|
|
||||||
ev = builder.EventBuilder(pdu_dict)
|
ev = builder.EventBuilder(pdu_dict)
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
(destination, ev)
|
(destination, ev)
|
||||||
)
|
)
|
||||||
break
|
|
||||||
except CodeMessageException as e:
|
|
||||||
if not 500 <= e.code < 600:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to make_%s via %s: %s",
|
|
||||||
membership, destination, e.message
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to make_%s via %s: %s",
|
|
||||||
membership, destination, e.message
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError("Failed to send to any server.")
|
return self._try_destination_list(
|
||||||
|
"make_" + membership, destinations, send_request,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def send_join(self, destinations, pdu):
|
def send_join(self, destinations, pdu):
|
||||||
"""Sends a join event to one of a list of homeservers.
|
"""Sends a join event to one of a list of homeservers.
|
||||||
|
|
||||||
|
@ -558,87 +605,70 @@ class FederationClient(FederationBase):
|
||||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for destination in destinations:
|
@defer.inlineCallbacks
|
||||||
if destination == self.server_name:
|
def send_request(destination):
|
||||||
continue
|
time_now = self._clock.time_msec()
|
||||||
|
_, content = yield self.transport_layer.send_join(
|
||||||
|
destination=destination,
|
||||||
|
room_id=pdu.room_id,
|
||||||
|
event_id=pdu.event_id,
|
||||||
|
content=pdu.get_pdu_json(time_now),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
logger.debug("Got content: %s", content)
|
||||||
time_now = self._clock.time_msec()
|
|
||||||
_, content = yield self.transport_layer.send_join(
|
|
||||||
destination=destination,
|
|
||||||
room_id=pdu.room_id,
|
|
||||||
event_id=pdu.event_id,
|
|
||||||
content=pdu.get_pdu_json(time_now),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Got content: %s", content)
|
state = [
|
||||||
|
event_from_pdu_json(p, outlier=True)
|
||||||
|
for p in content.get("state", [])
|
||||||
|
]
|
||||||
|
|
||||||
state = [
|
auth_chain = [
|
||||||
event_from_pdu_json(p, outlier=True)
|
event_from_pdu_json(p, outlier=True)
|
||||||
for p in content.get("state", [])
|
for p in content.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
auth_chain = [
|
pdus = {
|
||||||
event_from_pdu_json(p, outlier=True)
|
p.event_id: p
|
||||||
for p in content.get("auth_chain", [])
|
for p in itertools.chain(state, auth_chain)
|
||||||
]
|
}
|
||||||
|
|
||||||
pdus = {
|
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
p.event_id: p
|
destination, list(pdus.values()),
|
||||||
for p in itertools.chain(state, auth_chain)
|
outlier=True,
|
||||||
}
|
)
|
||||||
|
|
||||||
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
valid_pdus_map = {
|
||||||
destination, list(pdus.values()),
|
p.event_id: p
|
||||||
outlier=True,
|
for p in valid_pdus
|
||||||
)
|
}
|
||||||
|
|
||||||
valid_pdus_map = {
|
# NB: We *need* to copy to ensure that we don't have multiple
|
||||||
p.event_id: p
|
# references being passed on, as that causes... issues.
|
||||||
for p in valid_pdus
|
signed_state = [
|
||||||
}
|
copy.copy(valid_pdus_map[p.event_id])
|
||||||
|
for p in state
|
||||||
|
if p.event_id in valid_pdus_map
|
||||||
|
]
|
||||||
|
|
||||||
# NB: We *need* to copy to ensure that we don't have multiple
|
signed_auth = [
|
||||||
# references being passed on, as that causes... issues.
|
valid_pdus_map[p.event_id]
|
||||||
signed_state = [
|
for p in auth_chain
|
||||||
copy.copy(valid_pdus_map[p.event_id])
|
if p.event_id in valid_pdus_map
|
||||||
for p in state
|
]
|
||||||
if p.event_id in valid_pdus_map
|
|
||||||
]
|
|
||||||
|
|
||||||
signed_auth = [
|
# NB: We *need* to copy to ensure that we don't have multiple
|
||||||
valid_pdus_map[p.event_id]
|
# references being passed on, as that causes... issues.
|
||||||
for p in auth_chain
|
for s in signed_state:
|
||||||
if p.event_id in valid_pdus_map
|
s.internal_metadata = copy.deepcopy(s.internal_metadata)
|
||||||
]
|
|
||||||
|
|
||||||
# NB: We *need* to copy to ensure that we don't have multiple
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
# references being passed on, as that causes... issues.
|
|
||||||
for s in signed_state:
|
|
||||||
s.internal_metadata = copy.deepcopy(s.internal_metadata)
|
|
||||||
|
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
defer.returnValue({
|
||||||
|
"state": signed_state,
|
||||||
defer.returnValue({
|
"auth_chain": signed_auth,
|
||||||
"state": signed_state,
|
"origin": destination,
|
||||||
"auth_chain": signed_auth,
|
})
|
||||||
"origin": destination,
|
return self._try_destination_list("send_join", destinations, send_request)
|
||||||
})
|
|
||||||
except CodeMessageException as e:
|
|
||||||
if not 500 <= e.code < 600:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to send_join via %s: %s",
|
|
||||||
destination, e.message
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to send_join via %s: %s",
|
|
||||||
destination, e.message
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError("Failed to send to any server.")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_invite(self, destination, room_id, event_id, pdu):
|
def send_invite(self, destination, room_id, event_id, pdu):
|
||||||
|
@ -663,7 +693,6 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
defer.returnValue(pdu)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def send_leave(self, destinations, pdu):
|
def send_leave(self, destinations, pdu):
|
||||||
"""Sends a leave event to one of a list of homeservers.
|
"""Sends a leave event to one of a list of homeservers.
|
||||||
|
|
||||||
|
@ -681,34 +710,24 @@ class FederationClient(FederationBase):
|
||||||
Deferred: resolves to None.
|
Deferred: resolves to None.
|
||||||
|
|
||||||
Fails with a ``CodeMessageException`` if the chosen remote server
|
Fails with a ``CodeMessageException`` if the chosen remote server
|
||||||
returns a non-200 code.
|
returns a 300/400 code.
|
||||||
|
|
||||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||||
"""
|
"""
|
||||||
for destination in destinations:
|
@defer.inlineCallbacks
|
||||||
if destination == self.server_name:
|
def send_request(destination):
|
||||||
continue
|
time_now = self._clock.time_msec()
|
||||||
|
_, content = yield self.transport_layer.send_leave(
|
||||||
|
destination=destination,
|
||||||
|
room_id=pdu.room_id,
|
||||||
|
event_id=pdu.event_id,
|
||||||
|
content=pdu.get_pdu_json(time_now),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
logger.debug("Got content: %s", content)
|
||||||
time_now = self._clock.time_msec()
|
defer.returnValue(None)
|
||||||
_, content = yield self.transport_layer.send_leave(
|
|
||||||
destination=destination,
|
|
||||||
room_id=pdu.room_id,
|
|
||||||
event_id=pdu.event_id,
|
|
||||||
content=pdu.get_pdu_json(time_now),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Got content: %s", content)
|
return self._try_destination_list("send_leave", destinations, send_request)
|
||||||
defer.returnValue(None)
|
|
||||||
except CodeMessageException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to send_leave via %s: %s",
|
|
||||||
destination, e.message
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError("Failed to send to any server.")
|
|
||||||
|
|
||||||
def get_public_rooms(self, destination, limit=None, since_token=None,
|
def get_public_rooms(self, destination, limit=None, since_token=None,
|
||||||
search_filter=None, include_all_networks=False,
|
search_filter=None, include_all_networks=False,
|
||||||
|
|
Loading…
Reference in a new issue