forked from MirrorHub/synapse
commit
7d897f5bfc
16 changed files with 708 additions and 172 deletions
|
@ -102,6 +102,8 @@ class Auth(object):
|
||||||
def check_host_in_room(self, room_id, host):
|
def check_host_in_room(self, room_id, host):
|
||||||
curr_state = yield self.state.get_current_state(room_id)
|
curr_state = yield self.state.get_current_state(room_id)
|
||||||
|
|
||||||
|
logger.debug("Got curr_state %s", curr_state)
|
||||||
|
|
||||||
for event in curr_state:
|
for event in curr_state:
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
try:
|
try:
|
||||||
|
@ -358,9 +360,23 @@ class Auth(object):
|
||||||
def add_auth_events(self, builder, context):
|
def add_auth_events(self, builder, context):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
if builder.type == EventTypes.Create:
|
auth_ids = self.compute_auth_events(builder, context)
|
||||||
builder.auth_events = []
|
|
||||||
return
|
auth_events_entries = yield self.store.add_event_hashes(
|
||||||
|
auth_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.auth_events = auth_events_entries
|
||||||
|
|
||||||
|
context.auth_events = {
|
||||||
|
k: v
|
||||||
|
for k, v in context.current_state.items()
|
||||||
|
if v.event_id in auth_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
def compute_auth_events(self, event, context):
|
||||||
|
if event.type == EventTypes.Create:
|
||||||
|
return []
|
||||||
|
|
||||||
auth_ids = []
|
auth_ids = []
|
||||||
|
|
||||||
|
@ -373,7 +389,7 @@ class Auth(object):
|
||||||
key = (EventTypes.JoinRules, "", )
|
key = (EventTypes.JoinRules, "", )
|
||||||
join_rule_event = context.current_state.get(key)
|
join_rule_event = context.current_state.get(key)
|
||||||
|
|
||||||
key = (EventTypes.Member, builder.user_id, )
|
key = (EventTypes.Member, event.user_id, )
|
||||||
member_event = context.current_state.get(key)
|
member_event = context.current_state.get(key)
|
||||||
|
|
||||||
key = (EventTypes.Create, "", )
|
key = (EventTypes.Create, "", )
|
||||||
|
@ -387,8 +403,8 @@ class Auth(object):
|
||||||
else:
|
else:
|
||||||
is_public = False
|
is_public = False
|
||||||
|
|
||||||
if builder.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
e_type = builder.content["membership"]
|
e_type = event.content["membership"]
|
||||||
if e_type in [Membership.JOIN, Membership.INVITE]:
|
if e_type in [Membership.JOIN, Membership.INVITE]:
|
||||||
if join_rule_event:
|
if join_rule_event:
|
||||||
auth_ids.append(join_rule_event.event_id)
|
auth_ids.append(join_rule_event.event_id)
|
||||||
|
@ -403,17 +419,7 @@ class Auth(object):
|
||||||
if member_event.content["membership"] == Membership.JOIN:
|
if member_event.content["membership"] == Membership.JOIN:
|
||||||
auth_ids.append(member_event.event_id)
|
auth_ids.append(member_event.event_id)
|
||||||
|
|
||||||
auth_events_entries = yield self.store.add_event_hashes(
|
return auth_ids
|
||||||
auth_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
builder.auth_events = auth_events_entries
|
|
||||||
|
|
||||||
context.auth_events = {
|
|
||||||
k: v
|
|
||||||
for k, v in context.current_state.items()
|
|
||||||
if v.event_id in auth_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _can_send_event(self, event, auth_events):
|
def _can_send_event(self, event, auth_events):
|
||||||
|
|
|
@ -74,3 +74,9 @@ class EventTypes(object):
|
||||||
Message = "m.room.message"
|
Message = "m.room.message"
|
||||||
Topic = "m.room.topic"
|
Topic = "m.room.topic"
|
||||||
Name = "m.room.name"
|
Name = "m.room.name"
|
||||||
|
|
||||||
|
|
||||||
|
class RejectedReason(object):
|
||||||
|
AUTH_ERROR = "auth_error"
|
||||||
|
REPLACED = "replaced"
|
||||||
|
NOT_ANCESTOR = "not_ancestor"
|
||||||
|
|
|
@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.remote_key = defer.Deferred()
|
self.remote_key = defer.Deferred()
|
||||||
|
self.host = None
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
logger.debug("Connected to %s", self.transport.getHost())
|
self.host = self.transport.getHost()
|
||||||
|
logger.debug("Connected to %s", self.host)
|
||||||
self.sendCommand(b"GET", b"/_matrix/key/v1/")
|
self.sendCommand(b"GET", b"/_matrix/key/v1/")
|
||||||
self.endHeaders()
|
self.endHeaders()
|
||||||
self.timer = reactor.callLater(
|
self.timer = reactor.callLater(
|
||||||
|
@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
self.timer.cancel()
|
self.timer.cancel()
|
||||||
|
|
||||||
def on_timeout(self):
|
def on_timeout(self):
|
||||||
logger.debug("Timeout waiting for response from %s",
|
logger.debug("Timeout waiting for response from %s", self.host)
|
||||||
self.transport.getHost())
|
|
||||||
self.remote_key.errback(IOError("Timeout waiting for response"))
|
self.remote_key.errback(IOError("Timeout waiting for response"))
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from synapse.util.frozenutils import freeze, unfreeze
|
||||||
|
|
||||||
class _EventInternalMetadata(object):
|
class _EventInternalMetadata(object):
|
||||||
def __init__(self, internal_metadata_dict):
|
def __init__(self, internal_metadata_dict):
|
||||||
self.__dict__ = internal_metadata_dict
|
self.__dict__ = dict(internal_metadata_dict)
|
||||||
|
|
||||||
def get_dict(self):
|
def get_dict(self):
|
||||||
return dict(self.__dict__)
|
return dict(self.__dict__)
|
||||||
|
|
|
@ -23,14 +23,15 @@ import copy
|
||||||
|
|
||||||
|
|
||||||
class EventBuilder(EventBase):
|
class EventBuilder(EventBase):
|
||||||
def __init__(self, key_values={}):
|
def __init__(self, key_values={}, internal_metadata_dict={}):
|
||||||
signatures = copy.deepcopy(key_values.pop("signatures", {}))
|
signatures = copy.deepcopy(key_values.pop("signatures", {}))
|
||||||
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
|
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
|
||||||
|
|
||||||
super(EventBuilder, self).__init__(
|
super(EventBuilder, self).__init__(
|
||||||
key_values,
|
key_values,
|
||||||
signatures=signatures,
|
signatures=signatures,
|
||||||
unsigned=unsigned
|
unsigned=unsigned,
|
||||||
|
internal_metadata_dict=internal_metadata_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
|
|
|
@ -45,12 +45,14 @@ def prune_event(event):
|
||||||
"membership",
|
"membership",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
event_dict = event.get_dict()
|
||||||
|
|
||||||
new_content = {}
|
new_content = {}
|
||||||
|
|
||||||
def add_fields(*fields):
|
def add_fields(*fields):
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field in event.content:
|
if field in event.content:
|
||||||
new_content[field] = event.content[field]
|
new_content[field] = event_dict["content"][field]
|
||||||
|
|
||||||
if event_type == EventTypes.Member:
|
if event_type == EventTypes.Member:
|
||||||
add_fields("membership")
|
add_fields("membership")
|
||||||
|
@ -75,7 +77,7 @@ def prune_event(event):
|
||||||
|
|
||||||
allowed_fields = {
|
allowed_fields = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in event.get_dict().items()
|
for k, v in event_dict.items()
|
||||||
if k in allowed_keys
|
if k in allowed_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +88,10 @@ def prune_event(event):
|
||||||
if "age_ts" in event.unsigned:
|
if "age_ts" in event.unsigned:
|
||||||
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
|
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
|
||||||
|
|
||||||
return type(event)(allowed_fields)
|
return type(event)(
|
||||||
|
allowed_fields,
|
||||||
|
internal_metadata_dict=event.internal_metadata.get_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_event_raw(d):
|
def format_event_raw(d):
|
||||||
|
|
|
@ -20,6 +20,13 @@ from .units import Edu
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
|
from synapse.crypto.event_signing import check_event_content_hash
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -126,6 +133,11 @@ class FederationClient(object):
|
||||||
for p in transaction_data["pdus"]
|
for p in transaction_data["pdus"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for i, pdu in enumerate(pdus):
|
||||||
|
pdus[i] = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
|
|
||||||
defer.returnValue(pdus)
|
defer.returnValue(pdus)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -159,12 +171,6 @@ class FederationClient(object):
|
||||||
transaction_data = yield self.transport_layer.get_event(
|
transaction_data = yield self.transport_layer.get_event(
|
||||||
destination, event_id
|
destination, event_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
|
||||||
"Failed to get PDU %s from %s because %s",
|
|
||||||
event_id, destination, e,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.debug("transaction_data %r", transaction_data)
|
logger.debug("transaction_data %r", transaction_data)
|
||||||
|
|
||||||
|
@ -175,9 +181,19 @@ class FederationClient(object):
|
||||||
|
|
||||||
if pdu_list:
|
if pdu_list:
|
||||||
pdu = pdu_list[0]
|
pdu = pdu_list[0]
|
||||||
# TODO: We need to check signatures here
|
|
||||||
|
# Check signatures are correct.
|
||||||
|
pdu = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
"Failed to get PDU %s from %s because %s",
|
||||||
|
event_id, destination, e,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
defer.returnValue(pdu)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -208,6 +224,16 @@ class FederationClient(object):
|
||||||
for p in result.get("auth_chain", [])
|
for p in result.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for i, pdu in enumerate(pdus):
|
||||||
|
pdus[i] = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
|
|
||||||
|
for i, pdu in enumerate(auth_chain):
|
||||||
|
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
|
|
||||||
defer.returnValue((pdus, auth_chain))
|
defer.returnValue((pdus, auth_chain))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -222,6 +248,11 @@ class FederationClient(object):
|
||||||
for p in res["auth_chain"]
|
for p in res["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for i, pdu in enumerate(auth_chain):
|
||||||
|
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
|
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
defer.returnValue(auth_chain)
|
defer.returnValue(auth_chain)
|
||||||
|
@ -260,6 +291,16 @@ class FederationClient(object):
|
||||||
for p in content.get("auth_chain", [])
|
for p in content.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for i, pdu in enumerate(state):
|
||||||
|
state[i] = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
|
|
||||||
|
for i, pdu in enumerate(auth_chain):
|
||||||
|
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
|
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
|
@ -281,7 +322,48 @@ class FederationClient(object):
|
||||||
|
|
||||||
logger.debug("Got response to send_invite: %s", pdu_dict)
|
logger.debug("Got response to send_invite: %s", pdu_dict)
|
||||||
|
|
||||||
defer.returnValue(self.event_from_pdu_json(pdu_dict))
|
pdu = self.event_from_pdu_json(pdu_dict)
|
||||||
|
|
||||||
|
# Check signatures are correct.
|
||||||
|
pdu = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
|
|
||||||
|
defer.returnValue(pdu)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def query_auth(self, destination, room_id, event_id, local_auth):
|
||||||
|
"""
|
||||||
|
Params:
|
||||||
|
destination (str)
|
||||||
|
event_it (str)
|
||||||
|
local_auth (list)
|
||||||
|
"""
|
||||||
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
|
send_content = {
|
||||||
|
"auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
|
||||||
|
}
|
||||||
|
|
||||||
|
code, content = yield self.transport_layer.send_query_auth(
|
||||||
|
destination=destination,
|
||||||
|
room_id=room_id,
|
||||||
|
event_id=event_id,
|
||||||
|
content=send_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_chain = [
|
||||||
|
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
|
||||||
|
for e in content["auth_chain"]
|
||||||
|
]
|
||||||
|
|
||||||
|
ret = {
|
||||||
|
"auth_chain": auth_chain,
|
||||||
|
"rejects": content.get("rejects", []),
|
||||||
|
"missing": content.get("missing", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def event_from_pdu_json(self, pdu_json, outlier=False):
|
def event_from_pdu_json(self, pdu_json, outlier=False):
|
||||||
event = FrozenEvent(
|
event = FrozenEvent(
|
||||||
|
@ -291,3 +373,37 @@ class FederationClient(object):
|
||||||
event.internal_metadata.outlier = outlier
|
event.internal_metadata.outlier = outlier
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_sigs_and_hash(self, pdu):
|
||||||
|
"""Throws a SynapseError if the PDU does not have the correct
|
||||||
|
signatures.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FrozenEvent: Either the given event or it redacted if it failed the
|
||||||
|
content hash check.
|
||||||
|
"""
|
||||||
|
# Check signatures are correct.
|
||||||
|
redacted_event = prune_event(pdu)
|
||||||
|
redacted_pdu_json = redacted_event.get_pdu_json()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.keyring.verify_json_for_server(
|
||||||
|
pdu.origin, redacted_pdu_json
|
||||||
|
)
|
||||||
|
except SynapseError:
|
||||||
|
logger.warn(
|
||||||
|
"Signature check failed for %s redacted to %s",
|
||||||
|
encode_canonical_json(pdu.get_pdu_json()),
|
||||||
|
encode_canonical_json(redacted_pdu_json),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not check_event_content_hash(pdu):
|
||||||
|
logger.warn(
|
||||||
|
"Event content has been tampered, redacting %s, %s",
|
||||||
|
pdu.event_id, encode_canonical_json(pdu.get_dict())
|
||||||
|
)
|
||||||
|
defer.returnValue(redacted_event)
|
||||||
|
|
||||||
|
defer.returnValue(pdu)
|
||||||
|
|
|
@ -21,6 +21,13 @@ from .units import Transaction, Edu
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
|
from synapse.crypto.event_signing import check_event_content_hash
|
||||||
|
|
||||||
|
from synapse.api.errors import FederationError, SynapseError
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -97,8 +104,10 @@ class FederationServer(object):
|
||||||
response = yield self.transaction_actions.have_responded(transaction)
|
response = yield self.transaction_actions.have_responded(transaction)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
logger.debug("[%s] We've already responed to this request",
|
logger.debug(
|
||||||
transaction.transaction_id)
|
"[%s] We've already responed to this request",
|
||||||
|
transaction.transaction_id
|
||||||
|
)
|
||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -221,6 +230,54 @@ class FederationServer(object):
|
||||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_query_auth_request(self, origin, content, event_id):
|
||||||
|
"""
|
||||||
|
Content is a dict with keys::
|
||||||
|
auth_chain (list): A list of events that give the auth chain.
|
||||||
|
missing (list): A list of event_ids indicating what the other
|
||||||
|
side (`origin`) think we're missing.
|
||||||
|
rejects (dict): A mapping from event_id to a 2-tuple of reason
|
||||||
|
string and a proof (or None) of why the event was rejected.
|
||||||
|
The keys of this dict give the list of events the `origin` has
|
||||||
|
rejected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin (str)
|
||||||
|
content (dict)
|
||||||
|
event_id (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: Results in `dict` with the same format as `content`
|
||||||
|
"""
|
||||||
|
auth_chain = [
|
||||||
|
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
|
||||||
|
for e in content["auth_chain"]
|
||||||
|
]
|
||||||
|
|
||||||
|
missing = [
|
||||||
|
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
|
||||||
|
for e in content.get("missing", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
ret = yield self.handler.on_query_auth(
|
||||||
|
origin, event_id, auth_chain, content.get("rejects", []), missing
|
||||||
|
)
|
||||||
|
|
||||||
|
time_now = self._clock.time_msec()
|
||||||
|
send_content = {
|
||||||
|
"auth_chain": [
|
||||||
|
e.get_pdu_json(time_now)
|
||||||
|
for e in ret["auth_chain"]
|
||||||
|
],
|
||||||
|
"rejects": ret.get("rejects", []),
|
||||||
|
"missing": ret.get("missing", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue(
|
||||||
|
(200, send_content)
|
||||||
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
|
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
|
||||||
""" Get a PDU from the database with given origin and id.
|
""" Get a PDU from the database with given origin and id.
|
||||||
|
@ -253,6 +310,9 @@ class FederationServer(object):
|
||||||
origin, pdu.event_id, do_auth=False
|
origin, pdu.event_id, do_auth=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME: Currently we fetch an event again when we already have it
|
||||||
|
# if it has been marked as an outlier.
|
||||||
|
|
||||||
already_seen = (
|
already_seen = (
|
||||||
existing and (
|
existing and (
|
||||||
not existing.internal_metadata.is_outlier()
|
not existing.internal_metadata.is_outlier()
|
||||||
|
@ -264,14 +324,27 @@ class FederationServer(object):
|
||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check signature.
|
||||||
|
try:
|
||||||
|
pdu = yield self._check_sigs_and_hash(pdu)
|
||||||
|
except SynapseError as e:
|
||||||
|
raise FederationError(
|
||||||
|
"ERROR",
|
||||||
|
e.code,
|
||||||
|
e.msg,
|
||||||
|
affected=pdu.event_id,
|
||||||
|
)
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
|
|
||||||
auth_chain = []
|
auth_chain = []
|
||||||
|
|
||||||
have_seen = yield self.store.have_events(
|
have_seen = yield self.store.have_events(
|
||||||
[e for e, _ in pdu.prev_events]
|
[ev for ev, _ in pdu.prev_events]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fetch_state = False
|
||||||
|
|
||||||
# Get missing pdus if necessary.
|
# Get missing pdus if necessary.
|
||||||
if not pdu.internal_metadata.is_outlier():
|
if not pdu.internal_metadata.is_outlier():
|
||||||
# We only backfill backwards to the min depth.
|
# We only backfill backwards to the min depth.
|
||||||
|
@ -308,12 +381,22 @@ class FederationServer(object):
|
||||||
logger.debug("Processed pdu %s", event_id)
|
logger.debug("Processed pdu %s", event_id)
|
||||||
else:
|
else:
|
||||||
logger.warn("Failed to get PDU %s", event_id)
|
logger.warn("Failed to get PDU %s", event_id)
|
||||||
|
fetch_state = True
|
||||||
except:
|
except:
|
||||||
# TODO(erikj): Do some more intelligent retries.
|
# TODO(erikj): Do some more intelligent retries.
|
||||||
logger.exception("Failed to get PDU")
|
logger.exception("Failed to get PDU")
|
||||||
|
fetch_state = True
|
||||||
else:
|
else:
|
||||||
# We need to get the state at this event, since we have reached
|
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||||
# a backward extremity edge.
|
seen = set(have_seen.keys())
|
||||||
|
if prevs - seen:
|
||||||
|
fetch_state = True
|
||||||
|
else:
|
||||||
|
fetch_state = True
|
||||||
|
|
||||||
|
if fetch_state:
|
||||||
|
# We need to get the state at this event, since we haven't
|
||||||
|
# processed all the prev events.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_handle_new_pdu getting state for %s",
|
"_handle_new_pdu getting state for %s",
|
||||||
pdu.room_id
|
pdu.room_id
|
||||||
|
@ -343,3 +426,37 @@ class FederationServer(object):
|
||||||
event.internal_metadata.outlier = outlier
|
event.internal_metadata.outlier = outlier
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_sigs_and_hash(self, pdu):
|
||||||
|
"""Throws a SynapseError if the PDU does not have the correct
|
||||||
|
signatures.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FrozenEvent: Either the given event or it redacted if it failed the
|
||||||
|
content hash check.
|
||||||
|
"""
|
||||||
|
# Check signatures are correct.
|
||||||
|
redacted_event = prune_event(pdu)
|
||||||
|
redacted_pdu_json = redacted_event.get_pdu_json()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.keyring.verify_json_for_server(
|
||||||
|
pdu.origin, redacted_pdu_json
|
||||||
|
)
|
||||||
|
except SynapseError:
|
||||||
|
logger.warn(
|
||||||
|
"Signature check failed for %s redacted to %s",
|
||||||
|
encode_canonical_json(pdu.get_pdu_json()),
|
||||||
|
encode_canonical_json(redacted_pdu_json),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not check_event_content_hash(pdu):
|
||||||
|
logger.warn(
|
||||||
|
"Event content has been tampered, redacting %s, %s",
|
||||||
|
pdu.event_id, encode_canonical_json(pdu.get_dict())
|
||||||
|
)
|
||||||
|
defer.returnValue(redacted_event)
|
||||||
|
|
||||||
|
defer.returnValue(pdu)
|
||||||
|
|
|
@ -51,6 +51,8 @@ class ReplicationLayer(FederationClient, FederationServer):
|
||||||
def __init__(self, hs, transport_layer):
|
def __init__(self, hs, transport_layer):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
|
self.keyring = hs.get_keyring()
|
||||||
|
|
||||||
self.transport_layer = transport_layer
|
self.transport_layer = transport_layer
|
||||||
self.transport_layer.register_received_handler(self)
|
self.transport_layer.register_received_handler(self)
|
||||||
self.transport_layer.register_request_handler(self)
|
self.transport_layer.register_request_handler(self)
|
||||||
|
|
|
@ -213,3 +213,19 @@ class TransportLayerClient(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def send_query_auth(self, destination, room_id, event_id, content):
|
||||||
|
path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
|
||||||
|
|
||||||
|
code, content = yield self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not 200 <= code < 300:
|
||||||
|
raise RuntimeError("Got %d from send_invite", code)
|
||||||
|
|
||||||
|
defer.returnValue(json.loads(content))
|
||||||
|
|
|
@ -42,7 +42,7 @@ class TransportLayerServer(object):
|
||||||
content = None
|
content = None
|
||||||
origin = None
|
origin = None
|
||||||
|
|
||||||
if request.method == "PUT":
|
if request.method in ["PUT", "POST"]:
|
||||||
# TODO: Handle other method types? other content types?
|
# TODO: Handle other method types? other content types?
|
||||||
try:
|
try:
|
||||||
content_bytes = request.content.read()
|
content_bytes = request.content.read()
|
||||||
|
@ -234,6 +234,16 @@ class TransportLayerServer(object):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.server.register_path(
|
||||||
|
"POST",
|
||||||
|
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
|
||||||
|
self._with_authentication(
|
||||||
|
lambda origin, content, query, context, event_id:
|
||||||
|
self._on_query_auth_request(
|
||||||
|
origin, content, event_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -325,3 +335,12 @@ class TransportLayerServer(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, content))
|
defer.returnValue((200, content))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def _on_query_auth_request(self, origin, content, event_id):
|
||||||
|
new_content = yield self.request_handler.on_query_auth_request(
|
||||||
|
origin, content, event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
|
@ -17,19 +17,16 @@
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.events.utils import prune_event
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, FederationError, SynapseError, StoreError,
|
AuthError, FederationError, StoreError,
|
||||||
)
|
)
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.crypto.event_signing import (
|
from synapse.crypto.event_signing import (
|
||||||
compute_event_signature, check_event_content_hash,
|
compute_event_signature, add_hashes_and_signatures,
|
||||||
add_hashes_and_signatures,
|
|
||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from syutil.jsonutil import encode_canonical_json
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
logger.debug("Processing event: %s", event.event_id)
|
logger.debug("Processing event: %s", event.event_id)
|
||||||
|
|
||||||
redacted_event = prune_event(event)
|
|
||||||
|
|
||||||
redacted_pdu_json = redacted_event.get_pdu_json()
|
|
||||||
try:
|
|
||||||
yield self.keyring.verify_json_for_server(
|
|
||||||
event.origin, redacted_pdu_json
|
|
||||||
)
|
|
||||||
except SynapseError as e:
|
|
||||||
logger.warn(
|
|
||||||
"Signature check failed for %s redacted to %s",
|
|
||||||
encode_canonical_json(pdu.get_pdu_json()),
|
|
||||||
encode_canonical_json(redacted_pdu_json),
|
|
||||||
)
|
|
||||||
raise FederationError(
|
|
||||||
"ERROR",
|
|
||||||
e.code,
|
|
||||||
e.msg,
|
|
||||||
affected=event.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not check_event_content_hash(event):
|
|
||||||
logger.warn(
|
|
||||||
"Event content has been tampered, redacting %s, %s",
|
|
||||||
event.event_id, encode_canonical_json(event.get_dict())
|
|
||||||
)
|
|
||||||
event = redacted_event
|
|
||||||
|
|
||||||
logger.debug("Event: %s", event)
|
logger.debug("Event: %s", event)
|
||||||
|
|
||||||
# FIXME (erikj): Awful hack to make the case where we are not currently
|
# FIXME (erikj): Awful hack to make the case where we are not currently
|
||||||
|
@ -149,41 +119,20 @@ class FederationHandler(BaseHandler):
|
||||||
event.room_id,
|
event.room_id,
|
||||||
self.server_name
|
self.server_name
|
||||||
)
|
)
|
||||||
if not is_in_room and not event.internal_metadata.outlier:
|
if not is_in_room and not event.internal_metadata.is_outlier():
|
||||||
logger.debug("Got event for room we're not in.")
|
logger.debug("Got event for room we're not in.")
|
||||||
|
|
||||||
replication = self.replication_layer
|
|
||||||
|
|
||||||
if not state:
|
|
||||||
state, auth_chain = yield replication.get_state_for_room(
|
|
||||||
origin, context=event.room_id, event_id=event.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not auth_chain:
|
|
||||||
auth_chain = yield replication.get_event_auth(
|
|
||||||
origin,
|
|
||||||
context=event.room_id,
|
|
||||||
event_id=event.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
for e in auth_chain:
|
|
||||||
e.internal_metadata.outlier = True
|
|
||||||
try:
|
|
||||||
yield self._handle_new_event(e, fetch_auth_from=origin)
|
|
||||||
except:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to handle auth event %s",
|
|
||||||
e.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
current_state = state
|
current_state = state
|
||||||
|
|
||||||
if state:
|
if state and auth_chain is not None:
|
||||||
for e in state:
|
for e in state:
|
||||||
logging.info("A :) %r", e)
|
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_event(e)
|
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||||
|
auth = {
|
||||||
|
(e.type, e.state_key): e for e in auth_chain
|
||||||
|
if e.event_id in auth_ids
|
||||||
|
}
|
||||||
|
yield self._handle_new_event(origin, e, auth_events=auth)
|
||||||
except:
|
except:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to handle state event %s",
|
"Failed to handle state event %s",
|
||||||
|
@ -192,6 +141,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_event(
|
yield self._handle_new_event(
|
||||||
|
origin,
|
||||||
event,
|
event,
|
||||||
state=state,
|
state=state,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
|
@ -394,7 +344,14 @@ class FederationHandler(BaseHandler):
|
||||||
for e in auth_chain:
|
for e in auth_chain:
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_event(e)
|
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||||
|
auth = {
|
||||||
|
(e.type, e.state_key): e for e in auth_chain
|
||||||
|
if e.event_id in auth_ids
|
||||||
|
}
|
||||||
|
yield self._handle_new_event(
|
||||||
|
target_host, e, auth_events=auth
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to handle auth event %s",
|
"Failed to handle auth event %s",
|
||||||
|
@ -405,8 +362,13 @@ class FederationHandler(BaseHandler):
|
||||||
# FIXME: Auth these.
|
# FIXME: Auth these.
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
try:
|
||||||
|
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||||
|
auth = {
|
||||||
|
(e.type, e.state_key): e for e in auth_chain
|
||||||
|
if e.event_id in auth_ids
|
||||||
|
}
|
||||||
yield self._handle_new_event(
|
yield self._handle_new_event(
|
||||||
e, fetch_auth_from=target_host
|
target_host, e, auth_events=auth
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
|
@ -415,6 +377,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._handle_new_event(
|
yield self._handle_new_event(
|
||||||
|
target_host,
|
||||||
new_event,
|
new_event,
|
||||||
state=state,
|
state=state,
|
||||||
current_state=state,
|
current_state=state,
|
||||||
|
@ -481,7 +444,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
event.internal_metadata.outlier = False
|
event.internal_metadata.outlier = False
|
||||||
|
|
||||||
context = yield self._handle_new_event(event)
|
context = yield self._handle_new_event(origin, event)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
|
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
|
||||||
|
@ -682,11 +645,12 @@ class FederationHandler(BaseHandler):
|
||||||
waiters.pop().callback(None)
|
waiters.pop().callback(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _handle_new_event(self, event, state=None, backfilled=False,
|
@log_function
|
||||||
current_state=None, fetch_auth_from=None):
|
def _handle_new_event(self, origin, event, state=None, backfilled=False,
|
||||||
|
current_state=None, auth_events=None):
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_handle_new_event: Before annotate: %s, sigs: %s",
|
"_handle_new_event: %s, sigs: %s",
|
||||||
event.event_id, event.signatures,
|
event.event_id, event.signatures,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -694,65 +658,44 @@ class FederationHandler(BaseHandler):
|
||||||
event, old_state=state
|
event, old_state=state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not auth_events:
|
||||||
|
auth_events = context.auth_events
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_handle_new_event: Before auth fetch: %s, sigs: %s",
|
"_handle_new_event: %s, auth_events: %s",
|
||||||
event.event_id, event.signatures,
|
event.event_id, auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_new_state = not event.internal_metadata.is_outlier()
|
is_new_state = not event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
known_ids = set(
|
# This is a hack to fix some old rooms where the initial join event
|
||||||
[s.event_id for s in context.auth_events.values()]
|
# didn't reference the create event in its auth events.
|
||||||
)
|
|
||||||
|
|
||||||
for e_id, _ in event.auth_events:
|
|
||||||
if e_id not in known_ids:
|
|
||||||
e = yield self.store.get_event(e_id, allow_none=True)
|
|
||||||
|
|
||||||
if not e and fetch_auth_from is not None:
|
|
||||||
# Grab the auth_chain over federation if we are missing
|
|
||||||
# auth events.
|
|
||||||
auth_chain = yield self.replication_layer.get_event_auth(
|
|
||||||
fetch_auth_from, event.event_id, event.room_id
|
|
||||||
)
|
|
||||||
for auth_event in auth_chain:
|
|
||||||
yield self._handle_new_event(auth_event)
|
|
||||||
e = yield self.store.get_event(e_id, allow_none=True)
|
|
||||||
|
|
||||||
if not e:
|
|
||||||
# TODO: Do some conflict res to make sure that we're
|
|
||||||
# not the ones who are wrong.
|
|
||||||
logger.info(
|
|
||||||
"Rejecting %s as %s not in db or %s",
|
|
||||||
event.event_id, e_id, known_ids,
|
|
||||||
)
|
|
||||||
# FIXME: How does raising AuthError work with federation?
|
|
||||||
raise AuthError(403, "Cannot find auth event")
|
|
||||||
|
|
||||||
context.auth_events[(e.type, e.state_key)] = e
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"_handle_new_event: Before hack: %s, sigs: %s",
|
|
||||||
event.event_id, event.signatures,
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member and not event.auth_events:
|
if event.type == EventTypes.Member and not event.auth_events:
|
||||||
if len(event.prev_events) == 1:
|
if len(event.prev_events) == 1:
|
||||||
c = yield self.store.get_event(event.prev_events[0][0])
|
c = yield self.store.get_event(event.prev_events[0][0])
|
||||||
if c.type == EventTypes.Create:
|
if c.type == EventTypes.Create:
|
||||||
context.auth_events[(c.type, c.state_key)] = c
|
auth_events[(c.type, c.state_key)] = c
|
||||||
|
|
||||||
logger.debug(
|
try:
|
||||||
"_handle_new_event: Before auth check: %s, sigs: %s",
|
yield self.do_auth(
|
||||||
event.event_id, event.signatures,
|
origin, event, context, auth_events=auth_events
|
||||||
|
)
|
||||||
|
except AuthError as e:
|
||||||
|
logger.warn(
|
||||||
|
"Rejecting %s because %s",
|
||||||
|
event.event_id, e.msg
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth.check(event, auth_events=context.auth_events)
|
context.rejected = RejectedReason.AUTH_ERROR
|
||||||
|
|
||||||
logger.debug(
|
yield self.store.persist_event(
|
||||||
"_handle_new_event: Before persist_event: %s, sigs: %s",
|
event,
|
||||||
event.event_id, event.signatures,
|
context=context,
|
||||||
|
backfilled=backfilled,
|
||||||
|
is_new_state=False,
|
||||||
|
current_state=current_state,
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
yield self.store.persist_event(
|
yield self.store.persist_event(
|
||||||
event,
|
event,
|
||||||
|
@ -762,9 +705,257 @@ class FederationHandler(BaseHandler):
|
||||||
current_state=current_state,
|
current_state=current_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
defer.returnValue(context)
|
||||||
"_handle_new_event: After persist_event: %s, sigs: %s",
|
|
||||||
event.event_id, event.signatures,
|
@defer.inlineCallbacks
|
||||||
|
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
||||||
|
missing):
|
||||||
|
# Just go through and process each event in `remote_auth_chain`. We
|
||||||
|
# don't want to fall into the trap of `missing` being wrong.
|
||||||
|
for e in remote_auth_chain:
|
||||||
|
try:
|
||||||
|
yield self._handle_new_event(origin, e)
|
||||||
|
except AuthError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Now get the current auth_chain for the event.
|
||||||
|
local_auth_chain = yield self.store.get_auth_chain([event_id])
|
||||||
|
|
||||||
|
# TODO: Check if we would now reject event_id. If so we need to tell
|
||||||
|
# everyone.
|
||||||
|
|
||||||
|
ret = yield self.construct_auth_difference(
|
||||||
|
local_auth_chain, remote_auth_chain
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(context)
|
logger.debug("on_query_auth reutrning: %s", ret)
|
||||||
|
|
||||||
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def do_auth(self, origin, event, context, auth_events):
|
||||||
|
# Check if we have all the auth events.
|
||||||
|
res = yield self.store.have_events(
|
||||||
|
[e_id for e_id, _ in event.auth_events]
|
||||||
|
)
|
||||||
|
|
||||||
|
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||||
|
seen_events = set(res.keys())
|
||||||
|
|
||||||
|
missing_auth = event_auth_events - seen_events
|
||||||
|
|
||||||
|
if missing_auth:
|
||||||
|
logger.debug("Missing auth: %s", missing_auth)
|
||||||
|
# If we don't have all the auth events, we need to get them.
|
||||||
|
remote_auth_chain = yield self.replication_layer.get_event_auth(
|
||||||
|
origin, event.room_id, event.event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
for e in remote_auth_chain:
|
||||||
|
try:
|
||||||
|
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||||
|
auth = {
|
||||||
|
(e.type, e.state_key): e for e in remote_auth_chain
|
||||||
|
if e.event_id in auth_ids
|
||||||
|
}
|
||||||
|
e.internal_metadata.outlier = True
|
||||||
|
yield self._handle_new_event(
|
||||||
|
origin, e, auth_events=auth
|
||||||
|
)
|
||||||
|
auth_events[(e.type, e.state_key)] = e
|
||||||
|
except AuthError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# FIXME: Assumes we have and stored all the state for all the
|
||||||
|
# prev_events
|
||||||
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
|
different_auth = event_auth_events - current_state
|
||||||
|
|
||||||
|
if different_auth and not event.internal_metadata.is_outlier():
|
||||||
|
# Do auth conflict res.
|
||||||
|
logger.debug("Different auth: %s", different_auth)
|
||||||
|
|
||||||
|
# 1. Get what we think is the auth chain.
|
||||||
|
auth_ids = self.auth.compute_auth_events(event, context)
|
||||||
|
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
||||||
|
|
||||||
|
# 2. Get remote difference.
|
||||||
|
result = yield self.replication_layer.query_auth(
|
||||||
|
origin,
|
||||||
|
event.room_id,
|
||||||
|
event.event_id,
|
||||||
|
local_auth_chain,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Process any remote auth chain events we haven't seen.
|
||||||
|
for missing_id in result.get("missing", []):
|
||||||
|
try:
|
||||||
|
for e in result["auth_chain"]:
|
||||||
|
if e.event_id == missing_id:
|
||||||
|
ev = e
|
||||||
|
break
|
||||||
|
|
||||||
|
auth_ids = [e_id for e_id, _ in ev.auth_events]
|
||||||
|
auth = {
|
||||||
|
(e.type, e.state_key): e for e in result["auth_chain"]
|
||||||
|
if e.event_id in auth_ids
|
||||||
|
}
|
||||||
|
ev.internal_metadata.outlier = True
|
||||||
|
yield self._handle_new_event(
|
||||||
|
origin, ev, auth_events=auth
|
||||||
|
)
|
||||||
|
|
||||||
|
if ev.event_id in event_auth_events:
|
||||||
|
auth_events[(ev.type, ev.state_key)] = ev
|
||||||
|
except AuthError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 4. Look at rejects and their proofs.
|
||||||
|
# TODO.
|
||||||
|
|
||||||
|
context.current_state.update(auth_events)
|
||||||
|
context.state_group = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.auth.check(event, auth_events=auth_events)
|
||||||
|
except AuthError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def construct_auth_difference(self, local_auth, remote_auth):
|
||||||
|
""" Given a local and remote auth chain, find the differences. This
|
||||||
|
assumes that we have already processed all events in remote_auth
|
||||||
|
|
||||||
|
Params:
|
||||||
|
local_auth (list)
|
||||||
|
remote_auth (list)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.debug("construct_auth_difference Start!")
|
||||||
|
|
||||||
|
# TODO: Make sure we are OK with local_auth or remote_auth having more
|
||||||
|
# auth events in them than strictly necessary.
|
||||||
|
|
||||||
|
def sort_fun(ev):
|
||||||
|
return ev.depth, ev.event_id
|
||||||
|
|
||||||
|
logger.debug("construct_auth_difference after sort_fun!")
|
||||||
|
|
||||||
|
# We find the differences by starting at the "bottom" of each list
|
||||||
|
# and iterating up on both lists. The lists are ordered by depth and
|
||||||
|
# then event_id, we iterate up both lists until we find the event ids
|
||||||
|
# don't match. Then we look at depth/event_id to see which side is
|
||||||
|
# missing that event, and iterate only up that list. Repeat.
|
||||||
|
|
||||||
|
remote_list = list(remote_auth)
|
||||||
|
remote_list.sort(key=sort_fun)
|
||||||
|
|
||||||
|
local_list = list(local_auth)
|
||||||
|
local_list.sort(key=sort_fun)
|
||||||
|
|
||||||
|
local_iter = iter(local_list)
|
||||||
|
remote_iter = iter(remote_list)
|
||||||
|
|
||||||
|
logger.debug("construct_auth_difference before get_next!")
|
||||||
|
|
||||||
|
def get_next(it, opt=None):
|
||||||
|
try:
|
||||||
|
return it.next()
|
||||||
|
except:
|
||||||
|
return opt
|
||||||
|
|
||||||
|
current_local = get_next(local_iter)
|
||||||
|
current_remote = get_next(remote_iter)
|
||||||
|
|
||||||
|
logger.debug("construct_auth_difference before while")
|
||||||
|
|
||||||
|
missing_remotes = []
|
||||||
|
missing_locals = []
|
||||||
|
while current_local or current_remote:
|
||||||
|
if current_remote is None:
|
||||||
|
missing_locals.append(current_local)
|
||||||
|
current_local = get_next(local_iter)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_local is None:
|
||||||
|
missing_remotes.append(current_remote)
|
||||||
|
current_remote = get_next(remote_iter)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_local.event_id == current_remote.event_id:
|
||||||
|
current_local = get_next(local_iter)
|
||||||
|
current_remote = get_next(remote_iter)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_local.depth < current_remote.depth:
|
||||||
|
missing_locals.append(current_local)
|
||||||
|
current_local = get_next(local_iter)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_local.depth > current_remote.depth:
|
||||||
|
missing_remotes.append(current_remote)
|
||||||
|
current_remote = get_next(remote_iter)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# They have the same depth, so we fall back to the event_id order
|
||||||
|
if current_local.event_id < current_remote.event_id:
|
||||||
|
missing_locals.append(current_local)
|
||||||
|
current_local = get_next(local_iter)
|
||||||
|
|
||||||
|
if current_local.event_id > current_remote.event_id:
|
||||||
|
missing_remotes.append(current_remote)
|
||||||
|
current_remote = get_next(remote_iter)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug("construct_auth_difference after while")
|
||||||
|
|
||||||
|
# missing locals should be sent to the server
|
||||||
|
# We should find why we are missing remotes, as they will have been
|
||||||
|
# rejected.
|
||||||
|
|
||||||
|
# Remove events from missing_remotes if they are referencing a missing
|
||||||
|
# remote. We only care about the "root" rejected ones.
|
||||||
|
missing_remote_ids = [e.event_id for e in missing_remotes]
|
||||||
|
base_remote_rejected = list(missing_remotes)
|
||||||
|
for e in missing_remotes:
|
||||||
|
for e_id, _ in e.auth_events:
|
||||||
|
if e_id in missing_remote_ids:
|
||||||
|
base_remote_rejected.remove(e)
|
||||||
|
|
||||||
|
reason_map = {}
|
||||||
|
|
||||||
|
for e in base_remote_rejected:
|
||||||
|
reason = yield self.store.get_rejection_reason(e.event_id)
|
||||||
|
if reason is None:
|
||||||
|
# FIXME: ERRR?!
|
||||||
|
logger.warn("Could not find reason for %s", e.event_id)
|
||||||
|
raise RuntimeError("")
|
||||||
|
|
||||||
|
reason_map[e.event_id] = reason
|
||||||
|
|
||||||
|
if reason == RejectedReason.AUTH_ERROR:
|
||||||
|
pass
|
||||||
|
elif reason == RejectedReason.REPLACED:
|
||||||
|
# TODO: Get proof
|
||||||
|
pass
|
||||||
|
elif reason == RejectedReason.NOT_ANCESTOR:
|
||||||
|
# TODO: Get proof.
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.debug("construct_auth_difference returning")
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"auth_chain": local_auth,
|
||||||
|
"rejects": {
|
||||||
|
e.event_id: {
|
||||||
|
"reason": reason_map[e.event_id],
|
||||||
|
"proof": None,
|
||||||
|
}
|
||||||
|
for e in base_remote_rejected
|
||||||
|
},
|
||||||
|
"missing": [e.event_id for e in missing_locals],
|
||||||
|
})
|
||||||
|
|
|
@ -244,6 +244,43 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
defer.returnValue((response.code, body))
|
defer.returnValue((response.code, body))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def post_json(self, destination, path, data={}):
|
||||||
|
""" Sends the specifed json data using POST
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): The remote server to send the HTTP request
|
||||||
|
to.
|
||||||
|
path (str): The HTTP path.
|
||||||
|
data (dict): A dict containing the data that will be used as
|
||||||
|
the request body. This will be encoded as JSON.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||||
|
will be the decoded JSON body. On a 4xx or 5xx error response a
|
||||||
|
CodeMessageException is raised.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def body_callback(method, url_bytes, headers_dict):
|
||||||
|
self.sign_request(
|
||||||
|
destination, method, url_bytes, headers_dict, data
|
||||||
|
)
|
||||||
|
return _JsonProducer(data)
|
||||||
|
|
||||||
|
response = yield self._create_request(
|
||||||
|
destination.encode("ascii"),
|
||||||
|
"POST",
|
||||||
|
path.encode("ascii"),
|
||||||
|
body_callback=body_callback,
|
||||||
|
headers_dict={"Content-Type": ["application/json"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Getting resp body")
|
||||||
|
body = yield readBody(response)
|
||||||
|
logger.debug("Got resp body")
|
||||||
|
|
||||||
|
defer.returnValue((response.code, body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
|
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
|
||||||
""" GETs some json from the given host homeserver and path
|
""" GETs some json from the given host homeserver and path
|
||||||
|
|
|
@ -168,10 +168,17 @@ class StateHandler(object):
|
||||||
first is the name of a state group if one and only one is involved,
|
first is the name of a state group if one and only one is involved,
|
||||||
otherwise `None`.
|
otherwise `None`.
|
||||||
"""
|
"""
|
||||||
|
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups(
|
state_groups = yield self.store.get_state_groups(
|
||||||
event_ids
|
event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"resolve_state_groups state_groups %s",
|
||||||
|
state_groups.keys()
|
||||||
|
)
|
||||||
|
|
||||||
group_names = set(state_groups.keys())
|
group_names = set(state_groups.keys())
|
||||||
if len(group_names) == 1:
|
if len(group_names) == 1:
|
||||||
name, state_list = state_groups.items().pop()
|
name, state_list = state_groups.items().pop()
|
||||||
|
|
|
@ -28,6 +28,16 @@ class RejectionsStore(SQLBaseStore):
|
||||||
values={
|
values={
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"last_failure": self._clock.time_msec(),
|
"last_check": self._clock.time_msec(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_rejection_reason(self, event_id):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="rejections",
|
||||||
|
retcol="reason",
|
||||||
|
keyvalues={
|
||||||
|
"event_id": event_id,
|
||||||
|
},
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
|
@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
|
||||||
"get_room",
|
"get_room",
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
"set_destination_retry_timings",
|
"set_destination_retry_timings",
|
||||||
|
"have_events",
|
||||||
]),
|
]),
|
||||||
resource_for_federation=NonCallableMock(),
|
resource_for_federation=NonCallableMock(),
|
||||||
http_client=NonCallableMock(spec_set=[]),
|
http_client=NonCallableMock(spec_set=[]),
|
||||||
|
@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase):
|
||||||
self.datastore.persist_event.return_value = defer.succeed(None)
|
self.datastore.persist_event.return_value = defer.succeed(None)
|
||||||
self.datastore.get_room.return_value = defer.succeed(True)
|
self.datastore.get_room.return_value = defer.succeed(True)
|
||||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||||
|
self.datastore.have_events.return_value = defer.succeed({})
|
||||||
|
|
||||||
def annotate(ev, old_state=None):
|
def annotate(ev, old_state=None):
|
||||||
context = Mock()
|
context = Mock()
|
||||||
|
|
Loading…
Reference in a new issue