import copy from collections import OrderedDict from dataclasses import dataclass from typing import Optional import torch import warnings from torch import nn import torch.nn.functional as F try: from pytorch_quantization import nn as quant_nn except ImportError as e: warnings.warn( "pytorch_quantization module not found, quantization will not be available" ) quant_nn = None # LayerBuilder {{{ class LayerBuilder(object): @dataclass class Config: activation: str = "relu" conv_init: str = "fan_in" bn_momentum: Optional[float] = None bn_epsilon: Optional[float] = None def __init__(self, config: "LayerBuilder.Config"): self.config = config def conv( self, kernel_size, in_planes, out_planes, groups=1, stride=1, bn=False, zero_init_bn=False, act=False, ): conv = nn.Conv2d( in_planes, out_planes, kernel_size=kernel_size, groups=groups, stride=stride, padding=int((kernel_size - 1) / 2), bias=False, ) nn.init.kaiming_normal_( conv.weight, mode=self.config.conv_init, nonlinearity="relu" ) layers = [("conv", conv)] if bn: layers.append(("bn", self.batchnorm(out_planes, zero_init_bn))) if act: layers.append(("act", self.activation())) if bn or act: return nn.Sequential(OrderedDict(layers)) else: return conv def convDepSep( self, kernel_size, in_planes, out_planes, stride=1, bn=False, act=False ): """3x3 depthwise separable convolution with padding""" c = self.conv( kernel_size, in_planes, out_planes, groups=in_planes, stride=stride, bn=bn, act=act, ) return c def conv3x3(self, in_planes, out_planes, stride=1, groups=1, bn=False, act=False): """3x3 convolution with padding""" c = self.conv( 3, in_planes, out_planes, groups=groups, stride=stride, bn=bn, act=act ) return c def conv1x1(self, in_planes, out_planes, stride=1, groups=1, bn=False, act=False): """1x1 convolution with padding""" c = self.conv( 1, in_planes, out_planes, groups=groups, stride=stride, bn=bn, act=act ) return c def conv7x7(self, in_planes, out_planes, stride=1, groups=1, bn=False, act=False): """7x7 convolution with padding""" c = self.conv( 7, in_planes, out_planes, groups=groups, stride=stride, bn=bn, act=act ) return c def conv5x5(self, in_planes, out_planes, stride=1, groups=1, bn=False, act=False): """5x5 convolution with padding""" c = self.conv( 5, in_planes, out_planes, groups=groups, stride=stride, bn=bn, act=act ) return c def batchnorm(self, planes, zero_init=False): bn_cfg = {} if self.config.bn_momentum is not None: bn_cfg["momentum"] = self.config.bn_momentum if self.config.bn_epsilon is not None: bn_cfg["eps"] = self.config.bn_epsilon bn = nn.BatchNorm2d(planes, **bn_cfg) gamma_init_val = 0 if zero_init else 1 nn.init.constant_(bn.weight, gamma_init_val) nn.init.constant_(bn.bias, 0) return bn def activation(self): return { "silu": lambda: nn.SiLU(inplace=True), "relu": lambda: nn.ReLU(inplace=True), "onnx-silu": ONNXSiLU, }[self.config.activation]() # LayerBuilder }}} # LambdaLayer {{{ class LambdaLayer(nn.Module): def __init__(self, lmbd): super().__init__() self.lmbd = lmbd def forward(self, x): return self.lmbd(x) # }}} # SqueezeAndExcitation {{{ class SqueezeAndExcitation(nn.Module): def __init__(self, in_channels, squeeze, activation): super(SqueezeAndExcitation, self).__init__() self.squeeze = nn.Linear(in_channels, squeeze) self.expand = nn.Linear(squeeze, in_channels) self.activation = activation self.sigmoid = nn.Sigmoid() def forward(self, x): return self._attention(x) def _attention(self, x): out = torch.mean(x, [2, 3]) out = self.squeeze(out) out = self.activation(out) out = self.expand(out) out = self.sigmoid(out) out = out.unsqueeze(2).unsqueeze(3) return out class SqueezeAndExcitationTRT(nn.Module): def __init__(self, in_channels, squeeze, activation): super(SqueezeAndExcitationTRT, self).__init__() self.pooling = nn.AdaptiveAvgPool2d(1) self.squeeze = nn.Conv2d(in_channels, squeeze, 1) self.expand = nn.Conv2d(squeeze, in_channels, 1) self.activation = activation self.sigmoid = nn.Sigmoid() def forward(self, x): return self._attention(x) def _attention(self, x): out = self.pooling(x) out = self.squeeze(out) out = self.activation(out) out = self.expand(out) out = self.sigmoid(out) return out # }}} # EMA {{{ class EMA: def __init__(self, mu, module_ema): self.mu = mu self.module_ema = module_ema def __call__(self, module, step=None): if step is None: mu = self.mu else: mu = min(self.mu, (1.0 + step) / (10 + step)) def strip_module(s: str) -> str: return s mesd = self.module_ema.state_dict() with torch.no_grad(): for name, x in module.state_dict().items(): if name.endswith("num_batches_tracked"): continue n = strip_module(name) mesd[n].mul_(mu) mesd[n].add_((1.0 - mu) * x) # }}} # ONNXSiLU {{{ # Since torch.nn.SiLU is not supported in ONNX, # it is required to use this implementation in exported model (15-20% more GPU memory is needed) class ONNXSiLU(nn.Module): def __init__(self, *args, **kwargs): super(ONNXSiLU, self).__init__() def forward(self, x): return x * torch.sigmoid(x) # }}} class SequentialSqueezeAndExcitation(SqueezeAndExcitation): def __init__(self, in_channels, squeeze, activation, quantized=False): super().__init__(in_channels, squeeze, activation) self.quantized = quantized if quantized: assert quant_nn is not None, "pytorch_quantization is not available" self.mul_a_quantizer = quant_nn.TensorQuantizer( quant_nn.QuantConv2d.default_quant_desc_input ) self.mul_b_quantizer = quant_nn.TensorQuantizer( quant_nn.QuantConv2d.default_quant_desc_input ) else: self.mul_a_quantizer = nn.Identity() self.mul_b_quantizer = nn.Identity() def forward(self, x): out = self._attention(x) if not self.quantized: return out * x else: x_quant = self.mul_a_quantizer(out) return x_quant * self.mul_b_quantizer(x) class SequentialSqueezeAndExcitationTRT(SqueezeAndExcitationTRT): def __init__(self, in_channels, squeeze, activation, quantized=False): super().__init__(in_channels, squeeze, activation) self.quantized = quantized if quantized: assert quant_nn is not None, "pytorch_quantization is not available" self.mul_a_quantizer = quant_nn.TensorQuantizer( quant_nn.QuantConv2d.default_quant_desc_input ) self.mul_b_quantizer = quant_nn.TensorQuantizer( quant_nn.QuantConv2d.default_quant_desc_input ) else: self.mul_a_quantizer = nn.Identity() self.mul_b_quantizer = nn.Identity() def forward(self, x): out = self._attention(x) if not self.quantized: return out * x else: x_quant = self.mul_a_quantizer(out) return x_quant * self.mul_b_quantizer(x) class StochasticDepthResidual(nn.Module): def __init__(self, survival_prob: float): super().__init__() self.survival_prob = survival_prob self.register_buffer("mask", torch.ones(()), persistent=False) def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor: if not self.training: return torch.add(residual, other=x) else: with torch.no_grad(): F.dropout( self.mask, p=1 - self.survival_prob, training=self.training, inplace=False, ) return torch.addcmul(residual, self.mask, x) class Flatten(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return x.squeeze(-1).squeeze(-1)