
117 lines
4.2 KiB

# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn as nn
normalizations = {
"instancenorm3d": nn.InstanceNorm3d,
"instancenorm2d": nn.InstanceNorm2d,
"batchnorm3d": nn.BatchNorm3d,
"batchnorm2d": nn.BatchNorm2d,
convolutions = {
"Conv2d": nn.Conv2d,
"Conv3d": nn.Conv3d,
"ConvTranspose2d": nn.ConvTranspose2d,
"ConvTranspose3d": nn.ConvTranspose3d,
def get_norm(name, out_channels):
if "groupnorm" in name:
return nn.GroupNorm(32, out_channels, affine=True)
return normalizations[name](out_channels, affine=True)
def get_conv(in_channels, out_channels, kernel_size, stride, dim, bias=False):
conv = convolutions[f"Conv{dim}d"]
padding = get_padding(kernel_size, stride)
return conv(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
def get_transp_conv(in_channels, out_channels, kernel_size, stride, dim):
conv = convolutions[f"ConvTranspose{dim}d"]
padding = get_padding(kernel_size, stride)
output_padding = get_output_padding(kernel_size, stride, padding)
return conv(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True)
def get_padding(kernel_size, stride):
kernel_size_np = np.atleast_1d(kernel_size)
stride_np = np.atleast_1d(stride)
padding_np = (kernel_size_np - stride_np + 1) / 2
padding = tuple(int(p) for p in padding_np)
return padding if len(padding) > 1 else padding[0]
def get_output_padding(kernel_size, stride, padding):
kernel_size_np = np.atleast_1d(kernel_size)
stride_np = np.atleast_1d(stride)
padding_np = np.atleast_1d(padding)
out_padding_np = 2 * padding_np + stride_np - kernel_size_np
out_padding = tuple(int(p) for p in out_padding_np)
return out_padding if len(out_padding) > 1 else out_padding[0]
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
super(ConvLayer, self).__init__()
self.conv = get_conv(in_channels, out_channels, kernel_size, stride, kwargs["dim"])
self.norm = get_norm(kwargs["norm"], out_channels)
self.lrelu = nn.LeakyReLU(negative_slope=kwargs["negative_slope"], inplace=True)
def forward(self, data):
out = self.conv(data)
out = self.norm(out)
out = self.lrelu(out)
return out
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
super(ConvBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, **kwargs)
self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, **kwargs)
def forward(self, input_data):
out = self.conv1(input_data)
out = self.conv2(out)
return out
class UpsampleBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
super(UpsampleBlock, self).__init__()
self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, kwargs["dim"])
self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, **kwargs)
def forward(self, input_data, skip_data):
out = self.transp_conv(input_data)
out =, skip_data), dim=1)
out = self.conv_block(out)
return out
class OutputBlock(nn.Module):
def __init__(self, in_channels, out_channels, dim):
super(OutputBlock, self).__init__()
self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
nn.init.constant_(self.conv.bias, 0)
def forward(self, input_data):
return self.conv(input_data)