DeepLearningExamples/PyTorch/Classification/ConvNets/image_classification/models/common.py
2021-11-09 13:42:18 -08:00

303 lines
8.9 KiB
Python

import copy
from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional
import torch
import warnings
from torch import nn
import torch.nn.functional as F
try:
from pytorch_quantization import nn as quant_nn
except ImportError as e:
warnings.warn(
"pytorch_quantization module not found, quantization will not be available"
)
quant_nn = None
# LayerBuilder {{{
class LayerBuilder(object):
@dataclass
class Config:
activation: str = "relu"
conv_init: str = "fan_in"
bn_momentum: Optional[float] = None
bn_epsilon: Optional[float] = None
def __init__(self, config: "LayerBuilder.Config"):
self.config = config
def conv(
self,
kernel_size,
in_planes,
out_planes,
groups=1,
stride=1,
bn=False,
zero_init_bn=False,
act=False,
):
conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
groups=groups,
stride=stride,
padding=int((kernel_size - 1) / 2),
bias=False,
)
nn.init.kaiming_normal_(
conv.weight, mode=self.config.conv_init, nonlinearity="relu"
)
layers = [("conv", conv)]
if bn:
layers.append(("bn", self.batchnorm(out_planes, zero_init_bn)))
if act:
layers.append(("act", self.activation()))
if bn or act:
return nn.Sequential(OrderedDict(layers))
else:
return conv
def convDepSep(
self, kernel_size, in_planes, out_planes, stride=1, bn=False, act=False
):
"""3x3 depthwise separable convolution with padding"""
c = self.conv(
kernel_size,
in_planes,
out_planes,
groups=in_planes,
stride=stride,
bn=bn,
act=act,
)
return c
def conv3x3(self, in_planes, out_planes, stride=1, groups=1, bn=False, act=False):
"""3x3 convolution with padding"""
c = self.conv(
3, in_planes, out_planes, groups=groups, stride=stride, bn=bn, act=act
)
return c
def conv1x1(self, in_planes, out_planes, stride=1, groups=1, bn=False, act=False):
"""1x1 convolution with padding"""
c = self.conv(
1, in_planes, out_planes, groups=groups, stride=stride, bn=bn, act=act
)
return c
def conv7x7(self, in_planes, out_planes, stride=1, groups=1, bn=False, act=False):
"""7x7 convolution with padding"""
c = self.conv(
7, in_planes, out_planes, groups=groups, stride=stride, bn=bn, act=act
)
return c
def conv5x5(self, in_planes, out_planes, stride=1, groups=1, bn=False, act=False):
"""5x5 convolution with padding"""
c = self.conv(
5, in_planes, out_planes, groups=groups, stride=stride, bn=bn, act=act
)
return c
def batchnorm(self, planes, zero_init=False):
bn_cfg = {}
if self.config.bn_momentum is not None:
bn_cfg["momentum"] = self.config.bn_momentum
if self.config.bn_epsilon is not None:
bn_cfg["eps"] = self.config.bn_epsilon
bn = nn.BatchNorm2d(planes, **bn_cfg)
gamma_init_val = 0 if zero_init else 1
nn.init.constant_(bn.weight, gamma_init_val)
nn.init.constant_(bn.bias, 0)
return bn
def activation(self):
return {
"silu": lambda: nn.SiLU(inplace=True),
"relu": lambda: nn.ReLU(inplace=True),
"onnx-silu": ONNXSiLU,
}[self.config.activation]()
# LayerBuilder }}}
# LambdaLayer {{{
class LambdaLayer(nn.Module):
def __init__(self, lmbd):
super().__init__()
self.lmbd = lmbd
def forward(self, x):
return self.lmbd(x)
# }}}
# SqueezeAndExcitation {{{
class SqueezeAndExcitation(nn.Module):
def __init__(self, in_channels, squeeze, activation):
super(SqueezeAndExcitation, self).__init__()
self.squeeze = nn.Linear(in_channels, squeeze)
self.expand = nn.Linear(squeeze, in_channels)
self.activation = activation
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self._attention(x)
def _attention(self, x):
out = torch.mean(x, [2, 3])
out = self.squeeze(out)
out = self.activation(out)
out = self.expand(out)
out = self.sigmoid(out)
out = out.unsqueeze(2).unsqueeze(3)
return out
class SqueezeAndExcitationTRT(nn.Module):
def __init__(self, in_channels, squeeze, activation):
super(SqueezeAndExcitationTRT, self).__init__()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
self.expand = nn.Conv2d(squeeze, in_channels, 1)
self.activation = activation
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self._attention(x)
def _attention(self, x):
out = self.pooling(x)
out = self.squeeze(out)
out = self.activation(out)
out = self.expand(out)
out = self.sigmoid(out)
return out
# }}}
# EMA {{{
class EMA:
def __init__(self, mu, module_ema):
self.mu = mu
self.module_ema = module_ema
def __call__(self, module, step=None):
if step is None:
mu = self.mu
else:
mu = min(self.mu, (1.0 + step) / (10 + step))
def strip_module(s: str) -> str:
return s
mesd = self.module_ema.state_dict()
with torch.no_grad():
for name, x in module.state_dict().items():
if name.endswith("num_batches_tracked"):
continue
n = strip_module(name)
mesd[n].mul_(mu)
mesd[n].add_((1.0 - mu) * x)
# }}}
# ONNXSiLU {{{
# Since torch.nn.SiLU is not supported in ONNX,
# it is required to use this implementation in exported model (15-20% more GPU memory is needed)
class ONNXSiLU(nn.Module):
def __init__(self, *args, **kwargs):
super(ONNXSiLU, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
# }}}
class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
def __init__(self, in_channels, squeeze, activation, quantized=False):
super().__init__(in_channels, squeeze, activation)
self.quantized = quantized
if quantized:
assert quant_nn is not None, "pytorch_quantization is not available"
self.mul_a_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
self.mul_b_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
else:
self.mul_a_quantizer = nn.Identity()
self.mul_b_quantizer = nn.Identity()
def forward(self, x):
out = self._attention(x)
if not self.quantized:
return out * x
else:
x_quant = self.mul_a_quantizer(out)
return x_quant * self.mul_b_quantizer(x)
class SequentialSqueezeAndExcitationTRT(SqueezeAndExcitationTRT):
def __init__(self, in_channels, squeeze, activation, quantized=False):
super().__init__(in_channels, squeeze, activation)
self.quantized = quantized
if quantized:
assert quant_nn is not None, "pytorch_quantization is not available"
self.mul_a_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
self.mul_b_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
else:
self.mul_a_quantizer = nn.Identity()
self.mul_b_quantizer = nn.Identity()
def forward(self, x):
out = self._attention(x)
if not self.quantized:
return out * x
else:
x_quant = self.mul_a_quantizer(out)
return x_quant * self.mul_b_quantizer(x)
class StochasticDepthResidual(nn.Module):
def __init__(self, survival_prob: float):
super().__init__()
self.survival_prob = survival_prob
self.register_buffer("mask", torch.ones(()), persistent=False)
def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
if not self.training:
return torch.add(residual, other=x)
else:
with torch.no_grad():
F.dropout(
self.mask,
p=1 - self.survival_prob,
training=self.training,
inplace=False,
)
return torch.addcmul(residual, self.mask, x)
class Flatten(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.squeeze(-1).squeeze(-1)