Remove race condition

This commit is contained in:
Erik Johnston 2015-05-14 16:54:35 +01:00
parent ef3d8754f5
commit 1d566edb81
4 changed files with 161 additions and 100 deletions

View file

@ -26,6 +26,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import contextlib
import functools import functools
import sys import sys
import time import time
@ -299,7 +301,7 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._event_fetch_lock = threading.Lock() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
@ -342,6 +344,84 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
@contextlib.contextmanager
def _new_transaction(self, conn, desc, after_callbacks):
start = time.time() * 1000
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
try:
yield txn
conn.commit()
return
except:
try:
conn.rollback()
except:
pass
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
@defer.inlineCallbacks @defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool.""" """Wraps the .runInteraction() method on the underlying db_pool."""
@ -353,75 +433,15 @@ class SQLBaseStore(object):
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn): if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection") logger.debug("Reconnecting closed database connection")
conn.reconnect() conn.reconnect()
current_context.copy_to(context) current_context.copy_to(context)
start = time.time() * 1000 with self._new_transaction(conn, desc, after_callbacks) as txn:
txn_id = self._TXN_ID return func(txn, *args, **kwargs)
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
result = yield preserve_context_over_fn( result = yield preserve_context_over_fn(
self._db_pool.runWithConnection, self._db_pool.runWithConnection,
@ -432,6 +452,32 @@ class SQLBaseStore(object):
after_callback(*after_args) after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
defer.returnValue(result)
def cursor_to_dict(self, cursor): def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.

View file

@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object): class PostgresEngine(object):
single_threaded = False
def __init__(self, database_module): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)

View file

@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True
def __init__(self, database_module): def __init__(self, database_module):
self.module = database_module self.module = database_module

View file

@ -504,23 +504,26 @@ class EventsStore(SQLBaseStore):
if not events: if not events:
defer.returnValue({}) defer.returnValue({})
def do_fetch(txn): def do_fetch(conn):
event_list = [] event_list = []
while True: while True:
try: try:
with self._event_fetch_lock: with self._event_fetch_lock:
event_list = self._event_fetch_list i = 0
self._event_fetch_list = [] while not self._event_fetch_list:
if not event_list:
self._event_fetch_ongoing -= 1 self._event_fetch_ongoing -= 1
return return
event_list = self._event_fetch_list
self._event_fetch_list = []
event_id_lists = zip(*event_list)[0] event_id_lists = zip(*event_list)[0]
event_ids = [ event_ids = [
item for sublist in event_id_lists for item in sublist item for sublist in event_id_lists for item in sublist
] ]
rows = self._fetch_event_rows(txn, event_ids)
with self._new_transaction(conn, "do_fetch", []) as txn:
rows = self._fetch_event_rows(txn, event_ids)
row_dict = { row_dict = {
r["event_id"]: r r["event_id"]: r
@ -528,22 +531,44 @@ class EventsStore(SQLBaseStore):
} }
for ids, d in event_list: for ids, d in event_list:
reactor.callFromThread( def fire():
d.callback, if not d.called:
[ d.callback(
row_dict[i] for i in ids [
if i in row_dict row_dict[i]
] for i in ids
) if i in row_dict
]
)
reactor.callFromThread(fire)
except Exception as e: except Exception as e:
logger.exception("do_fetch")
for _, d in event_list: for _, d in event_list:
try: if not d.called:
reactor.callFromThread(d.errback, e) reactor.callFromThread(d.errback, e)
except:
pass
def cb(rows): with self._event_fetch_lock:
return defer.gatherResults([ self._event_fetch_ongoing -= 1
return
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify_all()
# if self._event_fetch_ongoing < 5:
self._event_fetch_ongoing += 1
self.runWithConnection(
do_fetch
)
rows = yield events_d
res = yield defer.gatherResults(
[
self._get_event_from_row( self._get_event_from_row(
None, None,
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],
@ -552,23 +577,9 @@ class EventsStore(SQLBaseStore):
rejected_reason=row["rejects"], rejected_reason=row["rejects"],
) )
for row in rows for row in rows
]) ],
consumeErrors=True
d = defer.Deferred() )
d.addCallback(cb)
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, d)
)
if self._event_fetch_ongoing < 3:
self._event_fetch_ongoing += 1
self.runInteraction(
"do_fetch",
do_fetch
)
res = yield d
defer.returnValue({ defer.returnValue({
e.event_id: e e.event_id: e