0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-01 17:08:19 +02:00

Refactor _get_events

This commit is contained in:
Erik Johnston 2015-05-14 13:31:55 +01:00
parent 36ea26c5c0
commit cdb3757942
3 changed files with 131 additions and 261 deletions

View file

@ -17,6 +17,7 @@ import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util import unwrap_deferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
@ -28,7 +29,6 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import functools import functools
import itertools
import simplejson as json import simplejson as json
import sys import sys
import time import time
@ -870,35 +870,43 @@ class SQLBaseStore(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True, def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, desc="_get_events"): get_prev_content=False, allow_rejected=False, txn=None):
N = 50 # Only fetch 100 events at a time. if not event_ids:
defer.returnValue([])
ds = [ event_map = self._get_events_from_cache(
self._fetch_events( event_ids,
event_ids[i*N:(i+1)*N], check_redacted=check_redacted,
check_redacted=check_redacted, get_prev_content=get_prev_content,
get_prev_content=get_prev_content, allow_rejected=allow_rejected,
)
for i in range(1 + len(event_ids) / N)
]
res = yield defer.gatherResults(ds, consumeErrors=True)
defer.returnValue(
list(itertools.chain(*res))
) )
missing_events = [e for e in event_ids if e not in event_map]
missing_events = yield self._fetch_events(
txn,
missing_events,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
def _get_events_txn(self, txn, event_ids, check_redacted=True, def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False): get_prev_content=False, allow_rejected=False):
N = 50 # Only fetch 100 events at a time. return unwrap_deferred(self._get_events(
return list(itertools.chain(*[ event_ids,
self._fetch_events_txn( check_redacted=check_redacted,
txn, event_ids[i*N:(i+1)*N], get_prev_content=get_prev_content,
check_redacted=check_redacted, allow_rejected=allow_rejected,
get_prev_content=get_prev_content, txn=txn,
) ))
for i in range(1 + len(event_ids) / N)
]))
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True): for check_redacted in (False, True):
@ -909,68 +917,24 @@ class SQLBaseStore(object):
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
start_time = time.time() * 1000 events = self._get_events_txn(
txn, [event_id],
def update_counter(desc, last_time):
curr_time = self._get_event_counters.update(desc, last_time)
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
try:
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
if allow_rejected or not ret.rejected_reason:
return ret
else:
return None
except KeyError:
pass
finally:
start_time = update_counter("event_cache", start_time)
sql = (
"SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
"FROM event_json as e "
"LEFT JOIN rejections as rej USING (event_id) "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
txn.execute(sql, (event_id,))
res = txn.fetchone()
if not res:
return None
internal_metadata, js, redacted, rejected_reason = res
start_time = update_counter("select_event", start_time)
result = self._get_event_from_row_txn(
txn, internal_metadata, js, redacted,
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
rejected_reason=rejected_reason, allow_rejected=allow_rejected,
) )
self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
if allow_rejected or not rejected_reason: return events[0] if events else None
return result
else:
return None
def _fetch_events_txn(self, txn, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not events:
return []
def _get_events_from_cache(self, events, check_redacted, get_prev_content,
allow_rejected):
event_map = {} event_map = {}
for event_id in events: for event_id in events:
try: try:
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content
)
if allow_rejected or not ret.rejected_reason: if allow_rejected or not ret.rejected_reason:
event_map[event_id] = ret event_map[event_id] = ret
@ -979,136 +943,82 @@ class SQLBaseStore(object):
except KeyError: except KeyError:
pass pass
missing_events = [ return event_map
e for e in events
if e not in event_map
]
if missing_events:
sql = (
"SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(missing_events)),)
txn.execute(sql, missing_events)
rows = txn.fetchall()
res = [
self._get_event_from_row_txn(
txn, row[0], row[1], row[2],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row[3],
)
for row in rows
]
event_map.update({
e.event_id: e
for e in res if e
})
for e in res:
self._get_event_cache.prefill(
e.event_id, check_redacted, get_prev_content, e
)
return [
event_map[e_id] for e_id in events
if e_id in event_map and event_map[e_id]
]
@defer.inlineCallbacks @defer.inlineCallbacks
def _fetch_events(self, events, check_redacted=True, def _fetch_events(self, txn, events, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
if not events: if not events:
defer.returnValue([]) defer.returnValue({})
event_map = {} rows = []
N = 2
for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N]
if not evs:
break
for event_id in events:
try:
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
if allow_rejected or not ret.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
except KeyError:
pass
missing_events = [
e for e in events
if e not in event_map
]
if missing_events:
sql = ( sql = (
"SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
" FROM event_json as e" " FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts" " LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)" " WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(missing_events)),) ) % (",".join(["?"]*len(evs)),)
rows = yield self._execute( if txn:
"_fetch_events", txn.execute(sql, evs)
None, rows.extend(txn.fetchall())
sql, else:
*missing_events res = yield self._execute("_fetch_events", None, sql, *evs)
rows.extend(res)
res = []
for row in rows:
e = yield self._get_event_from_row(
txn,
row[0], row[1], row[2],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row[3],
)
res.append(e)
for e in res:
self._get_event_cache.prefill(
e.event_id, check_redacted, get_prev_content, e
) )
res_ds = [ defer.returnValue({
self._get_event_from_row( e.event_id: e
row[0], row[1], row[2], for e in res if e
check_redacted=check_redacted, })
get_prev_content=get_prev_content,
rejected_reason=row[3],
)
for row in rows
]
res = yield defer.gatherResults(res_ds, consumeErrors=True)
event_map.update({
e.event_id: e
for e in res if e
})
for e in res:
self._get_event_cache.prefill(
e.event_id, check_redacted, get_prev_content, e
)
defer.returnValue([
event_map[e_id] for e_id in events
if e_id in event_map and event_map[e_id]
])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted, def _get_event_from_row(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False, check_redacted=True, get_prev_content=False,
rejected_reason=None): rejected_reason=None):
start_time = time.time() * 1000
def update_counter(desc, last_time):
curr_time = self._get_event_counters.update(desc, last_time)
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
d = json.loads(js) d = json.loads(js)
start_time = update_counter("decode_json", start_time)
internal_metadata = json.loads(internal_metadata) internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time)
def select(txn, *args, **kwargs):
if txn:
return self._simple_select_one_onecol_txn(txn, *args, **kwargs)
else:
return self._simple_select_one_onecol(
*args,
desc="_get_event_from_row", **kwargs
)
def get_event(txn, *args, **kwargs):
if txn:
return self._get_event_txn(txn, *args, **kwargs)
else:
return self.get_event(*args, **kwargs)
if rejected_reason: if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol( rejected_reason = yield select(
desc="_get_event_from_row", txn,
table="rejections", table="rejections",
keyvalues={"event_id": rejected_reason}, keyvalues={"event_id": rejected_reason},
retcol="reason", retcol="reason",
@ -1119,13 +1029,12 @@ class SQLBaseStore(object):
internal_metadata_dict=internal_metadata, internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason, rejected_reason=rejected_reason,
) )
start_time = update_counter("build_frozen_event", start_time)
if check_redacted and redacted: if check_redacted and redacted:
ev = prune_event(ev) ev = prune_event(ev)
redaction_id = yield self._simple_select_one_onecol( redaction_id = yield select(
desc="_get_event_from_row", txn,
table="redactions", table="redactions",
keyvalues={"redacts": ev.event_id}, keyvalues={"redacts": ev.event_id},
retcol="event_id", retcol="event_id",
@ -1134,93 +1043,26 @@ class SQLBaseStore(object):
ev.unsigned["redacted_by"] = redaction_id ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event. # Get the redaction event.
because = yield self.get_event_txn( because = yield get_event(
txn,
redaction_id, redaction_id,
check_redacted=False check_redacted=False
) )
if because: if because:
ev.unsigned["redacted_because"] = because ev.unsigned["redacted_because"] = because
start_time = update_counter("redact_event", start_time)
if get_prev_content and "replaces_state" in ev.unsigned: if get_prev_content and "replaces_state" in ev.unsigned:
prev = yield self.get_event( prev = yield get_event(
txn,
ev.unsigned["replaces_state"], ev.unsigned["replaces_state"],
get_prev_content=False, get_prev_content=False,
) )
if prev: if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
start_time = update_counter("get_prev_content", start_time)
defer.returnValue(ev) defer.returnValue(ev)
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
start_time = time.time() * 1000
def update_counter(desc, last_time):
curr_time = self._get_event_counters.update(desc, last_time)
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
d = json.loads(js)
start_time = update_counter("decode_json", start_time)
internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time)
if rejected_reason:
rejected_reason = self._simple_select_one_onecol_txn(
txn,
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
start_time = update_counter("build_frozen_event", start_time)
if check_redacted and redacted:
ev = prune_event(ev)
redaction_id = self._simple_select_one_onecol_txn(
txn,
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = self._get_event_txn(
txn,
redaction_id,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
start_time = update_counter("redact_event", start_time)
if get_prev_content and "replaces_state" in ev.unsigned:
prev = self._get_event_txn(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
start_time = update_counter("get_prev_content", start_time)
return ev
def _parse_events(self, rows): def _parse_events(self, rows):
return self.runInteraction( return self.runInteraction(
"_parse_events", self._parse_events_txn, rows "_parse_events", self._parse_events_txn, rows

View file

@ -85,7 +85,7 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def c(vals): def c(vals):
vals[:] = yield self._fetch_events(vals, get_prev_content=False) vals[:] = yield self._get_events(vals, get_prev_content=False)
yield defer.gatherResults( yield defer.gatherResults(
[ [

View file

@ -29,6 +29,34 @@ def unwrapFirstError(failure):
return failure.value.subFailure return failure.value.subFailure
def unwrap_deferred(d):
"""Given a deferred that we know has completed, return its value or raise
the failure as an exception
"""
if not d.called:
raise RuntimeError("deferred has not finished")
res = []
def f(r):
res.append(r)
return r
d.addCallback(f)
if res:
return res[0]
def f(r):
res.append(r)
return r
d.addErrback(f)
if res:
res[0].raiseException()
else:
raise RuntimeError("deferred did not call callbacks")
class Clock(object): class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.