[NCF/PyT] Adding BYOD capabilities

This commit is contained in:
Tomasz Cheda 2021-09-03 06:20:12 -07:00 committed by Krzysztof Kudrynski
parent 4fdd014ebf
commit 5d6d417ff5
38 changed files with 2216 additions and 460 deletions

View file

@ -0,0 +1,2 @@
.git
data/

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.06-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.04-py3
FROM ${FROM_IMAGE_NAME}
RUN apt-get update && \

File diff suppressed because it is too large Load diff

View file

@ -14,7 +14,7 @@
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -31,13 +31,21 @@
from argparse import ArgumentParser
import pandas as pd
from load import implicit_load
from feature_spec import FeatureSpec
from neumf_constants import USER_CHANNEL_NAME, ITEM_CHANNEL_NAME, LABEL_CHANNEL_NAME, TEST_SAMPLES_PER_SERIES
import torch
import os
import tqdm
MIN_RATINGS = 20
TEST_1 = 'test_data_1.pt'
TEST_0 = 'test_data_0.pt'
TRAIN_1 = 'train_data_1.pt'
TRAIN_0 = 'train_data_0.pt'
USER_COLUMN = 'user_id'
ITEM_COLUMN = 'item_id'
def parse_args():
parser = ArgumentParser()
parser.add_argument('--path', type=str, default='/data/ml-20m/ratings.csv',
@ -61,7 +69,7 @@ class _TestNegSampler:
ids = (train_ratings[:, 0] * self.nb_items) + train_ratings[:, 1]
self.set = set(ids)
def generate(self, batch_size=128*1024):
def generate(self, batch_size=128 * 1024):
users = torch.arange(0, self.nb_users).reshape([1, -1]).repeat([self.nb_neg, 1]).transpose(0, 1).reshape(-1)
items = [-1] * len(users)
@ -82,6 +90,71 @@ class _TestNegSampler:
return items
def save_feature_spec(user_cardinality, item_cardinality, dtypes, test_negative_samples, output_path,
user_feature_name='user',
item_feature_name='item',
label_feature_name='label'):
feature_spec = {
user_feature_name: {
'dtype': dtypes[user_feature_name],
'cardinality': int(user_cardinality)
},
item_feature_name: {
'dtype': dtypes[item_feature_name],
'cardinality': int(item_cardinality)
},
label_feature_name: {
'dtype': dtypes[label_feature_name],
}
}
metadata = {
TEST_SAMPLES_PER_SERIES: test_negative_samples + 1
}
train_mapping = [
{
'type': 'torch_tensor',
'features': [
user_feature_name,
item_feature_name
],
'files': [TRAIN_0]
},
{
'type': 'torch_tensor',
'features': [
label_feature_name
],
'files': [TRAIN_1]
}
]
test_mapping = [
{
'type': 'torch_tensor',
'features': [
user_feature_name,
item_feature_name
],
'files': [TEST_0],
},
{
'type': 'torch_tensor',
'features': [
label_feature_name
],
'files': [TEST_1],
}
]
channel_spec = {
USER_CHANNEL_NAME: [user_feature_name],
ITEM_CHANNEL_NAME: [item_feature_name],
LABEL_CHANNEL_NAME: [label_feature_name]
}
source_spec = {'train': train_mapping, 'test': test_mapping}
feature_spec = FeatureSpec(feature_spec=feature_spec, metadata=metadata, source_spec=source_spec,
channel_spec=channel_spec, base_directory="")
feature_spec.to_yaml(output_path=output_path)
def main():
args = parse_args()
@ -91,39 +164,54 @@ def main():
print("Loading raw data from {}".format(args.path))
df = implicit_load(args.path, sort=False)
print("Filtering out users with less than {} ratings".format(MIN_RATINGS))
grouped = df.groupby(USER_COLUMN)
df = grouped.filter(lambda x: len(x) >= MIN_RATINGS)
print("Mapping original user and item IDs to new sequential IDs")
df[USER_COLUMN] = pd.factorize(df[USER_COLUMN])[0]
df[ITEM_COLUMN] = pd.factorize(df[ITEM_COLUMN])[0]
user_cardinality = df[USER_COLUMN].max() + 1
item_cardinality = df[ITEM_COLUMN].max() + 1
# Need to sort before popping to get last item
df.sort_values(by='timestamp', inplace=True)
# clean up data
del df['rating'], df['timestamp']
df = df.drop_duplicates() # assuming it keeps order
df = df.drop_duplicates() # assuming it keeps order
# now we have filtered and sorted by time data, we can split test data out
# Test set is the last interaction for a given user
grouped_sorted = df.groupby(USER_COLUMN, group_keys=False)
test_data = grouped_sorted.tail(1).sort_values(by='user_id')
# need to pop for each group
test_data = grouped_sorted.tail(1).sort_values(by=USER_COLUMN)
# Train set is all interactions but the last one
train_data = grouped_sorted.apply(lambda x: x.iloc[:-1])
# Note: no way to keep reference training data ordering because use of python set and multi-process
# It should not matter since it will be later randomized again
# save train and val data that is fixed.
train_ratings = torch.from_numpy(train_data.values)
torch.save(train_ratings, args.output+'/train_ratings.pt')
test_ratings = torch.from_numpy(test_data.values)
torch.save(test_ratings, args.output+'/test_ratings.pt')
sampler = _TestNegSampler(train_ratings.cpu().numpy(), args.valid_negative)
sampler = _TestNegSampler(train_data.values, args.valid_negative)
test_negs = sampler.generate().cuda()
test_negs = test_negs.reshape(-1, args.valid_negative)
torch.save(test_negs, args.output+'/test_negatives.pt')
# Reshape train set into user,item,label tabular and save
train_ratings = torch.from_numpy(train_data.values).cuda()
train_labels = torch.ones_like(train_ratings[:, 0:1], dtype=torch.float32)
torch.save(train_ratings, os.path.join(args.output, TRAIN_0))
torch.save(train_labels, os.path.join(args.output, TRAIN_1))
# Reshape test set into user,item,label tabular and save
# All users have the same number of items, items for a given user appear consecutively
test_ratings = torch.from_numpy(test_data.values).cuda()
test_users_pos = test_ratings[:, 0:1] # slicing instead of indexing to keep dimensions
test_items_pos = test_ratings[:, 1:2]
test_users = test_users_pos.repeat_interleave(args.valid_negative + 1, dim=0)
test_items = torch.cat((test_items_pos.reshape(-1, 1), test_negs), dim=1).reshape(-1, 1)
positive_labels = torch.ones_like(test_users_pos, dtype=torch.float32)
negative_labels = torch.zeros_like(test_users_pos, dtype=torch.float32).repeat(1, args.valid_negative)
test_labels = torch.cat((positive_labels, negative_labels), dim=1).reshape(-1, 1)
dtypes = {'user': str(test_users.dtype), 'item': str(test_items.dtype), 'label': str(test_labels.dtype)}
test_tensor = torch.cat((test_users, test_items), dim=1)
torch.save(test_tensor, os.path.join(args.output, TEST_0))
torch.save(test_labels, os.path.join(args.output, TEST_1))
save_feature_spec(user_cardinality=user_cardinality, item_cardinality=item_cardinality, dtypes=dtypes,
test_negative_samples=args.valid_negative, output_path=args.output + '/feature_spec.yaml')
if __name__ == '__main__':
main()

View file

@ -0,0 +1,158 @@
# Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
import pandas as pd
import numpy as np
from load import implicit_load
from convert import save_feature_spec, _TestNegSampler, TEST_0, TEST_1, TRAIN_0, TRAIN_1
import torch
import os
USER_COLUMN = 'user_id'
ITEM_COLUMN = 'item_id'
def parse_args():
parser = ArgumentParser()
parser.add_argument('--path', type=str, default='/data/ml-20m/ratings.csv',
help='Path to reviews CSV file from MovieLens')
parser.add_argument('--output', type=str, default='/data',
help='Output directory for train and test files')
parser.add_argument('--valid_negative', type=int, default=100,
help='Number of negative samples for each positive test example')
parser.add_argument('--seed', '-s', type=int, default=1,
help='Manually set random seed for torch')
parser.add_argument('--test', type=str, help='select modification to be applied to the set')
return parser.parse_args()
def main():
args = parse_args()
if args.seed is not None:
torch.manual_seed(args.seed)
print("Loading raw data from {}".format(args.path))
df = implicit_load(args.path, sort=False)
if args.test == 'less_user':
to_drop = set(list(df[USER_COLUMN].unique())[-100:])
df = df[~df[USER_COLUMN].isin(to_drop)]
if args.test == 'less_item':
to_drop = set(list(df[ITEM_COLUMN].unique())[-100:])
df = df[~df[ITEM_COLUMN].isin(to_drop)]
if args.test == 'more_user':
sample = df.sample(frac=0.2).copy()
sample[USER_COLUMN] = sample[USER_COLUMN] + 10000000
df = df.append(sample)
users = df[USER_COLUMN]
df = df[users.isin(users[users.duplicated(keep=False)])] # make sure something remains in the train set
if args.test == 'more_item':
sample = df.sample(frac=0.2).copy()
sample[ITEM_COLUMN] = sample[ITEM_COLUMN] + 10000000
df = df.append(sample)
print("Mapping original user and item IDs to new sequential IDs")
df[USER_COLUMN] = pd.factorize(df[USER_COLUMN])[0]
df[ITEM_COLUMN] = pd.factorize(df[ITEM_COLUMN])[0]
user_cardinality = df[USER_COLUMN].max() + 1
item_cardinality = df[ITEM_COLUMN].max() + 1
# Need to sort before popping to get last item
df.sort_values(by='timestamp', inplace=True)
# clean up data
del df['rating'], df['timestamp']
df = df.drop_duplicates() # assuming it keeps order
# Test set is the last interaction for a given user
grouped_sorted = df.groupby(USER_COLUMN, group_keys=False)
test_data = grouped_sorted.tail(1).sort_values(by=USER_COLUMN)
# Train set is all interactions but the last one
train_data = grouped_sorted.apply(lambda x: x.iloc[:-1])
sampler = _TestNegSampler(train_data.values, args.valid_negative)
test_negs = sampler.generate().cuda()
if args.valid_negative > 0:
test_negs = test_negs.reshape(-1, args.valid_negative)
else:
test_negs = test_negs.reshape(test_data.shape[0], 0)
if args.test == 'more_pos':
mask = np.random.rand(len(test_data)) < 0.5
sample = test_data[mask].copy()
sample[ITEM_COLUMN] = sample[ITEM_COLUMN] + 5
test_data = test_data.append(sample)
test_negs_copy = test_negs[mask]
test_negs = torch.cat((test_negs, test_negs_copy), dim=0)
if args.test == 'less_pos':
mask = np.random.rand(len(test_data)) < 0.5
test_data = test_data[mask]
test_negs = test_negs[mask]
# Reshape train set into user,item,label tabular and save
train_ratings = torch.from_numpy(train_data.values).cuda()
train_labels = torch.ones_like(train_ratings[:, 0:1], dtype=torch.float32)
torch.save(train_ratings, os.path.join(args.output, TRAIN_0))
torch.save(train_labels, os.path.join(args.output, TRAIN_1))
# Reshape test set into user,item,label tabular and save
# All users have the same number of items, items for a given user appear consecutively
test_ratings = torch.from_numpy(test_data.values).cuda()
test_users_pos = test_ratings[:, 0:1] # slicing instead of indexing to keep dimensions
test_items_pos = test_ratings[:, 1:2]
test_users = test_users_pos.repeat_interleave(args.valid_negative + 1, dim=0)
test_items = torch.cat((test_items_pos.reshape(-1, 1), test_negs), dim=1).reshape(-1, 1)
positive_labels = torch.ones_like(test_users_pos, dtype=torch.float32)
negative_labels = torch.zeros_like(test_users_pos, dtype=torch.float32).repeat(1, args.valid_negative)
test_labels = torch.cat((positive_labels, negative_labels), dim=1).reshape(-1, 1)
dtypes = {'user': str(test_users.dtype), 'item': str(test_items.dtype), 'label': str(test_labels.dtype)}
test_tensor = torch.cat((test_users, test_items), dim=1)
torch.save(test_tensor, os.path.join(args.output, TEST_0))
torch.save(test_labels, os.path.join(args.output, TEST_1))
if args.test == 'other_names':
dtypes = {'user_2': str(test_users.dtype),
'item_2': str(test_items.dtype),
'label_2': str(test_labels.dtype)}
save_feature_spec(user_cardinality=user_cardinality, item_cardinality=item_cardinality, dtypes=dtypes,
test_negative_samples=args.valid_negative, output_path=args.output + '/feature_spec.yaml',
user_feature_name='user_2',
item_feature_name='item_2',
label_feature_name='label_2')
else:
save_feature_spec(user_cardinality=user_cardinality, item_cardinality=item_cardinality, dtypes=dtypes,
test_negative_samples=args.valid_negative, output_path=args.output + '/feature_spec.yaml')
if __name__ == '__main__':
main()

View file

@ -0,0 +1,43 @@
feature_spec:
user:
cardinality: auto
item:
cardinality: auto
label:
metadata:
test_samples_per_series: 3
source_spec:
train:
- type: csv
features: #Each line corresponds to a column in the csv files
- user
- item
- label
files:
- train_data_1.csv # we assume no header
- train_data_2.csv
test:
- type: csv
features:
- user
- item
- label
files:
- test_data_1.csv
channel_spec:
user_ch: # Channel names are model-specific magics (in this model, neumf_constants.py)
- user
item_ch:
- item
label_ch:
- label
# Requirements:
# We assume the ids supplied have already been factorized into 0...N
# In the mapping to be used for validation and testing, candidates for each series (each user) appear consecutively.
# Each series has the same number of items: metadata['test_samples_per_series']

View file

@ -0,0 +1,30 @@
0, 8, 0
0, 18, 0
0, 17, 1
1, 7, 0
1, 6, 0
1, 16, 1
2, 12, 0
2, 13, 0
2, 16, 1
3, 3, 0
3, 1, 0
3, 5, 1
4, 16, 0
4, 3, 0
4, 8, 1
5, 14, 0
5, 12, 0
5, 12, 1
6, 3, 0
6, 3, 0
6, 1, 1
7, 3, 0
7, 18, 0
7, 8, 1
8, 8, 0
8, 8, 0
8, 2, 1
9, 19, 0
9, 9, 0
9, 18, 1
1 0 8 0
2 0 18 0
3 0 17 1
4 1 7 0
5 1 6 0
6 1 16 1
7 2 12 0
8 2 13 0
9 2 16 1
10 3 3 0
11 3 1 0
12 3 5 1
13 4 16 0
14 4 3 0
15 4 8 1
16 5 14 0
17 5 12 0
18 5 12 1
19 6 3 0
20 6 3 0
21 6 1 1
22 7 3 0
23 7 18 0
24 7 8 1
25 8 8 0
26 8 8 0
27 8 2 1
28 9 19 0
29 9 9 0
30 9 18 1

View file

@ -0,0 +1,60 @@
0, 14, 0
0, 3, 0
0, 18, 1
0, 15, 1
0, 2, 0
0, 1, 1
0, 5, 1
0, 7, 0
0, 12, 1
0, 19, 0
1, 9, 1
1, 0, 0
1, 16, 1
1, 2, 0
1, 8, 1
1, 17, 0
1, 17, 1
1, 9, 0
1, 5, 0
1, 12, 1
2, 8, 1
2, 0, 1
2, 1, 1
2, 0, 0
2, 4, 0
2, 17, 1
2, 18, 0
2, 3, 0
2, 10, 0
2, 18, 1
3, 14, 1
3, 4, 0
3, 0, 0
3, 16, 1
3, 6, 0
3, 17, 1
3, 0, 1
3, 0, 0
3, 3, 1
3, 0, 1
4, 13, 1
4, 8, 1
4, 1, 1
4, 14, 0
4, 18, 0
4, 7, 1
4, 19, 1
4, 3, 1
4, 17, 1
4, 17, 0
5, 8, 1
5, 10, 0
5, 4, 0
5, 19, 0
5, 12, 0
5, 3, 1
5, 5, 0
5, 8, 1
5, 19, 1
5, 12, 0
1 0 14 0
2 0 3 0
3 0 18 1
4 0 15 1
5 0 2 0
6 0 1 1
7 0 5 1
8 0 7 0
9 0 12 1
10 0 19 0
11 1 9 1
12 1 0 0
13 1 16 1
14 1 2 0
15 1 8 1
16 1 17 0
17 1 17 1
18 1 9 0
19 1 5 0
20 1 12 1
21 2 8 1
22 2 0 1
23 2 1 1
24 2 0 0
25 2 4 0
26 2 17 1
27 2 18 0
28 2 3 0
29 2 10 0
30 2 18 1
31 3 14 1
32 3 4 0
33 3 0 0
34 3 16 1
35 3 6 0
36 3 17 1
37 3 0 1
38 3 0 0
39 3 3 1
40 3 0 1
41 4 13 1
42 4 8 1
43 4 1 1
44 4 14 0
45 4 18 0
46 4 7 1
47 4 19 1
48 4 3 1
49 4 17 1
50 4 17 0
51 5 8 1
52 5 10 0
53 5 4 0
54 5 19 0
55 5 12 0
56 5 3 1
57 5 5 0
58 5 8 1
59 5 19 1
60 5 12 0

View file

@ -0,0 +1,40 @@
6, 18, 0
6, 19, 0
6, 4, 0
6, 16, 1
6, 19, 0
6, 2, 0
6, 4, 0
6, 2, 1
6, 0, 0
6, 12, 1
7, 3, 0
7, 7, 0
7, 16, 0
7, 4, 0
7, 19, 0
7, 11, 1
7, 10, 1
7, 13, 1
7, 18, 0
7, 4, 0
8, 4, 0
8, 5, 0
8, 12, 0
8, 2, 0
8, 14, 0
8, 19, 1
8, 0, 0
8, 17, 1
8, 19, 1
8, 15, 1
9, 9, 1
9, 17, 0
9, 9, 1
9, 14, 0
9, 11, 0
9, 17, 1
9, 4, 0
9, 1, 0
9, 8, 0
9, 10, 1
1 6 18 0
2 6 19 0
3 6 4 0
4 6 16 1
5 6 19 0
6 6 2 0
7 6 4 0
8 6 2 1
9 6 0 0
10 6 12 1
11 7 3 0
12 7 7 0
13 7 16 0
14 7 4 0
15 7 19 0
16 7 11 1
17 7 10 1
18 7 13 1
19 7 18 0
20 7 4 0
21 8 4 0
22 8 5 0
23 8 12 0
24 8 2 0
25 8 14 0
26 8 19 1
27 8 0 0
28 8 17 1
29 8 19 1
30 8 15 1
31 9 9 1
32 9 17 0
33 9 9 1
34 9 14 0
35 9 11 0
36 9 17 1
37 9 4 0
38 9 1 0
39 9 8 0
40 9 10 1

View file

@ -0,0 +1,50 @@
0, 17, 38308
0, 1, 38302
0, 10, 53558
0, 17, 53042
0, 12, 43899
1, 4, 85239
1, 3, 44884
1, 3, 37412
1, 8, 58416
1, 9, 39814
2, 6, 53985
2, 17, 63080
2, 9, 85791
2, 19, 37194
2, 3, 76871
3, 4, 32445
3, 1, 97224
3, 8, 76409
3, 6, 81547
3, 0, 52471
4, 15, 96242
4, 3, 72309
4, 9, 54815
4, 6, 94187
4, 16, 97208
5, 5, 56902
5, 0, 23414
5, 8, 55770
5, 5, 27745
5, 6, 61599
6, 12, 21675
6, 4, 53968
6, 7, 66164
6, 13, 94933
6, 1, 92957
7, 9, 30137
7, 11, 85128
7, 18, 30088
7, 14, 32186
7, 10, 84664
8, 1, 39714
8, 4, 27987
8, 15, 70023
8, 17, 93690
8, 8, 93827
9, 4, 80146
9, 8, 20896
9, 1, 55230
9, 13, 29631
9, 2, 46368
1 0 17 38308
2 0 1 38302
3 0 10 53558
4 0 17 53042
5 0 12 43899
6 1 4 85239
7 1 3 44884
8 1 3 37412
9 1 8 58416
10 1 9 39814
11 2 6 53985
12 2 17 63080
13 2 9 85791
14 2 19 37194
15 2 3 76871
16 3 4 32445
17 3 1 97224
18 3 8 76409
19 3 6 81547
20 3 0 52471
21 4 15 96242
22 4 3 72309
23 4 9 54815
24 4 6 94187
25 4 16 97208
26 5 5 56902
27 5 0 23414
28 5 8 55770
29 5 5 27745
30 5 6 61599
31 6 12 21675
32 6 4 53968
33 6 7 66164
34 6 13 94933
35 6 1 92957
36 7 9 30137
37 7 11 85128
38 7 18 30088
39 7 14 32186
40 7 10 84664
41 8 1 39714
42 8 4 27987
43 8 15 70023
44 8 17 93690
45 8 8 93827
46 9 4 80146
47 9 8 20896
48 9 1 55230
49 9 13 29631
50 9 2 46368

View file

@ -0,0 +1,52 @@
feature_spec:
user:
dtype: torch.int64
cardinality: 138493
item:
dtype: torch.int64
cardinality: 26744
label:
dtype: torch.float32
metadata:
test_samples_per_series: 101
source_spec:
train:
- type: torch_tensor
features:
# For torch_tensor, each line corresponds to a column. They are ordered
- user
- item
files:
# Loader currently only supports one file per chunk
- train_data_0.pt # Paths are relative to data-spec's directory
- type: torch_tensor
features:
- label
files:
- train_data_1.pt
test:
- type: torch_tensor
features:
- user
- item
files:
- test_data_0.pt
- type: torch_tensor
features:
- label
files:
- test_data_1.pt
channel_spec:
user_ch: # Channel names are model-specific magics (in this model, neumf_constants.py)
- user
item_ch:
- item
label_ch:
- label
# Requirements:
# During validation, for each user we have the same number of samples, supplied consecutively

View file

@ -14,7 +14,7 @@
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -28,95 +28,227 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import torch
import os
from feature_spec import FeatureSpec
from neumf_constants import USER_CHANNEL_NAME, ITEM_CHANNEL_NAME, LABEL_CHANNEL_NAME, TEST_SAMPLES_PER_SERIES
def create_test_data(test_ratings, test_negs, args):
test_users = test_ratings[:,0]
test_pos = test_ratings[:,1].reshape(-1,1)
class TorchTensorDataset:
""" Warning! This dataset/loader uses torch.load. Torch.load implicitly uses pickle. Pickle is insecure.
It is trivial to achieve arbitrary code execution using a prepared pickle payload. Only unpickle data you trust."""
# create items with real sample at last position
num_valid_negative = test_negs.shape[1]
test_users = test_users.reshape(-1,1).repeat(1, 1 + num_valid_negative)
test_items = torch.cat((test_negs, test_pos), dim=1)
del test_ratings, test_negs
def __init__(self, feature_spec: FeatureSpec, mapping_name: str, args):
self.local_rank = args.local_rank
self.mapping_name = mapping_name
self.features = dict()
self.feature_spec = feature_spec
self._load_features()
# generate dup mask and real indices for exact same behavior on duplication compare to reference
# here we need a sort that is stable(keep order of duplicates)
sorted_items, indices = torch.sort(test_items) # [1,1,1,2], [3,1,0,2]
sum_item_indices = sorted_items.float()+indices.float()/len(indices[0]) #[1.75,1.25,1.0,2.5]
indices_order = torch.sort(sum_item_indices)[1] #[2,1,0,3]
stable_indices = torch.gather(indices, 1, indices_order) #[0,1,3,2]
# produce -1 mask
dup_mask = (sorted_items[:,0:-1] == sorted_items[:,1:])
dup_mask = dup_mask.type(torch.uint8)
dup_mask = torch.cat((torch.zeros_like(test_pos, dtype=torch.uint8), dup_mask), dim=1)
dup_mask = torch.gather(dup_mask, 1, stable_indices.sort()[1])
# produce real sample indices to later check in topk
sorted_items, indices = (test_items != test_pos).type(torch.uint8).sort()
sum_item_indices = sorted_items.float()+indices.float()/len(indices[0])
indices_order = torch.sort(sum_item_indices)[1]
stable_indices = torch.gather(indices, 1, indices_order)
real_indices = stable_indices[:,0]
if args.distributed:
test_users = torch.chunk(test_users, args.world_size)[args.local_rank]
test_items = torch.chunk(test_items, args.world_size)[args.local_rank]
dup_mask = torch.chunk(dup_mask, args.world_size)[args.local_rank]
real_indices = torch.chunk(real_indices, args.world_size)[args.local_rank]
test_users = test_users.view(-1).split(args.valid_batch_size)
test_items = test_items.view(-1).split(args.valid_batch_size)
return test_users, test_items, dup_mask, real_indices
def _load_features(self):
chunks = self.feature_spec.source_spec[self.mapping_name]
for chunk in chunks:
assert chunk['type'] == 'torch_tensor', "Only torch_tensor files supported in this loader"
files_list = chunk['files']
assert len(files_list) == 1, "Only one file per chunk supported in this loader"
file_relative_path = files_list[0]
path_to_load = os.path.join(self.feature_spec.base_directory, file_relative_path)
chunk_data = torch.load(path_to_load, map_location=torch.device('cuda:{}'.format(self.local_rank)))
running_pos = 0
for feature_name in chunk['features']:
next_running_pos = running_pos + 1
feature_data = chunk_data[:, running_pos:next_running_pos]
# This is needed because slicing instead of indexing keeps the data 2-dimensional
feature_data = feature_data.reshape(-1, 1)
running_pos = next_running_pos
self.features[feature_name] = feature_data
def prepare_epoch_train_data(train_ratings, nb_items, args):
# create label
train_label = torch.ones_like(train_ratings[:,0], dtype=torch.float32)
neg_label = torch.zeros_like(train_label, dtype=torch.float32)
neg_label = neg_label.repeat(args.negative_samples)
train_label = torch.cat((train_label,neg_label))
del neg_label
class TestDataLoader:
def __init__(self, dataset: TorchTensorDataset, args):
self.dataset = dataset
self.feature_spec = dataset.feature_spec
self.channel_spec = self.feature_spec.channel_spec
self.samples_in_series = self.feature_spec.metadata[TEST_SAMPLES_PER_SERIES]
self.raw_dataset_length = None # First feature loaded sets this. Total length before splitting across cards
self.data = dict()
self.world_size = args.world_size
self.local_rank = args.local_rank
self.batch_size = args.valid_batch_size
train_users = train_ratings[:,0]
train_items = train_ratings[:,1]
self._build_channel_dict()
self._deduplication_augmentation()
self._split_between_devices()
self._split_into_batches()
train_users_per_worker = len(train_label) / args.world_size
train_users_begin = int(train_users_per_worker * args.local_rank)
train_users_end = int(train_users_per_worker * (args.local_rank + 1))
def _build_channel_dict(self):
for channel_name, channel_features in self.channel_spec.items():
channel_tensors = dict()
for feature_name in channel_features:
channel_tensors[feature_name] = self.dataset.features[feature_name]
# prepare data for epoch
neg_users = train_users.repeat(args.negative_samples)
neg_items = torch.empty_like(neg_users, dtype=torch.int64).random_(0, nb_items)
if not self.raw_dataset_length:
self.raw_dataset_length = channel_tensors[feature_name].shape[0]
else:
assert self.raw_dataset_length == channel_tensors[feature_name].shape[0]
epoch_users = torch.cat((train_users, neg_users))
epoch_items = torch.cat((train_items, neg_items))
self.data[channel_name] = channel_tensors
del neg_users, neg_items
def _deduplication_augmentation(self):
# Augmentation
# This adds a duplication mask tensor.
# This is here to exactly replicate the MLPerf training regime. Moving this deduplication to the candidate item
# generation stage increases the real diversity of the candidates, which makes the ranking task harder
# and results in a drop in HR@10 of approx 0.01. This has been deemed unacceptable (May 2021).
# shuffle prepared data and split into batches
epoch_indices = torch.randperm(train_users_end - train_users_begin, device='cuda:{}'.format(args.local_rank))
epoch_indices += train_users_begin
# We need the duplication mask to determine if a given item should be skipped during ranking
# If an item with label 1 is duplicated in the sampled ones, we need to be careful to not mark the one with
# label 1 as a duplicate. If an item appears repeatedly only with label 1, no duplicates are marked.
epoch_users = epoch_users[epoch_indices]
epoch_items = epoch_items[epoch_indices]
epoch_label = train_label[epoch_indices]
# To easily compute candidates, we sort the items. This will impact the distribution of examples between
# devices, but should not influence the numerics or performance meaningfully.
# We need to assure that the positive item, which we don't want to mark as a duplicate, appears first.
# We do this by adding labels as a secondary factor
if args.distributed:
local_batch = args.batch_size // args.world_size
else:
local_batch = args.batch_size
# Reshape the tensors to have items for a given user in a single row
user_feature_name = self.channel_spec[USER_CHANNEL_NAME][0]
item_feature_name = self.channel_spec[ITEM_CHANNEL_NAME][0]
label_feature_name = self.channel_spec[LABEL_CHANNEL_NAME][0]
self.ignore_mask_channel_name = 'mask_ch'
self.ignore_mask_feature_name = 'mask'
epoch_users = epoch_users.split(local_batch)
epoch_items = epoch_items.split(local_batch)
epoch_label = epoch_label.split(local_batch)
items = self.data[ITEM_CHANNEL_NAME][item_feature_name].view(-1, self.samples_in_series)
users = self.data[USER_CHANNEL_NAME][user_feature_name].view(-1, self.samples_in_series)
labels = self.data[LABEL_CHANNEL_NAME][label_feature_name].view(-1, self.samples_in_series)
# the last batch will almost certainly be smaller, drop it
epoch_users = epoch_users[:-1]
epoch_items = epoch_items[:-1]
epoch_label = epoch_label[:-1]
sorting_weights = items.float() - labels.float() * 0.5
_, indices = torch.sort(sorting_weights)
# The gather reorders according to the indices decided by the sort above
sorted_items = torch.gather(items, 1, indices)
sorted_labels = torch.gather(labels, 1, indices)
sorted_users = torch.gather(users, 1, indices)
return epoch_users, epoch_items, epoch_label
dup_mask = sorted_items[:, 0:-1] == sorted_items[:, 1:] # This says if a given item is equal to the next one
dup_mask = dup_mask.type(torch.bool)
# The first item for a given user can never be a duplicate:
dup_mask = torch.cat((torch.zeros_like(dup_mask[:, 0:1]), dup_mask), dim=1)
# Reshape them back
self.data[ITEM_CHANNEL_NAME][item_feature_name] = sorted_items.view(-1, 1)
self.data[USER_CHANNEL_NAME][user_feature_name] = sorted_users.view(-1, 1)
self.data[LABEL_CHANNEL_NAME][label_feature_name] = sorted_labels.view(-1, 1)
self.data[self.ignore_mask_channel_name] = dict()
self.data[self.ignore_mask_channel_name][self.ignore_mask_feature_name] = dup_mask.view(-1, 1)
def _split_between_devices(self):
if self.world_size > 1:
# DO NOT REPLACE WITH torch.chunk (number of returned chunks can silently be lower than requested).
# It would break compatibility with small datasets.
num_test_cases = self.raw_dataset_length / self.samples_in_series
smaller_batch = (int(num_test_cases // self.world_size)) * self.samples_in_series
bigger_batch = smaller_batch + self.samples_in_series
remainder = int(num_test_cases % self.world_size)
samples_per_card = [bigger_batch] * remainder + [smaller_batch] * (self.world_size - remainder)
for channel_name, channel_dict in self.data.items():
for feature_name, feature_tensor in channel_dict.items():
channel_dict[feature_name] = \
channel_dict[feature_name].split(samples_per_card)[self.local_rank]
def _split_into_batches(self):
self.batches = None
# This is the structure of each batch, waiting to be copied and filled in with data
for channel_name, channel_dict in self.data.items():
for feature_name, feature_tensor in channel_dict.items():
feature_batches = feature_tensor.view(-1).split(self.batch_size)
if not self.batches:
self.batches = list(
{channel_name: dict() for channel_name in self.data.keys()} for _ in feature_batches)
for pos, feature_batch_data in enumerate(feature_batches):
self.batches[pos][channel_name][feature_name] = feature_batch_data
def get_epoch_data(self):
return self.batches
def get_ignore_mask(self):
return self.data[self.ignore_mask_channel_name][self.ignore_mask_feature_name]
class TrainDataloader:
def __init__(self, dataset: TorchTensorDataset, args):
self.dataset = dataset
self.local_rank = args.local_rank
if args.distributed:
self.local_batch = args.batch_size // args.world_size
else:
self.local_batch = args.batch_size
self.feature_spec = dataset.feature_spec
self.channel_spec = self.feature_spec.channel_spec
self.negative_samples = args.negative_samples
self.data = dict()
self.raw_dataset_length = None # first feature loaded sets this
self._build_channel_dict()
self.length_after_augmentation = self.raw_dataset_length * (self.negative_samples + 1)
samples_per_worker = self.length_after_augmentation / args.world_size
self.samples_begin = int(samples_per_worker * args.local_rank)
self.samples_end = int(samples_per_worker * (args.local_rank + 1))
def _build_channel_dict(self):
for channel_name, channel_features in self.channel_spec.items():
channel_tensors = dict()
for feature_name in channel_features:
channel_tensors[feature_name] = self.dataset.features[feature_name]
if not self.raw_dataset_length:
self.raw_dataset_length = channel_tensors[feature_name].shape[0]
else:
assert self.raw_dataset_length == channel_tensors[feature_name].shape[0]
self.data[channel_name] = channel_tensors
def get_epoch_data(self):
# Augment, appending args.negative_samples times the original set, now with random items end negative labels
augmented_data = {channel_name: dict() for channel_name in self.data.keys()}
user_feature_name = self.channel_spec[USER_CHANNEL_NAME][0]
item_feature_name = self.channel_spec[ITEM_CHANNEL_NAME][0]
label_feature_name = self.channel_spec[LABEL_CHANNEL_NAME][0]
# USER
user_tensor = self.data[USER_CHANNEL_NAME][user_feature_name]
neg_users = user_tensor.repeat(self.negative_samples, 1)
augmented_users = torch.cat((user_tensor, neg_users))
augmented_data[USER_CHANNEL_NAME][user_feature_name] = augmented_users
del neg_users
# ITEM
item_tensor = self.data[ITEM_CHANNEL_NAME][item_feature_name]
neg_items = torch.empty_like(item_tensor).repeat(self.negative_samples, 1) \
.random_(0, self.feature_spec.feature_spec[item_feature_name]['cardinality'])
augmented_items = torch.cat((item_tensor, neg_items))
augmented_data[ITEM_CHANNEL_NAME][item_feature_name] = augmented_items
del neg_items
# LABEL
label_tensor = self.data[LABEL_CHANNEL_NAME][label_feature_name]
neg_label = torch.zeros_like(label_tensor, dtype=torch.float32).repeat(self.negative_samples, 1)
augmented_labels = torch.cat((label_tensor, neg_label))
del neg_label
augmented_data[LABEL_CHANNEL_NAME][label_feature_name] = augmented_labels
# Labels are not shuffled between cards.
# This replicates previous behaviour.
epoch_indices = torch.randperm(self.samples_end - self.samples_begin, device='cuda:{}'.format(self.local_rank))
epoch_indices += self.samples_begin
batches = None
for channel_name, channel_dict in augmented_data.items():
for feature_name, feature_tensor in channel_dict.items():
# the last batch will almost certainly be smaller, drop it
# Warning: may not work if there's only one
feature_batches = feature_tensor.view(-1)[epoch_indices].split(self.local_batch)[:-1]
if not batches:
batches = list({channel_name: dict() for channel_name in self.data.keys()} for _ in feature_batches)
for pos, feature_batch_data in enumerate(feature_batches):
batches[pos][channel_name][feature_name] = feature_batch_data
return batches

View file

@ -0,0 +1,50 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os
from typing import List, Dict
class FeatureSpec:
def __init__(self, feature_spec, source_spec, channel_spec, metadata, base_directory):
self.feature_spec: Dict = feature_spec
self.source_spec: Dict = source_spec
self.channel_spec: Dict = channel_spec
self.metadata: Dict = metadata
self.base_directory: str = base_directory
@classmethod
def from_yaml(cls, path):
with open(path, 'r') as feature_spec_file:
base_directory = os.path.dirname(path)
feature_spec = yaml.safe_load(feature_spec_file)
return cls.from_dict(feature_spec, base_directory=base_directory)
@classmethod
def from_dict(cls, source_dict, base_directory):
return cls(base_directory=base_directory, **source_dict)
def to_dict(self) -> Dict:
attributes_to_dump = ['feature_spec', 'source_spec', 'channel_spec', 'metadata']
return {attr: self.__dict__[attr] for attr in attributes_to_dump}
def to_string(self):
return yaml.dump(self.to_dict())
def to_yaml(self, output_path=None):
if not output_path:
output_path = self.base_directory + '/feature_spec.yaml'
with open(output_path, 'w') as output_file:
print(yaml.dump(self.to_dict()), file=output_file)

View file

@ -1,45 +0,0 @@
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class Fp16Optimizer:
def __init__(self, fp16_model, loss_scale=8192.0):
print('Initializing fp16 optimizer')
self.initialize_model(fp16_model)
self.loss_scale = loss_scale
def initialize_model(self, model):
print('Reset fp16 grad')
self.fp16_model = model
for param in self.fp16_model.parameters():
param.grad = None
print('Initializing fp32 clone weights')
self.fp32_params = [param.clone().type(torch.cuda.FloatTensor).detach()
for param in model.parameters()]
for param in self.fp32_params:
param.requires_grad = True
def backward(self, loss):
loss *= self.loss_scale
loss.backward()
def step(self, optimizer):
optimizer.step(grads=[p.grad for p in self.fp16_model.parameters()],
output_params=self.fp16_model.parameters(), scale=self.loss_scale)
for p in self.fp16_model.parameters():
p.grad = None

View file

@ -0,0 +1,25 @@
rm -r /data/cache/ml-20m
## Prepare the standard dataset:
./prepare_dataset.sh
## Prepare the modified dataset:
./test_dataset.sh
## Run on the modified dataset:
./test_cases.sh
## Check featurespec:
python test_featurespec_correctness.py /data/cache/ml-20m/feature_spec.yaml /data/ml-20m/feature_spec_template.yaml
## Other dataset:
rm -r /data/cache/ml-1m
./prepare_dataset.sh ml-1m
python -m torch.distributed.launch --nproc_per_node=1 --use_env ncf.py --data /data/cache/ml-1m --epochs 1

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

View file

@ -82,18 +82,23 @@ def main():
items = torch.cuda.LongTensor(batch_size).random_(0, args.n_items)
latencies = []
for _ in range(args.num_batches):
for i in range(args.num_batches):
torch.cuda.synchronize()
start = time.time()
_ = model(users, items, sigmoid=True)
torch.cuda.synchronize()
latencies.append(time.time() - start)
end_time = time.time()
if i < 10: # warmup iterations
continue
latencies.append(end_time - start)
result_data[f'batch_{batch_size}_mean_throughput'] = batch_size / np.mean(latencies)
result_data[f'batch_{batch_size}_mean_latency'] = np.mean(latencies)
result_data[f'batch_{batch_size}_p90_latency'] = np.percentile(latencies, 0.90)
result_data[f'batch_{batch_size}_p95_latency'] = np.percentile(latencies, 0.95)
result_data[f'batch_{batch_size}_p99_latency'] = np.percentile(latencies, 0.99)
result_data[f'batch_{batch_size}_p90_latency'] = np.percentile(latencies, 90)
result_data[f'batch_{batch_size}_p95_latency'] = np.percentile(latencies, 95)
result_data[f'batch_{batch_size}_p99_latency'] = np.percentile(latencies, 99)
dllogger.log(data=result_data, step=tuple())
dllogger.flush()

View file

@ -11,12 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import pandas as pd
RatingData = namedtuple('RatingData',
['items', 'users', 'ratings', 'min_date', 'max_date'])
@ -67,6 +81,13 @@ def load_ml_20m(filename, sort=True):
return process_movielens(ratings, sort=sort)
def load_unknown(filename, sort=True):
names = ['user_id', 'item_id', 'timestamp']
ratings = pd.read_csv(filename, names=names, header=0, engine='python')
ratings['rating'] = 5
return process_movielens(ratings, sort=sort)
DATASETS = [k.replace('load_', '') for k in locals().keys() if "load_" in k]
@ -74,7 +95,8 @@ def get_dataset_name(filename):
for dataset in DATASETS:
if dataset in filename.replace('-', '_').lower():
return dataset
raise NotImplementedError
print("Unknown dataset. Expecting `user_id`, `item_id` , and `timestamp`")
return "unknown"
def implicit_load(filename, sort=True):

View file

@ -14,7 +14,7 @@
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -42,23 +42,29 @@ import torch.nn as nn
import utils
import dataloading
from neumf import NeuMF
from feature_spec import FeatureSpec
from neumf_constants import USER_CHANNEL_NAME, ITEM_CHANNEL_NAME, LABEL_CHANNEL_NAME
import dllogger
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
def parse_args():
parser = ArgumentParser(description="Train a Nerual Collaborative"
parser = ArgumentParser(description="Train a Neural Collaborative"
" Filtering model")
parser.add_argument('--data', type=str,
help='Path to test and training data files')
help='Path to the directory containing the feature specification yaml')
parser.add_argument('--feature_spec_file', type=str, default='feature_spec.yaml',
help='Name of the feature specification file or path relative to the data directory.')
parser.add_argument('-e', '--epochs', type=int, default=30,
help='Number of epochs for training')
parser.add_argument('-b', '--batch_size', type=int, default=2**20,
help='Number of examples for each iteration')
parser.add_argument('--valid_batch_size', type=int, default=2**20,
help='Number of examples in each validation chunk')
parser.add_argument('-b', '--batch_size', type=int, default=2 ** 20,
help='Number of examples for each iteration. This will be divided by the number of devices')
parser.add_argument('--valid_batch_size', type=int, default=2 ** 20,
help='Number of examples in each validation chunk. This will be the maximum size of a batch '
'on each device.')
parser.add_argument('-f', '--factors', type=int, default=64,
help='Number of predictive factors')
parser.add_argument('--layers', nargs='+', type=int,
@ -83,11 +89,13 @@ def parse_args():
parser.add_argument('--dropout', type=float, default=0.5,
help='Dropout probability, if equal to 0 will not use dropout at all')
parser.add_argument('--checkpoint_dir', default='', type=str,
help='Path to the directory storing the checkpoint file, passing an empty path disables checkpoint saving')
help='Path to the directory storing the checkpoint file, '
'passing an empty path disables checkpoint saving')
parser.add_argument('--load_checkpoint_path', default=None, type=str,
help='Path to the checkpoint file to be loaded before training/evaluation')
parser.add_argument('--mode', choices=['train', 'test'], default='train', type=str,
help='Passing "test" will only run a single evaluation, otherwise full training will be performed')
help='Passing "test" will only run a single evaluation; '
'otherwise, full training will be performed')
parser.add_argument('--grads_accumulated', default=1, type=int,
help='Number of gradients to accumulate before performing an optimization step')
parser.add_argument('--amp', action='store_true', help='Enable mixed precision training')
@ -116,34 +124,44 @@ def init_distributed(args):
args.local_rank = 0
def val_epoch(model, x, y, dup_mask, real_indices, K, samples_per_user, num_user,
epoch=None, distributed=False):
def val_epoch(model, dataloader: dataloading.TestDataLoader, k, distributed=False):
model.eval()
user_feature_name = dataloader.channel_spec[USER_CHANNEL_NAME][0]
item_feature_name = dataloader.channel_spec[ITEM_CHANNEL_NAME][0]
label_feature_name = dataloader.channel_spec[LABEL_CHANNEL_NAME][0]
with torch.no_grad():
p = []
for u,n in zip(x,y):
p.append(model(u, n, sigmoid=True).detach())
labels_list = []
for batch_dict in dataloader.get_epoch_data():
user_batch = batch_dict[USER_CHANNEL_NAME][user_feature_name]
item_batch = batch_dict[ITEM_CHANNEL_NAME][item_feature_name]
label_batch = batch_dict[LABEL_CHANNEL_NAME][label_feature_name]
temp = torch.cat(p).view(-1,samples_per_user)
del x, y, p
p.append(model(user_batch, item_batch, sigmoid=True).detach())
labels_list.append(label_batch)
# set duplicate results for the same item to -1 before topk
temp[dup_mask] = -1
out = torch.topk(temp,K)[1]
# topk in pytorch is stable(if not sort)
# key(item):value(prediction) pairs are ordered as original key(item) order
# so we need the first position of real item(stored in real_indices) to check if it is in topk
ifzero = (out == real_indices.view(-1,1))
ignore_mask = dataloader.get_ignore_mask().view(-1, dataloader.samples_in_series)
ratings = torch.cat(p).view(-1, dataloader.samples_in_series)
ratings[ignore_mask] = -1
labels = torch.cat(labels_list).view(-1, dataloader.samples_in_series)
del p, labels_list
top_indices = torch.topk(ratings, k)[1]
# Positive items are always first in a given series
labels_of_selected = torch.gather(labels, 1, top_indices)
ifzero = (labels_of_selected == 1)
hits = ifzero.sum()
ndcg = (math.log(2) / (torch.nonzero(ifzero)[:,1].view(-1).to(torch.float)+2).log_()).sum()
ndcg = (math.log(2) / (torch.nonzero(ifzero)[:, 1].view(-1).to(torch.float) + 2).log_()).sum()
# torch.nonzero may cause host-device synchronization
if distributed:
torch.distributed.all_reduce(hits, op=torch.distributed.reduce_op.SUM)
torch.distributed.all_reduce(ndcg, op=torch.distributed.reduce_op.SUM)
torch.distributed.all_reduce(hits, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(ndcg, op=torch.distributed.ReduceOp.SUM)
hr = hits.item() / num_user
ndcg = ndcg.item() / num_user
num_test_cases = dataloader.raw_dataset_length / dataloader.samples_in_series
hr = hits.item() / num_test_cases
ndcg = ndcg.item() / num_test_cases
model.train()
return hr, ndcg
@ -160,6 +178,12 @@ def main():
else:
dllogger.init(backends=[])
dllogger.metadata('train_throughput', {"name": 'train_throughput', 'format': ":.3e"})
dllogger.metadata('hr@10', {"name": 'hr@10', 'format': ":.5f"})
dllogger.metadata('train_epoch_time', {"name": 'train_epoch_time', 'format': ":.3f"})
dllogger.metadata('validation_epoch_time', {"name": 'validation_epoch_time', 'format': ":.3f"})
dllogger.metadata('eval_throughput', {"name": 'eval_throughput', 'format': ":.3e"})
dllogger.log(data=vars(args), step='PARAMETER')
if args.seed is not None:
@ -176,25 +200,22 @@ def main():
main_start_time = time.time()
train_ratings = torch.load(args.data+'/train_ratings.pt', map_location=torch.device('cuda:{}'.format(args.local_rank)))
test_ratings = torch.load(args.data+'/test_ratings.pt', map_location=torch.device('cuda:{}'.format(args.local_rank)))
test_negs = torch.load(args.data+'/test_negatives.pt', map_location=torch.device('cuda:{}'.format(args.local_rank)))
valid_negative = test_negs.shape[1]
nb_maxs = torch.max(train_ratings, 0)[0]
nb_users = nb_maxs[0].item() + 1
nb_items = nb_maxs[1].item() + 1
all_test_users = test_ratings.shape[0]
test_users, test_items, dup_mask, real_indices = dataloading.create_test_data(test_ratings, test_negs, args)
feature_spec_path = os.path.join(args.data, args.feature_spec_file)
feature_spec = FeatureSpec.from_yaml(feature_spec_path)
trainset = dataloading.TorchTensorDataset(feature_spec, mapping_name='train', args=args)
testset = dataloading.TorchTensorDataset(feature_spec, mapping_name='test', args=args)
train_loader = dataloading.TrainDataloader(trainset, args)
test_loader = dataloading.TestDataLoader(testset, args)
# make pytorch memory behavior more consistent later
torch.cuda.empty_cache()
# Create model
model = NeuMF(nb_users, nb_items,
user_feature_name = feature_spec.channel_spec[USER_CHANNEL_NAME][0]
item_feature_name = feature_spec.channel_spec[ITEM_CHANNEL_NAME][0]
label_feature_name = feature_spec.channel_spec[LABEL_CHANNEL_NAME][0]
model = NeuMF(nb_users=feature_spec.feature_spec[user_feature_name]['cardinality'],
nb_items=feature_spec.feature_spec[item_feature_name]['cardinality'],
mf_dim=args.factors,
mlp_layer_sizes=args.layers,
dropout=args.dropout)
@ -202,7 +223,7 @@ def main():
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate,
betas=(args.beta1, args.beta2), eps=args.eps)
criterion = nn.BCEWithLogitsLoss(reduction='none') # use torch.mean() with dim later to avoid copy to host
criterion = nn.BCEWithLogitsLoss(reduction='none') # use torch.mean() with dim later to avoid copy to host
# Move model and loss to GPU
model = model.cuda()
criterion = criterion.cuda()
@ -216,48 +237,56 @@ def main():
local_batch = args.batch_size // args.world_size
traced_criterion = torch.jit.trace(criterion.forward,
(torch.rand(local_batch,1),torch.rand(local_batch,1)))
(torch.rand(local_batch, 1), torch.rand(local_batch, 1)))
print(model)
print("{} parameters".format(utils.count_parameters(model)))
if args.load_checkpoint_path:
state_dict = torch.load(args.load_checkpoint_path)
state_dict = {k.replace('module.', '') : v for k,v in state_dict.items()}
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
if args.mode == 'test':
start = time.time()
hr, ndcg = val_epoch(model, test_users, test_items, dup_mask, real_indices, args.topk,
samples_per_user=valid_negative + 1,
num_user=all_test_users, distributed=args.distributed)
hr, ndcg = val_epoch(model, test_loader, args.topk, distributed=args.distributed)
val_time = time.time() - start
eval_size = all_test_users * (valid_negative + 1)
eval_size = test_loader.raw_dataset_length
eval_throughput = eval_size / val_time
dllogger.log(step=tuple(), data={'best_eval_throughput' : eval_throughput,
'hr@10' : hr})
dllogger.log(step=tuple(), data={'best_eval_throughput': eval_throughput,
'hr@10': hr})
return
# this should always be overridden if hr>0.
# It is theoretically possible for the hit rate to be zero in the first epoch, which would result in referring
# to an uninitialized variable.
max_hr = 0
best_epoch = 0
best_model_timestamp = time.time()
train_throughputs, eval_throughputs = [], []
for epoch in range(args.epochs):
begin = time.time()
epoch_users, epoch_items, epoch_label = dataloading.prepare_epoch_train_data(train_ratings, nb_items, args)
num_batches = len(epoch_users)
batch_dict_list = train_loader.get_epoch_data()
num_batches = len(batch_dict_list)
for i in range(num_batches // args.grads_accumulated):
for j in range(args.grads_accumulated):
batch_idx = (args.grads_accumulated * i) + j
user = epoch_users[batch_idx]
item = epoch_items[batch_idx]
label = epoch_label[batch_idx].view(-1,1)
batch_dict = batch_dict_list[batch_idx]
outputs = model(user, item)
loss = traced_criterion(outputs, label).float()
user_features = batch_dict[USER_CHANNEL_NAME]
item_features = batch_dict[ITEM_CHANNEL_NAME]
user_batch = user_features[user_feature_name]
item_batch = item_features[item_feature_name]
label_features = batch_dict[LABEL_CHANNEL_NAME]
label_batch = label_features[label_feature_name]
outputs = model(user_batch, item_batch)
loss = traced_criterion(outputs, label_batch.view(-1, 1)).float()
loss = torch.mean(loss.view(-1), 0)
if args.amp:
@ -270,31 +299,27 @@ def main():
for p in model.parameters():
p.grad = None
del epoch_users, epoch_items, epoch_label
del batch_dict_list
train_time = time.time() - begin
begin = time.time()
epoch_samples = len(train_ratings) * (args.negative_samples + 1)
epoch_samples = train_loader.length_after_augmentation
train_throughput = epoch_samples / train_time
train_throughputs.append(train_throughput)
hr, ndcg = val_epoch(model, test_users, test_items, dup_mask, real_indices, args.topk,
samples_per_user=valid_negative + 1,
num_user=all_test_users, epoch=epoch, distributed=args.distributed)
hr, ndcg = val_epoch(model, test_loader, args.topk, distributed=args.distributed)
val_time = time.time() - begin
eval_size = all_test_users * (valid_negative + 1)
eval_size = test_loader.raw_dataset_length
eval_throughput = eval_size / val_time
eval_throughputs.append(eval_throughput)
dllogger.log(step=(epoch,),
data = {'train_throughput': train_throughput,
'hr@10': hr,
'train_epoch_time': train_time,
'validation_epoch_time': val_time,
'eval_throughput': eval_throughput})
data={'train_throughput': train_throughput,
'hr@10': hr,
'train_epoch_time': train_time,
'validation_epoch_time': val_time,
'eval_throughput': eval_throughput})
if hr > max_hr and args.local_rank == 0:
max_hr = hr

View file

@ -35,6 +35,7 @@ import torch.nn as nn
import sys
from os.path import abspath, join, dirname
class NeuMF(nn.Module):
def __init__(self, nb_users, nb_items,
mf_dim, mlp_layer_sizes, dropout=0):

View file

@ -0,0 +1,18 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
USER_CHANNEL_NAME = 'user_ch'
ITEM_CHANNEL_NAME = 'item_ch'
LABEL_CHANNEL_NAME = 'label_ch'
TEST_SAMPLES_PER_SERIES = 'test_samples_per_series'

View file

@ -14,7 +14,7 @@
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -34,22 +34,26 @@ set -e
set -x
DATASET_NAME=${1:-'ml-20m'}
RAW_DATADIR=${2:-'/data'}
CACHED_DATADIR=${3:-"${RAW_DATADIR}/cache/${DATASET_NAME}"}
RAW_DATADIR=${2:-"/data/${DATASET_NAME}"}
CACHED_DATADIR=${3:-"/data/cache/${DATASET_NAME}"}
# you can add another option to this case in order to support other datasets
case ${DATASET_NAME} in
'ml-20m')
ZIP_PATH=${RAW_DATADIR}/'ml-20m.zip'
SHOULD_UNZIP=1
RATINGS_PATH=${RAW_DATADIR}'/ml-20m/ratings.csv'
;;
'ml-1m')
ZIP_PATH=${RAW_DATADIR}/'ml-1m.zip'
SHOULD_UNZIP=1
RATINGS_PATH=${RAW_DATADIR}'/ml-1m/ratings.dat'
;;
*)
echo "Unsupported dataset name: $DATASET_NAME"
exit 1
*)
echo "Using unknown dataset: $DATASET_NAME."
RATINGS_PATH=${RAW_DATADIR}'/ratings.csv'
echo "Expecting file at ${RATINGS_PATH}"
SHOULD_UNZIP=0
esac
if [ ! -d ${RAW_DATADIR} ]; then
@ -64,16 +68,21 @@ if [ -f log ]; then
rm -f log
fi
if [ ! -f ${ZIP_PATH} ]; then
echo "Dataset not found. Please download it from: https://grouplens.org/datasets/movielens/20m/ and put it in ${ZIP_PATH}"
exit 1
fi
if [ ! -f ${RATINGS_PATH} ]; then
unzip -u ${ZIP_PATH} -d ${RAW_DATADIR}
if [ $SHOULD_UNZIP == 1 ]; then
if [ ! -f ${ZIP_PATH} ]; then
echo "Dataset not found. Please download it from: https://grouplens.org/datasets/movielens/20m/ and put it in ${ZIP_PATH}"
exit 1
fi
unzip -u ${ZIP_PATH} -d ${RAW_DATADIR}
else
echo "File not found at ${RATINGS_PATH}. Aborting."
exit 1
fi
fi
if [ ! -f ${CACHED_DATADIR}/train_ratings.pt ]; then
if [ ! -f ${CACHED_DATADIR}/feature_spec.yaml ]; then
echo "preprocessing ${RATINGS_PATH} and save to disk"
t0=$(date +%s)
python convert.py --path ${RATINGS_PATH} --output ${CACHED_DATADIR}
@ -84,7 +93,7 @@ else
echo 'Using cached preprocessed data'
fi
echo "Dataset $DATASET_NAME successfully prepared at: $CACHED_DATADIR\n"
echo "Dataset $DATASET_NAME successfully prepared at: $CACHED_DATADIR"
echo "You can now run the training with: python -m torch.distributed.launch --nproc_per_node=<number_of_GPUs> --use_env ncf.py --data ${CACHED_DATADIR}"

View file

@ -0,0 +1,70 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import matplotlib.pyplot as plt
def get_training_data(filename):
with open(filename, 'r') as opened:
line = opened.readlines()[-1]
json_content = line[len("DLLL "):]
data = json.loads(json_content)["data"]
with open(filename, 'r') as opened:
for line in opened.readlines():
d = json.loads(line[len("DLLL "):])
if d.get("step", "") == "PARAMETER":
data['batch_size'] = d["data"]["batch_size"]
return data
a100 = "runs/pytorch_ncf_A100-SXM4-40GBx{numgpus}gpus_{precision}_{num_run}.json"
v16 = "runs/pytorch_ncf_Tesla V100-SXM2-16GBx{numgpus}gpus_{precision}_{num_run}.json"
v32 = "runs/pytorch_ncf_Tesla V100-SXM2-32GBx{numgpus}gpus_{precision}_{num_run}.json"
dgx2 = "runs/pytorch_ncf_Tesla V100-SXM3-32GBx{numgpus}gpus_{precision}_{num_run}.json"
fp32 = "FP32"
amp = "Mixed (AMP)"
tf32 = "TF32"
def get_accs(arch, numgpu, prec):
data = [get_training_data(arch.format(numgpus=numgpu, num_run=num_run, precision=prec)) for num_run in range(1, 21)]
accs = [d["best_accuracy"] for d in data]
return accs
def get_plots():
archs = [dgx2, a100]
gpuranges = [(1, 8, 16), (1, 8)]
titles = ["DGX2 32GB", "DGX A100 40GB"]
fullprecs = [fp32, tf32]
fig, axs = plt.subplots(2, 3, sharey=True, figsize=(8, 8))
plt.subplots_adjust(hspace=0.5)
for x, arch in enumerate(archs):
gpurange = gpuranges[x]
for y, gpu in enumerate(gpurange):
f_data = get_accs(arch, gpu, fullprecs[x])
h_data = get_accs(arch, gpu, amp)
axs[x, y].boxplot([f_data, h_data])
axs[x, y].set_xticklabels([fullprecs[x], amp])
axs[x, y].set_title(f"{gpu} GPUs" if gpu > 1 else "1 GPU")
axs[x, 0].set_ylabel(titles[x])
fig.delaxes(axs[1, 2])
# plt.show()
plt.savefig("box_plots.png")
if __name__ == "__main__":
get_plots()

View file

@ -0,0 +1,113 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tabulate
import numpy as np
def get_training_data(filename):
with open(filename, 'r') as opened:
line = opened.readlines()[-1]
json_content = line[len("DLLL "):]
data = json.loads(json_content)["data"]
with open(filename, 'r') as opened:
for line in opened.readlines():
d = json.loads(line[len("DLLL "):])
if d.get("step", "") == "PARAMETER":
data['batch_size'] = d["data"]["batch_size"]
return data
a100 = "runs/pytorch_ncf_A100-SXM4-40GBx{numgpus}gpus_{precision}_{num_run}.json"
v16 = "runs/pytorch_ncf_Tesla V100-SXM2-16GBx{numgpus}gpus_{precision}_{num_run}.json"
v32 = "runs/pytorch_ncf_Tesla V100-SXM2-32GBx{numgpus}gpus_{precision}_{num_run}.json"
dgx2 = "runs/pytorch_ncf_Tesla V100-SXM3-32GBx{numgpus}gpus_{precision}_{num_run}.json"
fp32 = "FP32"
amp = "Mixed (AMP)"
tf32 = "TF32"
first = a100.format(numgpus=1, precision=fp32, num_run=1)
timevar = 'time_to_target' #"time_to_best_model"
def get_acc_table(arch, numgpus, fullprec):
headers = ["GPUs", "Batch size / GPU", f"Accuracy - {fullprec}", "Accuracy - mixed precision", f"Time to train - {fullprec}", "Time to train - mixed precision", f"Time to train speedup ({fullprec} to mixed precision)"]
table = []
for numgpus in numgpus:
data_full = [get_training_data(arch.format(numgpus=numgpus, num_run=num_run, precision=fullprec)) for num_run in range(1, 21)]
data_mixed = [get_training_data(arch.format(numgpus=numgpus, num_run=num_run, precision=amp)) for num_run in range(1, 21)]
bsize = data_full[0]['batch_size']/numgpus
accs_full = np.mean([d["best_accuracy"] for d in data_full])
accs_mixed = np.mean([d["best_accuracy"] for d in data_mixed])
time_full = np.mean([d[timevar] for d in data_full])
time_mixed = np.mean([d[timevar] for d in data_mixed])
speedup = time_full / time_mixed
row = [numgpus, int(bsize),
"{:.6f}".format(accs_full),
"{:.6f}".format(accs_mixed),
"{:.6f}".format(time_full),
"{:.6f}".format(time_mixed),
"{:.2f}".format(speedup)]
table.append(row)
print(tabulate.tabulate(table, headers, tablefmt='pipe'))
def get_perf_table(arch, numgpus, fullprec):
headers = ["GPUs",
"Batch size / GPU",
f"Throughput - {fullprec} (samples/s)",
"Throughput - mixed precision (samples/s)",
f"Throughput speedup ({fullprec} to mixed precision)",
f"Strong scaling - {fullprec}",
"Strong scaling - mixed precision",
]
table = []
base_full = None
base_mixed = None
for numgpus in numgpus:
data_full = [get_training_data(arch.format(numgpus=numgpus, num_run=num_run, precision=fullprec)) for num_run in range(1, 21)]
data_mixed = [get_training_data(arch.format(numgpus=numgpus, num_run=num_run, precision=amp)) for num_run in range(1, 21)]
bsize = data_full[0]['batch_size']/numgpus
_full = np.mean([d["best_train_throughput"] for d in data_full])
_mixed = np.mean([d["best_train_throughput"] for d in data_mixed])
if numgpus == 1:
base_full = _full
base_mixed = _mixed
scaling_full = _full/ base_full
scaling_mixed = _mixed / base_mixed
time_mixed = np.mean([d[timevar] for d in data_mixed])
speedup = _full / _mixed
row = [numgpus, int(bsize),
"{:.2f}M".format(_full / 10**6),
"{:.2f}M".format(_mixed / 10**6),
"{:.2f}".format(speedup),
"{:.2f}".format(scaling_full),
"{:.2f}".format(scaling_mixed)]
table.append(row)
print(tabulate.tabulate(table, headers, tablefmt='pipe'))
#get_acc_table(a100, (1, 8), tf32)
#get_acc_table(v16, (1, 8), fp32)
#get_acc_table(v32, (1, 8), fp32)
#get_acc_table(dgx2, (1, 8, 16), fp32)
#get_perf_table(a100, (1, 8), tf32)
#get_perf_table(v16, (1, 8), fp32)
#get_perf_table(v32, (1, 8), fp32)
#get_perf_table(dgx2, (1, 8, 16), fp32)

View file

@ -0,0 +1,66 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import matplotlib.pyplot as plt
def get_curve(filename):
hrs = []
with open(filename, 'r') as opened:
for line in opened.readlines():
d = json.loads(line[len("DLLL "):])
try:
hrs.append(d["data"]["hr@10"])
except KeyError:
pass
return hrs
a100 = "runs/pytorch_ncf_A100-SXM4-40GBx{numgpus}gpus_{precision}_{num_run}.json"
v16 = "runs/pytorch_ncf_Tesla V100-SXM2-16GBx{numgpus}gpus_{precision}_{num_run}.json"
v32 = "runs/pytorch_ncf_Tesla V100-SXM2-32GBx{numgpus}gpus_{precision}_{num_run}.json"
dgx2 = "runs/pytorch_ncf_Tesla V100-SXM3-32GBx{numgpus}gpus_{precision}_{num_run}.json"
fp32 = "FP32"
amp = "Mixed (AMP)"
tf32 = "TF32"
def get_accs(arch, numgpu, prec):
data = [get_curve(arch.format(numgpus=numgpu, num_run=num_run, precision=prec)) for num_run in range(1, 21)]
return data[0]
def get_plots():
archs = [dgx2, a100]
titles = ["DGX2 32GB", "DGX A100 40GB"]
fullprecs = [fp32, tf32]
halfprecs = [amp, amp]
gpuranges = [(1, 8, 16), (1, 8)]
fig, axs = plt.subplots(1, 2, sharey=True, figsize=(10, 5))
plt.subplots_adjust(hspace=0.5)
for x, prec in enumerate([fullprecs, halfprecs]):
for i, arch in enumerate(archs):
for numgpu in gpuranges[i]:
d = get_accs(arch, numgpu, prec[i])
axs[x].plot(range(len(d)), d, label=f"{titles[i]} x {numgpu} {prec[i]}")
axs[x].legend()
#plt.show()
plt.savefig("val_curves.png")
if __name__ == "__main__":
get_plots()

View file

@ -0,0 +1,44 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tabulate
archs = ["a100", "v100"]
precs = ["full", "half"]
for arch in archs:
for prec in precs:
filename = f"inference/{arch}_{prec}.log"
with open(filename) as opened:
line = opened.readlines()[-1]
log = json.loads(line[len("DLLL "):])['data']
print(log)
batch_sizes = [1024, 4096, 16384, 65536, 262144, 1048576]
t_avg = "batch_{}_mean_throughput"
l_mean = "batch_{}_mean_latency"
l_90 = "batch_{}_p90_latency"
l_95 = "batch_{}_p95_latency"
l_99 = "batch_{}_p99_latency"
headers = ["Batch size", "Throughput Avg", "Latency Avg", "Latency 90%", "Latency 95%", "Latency 99%"]
table = []
for bsize in batch_sizes:
table.append([bsize,
"{:3.3f}".format(log[t_avg.format(bsize)]),
"{:.6f}".format(log[l_mean.format(bsize)]),
"{:.6f}".format(log[l_90.format(bsize)]),
"{:.6f}".format(log[l_95.format(bsize)]),
"{:.6f}".format(log[l_99.format(bsize)])])
print(filename)
print(tabulate.tabulate(table, headers, tablefmt='pipe'))

View file

@ -1,3 +1,4 @@
pandas
pandas>=0.24.2
tqdm
git+https://github.com/NVIDIA/dllogger#egg=dllogger
pyyaml
git+https://github.com/NVIDIA/dllogger#egg=dllogger

View file

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
set -x
for test_name in more_pos less_pos less_user less_item more_user more_item other_names;
do
CACHED_DATADIR='/data/cache/ml-20m'
NEW_DIR=${CACHED_DATADIR}/${test_name}
echo "Trying to run on modified dataset: $test_name"
python -m torch.distributed.launch --nproc_per_node=1 --use_env ncf.py --data ${NEW_DIR} --epochs 1
echo "Model runs on modified dataset: $test_name"
done
for test_sample in '0' '10' '200';
do
CACHED_DATADIR='/data/cache/ml-20m'
NEW_DIR=${CACHED_DATADIR}/sample_${test_name}
echo "Trying to run on dataset with test sampling: $test_sample"
python -m torch.distributed.launch --nproc_per_node=1 --use_env ncf.py --data ${NEW_DIR} --epochs 1
echo "Model runs on dataset with test sampling: $test_sample"
done
for online_sample in '0' '1' '10';
do
CACHED_DATADIR='/data/cache/ml-20m'
echo "Trying to run with train sampling: $online_sample"
python -m torch.distributed.launch --nproc_per_node=1 --use_env ncf.py --data ${CACHED_DATADIR} --epochs 1 -n ${online_sample}
echo "Model runs with train sampling: $online_sample"
done

View file

@ -0,0 +1,103 @@
# Copyright (c) 2018, deepakn94, robieta. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
set -e
set -x
DATASET_NAME=${1:-'ml-20m'}
RAW_DATADIR=${2:-"/data/${DATASET_NAME}"}
CACHED_DATADIR=${3:-"$/data/cache/${DATASET_NAME}"}
# you can add another option to this case in order to support other datasets
case ${DATASET_NAME} in
'ml-20m')
ZIP_PATH=${RAW_DATADIR}/'ml-20m.zip'
RATINGS_PATH=${RAW_DATADIR}'/ml-20m/ratings.csv'
;;
'ml-1m')
ZIP_PATH=${RAW_DATADIR}/'ml-1m.zip'
RATINGS_PATH=${RAW_DATADIR}'/ml-1m/ratings.dat'
;;
*)
echo "Unsupported dataset name: $DATASET_NAME"
exit 1
esac
if [ ! -d ${RAW_DATADIR} ]; then
mkdir -p ${RAW_DATADIR}
fi
if [ ! -d ${CACHED_DATADIR} ]; then
mkdir -p ${CACHED_DATADIR}
fi
if [ -f log ]; then
rm -f log
fi
if [ ! -f ${ZIP_PATH} ]; then
echo "Dataset not found. Please download it from: https://grouplens.org/datasets/movielens/20m/ and put it in ${ZIP_PATH}"
exit 1
fi
if [ ! -f ${RATINGS_PATH} ]; then
unzip -u ${ZIP_PATH} -d ${RAW_DATADIR}
fi
for test_name in more_pos less_pos less_user less_item more_user more_item other_names;
do
NEW_DIR=${CACHED_DATADIR}/${test_name}
if [ ! -d ${NEW_DIR} ]; then
mkdir -p ${NEW_DIR}
fi
python convert_test.py --path ${RATINGS_PATH} --output $NEW_DIR --test ${test_name}
echo "Generated testing for $test_name"
done
for test_sample in '0' '10' '200';
do
NEW_DIR=${CACHED_DATADIR}/sample_${test_name}
if [ ! -d ${NEW_DIR} ]; then
mkdir -p ${NEW_DIR}
fi
python convert_test.py --path ${RATINGS_PATH} --output $NEW_DIR --valid_negative $test_sample
echo "Generated testing for $test_name"
done
echo "Dataset $DATASET_NAME successfully prepared at: $CACHED_DATADIR"
echo "You can now run the training with: python -m torch.distributed.launch --nproc_per_node=<number_of_GPUs> --use_env ncf.py --data ${CACHED_DATADIR}"

View file

@ -0,0 +1,90 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from feature_spec import FeatureSpec
from neumf_constants import TEST_SAMPLES_PER_SERIES
from dataloading import TorchTensorDataset
import torch
import os
import sys
def test_matches_template(path, template_path):
loaded_featurespec_string = FeatureSpec.from_yaml(path).to_string()
loaded_template_string = FeatureSpec.from_yaml(template_path).to_string()
assert loaded_template_string == loaded_featurespec_string
def mock_args():
class Obj:
pass
args = Obj()
args.__dict__['local_rank'] = 0
return args
def test_dtypes(path):
loaded_featurespec = FeatureSpec.from_yaml(path)
features = loaded_featurespec.feature_spec
declared_dtypes = {name: data['dtype'] for name, data in features.items()}
source_spec = loaded_featurespec.source_spec
for mapping in source_spec.values():
for chunk in mapping:
chunk_dtype = None
for present_feature in chunk['features']:
assert present_feature in declared_dtypes, "unknown feature in mapping"
# Check declared type
feature_dtype = declared_dtypes[present_feature]
if chunk_dtype is None:
chunk_dtype = feature_dtype
else:
assert chunk_dtype == feature_dtype
path_to_load = os.path.join(loaded_featurespec.base_directory, chunk['files'][0])
loaded_data = torch.load(path_to_load)
assert str(loaded_data.dtype) == chunk_dtype
def test_cardinalities(path):
loaded_featurespec = FeatureSpec.from_yaml(path)
features = loaded_featurespec.feature_spec
declared_cardinalities = {name: data['cardinality'] for name, data in features.items() if 'cardinality' in data}
source_spec = loaded_featurespec.source_spec
for mapping_name, mapping in source_spec.items():
dataset = TorchTensorDataset(loaded_featurespec, mapping_name, mock_args())
for feature_name, cardinality in declared_cardinalities.items():
feature_data = dataset.features[feature_name]
biggest_num = feature_data.max().item()
assert biggest_num < cardinality
def test_samples_in_test_series(path):
loaded_featurespec = FeatureSpec.from_yaml(path)
series_length = loaded_featurespec.metadata[TEST_SAMPLES_PER_SERIES]
dataset = TorchTensorDataset(loaded_featurespec, 'test', mock_args())
for feature in dataset.features.values():
assert len(feature) % series_length == 0
if __name__ == '__main__':
tested_spec = sys.argv[1]
template = sys.argv[2]
test_cardinalities(tested_spec)
test_dtypes(tested_spec)
test_samples_in_test_series(tested_spec)
test_matches_template(tested_spec, template)

View file

@ -0,0 +1,127 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
import os
import torch
import pandas as pd
from feature_spec import FeatureSpec
from neumf_constants import USER_CHANNEL_NAME, ITEM_CHANNEL_NAME, LABEL_CHANNEL_NAME
def parse_args():
parser = ArgumentParser()
parser.add_argument('--path', type=str, default='',
help='Path to input data directory')
parser.add_argument('--feature_spec_in', type=str, default='feature_spec.yaml',
help='Name of the input feature specification file, or path relative to data directory.')
parser.add_argument('--output', type=str, default='/data',
help='Path to output data directory')
parser.add_argument('--feature_spec_out', type=str, default='feature_spec.yaml',
help='Name of the output feature specification file, or path relative to data directory.')
return parser.parse_args()
def main():
args = parse_args()
args_output = args.output
args_path = args.path
args_feature_spec_in = args.feature_spec_in
args_feature_spec_out = args.feature_spec_out
feature_spec_path = os.path.join(args_path, args_feature_spec_in)
feature_spec = FeatureSpec.from_yaml(feature_spec_path)
# Only three features are transcoded - this is NCF specific
user_feature_name = feature_spec.channel_spec[USER_CHANNEL_NAME][0]
item_feature_name = feature_spec.channel_spec[ITEM_CHANNEL_NAME][0]
label_feature_name = feature_spec.channel_spec[LABEL_CHANNEL_NAME][0]
categorical_features = [user_feature_name, item_feature_name]
found_cardinalities = {f: 0 for f in categorical_features}
new_source_spec = {}
for mapping_name, mapping in feature_spec.source_spec.items():
# Load all chunks and link into one df
chunk_dfs = []
for chunk in mapping:
assert chunk['type'] == 'csv', "Only csv files supported in this transcoder"
file_dfs = []
for file in chunk['files']:
path_to_load = os.path.join(feature_spec.base_directory, file)
file_dfs.append(pd.read_csv(path_to_load, header=None))
chunk_df = pd.concat(file_dfs, ignore_index=True)
chunk_df.columns = chunk['features']
chunk_df.reset_index(drop=True, inplace=True)
chunk_dfs.append(chunk_df)
mapping_df = pd.concat(chunk_dfs, axis=1) # This takes care of making sure feature names are unique
for feature in categorical_features:
mapping_cardinality = mapping_df[feature].max() + 1
previous_cardinality = found_cardinalities[feature]
found_cardinalities[feature] = max(previous_cardinality, mapping_cardinality)
# We group together users and items, while separating labels. This is because of the target dtypes: ids are int,
# while labels are float to compute loss.
ints_tensor = torch.from_numpy(mapping_df[[user_feature_name, item_feature_name]].values).long()
ints_file = f"{mapping_name}_data_0.pt"
ints_chunk = {"type": "torch_tensor",
"features": [user_feature_name, item_feature_name],
"files": [ints_file]}
torch.save(ints_tensor, os.path.join(args_output, ints_file))
floats_tensor = torch.from_numpy(mapping_df[[label_feature_name]].values).float()
floats_file = f"{mapping_name}_data_1.pt"
floats_chunk = {"type": "torch_tensor",
"features": [label_feature_name],
"files": [floats_file]}
torch.save(floats_tensor, os.path.join(args_output, floats_file))
new_source_spec[mapping_name] = [ints_chunk, floats_chunk]
for feature in categorical_features:
found_cardinality = found_cardinalities[feature]
declared_cardinality = feature_spec.feature_spec[feature].get('cardinality', 'auto')
if declared_cardinality != "auto":
declared = int(declared_cardinality)
assert declared >= found_cardinality, "Specified cardinality conflicts data"
found_cardinalities[feature] = declared
new_inner_feature_spec = {
user_feature_name: {
"dtype": "torch.int64",
"cardinality": int(found_cardinalities[user_feature_name])
},
item_feature_name: {
"dtype": "torch.int64",
"cardinality": int(found_cardinalities[item_feature_name])
},
label_feature_name: {
"dtype": "torch.float32"
}
}
new_feature_spec = FeatureSpec(feature_spec=new_inner_feature_spec,
source_spec=new_source_spec,
channel_spec=feature_spec.channel_spec,
metadata=feature_spec.metadata,
base_directory="")
feature_spec_save_path = os.path.join(args_output, args_feature_spec_out)
new_feature_spec.to_yaml(output_path=feature_spec_save_path)
if __name__ == '__main__':
main()

View file

@ -13,9 +13,8 @@
# limitations under the License.
import os
import json
from functools import reduce
import time
def count_parameters(model):
c = map(lambda p: reduce(lambda x, y: x * y, p.size()), model.parameters())