mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 18:13:54 +01:00
Fix bugs in invite/join dances.
We now do more implement more of the auth on the events so that we don't reject valid events.
This commit is contained in:
parent
3536fd7d60
commit
64fc859dac
8 changed files with 221 additions and 148 deletions
|
@ -38,27 +38,21 @@ class Auth(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
def check(self, event, raises=False):
|
def check(self, event, auth_events):
|
||||||
""" Checks if this event is correctly authed.
|
""" Checks if this event is correctly authed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the auth checks pass.
|
True if the auth checks pass.
|
||||||
Raises:
|
|
||||||
AuthError if there was a problem authorising this event. This will
|
|
||||||
be raised only if raises=True.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if hasattr(event, "room_id"):
|
if not hasattr(event, "room_id"):
|
||||||
if event.old_state_events is None:
|
raise AuthError(500, "Event has no room_id: %s" % event)
|
||||||
|
if auth_events is None:
|
||||||
# Oh, we don't know what the state of the room was, so we
|
# Oh, we don't know what the state of the room was, so we
|
||||||
# are trusting that this is allowed (at least for now)
|
# are trusting that this is allowed (at least for now)
|
||||||
logger.warn("Trusting event: %s", event.event_id)
|
logger.warn("Trusting event: %s", event.event_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if hasattr(event, "outlier") and event.outlier is True:
|
|
||||||
# TODO (erikj): Auth for outliers is done differently.
|
|
||||||
return True
|
|
||||||
|
|
||||||
if event.type == RoomCreateEvent.TYPE:
|
if event.type == RoomCreateEvent.TYPE:
|
||||||
# FIXME
|
# FIXME
|
||||||
return True
|
return True
|
||||||
|
@ -68,49 +62,42 @@ class Auth(object):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if event.type == RoomMemberEvent.TYPE:
|
if event.type == RoomMemberEvent.TYPE:
|
||||||
allowed = self.is_membership_change_allowed(event)
|
allowed = self.is_membership_change_allowed(
|
||||||
|
event, auth_events
|
||||||
|
)
|
||||||
if allowed:
|
if allowed:
|
||||||
logger.debug("Allowing! %s", event)
|
logger.debug("Allowing! %s", event)
|
||||||
else:
|
else:
|
||||||
logger.debug("Denying! %s", event)
|
logger.debug("Denying! %s", event)
|
||||||
return allowed
|
return allowed
|
||||||
|
|
||||||
self.check_event_sender_in_room(event)
|
self.check_event_sender_in_room(event, auth_events)
|
||||||
self._can_send_event(event)
|
self._can_send_event(event, auth_events)
|
||||||
|
|
||||||
if event.type == RoomPowerLevelsEvent.TYPE:
|
if event.type == RoomPowerLevelsEvent.TYPE:
|
||||||
self._check_power_levels(event)
|
self._check_power_levels(event, auth_events)
|
||||||
|
|
||||||
if event.type == RoomRedactionEvent.TYPE:
|
if event.type == RoomRedactionEvent.TYPE:
|
||||||
self._check_redaction(event)
|
self._check_redaction(event, auth_events)
|
||||||
|
|
||||||
logger.debug("Allowing! %s", event)
|
logger.debug("Allowing! %s", event)
|
||||||
return True
|
|
||||||
else:
|
|
||||||
raise AuthError(500, "Unknown event: %s" % event)
|
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Event auth check failed on event %s with msg: %s",
|
"Event auth check failed on event %s with msg: %s",
|
||||||
event, e.msg
|
event, e.msg
|
||||||
)
|
)
|
||||||
logger.info("Denying! %s", event)
|
logger.info("Denying! %s", event)
|
||||||
if raises:
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_joined_room(self, room_id, user_id):
|
def check_joined_room(self, room_id, user_id):
|
||||||
try:
|
member = yield self.state.get_current_state(
|
||||||
member = yield self.store.get_room_member(
|
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user_id=user_id
|
event_type=RoomMemberEvent.TYPE,
|
||||||
|
state_key=user_id
|
||||||
)
|
)
|
||||||
self._check_joined_room(member, user_id, room_id)
|
self._check_joined_room(member, user_id, room_id)
|
||||||
defer.returnValue(member)
|
defer.returnValue(member)
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
defer.returnValue(None)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_host_in_room(self, room_id, host):
|
def check_host_in_room(self, room_id, host):
|
||||||
|
@ -130,9 +117,9 @@ class Auth(object):
|
||||||
|
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
def check_event_sender_in_room(self, event):
|
def check_event_sender_in_room(self, event, auth_events):
|
||||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||||
member_event = event.state_events.get(key)
|
member_event = auth_events.get(key)
|
||||||
|
|
||||||
return self._check_joined_room(
|
return self._check_joined_room(
|
||||||
member_event,
|
member_event,
|
||||||
|
@ -147,14 +134,14 @@ class Auth(object):
|
||||||
))
|
))
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def is_membership_change_allowed(self, event):
|
def is_membership_change_allowed(self, event, auth_events):
|
||||||
membership = event.content["membership"]
|
membership = event.content["membership"]
|
||||||
|
|
||||||
# Check if this is the room creator joining:
|
# Check if this is the room creator joining:
|
||||||
if len(event.prev_events) == 1 and Membership.JOIN == membership:
|
if len(event.prev_events) == 1 and Membership.JOIN == membership:
|
||||||
# Get room creation event:
|
# Get room creation event:
|
||||||
key = (RoomCreateEvent.TYPE, "", )
|
key = (RoomCreateEvent.TYPE, "", )
|
||||||
create = event.old_state_events.get(key)
|
create = auth_events.get(key)
|
||||||
if create and event.prev_events[0][0] == create.event_id:
|
if create and event.prev_events[0][0] == create.event_id:
|
||||||
if create.content["creator"] == event.state_key:
|
if create.content["creator"] == event.state_key:
|
||||||
return True
|
return True
|
||||||
|
@ -163,19 +150,19 @@ class Auth(object):
|
||||||
|
|
||||||
# get info about the caller
|
# get info about the caller
|
||||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||||
caller = event.old_state_events.get(key)
|
caller = auth_events.get(key)
|
||||||
|
|
||||||
caller_in_room = caller and caller.membership == Membership.JOIN
|
caller_in_room = caller and caller.membership == Membership.JOIN
|
||||||
caller_invited = caller and caller.membership == Membership.INVITE
|
caller_invited = caller and caller.membership == Membership.INVITE
|
||||||
|
|
||||||
# get info about the target
|
# get info about the target
|
||||||
key = (RoomMemberEvent.TYPE, target_user_id, )
|
key = (RoomMemberEvent.TYPE, target_user_id, )
|
||||||
target = event.old_state_events.get(key)
|
target = auth_events.get(key)
|
||||||
|
|
||||||
target_in_room = target and target.membership == Membership.JOIN
|
target_in_room = target and target.membership == Membership.JOIN
|
||||||
|
|
||||||
key = (RoomJoinRulesEvent.TYPE, "", )
|
key = (RoomJoinRulesEvent.TYPE, "", )
|
||||||
join_rule_event = event.old_state_events.get(key)
|
join_rule_event = auth_events.get(key)
|
||||||
if join_rule_event:
|
if join_rule_event:
|
||||||
join_rule = join_rule_event.content.get(
|
join_rule = join_rule_event.content.get(
|
||||||
"join_rule", JoinRules.INVITE
|
"join_rule", JoinRules.INVITE
|
||||||
|
@ -186,11 +173,13 @@ class Auth(object):
|
||||||
user_level = self._get_power_level_from_event_state(
|
user_level = self._get_power_level_from_event_state(
|
||||||
event,
|
event,
|
||||||
event.user_id,
|
event.user_id,
|
||||||
|
auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
ban_level, kick_level, redact_level = (
|
ban_level, kick_level, redact_level = (
|
||||||
self._get_ops_level_from_event_state(
|
self._get_ops_level_from_event_state(
|
||||||
event
|
event,
|
||||||
|
auth_events,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -260,9 +249,9 @@ class Auth(object):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _get_power_level_from_event_state(self, event, user_id):
|
def _get_power_level_from_event_state(self, event, user_id, auth_events):
|
||||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||||
power_level_event = event.old_state_events.get(key)
|
power_level_event = auth_events.get(key)
|
||||||
level = None
|
level = None
|
||||||
if power_level_event:
|
if power_level_event:
|
||||||
level = power_level_event.content.get("users", {}).get(user_id)
|
level = power_level_event.content.get("users", {}).get(user_id)
|
||||||
|
@ -270,16 +259,16 @@ class Auth(object):
|
||||||
level = power_level_event.content.get("users_default", 0)
|
level = power_level_event.content.get("users_default", 0)
|
||||||
else:
|
else:
|
||||||
key = (RoomCreateEvent.TYPE, "", )
|
key = (RoomCreateEvent.TYPE, "", )
|
||||||
create_event = event.old_state_events.get(key)
|
create_event = auth_events.get(key)
|
||||||
if (create_event is not None and
|
if (create_event is not None and
|
||||||
create_event.content["creator"] == user_id):
|
create_event.content["creator"] == user_id):
|
||||||
return 100
|
return 100
|
||||||
|
|
||||||
return level
|
return level
|
||||||
|
|
||||||
def _get_ops_level_from_event_state(self, event):
|
def _get_ops_level_from_event_state(self, event, auth_events):
|
||||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||||
power_level_event = event.old_state_events.get(key)
|
power_level_event = auth_events.get(key)
|
||||||
|
|
||||||
if power_level_event:
|
if power_level_event:
|
||||||
return (
|
return (
|
||||||
|
@ -375,6 +364,11 @@ class Auth(object):
|
||||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||||
member_event = event.old_state_events.get(key)
|
member_event = event.old_state_events.get(key)
|
||||||
|
|
||||||
|
key = (RoomCreateEvent.TYPE, "", )
|
||||||
|
create_event = event.old_state_events.get(key)
|
||||||
|
if create_event:
|
||||||
|
auth_events.append(create_event.event_id)
|
||||||
|
|
||||||
if join_rule_event:
|
if join_rule_event:
|
||||||
join_rule = join_rule_event.content.get("join_rule")
|
join_rule = join_rule_event.content.get("join_rule")
|
||||||
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
||||||
|
@ -406,9 +400,9 @@ class Auth(object):
|
||||||
event.auth_events = zip(auth_events, hashes)
|
event.auth_events = zip(auth_events, hashes)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _can_send_event(self, event):
|
def _can_send_event(self, event, auth_events):
|
||||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||||
send_level_event = event.old_state_events.get(key)
|
send_level_event = auth_events.get(key)
|
||||||
send_level = None
|
send_level = None
|
||||||
if send_level_event:
|
if send_level_event:
|
||||||
send_level = send_level_event.content.get("events", {}).get(
|
send_level = send_level_event.content.get("events", {}).get(
|
||||||
|
@ -432,6 +426,7 @@ class Auth(object):
|
||||||
user_level = self._get_power_level_from_event_state(
|
user_level = self._get_power_level_from_event_state(
|
||||||
event,
|
event,
|
||||||
event.user_id,
|
event.user_id,
|
||||||
|
auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_level:
|
if user_level:
|
||||||
|
@ -468,14 +463,16 @@ class Auth(object):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _check_redaction(self, event):
|
def _check_redaction(self, event, auth_events):
|
||||||
user_level = self._get_power_level_from_event_state(
|
user_level = self._get_power_level_from_event_state(
|
||||||
event,
|
event,
|
||||||
event.user_id,
|
event.user_id,
|
||||||
|
auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
_, _, redact_level = self._get_ops_level_from_event_state(
|
_, _, redact_level = self._get_ops_level_from_event_state(
|
||||||
event
|
event,
|
||||||
|
auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_level < redact_level:
|
if user_level < redact_level:
|
||||||
|
@ -484,7 +481,7 @@ class Auth(object):
|
||||||
"You don't have permission to redact events"
|
"You don't have permission to redact events"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_power_levels(self, event):
|
def _check_power_levels(self, event, auth_events):
|
||||||
user_list = event.content.get("users", {})
|
user_list = event.content.get("users", {})
|
||||||
# Validate users
|
# Validate users
|
||||||
for k, v in user_list.items():
|
for k, v in user_list.items():
|
||||||
|
@ -499,7 +496,7 @@ class Auth(object):
|
||||||
raise SynapseError(400, "Not a valid power level: %s" % (v,))
|
raise SynapseError(400, "Not a valid power level: %s" % (v,))
|
||||||
|
|
||||||
key = (event.type, event.state_key, )
|
key = (event.type, event.state_key, )
|
||||||
current_state = event.old_state_events.get(key)
|
current_state = auth_events.get(key)
|
||||||
|
|
||||||
if not current_state:
|
if not current_state:
|
||||||
return
|
return
|
||||||
|
@ -507,6 +504,7 @@ class Auth(object):
|
||||||
user_level = self._get_power_level_from_event_state(
|
user_level = self._get_power_level_from_event_state(
|
||||||
event,
|
event,
|
||||||
event.user_id,
|
event.user_id,
|
||||||
|
auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check other levels:
|
# Check other levels:
|
||||||
|
|
|
@ -125,6 +125,7 @@ class SynapseEvent(JsonEncodedObject):
|
||||||
pdu_json.pop("outlier", None)
|
pdu_json.pop("outlier", None)
|
||||||
pdu_json.pop("replaces_state", None)
|
pdu_json.pop("replaces_state", None)
|
||||||
pdu_json.pop("redacted", None)
|
pdu_json.pop("redacted", None)
|
||||||
|
pdu_json.pop("prev_content", None)
|
||||||
state_hash = pdu_json.pop("state_hash", None)
|
state_hash = pdu_json.pop("state_hash", None)
|
||||||
if state_hash is not None:
|
if state_hash is not None:
|
||||||
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash
|
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash
|
||||||
|
|
|
@ -481,11 +481,17 @@ class ReplicationLayer(object):
|
||||||
# FIXME: We probably want to do something with the auth_chain given
|
# FIXME: We probably want to do something with the auth_chain given
|
||||||
# to us
|
# to us
|
||||||
|
|
||||||
# auth_chain = [
|
auth_chain = [
|
||||||
# Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
|
self.event_from_pdu_json(p, outlier=True)
|
||||||
# ]
|
for p in content.get("auth_chain", [])
|
||||||
|
]
|
||||||
|
|
||||||
defer.returnValue(state)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"state": state,
|
||||||
|
"auth_chain": auth_chain,
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_invite(self, destination, context, event_id, pdu):
|
def send_invite(self, destination, context, event_id, pdu):
|
||||||
|
@ -551,12 +557,26 @@ class ReplicationLayer(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not exists:
|
if not exists:
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
"Getting missing auth event %s from %s",
|
||||||
|
e_id,
|
||||||
|
origin,
|
||||||
|
)
|
||||||
|
|
||||||
yield self.get_pdu(
|
yield self.get_pdu(
|
||||||
origin,
|
origin,
|
||||||
event_id=e_id,
|
event_id=e_id,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Processed pdu %s", e_id)
|
logger.debug("Processed pdu %s", e_id)
|
||||||
|
except:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to get auth event %s from %s",
|
||||||
|
e_id,
|
||||||
|
origin
|
||||||
|
)
|
||||||
|
|
||||||
# Get missing pdus if necessary.
|
# Get missing pdus if necessary.
|
||||||
if not pdu.outlier:
|
if not pdu.outlier:
|
||||||
|
@ -578,7 +598,7 @@ class ReplicationLayer(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.get_pdu(
|
yield self.get_pdu(
|
||||||
pdu.origin,
|
origin,
|
||||||
event_id=event_id,
|
event_id=event_id,
|
||||||
)
|
)
|
||||||
logger.debug("Processed pdu %s", event_id)
|
logger.debug("Processed pdu %s", event_id)
|
||||||
|
|
|
@ -78,7 +78,7 @@ class BaseHandler(object):
|
||||||
|
|
||||||
if not suppress_auth:
|
if not suppress_auth:
|
||||||
logger.debug("Authing...")
|
logger.debug("Authing...")
|
||||||
self.auth.check(event, raises=True)
|
self.auth.check(event, auth_events=event.old_state_events)
|
||||||
logger.debug("Authed")
|
logger.debug("Authed")
|
||||||
else:
|
else:
|
||||||
logger.debug("Suppressed auth.")
|
logger.debug("Suppressed auth.")
|
||||||
|
|
|
@ -24,7 +24,8 @@ from synapse.api.constants import Membership
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.crypto.event_signing import (
|
from synapse.crypto.event_signing import (
|
||||||
compute_event_signature, check_event_content_hash
|
compute_event_signature, check_event_content_hash,
|
||||||
|
add_hashes_and_signatures,
|
||||||
)
|
)
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
|
@ -141,15 +142,14 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
event = redacted_event
|
event = redacted_event
|
||||||
|
|
||||||
is_new_state = yield self.state_handler.annotate_event_with_state(
|
|
||||||
event,
|
|
||||||
old_state=state
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Event: %s", event)
|
logger.debug("Event: %s", event)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, raises=True)
|
yield self._handle_new_event(
|
||||||
|
event,
|
||||||
|
state=state,
|
||||||
|
backfilled=backfilled
|
||||||
|
)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
raise FederationError(
|
raise FederationError(
|
||||||
"ERROR",
|
"ERROR",
|
||||||
|
@ -158,17 +158,6 @@ class FederationHandler(BaseHandler):
|
||||||
affected=event.event_id,
|
affected=event.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_new_state = is_new_state and not backfilled
|
|
||||||
|
|
||||||
# TODO: Implement something in federation that allows us to
|
|
||||||
# respond to PDU.
|
|
||||||
|
|
||||||
yield self.store.persist_event(
|
|
||||||
event,
|
|
||||||
backfilled,
|
|
||||||
is_new_state=is_new_state
|
|
||||||
)
|
|
||||||
|
|
||||||
room = yield self.store.get_room(event.room_id)
|
room = yield self.store.get_room(event.room_id)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
|
@ -276,6 +265,8 @@ class FederationHandler(BaseHandler):
|
||||||
We suspend processing of any received events from this room until we
|
We suspend processing of any received events from this room until we
|
||||||
have finished processing the join.
|
have finished processing the join.
|
||||||
"""
|
"""
|
||||||
|
logger.debug("Joining %s to %s", joinee, room_id)
|
||||||
|
|
||||||
pdu = yield self.replication_layer.make_join(
|
pdu = yield self.replication_layer.make_join(
|
||||||
target_host,
|
target_host,
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -298,19 +289,28 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event.event_id = self.event_factory.create_event_id()
|
event.event_id = self.event_factory.create_event_id()
|
||||||
|
event.origin = self.hs.hostname
|
||||||
event.content = content
|
event.content = content
|
||||||
|
|
||||||
state = yield self.replication_layer.send_join(
|
if not hasattr(event, "signatures"):
|
||||||
|
event.signatures = {}
|
||||||
|
|
||||||
|
add_hashes_and_signatures(
|
||||||
|
event,
|
||||||
|
self.hs.hostname,
|
||||||
|
self.hs.config.signing_key[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = yield self.replication_layer.send_join(
|
||||||
target_host,
|
target_host,
|
||||||
event
|
event
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("do_invite_join state: %s", state)
|
state = ret["state"]
|
||||||
|
auth_chain = ret["auth_chain"]
|
||||||
|
|
||||||
yield self.state_handler.annotate_event_with_state(
|
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
||||||
event,
|
logger.debug("do_invite_join state: %s", state)
|
||||||
old_state=state
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("do_invite_join event: %s", event)
|
logger.debug("do_invite_join event: %s", event)
|
||||||
|
|
||||||
|
@ -324,34 +324,31 @@ class FederationHandler(BaseHandler):
|
||||||
# FIXME
|
# FIXME
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
for e in auth_chain:
|
||||||
|
e.outlier = True
|
||||||
|
yield self._handle_new_event(e)
|
||||||
|
|
||||||
for e in state:
|
for e in state:
|
||||||
# FIXME: Auth these.
|
# FIXME: Auth these.
|
||||||
e.outlier = True
|
e.outlier = True
|
||||||
|
yield self._handle_new_event(e)
|
||||||
|
|
||||||
yield self.state_handler.annotate_event_with_state(
|
yield self._handle_new_event(event, state=state)
|
||||||
e,
|
|
||||||
|
yield self.notifier.on_new_room_event(
|
||||||
|
event, extra_users=[joinee]
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.store.persist_event(
|
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||||
e,
|
|
||||||
backfilled=False,
|
|
||||||
is_new_state=True
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.store.persist_event(
|
|
||||||
event,
|
|
||||||
backfilled=False,
|
|
||||||
is_new_state=True
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
room_queue = self.room_queues[room_id]
|
room_queue = self.room_queues[room_id]
|
||||||
del self.room_queues[room_id]
|
del self.room_queues[room_id]
|
||||||
|
|
||||||
for p in room_queue:
|
for p in room_queue:
|
||||||
try:
|
try:
|
||||||
yield self.on_receive_pdu(p, backfilled=False)
|
self.on_receive_pdu(p, backfilled=False)
|
||||||
except:
|
except:
|
||||||
pass
|
logger.exception("Couldn't handle pdu")
|
||||||
|
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
@ -375,7 +372,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
yield self.state_handler.annotate_event_with_state(event)
|
yield self.state_handler.annotate_event_with_state(event)
|
||||||
yield self.auth.add_auth_events(event)
|
yield self.auth.add_auth_events(event)
|
||||||
self.auth.check(event, raises=True)
|
self.auth.check(event, auth_events=event.old_state_events)
|
||||||
|
|
||||||
pdu = event
|
pdu = event
|
||||||
|
|
||||||
|
@ -391,17 +388,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
event.outlier = False
|
event.outlier = False
|
||||||
|
|
||||||
state_handler = self.state_handler
|
yield self._handle_new_event(event)
|
||||||
is_new_state = yield state_handler.annotate_event_with_state(event)
|
|
||||||
self.auth.check(event, raises=True)
|
|
||||||
|
|
||||||
# FIXME (erikj): All this is duplicated above :(
|
|
||||||
|
|
||||||
yield self.store.persist_event(
|
|
||||||
event,
|
|
||||||
backfilled=False,
|
|
||||||
is_new_state=is_new_state
|
|
||||||
)
|
|
||||||
|
|
||||||
extra_users = []
|
extra_users = []
|
||||||
if event.type == RoomMemberEvent.TYPE:
|
if event.type == RoomMemberEvent.TYPE:
|
||||||
|
@ -414,7 +401,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == RoomMemberEvent.TYPE:
|
if event.type == RoomMemberEvent.TYPE:
|
||||||
if event.membership == Membership.JOIN:
|
if event.content["membership"] == Membership.JOIN:
|
||||||
user = self.hs.parse_userid(event.state_key)
|
user = self.hs.parse_userid(event.state_key)
|
||||||
yield self.distributor.fire(
|
yield self.distributor.fire(
|
||||||
"user_joined_room", user=user, room_id=event.room_id
|
"user_joined_room", user=user, room_id=event.room_id
|
||||||
|
@ -565,3 +552,56 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
while waiters:
|
while waiters:
|
||||||
waiters.pop().callback(None)
|
waiters.pop().callback(None)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_new_event(self, event, state=None, backfilled=False):
|
||||||
|
is_new_state = yield self.state_handler.annotate_event_with_state(
|
||||||
|
event,
|
||||||
|
old_state=state
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.old_state_events:
|
||||||
|
known_ids = set(
|
||||||
|
[s.event_id for s in event.old_state_events.values()]
|
||||||
|
)
|
||||||
|
for e_id, _ in event.auth_events:
|
||||||
|
if e_id not in known_ids:
|
||||||
|
e = yield self.store.get_event(
|
||||||
|
e_id,
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not e:
|
||||||
|
# TODO: Do some conflict res to make sure that we're
|
||||||
|
# not the ones who are wrong.
|
||||||
|
logger.info(
|
||||||
|
"Rejecting %s as %s not in %s",
|
||||||
|
event.event_id, e_id, known_ids,
|
||||||
|
)
|
||||||
|
raise AuthError(403, "Auth events are stale")
|
||||||
|
|
||||||
|
auth_events = event.old_state_events
|
||||||
|
else:
|
||||||
|
# We need to get the auth events from somewhere.
|
||||||
|
|
||||||
|
# TODO: Don't just hit the DBs?
|
||||||
|
|
||||||
|
auth_events = {}
|
||||||
|
for e_id, _ in event.auth_events:
|
||||||
|
e = yield self.store.get_event(
|
||||||
|
e_id,
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not e:
|
||||||
|
raise AuthError(403, "Can't find auth event.")
|
||||||
|
|
||||||
|
auth_events[(e.type, e.state_key)] = e
|
||||||
|
|
||||||
|
self.auth.check(event, auth_events=auth_events)
|
||||||
|
|
||||||
|
yield self.store.persist_event(
|
||||||
|
event,
|
||||||
|
backfilled=backfilled,
|
||||||
|
is_new_state=(is_new_state and not backfilled)
|
||||||
|
)
|
||||||
|
|
|
@ -306,7 +306,7 @@ class MessageHandler(BaseHandler):
|
||||||
auth_user = self.hs.parse_userid(user_id)
|
auth_user = self.hs.parse_userid(user_id)
|
||||||
|
|
||||||
# TODO: These concurrently
|
# TODO: These concurrently
|
||||||
state_tuples = yield self.store.get_current_state(room_id)
|
state_tuples = yield self.state_handler.get_current_state(room_id)
|
||||||
state = [self.hs.serialize_event(x) for x in state_tuples]
|
state = [self.hs.serialize_event(x) for x in state_tuples]
|
||||||
|
|
||||||
member_event = (yield self.store.get_room_member(
|
member_event = (yield self.store.get_room_member(
|
||||||
|
|
|
@ -83,20 +83,22 @@ class FederationTestCase(unittest.TestCase):
|
||||||
event_id="$a:b",
|
event_id="$a:b",
|
||||||
user_id="@a:b",
|
user_id="@a:b",
|
||||||
origin="b",
|
origin="b",
|
||||||
|
auth_events=[],
|
||||||
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.persist_event.return_value = defer.succeed(None)
|
self.datastore.persist_event.return_value = defer.succeed(None)
|
||||||
self.datastore.get_room.return_value = defer.succeed(True)
|
self.datastore.get_room.return_value = defer.succeed(True)
|
||||||
|
|
||||||
self.state_handler.annotate_event_with_state.return_value = (
|
def annotate(ev, old_state=None):
|
||||||
defer.succeed(False)
|
ev.old_state_events = []
|
||||||
)
|
return defer.succeed(False)
|
||||||
|
self.state_handler.annotate_event_with_state.side_effect = annotate
|
||||||
|
|
||||||
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
|
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
|
||||||
|
|
||||||
self.datastore.persist_event.assert_called_once_with(
|
self.datastore.persist_event.assert_called_once_with(
|
||||||
ANY, False, is_new_state=False
|
ANY, is_new_state=False, backfilled=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state_handler.annotate_event_with_state.assert_called_once_with(
|
self.state_handler.annotate_event_with_state.assert_called_once_with(
|
||||||
|
@ -104,7 +106,7 @@ class FederationTestCase(unittest.TestCase):
|
||||||
old_state=None,
|
old_state=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth.check.assert_called_once_with(ANY, raises=True)
|
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||||
|
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
self.notifier.on_new_room_event.assert_called_once_with(
|
||||||
ANY,
|
ANY,
|
||||||
|
|
|
@ -120,7 +120,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.datastore.get_room_member.return_value = defer.succeed(None)
|
self.datastore.get_room_member.return_value = defer.succeed(None)
|
||||||
|
|
||||||
event.state_events = {
|
event.old_state_events = {
|
||||||
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||||
user_id="@alice:green",
|
user_id="@alice:green",
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -129,9 +129,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
user_id="@bob:red",
|
user_id="@bob:red",
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
),
|
),
|
||||||
(RoomMemberEvent.TYPE, target_user_id): event,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
event.state_events = event.old_state_events
|
||||||
|
event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event
|
||||||
|
|
||||||
# Actual invocation
|
# Actual invocation
|
||||||
yield self.room_member_handler.change_membership(event)
|
yield self.room_member_handler.change_membership(event)
|
||||||
|
|
||||||
|
@ -187,6 +189,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
(RoomMemberEvent.TYPE, user_id): event,
|
(RoomMemberEvent.TYPE, user_id): event,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
event.old_state_events = {
|
||||||
|
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||||
|
user_id="@alice:green",
|
||||||
|
room_id=room_id,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
event.state_events = event.old_state_events
|
||||||
|
event.state_events[(RoomMemberEvent.TYPE, user_id)] = event
|
||||||
|
|
||||||
# Actual invocation
|
# Actual invocation
|
||||||
yield self.room_member_handler.change_membership(event)
|
yield self.room_member_handler.change_membership(event)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue