0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 03:03:50 +01:00

Remove PDU tables.

This commit is contained in:
Erik Johnston 2014-10-31 14:00:14 +00:00
parent 946d02536b
commit bfa36a72b9
6 changed files with 2 additions and 1230 deletions

View file

@ -32,76 +32,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PduActions(object):
""" Defines persistence actions that relate to handling PDUs.
"""
def __init__(self, datastore):
self.store = datastore
@log_function
def mark_as_processed(self, pdu):
""" Persist the fact that we have fully processed the given `Pdu`
Returns:
Deferred
"""
return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
@defer.inlineCallbacks
@log_function
def after_transaction(self, transaction_id, destination, origin):
""" Returns all `Pdu`s that we sent to the given remote home server
after a given transaction id.
Returns:
Deferred: Results in a list of `Pdu`s
"""
results = yield self.store.get_pdus_after_transaction(
transaction_id,
destination
)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@defer.inlineCallbacks
@log_function
def get_all_pdus_from_context(self, context):
results = yield self.store.get_all_pdus_from_context(context)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@defer.inlineCallbacks
@log_function
def backfill(self, context, pdu_list, limit):
""" For a given list of PDU id and origins return the proceeding
`limit` `Pdu`s in the given `context`.
Returns:
Deferred: Results in a list of `Pdu`s.
"""
results = yield self.store.get_backfill(
context, pdu_list, limit
)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@log_function
def is_new(self, pdu):
""" When we receive a `Pdu` from a remote home server, we want to
figure out whether it is `new`, i.e. it is not some historic PDU that
we haven't seen simply because we haven't backfilled back that far.
Returns:
Deferred: Results in a `bool`
"""
return self.store.is_pdu_new(
pdu_id=pdu.pdu_id,
origin=pdu.origin,
context=pdu.context,
depth=pdu.depth
)
class TransactionActions(object): class TransactionActions(object):
""" Defines persistence actions that relate to handling Transactions. """ Defines persistence actions that relate to handling Transactions.
""" """

View file

@ -21,7 +21,7 @@ from twisted.internet import defer
from .units import Transaction, Pdu, Edu from .units import Transaction, Pdu, Edu
from .persistence import PduActions, TransactionActions from .persistence import TransactionActions
from synapse.util.logutils import log_function from synapse.util.logutils import log_function

View file

@ -37,7 +37,6 @@ from .registration import RegistrationStore
from .room import RoomStore from .room import RoomStore
from .roommember import RoomMemberStore from .roommember import RoomMemberStore
from .stream import StreamStore from .stream import StreamStore
from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .event_federation import EventFederationStore from .event_federation import EventFederationStore
@ -60,7 +59,6 @@ logger = logging.getLogger(__name__)
SCHEMAS = [ SCHEMAS = [
"transactions", "transactions",
"pdu",
"users", "users",
"profiles", "profiles",
"presence", "presence",
@ -89,7 +87,7 @@ class _RollbackButIsFineException(Exception):
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore, PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore, ): EventFederationStore, ):
@ -150,68 +148,12 @@ class DataStore(RoomMemberStore, RoomStore,
def _persist_pdu_event_txn(self, txn, pdu=None, event=None, def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
backfilled=False, stream_ordering=None, backfilled=False, stream_ordering=None,
is_new_state=True): is_new_state=True):
if pdu is not None:
self._persist_event_pdu_txn(txn, pdu)
if event is not None: if event is not None:
return self._persist_event_txn( return self._persist_event_txn(
txn, event, backfilled, stream_ordering, txn, event, backfilled, stream_ordering,
is_new_state=is_new_state, is_new_state=is_new_state,
) )
def _persist_event_pdu_txn(self, txn, pdu):
cols = dict(pdu.__dict__)
unrec_keys = dict(pdu.unrecognized_keys)
del cols["hashes"]
del cols["signatures"]
del cols["content"]
del cols["prev_pdus"]
cols["content_json"] = json.dumps(pdu.content)
unrec_keys.update({
k: v for k, v in cols.items()
if k not in PdusTable.fields
})
cols["unrecognized_keys"] = json.dumps(unrec_keys)
cols["ts"] = cols.pop("origin_server_ts")
logger.debug("Persisting: %s", repr(cols))
for hash_alg, hash_base64 in pdu.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_pdu_content_hash_txn(
txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
)
signatures = pdu.signatures.get(pdu.origin, {})
for key_id, signature_base64 in signatures.items():
signature_bytes = decode_base64(signature_base64)
self._store_pdu_origin_signature_txn(
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
)
for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_pdu_hash_txn(
txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg,
hash_bytes
)
(ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu)
self._store_pdu_reference_hash_txn(
txn, pdu.pdu_id, pdu.origin, ref_alg, ref_hash_bytes
)
if pdu.is_state:
self._persist_state_txn(txn, pdu.prev_pdus, cols)
else:
self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
@log_function @log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True): is_new_state=True):

View file

@ -1,949 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu
from synapse.util.logutils import log_function
from syutil.base64util import encode_base64
from collections import namedtuple
import logging
logger = logging.getLogger(__name__)
class PduStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
def get_pdu(self, pdu_id, origin):
"""Given a pdu_id and origin, get a PDU.
Args:
txn
pdu_id (str)
origin (str)
Returns:
PduTuple: If the pdu does not exist in the database, returns None
"""
return self.runInteraction(
"get_pdu", self._get_pdu_tuple, pdu_id, origin
)
def _get_pdu_tuple(self, txn, pdu_id, origin):
res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
return res[0] if res else None
def _get_pdu_tuples(self, txn, pdu_id_tuples):
results = []
for pdu_id, origin in pdu_id_tuples:
txn.execute(
PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
(pdu_id, origin)
)
edges = [
(r.prev_pdu_id, r.prev_origin)
for r in PduEdgesTable.decode_results(txn.fetchall())
]
edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
signatures = self._get_pdu_origin_signatures_txn(
txn, pdu_id, origin
)
query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
"WHERE p.pdu_id = ? AND p.origin = ? "
) % {
"fields": _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s"),
"pdus": PdusTable.table_name,
"state": StatePdusTable.table_name,
}
txn.execute(query, (pdu_id, origin))
row = txn.fetchone()
if row:
results.append(PduTuple(
PduEntry(*row), edges, hashes, signatures, edge_hashes
))
return results
def get_current_state_for_context(self, context):
"""Get a list of PDUs that represent the current state for a given
context
Args:
context (str)
Returns:
list: A list of PduTuples
"""
return self.runInteraction(
"get_current_state_for_context",
self._get_current_state_for_context,
context
)
def _get_current_state_for_context(self, txn, context):
query = (
"SELECT pdu_id, origin FROM %s WHERE context = ?"
% CurrentStateTable.table_name
)
logger.debug("get_current_state %s, Args=%s", query, context)
txn.execute(query, (context,))
res = txn.fetchall()
logger.debug("get_current_state %d results", len(res))
return self._get_pdu_tuples(txn, res)
def _persist_pdu_txn(self, txn, prev_pdus, cols):
"""Inserts a (non-state) PDU into the database.
Args:
txn,
prev_pdus (list)
**cols: The columns to insert into the PdusTable.
"""
entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
txn.execute(PdusTable.insert_statement(), entry)
self._handle_prev_pdus(
txn, entry.outlier, entry.pdu_id, entry.origin,
prev_pdus, entry.context
)
def mark_pdu_as_processed(self, pdu_id, pdu_origin):
"""Mark a received PDU as processed.
Args:
txn
pdu_id (str)
pdu_origin (str)
"""
return self.runInteraction(
"mark_pdu_as_processed",
self._mark_as_processed, pdu_id, pdu_origin
)
def _mark_as_processed(self, txn, pdu_id, pdu_origin):
txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context."""
return self.runInteraction(
"get_all_pdus_from_context",
self._get_all_pdus_from_context, context,
)
def _get_all_pdus_from_context(self, txn, context):
query = (
"SELECT pdu_id, origin FROM %s "
"WHERE context = ?"
) % PdusTable.table_name
txn.execute(query, (context,))
return self._get_pdu_tuples(txn, txn.fetchall())
def get_backfill(self, context, pdu_list, limit):
"""Get a list of Pdus for a given topic that occured before (and
including) the pdus in pdu_list. Return a list of max size `limit`.
Args:
txn
context (str)
pdu_list (list)
limit (int)
Return:
list: A list of PduTuples
"""
return self.runInteraction(
"get_backfill",
self._get_backfill, context, pdu_list, limit
)
def _get_backfill(self, txn, context, pdu_list, limit):
logger.debug(
"backfill: %s, %s, %s",
context, repr(pdu_list), limit
)
# We seed the pdu_results with the things from the pdu_list.
pdu_results = pdu_list
front = pdu_list
query = (
"SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
"WHERE context = ? AND pdu_id = ? AND origin = ? "
"LIMIT ?"
) % {
"edges_table": PduEdgesTable.table_name,
}
# We iterate through all pdu_ids in `front` to select their previous
# pdus. These are dumped in `new_front`. We continue until we reach the
# limit *or* new_front is empty (i.e., we've run out of things to
# select
while front and len(pdu_results) < limit:
new_front = []
for pdu_id, origin in front:
logger.debug(
"_backfill_interaction: i=%s, o=%s",
pdu_id, origin
)
txn.execute(
query,
(context, pdu_id, origin, limit - len(pdu_results))
)
for row in txn.fetchall():
logger.debug(
"_backfill_interaction: got i=%s, o=%s",
*row
)
new_front.append(row)
front = new_front
pdu_results += new_front
# We also want to update the `prev_pdus` attributes before returning.
return self._get_pdu_tuples(txn, pdu_results)
def get_min_depth_for_context(self, context):
"""Get the current minimum depth for a context
Args:
txn
context (str)
"""
return self.runInteraction(
"get_min_depth_for_context",
self._get_min_depth_for_context, context
)
def _get_min_depth_for_context(self, txn, context):
return self._get_min_depth_interaction(txn, context)
def _get_min_depth_interaction(self, txn, context):
txn.execute(
"SELECT min_depth FROM %s WHERE context = ?"
% ContextDepthTable.table_name,
(context,)
)
row = txn.fetchone()
return row[0] if row else None
def _update_min_depth_for_context_txn(self, txn, context, depth):
"""Update the minimum `depth` of the given context, which is the line
on which we stop backfilling backwards.
Args:
context (str)
depth (int)
"""
min_depth = self._get_min_depth_interaction(txn, context)
do_insert = depth < min_depth if min_depth else True
if do_insert:
txn.execute(
"INSERT OR REPLACE INTO %s (context, min_depth) "
"VALUES (?,?)" % ContextDepthTable.table_name,
(context, depth)
)
def get_latest_pdus_in_context(self, context):
return self.runInteraction(
"get_latest_pdus_in_context",
self._get_latest_pdus_in_context,
context
)
def _get_latest_pdus_in_context(self, txn, context):
"""Get's a list of the most current pdus for a given context. This is
used when we are sending a Pdu and need to fill out the `prev_pdus`
key
Args:
txn
context
"""
query = (
"SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
"INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
"AND f.origin = p.origin "
"WHERE f.context = ?"
) % {
"pdus": PdusTable.table_name,
"forward": PduForwardExtremitiesTable.table_name,
}
logger.debug("get_prev query: %s", query)
txn.execute(
query,
(context, )
)
results = []
for pdu_id, origin, depth in txn.fetchall():
hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
sha256_bytes = hashes["sha256"]
prev_hashes = {"sha256": encode_base64(sha256_bytes)}
results.append((pdu_id, origin, prev_hashes, depth))
return results
@defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context):
"""Get a list of Pdus that we haven't backfilled beyond yet (and havent
seen). This list is used when we want to backfill backwards and is the
list we send to the remote server.
Args:
txn
context (str)
Returns:
list: A list of PduIdTuple.
"""
results = yield self._execute(
None,
"SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
% {"back": PduBackwardExtremitiesTable.table_name, },
context
)
defer.returnValue([PduIdTuple(i, o) for i, o in results])
def is_pdu_new(self, pdu_id, origin, context, depth):
"""For a given Pdu, try and figure out if it's 'new', i.e., if it's
not something we got randomly from the past, for example when we
request the current state of the room that will probably return a bunch
of pdus from before we joined.
Args:
txn
pdu_id (str)
origin (str)
context (str)
depth (int)
Returns:
bool
"""
return self.runInteraction(
"is_pdu_new",
self._is_pdu_new,
pdu_id=pdu_id,
origin=origin,
context=context,
depth=depth
)
def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
# If depth > min depth in back table, then we classify it as new.
# OR if there is nothing in the back table, then it kinda needs to
# be a new thing.
query = (
"SELECT min(p.depth) FROM %(edges)s as e "
"INNER JOIN %(back)s as b "
"ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
"INNER JOIN %(pdus)s as p "
"ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
"WHERE p.context = ?"
) % {
"pdus": PdusTable.table_name,
"edges": PduEdgesTable.table_name,
"back": PduBackwardExtremitiesTable.table_name,
}
txn.execute(query, (context,))
min_depth, = txn.fetchone()
if not min_depth or depth > int(min_depth):
logger.debug(
"is_new true: id=%s, o=%s, d=%s min_depth=%s",
pdu_id, origin, depth, min_depth
)
return True
# If this pdu is in the forwards table, then it also is a new one
query = (
"SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
) % {
"forward": PduForwardExtremitiesTable.table_name,
}
txn.execute(query, (pdu_id, origin))
# Did we get anything?
if txn.fetchall():
logger.debug(
"is_new true: id=%s, o=%s, d=%s was forward",
pdu_id, origin, depth
)
return True
logger.debug(
"is_new false: id=%s, o=%s, d=%s",
pdu_id, origin, depth
)
# FINE THEN. It's probably old.
return False
@staticmethod
@log_function
def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
context):
txn.executemany(
PduEdgesTable.insert_statement(),
[(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
)
# Update the extremities table if this is not an outlier.
if not outlier:
# First, we delete the new one from the forwards extremities table.
query = (
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
% PduForwardExtremitiesTable.table_name
)
txn.executemany(query, list(p[:2] for p in prev_pdus))
# We only insert as a forward extremety the new pdu if there are no
# other pdus that reference it as a prev pdu
query = (
"INSERT INTO %(table)s (pdu_id, origin, context) "
"SELECT ?, ?, ? WHERE NOT EXISTS ("
"SELECT 1 FROM %(pdu_edges)s WHERE "
"prev_pdu_id = ? AND prev_origin = ?"
")"
) % {
"table": PduForwardExtremitiesTable.table_name,
"pdu_edges": PduEdgesTable.table_name
}
logger.debug("query: %s", query)
txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
# Insert all the prev_pdus as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
txn.executemany(
PduBackwardExtremitiesTable.insert_statement(),
[(i, o, context) for i, o, _ in prev_pdus]
)
# Also delete from the backwards extremities table all ones that
# reference pdus that we have already seen
query = (
"DELETE FROM %(pdu_back)s WHERE EXISTS ("
"SELECT 1 FROM %(pdus)s AS pdus "
"WHERE "
"%(pdu_back)s.pdu_id = pdus.pdu_id "
"AND %(pdu_back)s.origin = pdus.origin "
"AND not pdus.outlier "
")"
) % {
"pdu_back": PduBackwardExtremitiesTable.table_name,
"pdus": PdusTable.table_name,
}
txn.execute(query)
class StatePduStore(SQLBaseStore):
"""A collection of queries for handling state PDUs.
"""
def _persist_state_txn(self, txn, prev_pdus, cols):
"""Inserts a state PDU into the database
Args:
txn,
prev_pdus (list)
**cols: The columns to insert into the PdusTable and StatePdusTable
"""
pdu_entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
state_entry = StatePdusTable.EntryType(
**{k: cols.get(k, None) for k in StatePdusTable.fields}
)
logger.debug("Inserting pdu: %s", repr(pdu_entry))
logger.debug("Inserting state: %s", repr(state_entry))
txn.execute(PdusTable.insert_statement(), pdu_entry)
txn.execute(StatePdusTable.insert_statement(), state_entry)
self._handle_prev_pdus(
txn,
pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
pdu_entry.context
)
def get_unresolved_state_tree(self, new_state_pdu):
return self.runInteraction(
"get_unresolved_state_tree",
self._get_unresolved_state_tree, new_state_pdu
)
@log_function
def _get_unresolved_state_tree(self, txn, new_pdu):
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
ReturnType = namedtuple(
"StateReturnType", ["new_branch", "current_branch"]
)
return_value = ReturnType([new_pdu], [])
if not current:
logger.debug("get_unresolved_state_tree No current state.")
return (return_value, None)
return_value.current_branch.append(current)
enum_branches = self._enumerate_state_branches(
txn, new_pdu, current
)
missing_branch = None
for branch, prev_state, state in enum_branches:
if state:
return_value[branch].append(state)
else:
# We don't have prev_state :(
missing_branch = branch
break
return (return_value, missing_branch)
def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key):
return self.runInteraction(
"update_current_state",
self._update_current_state,
pdu_id, origin, context, pdu_type, state_key
)
def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
state_key):
query = (
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
) % {
"curr": CurrentStateTable.table_name,
"fields": CurrentStateTable.get_fields_string(),
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
}
query_args = CurrentStateTable.EntryType(
pdu_id=pdu_id,
origin=origin,
context=context,
pdu_type=pdu_type,
state_key=state_key
)
txn.execute(query, query_args)
def get_current_state_pdu(self, context, pdu_type, state_key):
"""For a given context, pdu_type, state_key 3-tuple, return what is
currently considered the current state.
Args:
txn
context (str)
pdu_type (str)
state_key (str)
Returns:
PduEntry
"""
return self.runInteraction(
"get_current_state_pdu",
self._get_current_state_pdu, context, pdu_type, state_key
)
def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
return self._get_current_interaction(txn, context, pdu_type, state_key)
def _get_current_interaction(self, txn, context, pdu_type, state_key):
logger.debug(
"_get_current_interaction %s %s %s",
context, pdu_type, state_key
)
fields = _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s")
current_query = (
"SELECT %(fields)s FROM %(state)s as s "
"INNER JOIN %(pdus)s as p "
"ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
"INNER JOIN %(curr)s as c "
"ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
"WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
) % {
"fields": fields,
"curr": CurrentStateTable.table_name,
"state": StatePdusTable.table_name,
"pdus": PdusTable.table_name,
}
txn.execute(
current_query,
(context, pdu_type, state_key)
)
row = txn.fetchone()
result = PduEntry(*row) if row else None
if not result:
logger.debug("_get_current_interaction not found")
else:
logger.debug(
"_get_current_interaction found %s %s",
result.pdu_id, result.origin
)
return result
def handle_new_state(self, new_pdu):
"""Actually perform conflict resolution on the new_pdu on the
assumption we have all the pdus required to perform it.
Args:
new_pdu
Returns:
bool: True if the new_pdu clobbered the current state, False if not
"""
return self.runInteraction(
"handle_new_state",
self._handle_new_state, new_pdu
)
def _handle_new_state(self, txn, new_pdu):
logger.debug(
"handle_new_state %s %s",
new_pdu.pdu_id, new_pdu.origin
)
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
is_current = False
if (not current or not current.prev_state_id
or not current.prev_state_origin):
# Oh, we don't have any state for this yet.
is_current = True
elif (current.pdu_id == new_pdu.prev_state_id
and current.origin == new_pdu.prev_state_origin):
# Oh! A direct clobber. Just do it.
is_current = True
else:
##
# Ok, now loop through until we get to a common ancestor.
max_new = int(new_pdu.power_level)
max_current = int(current.power_level)
enum_branches = self._enumerate_state_branches(
txn, new_pdu, current
)
for branch, prev_state, state in enum_branches:
if not state:
raise RuntimeError(
"Could not find state_pdu %s %s" %
(
prev_state.prev_state_id,
prev_state.prev_state_origin
)
)
if branch == 0:
max_new = max(int(state.depth), max_new)
else:
max_current = max(int(state.depth), max_current)
is_current = max_new > max_current
if is_current:
logger.debug("handle_new_state make current")
# Right, this is a new thing, so woo, just insert it.
txn.execute(
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
% {
"curr": CurrentStateTable.table_name,
"fields": CurrentStateTable.get_fields_string(),
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
},
CurrentStateTable.EntryType(
*(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
)
)
else:
logger.debug("handle_new_state not current")
logger.debug("handle_new_state done")
return is_current
@log_function
def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
branch_a = pdu_a
branch_b = pdu_b
while True:
if (branch_a.pdu_id == branch_b.pdu_id
and branch_a.origin == branch_b.origin):
# Woo! We found a common ancestor
logger.debug("_enumerate_state_branches Found common ancestor")
break
do_branch_a = (
hasattr(branch_a, "prev_state_id") and
branch_a.prev_state_id
)
do_branch_b = (
hasattr(branch_b, "prev_state_id") and
branch_b.prev_state_id
)
logger.debug(
"do_branch_a=%s, do_branch_b=%s",
do_branch_a, do_branch_b
)
if do_branch_a and do_branch_b:
do_branch_a = int(branch_a.depth) > int(branch_b.depth)
if do_branch_a:
pdu_tuple = PduIdTuple(
branch_a.prev_state_id,
branch_a.prev_state_origin
)
prev_branch = branch_a
logger.debug("getting branch_a prev %s", pdu_tuple)
branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_a:
branch_a = Pdu.from_pdu_tuple(branch_a)
logger.debug("branch_a=%s", branch_a)
yield (0, prev_branch, branch_a)
if not branch_a:
break
elif do_branch_b:
pdu_tuple = PduIdTuple(
branch_b.prev_state_id,
branch_b.prev_state_origin
)
prev_branch = branch_b
logger.debug("getting branch_b prev %s", pdu_tuple)
branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_b:
branch_b = Pdu.from_pdu_tuple(branch_b)
logger.debug("branch_b=%s", branch_b)
yield (1, prev_branch, branch_b)
if not branch_b:
break
else:
break
class PdusTable(Table):
table_name = "pdus"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"ts",
"depth",
"is_state",
"content_json",
"unrecognized_keys",
"outlier",
"have_processed",
]
EntryType = namedtuple("PdusEntry", fields)
class PduDestinationsTable(Table):
table_name = "pdu_destinations"
fields = [
"pdu_id",
"origin",
"destination",
"delivered_ts",
]
EntryType = namedtuple("PduDestinationsEntry", fields)
class PduEdgesTable(Table):
table_name = "pdu_edges"
fields = [
"pdu_id",
"origin",
"prev_pdu_id",
"prev_origin",
"context"
]
EntryType = namedtuple("PduEdgesEntry", fields)
class PduForwardExtremitiesTable(Table):
table_name = "pdu_forward_extremities"
fields = [
"pdu_id",
"origin",
"context",
]
EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
class PduBackwardExtremitiesTable(Table):
table_name = "pdu_backward_extremities"
fields = [
"pdu_id",
"origin",
"context",
]
EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
class ContextDepthTable(Table):
table_name = "context_depth"
fields = [
"context",
"min_depth",
]
EntryType = namedtuple("ContextDepthEntry", fields)
class StatePdusTable(Table):
table_name = "state_pdus"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"state_key",
"power_level",
"prev_state_id",
"prev_state_origin",
]
EntryType = namedtuple("StatePdusEntry", fields)
class CurrentStateTable(Table):
table_name = "current_state"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"state_key",
]
EntryType = namedtuple("CurrentStateEntry", fields)
_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
# TODO: These should probably be put somewhere more sensible
PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
PduEntry = _pdu_state_joiner.EntryType
""" We are always interested in the join of the PdusTable and StatePdusTable,
rather than just the PdusTable.
This does not include a prev_pdus key.
"""
PduTuple = namedtuple(
"PduTuple",
("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
)
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU.
"""

View file

@ -1,106 +0,0 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Stores pdus and their content
CREATE TABLE IF NOT EXISTS pdus(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
ts INTEGER,
depth INTEGER DEFAULT 0 NOT NULL,
is_state BOOL,
content_json TEXT,
unrecognized_keys TEXT,
outlier BOOL NOT NULL,
have_processed BOOL,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
);
-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
CREATE TABLE IF NOT EXISTS state_pdus(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
state_key TEXT,
power_level TEXT,
prev_state_id TEXT,
prev_state_origin TEXT,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
);
CREATE TABLE IF NOT EXISTS current_state(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
state_key TEXT,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
);
-- Stores where each pdu we want to send should be sent and the delivery status.
create TABLE IF NOT EXISTS pdu_destinations(
pdu_id TEXT,
origin TEXT,
destination TEXT,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
pdu_id TEXT,
origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
pdu_id TEXT,
origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_edges(
pdu_id TEXT,
origin TEXT,
prev_pdu_id TEXT,
prev_origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
);
CREATE TABLE IF NOT EXISTS context_depth(
context TEXT,
min_depth INTEGER,
CONSTRAINT uniqueness UNIQUE (context)
);
CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);

View file

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Table from ._base import SQLBaseStore, Table
from .pdu import PdusTable
from collections import namedtuple from collections import namedtuple
@ -207,50 +206,6 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall()) return ReceivedTransactionsTable.decode_results(txn.fetchall())
def get_pdus_after_transaction(self, transaction_id, destination):
"""For a given local transaction_id that we sent to a given destination
home server, return a list of PDUs that were sent to that destination
after it.
Args:
txn
transaction_id (str)
destination (str)
Returns
list: A list of PduTuple
"""
return self.runInteraction(
"get_pdus_after_transaction",
self._get_pdus_after_transaction,
transaction_id, destination
)
def _get_pdus_after_transaction(self, txn, transaction_id, destination):
# Query that first get's all transaction_ids with an id greater than
# the one given from the `sent_transactions` table. Then JOIN on this
# from the `tx->pdu` table to get a list of (pdu_id, origin) that
# specify the pdus that were sent in those transactions.
query = (
"SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
"INNER JOIN %(sent_tx)s as st "
"ON tp.transaction_id = st.transaction_id "
"AND tp.destination = st.destination "
"WHERE st.id > ("
"SELECT id FROM %(sent_tx)s "
"WHERE transaction_id = ? AND destination = ?"
) % {
"tx_pdu": TransactionsToPduTable.table_name,
"sent_tx": SentTransactions.table_name,
}
txn.execute(query, (transaction_id, destination))
pdus = PdusTable.decode_results(txn.fetchall())
return self._get_pdu_tuples(txn, pdus)
class ReceivedTransactionsTable(Table): class ReceivedTransactionsTable(Table):
table_name = "received_transactions" table_name = "received_transactions"