mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 18:13:54 +01:00
Make StreamIdGen get_next
and get_next_mult
async (#8161)
This is mainly so that `StreamIdGenerator` and `MultiWriterIdGenerator` will have the same interface, allowing them to be used interchangeably.
This commit is contained in:
parent
74bf8d4d06
commit
2231dffee6
14 changed files with 54 additions and 49 deletions
1
changelog.d/8161.misc
Normal file
1
changelog.d/8161.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.
|
|
@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
"""
|
"""
|
||||||
content_json = json_encoder.encode(content)
|
content_json = json_encoder.encode(content)
|
||||||
|
|
||||||
with self._account_data_id_gen.get_next() as next_id:
|
with await self._account_data_id_gen.get_next() as next_id:
|
||||||
# no need to lock here as room_account_data has a unique constraint
|
# no need to lock here as room_account_data has a unique constraint
|
||||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||||
# retry if there is a conflict.
|
# retry if there is a conflict.
|
||||||
|
@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
"""
|
"""
|
||||||
content_json = json_encoder.encode(content)
|
content_json = json_encoder.encode(content)
|
||||||
|
|
||||||
with self._account_data_id_gen.get_next() as next_id:
|
with await self._account_data_id_gen.get_next() as next_id:
|
||||||
# no need to lock here as account_data has a unique constraint on
|
# no need to lock here as account_data has a unique constraint on
|
||||||
# (user_id, account_data_type) so simple_upsert will retry if
|
# (user_id, account_data_type) so simple_upsert will retry if
|
||||||
# there is a conflict.
|
# there is a conflict.
|
||||||
|
|
|
@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||||
rows.append((destination, stream_id, now_ms, edu_json))
|
rows.append((destination, stream_id, now_ms, edu_json))
|
||||||
txn.executemany(sql, rows)
|
txn.executemany(sql, rows)
|
||||||
|
|
||||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self.clock.time_msec()
|
now_ms = self.clock.time_msec()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||||
|
@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||||
txn, stream_id, local_messages_by_user_then_device
|
txn, stream_id, local_messages_by_user_then_device
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self.clock.time_msec()
|
now_ms = self.clock.time_msec()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_from_remote_to_device_inbox",
|
"add_messages_from_remote_to_device_inbox",
|
||||||
|
|
|
@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
THe new stream ID.
|
THe new stream ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self._device_list_id_gen.get_next() as stream_id:
|
with await self._device_list_id_gen.get_next() as stream_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_user_sig_change_to_streams",
|
"add_user_sig_change_to_streams",
|
||||||
self._add_user_signature_change_txn,
|
self._add_user_signature_change_txn,
|
||||||
|
@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
if not device_ids:
|
if not device_ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
|
with await self._device_list_id_gen.get_next_mult(
|
||||||
|
len(device_ids)
|
||||||
|
) as stream_ids:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_device_change_to_stream",
|
"add_device_change_to_stream",
|
||||||
self._add_device_change_to_stream_txn,
|
self._add_device_change_to_stream_txn,
|
||||||
|
@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
return stream_ids[-1]
|
return stream_ids[-1]
|
||||||
|
|
||||||
context = get_active_span_text_map()
|
context = get_active_span_text_map()
|
||||||
with self._device_list_id_gen.get_next_mult(
|
with await self._device_list_id_gen.get_next_mult(
|
||||||
len(hosts) * len(device_ids)
|
len(hosts) * len(device_ids)
|
||||||
) as stream_ids:
|
) as stream_ids:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
|
|
@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
|
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
|
||||||
"""Set a user's cross-signing key.
|
"""Set a user's cross-signing key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
for a master key, 'self_signing' for a self-signing key, or
|
for a master key, 'self_signing' for a self-signing key, or
|
||||||
'user_signing' for a user-signing key
|
'user_signing' for a user-signing key
|
||||||
key (dict): the key data
|
key (dict): the key data
|
||||||
|
stream_id (int)
|
||||||
"""
|
"""
|
||||||
# the 'key' dict will look something like:
|
# the 'key' dict will look something like:
|
||||||
# {
|
# {
|
||||||
|
@ -695,7 +696,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
# and finally, store the key itself
|
# and finally, store the key itself
|
||||||
with self._cross_signing_id_gen.get_next() as stream_id:
|
|
||||||
self.db_pool.simple_insert_txn(
|
self.db_pool.simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
"e2e_cross_signing_keys",
|
"e2e_cross_signing_keys",
|
||||||
|
@ -711,7 +711,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
async def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||||
"""Set a user's cross-signing key.
|
"""Set a user's cross-signing key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -719,12 +719,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
key_type (str): the type of cross-signing key to set
|
key_type (str): the type of cross-signing key to set
|
||||||
key (dict): the key data
|
key (dict): the key data
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
|
||||||
|
with await self._cross_signing_id_gen.get_next() as stream_id:
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
"add_e2e_cross_signing_key",
|
"add_e2e_cross_signing_key",
|
||||||
self._set_e2e_cross_signing_key_txn,
|
self._set_e2e_cross_signing_key_txn,
|
||||||
user_id,
|
user_id,
|
||||||
key_type,
|
key_type,
|
||||||
key,
|
key,
|
||||||
|
stream_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_e2e_cross_signing_signatures(self, user_id, signatures):
|
def store_e2e_cross_signing_signatures(self, user_id, signatures):
|
||||||
|
|
|
@ -153,11 +153,11 @@ class PersistEventsStore:
|
||||||
# Note: Multiple instances of this function cannot be in flight at
|
# Note: Multiple instances of this function cannot be in flight at
|
||||||
# the same time for the same room.
|
# the same time for the same room.
|
||||||
if backfilled:
|
if backfilled:
|
||||||
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
|
||||||
len(events_and_contexts)
|
len(events_and_contexts)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
|
||||||
len(events_and_contexts)
|
len(events_and_contexts)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
|
|
||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
with self._group_updates_id_gen.get_next() as next_id:
|
with await self._group_updates_id_gen.get_next() as next_id:
|
||||||
res = await self.db_pool.runInteraction(
|
res = await self.db_pool.runInteraction(
|
||||||
"register_user_group_membership",
|
"register_user_group_membership",
|
||||||
_register_user_group_membership_txn,
|
_register_user_group_membership_txn,
|
||||||
|
|
|
@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
class PresenceStore(SQLBaseStore):
|
class PresenceStore(SQLBaseStore):
|
||||||
async def update_presence(self, presence_states):
|
async def update_presence(self, presence_states):
|
||||||
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
|
||||||
len(presence_states)
|
len(presence_states)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
) -> None:
|
) -> None:
|
||||||
conditions_json = json_encoder.encode(conditions)
|
conditions_json = json_encoder.encode(conditions)
|
||||||
actions_json = json_encoder.encode(actions)
|
actions_json = json_encoder.encode(actions)
|
||||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
if before or after:
|
if before or after:
|
||||||
|
@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
|
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
|
||||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
data={"actions": actions_json},
|
data={"actions": actions_json},
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
|
|
@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
|
||||||
last_stream_ordering,
|
last_stream_ordering,
|
||||||
profile_tag="",
|
profile_tag="",
|
||||||
) -> None:
|
) -> None:
|
||||||
with self._pushers_id_gen.get_next() as stream_id:
|
with await self._pushers_id_gen.get_next() as stream_id:
|
||||||
# no need to lock because `pushers` has a unique key on
|
# no need to lock because `pushers` has a unique key on
|
||||||
# (app_id, pushkey, user_name) so simple_upsert will retry
|
# (app_id, pushkey, user_name) so simple_upsert will retry
|
||||||
await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
|
@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._pushers_id_gen.get_next() as stream_id:
|
with await self._pushers_id_gen.get_next() as stream_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_pusher", delete_pusher_txn, stream_id
|
"delete_pusher", delete_pusher_txn, stream_id
|
||||||
)
|
)
|
||||||
|
|
|
@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||||
"insert_receipt_conv", graph_to_linear
|
"insert_receipt_conv", graph_to_linear
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_id_manager = self._receipts_id_gen.get_next()
|
with await self._receipts_id_gen.get_next() as stream_id:
|
||||||
with stream_id_manager as stream_id:
|
|
||||||
event_ts = await self.db_pool.runInteraction(
|
event_ts = await self.db_pool.runInteraction(
|
||||||
"insert_linearized_receipt",
|
"insert_linearized_receipt",
|
||||||
self.insert_linearized_receipt_txn,
|
self.insert_linearized_receipt_txn,
|
||||||
|
|
|
@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._public_room_id_gen.get_next() as next_id:
|
with await self._public_room_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"store_room_txn", store_room_txn, next_id
|
"store_room_txn", store_room_txn, next_id
|
||||||
)
|
)
|
||||||
|
@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._public_room_id_gen.get_next() as next_id:
|
with await self._public_room_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"set_room_is_public", set_room_is_public_txn, next_id
|
"set_room_is_public", set_room_is_public_txn, next_id
|
||||||
)
|
)
|
||||||
|
@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._public_room_id_gen.get_next() as next_id:
|
with await self._public_room_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"set_room_is_public_appservice",
|
"set_room_is_public_appservice",
|
||||||
set_room_is_public_appservice_txn,
|
set_room_is_public_appservice_txn,
|
||||||
|
|
|
@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
|
||||||
)
|
)
|
||||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||||
|
|
||||||
with self._account_data_id_gen.get_next() as next_id:
|
with await self._account_data_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
|
||||||
txn.execute(sql, (user_id, room_id, tag))
|
txn.execute(sql, (user_id, room_id, tag))
|
||||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||||
|
|
||||||
with self._account_data_id_gen.get_next() as next_id:
|
with await self._account_data_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
|
|
@ -80,7 +80,7 @@ class StreamIdGenerator(object):
|
||||||
upwards, -1 to grow downwards.
|
upwards, -1 to grow downwards.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
with stream_id_gen.get_next() as stream_id:
|
with await stream_id_gen.get_next() as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -95,10 +95,10 @@ class StreamIdGenerator(object):
|
||||||
)
|
)
|
||||||
self._unfinished_ids = deque() # type: Deque[int]
|
self._unfinished_ids = deque() # type: Deque[int]
|
||||||
|
|
||||||
def get_next(self):
|
async def get_next(self):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with stream_id_gen.get_next() as stream_id:
|
with await stream_id_gen.get_next() as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -117,10 +117,10 @@ class StreamIdGenerator(object):
|
||||||
|
|
||||||
return manager()
|
return manager()
|
||||||
|
|
||||||
def get_next_mult(self, n):
|
async def get_next_mult(self, n):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with stream_id_gen.get_next(n) as stream_ids:
|
with await stream_id_gen.get_next(n) as stream_ids:
|
||||||
# ... persist events ...
|
# ... persist events ...
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|
Loading…
Reference in a new issue