0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 17:03:50 +01:00

Merge pull request #155 from matrix-org/erikj/perf

Bulk and batch retrieval of events.
This commit is contained in:
Erik Johnston 2015-05-21 14:54:40 +01:00
commit a551c5dad7
13 changed files with 595 additions and 203 deletions

View file

@ -222,7 +222,7 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
if pdu_list: if pdu_list and pdu_list[0]:
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
@ -255,7 +255,7 @@ class FederationClient(FederationBase):
) )
continue continue
if self._get_pdu_cache is not None: if self._get_pdu_cache is not None and pdu:
self._get_pdu_cache[event_id] = pdu self._get_pdu_cache[event_id] = pdu
defer.returnValue(pdu) defer.returnValue(pdu)
@ -561,7 +561,7 @@ class FederationClient(FederationBase):
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing): for (result, val), (e_id, _) in zip(res, ordered_missing):
if result: if result and val:
signed_events.append(val) signed_events.append(val)
else: else:
failed_to_fetch.add(e_id) failed_to_fetch.add(e_id)

View file

@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 18 SCHEMA_VERSION = 19
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -25,6 +25,7 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import functools import functools
import sys import sys
import time import time
@ -45,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
caches_by_name = {} caches_by_name = {}
cache_counter = metrics.register_cache( cache_counter = metrics.register_cache(
@ -298,6 +298,12 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
self._pending_ds = []
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator() self._stream_id_gen = StreamIdGenerator()
@ -337,6 +343,75 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
r = func(txn, *args, **kwargs)
conn.commit()
return r
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
@defer.inlineCallbacks @defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool.""" """Wraps the .runInteraction() method on the underlying db_pool."""
@ -348,75 +423,16 @@ class SQLBaseStore(object):
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn): if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection") logger.debug("Reconnecting closed database connection")
conn.reconnect() conn.reconnect()
current_context.copy_to(context) current_context.copy_to(context)
start = time.time() * 1000 return self._new_transaction(
txn_id = self._TXN_ID conn, desc, after_callbacks, func, *args, **kwargs
)
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
result = yield preserve_context_over_fn( result = yield preserve_context_over_fn(
self._db_pool.runWithConnection, self._db_pool.runWithConnection,
@ -427,6 +443,32 @@ class SQLBaseStore(object):
after_callback(*after_args) after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
defer.returnValue(result)
def cursor_to_dict(self, cursor): def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.

View file

@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object): class PostgresEngine(object):
single_threaded = False
def __init__(self, database_module): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)

View file

@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True
def __init__(self, database_module): def __init__(self, database_module):
self.module = database_module self.module = database_module

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cached
from syutil.base64util import encode_base64 from syutil.base64util import encode_base64
@ -33,16 +35,7 @@ class EventFederationStore(SQLBaseStore):
""" """
def get_auth_chain(self, event_ids): def get_auth_chain(self, event_ids):
return self.runInteraction( return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
"get_auth_chain",
self._get_auth_chain_txn,
event_ids
)
def _get_auth_chain_txn(self, txn, event_ids):
results = self._get_auth_chain_ids_txn(txn, event_ids)
return self._get_events_txn(txn, results)
def get_auth_chain_ids(self, event_ids): def get_auth_chain_ids(self, event_ids):
return self.runInteraction( return self.runInteraction(
@ -369,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_backfill_events", "get_backfill_events",
self._get_backfill_events, room_id, event_list, limit self._get_backfill_events, room_id, event_list, limit
) ).addCallback(self._get_events)
def _get_backfill_events(self, txn, room_id, event_list, limit): def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug( logger.debug(
@ -415,16 +408,26 @@ class EventFederationStore(SQLBaseStore):
front = new_front front = new_front
event_results += new_front event_results += new_front
return self._get_events_txn(txn, event_results) return event_results
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, def get_missing_events(self, room_id, earliest_events, latest_events,
limit, min_depth): limit, min_depth):
return self.runInteraction( ids = yield self.runInteraction(
"get_missing_events", "get_missing_events",
self._get_missing_events, self._get_missing_events,
room_id, earliest_events, latest_events, limit, min_depth room_id, earliest_events, latest_events, limit, min_depth
) )
events = yield self._get_events(ids)
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
defer.returnValue(events[:limit])
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
limit, min_depth): limit, min_depth):
@ -456,14 +459,7 @@ class EventFederationStore(SQLBaseStore):
front = new_front front = new_front
event_results |= new_front event_results |= new_front
events = self._get_events_txn(txn, event_results) return event_results
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
return events[:limit]
def clean_room_for_join(self, room_id): def clean_room_for_join(self, room_id):
return self.runInteraction( return self.runInteraction(

View file

@ -15,11 +15,12 @@
from _base import SQLBaseStore, _RollbackButIsFineException from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer from twisted.internet import defer, reactor
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
@ -34,6 +35,16 @@ import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# These values are used in the `enqueus_event` and `_do_fetch` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
# TODO: Make these configurable.
EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -91,18 +102,17 @@ class EventsStore(SQLBaseStore):
Returns: Returns:
Deferred : A FrozenEvent. Deferred : A FrozenEvent.
""" """
event = yield self.runInteraction( events = yield self._get_events(
"get_event", self._get_event_txn, [event_id],
event_id,
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
allow_rejected=allow_rejected, allow_rejected=allow_rejected,
) )
if not event and not allow_none: if not events and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,)) raise RuntimeError("Could not find event %s" % (event_id,))
defer.returnValue(event) defer.returnValue(events[0] if events else None)
@log_function @log_function
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context, backfilled,
@ -401,28 +411,75 @@ class EventsStore(SQLBaseStore):
"have_events", f, "have_events", f,
) )
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True, def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False): get_prev_content=False, allow_rejected=False):
return self.runInteraction( if not event_ids:
"_get_events", self._get_events_txn, event_ids, defer.returnValue([])
check_redacted=check_redacted, get_prev_content=get_prev_content,
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
) )
missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids:
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
def _get_events_txn(self, txn, event_ids, check_redacted=True, def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False): get_prev_content=False, allow_rejected=False):
if not event_ids: if not event_ids:
return [] return []
events = [ event_map = self._get_events_from_cache(
self._get_event_txn( event_ids,
txn, event_id, check_redacted=check_redacted,
check_redacted=check_redacted, get_prev_content=get_prev_content,
get_prev_content=get_prev_content allow_rejected=allow_rejected,
) )
for event_id in event_ids
]
return [e for e in events if e] missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids:
return [
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
]
missing_events = self._fetch_events_txn(
txn,
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
return [
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
]
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True): for check_redacted in (False, True):
@ -433,54 +490,217 @@ class EventsStore(SQLBaseStore):
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
try: events = self._get_events_txn(
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) txn, [event_id],
if allow_rejected or not ret.rejected_reason:
return ret
else:
return None
except KeyError:
pass
sql = (
"SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
"FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
"LEFT JOIN rejections as rej on rej.event_id = e.event_id "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
txn.execute(sql, (event_id,))
res = txn.fetchone()
if not res:
return None
internal_metadata, js, redacted, rejected_reason = res
result = self._get_event_from_row_txn(
txn, internal_metadata, js, redacted,
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
rejected_reason=rejected_reason, allow_rejected=allow_rejected,
) )
self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
if allow_rejected or not rejected_reason: return events[0] if events else None
return result
else:
return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, def _get_events_from_cache(self, events, check_redacted, get_prev_content,
check_redacted=True, get_prev_content=False, allow_rejected):
rejected_reason=None): event_map = {}
for event_id in events:
try:
ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content
)
if allow_rejected or not ret.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
except KeyError:
pass
return event_map
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
event_list = []
i = 0
while True:
try:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
single_threaded = self.database_engine.single_threaded
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
self._event_fetch_ongoing -= 1
return
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
event_id_lists = zip(*event_list)[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction(
conn, "do_fetch", [], self._fetch_event_rows, event_ids
)
row_dict = {
r["event_id"]: r
for r in rows
}
# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
if not d.called:
try:
d.callback([
res[i]
for i in ids
if i in res
])
except:
logger.exception("Failed to callback")
reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
for _, d in evs:
if not d.called:
d.errback(e)
if event_list:
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
"""
if not events:
defer.returnValue({})
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
should_start = True
else:
should_start = False
if should_start:
self.runWithConnection(
self._do_fetch
)
rows = yield preserve_context_over_deferred(events_d)
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults(
[
self._get_event_from_row(
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row["rejects"],
)
for row in rows
],
consumeErrors=True
)
defer.returnValue({
e.event_id: e
for e in res if e
})
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N]
if not evs:
break
sql = (
"SELECT "
" e.event_id as event_id, "
" e.internal_metadata,"
" e.json,"
" r.redacts as redacts,"
" rej.event_id as rejects "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
return rows
def _fetch_events_txn(self, txn, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not events:
return {}
rows = self._fetch_event_rows(
txn, events,
)
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = [
self._get_event_from_row_txn(
txn,
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row["rejects"],
)
for row in rows
]
return {
r.event_id: r
for r in res
}
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js) d = json.loads(js)
internal_metadata = json.loads(internal_metadata) internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
desc="_get_event_from_row",
)
ev = FrozenEvent( ev = FrozenEvent(
d, d,
internal_metadata_dict=internal_metadata, internal_metadata_dict=internal_metadata,
@ -490,12 +710,74 @@ class EventsStore(SQLBaseStore):
if check_redacted and redacted: if check_redacted and redacted:
ev = prune_event(ev) ev = prune_event(ev)
ev.unsigned["redacted_by"] = redacted redaction_id = yield self._simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
desc="_get_event_from_row",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
prev = yield self.get_event(
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
)
defer.returnValue(ev)
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = self._simple_select_one_onecol_txn(
txn,
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
if check_redacted and redacted:
ev = prune_event(ev)
redaction_id = self._simple_select_one_onecol_txn(
txn,
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event. # Get the redaction event.
because = self._get_event_txn( because = self._get_event_txn(
txn, txn,
redacted, redaction_id,
check_redacted=False check_redacted=False
) )
@ -511,6 +793,10 @@ class EventsStore(SQLBaseStore):
if prev: if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
)
return ev return ev
def _parse_events(self, rows): def _parse_events(self, rows):

View file

@ -76,16 +76,16 @@ class RoomMemberStore(SQLBaseStore):
Returns: Returns:
Deferred: Results in a MembershipEvent or None. Deferred: Results in a MembershipEvent or None.
""" """
def f(txn): return self.runInteraction(
events = self._get_members_events_txn( "get_room_member",
txn, self._get_members_events_txn,
room_id, room_id,
user_id=user_id, user_id=user_id,
) ).addCallback(
self._get_events
return events[0] if events else None ).addCallback(
lambda events: events[0] if events else None
return self.runInteraction("get_room_member", f) )
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
@ -110,15 +110,12 @@ class RoomMemberStore(SQLBaseStore):
Returns: Returns:
list of namedtuples representing the members in this room. list of namedtuples representing the members in this room.
""" """
return self.runInteraction(
def f(txn): "get_room_members",
return self._get_members_events_txn( self._get_members_events_txn,
txn, room_id,
room_id, membership=membership,
membership=membership, ).addCallback(self._get_events)
)
return self.runInteraction("get_room_members", f)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
@ -190,14 +187,14 @@ class RoomMemberStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_members_query", self._get_members_events_txn, "get_members_query", self._get_members_events_txn,
where_clause, where_values where_clause, where_values
) ).addCallbacks(self._get_events)
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn( rows = self._get_members_rows_txn(
txn, txn,
room_id, membership, user_id, room_id, membership, user_id,
) )
return self._get_events_txn(txn, [r["event_id"] for r in rows]) return [r["event_id"] for r in rows]
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?" where_clause = "c.room_id = ?"

View file

@ -0,0 +1,19 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX events_order_topo_stream_room ON events(
topological_ordering, stream_ordering, room_id
);

View file

@ -43,6 +43,7 @@ class StateStore(SQLBaseStore):
* `state_groups_state`: Maps state group to state events. * `state_groups_state`: Maps state group to state events.
""" """
@defer.inlineCallbacks
def get_state_groups(self, event_ids): def get_state_groups(self, event_ids):
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids
@ -71,17 +72,29 @@ class StateStore(SQLBaseStore):
retcol="event_id", retcol="event_id",
) )
state = self._get_events_txn(txn, state_ids) res[group] = state_ids
res[group] = state
return res return res
return self.runInteraction( states = yield self.runInteraction(
"get_state_groups", "get_state_groups",
f, f,
) )
@defer.inlineCallbacks
def c(vals):
vals[:] = yield self._get_events(vals, get_prev_content=False)
yield defer.gatherResults(
[
c(vals)
for vals in states.values()
],
consumeErrors=True,
)
defer.returnValue(states)
def _store_state_groups_txn(self, txn, event, context): def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None: if context.current_state is None:
return return
@ -146,11 +159,12 @@ class StateStore(SQLBaseStore):
args = (room_id, ) args = (room_id, )
txn.execute(sql, args) txn.execute(sql, args)
results = self.cursor_to_dict(txn) results = txn.fetchall()
return self._parse_events_txn(txn, results) return [r[0] for r in results]
events = yield self.runInteraction("get_current_state", f) event_ids = yield self.runInteraction("get_current_state", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)

View file

@ -224,7 +224,7 @@ class StreamStore(SQLBaseStore):
return self.runInteraction("get_room_events_stream", f) return self.runInteraction("get_room_events_stream", f)
@log_function @defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None, def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1, direction='b', limit=-1,
with_feedback=False): with_feedback=False):
@ -286,18 +286,20 @@ class StreamStore(SQLBaseStore):
# TODO (erikj): We should work out what to do here instead. # TODO (erikj): We should work out what to do here instead.
next_token = to_key if to_key else from_key next_token = to_key if to_key else from_key
events = self._get_events_txn( return rows, next_token,
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(events, rows) rows, token = yield self.runInteraction("paginate_room_events", f)
return events, next_token, events = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
return self.runInteraction("paginate_room_events", f) self._set_before_and_after(events, rows)
defer.returnValue((events, token))
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token, def get_recent_events_for_room(self, room_id, limit, end_token,
with_feedback=False, from_token=None): with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback # TODO (erikj): Handle compressed feedback
@ -349,20 +351,23 @@ class StreamStore(SQLBaseStore):
else: else:
token = (str(end_token), str(end_token)) token = (str(end_token), str(end_token))
events = self._get_events_txn( return rows, token
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(events, rows) rows, token = yield self.runInteraction(
return events, token
return self.runInteraction(
"get_recent_events_for_room", get_recent_events_for_room_txn "get_recent_events_for_room", get_recent_events_for_room_txn
) )
logger.debug("stream before")
events = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
logger.debug("stream after")
self._set_before_and_after(events, rows)
defer.returnValue((events, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'): def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token(self) token = yield self._stream_id_gen.get_max_token(self)

View file

@ -29,6 +29,34 @@ def unwrapFirstError(failure):
return failure.value.subFailure return failure.value.subFailure
def unwrap_deferred(d):
"""Given a deferred that we know has completed, return its value or raise
the failure as an exception
"""
if not d.called:
raise RuntimeError("deferred has not finished")
res = []
def f(r):
res.append(r)
return r
d.addCallback(f)
if res:
return res[0]
def f(r):
res.append(r)
return r
d.addErrback(f)
if res:
res[0].raiseException()
else:
raise RuntimeError("deferred did not call callbacks")
class Clock(object): class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.
@ -52,16 +80,16 @@ class Clock(object):
def stop_looping_call(self, loop): def stop_looping_call(self, loop):
loop.stop() loop.stop()
def call_later(self, delay, callback): def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(): def wrapped_callback(*args, **kwargs):
with PreserveLoggingContext(): with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context LoggingContext.thread_local.current_context = current_context
callback() callback(*args, **kwargs)
with PreserveLoggingContext(): with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback) return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer): def cancel_call_later(self, timer):
timer.cancel() timer.cancel()

View file

@ -33,8 +33,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db_pool = Mock(spec=["runInteraction"]) self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock() self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor"]) self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
self.mock_conn.cursor.return_value = self.mock_txn self.mock_conn.cursor.return_value = self.mock_txn
self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline # Our fake runInteraction just runs synchronously inline
def runInteraction(func, *args, **kwargs): def runInteraction(func, *args, **kwargs):