forked from MirrorHub/synapse
Include hashes of previous pdus when referencing them
This commit is contained in:
parent
66104da10c
commit
bb04447c44
11 changed files with 95 additions and 31 deletions
|
@ -65,13 +65,13 @@ class SynapseEvent(JsonEncodedObject):
|
||||||
|
|
||||||
internal_keys = [
|
internal_keys = [
|
||||||
"is_state",
|
"is_state",
|
||||||
"prev_events",
|
|
||||||
"depth",
|
"depth",
|
||||||
"destinations",
|
"destinations",
|
||||||
"origin",
|
"origin",
|
||||||
"outlier",
|
"outlier",
|
||||||
"power_level",
|
"power_level",
|
||||||
"redacted",
|
"redacted",
|
||||||
|
"prev_pdus",
|
||||||
]
|
]
|
||||||
|
|
||||||
required_keys = [
|
required_keys = [
|
||||||
|
|
|
@ -45,9 +45,7 @@ class PduCodec(object):
|
||||||
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
|
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
|
||||||
kwargs["room_id"] = pdu.context
|
kwargs["room_id"] = pdu.context
|
||||||
kwargs["etype"] = pdu.pdu_type
|
kwargs["etype"] = pdu.pdu_type
|
||||||
kwargs["prev_events"] = [
|
kwargs["prev_pdus"] = pdu.prev_pdus
|
||||||
encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
|
|
||||||
]
|
|
||||||
|
|
||||||
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
|
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
|
||||||
kwargs["prev_state"] = encode_event_id(
|
kwargs["prev_state"] = encode_event_id(
|
||||||
|
@ -78,11 +76,8 @@ class PduCodec(object):
|
||||||
d["context"] = event.room_id
|
d["context"] = event.room_id
|
||||||
d["pdu_type"] = event.type
|
d["pdu_type"] = event.type
|
||||||
|
|
||||||
if hasattr(event, "prev_events"):
|
if hasattr(event, "prev_pdus"):
|
||||||
d["prev_pdus"] = [
|
d["prev_pdus"] = event.prev_pdus
|
||||||
decode_event_id(e, self.server_name)
|
|
||||||
for e in event.prev_events
|
|
||||||
]
|
|
||||||
|
|
||||||
if hasattr(event, "prev_state"):
|
if hasattr(event, "prev_state"):
|
||||||
d["prev_state_id"], d["prev_state_origin"] = (
|
d["prev_state_id"], d["prev_state_origin"] = (
|
||||||
|
@ -95,7 +90,7 @@ class PduCodec(object):
|
||||||
kwargs = copy.deepcopy(event.unrecognized_keys)
|
kwargs = copy.deepcopy(event.unrecognized_keys)
|
||||||
kwargs.update({
|
kwargs.update({
|
||||||
k: v for k, v in d.items()
|
k: v for k, v in d.items()
|
||||||
if k not in ["event_id", "room_id", "type", "prev_events"]
|
if k not in ["event_id", "room_id", "type"]
|
||||||
})
|
})
|
||||||
|
|
||||||
if "ts" not in kwargs:
|
if "ts" not in kwargs:
|
||||||
|
|
|
@ -443,7 +443,7 @@ class ReplicationLayer(object):
|
||||||
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
|
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
|
||||||
|
|
||||||
if min_depth and pdu.depth > min_depth:
|
if min_depth and pdu.depth > min_depth:
|
||||||
for pdu_id, origin in pdu.prev_pdus:
|
for pdu_id, origin, hashes in pdu.prev_pdus:
|
||||||
exists = yield self._get_persisted_pdu(pdu_id, origin)
|
exists = yield self._get_persisted_pdu(pdu_id, origin)
|
||||||
|
|
||||||
if not exists:
|
if not exists:
|
||||||
|
|
|
@ -141,8 +141,16 @@ class Pdu(JsonEncodedObject):
|
||||||
for kid, sig in pdu_tuple.signatures.items()
|
for kid, sig in pdu_tuple.signatures.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prev_pdus = []
|
||||||
|
for prev_pdu in pdu_tuple.prev_pdu_list:
|
||||||
|
prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {})
|
||||||
|
prev_hashes = {
|
||||||
|
alg: encode_base64(hsh) for alg, hsh in prev_hashes.items()
|
||||||
|
}
|
||||||
|
prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes))
|
||||||
|
|
||||||
return Pdu(
|
return Pdu(
|
||||||
prev_pdus=pdu_tuple.prev_pdu_list,
|
prev_pdus=prev_pdus,
|
||||||
**args
|
**args
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -72,10 +72,6 @@ class StateHandler(object):
|
||||||
|
|
||||||
snapshot.fill_out_prev_events(event)
|
snapshot.fill_out_prev_events(event)
|
||||||
|
|
||||||
event.prev_events = [
|
|
||||||
e for e in event.prev_events if e != event.event_id
|
|
||||||
]
|
|
||||||
|
|
||||||
current_state = snapshot.prev_state_pdu
|
current_state = snapshot.prev_state_pdu
|
||||||
|
|
||||||
if current_state:
|
if current_state:
|
||||||
|
|
|
@ -177,6 +177,14 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
|
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus:
|
||||||
|
for alg, hash_base64 in prev_hashes.items():
|
||||||
|
hash_bytes = decode_base64(hash_base64)
|
||||||
|
self._store_prev_pdu_hash_txn(
|
||||||
|
txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg,
|
||||||
|
hash_bytes
|
||||||
|
)
|
||||||
|
|
||||||
if pdu.is_state:
|
if pdu.is_state:
|
||||||
self._persist_state_txn(txn, pdu.prev_pdus, cols)
|
self._persist_state_txn(txn, pdu.prev_pdus, cols)
|
||||||
else:
|
else:
|
||||||
|
@ -352,6 +360,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
prev_pdus = self._get_latest_pdus_in_context(
|
prev_pdus = self._get_latest_pdus_in_context(
|
||||||
txn, room_id
|
txn, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if state_type is not None and state_key is not None:
|
if state_type is not None and state_key is not None:
|
||||||
prev_state_pdu = self._get_current_state_pdu(
|
prev_state_pdu = self._get_current_state_pdu(
|
||||||
txn, room_id, state_type, state_key
|
txn, room_id, state_type, state_key
|
||||||
|
@ -401,17 +410,16 @@ class Snapshot(object):
|
||||||
self.prev_state_pdu = prev_state_pdu
|
self.prev_state_pdu = prev_state_pdu
|
||||||
|
|
||||||
def fill_out_prev_events(self, event):
|
def fill_out_prev_events(self, event):
|
||||||
if hasattr(event, "prev_events"):
|
if hasattr(event, "prev_pdus"):
|
||||||
return
|
return
|
||||||
|
|
||||||
es = [
|
event.prev_pdus = [
|
||||||
"%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
|
(p_id, origin, hashes)
|
||||||
|
for p_id, origin, hashes, _ in self.prev_pdus
|
||||||
]
|
]
|
||||||
|
|
||||||
event.prev_events = [e for e in es if e != event.event_id]
|
|
||||||
|
|
||||||
if self.prev_pdus:
|
if self.prev_pdus:
|
||||||
event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
|
event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1
|
||||||
else:
|
else:
|
||||||
event.depth = 0
|
event.depth = 0
|
||||||
|
|
||||||
|
|
|
@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
|
||||||
from synapse.federation.units import Pdu
|
from synapse.federation.units import Pdu
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
|
||||||
|
from syutil.base64util import encode_base64
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,6 +67,8 @@ class PduStore(SQLBaseStore):
|
||||||
for r in PduEdgesTable.decode_results(txn.fetchall())
|
for r in PduEdgesTable.decode_results(txn.fetchall())
|
||||||
]
|
]
|
||||||
|
|
||||||
|
edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
|
||||||
|
|
||||||
hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin)
|
hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin)
|
||||||
signatures = self._get_pdu_origin_signatures_txn(
|
signatures = self._get_pdu_origin_signatures_txn(
|
||||||
txn, pdu_id, origin
|
txn, pdu_id, origin
|
||||||
|
@ -86,7 +91,7 @@ class PduStore(SQLBaseStore):
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
if row:
|
if row:
|
||||||
results.append(PduTuple(
|
results.append(PduTuple(
|
||||||
PduEntry(*row), edges, hashes, signatures
|
PduEntry(*row), edges, hashes, signatures, edge_hashes
|
||||||
))
|
))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -310,9 +315,14 @@ class PduStore(SQLBaseStore):
|
||||||
(context, )
|
(context, )
|
||||||
)
|
)
|
||||||
|
|
||||||
results = txn.fetchall()
|
results = []
|
||||||
|
for pdu_id, origin, depth in txn.fetchall():
|
||||||
|
hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin)
|
||||||
|
sha256_bytes = hashes["sha256"]
|
||||||
|
prev_hashes = {"sha256": encode_base64(sha256_bytes)}
|
||||||
|
results.append((pdu_id, origin, prev_hashes, depth))
|
||||||
|
|
||||||
return [(row[0], row[1], row[2]) for row in results]
|
return results
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_oldest_pdus_in_context(self, context):
|
def get_oldest_pdus_in_context(self, context):
|
||||||
|
@ -431,7 +441,7 @@ class PduStore(SQLBaseStore):
|
||||||
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
|
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
|
||||||
% PduForwardExtremitiesTable.table_name
|
% PduForwardExtremitiesTable.table_name
|
||||||
)
|
)
|
||||||
txn.executemany(query, prev_pdus)
|
txn.executemany(query, list(p[:2] for p in prev_pdus))
|
||||||
|
|
||||||
# We only insert as a forward extremety the new pdu if there are no
|
# We only insert as a forward extremety the new pdu if there are no
|
||||||
# other pdus that reference it as a prev pdu
|
# other pdus that reference it as a prev pdu
|
||||||
|
@ -454,7 +464,7 @@ class PduStore(SQLBaseStore):
|
||||||
# deleted in a second if they're incorrect anyway.
|
# deleted in a second if they're incorrect anyway.
|
||||||
txn.executemany(
|
txn.executemany(
|
||||||
PduBackwardExtremitiesTable.insert_statement(),
|
PduBackwardExtremitiesTable.insert_statement(),
|
||||||
[(i, o, context) for i, o in prev_pdus]
|
[(i, o, context) for i, o, _ in prev_pdus]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Also delete from the backwards extremities table all ones that
|
# Also delete from the backwards extremities table all ones that
|
||||||
|
@ -915,7 +925,7 @@ This does not include a prev_pdus key.
|
||||||
|
|
||||||
PduTuple = namedtuple(
|
PduTuple = namedtuple(
|
||||||
"PduTuple",
|
"PduTuple",
|
||||||
("pdu_entry", "prev_pdu_list", "hashes", "signatures")
|
("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
|
||||||
)
|
)
|
||||||
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
|
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
|
||||||
the `prev_pdus` key of a PDU.
|
the `prev_pdus` key of a PDU.
|
||||||
|
|
|
@ -34,3 +34,19 @@ CREATE TABLE IF NOT EXISTS pdu_origin_signatures (
|
||||||
CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
|
CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
|
||||||
pdu_id, origin
|
pdu_id, origin
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS pdu_edge_hashes(
|
||||||
|
pdu_id TEXT,
|
||||||
|
origin TEXT,
|
||||||
|
prev_pdu_id TEXT,
|
||||||
|
prev_origin TEXT,
|
||||||
|
algorithm TEXT,
|
||||||
|
hash BLOB,
|
||||||
|
CONSTRAINT uniqueness UNIQUE (
|
||||||
|
pdu_id, origin, prev_pdu_id, prev_origin, algorithm
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes(
|
||||||
|
pdu_id, origin
|
||||||
|
);
|
||||||
|
|
|
@ -88,3 +88,34 @@ class SignatureStore(SQLBaseStore):
|
||||||
"signature": buffer(signature_bytes),
|
"signature": buffer(signature_bytes),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin):
|
||||||
|
"""Get all the hashes for previous PDUs of a PDU
|
||||||
|
Args:
|
||||||
|
txn (cursor):
|
||||||
|
pdu_id (str): Id of the PDU.
|
||||||
|
origin (str): Origin of the PDU.
|
||||||
|
Returns:
|
||||||
|
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
|
||||||
|
"""
|
||||||
|
query = (
|
||||||
|
"SELECT prev_pdu_id, prev_origin, algorithm, hash"
|
||||||
|
" FROM pdu_edge_hashes"
|
||||||
|
" WHERE pdu_id = ? and origin = ?"
|
||||||
|
)
|
||||||
|
txn.execute(query, (pdu_id, origin))
|
||||||
|
results = {}
|
||||||
|
for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall():
|
||||||
|
hashes = results.setdefault((prev_pdu_id, prev_origin), {})
|
||||||
|
hashes[algorithm] = hash_bytes
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id,
|
||||||
|
prev_origin, algorithm, hash_bytes):
|
||||||
|
self._simple_insert_txn(txn, "pdu_edge_hashes", {
|
||||||
|
"pdu_id": pdu_id,
|
||||||
|
"origin": origin,
|
||||||
|
"prev_pdu_id": prev_pdu_id,
|
||||||
|
"prev_origin": prev_origin,
|
||||||
|
"algorithm": algorithm,
|
||||||
|
"hash": buffer(hash_bytes),
|
||||||
|
})
|
||||||
|
|
|
@ -41,7 +41,7 @@ def make_pdu(prev_pdus=[], **kwargs):
|
||||||
}
|
}
|
||||||
pdu_fields.update(kwargs)
|
pdu_fields.update(kwargs)
|
||||||
|
|
||||||
return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {})
|
return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {}, {})
|
||||||
|
|
||||||
|
|
||||||
class FederationTestCase(unittest.TestCase):
|
class FederationTestCase(unittest.TestCase):
|
||||||
|
|
|
@ -88,7 +88,7 @@ class PduCodecTestCase(unittest.TestCase):
|
||||||
self.assertEquals(pdu.context, event.room_id)
|
self.assertEquals(pdu.context, event.room_id)
|
||||||
self.assertEquals(pdu.is_state, event.is_state)
|
self.assertEquals(pdu.is_state, event.is_state)
|
||||||
self.assertEquals(pdu.depth, event.depth)
|
self.assertEquals(pdu.depth, event.depth)
|
||||||
self.assertEquals(["alice@bob.com"], event.prev_events)
|
self.assertEquals(pdu.prev_pdus, event.prev_pdus)
|
||||||
self.assertEquals(pdu.content, event.content)
|
self.assertEquals(pdu.content, event.content)
|
||||||
|
|
||||||
def test_pdu_from_event(self):
|
def test_pdu_from_event(self):
|
||||||
|
@ -144,7 +144,7 @@ class PduCodecTestCase(unittest.TestCase):
|
||||||
self.assertEquals(pdu.context, event.room_id)
|
self.assertEquals(pdu.context, event.room_id)
|
||||||
self.assertEquals(pdu.is_state, event.is_state)
|
self.assertEquals(pdu.is_state, event.is_state)
|
||||||
self.assertEquals(pdu.depth, event.depth)
|
self.assertEquals(pdu.depth, event.depth)
|
||||||
self.assertEquals(["alice@bob.com"], event.prev_events)
|
self.assertEquals(pdu.prev_pdus, event.prev_pdus)
|
||||||
self.assertEquals(pdu.content, event.content)
|
self.assertEquals(pdu.content, event.content)
|
||||||
self.assertEquals(pdu.state_key, event.state_key)
|
self.assertEquals(pdu.state_key, event.state_key)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue