mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 14:24:05 +01:00
Initial implementation of auth conflict resolution
This commit is contained in:
parent
5a3a15f5c1
commit
78015948a7
8 changed files with 211 additions and 82 deletions
|
@ -45,12 +45,14 @@ def prune_event(event):
|
|||
"membership",
|
||||
]
|
||||
|
||||
event_dict = event.get_dict()
|
||||
|
||||
new_content = {}
|
||||
|
||||
def add_fields(*fields):
|
||||
for field in fields:
|
||||
if field in event.content:
|
||||
new_content[field] = event.content[field]
|
||||
new_content[field] = event_dict["content"][field]
|
||||
|
||||
if event_type == EventTypes.Member:
|
||||
add_fields("membership")
|
||||
|
@ -75,7 +77,7 @@ def prune_event(event):
|
|||
|
||||
allowed_fields = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
for k, v in event_dict.items()
|
||||
if k in allowed_keys
|
||||
}
|
||||
|
||||
|
|
|
@ -345,7 +345,7 @@ class FederationClient(object):
|
|||
"auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
|
||||
}
|
||||
|
||||
code, content = yield self.transport_layer.send_invite(
|
||||
code, content = yield self.transport_layer.send_query_auth(
|
||||
destination=destination,
|
||||
room_id=room_id,
|
||||
event_id=event_id,
|
||||
|
|
|
@ -230,6 +230,39 @@ class FederationServer(object):
|
|||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_query_auth_request(self, origin, content, event_id):
|
||||
auth_chain = [
|
||||
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
missing = [
|
||||
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
|
||||
for e in content.get("missing", [])
|
||||
]
|
||||
|
||||
ret = yield self.handler.on_query_auth(
|
||||
origin, event_id, auth_chain, content.get("rejects", []), missing
|
||||
)
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
send_content = {
|
||||
"auth_chain": [
|
||||
e.get_pdu_json(time_now)
|
||||
for e in ret["auth_chain"]
|
||||
],
|
||||
"rejects": content.get("rejects", []),
|
||||
"missing": [
|
||||
e.get_pdu_json(time_now)
|
||||
for e in ret.get("missing", [])
|
||||
],
|
||||
}
|
||||
|
||||
defer.returnValue(
|
||||
(200, send_content)
|
||||
)
|
||||
|
||||
@log_function
|
||||
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
|
||||
""" Get a PDU from the database with given origin and id.
|
||||
|
|
|
@ -213,3 +213,19 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_query_auth(self, destination, room_id, event_id, content):
|
||||
path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
|
||||
|
||||
code, content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
)
|
||||
|
||||
if not 200 <= code < 300:
|
||||
raise RuntimeError("Got %d from send_invite", code)
|
||||
|
||||
defer.returnValue(json.loads(content))
|
||||
|
|
|
@ -42,7 +42,7 @@ class TransportLayerServer(object):
|
|||
content = None
|
||||
origin = None
|
||||
|
||||
if request.method == "PUT":
|
||||
if request.method in ["PUT", "POST"]:
|
||||
# TODO: Handle other method types? other content types?
|
||||
try:
|
||||
content_bytes = request.content.read()
|
||||
|
@ -234,6 +234,16 @@ class TransportLayerServer(object):
|
|||
)
|
||||
)
|
||||
)
|
||||
self.server.register_path(
|
||||
"POST",
|
||||
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context, event_id:
|
||||
self._on_query_auth_request(
|
||||
origin, content, event_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -325,3 +335,12 @@ class TransportLayerServer(object):
|
|||
)
|
||||
|
||||
defer.returnValue((200, content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_query_auth_request(self, origin, content, event_id):
|
||||
new_content = yield self.request_handler.on_query_auth_request(
|
||||
origin, content, event_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
|
|
@ -126,7 +126,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
if not state:
|
||||
state, auth_chain = yield replication.get_state_for_room(
|
||||
origin, context=event.room_id, event_id=event.event_id,
|
||||
origin, room_id=event.room_id, event_id=event.event_id,
|
||||
)
|
||||
|
||||
if not auth_chain:
|
||||
|
@ -139,7 +139,7 @@ class FederationHandler(BaseHandler):
|
|||
for e in auth_chain:
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e, fetch_auth_from=origin)
|
||||
yield self._handle_new_event(origin, e)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
|
@ -152,7 +152,7 @@ class FederationHandler(BaseHandler):
|
|||
for e in state:
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e)
|
||||
yield self._handle_new_event(origin, e)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle state event %s",
|
||||
|
@ -161,6 +161,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
try:
|
||||
yield self._handle_new_event(
|
||||
origin,
|
||||
event,
|
||||
state=state,
|
||||
backfilled=backfilled,
|
||||
|
@ -363,7 +364,14 @@ class FederationHandler(BaseHandler):
|
|||
for e in auth_chain:
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e)
|
||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||
auth = {
|
||||
(e.type, e.state_key): e for e in auth_chain
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
yield self._handle_new_event(
|
||||
target_host, e, auth_events=auth
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
|
@ -374,8 +382,13 @@ class FederationHandler(BaseHandler):
|
|||
# FIXME: Auth these.
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||
auth = {
|
||||
(e.type, e.state_key): e for e in auth_chain
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
yield self._handle_new_event(
|
||||
e, fetch_auth_from=target_host
|
||||
target_host, e, auth_events=auth
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
|
@ -384,6 +397,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
yield self._handle_new_event(
|
||||
target_host,
|
||||
new_event,
|
||||
state=state,
|
||||
current_state=state,
|
||||
|
@ -450,7 +464,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
event.internal_metadata.outlier = False
|
||||
|
||||
context = yield self._handle_new_event(event)
|
||||
context = yield self._handle_new_event(origin, event)
|
||||
|
||||
logger.debug(
|
||||
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
|
||||
|
@ -651,11 +665,12 @@ class FederationHandler(BaseHandler):
|
|||
waiters.pop().callback(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_event(self, event, state=None, backfilled=False,
|
||||
current_state=None, fetch_auth_from=None):
|
||||
@log_function
|
||||
def _handle_new_event(self, origin, event, state=None, backfilled=False,
|
||||
current_state=None, auth_events=None):
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before annotate: %s, sigs: %s",
|
||||
"_handle_new_event: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
|
@ -663,62 +678,34 @@ class FederationHandler(BaseHandler):
|
|||
event, old_state=state
|
||||
)
|
||||
|
||||
if not auth_events:
|
||||
auth_events = context.auth_events
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before auth fetch: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
"_handle_new_event: %s, auth_events: %s",
|
||||
event.event_id, auth_events,
|
||||
)
|
||||
|
||||
is_new_state = not event.internal_metadata.is_outlier()
|
||||
|
||||
known_ids = set(
|
||||
[s.event_id for s in context.auth_events.values()]
|
||||
)
|
||||
|
||||
for e_id, _ in event.auth_events:
|
||||
if e_id not in known_ids:
|
||||
e = yield self.store.get_event(e_id, allow_none=True)
|
||||
|
||||
if not e and fetch_auth_from is not None:
|
||||
# Grab the auth_chain over federation if we are missing
|
||||
# auth events.
|
||||
auth_chain = yield self.replication_layer.get_event_auth(
|
||||
fetch_auth_from, event.event_id, event.room_id
|
||||
)
|
||||
for auth_event in auth_chain:
|
||||
yield self._handle_new_event(auth_event)
|
||||
e = yield self.store.get_event(e_id, allow_none=True)
|
||||
|
||||
if not e:
|
||||
# TODO: Do some conflict res to make sure that we're
|
||||
# not the ones who are wrong.
|
||||
logger.info(
|
||||
"Rejecting %s as %s not in db or %s",
|
||||
event.event_id, e_id, known_ids,
|
||||
)
|
||||
# FIXME: How does raising AuthError work with federation?
|
||||
raise AuthError(403, "Cannot find auth event")
|
||||
|
||||
context.auth_events[(e.type, e.state_key)] = e
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before hack: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
# This is a hack to fix some old rooms where the initial join event
|
||||
# didn't reference the create event in its auth events.
|
||||
if event.type == EventTypes.Member and not event.auth_events:
|
||||
if len(event.prev_events) == 1:
|
||||
c = yield self.store.get_event(event.prev_events[0][0])
|
||||
if c.type == EventTypes.Create:
|
||||
context.auth_events[(c.type, c.state_key)] = c
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before auth check: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
auth_events[(c.type, c.state_key)] = c
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.auth_events)
|
||||
except AuthError:
|
||||
yield self.do_auth(
|
||||
origin, event, context, auth_events=auth_events
|
||||
)
|
||||
except AuthError as e:
|
||||
logger.warn(
|
||||
"Rejecting %s because %s",
|
||||
event.event_id, e.msg
|
||||
)
|
||||
|
||||
# TODO: Store rejection.
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
|
@ -731,11 +718,6 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before persist_event: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
|
@ -744,25 +726,73 @@ class FederationHandler(BaseHandler):
|
|||
current_state=current_state,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: After persist_event: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_auth(self, origin, event, context):
|
||||
for e_id, _ in event.auth_events:
|
||||
pass
|
||||
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
||||
missing):
|
||||
# Just go through and process each event in `remote_auth_chain`. We
|
||||
# don't want to fall into the trap of `missing` being wrong.
|
||||
for e in remote_auth_chain:
|
||||
try:
|
||||
yield self._handle_new_event(origin, e)
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||
current_state = set(e.event_id for e in context.auth_events.values())
|
||||
# Now get the current auth_chain for the event.
|
||||
local_auth_chain = yield self.store.get_auth_chain([event_id])
|
||||
|
||||
missing_auth = auth_events - current_state
|
||||
# TODO: Check if we would now reject event_id. If so we need to tell
|
||||
# everyone.
|
||||
|
||||
ret = yield self.construct_auth_difference(
|
||||
local_auth_chain, remote_auth_chain
|
||||
)
|
||||
|
||||
logger.debug("on_query_auth reutrning: %s", ret)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def do_auth(self, origin, event, context, auth_events):
|
||||
# Check if we have all the auth events.
|
||||
res = yield self.store.have_events(
|
||||
[e_id for e_id, _ in event.auth_events]
|
||||
)
|
||||
|
||||
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||
seen_events = set(res.keys())
|
||||
|
||||
missing_auth = event_auth_events - seen_events
|
||||
|
||||
if missing_auth:
|
||||
logger.debug("Missing auth: %s", missing_auth)
|
||||
# If we don't have all the auth events, we need to get them.
|
||||
remote_auth_chain = yield self.replication_layer.get_event_auth(
|
||||
origin, event.room_id, event.event_id
|
||||
)
|
||||
|
||||
for e in remote_auth_chain:
|
||||
try:
|
||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||
auth = {
|
||||
(e.type, e.state_key): e for e in remote_auth_chain
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
yield self._handle_new_event(
|
||||
origin, e, auth_events=auth
|
||||
)
|
||||
auth_events[(e.type, e.state_key)] = e
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
current_state = set(e.event_id for e in auth_events.values())
|
||||
different_auth = event_auth_events - current_state
|
||||
|
||||
if different_auth and not event.internal_metadata.is_outlier():
|
||||
# Do auth conflict res.
|
||||
logger.debug("Different auth: %s", different_auth)
|
||||
|
||||
# 1. Get what we think is the auth chain.
|
||||
auth_ids = self.auth.compute_auth_events(event, context)
|
||||
|
@ -778,14 +808,24 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# 3. Process any remote auth chain events we haven't seen.
|
||||
for e in result.get("missing", []):
|
||||
# TODO.
|
||||
pass
|
||||
try:
|
||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||
auth = {
|
||||
(e.type, e.state_key): e for e in result["auth_chain"]
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
yield self._handle_new_event(
|
||||
origin, e, auth_events=auth
|
||||
)
|
||||
auth_events[(e.type, e.state_key)] = e
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
# 4. Look at rejects and their proofs.
|
||||
# TODO.
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.auth_events)
|
||||
self.auth.check(event, auth_events=auth_events)
|
||||
except AuthError:
|
||||
raise
|
||||
|
||||
|
@ -802,12 +842,16 @@ class FederationHandler(BaseHandler):
|
|||
dict
|
||||
"""
|
||||
|
||||
logger.debug("construct_auth_difference Start!")
|
||||
|
||||
# TODO: Make sure we are OK with local_auth or remote_auth having more
|
||||
# auth events in them than strictly necessary.
|
||||
|
||||
def sort_fun(ev):
|
||||
return ev.depth, ev.event_id
|
||||
|
||||
logger.debug("construct_auth_difference after sort_fun!")
|
||||
|
||||
# We find the differences by starting at the "bottom" of each list
|
||||
# and iterating up on both lists. The lists are ordered by depth and
|
||||
# then event_id, we iterate up both lists until we find the event ids
|
||||
|
@ -823,11 +867,18 @@ class FederationHandler(BaseHandler):
|
|||
local_iter = iter(local_list)
|
||||
remote_iter = iter(remote_list)
|
||||
|
||||
current_local = local_iter.next()
|
||||
current_remote = remote_iter.next()
|
||||
logger.debug("construct_auth_difference before get_next!")
|
||||
|
||||
def get_next(it, opt=None):
|
||||
return it.next() if it.has_next() else opt
|
||||
try:
|
||||
return it.next()
|
||||
except:
|
||||
return opt
|
||||
|
||||
current_local = get_next(local_iter)
|
||||
current_remote = get_next(remote_iter)
|
||||
|
||||
logger.debug("construct_auth_difference before while")
|
||||
|
||||
missing_remotes = []
|
||||
missing_locals = []
|
||||
|
@ -867,6 +918,8 @@ class FederationHandler(BaseHandler):
|
|||
current_remote = get_next(remote_iter)
|
||||
continue
|
||||
|
||||
logger.debug("construct_auth_difference after while")
|
||||
|
||||
# missing locals should be sent to the server
|
||||
# We should find why we are missing remotes, as they will have been
|
||||
# rejected.
|
||||
|
@ -886,6 +939,7 @@ class FederationHandler(BaseHandler):
|
|||
reason = yield self.store.get_rejection_reason(e.event_id)
|
||||
if reason is None:
|
||||
# FIXME: ERRR?!
|
||||
logger.warn("Could not find reason for %s", e.event_id)
|
||||
raise RuntimeError("")
|
||||
|
||||
reason_map[e.event_id] = reason
|
||||
|
@ -899,7 +953,10 @@ class FederationHandler(BaseHandler):
|
|||
# TODO: Get proof.
|
||||
pass
|
||||
|
||||
logger.debug("construct_auth_difference returning")
|
||||
|
||||
defer.returnValue({
|
||||
"auth_chain": local_auth,
|
||||
"rejects": {
|
||||
e.event_id: {
|
||||
"reason": reason_map[e.event_id],
|
||||
|
|
|
@ -28,12 +28,12 @@ class RejectionsStore(SQLBaseStore):
|
|||
values={
|
||||
"event_id": event_id,
|
||||
"reason": reason,
|
||||
"last_failure": self._clock.time_msec(),
|
||||
"last_check": self._clock.time_msec(),
|
||||
}
|
||||
)
|
||||
|
||||
def get_rejection_reason(self, event_id):
|
||||
self._simple_select_one_onecol(
|
||||
return self._simple_select_one_onecol(
|
||||
table="rejections",
|
||||
retcol="reason",
|
||||
keyvalues={
|
||||
|
|
|
@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
"get_room",
|
||||
"get_destination_retry_timings",
|
||||
"set_destination_retry_timings",
|
||||
"have_events",
|
||||
]),
|
||||
resource_for_federation=NonCallableMock(),
|
||||
http_client=NonCallableMock(spec_set=[]),
|
||||
|
@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
self.datastore.persist_event.return_value = defer.succeed(None)
|
||||
self.datastore.get_room.return_value = defer.succeed(True)
|
||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||
self.datastore.have_events.return_value = defer.succeed({})
|
||||
|
||||
def annotate(ev, old_state=None):
|
||||
context = Mock()
|
||||
|
|
Loading…
Reference in a new issue