import torch from torch import nn import torch.nn.functional as F class ConvNormLReLU(nn.Sequential): def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False): pad_layer = { "zero": nn.ZeroPad2d, "same": nn.ReplicationPad2d, "reflect": nn.ReflectionPad2d, } if pad_mode not in pad_layer: raise NotImplementedError super(ConvNormLReLU, self).__init__( pad_layer[pad_mode](padding), nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias), nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True), nn.LeakyReLU(0.2, inplace=True) ) class InvertedResBlock(nn.Module): def __init__(self, in_ch, out_ch, expansion_ratio=2): super(InvertedResBlock, self).__init__() self.use_res_connect = in_ch == out_ch bottleneck = int(round(in_ch*expansion_ratio)) layers = [] if expansion_ratio != 1: layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0)) # dw layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True)) # pw layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False)) layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True)) self.layers = nn.Sequential(*layers) def forward(self, input): out = self.layers(input) if self.use_res_connect: out = input + out return out class Generator(nn.Module): def __init__(self, ): super().__init__() self.block_a = nn.Sequential( ConvNormLReLU(3, 32, kernel_size=7, padding=3), ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)), ConvNormLReLU(64, 64) ) self.block_b = nn.Sequential( ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)), ConvNormLReLU(128, 128) ) self.block_c = nn.Sequential( ConvNormLReLU(128, 128), InvertedResBlock(128, 256, 2), InvertedResBlock(256, 256, 2), InvertedResBlock(256, 256, 2), InvertedResBlock(256, 256, 2), ConvNormLReLU(256, 128), ) self.block_d = nn.Sequential( ConvNormLReLU(128, 128), ConvNormLReLU(128, 128) ) self.block_e = nn.Sequential( ConvNormLReLU(128, 64), ConvNormLReLU(64, 64), ConvNormLReLU(64, 32, kernel_size=7, padding=3) ) self.out_layer = nn.Sequential( nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False), nn.Tanh() ) def forward(self, input, align_corners=True): out = self.block_a(input) half_size = out.size()[-2:] out = self.block_b(out) out = self.block_c(out) if align_corners: out = F.interpolate(out, half_size, mode="bilinear", align_corners=True) else: out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) out = self.block_d(out) if align_corners: out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True) else: out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) out = self.block_e(out) out = self.out_layer(out) return out