0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 05:21:55 +01:00

Fix bugs in invite/join dances.

We now do more implement more of the auth on the events so that we
don't reject valid events.
This commit is contained in:
Erik Johnston 2014-11-25 11:31:18 +00:00
parent 3536fd7d60
commit 64fc859dac
8 changed files with 221 additions and 148 deletions

View file

@ -38,27 +38,21 @@ class Auth(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
def check(self, event, raises=False): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
Raises:
AuthError if there was a problem authorising this event. This will
be raised only if raises=True.
""" """
try: try:
if hasattr(event, "room_id"): if not hasattr(event, "room_id"):
if event.old_state_events is None: raise AuthError(500, "Event has no room_id: %s" % event)
if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now) # are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id) logger.warn("Trusting event: %s", event.event_id)
return True return True
if hasattr(event, "outlier") and event.outlier is True:
# TODO (erikj): Auth for outliers is done differently.
return True
if event.type == RoomCreateEvent.TYPE: if event.type == RoomCreateEvent.TYPE:
# FIXME # FIXME
return True return True
@ -68,49 +62,42 @@ class Auth(object):
return True return True
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
allowed = self.is_membership_change_allowed(event) allowed = self.is_membership_change_allowed(
event, auth_events
)
if allowed: if allowed:
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
else: else:
logger.debug("Denying! %s", event) logger.debug("Denying! %s", event)
return allowed return allowed
self.check_event_sender_in_room(event) self.check_event_sender_in_room(event, auth_events)
self._can_send_event(event) self._can_send_event(event, auth_events)
if event.type == RoomPowerLevelsEvent.TYPE: if event.type == RoomPowerLevelsEvent.TYPE:
self._check_power_levels(event) self._check_power_levels(event, auth_events)
if event.type == RoomRedactionEvent.TYPE: if event.type == RoomRedactionEvent.TYPE:
self._check_redaction(event) self._check_redaction(event, auth_events)
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
return True
else:
raise AuthError(500, "Unknown event: %s" % event)
except AuthError as e: except AuthError as e:
logger.info( logger.info(
"Event auth check failed on event %s with msg: %s", "Event auth check failed on event %s with msg: %s",
event, e.msg event, e.msg
) )
logger.info("Denying! %s", event) logger.info("Denying! %s", event)
if raises:
raise raise
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id): def check_joined_room(self, room_id, user_id):
try: member = yield self.state.get_current_state(
member = yield self.store.get_room_member(
room_id=room_id, room_id=room_id,
user_id=user_id event_type=RoomMemberEvent.TYPE,
state_key=user_id
) )
self._check_joined_room(member, user_id, room_id) self._check_joined_room(member, user_id, room_id)
defer.returnValue(member) defer.returnValue(member)
except AttributeError:
pass
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
@ -130,9 +117,9 @@ class Auth(object):
defer.returnValue(False) defer.returnValue(False)
def check_event_sender_in_room(self, event): def check_event_sender_in_room(self, event, auth_events):
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.state_events.get(key) member_event = auth_events.get(key)
return self._check_joined_room( return self._check_joined_room(
member_event, member_event,
@ -147,14 +134,14 @@ class Auth(object):
)) ))
@log_function @log_function
def is_membership_change_allowed(self, event): def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"] membership = event.content["membership"]
# Check if this is the room creator joining: # Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership: if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event: # Get room creation event:
key = (RoomCreateEvent.TYPE, "", ) key = (RoomCreateEvent.TYPE, "", )
create = event.old_state_events.get(key) create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id: if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key: if create.content["creator"] == event.state_key:
return True return True
@ -163,19 +150,19 @@ class Auth(object):
# get info about the caller # get info about the caller
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, event.user_id, )
caller = event.old_state_events.get(key) caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target # get info about the target
key = (RoomMemberEvent.TYPE, target_user_id, ) key = (RoomMemberEvent.TYPE, target_user_id, )
target = event.old_state_events.get(key) target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN target_in_room = target and target.membership == Membership.JOIN
key = (RoomJoinRulesEvent.TYPE, "", ) key = (RoomJoinRulesEvent.TYPE, "", )
join_rule_event = event.old_state_events.get(key) join_rule_event = auth_events.get(key)
if join_rule_event: if join_rule_event:
join_rule = join_rule_event.content.get( join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE "join_rule", JoinRules.INVITE
@ -186,11 +173,13 @@ class Auth(object):
user_level = self._get_power_level_from_event_state( user_level = self._get_power_level_from_event_state(
event, event,
event.user_id, event.user_id,
auth_events,
) )
ban_level, kick_level, redact_level = ( ban_level, kick_level, redact_level = (
self._get_ops_level_from_event_state( self._get_ops_level_from_event_state(
event event,
auth_events,
) )
) )
@ -260,9 +249,9 @@ class Auth(object):
return True return True
def _get_power_level_from_event_state(self, event, user_id): def _get_power_level_from_event_state(self, event, user_id, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key) power_level_event = auth_events.get(key)
level = None level = None
if power_level_event: if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id) level = power_level_event.content.get("users", {}).get(user_id)
@ -270,16 +259,16 @@ class Auth(object):
level = power_level_event.content.get("users_default", 0) level = power_level_event.content.get("users_default", 0)
else: else:
key = (RoomCreateEvent.TYPE, "", ) key = (RoomCreateEvent.TYPE, "", )
create_event = event.old_state_events.get(key) create_event = auth_events.get(key)
if (create_event is not None and if (create_event is not None and
create_event.content["creator"] == user_id): create_event.content["creator"] == user_id):
return 100 return 100
return level return level
def _get_ops_level_from_event_state(self, event): def _get_ops_level_from_event_state(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key) power_level_event = auth_events.get(key)
if power_level_event: if power_level_event:
return ( return (
@ -375,6 +364,11 @@ class Auth(object):
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.old_state_events.get(key) member_event = event.old_state_events.get(key)
key = (RoomCreateEvent.TYPE, "", )
create_event = event.old_state_events.get(key)
if create_event:
auth_events.append(create_event.event_id)
if join_rule_event: if join_rule_event:
join_rule = join_rule_event.content.get("join_rule") join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False is_public = join_rule == JoinRules.PUBLIC if join_rule else False
@ -406,9 +400,9 @@ class Auth(object):
event.auth_events = zip(auth_events, hashes) event.auth_events = zip(auth_events, hashes)
@log_function @log_function
def _can_send_event(self, event): def _can_send_event(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (RoomPowerLevelsEvent.TYPE, "", )
send_level_event = event.old_state_events.get(key) send_level_event = auth_events.get(key)
send_level = None send_level = None
if send_level_event: if send_level_event:
send_level = send_level_event.content.get("events", {}).get( send_level = send_level_event.content.get("events", {}).get(
@ -432,6 +426,7 @@ class Auth(object):
user_level = self._get_power_level_from_event_state( user_level = self._get_power_level_from_event_state(
event, event,
event.user_id, event.user_id,
auth_events,
) )
if user_level: if user_level:
@ -468,14 +463,16 @@ class Auth(object):
return True return True
def _check_redaction(self, event): def _check_redaction(self, event, auth_events):
user_level = self._get_power_level_from_event_state( user_level = self._get_power_level_from_event_state(
event, event,
event.user_id, event.user_id,
auth_events,
) )
_, _, redact_level = self._get_ops_level_from_event_state( _, _, redact_level = self._get_ops_level_from_event_state(
event event,
auth_events,
) )
if user_level < redact_level: if user_level < redact_level:
@ -484,7 +481,7 @@ class Auth(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
def _check_power_levels(self, event): def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {}) user_list = event.content.get("users", {})
# Validate users # Validate users
for k, v in user_list.items(): for k, v in user_list.items():
@ -499,7 +496,7 @@ class Auth(object):
raise SynapseError(400, "Not a valid power level: %s" % (v,)) raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, ) key = (event.type, event.state_key, )
current_state = event.old_state_events.get(key) current_state = auth_events.get(key)
if not current_state: if not current_state:
return return
@ -507,6 +504,7 @@ class Auth(object):
user_level = self._get_power_level_from_event_state( user_level = self._get_power_level_from_event_state(
event, event,
event.user_id, event.user_id,
auth_events,
) )
# Check other levels: # Check other levels:

View file

@ -125,6 +125,7 @@ class SynapseEvent(JsonEncodedObject):
pdu_json.pop("outlier", None) pdu_json.pop("outlier", None)
pdu_json.pop("replaces_state", None) pdu_json.pop("replaces_state", None)
pdu_json.pop("redacted", None) pdu_json.pop("redacted", None)
pdu_json.pop("prev_content", None)
state_hash = pdu_json.pop("state_hash", None) state_hash = pdu_json.pop("state_hash", None)
if state_hash is not None: if state_hash is not None:
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash

View file

@ -481,11 +481,17 @@ class ReplicationLayer(object):
# FIXME: We probably want to do something with the auth_chain given # FIXME: We probably want to do something with the auth_chain given
# to us # to us
# auth_chain = [ auth_chain = [
# Pdu(outlier=True, **p) for p in content.get("auth_chain", []) self.event_from_pdu_json(p, outlier=True)
# ] for p in content.get("auth_chain", [])
]
defer.returnValue(state) auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": state,
"auth_chain": auth_chain,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, destination, context, event_id, pdu): def send_invite(self, destination, context, event_id, pdu):
@ -551,12 +557,26 @@ class ReplicationLayer(object):
) )
if not exists: if not exists:
try:
logger.debug(
"Getting missing auth event %s from %s",
e_id,
origin,
)
yield self.get_pdu( yield self.get_pdu(
origin, origin,
event_id=e_id, event_id=e_id,
outlier=True, outlier=True,
) )
logger.debug("Processed pdu %s", e_id) logger.debug("Processed pdu %s", e_id)
except:
logger.warn(
"Failed to get auth event %s from %s",
e_id,
origin
)
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.outlier: if not pdu.outlier:
@ -578,7 +598,7 @@ class ReplicationLayer(object):
try: try:
yield self.get_pdu( yield self.get_pdu(
pdu.origin, origin,
event_id=event_id, event_id=event_id,
) )
logger.debug("Processed pdu %s", event_id) logger.debug("Processed pdu %s", event_id)

View file

@ -78,7 +78,7 @@ class BaseHandler(object):
if not suppress_auth: if not suppress_auth:
logger.debug("Authing...") logger.debug("Authing...")
self.auth.check(event, raises=True) self.auth.check(event, auth_events=event.old_state_events)
logger.debug("Authed") logger.debug("Authed")
else: else:
logger.debug("Suppressed auth.") logger.debug("Suppressed auth.")

View file

@ -24,7 +24,8 @@ from synapse.api.constants import Membership
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, check_event_content_hash compute_event_signature, check_event_content_hash,
add_hashes_and_signatures,
) )
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -141,15 +142,14 @@ class FederationHandler(BaseHandler):
) )
event = redacted_event event = redacted_event
is_new_state = yield self.state_handler.annotate_event_with_state(
event,
old_state=state
)
logger.debug("Event: %s", event) logger.debug("Event: %s", event)
try: try:
self.auth.check(event, raises=True) yield self._handle_new_event(
event,
state=state,
backfilled=backfilled
)
except AuthError as e: except AuthError as e:
raise FederationError( raise FederationError(
"ERROR", "ERROR",
@ -158,17 +158,6 @@ class FederationHandler(BaseHandler):
affected=event.event_id, affected=event.event_id,
) )
is_new_state = is_new_state and not backfilled
# TODO: Implement something in federation that allows us to
# respond to PDU.
yield self.store.persist_event(
event,
backfilled,
is_new_state=is_new_state
)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)
if not room: if not room:
@ -276,6 +265,8 @@ class FederationHandler(BaseHandler):
We suspend processing of any received events from this room until we We suspend processing of any received events from this room until we
have finished processing the join. have finished processing the join.
""" """
logger.debug("Joining %s to %s", joinee, room_id)
pdu = yield self.replication_layer.make_join( pdu = yield self.replication_layer.make_join(
target_host, target_host,
room_id, room_id,
@ -298,19 +289,28 @@ class FederationHandler(BaseHandler):
try: try:
event.event_id = self.event_factory.create_event_id() event.event_id = self.event_factory.create_event_id()
event.origin = self.hs.hostname
event.content = content event.content = content
state = yield self.replication_layer.send_join( if not hasattr(event, "signatures"):
event.signatures = {}
add_hashes_and_signatures(
event,
self.hs.hostname,
self.hs.config.signing_key[0],
)
ret = yield self.replication_layer.send_join(
target_host, target_host,
event event
) )
logger.debug("do_invite_join state: %s", state) state = ret["state"]
auth_chain = ret["auth_chain"]
yield self.state_handler.annotate_event_with_state( logger.debug("do_invite_join auth_chain: %s", auth_chain)
event, logger.debug("do_invite_join state: %s", state)
old_state=state
)
logger.debug("do_invite_join event: %s", event) logger.debug("do_invite_join event: %s", event)
@ -324,34 +324,31 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
for e in auth_chain:
e.outlier = True
yield self._handle_new_event(e)
for e in state: for e in state:
# FIXME: Auth these. # FIXME: Auth these.
e.outlier = True e.outlier = True
yield self._handle_new_event(e)
yield self.state_handler.annotate_event_with_state( yield self._handle_new_event(event, state=state)
e,
yield self.notifier.on_new_room_event(
event, extra_users=[joinee]
) )
yield self.store.persist_event( logger.debug("Finished joining %s to %s", joinee, room_id)
e,
backfilled=False,
is_new_state=True
)
yield self.store.persist_event(
event,
backfilled=False,
is_new_state=True
)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
del self.room_queues[room_id] del self.room_queues[room_id]
for p in room_queue: for p in room_queue:
try: try:
yield self.on_receive_pdu(p, backfilled=False) self.on_receive_pdu(p, backfilled=False)
except: except:
pass logger.exception("Couldn't handle pdu")
defer.returnValue(True) defer.returnValue(True)
@ -375,7 +372,7 @@ class FederationHandler(BaseHandler):
yield self.state_handler.annotate_event_with_state(event) yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event) yield self.auth.add_auth_events(event)
self.auth.check(event, raises=True) self.auth.check(event, auth_events=event.old_state_events)
pdu = event pdu = event
@ -391,17 +388,7 @@ class FederationHandler(BaseHandler):
event.outlier = False event.outlier = False
state_handler = self.state_handler yield self._handle_new_event(event)
is_new_state = yield state_handler.annotate_event_with_state(event)
self.auth.check(event, raises=True)
# FIXME (erikj): All this is duplicated above :(
yield self.store.persist_event(
event,
backfilled=False,
is_new_state=is_new_state
)
extra_users = [] extra_users = []
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
@ -414,7 +401,7 @@ class FederationHandler(BaseHandler):
) )
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire( yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id "user_joined_room", user=user, room_id=event.room_id
@ -565,3 +552,56 @@ class FederationHandler(BaseHandler):
) )
while waiters: while waiters:
waiters.pop().callback(None) waiters.pop().callback(None)
@defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False):
is_new_state = yield self.state_handler.annotate_event_with_state(
event,
old_state=state
)
if event.old_state_events:
known_ids = set(
[s.event_id for s in event.old_state_events.values()]
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
if not e:
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in %s",
event.event_id, e_id, known_ids,
)
raise AuthError(403, "Auth events are stale")
auth_events = event.old_state_events
else:
# We need to get the auth events from somewhere.
# TODO: Don't just hit the DBs?
auth_events = {}
for e_id, _ in event.auth_events:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
if not e:
raise AuthError(403, "Can't find auth event.")
auth_events[(e.type, e.state_key)] = e
self.auth.check(event, auth_events=auth_events)
yield self.store.persist_event(
event,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled)
)

View file

@ -306,7 +306,7 @@ class MessageHandler(BaseHandler):
auth_user = self.hs.parse_userid(user_id) auth_user = self.hs.parse_userid(user_id)
# TODO: These concurrently # TODO: These concurrently
state_tuples = yield self.store.get_current_state(room_id) state_tuples = yield self.state_handler.get_current_state(room_id)
state = [self.hs.serialize_event(x) for x in state_tuples] state = [self.hs.serialize_event(x) for x in state_tuples]
member_event = (yield self.store.get_room_member( member_event = (yield self.store.get_room_member(

View file

@ -83,20 +83,22 @@ class FederationTestCase(unittest.TestCase):
event_id="$a:b", event_id="$a:b",
user_id="@a:b", user_id="@a:b",
origin="b", origin="b",
auth_events=[],
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"}, hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
) )
self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.state_handler.annotate_event_with_state.return_value = ( def annotate(ev, old_state=None):
defer.succeed(False) ev.old_state_events = []
) return defer.succeed(False)
self.state_handler.annotate_event_with_state.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu(pdu, False) yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
self.datastore.persist_event.assert_called_once_with( self.datastore.persist_event.assert_called_once_with(
ANY, False, is_new_state=False ANY, is_new_state=False, backfilled=False
) )
self.state_handler.annotate_event_with_state.assert_called_once_with( self.state_handler.annotate_event_with_state.assert_called_once_with(
@ -104,7 +106,7 @@ class FederationTestCase(unittest.TestCase):
old_state=None, old_state=None,
) )
self.auth.check.assert_called_once_with(ANY, raises=True) self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
ANY, ANY,

View file

@ -120,7 +120,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.datastore.get_room_member.return_value = defer.succeed(None) self.datastore.get_room_member.return_value = defer.succeed(None)
event.state_events = { event.old_state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member( (RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green", user_id="@alice:green",
room_id=room_id, room_id=room_id,
@ -129,9 +129,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
user_id="@bob:red", user_id="@bob:red",
room_id=room_id, room_id=room_id,
), ),
(RoomMemberEvent.TYPE, target_user_id): event,
} }
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event
# Actual invocation # Actual invocation
yield self.room_member_handler.change_membership(event) yield self.room_member_handler.change_membership(event)
@ -187,6 +189,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
(RoomMemberEvent.TYPE, user_id): event, (RoomMemberEvent.TYPE, user_id): event,
} }
event.old_state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
}
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, user_id)] = event
# Actual invocation # Actual invocation
yield self.room_member_handler.change_membership(event) yield self.room_member_handler.change_membership(event)