Refactor user_delete_access_tokens. Invalidate get_user_by_access_token to slaves.

This commit is contained in:
Erik Johnston 2016-08-15 17:04:39 +01:00
parent 75299af4fc
commit dc3a00f24f
3 changed files with 43 additions and 49 deletions

View file

@ -741,7 +741,7 @@ class AuthHandler(BaseHandler):
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else [] except_access_token_id = requester.access_token_id if requester else None
try: try:
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
@ -750,10 +750,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids user_id, except_access_token_id
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids user_id, except_access_token_id
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_token_ids=[]): def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access tokens ids %r", "Removing all pushers for user %s except access tokens id %r",
user_id, except_token_ids user_id, except_access_token_id
) )
for p in all: for p in all:
if p['user_name'] == user_id and p['access_token'] not in except_token_ids: if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']

View file

@ -251,7 +251,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
self.get_user_by_id.invalidate((user_id,)) self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[], def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None, device_id=None,
delete_refresh_tokens=False): delete_refresh_tokens=False):
""" """
@ -259,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args: Args:
user_id (str): ID of user the tokens belong to user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should except_token_id (str): list of access_tokens IDs which should
*not* be deleted *not* be deleted
device_id (str|None): ID of device the tokens are associated with. device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
@ -269,53 +269,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Returns: Returns:
defer.Deferred: defer.Deferred:
""" """
def f(txn, table, except_tokens, call_after_delete): def f(txn):
sql = "SELECT token FROM %s WHERE user_id = ?" % table keyvalues = {
clauses = [user_id] "user_id": user_id,
}
if device_id is not None: if device_id is not None:
sql += " AND device_id = ?" keyvalues["device_id"] = device_id
clauses.append(device_id)
if except_tokens: if delete_refresh_tokens:
sql += " AND id NOT IN (%s)" % ( self._simple_delete_txn(
",".join(["?" for _ in except_tokens]), txn,
) table="refresh_tokens",
clauses += except_tokens keyvalues=keyvalues,
txn.execute(sql, clauses)
rows = txn.fetchall()
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
txn.execute(
"DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
) )
# delete refresh tokens first, to stop new access tokens being items = keyvalues.items()
# allocated while our backs are turned where_clause = " AND ".join(k + " = ?" for k, _ in items)
if delete_refresh_tokens: values = [v for _, v in items]
yield self.runInteraction( if except_token_id:
"user_delete_access_tokens", f, where_clause += " AND id != ?"
table="refresh_tokens", values.append(except_token_id)
except_tokens=[],
call_after_delete=None, txn.execute(
"SELECT token FROM access_tokens WHERE %s" % where_clause,
values
)
rows = self.cursor_to_dict(txn)
for row in rows:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],)
)
txn.execute(
"DELETE FROM access_tokens WHERE %s" % where_clause,
values
) )
yield self.runInteraction( yield self.runInteraction(
"user_delete_access_tokens", f, "user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
) )
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
@ -328,7 +320,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
}, },
) )
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f) return self.runInteraction("delete_access_token", f)