0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-22 10:20:10 +01:00

Fix public room joining by making sure replaces_state never points to itself.

This commit is contained in:
Erik Johnston 2014-12-11 15:56:01 +00:00
parent 9191292b0f
commit 0b04369238
3 changed files with 13 additions and 10 deletions

View file

@ -349,7 +349,7 @@ class FederationHandler(BaseHandler):
handled_events = set() handled_events = set()
try: try:
builder.event_id = self.event_factory.create_event_id() builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname builder.origin = self.hs.hostname
builder.content = content builder.content = content
@ -593,12 +593,12 @@ class FederationHandler(BaseHandler):
} }
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
if hasattr(event, "state_key"): if event and event.is_state():
# Get previous state # Get previous state
if hasattr(event, "replaces_state") and event.replaces_state: if "replaces_state" in event.unsigned:
prev_event = yield self.store.get_event( prev_id = event.unsigned["replaces_state"]
event.replaces_state if prev_id != event.event_id:
) prev_event = yield self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event results[(event.type, event.state_key)] = prev_event
else: else:
del results[(event.type, event.state_key)] del results[(event.type, event.state_key)]

View file

@ -159,6 +159,7 @@ class StateHandler(object):
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state: if key in context.current_state:
replaces = context.current_state[key] replaces = context.current_state[key]
if replaces.event_id != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces.event_id
defer.returnValue([]) defer.returnValue([])

View file

@ -451,7 +451,8 @@ class SQLBaseStore(object):
return events return events
def _get_event_txn(self, txn, event_id, check_redacted=True): def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=True):
sql = ( sql = (
"SELECT json, r.event_id FROM event_json as e " "SELECT json, r.event_id FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts " "LEFT JOIN redactions as r ON e.event_id = r.redacts "
@ -487,10 +488,11 @@ class SQLBaseStore(object):
if because: if because:
ev.unsigned["redacted_because"] = because ev.unsigned["redacted_because"] = because
if "replaces_state" in ev.unsigned: if get_prev_content and "replaces_state" in ev.unsigned:
ev.unsigned["prev_content"] = self._get_event_txn( ev.unsigned["prev_content"] = self._get_event_txn(
txn, txn,
ev.unsigned["replaces_state"], ev.unsigned["replaces_state"],
get_prev_content=False,
).get_dict()["content"] ).get_dict()["content"]
return ev return ev