0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-19 08:24:25 +01:00

Merge pull request #2307 from matrix-org/erikj/user_ip_batch

Batch upsert user ips
This commit is contained in:
Erik Johnston 2017-06-27 15:08:32 +01:00 committed by GitHub
commit 816605a137
5 changed files with 101 additions and 48 deletions

View file

@ -23,7 +23,6 @@ from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.types import UserID from synapse.types import UserID
from synapse.util import logcontext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -200,7 +199,7 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
logcontext.preserve_fn(self.store.insert_client_ip)( self.store.insert_client_ip(
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,

View file

@ -106,7 +106,7 @@ class DeviceHandler(BaseHandler):
device_map = yield self.store.get_devices_by_user(user_id) device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device( ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys()) user_id, device_id=None
) )
devices = device_map.values() devices = device_map.values()
@ -133,7 +133,7 @@ class DeviceHandler(BaseHandler):
except errors.StoreError: except errors.StoreError:
raise errors.NotFoundError raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device( ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),) user_id, device_id,
) )
_update_device_from_client_ips(device, ips) _update_device_from_client_ips(device, ips)
defer.returnValue(device) defer.returnValue(device)

View file

@ -304,16 +304,6 @@ class DataStore(RoomMemberStore, RoomStore,
ret = yield self.runInteraction("count_users", _count_users) ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret) defer.returnValue(ret)
def get_user_ip_and_agents(self, user):
return self._simple_select_list(
table="user_ips",
keyvalues={"user_id": user.to_string()},
retcols=[
"access_token", "ip", "user_agent", "last_seen"
],
desc="get_user_ip_and_agents",
)
def get_users(self): def get_users(self):
"""Function to reterive a list of users in users table. """Function to reterive a list of users in users table.

View file

@ -15,7 +15,7 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer, reactor
from ._base import Cache from ._base import Cache
from . import background_updates from . import background_updates
@ -50,7 +50,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"], columns=["user_id", "device_id", "last_seen"],
) )
@defer.inlineCallbacks # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
def insert_client_ip(self, user, access_token, ip, user_agent, device_id): def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip) key = (user.to_string(), access_token, ip)
@ -62,34 +69,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Rate-limited inserts # Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None) return
self.client_ip_last_seen.prefill(key, now) self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint, self._batch_row_update[key] = (user_agent, device_id, now)
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
yield self._simple_upsert( def _update_client_ips_batch(self):
"user_ips", to_update = self._batch_row_update
self._batch_row_update = {}
return self.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips")
for entry in to_update.iteritems():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self._simple_upsert_txn(
txn,
table="user_ips",
keyvalues={ keyvalues={
"user_id": user.to_string(), "user_id": user_id,
"access_token": access_token, "access_token": access_token,
"ip": ip, "ip": ip,
"user_agent": user_agent, "user_agent": user_agent,
"device_id": device_id, "device_id": device_id,
}, },
values={ values={
"last_seen": now, "last_seen": last_seen,
}, },
desc="insert_client_ip",
lock=False, lock=False,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_last_client_ip_by_device(self, devices): def get_last_client_ip_by_device(self, user_id, device_id):
"""For each device_id listed, give the user_ip it was last seen on """For each device_id listed, give the user_ip it was last seen on
Args: Args:
devices (iterable[(str, str)]): list of (user_id, device_id) pairs user_id (str)
device_id (str): If None fetches all devices for the user
Returns: Returns:
defer.Deferred: resolves to a dict, where the keys defer.Deferred: resolves to a dict, where the keys
@ -100,6 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction( res = yield self.runInteraction(
"get_last_client_ip_by_device", "get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn, self._get_last_client_ip_by_device_txn,
user_id, device_id,
retcols=( retcols=(
"user_id", "user_id",
"access_token", "access_token",
@ -108,19 +130,30 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id", "device_id",
"last_seen", "last_seen",
), ),
devices=devices
) )
ret = {(d["user_id"], d["device_id"]): d for d in res} ret = {(d["user_id"], d["device_id"]): d for d in res}
for key in self._batch_row_update:
uid, access_token, ip = key
if uid == user_id:
user_agent, did, last_seen = self._batch_row_update[key]
if not device_id or did == device_id:
ret[(user_id, device_id)] = {
"user_id": user_id,
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"device_id": did,
"last_seen": last_seen,
}
defer.returnValue(ret) defer.returnValue(ret)
@classmethod @classmethod
def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = [] where_clauses = []
bindings = [] bindings = []
for (user_id, device_id) in devices:
if device_id is None: if device_id is None:
where_clauses.append("(user_id = ? AND device_id IS NULL)") where_clauses.append("user_id = ?")
bindings.extend((user_id, )) bindings.extend((user_id, ))
else: else:
where_clauses.append("(user_id = ? AND device_id = ?)") where_clauses.append("(user_id = ? AND device_id = ?)")
@ -152,3 +185,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute(sql, bindings) txn.execute(sql, bindings)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def get_user_ip_and_agents(self, user):
user_id = user.to_string()
results = {}
for key in self._batch_row_update:
uid, access_token, ip = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
rows = yield self._simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=[
"access_token", "ip", "user_agent", "last_seen"
],
desc="get_user_ip_and_agents",
)
results.update(
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
defer.returnValue(list(
{
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in results.iteritems()
))

View file

@ -43,10 +43,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
"access_token", "ip", "user_agent", "device_id", "access_token", "ip", "user_agent", "device_id",
) )
# deliberately use an iterable here to make sure that the lookup result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
# method doesn't iterate it twice
device_list = iter(((user_id, "device_id"),))
result = yield self.store.get_last_client_ip_by_device(device_list)
r = result[(user_id, "device_id")] r = result[(user_id, "device_id")]
self.assertDictContainsSubset( self.assertDictContainsSubset(