[BERT/PyT] fix onnx export (#689)

This commit is contained in:
Sharath T S 2020-09-15 12:44:10 -07:00 committed by GitHub
parent 437b950d5b
commit 482fe9ac8a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -287,7 +287,9 @@ class BertNonFusedLayerNorm(nn.Module):
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
s = (x - u)
s = s * s
s = s.mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
@ -323,7 +325,9 @@ class BertLayerNorm(Module):
x = self.fused_layer_norm(x)
else:
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
s = (x - u)
s = s * s
s = s.mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight * x + self.bias
return x