mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 10:13:48 +01:00
Claim local one-time-keys in bulk (#16565)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
91aa52c911
commit
de981ae567
4 changed files with 308 additions and 114 deletions
1
changelog.d/16565.feature
Normal file
1
changelog.d/16565.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve the performance of claiming encryption keys.
|
|
@ -753,6 +753,16 @@ class E2eKeysHandler:
|
||||||
async def upload_keys_for_user(
|
async def upload_keys_for_user(
|
||||||
self, user_id: str, device_id: str, keys: JsonDict
|
self, user_id: str, device_id: str, keys: JsonDict
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
user_id: user whose keys are being uploaded.
|
||||||
|
device_id: device whose keys are being uploaded.
|
||||||
|
keys: the body of a /keys/upload request.
|
||||||
|
|
||||||
|
Returns a dictionary with one field:
|
||||||
|
"one_time_keys": A mapping from algorithm to number of keys for that
|
||||||
|
algorithm, including those previously persisted.
|
||||||
|
"""
|
||||||
# This can only be called from the main process.
|
# This can only be called from the main process.
|
||||||
assert isinstance(self.device_handler, DeviceHandler)
|
assert isinstance(self.device_handler, DeviceHandler)
|
||||||
|
|
||||||
|
|
|
@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
...
|
...
|
||||||
|
|
||||||
async def claim_e2e_one_time_keys(
|
async def claim_e2e_one_time_keys(
|
||||||
self, query_list: Iterable[Tuple[str, str, str, int]]
|
self, query_list: Collection[Tuple[str, str, str, int]]
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
|
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
|
||||||
]:
|
]:
|
||||||
|
@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple pf:
|
A tuple (results, missing) of:
|
||||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||||
|
|
||||||
A copy of the input which has not been fulfilled.
|
A copy of the input which has not been fulfilled. The returned counts
|
||||||
|
may be less than the input counts. In this case, the returned counts
|
||||||
|
are the number of claims that were not fulfilled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@trace
|
|
||||||
def _claim_e2e_one_time_key_simple(
|
|
||||||
txn: LoggingTransaction,
|
|
||||||
user_id: str,
|
|
||||||
device_id: str,
|
|
||||||
algorithm: str,
|
|
||||||
count: int,
|
|
||||||
) -> List[Tuple[str, str]]:
|
|
||||||
"""Claim OTK for device for DBs that don't support RETURNING.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of key name (algorithm + key ID) and key JSON, if an
|
|
||||||
OTK was found.
|
|
||||||
"""
|
|
||||||
|
|
||||||
sql = """
|
|
||||||
SELECT key_id, key_json FROM e2e_one_time_keys_json
|
|
||||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
|
||||||
LIMIT ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(sql, (user_id, device_id, algorithm, count))
|
|
||||||
otk_rows = list(txn)
|
|
||||||
if not otk_rows:
|
|
||||||
return []
|
|
||||||
|
|
||||||
self.db_pool.simple_delete_many_txn(
|
|
||||||
txn,
|
|
||||||
table="e2e_one_time_keys_json",
|
|
||||||
column="key_id",
|
|
||||||
values=[otk_row[0] for otk_row in otk_rows],
|
|
||||||
keyvalues={
|
|
||||||
"user_id": user_id,
|
|
||||||
"device_id": device_id,
|
|
||||||
"algorithm": algorithm,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self._invalidate_cache_and_stream(
|
|
||||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
@trace
|
|
||||||
def _claim_e2e_one_time_key_returning(
|
|
||||||
txn: LoggingTransaction,
|
|
||||||
user_id: str,
|
|
||||||
device_id: str,
|
|
||||||
algorithm: str,
|
|
||||||
count: int,
|
|
||||||
) -> List[Tuple[str, str]]:
|
|
||||||
"""Claim OTK for device for DBs that support RETURNING.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of key name (algorithm + key ID) and key JSON, if an
|
|
||||||
OTK was found.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# We can use RETURNING to do the fetch and DELETE in once step.
|
|
||||||
sql = """
|
|
||||||
DELETE FROM e2e_one_time_keys_json
|
|
||||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
|
||||||
AND key_id IN (
|
|
||||||
SELECT key_id FROM e2e_one_time_keys_json
|
|
||||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
|
||||||
LIMIT ?
|
|
||||||
)
|
|
||||||
RETURNING key_id, key_json
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
sql,
|
|
||||||
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
|
|
||||||
)
|
|
||||||
otk_rows = list(txn)
|
|
||||||
if not otk_rows:
|
|
||||||
return []
|
|
||||||
|
|
||||||
self._invalidate_cache_and_stream(
|
|
||||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
missing: List[Tuple[str, str, str, int]] = []
|
missing: List[Tuple[str, str, str, int]] = []
|
||||||
for user_id, device_id, algorithm, count in query_list:
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
if self.database_engine.supports_returning:
|
# If we can use execute_values we can use a single batch query
|
||||||
# If we support RETURNING clause we can use a single query that
|
# in autocommit mode.
|
||||||
# allows us to use autocommit mode.
|
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
|
||||||
_claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
|
for user_id, device_id, algorithm, count in query_list:
|
||||||
db_autocommit = True
|
unfulfilled_claim_counts[user_id, device_id, algorithm] = count
|
||||||
else:
|
|
||||||
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
|
|
||||||
db_autocommit = False
|
|
||||||
|
|
||||||
claim_rows = await self.db_pool.runInteraction(
|
bulk_claims = await self.db_pool.runInteraction(
|
||||||
"claim_e2e_one_time_keys",
|
"claim_e2e_one_time_keys",
|
||||||
_claim_e2e_one_time_key,
|
self._claim_e2e_one_time_keys_bulk,
|
||||||
user_id,
|
query_list,
|
||||||
device_id,
|
db_autocommit=True,
|
||||||
algorithm,
|
|
||||||
count,
|
|
||||||
db_autocommit=db_autocommit,
|
|
||||||
)
|
)
|
||||||
if claim_rows:
|
|
||||||
|
for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
|
||||||
device_results = results.setdefault(user_id, {}).setdefault(
|
device_results = results.setdefault(user_id, {}).setdefault(
|
||||||
device_id, {}
|
device_id, {}
|
||||||
)
|
)
|
||||||
for claim_row in claim_rows:
|
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
|
||||||
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
|
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
|
||||||
|
|
||||||
# Did we get enough OTKs?
|
# Did we get enough OTKs?
|
||||||
count -= len(claim_rows)
|
missing = [
|
||||||
if count:
|
(user, device, alg, count)
|
||||||
missing.append((user_id, device_id, algorithm, count))
|
for (user, device, alg), count in unfulfilled_claim_counts.items()
|
||||||
|
if count > 0
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
for user_id, device_id, algorithm, count in query_list:
|
||||||
|
claim_rows = await self.db_pool.runInteraction(
|
||||||
|
"claim_e2e_one_time_keys",
|
||||||
|
self._claim_e2e_one_time_key_simple,
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
algorithm,
|
||||||
|
count,
|
||||||
|
db_autocommit=False,
|
||||||
|
)
|
||||||
|
if claim_rows:
|
||||||
|
device_results = results.setdefault(user_id, {}).setdefault(
|
||||||
|
device_id, {}
|
||||||
|
)
|
||||||
|
for claim_row in claim_rows:
|
||||||
|
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
|
||||||
|
# Did we get enough OTKs?
|
||||||
|
count -= len(claim_rows)
|
||||||
|
if count:
|
||||||
|
missing.append((user_id, device_id, algorithm, count))
|
||||||
|
|
||||||
return results, missing
|
return results, missing
|
||||||
|
|
||||||
|
@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@trace
|
||||||
|
def _claim_e2e_one_time_key_simple(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
algorithm: str,
|
||||||
|
count: int,
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
|
"""Claim OTK for device for DBs that don't support RETURNING.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of key name (algorithm + key ID) and key JSON, if an
|
||||||
|
OTK was found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT key_id, key_json FROM e2e_one_time_keys_json
|
||||||
|
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (user_id, device_id, algorithm, count))
|
||||||
|
otk_rows = list(txn)
|
||||||
|
if not otk_rows:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self.db_pool.simple_delete_many_txn(
|
||||||
|
txn,
|
||||||
|
table="e2e_one_time_keys_json",
|
||||||
|
column="key_id",
|
||||||
|
values=[otk_row[0] for otk_row in otk_rows],
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"algorithm": algorithm,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
|
||||||
|
|
||||||
|
@trace
|
||||||
|
def _claim_e2e_one_time_keys_bulk(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
query_list: Iterable[Tuple[str, str, str, int]],
|
||||||
|
) -> List[Tuple[str, str, str, str, str]]:
|
||||||
|
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_list: Collection of tuples (user_id, device_id, algorithm, count)
|
||||||
|
as passed to claim_e2e_one_time_keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
|
||||||
|
for each OTK claimed.
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
WITH claims(user_id, device_id, algorithm, claim_count) AS (
|
||||||
|
VALUES ?
|
||||||
|
), ranked_keys AS (
|
||||||
|
SELECT
|
||||||
|
user_id, device_id, algorithm, key_id, claim_count,
|
||||||
|
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
|
||||||
|
FROM e2e_one_time_keys_json
|
||||||
|
JOIN claims USING (user_id, device_id, algorithm)
|
||||||
|
)
|
||||||
|
DELETE FROM e2e_one_time_keys_json k
|
||||||
|
WHERE (user_id, device_id, algorithm, key_id) IN (
|
||||||
|
SELECT user_id, device_id, algorithm, key_id
|
||||||
|
FROM ranked_keys
|
||||||
|
WHERE r <= claim_count
|
||||||
|
)
|
||||||
|
RETURNING user_id, device_id, algorithm, key_id, key_json;
|
||||||
|
"""
|
||||||
|
otk_rows = cast(
|
||||||
|
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
|
||||||
|
)
|
||||||
|
|
||||||
|
seen_user_device: Set[Tuple[str, str]] = set()
|
||||||
|
for user_id, device_id, _, _, _ in otk_rows:
|
||||||
|
if (user_id, device_id) in seen_user_device:
|
||||||
|
continue
|
||||||
|
seen_user_device.add((user_id, device_id))
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return otk_rows
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -174,6 +174,164 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_claim_one_time_key_bulk(self) -> None:
|
||||||
|
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
|
||||||
|
# Apologies to the reader. This test is a little too verbose. It is particularly
|
||||||
|
# tricky to make assertions neatly with all these nested dictionaries in play.
|
||||||
|
|
||||||
|
# Three users with two devices each. Each device uses two algorithms.
|
||||||
|
# Each algorithm is invoked with two keys.
|
||||||
|
alice = f"@alice:{self.hs.hostname}"
|
||||||
|
brian = f"@brian:{self.hs.hostname}"
|
||||||
|
chris = f"@chris:{self.hs.hostname}"
|
||||||
|
one_time_keys = {
|
||||||
|
alice: {
|
||||||
|
"alice_dev_1": {
|
||||||
|
"alg1:k1": {"dummy_id": 1},
|
||||||
|
"alg1:k2": {"dummy_id": 2},
|
||||||
|
"alg2:k3": {"dummy_id": 3},
|
||||||
|
"alg2:k4": {"dummy_id": 4},
|
||||||
|
},
|
||||||
|
"alice_dev_2": {
|
||||||
|
"alg1:k5": {"dummy_id": 5},
|
||||||
|
"alg1:k6": {"dummy_id": 6},
|
||||||
|
"alg2:k7": {"dummy_id": 7},
|
||||||
|
"alg2:k8": {"dummy_id": 8},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
brian: {
|
||||||
|
"brian_dev_1": {
|
||||||
|
"alg1:k9": {"dummy_id": 9},
|
||||||
|
"alg1:k10": {"dummy_id": 10},
|
||||||
|
"alg2:k11": {"dummy_id": 11},
|
||||||
|
"alg2:k12": {"dummy_id": 12},
|
||||||
|
},
|
||||||
|
"brian_dev_2": {
|
||||||
|
"alg1:k13": {"dummy_id": 13},
|
||||||
|
"alg1:k14": {"dummy_id": 14},
|
||||||
|
"alg2:k15": {"dummy_id": 15},
|
||||||
|
"alg2:k16": {"dummy_id": 16},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
chris: {
|
||||||
|
"chris_dev_1": {
|
||||||
|
"alg1:k17": {"dummy_id": 17},
|
||||||
|
"alg1:k18": {"dummy_id": 18},
|
||||||
|
"alg2:k19": {"dummy_id": 19},
|
||||||
|
"alg2:k20": {"dummy_id": 20},
|
||||||
|
},
|
||||||
|
"chris_dev_2": {
|
||||||
|
"alg1:k21": {"dummy_id": 21},
|
||||||
|
"alg1:k22": {"dummy_id": 22},
|
||||||
|
"alg2:k23": {"dummy_id": 23},
|
||||||
|
"alg2:k24": {"dummy_id": 24},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for user_id, devices in one_time_keys.items():
|
||||||
|
for device_id, keys_dict in devices.items():
|
||||||
|
counts = self.get_success(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
{"one_time_keys": keys_dict},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# The upload should report 2 keys per algorithm.
|
||||||
|
expected_counts = {
|
||||||
|
"one_time_key_counts": {
|
||||||
|
# See count_e2e_one_time_keys for why this is hardcoded.
|
||||||
|
"signed_curve25519": 0,
|
||||||
|
"alg1": 2,
|
||||||
|
"alg2": 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(counts, expected_counts)
|
||||||
|
|
||||||
|
# Claim a variety of keys.
|
||||||
|
# Raw format, easier to make test assertions about.
|
||||||
|
claims_to_make = {
|
||||||
|
(alice, "alice_dev_1", "alg1"): 1,
|
||||||
|
(alice, "alice_dev_1", "alg2"): 2,
|
||||||
|
(alice, "alice_dev_2", "alg2"): 1,
|
||||||
|
(brian, "brian_dev_1", "alg1"): 2,
|
||||||
|
(brian, "brian_dev_2", "alg2"): 9001,
|
||||||
|
(chris, "chris_dev_2", "alg2"): 1,
|
||||||
|
}
|
||||||
|
# Convert to the format the handler wants.
|
||||||
|
query: Dict[str, Dict[str, Dict[str, int]]] = {}
|
||||||
|
for (user_id, device_id, algorithm), count in claims_to_make.items():
|
||||||
|
query.setdefault(user_id, {}).setdefault(device_id, {})[algorithm] = count
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
query,
|
||||||
|
self.requester,
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# No failures, please!
|
||||||
|
self.assertEqual(claim_res["failures"], {})
|
||||||
|
|
||||||
|
# Check that we get exactly the (user, device, algorithm)s we asked for.
|
||||||
|
got_otks = claim_res["one_time_keys"]
|
||||||
|
claimed_user_device_algorithms = {
|
||||||
|
(user_id, device_id, alg_key_id.split(":")[0])
|
||||||
|
for user_id, devices in got_otks.items()
|
||||||
|
for device_id, key_dict in devices.items()
|
||||||
|
for alg_key_id in key_dict
|
||||||
|
}
|
||||||
|
self.assertEqual(claimed_user_device_algorithms, set(claims_to_make))
|
||||||
|
|
||||||
|
# Now check the keys we got are what we expected.
|
||||||
|
def assertExactlyOneOtk(
|
||||||
|
user_id: str, device_id: str, *alg_key_pairs: str
|
||||||
|
) -> None:
|
||||||
|
key_dict = got_otks[user_id][device_id]
|
||||||
|
found = 0
|
||||||
|
for alg_key in alg_key_pairs:
|
||||||
|
if alg_key in key_dict:
|
||||||
|
expected_key_json = one_time_keys[user_id][device_id][alg_key]
|
||||||
|
self.assertEqual(key_dict[alg_key], expected_key_json)
|
||||||
|
found += 1
|
||||||
|
self.assertEqual(found, 1)
|
||||||
|
|
||||||
|
def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
|
||||||
|
key_dict = got_otks[user_id][device_id]
|
||||||
|
for alg_key in alg_key_pairs:
|
||||||
|
expected_key_json = one_time_keys[user_id][device_id][alg_key]
|
||||||
|
self.assertEqual(key_dict[alg_key], expected_key_json)
|
||||||
|
|
||||||
|
# Expect a single arbitrary key to be returned.
|
||||||
|
assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2")
|
||||||
|
assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg2:k8")
|
||||||
|
assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg2:k24")
|
||||||
|
|
||||||
|
assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4")
|
||||||
|
assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10")
|
||||||
|
assertAllOtks(brian, "brian_dev_2", "alg2:k15", "alg2:k16")
|
||||||
|
|
||||||
|
# Now check the unused key counts.
|
||||||
|
for user_id, devices in one_time_keys.items():
|
||||||
|
for device_id in devices:
|
||||||
|
counts_by_alg = self.get_success(
|
||||||
|
self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
|
)
|
||||||
|
# Somewhat fiddley to compute the expected count dict.
|
||||||
|
expected_counts_by_alg = {
|
||||||
|
"signed_curve25519": 0,
|
||||||
|
}
|
||||||
|
for alg in ["alg1", "alg2"]:
|
||||||
|
claim_count = claims_to_make.get((user_id, device_id, alg), 0)
|
||||||
|
remaining_count = max(0, 2 - claim_count)
|
||||||
|
if remaining_count > 0:
|
||||||
|
expected_counts_by_alg[alg] = remaining_count
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_fallback_key(self) -> None:
|
def test_fallback_key(self) -> None:
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
|
|
Loading…
Reference in a new issue