159 lines
6.5 KiB
Python
159 lines
6.5 KiB
Python
# Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# -----------------------------------------------------------------------
|
|
#
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import time
|
|
import torch
|
|
import tqdm
|
|
|
|
class _TestNegSampler:
|
|
def __init__(self, train_ratings, nb_neg):
|
|
self.nb_neg = nb_neg
|
|
self.nb_users = int(train_ratings[:, 0].max()) + 1
|
|
self.nb_items = int(train_ratings[:, 1].max()) + 1
|
|
|
|
# compute unique ids for quickly created hash set and fast lookup
|
|
ids = (train_ratings[:, 0] * self.nb_items) + train_ratings[:, 1]
|
|
self.set = set(ids)
|
|
|
|
def generate(self, batch_size=128*1024):
|
|
users = torch.arange(0, self.nb_users).reshape([1, -1]).repeat([self.nb_neg, 1]).transpose(0, 1).reshape(-1)
|
|
|
|
items = [-1] * len(users)
|
|
|
|
random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()
|
|
print('Generating validation negatives...')
|
|
for idx, u in enumerate(tqdm.tqdm(users.tolist())):
|
|
if not random_items:
|
|
random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()
|
|
j = random_items.pop()
|
|
while u * self.nb_items + j in self.set:
|
|
if not random_items:
|
|
random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()
|
|
j = random_items.pop()
|
|
|
|
items[idx] = j
|
|
items = torch.LongTensor(items)
|
|
return items
|
|
|
|
|
|
def create_test_data(train_ratings, test_ratings, args):
|
|
test_users = test_ratings[:,0]
|
|
test_pos = test_ratings[:,1].reshape(-1,1)
|
|
|
|
begin = time.time()
|
|
sampler = _TestNegSampler(train_ratings.cpu().numpy(), args.valid_negative)
|
|
test_negs = sampler.generate().cuda()
|
|
end = time.time()
|
|
print('Generating validation negatives took: ', end - begin)
|
|
del train_ratings
|
|
|
|
# create items with real sample at last position
|
|
test_users = test_users.reshape(-1,1).repeat(1, 1 + args.valid_negative)
|
|
test_items = torch.cat((test_negs.reshape(-1, args.valid_negative), test_pos), dim=1)
|
|
del test_ratings, test_negs
|
|
|
|
# generate dup mask and real indices for exact same behavior on duplication compare to reference
|
|
# here we need a sort that is stable(keep order of duplicates)
|
|
sorted_items, indices = torch.sort(test_items) # [1,1,1,2], [3,1,0,2]
|
|
sum_item_indices = sorted_items.float()+indices.float()/len(indices[0]) #[1.75,1.25,1.0,2.5]
|
|
indices_order = torch.sort(sum_item_indices)[1] #[2,1,0,3]
|
|
stable_indices = torch.gather(indices, 1, indices_order) #[0,1,3,2]
|
|
# produce -1 mask
|
|
dup_mask = (sorted_items[:,0:-1] == sorted_items[:,1:])
|
|
dup_mask = torch.cat((torch.zeros_like(test_pos, dtype=torch.uint8), dup_mask),dim=1)
|
|
dup_mask = torch.gather(dup_mask,1,stable_indices.sort()[1])
|
|
# produce real sample indices to later check in topk
|
|
sorted_items, indices = (test_items != test_pos).sort()
|
|
sum_item_indices = sorted_items.float()+indices.float()/len(indices[0])
|
|
indices_order = torch.sort(sum_item_indices)[1]
|
|
stable_indices = torch.gather(indices, 1, indices_order)
|
|
real_indices = stable_indices[:,0]
|
|
|
|
if args.distributed:
|
|
test_users = torch.chunk(test_users, args.world_size)[args.local_rank]
|
|
test_items = torch.chunk(test_items, args.world_size)[args.local_rank]
|
|
dup_mask = torch.chunk(dup_mask, args.world_size)[args.local_rank]
|
|
real_indices = torch.chunk(real_indices, args.world_size)[args.local_rank]
|
|
|
|
test_users = test_users.view(-1).split(args.valid_batch_size)
|
|
test_items = test_items.view(-1).split(args.valid_batch_size)
|
|
|
|
return test_users, test_items, dup_mask, real_indices
|
|
|
|
|
|
def prepare_epoch_train_data(train_ratings, nb_items, args):
|
|
# create label
|
|
train_label = torch.ones_like(train_ratings[:,0], dtype=torch.float32)
|
|
neg_label = torch.zeros_like(train_label, dtype=torch.float32)
|
|
neg_label = neg_label.repeat(args.negative_samples)
|
|
train_label = torch.cat((train_label,neg_label))
|
|
del neg_label
|
|
|
|
train_users = train_ratings[:,0]
|
|
train_items = train_ratings[:,1]
|
|
|
|
train_users_per_worker = len(train_label) / args.world_size
|
|
train_users_begin = int(train_users_per_worker * args.local_rank)
|
|
train_users_end = int(train_users_per_worker * (args.local_rank + 1))
|
|
|
|
# prepare data for epoch
|
|
neg_users = train_users.repeat(args.negative_samples)
|
|
neg_items = torch.empty_like(neg_users, dtype=torch.int64).random_(0, nb_items)
|
|
|
|
epoch_users = torch.cat((train_users, neg_users))
|
|
epoch_items = torch.cat((train_items, neg_items))
|
|
|
|
del neg_users, neg_items
|
|
|
|
# shuffle prepared data and split into batches
|
|
epoch_indices = torch.randperm(train_users_end - train_users_begin, device='cuda:{}'.format(args.local_rank))
|
|
epoch_indices += train_users_begin
|
|
|
|
epoch_users = epoch_users[epoch_indices]
|
|
epoch_items = epoch_items[epoch_indices]
|
|
epoch_label = train_label[epoch_indices]
|
|
|
|
if args.distributed:
|
|
local_batch = args.batch_size // args.world_size
|
|
else:
|
|
local_batch = args.batch_size
|
|
|
|
epoch_users = epoch_users.split(local_batch)
|
|
epoch_items = epoch_items.split(local_batch)
|
|
epoch_label = epoch_label.split(local_batch)
|
|
|
|
# the last batch will almost certainly be smaller, drop it
|
|
epoch_users = epoch_users[:-1]
|
|
epoch_items = epoch_items[:-1]
|
|
epoch_label = epoch_label[:-1]
|
|
|
|
return epoch_users, epoch_items, epoch_label
|
|
|