DeepLearningExamples/TensorFlow/Translation/GNMT/variable_mgr/batch_allreduce.py

613 lines
26 KiB
Python

# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains classes and functions for doing a single-machine batch all-reduce.
An all-reduce is taking the reduction (typically a sum) of a list of tensors,
each on a different device. The result must end up back on each device, which is
where the word "all" comes from. In summary, each device starts with a single
tensor, and ends up with the reduction of all tensors.
A batch all-reduce is doing several independent all-reduces. When doing a batch
all-reduce, care is taken to evenly distribute the reduction computations
across devices and inter-device tensor transfers across device links.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# TODO(reedwm): Support distributed all-reduces in this file.
# TODO(reedwm): Merge this code with allreduce.py, which contains some batch
# all-reduce code that this file calls. allreduce.py also supports distributed
# batch-reduce while this file only supports single-machine all-reduce.
import abc
from collections import namedtuple
import six
import tensorflow as tf
from tensorflow.python.ops import gradients_impl
from variable_mgr import allreduce
from variable_mgr import constants
def _all_reduce_using_copy(tensors_across_devices, use_mean):
"""Does an all-reduce of a list of tensors by copying to the current device.
The tensors are copied to the current device and then reduced.
Args:
tensors_across_devices: A list of tensors, each on a different device.
use_mean: Whether to take the mean of the tensors instead of a sum:
Returns:
A reduced tensor on the current device.
"""
assert tensors_across_devices
if isinstance(tensors_across_devices[0], tf.IndexedSlices):
reduced_tensor = gradients_impl._AggregateIndexedSlicesGradients(
tensors_across_devices)
if use_mean:
val = tf.multiply(reduced_tensor.values,
float(1. / len(tensors_across_devices)))
reduced_tensor = tf.IndexedSlices(val, reduced_tensor.indices,
reduced_tensor.dense_shape)
else:
reduced_tensor = tf.add_n(tensors_across_devices)
if use_mean:
reduced_tensor *= 1. / len(tensors_across_devices)
return reduced_tensor
@six.add_metaclass(abc.ABCMeta)
class BatchAllReduceAlgorithm(object):
"""Represents an algorithm for performing a batch all-reduce operation."""
def batch_all_reduce(self, all_device_tensors, num_splits, compact_tensors,
defer_tensors):
"""Performs a batch all-reduce.
The reduction done is a sum.
`all_device_tensors` is a list of list of tensors that will be batch
all-reduced. All tensors within a single inner list must be on the same
device. The nth element in each list, for any n, will be reduced together.
The return value is in the same form as `all_device_tensors`, except that
each tensor is reduced.
For example, if `all_device_tensors` is:
[[ A, B ], # A and B are on GPU 0
[ C, D ]] # C and D are on GPU 1
Then the return value will be:
[[ A+C, B+D ], # These two tensors are on GPU 0
[ A+C, B+D ]] # These two tensors are on GPU 1
Arguments:
all_device_tensors: A list of list of tensors. `all_device_tensors[i][j]`
is a tensor where `i` is the device index and `j` is the tensor index.
num_splits: If not None, tensors will be concatenated and split into this
many pieces during the all-reduce, then split back into their original
shapes afterwards. Has no impact on correctness and can improve
performance. Requires all tensors to be the same type.
compact_tensors: If True, tensors are casted to fp16 before being all-
reduced. Improves performance, but hurts numerical stability.
defer_tensors: If True, every time the return value
`reduced_all_device_tensors` is evaluated, the result will be the
reduced tensors values of `all_device_tensors` from the previous session
run instead of the current session run, or zero on the first session
run. This can improve performance. When training neural networks,
deferring gradients often does not harm training, so this can be used to
improve performance.
Returns:
reduced_all_device_tensors: A list in the same form as
`all_device_tensors`, except each tensor has been reduced.
warmup_ops: A list of ops needed to be run once before the all-reduce can
occur.
"""
# Before all-reducing tensors, we do several preprocessing functions that
# can speed up the all-reduce. We undo these functions after all-reducing
# the tensors.
warmup_ops = []
if num_splits:
packer = _TensorPacker(num_splits)
all_device_tensors = packer.concat_all_device_tensors(all_device_tensors)
# If enabled, we compact and defer tensors in between concatenating them
# and splitting them, because it is faster to do operations on a single
# concatenated tensor than on multiple smaller tensors.
if compact_tensors:
all_device_tensors_before_compact = all_device_tensors
all_device_tensors = _compact_all_device_tensors(all_device_tensors)
if defer_tensors:
all_device_tensors, put_ops, warmup_ops = _defer_all_device_tensors(
all_device_tensors)
if num_splits:
all_device_tensors = packer.split_all_device_tensors(all_device_tensors)
all_device_tensors = self._do_batch_all_reduce(all_device_tensors)
# Undo the preprocessing operations in opposite order as we applied them.
if num_splits:
all_device_tensors = packer.undo_split_all_device_tensors(
all_device_tensors)
# Note: There is no undo operation for deferring tensors. But we do need to
# call _add_put_op_control_deps at the end if we deferred the tensors.
if compact_tensors:
all_device_tensors = _undo_compact_all_device_tensors(
all_device_tensors, all_device_tensors_before_compact)
if num_splits:
all_device_tensors = packer.undo_concat_all_device_tensors(
all_device_tensors)
if defer_tensors:
all_device_tensors = _add_put_op_control_deps(all_device_tensors,
num_splits, put_ops)
return all_device_tensors, warmup_ops
@abc.abstractmethod
def _do_batch_all_reduce(self, all_device_tensors):
"""Performs a batch all-reduce.
Unlike `self.batch_all_reduce`, this does not do any preprocessing of the
tensors.
Args:
all_device_tensors: A list of list of tensors. `all_device_tensors[i][j]`
is a tensor where `i` is the device index and `j` is the tensor index.
Returns:
reduced_all_device_tensors: A list in the same form as
`all_device_tensors`, except each tensor has been reduced.
"""
pass
class CopyToDeviceAlgorithm(BatchAllReduceAlgorithm):
"""An algorithm that copies tensors to be reduced to a specific device."""
def __init__(self, devices_to_reduce_on, use_mean=False):
self._devices = devices_to_reduce_on
self._use_mean = use_mean
def _do_batch_all_reduce(self, all_device_tensors):
reduced_tensors = []
for i, tensors_across_devices in enumerate(zip(*all_device_tensors)):
with tf.device(self._devices[i % len(self._devices)]):
reduced_tensor = _all_reduce_using_copy(tensors_across_devices,
self._use_mean)
reduced_tensors.append(reduced_tensor)
# The tensors will be brought back to each device once they are used.
return [reduced_tensors] * len(all_device_tensors)
class HierarchicalCopyAlgorithm(BatchAllReduceAlgorithm):
"""An algorithm that uses hierarchical copies. This is only optimized for
eight devices connected in NetworkTopology.DGX1 or NetworkTopology.GCP_V100
topology.
"""
def __init__(self, network_topology):
"""Initializer for HierarchicalCopyAlgorithm.
Args:
network_topology: An instance of Enum class constants.NetworkTopology.
"""
self._network_topology = network_topology
def _do_batch_all_reduce(self, all_device_tensors):
avail_devices = [device_tensors[0].device
for device_tensors in all_device_tensors]
reduced_tensors = []
num_devices = len(avail_devices)
group_size = num_devices // 2
for i, tensors_across_devices in enumerate(zip(*all_device_tensors)):
group_0_main_device, group_1_main_device = self.__get_main_devices(
i, num_devices)
if group_0_main_device < group_size:
group_0_begin = 0
group_1_begin = group_size
else:
group_0_begin = group_size
group_1_begin = 0
# Reduce the first group.
group_0_tensors = tensors_across_devices[group_0_begin:
group_0_begin + group_size]
with tf.device(avail_devices[group_0_main_device]):
group_0_reduced_tensor = _all_reduce_using_copy(group_0_tensors, False)
# Reduce the second group.
group_1_tensors = tensors_across_devices[group_1_begin:
group_1_begin + group_size]
with tf.device(avail_devices[group_1_main_device]):
group_1_reduced_tensor = _all_reduce_using_copy(group_1_tensors, False)
# Reduce between the groups.
with tf.device(avail_devices[group_0_main_device]):
total_reduced_tensor = _all_reduce_using_copy(
[group_0_reduced_tensor, group_1_reduced_tensor], False)
# Broadcast the result back into the root of each group.
with tf.device(avail_devices[group_0_main_device]):
group_0_reduced_tensor_bcast = tf.identity(total_reduced_tensor)
with tf.device(avail_devices[group_1_main_device]):
group_1_reduced_tensor_bcast = tf.identity(total_reduced_tensor)
reduced_tensors_bcast = []
for j in range(len(tensors_across_devices)):
with tf.device(avail_devices[j]):
# Broadcast the result back to each member in the group from the root.
if (group_0_main_device < group_size) == (j < group_size):
src_device_tensor = group_0_reduced_tensor_bcast
else:
src_device_tensor = group_1_reduced_tensor_bcast
reduced_tensors_bcast.append(tf.identity(src_device_tensor))
reduced_tensors.append(reduced_tensors_bcast)
reduced_tensors = list(zip(*reduced_tensors))
return reduced_tensors
def __get_main_devices(self, tensor_index, num_devices):
"""Returns the pair of main devices to use for initial reduction.
Args:
tensor_index: Index of the current tensor in the list of tensors to copy.
num_devices: Total number of devices.
Returns:
A tuple containing pair of main device indices for the initial
reduction. Then, the first element of the tuple should be used for the
final reduction.
Raises:
ValueError: Invalid input arguments.
"""
if self._network_topology == constants.NetworkTopology.DGX1:
return tensor_index % num_devices, (tensor_index +
(num_devices // 2)) % num_devices
elif self._network_topology == constants.NetworkTopology.GCP_V100:
if num_devices != 8:
raise ValueError('HierarchicalCopy only supports eight devices in %s.' %
self._network_topology)
# TODO(hinsu): Generalize main device indices to handle any other
# isomorphic connection graph that connects two cliques using connections
# other than 0-5 and 2-7.
main_device_pairs = [(0, 5), (2, 7), (5, 0), (7, 2)]
return main_device_pairs[tensor_index % len(main_device_pairs)]
else:
# TODO(reedwm): make this logic more general for arbitrary topology.
raise ValueError(
'HierarchicalCopy is not supported for %s network topology.' %
self._network_topology)
class AllReduceSpecAlgorithm(BatchAllReduceAlgorithm):
"""An algorithm that uses an all reduce spec."""
def __init__(self, all_reduce_spec, gpu_indices, agg_small_grads_max_bytes,
agg_small_grads_max_group):
spec = allreduce.parse_all_reduce_spec(all_reduce_spec)
if len(spec) != 1:
raise ValueError(
'Replicated mode does not support hybrid all-reduce strategies')
self._all_reduce_spec = spec[0]
self._gpu_indices = gpu_indices
self._agg_small_grads_max_bytes = agg_small_grads_max_bytes
self._agg_small_grads_max_group = agg_small_grads_max_group
def _do_batch_all_reduce(self, all_device_tensors):
# TODO(reedwm): Merge allreduce.sum_gradients_all_reduce with the other
# gradient aggregation code, since gradient aggregation is doing an all
# reduce. Currently, we do gradient repacking in two different places.
# TODO(reedwm): Change the allreduce code to reduce tensors instead of
# tower_grads.
tower_grads = [[(t, None) for t in device_tensors]
for device_tensors in all_device_tensors]
aggregated_device_grads = allreduce.sum_gradients_all_reduce(
False, # single_session
['/job:localhost'],
tower_grads,
1,
self._all_reduce_spec.alg,
self._all_reduce_spec.shards,
self._gpu_indices,
agg_small_grads_max_bytes=self._agg_small_grads_max_bytes,
agg_small_grads_max_group=self._agg_small_grads_max_group)
return [[t for t, _ in grad_vars] for grad_vars in aggregated_device_grads]
def algorithm_from_params(params):
"""Returns a BatchAllReduceAlgorithm from a Params tuple."""
if params.all_reduce_spec:
if params.gpu_indices:
gpu_indices = [int(x) for x in params.gpu_indices.split(',')]
else:
gpu_indices = [x for x in range(params.num_gpus)]
return AllReduceSpecAlgorithm(params.all_reduce_spec, gpu_indices,
params.agg_small_grads_max_bytes,
params.agg_small_grads_max_group)
elif params.hierarchical_copy:
return HierarchicalCopyAlgorithm(params.network_topology)
else:
if params.local_parameter_device == 'gpu':
devices_to_reduce_on = ['/gpu:%d' % i for i in range(params.num_gpus)]
else:
devices_to_reduce_on = ['/cpu:0']
#### Made only for adam optimizer ####
return CopyToDeviceAlgorithm(devices_to_reduce_on, use_mean=True)
def _apply_to_all_device_tensors(all_device_tensors, apply_func, colocate=True):
"""Applies a function to each tensor in `all_device_tensors`.
A new list of lists of tensors is returned, where every tensor in
`all_device_tensors` has had `apply_func` called on it. `all_device_tensors`
is not modified.
Args:
all_device_tensors: A list of list of tensors. `all_device_tensors[i][j]` is
a tensor where `i` is the device index and `j` is the tensor index.
apply_func: A function taking in three arguments: tensor, device_index,
tensor_index, and returning a modified tensor.
`tensor` is `all_device_tensors[device_index][tensor_index]`.
colocate: If True, apply_func will be run under context manager colocated
with it's input tensor.
Returns:
A list in the same form as `all_device_tensors`, except each tensor has had
`apply_func` called on it.
"""
new_all_device_tensors = []
for device_index, device_tensors in enumerate(all_device_tensors):
new_device_tensors = []
for tensor_index, t in enumerate(device_tensors):
if colocate:
with tf.colocate_with(t):
new_t = apply_func(t, device_index, tensor_index)
else:
new_t = apply_func(t, device_index, tensor_index)
new_device_tensors.append(new_t)
new_all_device_tensors.append(new_device_tensors)
return new_all_device_tensors
def _defer_tensor(tensor):
"""Defers the retrieval of a tensor.
The tensor is put into a StagingArea, and the return value is the
retrieval of the tensor from the StagingArea. The effect is that the
tensor returned from this function is the tensor that was put in the
StagingArea for the previous Session.run() call.
Args:
tensor: The tensor to defer for one step.
Returns:
deferred_tensor: The tensor deferred for one step.
put_op: An op to put `tensor` in the StagingArea. Must be run every step
that `deferred_tensor` is run.
warmup_op: A warmup op that should be called before the first step. Puts
a zero tensor into the StagingArea.
"""
tensor_stage = tf.contrib.staging.StagingArea([tensor.dtype], [tensor.shape])
put_op = tensor_stage.put([tensor])
warmup_op = tensor_stage.put([tf.zeros(tensor.shape, dtype=tensor.dtype)])
# Fetch the next tensor to use.
(tensor,) = tensor_stage.get()
return tensor, put_op, warmup_op
def _defer_all_device_tensors(all_device_tensors):
"""Defers every tensor in `all_device_tensors`."""
put_ops = [[] for _ in all_device_tensors]
warmup_ops = [[] for _ in all_device_tensors]
def apply_func(tensor, device_index, tensor_index):
del tensor_index
tensor, put_op, warmup_op = _defer_tensor(tensor)
put_ops[device_index].append(put_op)
warmup_ops[device_index].append(warmup_op)
return tensor
new_all_device_tensors = _apply_to_all_device_tensors(all_device_tensors,
apply_func)
return new_all_device_tensors, put_ops, warmup_ops
def _add_put_op_control_deps(all_device_tensors, num_splits, put_ops):
"""Add control dependencies from `put_ops` to `all_device_tensors`.
This should only be called when deferred tensors are being used.
The control dependencies are added so that the put ops are run whenever
`all_device_tensors` is run. That way, the caller does not have to explicitly
run the put ops.
Args:
all_device_tensors: A list of list of tensors. `all_device_tensors[i][j]` is
a tensor where `i` is the device index and `j` is the tensor index.
num_splits: The number of splits that were used for the all-reduce.
put_ops: A list of put ops from deferring the tensors.
Returns:
A list in the same form as `all_device_tensors`, except each tensor has a
control dependency on an op in `put_ops`.
"""
def apply_func(tensor, device_index, tensor_index):
if num_splits == 0:
deps = [put_ops[device_index][tensor_index]]
else:
deps = put_ops[device_index]
assert len(deps) == 1
with tf.control_dependencies(deps):
return tf.identity(tensor, name='control_dependency')
return _apply_to_all_device_tensors(all_device_tensors, apply_func)
def _compact_all_device_tensors(all_device_tensors):
"""Compacts each tensor by casting to fp16."""
def apply_func(tensor, device_index, tensor_index):
del device_index, tensor_index
return tf.cast(tensor, tf.float16)
return _apply_to_all_device_tensors(all_device_tensors, apply_func)
def _undo_compact_all_device_tensors(all_device_tensors,
orig_all_device_tensors):
"""Uncompacts each tensor by casting to it's original dtype."""
def apply_func(tensor, device_index, tensor_index):
orig_tensor = orig_all_device_tensors[device_index][tensor_index]
with tf.colocate_with(orig_tensor):
return tf.cast(tensor, orig_tensor.dtype)
return _apply_to_all_device_tensors(all_device_tensors, apply_func,
colocate=False)
class _TensorPacker(object):
"""Packs and unpacks tensors into groups.
This class first concatenates a set of tensors, then split the concatenated
tensor into a small number of chunks. This is useful for all-reducing tensors,
as doing a small number of all-reduces on large tensors can be faster than
doing a large number of all-reduces on small tensors.
"""
def __init__(self, num_splits):
"""Initializes the _TensorPacker.
Args:
num_splits: The number of tensors to split the concatenated tensor into.
The batch all-reduce will consist of `num_splits` all-reduces.
"""
assert num_splits > 0
self._num_splits = num_splits
self._next_method = 'concat'
_concat_tensor_state = namedtuple('_concat_tensor_state',
['orig_shapes', 'orig_sizes'])
def _concat_tensors(self, device_tensors):
"""Concatenate tensors into a single tensor."""
flat_tensors = [tf.reshape(t, [-1]) for t in device_tensors]
orig_shapes = [t.shape for t in device_tensors]
orig_sizes = [s.num_elements() for s in orig_shapes]
# All shapes must be fully defined.
assert None not in orig_sizes
concatenated_grad = tf.concat(flat_tensors, 0)
return concatenated_grad, self._concat_tensor_state(orig_shapes, orig_sizes)
def _split_tensors(self, concatenated_tensor):
"""Splits concatenated tensor into `num_splits` pieces."""
# TODO(zhengxq): it is possible to optimize away the additional
# data movement by copying along the original tensor boundary.
# TODO(zhengxq): it is also possible to optimize away all the concat
# as well.
total_tensor_size = concatenated_tensor.shape.num_elements()
split_size = total_tensor_size // self._num_splits
split_size_last = total_tensor_size - split_size * (self._num_splits - 1)
split_sizes = [split_size] * (self._num_splits - 1) + [split_size_last]
tensor_packs = tf.split(concatenated_tensor, split_sizes)
return tensor_packs
def _undo_split_tensors(self, tensor_packs):
"""Undoes self._split_tensors()."""
return tf.concat(tensor_packs, 0)
def _undo_concat_tensors(self, concatenated_tensor, concat_tensor_state):
"""Undoes self._concat_tensors()."""
tensors_with_sizes = tf.split(concatenated_tensor,
concat_tensor_state.orig_sizes)
tensors_with_shapes = [
tf.reshape(grad, shape)
for grad, shape in zip(tensors_with_sizes,
concat_tensor_state.orig_shapes)
]
return tensors_with_shapes
def concat_all_device_tensors(self, all_device_tensors):
"""For each device, concatenate the device's tensors into a single tensor.
Args:
all_device_tensors: A list of list of tensors. `all_device_tensors[i][j]`
is a tensor where `i` is the device index and `j` is the tensor index.
Returns:
A list of list of tensors in a similar form as all_device_tensors, except
the tensors on each device have been concatenated. Each inner list
consists of a single concatenated tensor.
"""
assert self._next_method == 'concat'
new_all_device_tensors = []
tensor_states = []
for device_tensors in all_device_tensors:
with tf.colocate_with(device_tensors[0]):
concat_tensor, tensor_state = self._concat_tensors(device_tensors)
new_all_device_tensors.append([concat_tensor])
tensor_states.append(tensor_state)
self._tensor_states = tensor_states
self._next_method = 'split'
return new_all_device_tensors
def split_all_device_tensors(self, all_device_tensors):
"""Splits concatenated tensors into `num_splits` pieces.
`num_splits` is specified in the constructor. In the case where the total
size of a concatenated tensor is not divisible by `num_splits`, the last
split tensor gets more elements.
Args:
all_device_tensors: A list of list of tensors. `all_device_tensors[i][j]`
is a tensor where `i` is the device index and `j` is the tensor index.
For each i, `all_device_tensors[i]` must be a list of length 1 of a
single concatenated tensor.
Returns:
A list of list of tensors in a similar form as all_device_tensors, except
the concatenated tensor on each device have been split. Each inner list
is a list of length `num_splits`.
"""
assert self._next_method == 'split'
new_all_device_tensors = []
for [concat_tensor] in all_device_tensors:
with tf.colocate_with(concat_tensor):
new_all_device_tensors.append(self._split_tensors(concat_tensor))
self._orig_concat_all_device_tensors = all_device_tensors
self._next_method = 'undo_split'
return new_all_device_tensors
def undo_split_all_device_tensors(self, all_device_tensors):
"""Undoes the effects of `split_all_device_tensors`."""
assert self._next_method == 'undo_split'
new_all_device_tensors = []
for i, device_tensors in enumerate(all_device_tensors):
[orig_tensor] = self._orig_concat_all_device_tensors[i]
with tf.colocate_with(orig_tensor):
new_all_device_tensors.append(
[self._undo_split_tensors(device_tensors)])
self._next_method = 'undo_concat'
return new_all_device_tensors
def undo_concat_all_device_tensors(self, all_device_tensors):
"""Undoes the effects of `concat_all_device_tensors`."""
assert self._next_method == 'undo_concat'
new_all_device_tensors = []
for [concat_tensor], tensor_state in zip(all_device_tensors,
self._tensor_states):
with tf.colocate_with(concat_tensor):
new_all_device_tensors.append(self._undo_concat_tensors(concat_tensor,
tensor_state))
self._next_method = None
return new_all_device_tensors