[ResNet/TF] Fix gradient calculation for sync variable

This commit is contained in:
hXl3s 2021-04-26 15:00:10 +02:00 committed by GitHub
parent 01201316f8
commit 2d555548b6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -239,7 +239,8 @@ class ResnetModel(object):
with tf.device("/cpu:0"):
if hvd_utils.is_using_hvd():
sync_var = tf.Variable(initial_value=[0], dtype=tf.int32, name="signal_handler_var")
sync_var = tf.Variable(initial_value=[0], dtype=tf.int32, name="signal_handler_var",
trainable=False)
sync_var_assing = sync_var.assign([1], name="signal_handler_var_set")
sync_var_reset = sync_var.assign([0], name="signal_handler_var_reset")
sync_op = hvd.allreduce(sync_var, op=hvd.Sum, name="signal_handler_all_reduce")