0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-18 10:38:21 +02:00

Refactor user_delete_access_tokens. Invalidate get_user_by_access_token to slaves.

This commit is contained in:
Erik Johnston 2016-08-15 17:04:39 +01:00
parent 75299af4fc
commit dc3a00f24f
3 changed files with 43 additions and 49 deletions

View file

@ -741,7 +741,7 @@ class AuthHandler(BaseHandler):
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else []
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
@ -750,10 +750,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids
user_id, except_access_token_id
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids
user_id, except_access_token_id
)
@defer.inlineCallbacks

View file

@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_token_ids=[]):
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access tokens ids %r",
user_id, except_token_ids
"Removing all pushers for user %s except access tokens id %r",
user_id, except_access_token_id
)
for p in all:
if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']

View file

@ -251,7 +251,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[],
def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None,
delete_refresh_tokens=False):
"""
@ -259,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args:
user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should
except_token_id (str): list of access_tokens IDs which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
@ -269,53 +269,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Returns:
defer.Deferred:
"""
def f(txn, table, except_tokens, call_after_delete):
sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
def f(txn):
keyvalues = {
"user_id": user_id,
}
if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
keyvalues["device_id"] = device_id
if except_tokens:
sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_tokens]),
)
clauses += except_tokens
txn.execute(sql, clauses)
rows = txn.fetchall()
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
txn.execute(
"DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
if delete_refresh_tokens:
self._simple_delete_txn(
txn,
table="refresh_tokens",
keyvalues=keyvalues,
)
# delete refresh tokens first, to stop new access tokens being
# allocated while our backs are turned
if delete_refresh_tokens:
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)
txn.execute(
"SELECT token FROM access_tokens WHERE %s" % where_clause,
values
)
rows = self.cursor_to_dict(txn)
for row in rows:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],)
)
txn.execute(
"DELETE FROM access_tokens WHERE %s" % where_clause,
values
)
yield self.runInteraction(
"user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token):
@ -328,7 +320,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
},
)
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f)