[ResNet/TF] Fix gradient calculation for sync variable
This commit is contained in:
parent
01201316f8
commit
2d555548b6
|
@ -239,7 +239,8 @@ class ResnetModel(object):
|
|||
|
||||
with tf.device("/cpu:0"):
|
||||
if hvd_utils.is_using_hvd():
|
||||
sync_var = tf.Variable(initial_value=[0], dtype=tf.int32, name="signal_handler_var")
|
||||
sync_var = tf.Variable(initial_value=[0], dtype=tf.int32, name="signal_handler_var",
|
||||
trainable=False)
|
||||
sync_var_assing = sync_var.assign([1], name="signal_handler_var_set")
|
||||
sync_var_reset = sync_var.assign([0], name="signal_handler_var_reset")
|
||||
sync_op = hvd.allreduce(sync_var, op=hvd.Sum, name="signal_handler_all_reduce")
|
||||
|
|
Loading…
Reference in a new issue