[BERT/PyT] fix onnx export (#689)
This commit is contained in:
parent
437b950d5b
commit
482fe9ac8a
|
@ -287,7 +287,9 @@ class BertNonFusedLayerNorm(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
s = (x - u)
|
||||
s = s * s
|
||||
s = s.mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
|
@ -323,7 +325,9 @@ class BertLayerNorm(Module):
|
|||
x = self.fused_layer_norm(x)
|
||||
else:
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
s = (x - u)
|
||||
s = s * s
|
||||
s = s.mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight * x + self.bias
|
||||
return x
|
||||
|
|
Loading…
Reference in a new issue