# Copyright 2017-2018 The Apache Software Foundation # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # # ----------------------------------------------------------------------- # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import mxnet as mx from mxnet.gluon.block import HybridBlock from mxnet.gluon import nn def add_model_args(parser): model = parser.add_argument_group('Model') model.add_argument('--arch', default='resnetv15', choices=['resnetv1', 'resnetv15', 'resnextv1', 'resnextv15', 'xception'], help='model architecture') model.add_argument('--num-layers', type=int, default=50, help='number of layers in the neural network, \ required by some networks such as resnet') model.add_argument('--num-groups', type=int, default=32, help='number of groups for grouped convolutions, \ required by some networks such as resnext') model.add_argument('--num-classes', type=int, default=1000, help='the number of classes') model.add_argument('--batchnorm-eps', type=float, default=1e-5, help='the amount added to the batchnorm variance to prevent output explosion.') model.add_argument('--batchnorm-mom', type=float, default=0.9, help='the leaky-integrator factor controling the batchnorm mean and variance.') model.add_argument('--fuse-bn-relu', type=int, default=0, help='have batchnorm kernel perform activation relu') model.add_argument('--fuse-bn-add-relu', type=int, default=0, help='have batchnorm kernel perform add followed by activation relu') return model class Builder: def __init__(self, dtype, input_layout, conv_layout, bn_layout, pooling_layout, bn_eps, bn_mom, fuse_bn_relu, fuse_bn_add_relu): self.dtype = dtype self.input_layout = input_layout self.conv_layout = conv_layout self.bn_layout = bn_layout self.pooling_layout = pooling_layout self.bn_eps = bn_eps self.bn_mom = bn_mom self.fuse_bn_relu = fuse_bn_relu self.fuse_bn_add_relu = fuse_bn_add_relu self.act_type = 'relu' self.bn_gamma_initializer = lambda last: 'zeros' if last else 'ones' self.linear_initializer = lambda groups=1: mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2 * (groups ** 0.5)) self.last_layout = self.input_layout def copy(self): return copy.copy(self) def batchnorm(self, last=False): gamma_initializer = self.bn_gamma_initializer(last) bn_axis = 3 if self.bn_layout == 'NHWC' else 1 return self.sequence( self.transpose(self.bn_layout), nn.BatchNorm(axis=bn_axis, momentum=self.bn_mom, epsilon=self.bn_eps, gamma_initializer=gamma_initializer, running_variance_initializer=gamma_initializer) ) def batchnorm_add_relu(self, last=False): gamma_initializer = self.bn_gamma_initializer(last) if self.fuse_bn_add_relu: bn_axis = 3 if self.bn_layout == 'NHWC' else 1 return self.sequence( self.transpose(self.bn_layout), BatchNormAddRelu(axis=bn_axis, momentum=self.bn_mom, epsilon=self.bn_eps, act_type=self.act_type, gamma_initializer=gamma_initializer, running_variance_initializer=gamma_initializer) ) return NonFusedBatchNormAddRelu(self, last=last) def batchnorm_relu(self, last=False): gamma_initializer = self.bn_gamma_initializer(last) if self.fuse_bn_relu: bn_axis = 3 if self.bn_layout == 'NHWC' else 1 return self.sequence( self.transpose(self.bn_layout), nn.BatchNorm(axis=bn_axis, momentum=self.bn_mom, epsilon=self.bn_eps, act_type=self.act_type, gamma_initializer=gamma_initializer, running_variance_initializer=gamma_initializer) ) return self.sequence(self.batchnorm(last=last), self.activation()) def activation(self): return nn.Activation(self.act_type) def global_avg_pool(self): return self.sequence( self.transpose(self.pooling_layout), nn.GlobalAvgPool2D(layout=self.pooling_layout) ) def max_pool(self, pool_size, strides=1, padding=True): padding = pool_size // 2 if padding is True else int(padding) return self.sequence( self.transpose(self.pooling_layout), nn.MaxPool2D(pool_size, strides=strides, padding=padding, layout=self.pooling_layout) ) def conv(self, channels, kernel_size, padding=True, strides=1, groups=1, in_channels=0): padding = kernel_size // 2 if padding is True else int(padding) initializer = self.linear_initializer(groups=groups) return self.sequence( self.transpose(self.conv_layout), nn.Conv2D(channels, kernel_size=kernel_size, strides=strides, padding=padding, use_bias=False, groups=groups, in_channels=in_channels, layout=self.conv_layout, weight_initializer=initializer) ) def separable_conv(self, channels, kernel_size, in_channels, padding=True, strides=1): return self.sequence( self.conv(in_channels, kernel_size, padding=padding, strides=strides, groups=in_channels, in_channels=in_channels), self.conv(channels, 1, in_channels=in_channels) ) def dense(self, units, in_units=0): return nn.Dense(units, in_units=in_units, weight_initializer=self.linear_initializer()) def transpose(self, to_layout): if self.last_layout == to_layout: return None ret = Transpose(self.last_layout, to_layout) self.last_layout = to_layout return ret def sequence(self, *seq): seq = list(filter(lambda x: x is not None, seq)) if len(seq) == 1: return seq[0] ret = nn.HybridSequential() ret.add(*seq) return ret class Transpose(HybridBlock): def __init__(self, from_layout, to_layout): super().__init__() supported_layouts = ['NCHW', 'NHWC'] if from_layout not in supported_layouts: raise ValueError('Not prepared to handle layout: {}'.format(from_layout)) if to_layout not in supported_layouts: raise ValueError('Not prepared to handle layout: {}'.format(to_layout)) self.from_layout = from_layout self.to_layout = to_layout def hybrid_forward(self, F, x): # Insert transpose if from_layout and to_layout don't match if self.from_layout == 'NCHW' and self.to_layout == 'NHWC': return F.transpose(x, axes=(0, 2, 3, 1)) elif self.from_layout == 'NHWC' and self.to_layout == 'NCHW': return F.transpose(x, axes=(0, 3, 1, 2)) else: return x def __repr__(self): s = '{name}({content})' if self.from_layout == self.to_layout: content = 'passthrough ' + self.from_layout else: content = self.from_layout + ' -> ' + self.to_layout return s.format(name=self.__class__.__name__, content=content) class LayoutWrapper(HybridBlock): def __init__(self, op, io_layout, op_layout, **kwargs): super(LayoutWrapper, self).__init__(**kwargs) with self.name_scope(): self.layout1 = Transpose(io_layout, op_layout) self.op = op self.layout2 = Transpose(op_layout, io_layout) def hybrid_forward(self, F, *x): return self.layout2(self.op(*(self.layout1(y) for y in x))) class BatchNormAddRelu(nn.BatchNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self._kwargs.pop('act_type') != 'relu': raise ValueError('BatchNormAddRelu can be used only with ReLU as activation') def hybrid_forward(self, F, x, y, gamma, beta, running_mean, running_var): return F.BatchNormAddRelu(data=x, addend=y, gamma=gamma, beta=beta, moving_mean=running_mean, moving_var=running_var, name='fwd', **self._kwargs) class NonFusedBatchNormAddRelu(HybridBlock): def __init__(self, builder, **kwargs): super().__init__() self.bn = builder.batchnorm(**kwargs) self.act = builder.activation() def hybrid_forward(self, F, x, y): return self.act(self.bn(x) + y) # Blocks class ResNetBasicBlock(HybridBlock): def __init__(self, builder, channels, stride, downsample=False, in_channels=0, version='1', resnext_groups=None, **kwargs): super().__init__() assert not resnext_groups self.transpose = builder.transpose(builder.conv_layout) builder_copy = builder.copy() body = [ builder.conv(channels, 3, strides=stride, in_channels=in_channels), builder.batchnorm_relu(), builder.conv(channels, 3), ] self.body = builder.sequence(*body) self.bn_add_relu = builder.batchnorm_add_relu(last=True) builder = builder_copy if downsample: self.downsample = builder.sequence( builder.conv(channels, 1, strides=stride, in_channels=in_channels), builder.batchnorm() ) else: self.downsample = None def hybrid_forward(self, F, x): if self.transpose is not None: x = self.transpose(x) residual = x x = self.body(x) if self.downsample: residual = self.downsample(residual) x = self.bn_add_relu(x, residual) return x class ResNetBottleNeck(HybridBlock): def __init__(self, builder, channels, stride, downsample=False, in_channels=0, version='1', resnext_groups=None): super().__init__() stride1 = stride if version == '1' else 1 stride2 = 1 if version == '1' else stride mult = 2 if resnext_groups else 1 groups = resnext_groups or 1 self.transpose = builder.transpose(builder.conv_layout) builder_copy = builder.copy() body = [ builder.conv(channels * mult // 4, 1, strides=stride1, in_channels=in_channels), builder.batchnorm_relu(), builder.conv(channels * mult // 4, 3, strides=stride2), builder.batchnorm_relu(), builder.conv(channels, 1) ] self.body = builder.sequence(*body) self.bn_add_relu = builder.batchnorm_add_relu(last=True) builder = builder_copy if downsample: self.downsample = builder.sequence( builder.conv(channels, 1, strides=stride, in_channels=in_channels), builder.batchnorm() ) else: self.downsample = None def hybrid_forward(self, F, x): if self.transpose is not None: x = self.transpose(x) residual = x x = self.body(x) if self.downsample: residual = self.downsample(residual) x = self.bn_add_relu(x, residual) return x class XceptionBlock(HybridBlock): def __init__(self, builder, definition, in_channels, relu_at_beginning=True): super().__init__() self.transpose = builder.transpose(builder.conv_layout) builder_copy = builder.copy() body = [] if relu_at_beginning: body.append(builder.activation()) last_channels = in_channels for channels1, channels2 in zip(definition, definition[1:] + [0]): if channels1 > 0: body.append(builder.separable_conv(channels1, 3, in_channels=last_channels)) if channels2 > 0: body.append(builder.batchnorm_relu()) else: body.append(builder.batchnorm(last=True)) last_channels = channels1 else: body.append(builder.max_pool(3, 2)) self.body = builder.sequence(*body) builder = builder_copy if any(map(lambda x: x <= 0, definition)): self.shortcut = builder.sequence( builder.conv(last_channels, 1, strides=2, in_channels=in_channels), builder.batchnorm(), ) else: self.shortcut = builder.sequence() def hybrid_forward(self, F, x): return self.shortcut(x) + self.body(x) # Nets class ResNet(HybridBlock): def __init__(self, builder, block, layers, channels, classes=1000, version='1', resnext_groups=None): super().__init__() assert len(layers) == len(channels) - 1 self.version = version with self.name_scope(): features = [ builder.conv(channels[0], 7, strides=2), builder.batchnorm_relu(), builder.max_pool(3, 2), ] for i, num_layer in enumerate(layers): stride = 1 if i == 0 else 2 features.append(self.make_layer(builder, block, num_layer, channels[i+1], stride, in_channels=channels[i], resnext_groups=resnext_groups)) features.append(builder.global_avg_pool()) self.features = builder.sequence(*features) self.output = builder.dense(classes, in_units=channels[-1]) def make_layer(self, builder, block, layers, channels, stride, in_channels=0, resnext_groups=None): layer = [] layer.append(block(builder, channels, stride, channels != in_channels, in_channels=in_channels, version=self.version, resnext_groups=resnext_groups)) for _ in range(layers-1): layer.append(block(builder, channels, 1, False, in_channels=channels, version=self.version, resnext_groups=resnext_groups)) return builder.sequence(*layer) def hybrid_forward(self, F, x): x = self.features(x) x = self.output(x) return x class Xception(HybridBlock): def __init__(self, builder, definition=([32, 64], [[128, 128, 0], [256, 256, 0], [728, 728, 0], *([[728, 728, 728]] * 8), [728, 1024, 0]], [1536, 2048]), classes=1000): super().__init__() definition1, definition2, definition3 = definition with self.name_scope(): features = [] last_channels = 0 for i, channels in enumerate(definition1): features += [ builder.conv(channels, 3, strides=(2 if i == 0 else 1), in_channels=last_channels), builder.batchnorm_relu(), ] last_channels = channels for i, block_definition in enumerate(definition2): features.append(XceptionBlock(builder, block_definition, in_channels=last_channels, relu_at_beginning=False if i == 0 else True)) last_channels = list(filter(lambda x: x > 0, block_definition))[-1] for i, channels in enumerate(definition3): features += [ builder.separable_conv(channels, 3, in_channels=last_channels), builder.batchnorm_relu(), ] last_channels = channels features.append(builder.global_avg_pool()) self.features = builder.sequence(*features) self.output = builder.dense(classes, in_units=last_channels) def hybrid_forward(self, F, x): x = self.features(x) x = self.output(x) return x resnet_spec = {18: (ResNetBasicBlock, [2, 2, 2, 2], [64, 64, 128, 256, 512]), 34: (ResNetBasicBlock, [3, 4, 6, 3], [64, 64, 128, 256, 512]), 50: (ResNetBottleNeck, [3, 4, 6, 3], [64, 256, 512, 1024, 2048]), 101: (ResNetBottleNeck, [3, 4, 23, 3], [64, 256, 512, 1024, 2048]), 152: (ResNetBottleNeck, [3, 8, 36, 3], [64, 256, 512, 1024, 2048])} def create_resnet(builder, version, num_layers=50, resnext=False, classes=1000): assert num_layers in resnet_spec, \ "Invalid number of layers: {}. Options are {}".format( num_layers, str(resnet_spec.keys())) block_class, layers, channels = resnet_spec[num_layers] assert not resnext or num_layers >= 50, \ "Cannot create resnext with less then 50 layers" net = ResNet(builder, block_class, layers, channels, version=version, resnext_groups=args.num_groups if resnext else None) return net class fp16_model(mx.gluon.block.HybridBlock): def __init__(self, net, **kwargs): super(fp16_model, self).__init__(**kwargs) with self.name_scope(): self._net = net def hybrid_forward(self, F, x): y = self._net(x) y = F.cast(y, dtype='float32') return y def get_model(arch, num_classes, num_layers, image_shape, dtype, amp, input_layout, conv_layout, batchnorm_layout, pooling_layout, batchnorm_eps, batchnorm_mom, fuse_bn_relu, fuse_bn_add_relu, **kwargs): builder = Builder( dtype = dtype, input_layout = input_layout, conv_layout = conv_layout, bn_layout = batchnorm_layout, pooling_layout = pooling_layout, bn_eps = batchnorm_eps, bn_mom = batchnorm_mom, fuse_bn_relu = fuse_bn_relu, fuse_bn_add_relu = fuse_bn_add_relu, ) if arch.startswith('resnet') or arch.startswith('resnext'): version = '1' if arch in {'resnetv1', 'resnextv1'} else '1.5' net = create_resnet( builder = builder, version = version, resnext = arch.startswith('resnext'), num_layers = num_layers, classes = num_classes, ) elif arch == 'xception': net = Xception(builder, classes=num_classes) else: raise ValueError('Wrong model architecture') net.hybridize(static_shape=True, static_alloc=True) if not amp: net.cast(dtype) if dtype == 'float16': net = fp16_model(net) return net