User Cursor.__iter__ instead of fetchall

This prevents unnecessary construction of lists
This commit is contained in:
Erik Johnston 2017-03-23 17:53:49 +00:00
parent 59358cd3e7
commit 00957d1aa4
16 changed files with 41 additions and 42 deletions

View file

@ -73,6 +73,9 @@ class LoggingTransaction(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self.txn, name, value) setattr(self.txn, name, value)
def __iter__(self):
return self.txn.__iter__()
def execute(self, sql, *args): def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args) self._do_execute(self.txn.execute, sql, *args)
@ -357,7 +360,7 @@ class SQLBaseStore(object):
""" """
col_headers = list(intern(column[0]) for column in cursor.description) col_headers = list(intern(column[0]) for column in cursor.description)
results = list( results = list(
dict(zip(col_headers, row)) for row in cursor.fetchall() dict(zip(col_headers, row)) for row in cursor
) )
return results return results
@ -579,7 +582,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol, def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"): desc="_simple_select_onecol"):
@ -901,14 +904,14 @@ class SQLBaseStore(object):
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = { cache = {
row[0]: int(row[1]) row[0]: int(row[1])
for row in rows for row in txn
} }
txn.close()
if cache: if cache:
min_val = min(cache.values()) min_val = min(cache.values())
else: else:

View file

@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
global_account_data = { global_account_data = {
row[0]: json.loads(row[1]) for row in txn.fetchall() row[0]: json.loads(row[1]) for row in txn
} }
sql = ( sql = (
@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
account_data_by_room = {} account_data_by_room = {}
for row in txn.fetchall(): for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2]) room_account_data[row[1]] = json.loads(row[2])

View file

@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"]) message_json = ujson.dumps(messages_by_device["*"])
for row in txn.fetchall(): for row in txn:
# Add the message for all devices for this user on this # Add the message for all devices for this user on this
# server. # server.
device = row[0] device = row[0]
@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
txn.execute(sql, [user_id] + devices) txn.execute(sql, [user_id] + devices)
for row in txn.fetchall(): for row in txn:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server
device = row[0] device = row[0]
@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit user_id, device_id, last_stream_id, current_stream_id, limit
)) ))
messages = [] messages = []
for row in txn.fetchall(): for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(ujson.loads(row[1])) messages.append(ujson.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:
@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" ORDER BY stream_id ASC" " ORDER BY stream_id ASC"
) )
txn.execute(sql, (last_pos, upper_pos)) txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn.fetchall()) rows.extend(txn)
return rows return rows
@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit destination, last_stream_id, current_stream_id, limit
)) ))
messages = [] messages = []
for row in txn.fetchall(): for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(ujson.loads(row[1])) messages.append(ujson.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:

View file

@ -333,13 +333,12 @@ class DeviceStore(SQLBaseStore):
txn.execute( txn.execute(
sql, (destination, from_stream_id, now_stream_id, False) sql, (destination, from_stream_id, now_stream_id, False)
) )
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id # maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows} query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True txn, query_map.keys(), include_all_devices=True
) )

View file

@ -153,7 +153,7 @@ class EndToEndKeyStore(SQLBaseStore):
) )
txn.execute(sql, (user_id, device_id)) txn.execute(sql, (user_id, device_id))
result = {} result = {}
for algorithm, key_count in txn.fetchall(): for algorithm, key_count in txn:
result[algorithm] = key_count result[algorithm] = key_count
return result return result
return self.runInteraction( return self.runInteraction(
@ -174,7 +174,7 @@ class EndToEndKeyStore(SQLBaseStore):
user_result = result.setdefault(user_id, {}) user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {}) device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm)) txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn.fetchall(): for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id)) delete.append((user_id, device_id, algorithm, key_id))
sql = ( sql = (

View file

@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
base_sql % (",".join(["?"] * len(chunk)),), base_sql % (",".join(["?"] * len(chunk)),),
chunk chunk
) )
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn])
new_front -= results new_front -= results
@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, False,)) txn.execute(sql, (room_id, False,))
return dict(txn.fetchall()) return dict(txn)
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, )) txn.execute(sql, (room_id, ))
results = [] results = []
for event_id, depth in txn.fetchall(): for event_id, depth in txn:
hashes = self._get_event_reference_hashes_txn(txn, event_id) hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = { prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() k: encode_base64(v) for k, v in hashes.items()
@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
def get_forward_extremeties_for_room_txn(txn): def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
rows = txn.fetchall() return [event_id for event_id, in txn]
return [event_id for event_id, in rows]
return self.runInteraction( return self.runInteraction(
"get_forward_extremeties_for_room", "get_forward_extremeties_for_room",
@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for row in txn.fetchall(): for row in txn:
if row[1] not in event_results: if row[1] not in event_results:
queue.put((-row[0], row[1])) queue.put((-row[0], row[1]))
@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for e_id, in txn.fetchall(): for e_id, in txn:
new_front.add(e_id) new_front.add(e_id)
new_front -= earliest_events new_front -= earliest_events

View file

@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
" stream_ordering >= ? AND stream_ordering <= ?" " stream_ordering >= ? AND stream_ordering <= ?"
) )
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f) ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret) defer.returnValue(ret)

View file

@ -834,7 +834,7 @@ class EventsStore(SQLBaseStore):
have_persisted = { have_persisted = {
event_id: outlier event_id: outlier
for event_id, outlier in txn.fetchall() for event_id, outlier in txn
} }
to_remove = set() to_remove = set()

View file

@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
), ),
(current_version,) (current_version,)
) )
applied_deltas = [d for d, in txn.fetchall()] applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded
return None return None

View file

@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
) )
txn.execute(sql, (room_id, receipt_type, user_id)) txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall()
if results and topological_ordering: if topological_ordering:
for to, so, _ in results: for to, so, _ in txn:
if int(to) > topological_ordering: if int(to) > topological_ordering:
return False return False
elif int(to) == topological_ordering and int(so) >= stream_ordering: elif int(to) == topological_ordering and int(so) >= stream_ordering:

View file

@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
" WHERE lower(name) = lower(?)" " WHERE lower(name) = lower(?)"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return dict(txn.fetchall()) return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f) return self.runInteraction("get_users_by_id_case_insensitive", f)

View file

@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
sql % ("AND appservice_id IS NULL",), sql % ("AND appservice_id IS NULL",),
(stream_id,) (stream_id,)
) )
return dict(txn.fetchall()) return dict(txn)
else: else:
# We want to get from all lists, so we need to aggregate the results # We want to get from all lists, so we need to aggregate the results
@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
results = {} results = {}
# A room is visible if its visible on any list. # A room is visible if its visible on any list.
for room_id, visibility in txn.fetchall(): for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False) results[room_id] = bool(visibility) or results.get(room_id, False)
return results return results

View file

@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()} return {k: v for k, v in txn}
def _store_event_reference_hashes_txn(self, txn, events): def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU """Store a hash for a PDU

View file

@ -373,10 +373,9 @@ class StateStore(SQLBaseStore):
" WHERE state_group = ? %s" % (where_clause,), " WHERE state_group = ? %s" % (where_clause,),
args args
) )
rows = txn.fetchall()
results[group].update({ results[group].update({
(typ, state_key): event_id (typ, state_key): event_id
for typ, state_key, event_id in rows for typ, state_key, event_id in txn
if (typ, state_key) not in results[group] if (typ, state_key) not in results[group]
}) })

View file

@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
for stream_id, user_id, room_id in tag_ids: for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
tags = [] tags = []
for tag, content in txn.fetchall(): for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content) tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}" tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json)) results.append((stream_id, user_id, room_id, tag_json))
@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND stream_id > ?" " WHERE user_id = ? AND stream_id > ?"
) )
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()] room_ids = [row[0] for row in txn]
return room_ids return room_ids
changed = self._account_data_stream_cache.has_entity_changed( changed = self._account_data_stream_cache.has_entity_changed(

View file

@ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_one_1col(self): def test_select_one_1col(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.fetchall.return_value = [("Value",)] self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore._simple_select_one_onecol( value = yield self.datastore._simple_select_one_onecol(
table="tablename", table="tablename",
@ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_list(self): def test_select_list(self):
self.mock_txn.rowcount = 3 self.mock_txn.rowcount = 3
self.mock_txn.fetchall.return_value = ((1,), (2,), (3,)) self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = ( self.mock_txn.description = (
("colA", None, None, None, None, None, None), ("colA", None, None, None, None, None, None),
) )