Begin converting things to use the new Event structure

This commit is contained in:
Erik Johnston 2014-12-04 11:27:59 +00:00
parent 75b4329aaa
commit 5d7c9ab789
8 changed files with 96 additions and 39 deletions

View file

@ -393,18 +393,11 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_events.append(member_event.event_id) auth_events.append(member_event.event_id)
hashes = yield self.store.get_event_reference_hashes( auth_events = yield self.store.add_event_hashes(
auth_events auth_events
) )
hashes = [
{
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
for h in hashes
]
defer.returnValue(zip(auth_events, hashes)) defer.returnValue(auth_events)
@log_function @log_function
def _can_send_event(self, event, auth_events): def _can_send_event(self, event, auth_events):

View file

@ -81,6 +81,9 @@ class EventBase(object):
type = _event_dict_property("type") type = _event_dict_property("type")
user_id = _event_dict_property("sender") user_id = _event_dict_property("sender")
def is_state(self):
return hasattr(self, "state_key")
def get_dict(self): def get_dict(self):
d = dict(self._original) d = dict(self._original)
d.update({ d.update({

View file

@ -112,7 +112,7 @@ class ReplicationLayer(object):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
@log_function @log_function
def send_pdu(self, pdu): def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the """Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others. home server that should be transmitted to others.
@ -131,7 +131,7 @@ class ReplicationLayer(object):
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc. # TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order) self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug( logger.debug(
"[%s] transaction_layer.enqueue_pdu... done", "[%s] transaction_layer.enqueue_pdu... done",
@ -705,15 +705,13 @@ class _TransactionQueue(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def enqueue_pdu(self, pdu, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later. # table and we'll get back to it later.
destinations = set([ destinations = set(destinations)
d for d in pdu.destinations destinations.remove(self.server_name)
if d != self.server_name
])
logger.debug("Sending to: %s", str(destinations)) logger.debug("Sending to: %s", str(destinations))

View file

@ -15,11 +15,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.events.room import RoomMemberEvent from synapse.api.events.room import RoomMemberEvent
from synapse.api.constants import Membership from synapse.api.constants import Membership, EventTypes
from synapse.events.snapshot import EventSnapshot, EventContext from synapse.events.snapshot import EventSnapshot, EventContext
@ -59,7 +59,7 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_client_event(self, builder): def _create_new_client_event(self, builder):
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id, builder.room_id,
) )
@ -67,17 +67,28 @@ class BaseHandler(object):
depth = max([d for _, _, d in latest_ret]) depth = max([d for _, _, d in latest_ret])
prev_events = [(e, h) for e, h, _ in latest_ret] prev_events = [(e, h) for e, h, _ in latest_ret]
group, curr_state = yield self.state_handler.resolve_state_groups( state_handler = self.state_handler
[e for e, _ in prev_events] if builder.is_state():
ret = yield state_handler.resolve_state_groups(
[e for e, _ in prev_events],
event_type=builder.event_type,
state_key=builder.state_key,
) )
snapshot = EventSnapshot( group, curr_state, prev_state = ret
prev_events=prev_events,
depth=depth, prev_state = yield self.store.add_event_hashes(
current_state=curr_state, prev_state
current_state_group=group,
) )
builder.prev_state = prev_state
else:
group, curr_state, _ = yield state_handler.resolve_state_groups(
[e for e, _ in prev_events],
)
builder.internal_metadata.state_group = group
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
@ -103,9 +114,39 @@ class BaseHandler(object):
auth_events=curr_auth_events, auth_events=curr_auth_events,
) )
defer.returnValue(
(event, context,)
)
@defer.inlineCallbacks
def _handle_new_client_event(self, event, context):
# We now need to go and hit out to wherever we need to hit out to.
self.auth.check(event, auth_events=context.auth_events) self.auth.check(event, auth_events=context.auth_events)
yield self.store.persist_event(event)
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(
self.hs.parse_userid(s.state_key).domain
)
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
yield self.notifier.on_new_room_event(event)
federation_handler = self.hs.get_handlers().federation_handler
yield federation_handler.handle_new_event(
event,
None,
destinations=destinations,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[], def _on_new_room_event(self, event, snapshot, extra_destinations=[],

View file

@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_event(self, event, snapshot): def handle_new_event(self, event, snapshot, destinations):
""" Takes in an event from the client to server side, that has already """ Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any been authed and handled by the state module, and sends it to any
remote home servers that may be interested. remote home servers that may be interested.
@ -92,12 +92,7 @@ class FederationHandler(BaseHandler):
yield run_on_reactor() yield run_on_reactor()
pdu = event yield self.replication_layer.send_pdu(event, destinations)
if not hasattr(pdu, "destinations") or not pdu.destinations:
pdu.destinations = []
yield self.replication_layer.send_pdu(pdu)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -378,6 +378,7 @@ class RoomMemberHandler(BaseHandler):
else: else:
# This is not a JOIN, so we can handle it normally. # This is not a JOIN, so we can handle it normally.
# FIXME: This isn't idempotency.
if prev_state and prev_state.membership == event.membership: if prev_state and prev_state.membership == event.membership:
# double same action, treat this event as a NOOP. # double same action, treat this event as a NOOP.
defer.returnValue({}) defer.returnValue({})

View file

@ -89,7 +89,7 @@ class StateHandler(object):
ids = [e for e, _ in event.prev_events] ids = [e for e, _ in event.prev_events]
ret = yield self.resolve_state_groups(ids) ret = yield self.resolve_state_groups(ids)
state_group, new_state = ret state_group, new_state, _ = ret
event.old_state_events = copy.deepcopy(new_state) event.old_state_events = copy.deepcopy(new_state)
@ -137,7 +137,7 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups(self, event_ids): def resolve_state_groups(self, event_ids, event_type=None, state_key=""):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -156,7 +156,10 @@ class StateHandler(object):
(e.type, e.state_key): e (e.type, e.state_key): e
for e in state_list for e in state_list
} }
defer.returnValue((name, state)) prev_state = state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
defer.returnValue((name, state, [prev_state]))
state = {} state = {}
for group, g_state in state_groups.items(): for group, g_state in state_groups.items():
@ -177,6 +180,13 @@ class StateHandler(object):
if len(v.values()) > 1 if len(v.values()) > 1
} }
if event_type:
prev_states = conflicted_state.get(
(event_type, state_key), {}
).keys()
else:
prev_states = []
try: try:
new_state = {} new_state = {}
new_state.update(unconflicted_state) new_state.update(unconflicted_state)
@ -186,7 +196,7 @@ class StateHandler(object):
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
defer.returnValue((None, new_state)) defer.returnValue((None, new_state, prev_states))
def _get_power_level_from_event_state(self, event, user_id): def _get_power_level_from_event_state(self, event, user_id):
if hasattr(event, "old_state_events") and event.old_state_events: if hasattr(event, "old_state_events") and event.old_state_events:

View file

@ -15,6 +15,8 @@
from _base import SQLBaseStore from _base import SQLBaseStore
from syutil.base64util import encode_base64
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes""" """Persistence for event signatures and hashes"""
@ -67,6 +69,20 @@ class SignatureStore(SQLBaseStore):
f f
) )
def add_event_hashes(self, event_ids):
hashes = yield self.store.get_event_reference_hashes(
event_ids
)
hashes = [
{
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
for h in hashes
]
defer.returnValue(zip(event_ids, hashes))
def _get_event_reference_hashes_txn(self, txn, event_id): def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU. """Get all the hashes for a given PDU.
Args: Args: