DeepLearningExamples/PyTorch/Classification/ConvNets/launch.py
2021-04-13 19:05:27 +02:00

54 lines
1.7 KiB
Python

import os
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any
import yaml
from main import main, add_parser_arguments, available_models
import torch.backends.cudnn as cudnn
import argparse
def get_config_path():
return Path(os.path.dirname(os.path.abspath(__file__))) / "configs.yml"
if __name__ == "__main__":
yaml_cfg_parser = argparse.ArgumentParser(add_help=False)
yaml_cfg_parser.add_argument(
"--cfg_file",
default=get_config_path(),
type=str,
help="path to yaml config file",
)
yaml_cfg_parser.add_argument("--model", default=None, type=str, required=True)
yaml_cfg_parser.add_argument("--mode", default=None, type=str, required=True)
yaml_cfg_parser.add_argument("--precision", default=None, type=str, required=True)
yaml_cfg_parser.add_argument("--platform", default=None, type=str, required=True)
yaml_args, rest = yaml_cfg_parser.parse_known_args()
with open(yaml_args.cfg_file, "r") as cfg_file:
config = yaml.load(cfg_file, Loader=yaml.FullLoader)
cfg = {
**config["precision"][yaml_args.precision],
**config["platform"][yaml_args.platform],
**config["models"][yaml_args.model][yaml_args.platform][yaml_args.precision],
**config["mode"][yaml_args.mode],
}
parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
add_parser_arguments(parser)
parser.set_defaults(**cfg)
args, rest = parser.parse_known_args(rest)
model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest)
assert len(rest) == 0, f"Unknown args passed: {rest}"
cudnn.benchmark = True
main(args, model_args, model_arch)