[BERT/TF] bug fix in beta bias correction terms (#395)

This commit is contained in:
Swetha Mandava 2020-02-06 22:08:20 -08:00 committed by GitHub
parent ce73b32068
commit 0845aaa901
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -301,12 +301,12 @@ class LAMBOptimizer(tf.compat.v1.train.Optimizer):
self.beta_2 = beta_2
self.epsilon = epsilon
self.exclude_from_weight_decay = exclude_from_weight_decay
self.steps = 0
def apply_gradients(self, grads_and_vars, global_step=None, name=None,
def apply_gradients(self, grads_and_vars, global_step, name=None,
manual_fp16=False):
"""See base class."""
assignments = []
steps = tf.cast(global_step, tf.float32)
for (grad, param) in grads_and_vars:
if grad is None or param is None:
continue
@ -343,9 +343,8 @@ class LAMBOptimizer(tf.compat.v1.train.Optimizer):
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
tf.square(grad)))
self.steps += 1
beta1_correction = (1 - self.beta_1 ** self.steps)
beta2_correction = (1 - self.beta_2 ** self.steps)
beta1_correction = (1 - self.beta_1 ** steps)
beta2_correction = (1 - self.beta_2 ** steps)
next_m_unbiased = next_m / beta1_correction
next_v_unbiased = next_v / beta2_correction