diff --git a/TensorFlow/Classification/ConvNets/export_frozen_graph.py b/TensorFlow/Classification/ConvNets/export_frozen_graph.py index 2f817068..64c9ffd3 100644 --- a/TensorFlow/Classification/ConvNets/export_frozen_graph.py +++ b/TensorFlow/Classification/ConvNets/export_frozen_graph.py @@ -22,7 +22,7 @@ import os import tensorflow as tf -import horovod.tensorflow as hvd +from utils import hvd_wrapper as hvd from model import resnet tf.app.flags.DEFINE_string( @@ -75,8 +75,6 @@ FLAGS = tf.app.flags.FLAGS def main(_): - - # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs) hvd.init() if not FLAGS.output_file: diff --git a/TensorFlow/Classification/ConvNets/inference.py b/TensorFlow/Classification/ConvNets/inference.py new file mode 100644 index 00000000..8007ef08 --- /dev/null +++ b/TensorFlow/Classification/ConvNets/inference.py @@ -0,0 +1,134 @@ +import argparse +import os +import pathlib +import time +import tempfile + +import tensorflow as tf +import numpy as np + +from tensorflow.python.compiler.tensorrt import trt_convert as trt + +import dllogger + +from runtime import runner_utils +from runtime import runner +from model.resnet import model_architectures +from utils import data_utils +from utils import hvd_wrapper as hvd + +OUTPUT_SAVED_MODEL_PATH = tempfile.mkdtemp(prefix="tftrt-converted") +LOG_FREQUENCY = 100 + +def argument_parser() -> argparse.Namespace: + parser = argparse.ArgumentParser() + + exclusive_args = parser.add_mutually_exclusive_group() + exclusive_args.add_argument("--model", type=str, default=None, help="Saved model location to use for inference") + exclusive_args.add_argument("--architecture", type=str, choices=model_architectures.keys()) + + parser.add_argument("--log-path", type=str, default="./log.json", help="Path to log file") + parser.add_argument("--tf-trt", action="store_true", default=False, help="Use TF-TRT for inference") + parser.add_argument("--amp", action="store_true", default=False, help="Use AMP for inference") + parser.add_argument("--data-dir", type=str, required=False, + default=None, help="Localization of validation data") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for inference") + + return parser.parse_args() + +def main(args: argparse.Namespace): + hvd.init() + + dllogger.init(backends=[ + dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE, filename=args.log_path), + dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE) + ]) + dllogger.log(data=vars(args), step='PARAMETER') + + if args.model is None: + saved_model_to_load = tempfile.mkdtemp(prefix="tftrt-savedmodel") + r = runner.Runner(n_classes=1001, architecture=args.architecture, use_tf_amp=args.amp, + model_dir=saved_model_to_load) + r.train("batch", 1, 1, args.batch_size, is_benchmark=True) + r.evaluate("batch", 1, args.batch_size, export_dir=saved_model_to_load, + is_benchmark=True) + + saved_model_to_load = r.exported_path.decode("utf-8") + else: + saved_model_to_load = args.model + + output_tensor_name = "y_preds_ref:0" if not args.tf_trt else "ArgMax:0" + batch_size = args.batch_size + + if args.tf_trt: + converter = trt.TrtGraphConverter(input_saved_model_dir=str(saved_model_to_load), + precision_mode="FP16" if args.amp else "FP32") + converter.convert() + converter.save(OUTPUT_SAVED_MODEL_PATH) + saved_model_to_load = OUTPUT_SAVED_MODEL_PATH + elif args.amp: + os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1" + + if args.data_dir is not None: + filenames, _, num_steps, _, _ = runner_utils.parse_tfrecords_dataset( + data_dir=str(args.data_dir), + mode="validation", + iter_unit="epoch", + num_iter=1, + global_batch_size=batch_size, + ) + + + dataset = data_utils.get_tfrecords_input_fn(filenames=filenames, + batch_size=batch_size, + height=224, + width=224, + training=False, + distort_color=False, + num_threads=1, + deterministic=True) + iterator = dataset.make_initializable_iterator() + next_item = iterator.get_next() + else: + num_steps=60000 / batch_size + + + with tf.Session() as sess: + if args.data_dir is not None: + sess.run(iterator.initializer) + tf.saved_model.loader.load(sess, + [tf.saved_model.tag_constants.SERVING], + str(saved_model_to_load)) + + try: + start_time = time.time() + last_time = start_time + image_processed = 0 + image_correct = 0 + + for samples_processed in range(int(num_steps)): + if args.data_dir is not None: + next_batch_image, next_batch_target = sess.run(next_item) + else: + if samples_processed == 0: + next_batch_image = np.random.normal(size=(batch_size, 224, 224, 3)) + next_batch_target = np.random.randint(0, 1000, size=(batch_size,)) + output = sess.run([output_tensor_name], feed_dict={"input_tensor:0": next_batch_image}) + image_processed += args.batch_size + image_correct += np.sum(output == next_batch_target) + + if samples_processed % LOG_FREQUENCY == 0 and samples_processed != 0: + current_time = time.time() + current_throughput = LOG_FREQUENCY * batch_size / (current_time - last_time) + dllogger.log(step=(0, samples_processed), data={"throughput": current_throughput}) + last_time = current_time + + except tf.errors.OutOfRangeError: + pass + finally: + dllogger.log(step=tuple(), data={"throughput": image_processed / (last_time - start_time), + "accuracy": image_correct / image_processed}) + + +if __name__ == "__main__": + main(argument_parser()) \ No newline at end of file diff --git a/TensorFlow/Classification/ConvNets/main.py b/TensorFlow/Classification/ConvNets/main.py index ae0abc33..916b0602 100755 --- a/TensorFlow/Classification/ConvNets/main.py +++ b/TensorFlow/Classification/ConvNets/main.py @@ -22,10 +22,9 @@ warnings.simplefilter("ignore") import tensorflow as tf -import horovod.tensorflow as hvd +from utils import hvd_wrapper as hvd import dllogger -from utils import hvd_utils from runtime import Runner from model.resnet import model_architectures @@ -36,7 +35,7 @@ if __name__ == "__main__": tf.logging.set_verbosity(tf.logging.ERROR) FLAGS = parse_cmdline(model_architectures.keys()) - hvd.init() + hvd.init(True) if hvd.rank() == 0: log_path = os.path.join(FLAGS.results_dir, FLAGS.log_filename) @@ -100,11 +99,10 @@ if __name__ == "__main__": if FLAGS.mode in ["train_and_evaluate", 'evaluate', 'inference_benchmark']: - if FLAGS.mode == 'inference_benchmark' and hvd_utils.is_using_hvd(): + if FLAGS.mode == 'inference_benchmark' and hvd.size() > 1: raise NotImplementedError("Only single GPU inference is implemented.") - elif not hvd_utils.is_using_hvd() or hvd.rank() == 0: - + elif hvd.rank() == 0: runner.evaluate(iter_unit=FLAGS.iter_unit if FLAGS.mode != "train_and_evaluate" else "epoch", num_iter=FLAGS.num_iter if FLAGS.mode != "train_and_evaluate" else 1, warmup_steps=FLAGS.warmup_steps, @@ -124,10 +122,10 @@ if __name__ == "__main__": if not os.path.isfile(FLAGS.to_predict): raise ValueError("Only prediction on single images is supported!") - if hvd_utils.is_using_hvd(): + if hvd.size() > 1: raise NotImplementedError("Only single GPU inference is implemented.") - elif not hvd_utils.is_using_hvd() or hvd.rank() == 0: + else: runner.predict(FLAGS.to_predict, quantize=FLAGS.quantize, symmetric=FLAGS.symmetric, diff --git a/TensorFlow/Classification/ConvNets/model/blocks/conv2d_block.py b/TensorFlow/Classification/ConvNets/model/blocks/conv2d_block.py index 00e73dfe..7f2f3f10 100644 --- a/TensorFlow/Classification/ConvNets/model/blocks/conv2d_block.py +++ b/TensorFlow/Classification/ConvNets/model/blocks/conv2d_block.py @@ -64,10 +64,10 @@ def conv2d_block( trainable=is_training, dtype=tf.float32) net = tf.nn.conv2d(inputs, - group_filter, - strides=strides, - padding='SAME', - data_format=data_format) + group_filter, + strides=strides, + padding='SAME', + data_format=data_format) if use_batch_norm: net = layers.batch_norm( net, diff --git a/TensorFlow/Classification/ConvNets/model/resnet.py b/TensorFlow/Classification/ConvNets/model/resnet.py index 6ea2fb90..689ac1a5 100755 --- a/TensorFlow/Classification/ConvNets/model/resnet.py +++ b/TensorFlow/Classification/ConvNets/model/resnet.py @@ -19,15 +19,13 @@ from __future__ import print_function import tensorflow as tf -import horovod.tensorflow as hvd +from utils import hvd_wrapper as hvd import dllogger from model import layers from model import blocks from utils import var_storage -from utils import hvd_utils - from utils.data_utils import normalized_inputs from utils.learning_rate import learning_rate_scheduler @@ -337,8 +335,8 @@ class ResnetModel(object): if params["apply_loss_scaling"]: optimizer = FixedLossScalerOptimizer(optimizer, scale=params["loss_scale"]) - if hvd_utils.is_using_hvd(): - optimizer = hvd.DistributedOptimizer(optimizer) + if hvd.size() > 1: + optimizer = hvd.hvd_global_object.DistributedOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if mode != tf.estimator.ModeKeys.TRAIN: diff --git a/TensorFlow/Classification/ConvNets/resnet50v1.5/README.md b/TensorFlow/Classification/ConvNets/resnet50v1.5/README.md index 5d3be4d4..a0a181e1 100644 --- a/TensorFlow/Classification/ConvNets/resnet50v1.5/README.md +++ b/TensorFlow/Classification/ConvNets/resnet50v1.5/README.md @@ -276,7 +276,6 @@ The `utils/` directory contains the following utility modules: - `cmdline_helper.py`: helper module for command line processing - `data_utils.py`: module defining input data pipelines - `dali_utils.py`: helper module for DALI - - `hvd_utils.py`: helper module for Horovod - `image_processing.py`: image processing and data augmentation functions - `learning_rate.py`: definition of used learning rate schedule - `optimizers.py`: definition of used custom optimizers @@ -447,7 +446,11 @@ To run inference on a single example with a checkpoint and a model script, use: `python main.py --mode predict --model_dir --to_predict --results_dir ` -The optional `--xla` and `--amp` flags control XLA and AMP during inference. +The optional `--xla` and `--amp` flags control XLA and AMP during inference. To run inference using TF-TRT, please use the following command: + +`python inference.py --model --tf-trt --batch-size --data-dir ` + +The optional `--amp` flag controls AMP during inference. ## Performance diff --git a/TensorFlow/Classification/ConvNets/resnext101-32x4d/README.md b/TensorFlow/Classification/ConvNets/resnext101-32x4d/README.md index cb7b7c65..3602f18d 100644 --- a/TensorFlow/Classification/ConvNets/resnext101-32x4d/README.md +++ b/TensorFlow/Classification/ConvNets/resnext101-32x4d/README.md @@ -283,7 +283,6 @@ The `utils/` directory contains the following utility modules: - `cmdline_helper.py`: helper module for command line processing - `data_utils.py`: module defining input data pipelines - `dali_utils.py`: helper module for DALI - - `hvd_utils.py`: helper module for Horovod - `image_processing.py`: image processing and data augmentation functions - `learning_rate.py`: definition of used learning rate schedule - `optimizers.py`: definition of used custom optimizers diff --git a/TensorFlow/Classification/ConvNets/runtime/runner.py b/TensorFlow/Classification/ConvNets/runtime/runner.py index 1da8c011..7950b35e 100755 --- a/TensorFlow/Classification/ConvNets/runtime/runner.py +++ b/TensorFlow/Classification/ConvNets/runtime/runner.py @@ -21,13 +21,11 @@ import warnings import tensorflow as tf import numpy as np -import horovod.tensorflow as hvd - from model import resnet from utils import hooks from utils import data_utils -from utils import hvd_utils +from utils import hvd_wrapper as hvd from runtime import runner_utils @@ -142,8 +140,8 @@ class Runner(object): gpu_id=gpu_id) run_config_additional = tf.contrib.training.HParams( - model_dir=model_dir, #if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None, - log_dir=log_dir if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None, + model_dir=model_dir, + log_dir=log_dir if hvd.rank() == 0 else None, data_dir=data_dir, data_idx_dir=data_idx_dir, num_preprocessing_threads=num_preprocessing_threads) @@ -196,11 +194,7 @@ class Runner(object): @staticmethod def _get_global_batch_size(worker_batch_size): - - if hvd_utils.is_using_hvd(): - return worker_batch_size * hvd.size() - else: - return worker_batch_size + return worker_batch_size * hvd.size() @staticmethod def _get_session_config(mode, use_xla, use_dali, use_cpu, gpu_memory_fraction, gpu_id=0): @@ -225,7 +219,7 @@ class Runner(object): config.gpu_options.visible_device_list = str(gpu_id) config.gpu_options.force_gpu_compatible = True # Force pinned memory - if hvd_utils.is_using_hvd(): + if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) config.gpu_options.force_gpu_compatible = True # Force pinned memory @@ -248,10 +242,7 @@ class Runner(object): mode) if seed is not None: - if hvd_utils.is_using_hvd(): - tf_random_seed = 2 * (seed + hvd.rank()) - else: - tf_random_seed = 2 * seed + tf_random_seed = 2 * (seed + hvd.rank()) else: tf_random_seed = None @@ -277,11 +268,8 @@ class Runner(object): experimental_distribute=None) if mode == 'train': - if hvd_utils.is_using_hvd(): - config = config.replace(save_checkpoints_steps=1000 if hvd.rank() == 0 else None, - keep_checkpoint_every_n_hours=3) - else: - config = config.replace(save_checkpoints_steps=1000, keep_checkpoint_every_n_hours=3) + config = config.replace(save_checkpoints_steps=1000 if hvd.rank() == 0 else None, + keep_checkpoint_every_n_hours=3) return config @@ -343,7 +331,7 @@ class Runner(object): else: use_static_loss_scaling = False # Make sure it hasn't been set to True on FP32 training - num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size() + num_gpus = hvd.size() global_batch_size = batch_size * num_gpus if self.run_hparams.data_dir is not None: @@ -402,8 +390,8 @@ class Runner(object): ) training_hooks.append(self.training_logging_hook) - if hvd_utils.is_using_hvd(): - bcast_hook = hvd.BroadcastGlobalVariablesHook(0) + if hvd.size() > 1: + bcast_hook = hvd.hvd_global_object.BroadcastGlobalVariablesHook(0) training_hooks.append(bcast_hook) training_hooks.append(hooks.PrefillStagingAreasHook()) @@ -527,7 +515,7 @@ class Runner(object): if self.run_hparams.data_dir is None and not is_benchmark: raise ValueError('`data_dir` must be specified for evaluation!') - if hvd_utils.is_using_hvd() and hvd.rank() != 0: + if hvd.rank() != 0: raise RuntimeError('Multi-GPU inference is not supported') estimator_params = {'quantize': quantize, @@ -620,18 +608,25 @@ class Runner(object): ) eval_throughput = self.eval_logging_hook.mean_throughput.value() - eval_latencies = np.array(self.eval_logging_hook.latencies) * 1000 - eval_latencies_q = np.quantile(eval_latencies, q=[0.9, 0.95, 0.99]) - eval_latencies_mean = np.mean(eval_latencies) + if len(self.eval_logging_hook.latencies) > 0: + eval_latencies = np.array(self.eval_logging_hook.latencies) * 1000 + eval_latencies_q = np.quantile(eval_latencies, q=[0.9, 0.95, 0.99]) + eval_latencies_mean = np.mean(eval_latencies) + additional_metrics = { + 'eval_latency_avg': eval_latencies_mean, + 'eval_latency_p90': eval_latencies_q[0], + 'eval_latency_p95': eval_latencies_q[1], + 'eval_latency_p99': eval_latencies_q[2], + } + else: + additional_metrics = {} + dllogger.log(data={ 'top1_accuracy': float(eval_results['top1_accuracy']), 'top5_accuracy': float(eval_results['top5_accuracy']), 'eval_throughput': eval_throughput, - 'eval_latency_avg': eval_latencies_mean, - 'eval_latency_p90': eval_latencies_q[0], - 'eval_latency_p95': eval_latencies_q[1], - 'eval_latency_p99': eval_latencies_q[2], + **additional_metrics }, step=tuple()) @@ -644,7 +639,7 @@ class Runner(object): data_format=self.run_hparams.input_format, dtype=self.run_hparams.dtype) - image_classifier.export_savedmodel(export_dir, input_receiver_fn) + self.exported_path = image_classifier.export_savedmodel(export_dir, input_receiver_fn) except KeyboardInterrupt: print("Keyboard interrupt") diff --git a/TensorFlow/Classification/ConvNets/se-resnext101-32x4d/README.md b/TensorFlow/Classification/ConvNets/se-resnext101-32x4d/README.md index bdf29515..c05d7a9d 100644 --- a/TensorFlow/Classification/ConvNets/se-resnext101-32x4d/README.md +++ b/TensorFlow/Classification/ConvNets/se-resnext101-32x4d/README.md @@ -278,7 +278,6 @@ The `utils/` directory contains the following utility modules: - `cmdline_helper.py`: helper module for command line processing - `data_utils.py`: module defining input data pipelines - `dali_utils.py`: helper module for DALI - - `hvd_utils.py`: helper module for Horovod - `image_processing.py`: image processing and data augmentation functions - `learning_rate.py`: definition of used learning rate schedule - `optimizers.py`: definition of used custom optimizers diff --git a/TensorFlow/Classification/ConvNets/triton/rn50_model.py b/TensorFlow/Classification/ConvNets/triton/rn50_model.py index 29325bf8..e2942b67 100644 --- a/TensorFlow/Classification/ConvNets/triton/rn50_model.py +++ b/TensorFlow/Classification/ConvNets/triton/rn50_model.py @@ -23,8 +23,8 @@ def get_model( use_dali: bool = False, gpu_memory_fraction=0.7, ): - import horovod.tensorflow as hvd from runtime import Runner + from utils import hvd_wrapper as hvd hvd.init() diff --git a/TensorFlow/Classification/ConvNets/utils/dali_utils.py b/TensorFlow/Classification/ConvNets/utils/dali_utils.py index ff8ab00a..9ab9faf1 100644 --- a/TensorFlow/Classification/ConvNets/utils/dali_utils.py +++ b/TensorFlow/Classification/ConvNets/utils/dali_utils.py @@ -18,10 +18,9 @@ import sys import tensorflow as tf -import horovod.tensorflow as hvd from utils import image_processing -from utils import hvd_utils +from utils import hvd_wrapper as hvd from nvidia import dali import nvidia.dali.plugin.tf as dali_tf diff --git a/TensorFlow/Classification/ConvNets/utils/data_utils.py b/TensorFlow/Classification/ConvNets/utils/data_utils.py index 8f9bda74..8099815a 100644 --- a/TensorFlow/Classification/ConvNets/utils/data_utils.py +++ b/TensorFlow/Classification/ConvNets/utils/data_utils.py @@ -18,11 +18,10 @@ import sys import tensorflow as tf -import horovod.tensorflow as hvd from utils import image_processing -from utils import hvd_utils from utils import dali_utils +from utils import hvd_wrapper as hvd __all__ = ["get_synth_input_fn", "normalized_inputs"] @@ -82,16 +81,13 @@ def get_tfrecords_input_fn(filenames, batch_size, height, width, training, disto shuffle_buffer_size = 4096 if deterministic: - if hvd_utils.is_using_hvd(): - seed = 13 * (1 + hvd.rank()) - else: - seed = 13 + seed = 13 * hvd.rank() else: seed = None ds = tf.data.Dataset.from_tensor_slices(filenames) - if hvd_utils.is_using_hvd() and training: + if hvd.size() > 1 and training: ds = ds.shard(hvd.size(), hvd.rank()) ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=10, block_length=8) diff --git a/TensorFlow/Classification/ConvNets/utils/hooks/training_hooks.py b/TensorFlow/Classification/ConvNets/utils/hooks/training_hooks.py index edee651a..bf2f00fd 100755 --- a/TensorFlow/Classification/ConvNets/utils/hooks/training_hooks.py +++ b/TensorFlow/Classification/ConvNets/utils/hooks/training_hooks.py @@ -21,8 +21,7 @@ import tensorflow as tf import dllogger import signal -import horovod.tensorflow as hvd -from utils.hvd_utils import is_using_hvd +from utils import hvd_wrapper as hvd __all__ = ['TrainingLoggingHook', 'TrainingPartitionHook'] @@ -125,17 +124,17 @@ class TrainingPartitionHook(tf.estimator.SessionRunHook): signal.signal(signal.SIGTERM, self._signal_handler) def begin(self): - if is_using_hvd(): + if hvd.size() > 1: with tf.device("/cpu:0"): self.input_op = tf.placeholder(tf.int32, shape=()) - self.allreduce_op = hvd.allreduce(self.input_op, op=hvd.Sum, - name="signal_handler_all_reduce") + self.allreduce_op = hvd.hvd_global_object.allreduce( + self.input_op, op=hvd.hvd_global_object.Sum, name="signal_handler_all_reduce") def before_run(self, run_context): fetches = [tf.train.get_global_step()] feed_dict = None - if is_using_hvd() and (self.global_step % self.sync_freq) == 0: + if hvd.size() > 1 and (self.global_step % self.sync_freq) == 0: fetches += [self.allreduce_op] feed_dict = {self.input_op: int(self.signal_recieved)} @@ -144,7 +143,7 @@ class TrainingPartitionHook(tf.estimator.SessionRunHook): def after_run(self, run_context, run_values): self.global_step = run_values.results[0] + 1 - if is_using_hvd() and len(run_values.results) == 2: + if hvd.size() > 1 and len(run_values.results) == 2: if run_values.results[1] > 0: run_context.request_stop() elif self.signal_recieved: diff --git a/TensorFlow/Classification/ConvNets/utils/hvd_utils.py b/TensorFlow/Classification/ConvNets/utils/hvd_utils.py deleted file mode 100644 index 8f37a4c9..00000000 --- a/TensorFlow/Classification/ConvNets/utils/hvd_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import horovod.tensorflow as hvd - - -def is_using_hvd(): - return hvd.size() > 1 - rank_env = ['HOROVOD_RANK', 'OMPI_COMM_WORLD_RANK', 'PMI_RANK'] - size_env = ['HOROVOD_SIZE', 'OMPI_COMM_WORLD_SIZE', 'PMI_SIZE'] - - for r_var, s_var in zip(rank_env, size_env): - if r_var in os.environ and s_var in os.environ: - return int(s_var) > 1 - return False diff --git a/TensorFlow/Classification/ConvNets/utils/hvd_wrapper.py b/TensorFlow/Classification/ConvNets/utils/hvd_wrapper.py new file mode 100644 index 00000000..a326c2a4 --- /dev/null +++ b/TensorFlow/Classification/ConvNets/utils/hvd_wrapper.py @@ -0,0 +1,32 @@ +hvd_global_object = None + +def init(use_horovod: bool = False): + global hvd_global_object + if use_horovod: + import horovod.tensorflow as hvd + hvd.init() + hvd_global_object = hvd + else: + class _DummyWrapper: + def rank(self): return 0 + def size(self): return 1 + def local_rank(self): return 0 + def local_size(self): return 1 + hvd_global_object = _DummyWrapper() + + +def size(): + global hvd_global_object + return hvd_global_object.size() + +def rank(): + global hvd_global_object + return hvd_global_object.rank() + +def local_rank(): + global hvd_global_object + return hvd_global_object.local_rank() + +def local_size(): + global hvd_global_object + return hvd_global_object.local_size() \ No newline at end of file