192 lines
5.2 KiB
Python
192 lines
5.2 KiB
Python
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
|
|
DEFAULT_DIR = "/outbrain"
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Tensorflow2 WideAndDeep Model",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
add_help=True,
|
|
)
|
|
|
|
locations = parser.add_argument_group("location of datasets")
|
|
|
|
locations.add_argument(
|
|
"--train_data_pattern",
|
|
type=str,
|
|
default=f"{DEFAULT_DIR}/data/train/*.parquet",
|
|
help="Pattern of training file names. For example if training files are part_0.parquet, "
|
|
"part_0.parquet then --train_data_pattern is *.parquet",
|
|
)
|
|
|
|
locations.add_argument(
|
|
"--eval_data_pattern",
|
|
type=str,
|
|
default=f"{DEFAULT_DIR}/data/valid/*.parquet",
|
|
help="Pattern of eval file names. For example if training files are part_0.parquet, "
|
|
"part_0.parquet then --eval_data_pattern is *.parquet",
|
|
)
|
|
|
|
locations.add_argument(
|
|
"--use_checkpoint",
|
|
default=False,
|
|
action="store_true",
|
|
help="Use checkpoint stored in model_dir path",
|
|
)
|
|
|
|
locations.add_argument(
|
|
"--model_dir",
|
|
type=str,
|
|
default=f"{DEFAULT_DIR}/checkpoints",
|
|
help="Destination where model checkpoint will be saved",
|
|
)
|
|
|
|
locations.add_argument(
|
|
"--results_dir",
|
|
type=str,
|
|
default="/results",
|
|
help="Directory to store training results",
|
|
)
|
|
|
|
locations.add_argument(
|
|
"--log_filename",
|
|
type=str,
|
|
default="log.json",
|
|
help="Name of the file to store dlloger output",
|
|
)
|
|
|
|
training_params = parser.add_argument_group("training parameters")
|
|
|
|
training_params.add_argument(
|
|
"--global_batch_size",
|
|
type=int,
|
|
default=131072,
|
|
help="Total size of training batch",
|
|
)
|
|
|
|
training_params.add_argument(
|
|
"--eval_batch_size",
|
|
type=int,
|
|
default=131072,
|
|
help="Total size of evaluation batch",
|
|
)
|
|
|
|
training_params.add_argument(
|
|
"--num_epochs", type=int, default=20, help="Number of training epochs"
|
|
)
|
|
|
|
training_params.add_argument(
|
|
"--cpu", default=False, action="store_true", help="Run computations on the CPU"
|
|
)
|
|
|
|
training_params.add_argument(
|
|
"--amp",
|
|
default=False,
|
|
action="store_true",
|
|
help="Enable automatic mixed precision conversion",
|
|
)
|
|
|
|
training_params.add_argument(
|
|
"--xla", default=False, action="store_true", help="Enable XLA conversion"
|
|
)
|
|
|
|
training_params.add_argument(
|
|
"--linear_learning_rate",
|
|
type=float,
|
|
default=0.02,
|
|
help="Learning rate for linear model",
|
|
)
|
|
|
|
training_params.add_argument(
|
|
"--deep_learning_rate",
|
|
type=float,
|
|
default=0.00012,
|
|
help="Learning rate for deep model",
|
|
)
|
|
|
|
training_params.add_argument(
|
|
"--deep_warmup_epochs",
|
|
type=float,
|
|
default=6,
|
|
help="Number of learning rate warmup epochs for deep model",
|
|
)
|
|
|
|
model_construction = parser.add_argument_group("model construction")
|
|
|
|
model_construction.add_argument(
|
|
"--deep_hidden_units",
|
|
type=int,
|
|
default=[1024, 1024, 1024, 1024, 1024],
|
|
nargs="+",
|
|
help="Hidden units per layer for deep model, separated by spaces",
|
|
)
|
|
|
|
model_construction.add_argument(
|
|
"--deep_dropout",
|
|
type=float,
|
|
default=0.1,
|
|
help="Dropout regularization for deep model",
|
|
)
|
|
|
|
run_params = parser.add_argument_group("run mode parameters")
|
|
|
|
run_params.add_argument(
|
|
"--evaluate",
|
|
default=False,
|
|
action="store_true",
|
|
help="Only perform an evaluation on the validation dataset, don't train",
|
|
)
|
|
|
|
run_params.add_argument(
|
|
"--benchmark",
|
|
action="store_true",
|
|
default=False,
|
|
help="Run training or evaluation benchmark to collect performance metrics",
|
|
)
|
|
|
|
run_params.add_argument(
|
|
"--benchmark_warmup_steps",
|
|
type=int,
|
|
default=500,
|
|
help="Number of warmup steps before start of the benchmark",
|
|
)
|
|
|
|
run_params.add_argument(
|
|
"--benchmark_steps",
|
|
type=int,
|
|
default=1000,
|
|
help="Number of steps for performance benchmark",
|
|
)
|
|
|
|
run_params.add_argument(
|
|
"--affinity",
|
|
type=str,
|
|
default="socket_unique_interleaved",
|
|
choices=[
|
|
"socket",
|
|
"single",
|
|
"single_unique",
|
|
"socket_unique_interleaved",
|
|
"socket_unique_continuous",
|
|
"disabled",
|
|
],
|
|
help="Type of CPU affinity",
|
|
)
|
|
|
|
return parser.parse_args()
|