From 482fe9ac8a117a1f5eb05c4767d82e45a8adff2d Mon Sep 17 00:00:00 2001 From: Sharath T S Date: Tue, 15 Sep 2020 12:44:10 -0700 Subject: [PATCH] [BERT/PyT] fix onnx export (#689) --- PyTorch/LanguageModeling/BERT/modeling.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/PyTorch/LanguageModeling/BERT/modeling.py b/PyTorch/LanguageModeling/BERT/modeling.py index 27d7e291..0a1bd8f1 100755 --- a/PyTorch/LanguageModeling/BERT/modeling.py +++ b/PyTorch/LanguageModeling/BERT/modeling.py @@ -287,7 +287,9 @@ class BertNonFusedLayerNorm(nn.Module): def forward(self, x): u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) + s = (x - u) + s = s * s + s = s.mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias @@ -323,7 +325,9 @@ class BertLayerNorm(Module): x = self.fused_layer_norm(x) else: u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) + s = (x - u) + s = s * s + s = s.mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight * x + self.bias return x