DeepLearningExamples/PyTorch/Recommendation/NCF/ncf.py
Przemek Strzelczyk 0663b67c1a Updating models
2019-07-08 22:51:28 +02:00

350 lines
14 KiB
Python

# Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.jit
from apex.optimizers import FusedAdam
import logging
import os
import sys
import math
import time
from argparse import ArgumentParser
import torch
import torch.nn as nn
import utils
import dataloading
from neumf import NeuMF
from logger.logger import LOGGER, timed_block, timed_function
from logger import tags
from logger.autologging import log_hardware, log_args
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
LOGGER.model = 'ncf'
def parse_args():
parser = ArgumentParser(description="Train a Nerual Collaborative"
" Filtering model")
parser.add_argument('--data', type=str,
help='Path to test and training data files')
parser.add_argument('-e', '--epochs', type=int, default=30,
help='Number of epochs for training')
parser.add_argument('-b', '--batch_size', type=int, default=2**20,
help='Number of examples for each iteration')
parser.add_argument('--valid_batch_size', type=int, default=2**20,
help='Number of examples in each validation chunk')
parser.add_argument('-f', '--factors', type=int, default=64,
help='Number of predictive factors')
parser.add_argument('--layers', nargs='+', type=int,
default=[256, 256, 128, 64],
help='Sizes of hidden layers for MLP')
parser.add_argument('-n', '--negative_samples', type=int, default=4,
help='Number of negative examples per interaction')
parser.add_argument('-l', '--learning_rate', type=float, default=0.0045,
help='Learning rate for optimizer')
parser.add_argument('-k', '--topk', type=int, default=10,
help='Rank for test examples to be considered a hit')
parser.add_argument('--seed', '-s', type=int, default=1,
help='Manually set random seed for torch')
parser.add_argument('--threshold', '-t', type=float, default=1.0,
help='Stop training early at threshold')
parser.add_argument('--valid_negative', type=int, default=100,
help='Number of negative samples for each positive test example')
parser.add_argument('--beta1', '-b1', type=float, default=0.25,
help='Beta1 for Adam')
parser.add_argument('--beta2', '-b2', type=float, default=0.5,
help='Beta1 for Adam')
parser.add_argument('--eps', type=float, default=1e-8,
help='Epsilon for Adam')
parser.add_argument('--dropout', type=float, default=0.5,
help='Dropout probability, if equal to 0 will not use dropout at all')
parser.add_argument('--checkpoint_dir', default='/data/checkpoints/', type=str,
help='Path to the directory storing the checkpoint file')
parser.add_argument('--mode', choices=['train', 'test'], default='train', type=str,
help='Passing "test" will only run a single evaluation, otherwise full training will be performed')
parser.add_argument('--grads_accumulated', default=1, type=int,
help='Number of gradients to accumulate before performing an optimization step')
parser.add_argument('--opt_level', default='O2', type=str,
help='Optimization level for Automatic Mixed Precision',
choices=['O0', 'O2'])
parser.add_argument('--local_rank', default=0, type=int, help='Necessary for multi-GPU training')
return parser.parse_args()
def init_distributed(local_rank=0):
distributed = int(os.environ['WORLD_SIZE']) > 1
if distributed:
'''
Set cuda device so everything is done on the right GPU.
THIS MUST BE DONE AS SOON AS POSSIBLE.
'''
torch.cuda.set_device(local_rank)
'''Initialize distributed communication'''
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
logging_logger = logging.getLogger('mlperf_compliance')
if local_rank > 0:
sys.stdout = open('/dev/null', 'w')
sys.stderr = open('/dev/null', 'w')
logging_logger.setLevel(logging.ERROR)
logging_nvlogger = logging.getLogger('nv_dl_logger')
if local_rank > 0:
sys.stdout = open('/dev/null', 'w')
sys.stderr = open('/dev/null', 'w')
logging_nvlogger.setLevel(logging.ERROR)
return distributed, int(os.environ['WORLD_SIZE'])
def val_epoch(model, x, y, dup_mask, real_indices, K, samples_per_user, num_user,
epoch=None, distributed=False):
model.eval()
with torch.no_grad():
p = []
for u,n in zip(x,y):
p.append(model(u, n, sigmoid=True).detach())
temp = torch.cat(p).view(-1,samples_per_user)
del x, y, p
# set duplicate results for the same item to -1 before topk
temp[dup_mask] = -1
out = torch.topk(temp,K)[1]
# topk in pytorch is stable(if not sort)
# key(item):value(prediction) pairs are ordered as original key(item) order
# so we need the first position of real item(stored in real_indices) to check if it is in topk
ifzero = (out == real_indices.view(-1,1))
hits = ifzero.sum()
ndcg = (math.log(2) / (torch.nonzero(ifzero)[:,1].view(-1).to(torch.float)+2).log_()).sum()
LOGGER.log(key=tags.EVAL_SIZE, value={"epoch": epoch, "value": num_user * samples_per_user})
LOGGER.log(key=tags.EVAL_HP_NUM_USERS, value=num_user)
LOGGER.log(key=tags.EVAL_HP_NUM_NEG, value=samples_per_user - 1)
if distributed:
torch.distributed.all_reduce(hits, op=torch.distributed.reduce_op.SUM)
torch.distributed.all_reduce(ndcg, op=torch.distributed.reduce_op.SUM)
hr = hits.item() / num_user
ndcg = ndcg.item() / num_user
model.train()
return hr, ndcg
def main():
log_hardware()
args = parse_args()
args.distributed, args.world_size = init_distributed(args.local_rank)
log_args(args)
main_start_time = time.time()
if args.seed is not None:
torch.manual_seed(args.seed)
print("Saving results to {}".format(args.checkpoint_dir))
if not os.path.exists(args.checkpoint_dir) and args.checkpoint_dir != '':
os.makedirs(args.checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(args.checkpoint_dir, 'model.pth')
LOGGER.log(key=tags.PREPROC_HP_NUM_EVAL, value=args.valid_negative)
# The default of np.random.choice is replace=True, so does pytorch random_()
LOGGER.log(key=tags.PREPROC_HP_SAMPLE_EVAL_REPLACEMENT, value=True)
LOGGER.log(key=tags.INPUT_HP_SAMPLE_TRAIN_REPLACEMENT, value=True)
LOGGER.log(key=tags.INPUT_STEP_EVAL_NEG_GEN)
# sync workers before timing
if args.distributed:
torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
torch.cuda.synchronize()
LOGGER.log(key=tags.RUN_START)
train_ratings = torch.load(args.data+'/train_ratings.pt', map_location=torch.device('cuda:{}'.format(args.local_rank)))
test_ratings = torch.load(args.data+'/test_ratings.pt', map_location=torch.device('cuda:{}'.format(args.local_rank)))
nb_maxs = torch.max(train_ratings, 0)[0]
nb_users = nb_maxs[0].item() + 1
nb_items = nb_maxs[1].item() + 1
LOGGER.log(key=tags.INPUT_SIZE, value=len(train_ratings))
all_test_users = test_ratings.shape[0]
test_users, test_items, dup_mask, real_indices = dataloading.create_test_data(train_ratings, test_ratings, args)
# make pytorch memory behavior more consistent later
torch.cuda.empty_cache()
LOGGER.log(key=tags.INPUT_BATCH_SIZE, value=args.batch_size)
LOGGER.log(key=tags.INPUT_ORDER) # we shuffled later with randperm
# Create model
model = NeuMF(nb_users, nb_items,
mf_dim=args.factors,
mlp_layer_sizes=args.layers,
dropout=args.dropout)
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate,
betas=(args.beta1, args.beta2), eps=args.eps, eps_inside_sqrt=False)
criterion = nn.BCEWithLogitsLoss(reduction='none') # use torch.mean() with dim later to avoid copy to host
# Move model and loss to GPU
model = model.cuda()
criterion = criterion.cuda()
if args.opt_level == "O2":
model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level,
keep_batchnorm_fp32=False, loss_scale='dynamic')
if args.distributed:
model = DDP(model)
local_batch = args.batch_size // args.world_size
traced_criterion = torch.jit.trace(criterion.forward,
(torch.rand(local_batch,1),torch.rand(local_batch,1)))
print(model)
print("{} parameters".format(utils.count_parameters(model)))
LOGGER.log(key=tags.OPT_LR, value=args.learning_rate)
LOGGER.log(key=tags.OPT_NAME, value="Adam")
LOGGER.log(key=tags.OPT_HP_ADAM_BETA1, value=args.beta1)
LOGGER.log(key=tags.OPT_HP_ADAM_BETA2, value=args.beta2)
LOGGER.log(key=tags.OPT_HP_ADAM_EPSILON, value=args.eps)
LOGGER.log(key=tags.MODEL_HP_LOSS_FN, value=tags.VALUE_BCE)
if args.mode == 'test':
state_dict = torch.load(checkpoint_path)
model.load_state_dict(state_dict)
hr, ndcg = val_epoch(model, test_users, test_items, dup_mask, real_indices, args.topk,
samples_per_user=args.valid_negative + 1,
num_user=all_test_users, distributed=args.distributed)
print('HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}'
.format(K=args.topk, hit_rate=hr, ndcg=ndcg))
return
success = False
max_hr = 0
train_throughputs, eval_throughputs = [], []
LOGGER.log(key=tags.TRAIN_LOOP)
for epoch in range(args.epochs):
LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)
LOGGER.log(key=tags.INPUT_HP_NUM_NEG, value=args.negative_samples)
LOGGER.log(key=tags.INPUT_STEP_TRAIN_NEG_GEN)
begin = time.time()
epoch_users, epoch_items, epoch_label = dataloading.prepare_epoch_train_data(train_ratings, nb_items, args)
num_batches = len(epoch_users)
for i in range(num_batches // args.grads_accumulated):
for j in range(args.grads_accumulated):
batch_idx = (args.grads_accumulated * i) + j
user = epoch_users[batch_idx]
item = epoch_items[batch_idx]
label = epoch_label[batch_idx].view(-1,1)
outputs = model(user, item)
loss = traced_criterion(outputs, label).float()
loss = torch.mean(loss.view(-1), 0)
if args.opt_level == "O2":
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
for p in model.parameters():
p.grad = None
del epoch_users, epoch_items, epoch_label
train_time = time.time() - begin
begin = time.time()
epoch_samples = len(train_ratings) * (args.negative_samples + 1)
train_throughput = epoch_samples / train_time
train_throughputs.append(train_throughput)
LOGGER.log(key='train_throughput', value=train_throughput)
LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
LOGGER.log(key=tags.EVAL_START, value=epoch)
hr, ndcg = val_epoch(model, test_users, test_items, dup_mask, real_indices, args.topk,
samples_per_user=args.valid_negative + 1,
num_user=all_test_users, epoch=epoch, distributed=args.distributed)
val_time = time.time() - begin
print('Epoch {epoch}: HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f},'
' train_time = {train_time:.2f}, val_time = {val_time:.2f}'
.format(epoch=epoch, K=args.topk, hit_rate=hr,
ndcg=ndcg, train_time=train_time,
val_time=val_time))
LOGGER.log(key=tags.EVAL_ACCURACY, value={"epoch": epoch, "value": hr})
LOGGER.log(key=tags.EVAL_TARGET, value=args.threshold)
LOGGER.log(key=tags.EVAL_STOP, value=epoch)
eval_size = all_test_users * (args.valid_negative + 1)
eval_throughput = eval_size / val_time
eval_throughputs.append(eval_throughput)
LOGGER.log(key='eval_throughput', value=eval_throughput)
if hr > max_hr and args.local_rank == 0:
max_hr = hr
print("New best hr! Saving the model to: ", checkpoint_path)
torch.save(model.state_dict(), checkpoint_path)
if args.threshold is not None:
if hr >= args.threshold:
print("Hit threshold of {}".format(args.threshold))
success = True
break
LOGGER.log(key='best_train_throughput', value=max(train_throughputs))
LOGGER.log(key='best_eval_throughput', value=max(eval_throughputs))
LOGGER.log(key='best_accuracy', value=max_hr)
LOGGER.log(key='time_to_target', value=time.time() - main_start_time)
LOGGER.log(key=tags.RUN_STOP, value={"success": success})
LOGGER.log(key=tags.RUN_FINAL)
if __name__ == '__main__':
main()