70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
# Copyright 2017-2018 The Apache Software Foundation
|
|
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
#
|
|
# -----------------------------------------------------------------------
|
|
#
|
|
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import logging
|
|
import mxnet as mx
|
|
import numpy as np
|
|
|
|
import data, dali
|
|
import fit
|
|
import models
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Train classification models on ImageNet",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
models.add_model_args(parser)
|
|
fit.add_fit_args(parser)
|
|
data.add_data_args(parser)
|
|
dali.add_dali_args(parser)
|
|
data.add_data_aug_args(parser)
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(args):
|
|
head = '{asctime}:{levelname}: {message}'
|
|
logging.basicConfig(level=logging.DEBUG, format=head, style='{',
|
|
handlers=[logging.StreamHandler(sys.stderr), logging.FileHandler(args.log)])
|
|
logging.info('Start with arguments {}'.format(args))
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
setup_logging(args)
|
|
|
|
model = models.get_model(**vars(args))
|
|
data_loader = data.get_data_loader(args)
|
|
|
|
fit.fit(args, model, data_loader)
|