[ConvNets/PyT] TorchScriptable ConvNets

This commit is contained in:
Andrzej Sulecki 2021-11-09 13:42:18 -08:00 committed by Krzysztof Kudrynski
parent 3d3250a3ae
commit 4f2c6922bd
12 changed files with 915 additions and 590 deletions

View File

@ -0,0 +1,2 @@
*.pth.tar
*.log

View File

@ -22,6 +22,7 @@ def add_parser_arguments(parser):
parser.add_argument(
"--weight-path", metavar="<path>", help="name of file in which to store weights"
)
parser.add_argument("--ema", action="store_true", default=False)
if __name__ == "__main__":
@ -30,12 +31,13 @@ if __name__ == "__main__":
add_parser_arguments(parser)
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path, map_location=torch.device('cpu'))
checkpoint = torch.load(args.checkpoint_path, map_location=torch.device("cpu"))
key = "state_dict" if not args.ema else "ema_state_dict"
model_state_dict = {
k[len("module.") :] if "module." in k else k: v
for k, v in checkpoint["state_dict"].items()
}
print(f"Loaded model, acc : {checkpoint['best_prec1']}")
torch.save(model_state_dict, args.weight_path)
torch.save(model_state_dict, args.weight_path)

View File

@ -234,32 +234,29 @@ class Logger(object):
def log_metric(self, metric_name, val, n=1):
self.metrics[metric_name]["meter"].record(val, n=n)
def start_iteration(self, mode='train'):
if mode == 'val':
def start_iteration(self, mode="train"):
if mode == "val":
self.val_iteration += 1
elif mode == 'train':
elif mode == "train":
self.iteration += 1
elif mode == 'calib':
elif mode == "calib":
self.calib_iteration += 1
def end_iteration(self, mode='train'):
if mode == 'val':
def end_iteration(self, mode="train"):
if mode == "val":
it = self.val_iteration
elif mode == 'train':
elif mode == "train":
it = self.iteration
elif mode == 'calib':
elif mode == "calib":
it = self.calib_iteration
if it % self.print_interval == 0 or mode == 'calib':
metrics = {
n: m for n, m in self.metrics.items() if n.startswith(mode)
}
if mode == 'train':
if it % self.print_interval == 0 or mode == "calib":
metrics = {n: m for n, m in self.metrics.items() if n.startswith(mode)}
if mode == "train":
step = (self.epoch, self.iteration)
elif mode == 'val':
elif mode == "val":
step = (self.epoch, self.iteration, self.val_iteration)
elif mode == 'calib':
step = ('Calibration', self.calib_iteration)
elif mode == "calib":
step = ("Calibration", self.calib_iteration)
verbositys = {m["level"] for _, m in metrics.items()}
for ll in verbositys:
@ -282,12 +279,12 @@ class Logger(object):
self.val_iteration = 0
for n, m in self.metrics.items():
if not n.startswith('calib'):
if not n.startswith("calib"):
m["meter"].reset_epoch()
def end_epoch(self):
for n, m in self.metrics.items():
if not n.startswith('calib'):
if not n.startswith("calib"):
m["meter"].reset_iteration()
verbositys = {m["level"] for _, m in self.metrics.items()}
@ -302,12 +299,12 @@ class Logger(object):
self.calib_iteration = 0
for n, m in self.metrics.items():
if n.startswith('calib'):
if n.startswith("calib"):
m["meter"].reset_epoch()
def end_calibration(self):
for n, m in self.metrics.items():
if n.startswith('calib'):
if n.startswith("calib"):
m["meter"].reset_iteration()
def end(self):
@ -326,7 +323,7 @@ class Logger(object):
dllogger.flush()
def iteration_generator_wrapper(self, gen, mode='train'):
def iteration_generator_wrapper(self, gen, mode="train"):
for g in gen:
self.start_iteration(mode=mode)
yield g
@ -337,3 +334,155 @@ class Logger(object):
self.start_epoch()
yield g
self.end_epoch()
class Metrics:
ACC_METADATA = {"unit": "%", "format": ":.2f"}
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
TIME_METADATA = {"unit": "s", "format": ":.5f"}
LOSS_METADATA = {"format": ":.5f"}
LR_METADATA = {"format": ":.5f"}
def __init__(self, logger):
self.logger = logger
self.map = {}
def log(self, **kwargs):
if self.logger is None:
return
for k, v in kwargs.items():
tks = self.map.get(k, [k])
for tk in tks:
if isinstance(v, tuple):
self.logger.log_metric(tk, v[0], v[1])
else:
self.logger.log_metric(tk, v)
class TrainingMetrics(Metrics):
def __init__(self, logger):
super().__init__(logger)
if self.logger is not None:
self.map = {
"loss": ["train.loss"],
"compute_ips": ["train.compute_ips"],
"total_ips": ["train.total_ips"],
"data_time": ["train.data_time"],
"compute_time": ["train.compute_time"],
"lr": ["train.lr"],
}
logger.register_metric(
"train.loss",
LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.LOSS_METADATA,
)
logger.register_metric(
"train.compute_ips",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.IPS_METADATA,
)
logger.register_metric(
"train.total_ips",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.IPS_METADATA,
)
logger.register_metric(
"train.data_time",
PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
"train.compute_time",
PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
"train.lr",
LR_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
)
class ValidationMetrics(Metrics):
def __init__(self, logger, prefix):
super().__init__(logger)
if self.logger is not None:
self.map = {
"loss": [f"{prefix}.loss"],
"top1": [f"{prefix}.top1"],
"top5": [f"{prefix}.top5"],
"compute_ips": [f"{prefix}.compute_ips"],
"total_ips": [f"{prefix}.total_ips"],
"data_time": [f"{prefix}.data_time"],
"compute_time": [
f"{prefix}.compute_latency",
f"{prefix}.compute_latency_at100",
f"{prefix}.compute_latency_at99",
f"{prefix}.compute_latency_at95",
],
}
logger.register_metric(
f"{prefix}.top1",
ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.ACC_METADATA,
)
logger.register_metric(
f"{prefix}.top5",
ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.ACC_METADATA,
)
logger.register_metric(
f"{prefix}.loss",
LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.LOSS_METADATA,
)
logger.register_metric(
f"{prefix}.compute_ips",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.IPS_METADATA,
)
logger.register_metric(
f"{prefix}.total_ips",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.IPS_METADATA,
)
logger.register_metric(
f"{prefix}.data_time",
PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at100",
LAT_100(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at99",
LAT_99(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at95",
LAT_95(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)

View File

@ -5,6 +5,7 @@ from typing import Optional
import torch
import warnings
from torch import nn
import torch.nn.functional as F
try:
from pytorch_quantization import nn as quant_nn
@ -143,30 +144,44 @@ class LambdaLayer(nn.Module):
# SqueezeAndExcitation {{{
class SqueezeAndExcitation(nn.Module):
def __init__(self, in_channels, squeeze, activation, use_conv=False):
def __init__(self, in_channels, squeeze, activation):
super(SqueezeAndExcitation, self).__init__()
if use_conv:
self.pooling = nn.AdaptiveAvgPool2d(1)
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
self.expand = nn.Conv2d(squeeze, in_channels, 1)
else:
self.squeeze = nn.Linear(in_channels, squeeze)
self.expand = nn.Linear(squeeze, in_channels)
self.squeeze = nn.Linear(in_channels, squeeze)
self.expand = nn.Linear(squeeze, in_channels)
self.activation = activation
self.sigmoid = nn.Sigmoid()
self.use_conv = use_conv
def forward(self, x):
if self.use_conv:
out = self.pooling(x)
else:
out = torch.mean(x, [2, 3])
return self._attention(x)
def _attention(self, x):
out = torch.mean(x, [2, 3])
out = self.squeeze(out)
out = self.activation(out)
out = self.expand(out)
out = self.sigmoid(out)
out = out.unsqueeze(2).unsqueeze(3)
return out
class SqueezeAndExcitationTRT(nn.Module):
def __init__(self, in_channels, squeeze, activation):
super(SqueezeAndExcitationTRT, self).__init__()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
self.expand = nn.Conv2d(squeeze, in_channels, 1)
self.activation = activation
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self._attention(x)
def _attention(self, x):
out = self.pooling(x)
out = self.squeeze(out)
out = self.activation(out)
out = self.expand(out)
out = self.sigmoid(out)
if not self.use_conv:
out = out.unsqueeze(2).unsqueeze(3)
return out
@ -174,18 +189,9 @@ class SqueezeAndExcitation(nn.Module):
# EMA {{{
class EMA:
def __init__(self, mu):
def __init__(self, mu, module_ema):
self.mu = mu
self.shadow = {}
def state_dict(self):
return copy.deepcopy(self.shadow)
def load_state_dict(self, state_dict):
self.shadow = state_dict
def __len__(self):
return len(self.shadow)
self.module_ema = module_ema
def __call__(self, module, step=None):
if step is None:
@ -193,12 +199,17 @@ class EMA:
else:
mu = min(self.mu, (1.0 + step) / (10 + step))
for name, x in module.state_dict().items():
if name in self.shadow:
new_average = (1.0 - mu) * x + mu * self.shadow[name]
self.shadow[name] = new_average.clone()
else:
self.shadow[name] = x.clone()
def strip_module(s: str) -> str:
return s
mesd = self.module_ema.state_dict()
with torch.no_grad():
for name, x in module.state_dict().items():
if name.endswith("num_batches_tracked"):
continue
n = strip_module(name)
mesd[n].mul_(mu)
mesd[n].add_((1.0 - mu) * x)
# }}}
@ -218,10 +229,8 @@ class ONNXSiLU(nn.Module):
class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
def __init__(
self, in_channels, squeeze, activation, quantized=False, use_conv=False
):
super().__init__(in_channels, squeeze, activation, use_conv=use_conv)
def __init__(self, in_channels, squeeze, activation, quantized=False):
super().__init__(in_channels, squeeze, activation)
self.quantized = quantized
if quantized:
assert quant_nn is not None, "pytorch_quantization is not available"
@ -231,10 +240,63 @@ class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
self.mul_b_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
else:
self.mul_a_quantizer = nn.Identity()
self.mul_b_quantizer = nn.Identity()
def forward(self, x):
out = self._attention(x)
if not self.quantized:
return super().forward(x) * x
return out * x
else:
x_quant = self.mul_a_quantizer(super().forward(x))
x_quant = self.mul_a_quantizer(out)
return x_quant * self.mul_b_quantizer(x)
class SequentialSqueezeAndExcitationTRT(SqueezeAndExcitationTRT):
def __init__(self, in_channels, squeeze, activation, quantized=False):
super().__init__(in_channels, squeeze, activation)
self.quantized = quantized
if quantized:
assert quant_nn is not None, "pytorch_quantization is not available"
self.mul_a_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
self.mul_b_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
else:
self.mul_a_quantizer = nn.Identity()
self.mul_b_quantizer = nn.Identity()
def forward(self, x):
out = self._attention(x)
if not self.quantized:
return out * x
else:
x_quant = self.mul_a_quantizer(out)
return x_quant * self.mul_b_quantizer(x)
class StochasticDepthResidual(nn.Module):
def __init__(self, survival_prob: float):
super().__init__()
self.survival_prob = survival_prob
self.register_buffer("mask", torch.ones(()), persistent=False)
def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
if not self.training:
return torch.add(residual, other=x)
else:
with torch.no_grad():
F.dropout(
self.mask,
p=1 - self.survival_prob,
training=self.training,
inplace=False,
)
return torch.addcmul(residual, self.mask, x)
class Flatten(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.squeeze(-1).squeeze(-1)

View File

@ -31,11 +31,11 @@ except ImportError as e:
from .common import (
SqueezeAndExcitation,
ONNXSiLU,
SequentialSqueezeAndExcitation,
SequentialSqueezeAndExcitationTRT,
LayerBuilder,
LambdaLayer,
StochasticDepthResidual,
Flatten,
)
from .model import (
@ -206,6 +206,7 @@ class EfficientNet(nn.Module):
out_channels = arch.stem_channels
plc = 0
layers = []
for i, (k, s, r, e, c) in arch.enumerate():
layer, out_channels = self._make_layer(
block=arch.block,
@ -220,8 +221,8 @@ class EfficientNet(nn.Module):
trt=trt,
)
plc = plc + r
setattr(self, f"layer{i+1}", layer)
layers.append(layer)
self.layers = nn.Sequential(*layers)
self.features = self._make_features(out_channels, arch.feature_channels)
self.classifier = self._make_classifier(
arch.feature_channels, num_classes, dropout
@ -229,11 +230,7 @@ class EfficientNet(nn.Module):
def forward(self, x):
x = self.stem(x)
for i in range(self.num_layers):
fn = getattr(self, f"layer{i+1}")
x = fn(x)
x = self.layers(x)
x = self.features(x)
x = self.classifier(x)
@ -241,27 +238,34 @@ class EfficientNet(nn.Module):
def extract_features(self, x, layers=None):
if layers is None:
layers = [f"layer{i+1}" for i in range(self.num_layers)]
layers = [f"layer{i+1}" for i in range(self.num_layers)] + [
"features",
"classifier",
]
run = [
f"layer{i+1}"
i
for i in range(self.num_layers)
if "classifier" in layers
or "features" in layers
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
]
if "features" in layers or "classifier" in layers:
run.append("features")
if "classifier" in layers:
run.append("classifier")
output = {}
x = self.stem(x)
for l in run:
fn = getattr(self, l)
fn = self.layers[l]
x = fn(x)
if l in layers:
output[l] = x
if f"layer{l+1}" in layers:
output[f"layer{l+1}"] = x
if "features" in layers or "classifier" in layers:
x = self.features(x)
if "features" in layers:
output["features"] = x
if "classifier" in layers:
output["classifier"] = self.classifier(x)
return output
@ -298,7 +302,7 @@ class EfficientNet(nn.Module):
OrderedDict(
[
("pooling", nn.AdaptiveAvgPool2d(1)),
("squeeze", LambdaLayer(lambda x: x.squeeze(-1).squeeze(-1))),
("squeeze", Flatten()),
("dropout", nn.Dropout(dropout)),
("fc", nn.Linear(num_features, num_classes)),
]
@ -353,11 +357,33 @@ class EfficientNet(nn.Module):
layers.append((f"block{idx}", blk))
return nn.Sequential(OrderedDict(layers)), out_channels
def ngc_checkpoint_remap(self, url=None, version=None):
if version is None:
version = url.split("/")[8]
def to_sequential_remap(s):
splited = s.split(".")
if splited[0].startswith("layer"):
return ".".join(
["layers." + str(int(splited[0][len("layer") :]) - 1)] + splited[1:]
)
else:
return s
def no_remap(s):
return s
return {"20.12.0": to_sequential_remap, "21.03.0": to_sequential_remap}.get(
version, no_remap
)
# }}}
# MBConvBlock {{{
class MBConvBlock(nn.Module):
__constants__ = ["quantized"]
def __init__(
self,
builder: LayerBuilder,
@ -366,7 +392,7 @@ class MBConvBlock(nn.Module):
out_channels: int,
expand_ratio: int,
stride: int,
squeeze_excitation_ratio: int,
squeeze_excitation_ratio: float,
squeeze_hidden=False,
survival_prob: float = 1.0,
quantized: bool = False,
@ -387,25 +413,31 @@ class MBConvBlock(nn.Module):
self.depsep = builder.convDepSep(
depsep_kernel_size, hidden_dim, hidden_dim, stride, bn=True, act=True
)
self.se = SequentialSqueezeAndExcitation(
hidden_dim, squeeze_dim, builder.activation(), self.quantized, use_conv=trt
)
if trt or self.quantized:
# Need TRT mode for quantized in order to automatically insert quantization before pooling
self.se: nn.Module = SequentialSqueezeAndExcitationTRT(
hidden_dim, squeeze_dim, builder.activation(), self.quantized
)
else:
self.se: nn.Module = SequentialSqueezeAndExcitation(
hidden_dim, squeeze_dim, builder.activation(), self.quantized
)
self.proj = builder.conv1x1(hidden_dim, out_channels, bn=True)
self.survival_prob = survival_prob
if survival_prob == 1.0:
self.residual_add = torch.add
else:
self.residual_add = StochasticDepthResidual(survival_prob=survival_prob)
if self.quantized and self.residual:
assert quant_nn is not None, "pytorch_quantization is not available"
self.residual_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
) # TODO QuantConv2d ?!?
else:
self.residual_quantizer = nn.Identity()
def drop(self):
if self.survival_prob == 1.0:
return False
return random.uniform(0.0, 1.0) > self.survival_prob
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.residual:
return self.proj(
self.se(self.depsep(x if self.expand is None else self.expand(x)))
@ -414,16 +446,10 @@ class MBConvBlock(nn.Module):
b = self.proj(
self.se(self.depsep(x if self.expand is None else self.expand(x)))
)
if self.training:
if self.drop():
multiplication_factor = 0.0
else:
multiplication_factor = 1.0 / self.survival_prob
else:
multiplication_factor = 1.0
if self.quantized:
x = self.residual_quantizer(x)
return torch.add(x, alpha=multiplication_factor, other=b)
return self.residual_add(x, b)
def original_mbconv(
@ -436,7 +462,7 @@ def original_mbconv(
squeeze_excitation_ratio: int,
survival_prob: float,
quantized: bool,
trt: bool
trt: bool,
):
return MBConvBlock(
builder,

View File

@ -1,8 +1,15 @@
from dataclasses import dataclass, asdict, replace
from .common import (
SequentialSqueezeAndExcitationTRT,
SequentialSqueezeAndExcitation,
SqueezeAndExcitation,
SqueezeAndExcitationTRT,
)
from typing import Optional, Callable
import os
import torch
import argparse
from functools import partial
@dataclass
@ -37,7 +44,13 @@ class EntryPoint:
self.name = name
self.model = model
def __call__(self, pretrained=False, pretrained_from_file=None, **kwargs):
def __call__(
self,
pretrained=False,
pretrained_from_file=None,
state_dict_key_map_fn=None,
**kwargs,
):
assert not (pretrained and (pretrained_from_file is not None))
params = replace(self.model.params, **kwargs)
@ -66,7 +79,7 @@ class EntryPoint:
pretrained_from_file
)
)
# Temporary fix to allow NGC checkpoint loading
if state_dict is not None:
state_dict = {
k[len("module.") :] if k.startswith("module.") else k: v
@ -85,12 +98,32 @@ class EntryPoint:
else:
return t
if state_dict_key_map_fn is not None:
state_dict = {
state_dict_key_map_fn(k): v for k, v in state_dict.items()
}
if hasattr(model, "ngc_checkpoint_remap"):
remap_fn = model.ngc_checkpoint_remap(url=self.model.checkpoint_url)
state_dict = {remap_fn(k): v for k, v in state_dict.items()}
def _se_layer_uses_conv(m):
return any(
map(
partial(isinstance, m),
[
SqueezeAndExcitationTRT,
SequentialSqueezeAndExcitationTRT,
],
)
)
state_dict = {
k: reshape(
v,
conv=dict(model.named_modules())[
".".join(k.split(".")[:-2])
].use_conv,
conv=_se_layer_uses_conv(
dict(model.named_modules())[".".join(k.split(".")[:-2])]
),
)
if is_se_weight(k, v)
else v
@ -123,7 +156,8 @@ class EntryPoint:
def is_se_weight(key, value):
return (key.endswith("squeeze.weight") or key.endswith("expand.weight"))
return key.endswith("squeeze.weight") or key.endswith("expand.weight")
def create_entrypoint(m: Model):
def _ep(**kwargs):

View File

@ -36,14 +36,16 @@ from typing import List, Dict, Callable, Any, Type
import torch
import torch.nn as nn
from .common import SqueezeAndExcitation, LayerBuilder, LambdaLayer
from .common import (
SqueezeAndExcitation,
LayerBuilder,
SqueezeAndExcitationTRT,
)
from .model import (
Model,
ModelParams,
ModelArch,
OptimizerParams,
create_entrypoint,
EntryPoint,
)
@ -128,11 +130,18 @@ class Bottleneck(nn.Module):
self.stride = stride
self.fused_se = fused_se
self.squeeze = (
SqueezeAndExcitation(planes * expansion, se_squeeze, builder.activation(), use_conv=trt)
if se
else None
)
if se:
self.squeeze = (
SqueezeAndExcitation(
planes * expansion, se_squeeze, builder.activation()
)
if not trt
else SqueezeAndExcitationTRT(
planes * expansion, se_squeeze, builder.activation()
)
)
else:
self.squeeze = None
def forward(self, x):
residual = x
@ -215,6 +224,7 @@ class ResNet(nn.Module):
last_bn_0_init: bool = False
conv_init: str = "fan_in"
trt: bool = False
fused_se: bool = True
def parser(self, name):
p = super().parser(name)
@ -240,6 +250,10 @@ class ResNet(nn.Module):
help="initialization mode for convolutional layers, see https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_",
)
p.add_argument("--trt", metavar="True|False", default=self.trt, type=bool)
p.add_argument(
"--fused_se", metavar="True|False", default=self.fused_se, type=bool
)
return p
def __init__(
@ -249,6 +263,7 @@ class ResNet(nn.Module):
last_bn_0_init: bool = False,
conv_init: str = "fan_in",
trt: bool = False,
fused_se: bool = True,
):
super(ResNet, self).__init__()
@ -265,6 +280,7 @@ class ResNet(nn.Module):
inplanes = arch.stem_width
assert len(arch.widths) == len(arch.layers)
self.num_layers = len(arch.widths)
layers = []
for i, (w, l) in enumerate(zip(arch.widths, arch.layers)):
layer, inplanes = self._make_layer(
arch.block,
@ -275,9 +291,11 @@ class ResNet(nn.Module):
cardinality=arch.cardinality,
stride=1 if i == 0 else 2,
trt=trt,
fused_se=fused_se,
)
setattr(self, f"layer{i+1}", layer)
layers.append(layer)
self.layers = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(arch.widths[-1] * arch.expansion, num_classes)
@ -297,13 +315,8 @@ class ResNet(nn.Module):
def forward(self, x):
x = self.stem(x)
for i in range(self.num_layers):
fn = getattr(self, f"layer{i+1}")
x = fn(x)
x = self.layers(x)
x = self.classifier(x)
return x
def extract_features(self, x, layers=None):
@ -311,7 +324,7 @@ class ResNet(nn.Module):
layers = [f"layer{i+1}" for i in range(self.num_layers)] + ["classifier"]
run = [
f"layer{i+1}"
i
for i in range(self.num_layers)
if "classifier" in layers
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
@ -320,10 +333,10 @@ class ResNet(nn.Module):
output = {}
x = self.stem(x)
for l in run:
fn = getattr(self, l)
fn = self.layers[l]
x = fn(x)
if l in layers:
output[l] = x
if f"layer{l+1}" in layers:
output[f"layer{l+1}"] = x
if "classifier" in layers:
output["classifier"] = self.classifier(x)
@ -332,7 +345,16 @@ class ResNet(nn.Module):
# helper functions {{{
def _make_layer(
self, block, expansion, inplanes, planes, blocks, stride=1, cardinality=1, trt=False,
self,
block,
expansion,
inplanes,
planes,
blocks,
stride=1,
cardinality=1,
trt=False,
fused_se=True,
):
downsample = None
if stride != 1 or inplanes != planes * expansion:
@ -354,15 +376,33 @@ class ResNet(nn.Module):
stride=stride if i == 0 else 1,
cardinality=cardinality,
downsample=downsample if i == 0 else None,
fused_se=True,
fused_se=fused_se,
last_bn_0_init=self.last_bn_0_init,
trt = trt,
trt=trt,
)
)
inplanes = planes * expansion
return nn.Sequential(*layers), inplanes
def ngc_checkpoint_remap(self, url=None, version=None):
if version is None:
version = url.split("/")[8]
def to_sequential_remap(s):
splited = s.split(".")
if splited[0].startswith("layer"):
return ".".join(
["layers." + str(int(splited[0][len("layer") :]) - 1)] + splited[1:]
)
else:
return s
def no_remap(s):
return s
return {"20.06.0": to_sequential_remap}.get(version, no_remap)
# }}}

View File

@ -1,7 +1,39 @@
import math
import numpy as np
import torch
from torch import optim
def get_sgd_optimizer(parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False):
def get_optimizer(parameters, lr, args, state=None):
if args.optimizer == "sgd":
optimizer = get_sgd_optimizer(
parameters,
lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=args.nesterov,
bn_weight_decay=args.bn_weight_decay,
)
elif args.optimizer == "rmsprop":
optimizer = get_rmsprop_optimizer(
parameters,
lr,
alpha=args.rmsprop_alpha,
momentum=args.momentum,
weight_decay=args.weight_decay,
eps=args.rmsprop_eps,
bn_weight_decay=args.bn_weight_decay,
)
if not state is None:
optimizer.load_state_dict(state)
return optimizer
def get_sgd_optimizer(
parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False
):
if bn_weight_decay:
print(" ! Weight decay applied to BN parameters ")
params = [v for n, v in parameters]
@ -17,20 +49,112 @@ def get_sgd_optimizer(parameters, lr, momentum, weight_decay, nesterov=False, bn
{"params": rest_params, "weight_decay": weight_decay},
]
optimizer = torch.optim.SGD(params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
optimizer = torch.optim.SGD(
params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov
)
return optimizer
def get_rmsprop_optimizer(parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False):
def get_rmsprop_optimizer(
parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False
):
bn_params = [v for n, v in parameters if "bn" in n]
rest_params = [v for n, v in parameters if not "bn" in n]
params = [
{"params": bn_params, "weight_decay": weight_decay if bn_weight_decay else 0},
{"params": bn_params, "weight_decay": weight_decay if bn_weight_decay else 0},
{"params": rest_params, "weight_decay": weight_decay},
]
optimizer = torch.optim.RMSprop(params, lr=lr, alpha=alpha, weight_decay=weight_decay, momentum=momentum, eps=eps)
optimizer = torch.optim.RMSprop(
params,
lr=lr,
alpha=alpha,
weight_decay=weight_decay,
momentum=momentum,
eps=eps,
)
return optimizer
def lr_policy(lr_fn):
def _alr(optimizer, iteration, epoch):
lr = lr_fn(iteration, epoch)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return lr
return _alr
def lr_step_policy(base_lr, steps, decay_factor, warmup_length):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
lr = base_lr
for s in steps:
if epoch >= s:
lr *= decay_factor
return lr
return lr_policy(_lr_fn)
def lr_linear_policy(base_lr, warmup_length, epochs):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = base_lr * (1 - (e / es))
return lr
return lr_policy(_lr_fn)
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr=0):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
return lr
return lr_policy(_lr_fn)
def lr_exponential_policy(
base_lr,
warmup_length,
epochs,
final_multiplier=0.001,
decay_factor=None,
decay_step=1,
logger=None,
):
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
es = epochs - warmup_length
if decay_factor is not None:
epoch_decay = decay_factor
else:
epoch_decay = np.power(
2, np.log2(final_multiplier) / math.floor(es / decay_step)
)
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
lr = base_lr * (epoch_decay ** math.floor(e / decay_step))
return lr
return lr_policy(_lr_fn, logger=logger)

View File

@ -27,279 +27,212 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import math
import time
from copy import deepcopy
from functools import wraps
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from . import logger as log
from . import models
from . import utils
import dllogger
from .optimizers import get_sgd_optimizer, get_rmsprop_optimizer
from .models.common import EMA
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
ACC_METADATA = {"unit": "%", "format": ":.2f"}
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
TIME_METADATA = {"unit": "s", "format": ":.5f"}
LOSS_METADATA = {"format": ":.5f"}
from . import logger as log
from . import utils
from .logger import TrainingMetrics, ValidationMetrics
from .models.common import EMA
class ModelAndLoss(nn.Module):
class Executor:
def __init__(
self,
model,
loss,
cuda=True,
memory_format=torch.contiguous_format,
model: nn.Module,
loss: Optional[nn.Module],
cuda: bool = True,
memory_format: torch.memory_format = torch.contiguous_format,
amp: bool = False,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
divide_loss: int = 1,
ts_script: bool = False,
):
super(ModelAndLoss, self).__init__()
assert not (amp and scaler is None), "Gradient Scaler is needed for AMP"
if cuda:
model = model.cuda().to(memory_format=memory_format)
def xform(m: nn.Module) -> nn.Module:
if cuda:
m = m.cuda()
m.to(memory_format=memory_format)
return m
# define loss function (criterion) and optimizer
criterion = loss()
if cuda:
criterion = criterion.cuda()
self.model = model
self.loss = criterion
def forward(self, data, target):
output = self.model(data)
loss = self.loss(output, target)
return loss, output
self.model = xform(model)
if ts_script:
self.model = torch.jit.script(self.model)
self.ts_script = ts_script
self.loss = xform(loss) if loss is not None else None
self.amp = amp
self.scaler = scaler
self.is_distributed = False
self.divide_loss = divide_loss
self._fwd_bwd = None
self._forward = None
def distributed(self, gpu_id):
self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id)
self.is_distributed = True
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id)
torch.cuda.current_stream().wait_stream(s)
def load_model_state(self, state):
if not state is None:
self.model.load_state_dict(state)
def _fwd_bwd_fn(
self,
input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
with autocast(enabled=self.amp):
loss = self.loss(self.model(input), target)
loss /= self.divide_loss
self.scaler.scale(loss).backward()
return loss
def _forward_fn(
self, input: torch.Tensor, target: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad(), autocast(enabled=self.amp):
output = self.model(input)
loss = None if self.loss is None else self.loss(output, target)
return output if loss is None else loss, output
def optimize(self, fn):
return fn
@property
def forward_backward(self):
if self._fwd_bwd is None:
if self.loss is None:
raise NotImplementedError(
"Loss must not be None for forward+backward step"
)
self._fwd_bwd = self.optimize(self._fwd_bwd_fn)
return self._fwd_bwd
@property
def forward(self):
if self._forward is None:
self._forward = self.optimize(self._forward_fn)
return self._forward
def get_optimizer(parameters, lr, args, state=None):
if args.optimizer == 'sgd':
optimizer = get_sgd_optimizer(parameters, lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=args.nesterov,
bn_weight_decay=args.bn_weight_decay)
elif args.optimizer == 'rmsprop':
optimizer = get_rmsprop_optimizer(parameters, lr, alpha=args.rmsprop_alpha, momentum=args.momentum,
weight_decay=args.weight_decay,
eps=args.rmsprop_eps,
bn_weight_decay=args.bn_weight_decay)
if not state is None:
optimizer.load_state_dict(state)
class Trainer:
def __init__(
self,
executor: Executor,
optimizer: torch.optim.Optimizer,
grad_acc_steps: int,
ema: Optional[float] = None,
):
self.executor = executor
self.optimizer = optimizer
self.grad_acc_steps = grad_acc_steps
self.use_ema = False
if ema is not None:
self.ema_executor = deepcopy(self.executor)
self.ema = EMA(ema, self.ema_executor.model)
self.use_ema = True
return optimizer
self.optimizer.zero_grad(set_to_none=True)
self.steps_since_update = 0
def train(self):
self.executor.model.train()
def lr_policy(lr_fn, logger=None):
if logger is not None:
logger.register_metric(
"lr", log.LR_METER(), verbosity=dllogger.Verbosity.VERBOSE
)
def eval(self):
self.executor.model.eval()
if self.use_ema:
self.executor.model.eval()
def _alr(optimizer, iteration, epoch):
lr = lr_fn(iteration, epoch)
def train_step(self, input, target, step=None):
loss = self.executor.forward_backward(input, target)
if logger is not None:
logger.log_metric("lr", lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
self.steps_since_update += 1
return _alr
def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
lr = base_lr
for s in steps:
if epoch >= s:
lr *= decay_factor
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_linear_policy(base_lr, warmup_length, epochs, logger=None):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = base_lr * (1 - (e / es))
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr = 0, logger=None):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_exponential_policy(
base_lr, warmup_length, epochs, final_multiplier=0.001, decay_factor=None, decay_step=1, logger=None
):
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
es = epochs - warmup_length
if decay_factor is not None:
epoch_decay = decay_factor
else:
epoch_decay = np.power(2, np.log2(final_multiplier) / math.floor(es/decay_step))
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
lr = base_lr * (epoch_decay ** math.floor(e/decay_step))
return lr
return lr_policy(_lr_fn, logger=logger)
def get_train_step(
model_and_loss, optimizer, scaler, use_amp=False, batch_size_multiplier=1
):
def _step(input, target, optimizer_step=True):
input_var = Variable(input)
target_var = Variable(target)
with autocast(enabled=use_amp):
loss, output = model_and_loss(input_var, target_var)
loss /= batch_size_multiplier
if torch.distributed.is_initialized():
reduced_loss = utils.reduce_tensor(loss.data)
if self.steps_since_update == self.grad_acc_steps:
if self.executor.scaler is not None:
self.executor.scaler.step(self.optimizer)
self.executor.scaler.update()
else:
reduced_loss = loss.data
scaler.scale(loss).backward()
if optimizer_step:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
self.optimizer.step()
self.optimizer.zero_grad()
self.steps_since_update = 0
torch.cuda.synchronize()
return reduced_loss
if self.use_ema:
self.ema(self.executor.model, step=step)
return _step
return loss
def validation_steps(self) -> Dict[str, Callable]:
vsd: Dict[str, Callable] = {"val": self.executor.forward}
if self.use_ema:
vsd["val_ema"] = self.ema_executor.forward
return vsd
def state_dict(self) -> dict:
res = {
"state_dict": self.executor.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
if self.use_ema:
res["state_dict_ema"] = self.ema_executor.model.state_dict()
return res
def train(
train_step,
train_loader,
model_and_loss,
optimizer,
scaler,
lr_scheduler,
logger,
epoch,
steps_per_epoch,
log_fn,
timeout_handler,
ema=None,
use_amp=False,
prof=-1,
batch_size_multiplier=1,
register_metrics=True,
step=0,
):
interrupted = False
if register_metrics and logger is not None:
logger.register_metric(
"train.loss",
log.LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=LOSS_METADATA,
)
logger.register_metric(
"train.compute_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=IPS_METADATA,
)
logger.register_metric(
"train.total_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=IPS_METADATA,
)
logger.register_metric(
"train.data_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
"train.compute_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
step = get_train_step(
model_and_loss,
optimizer,
scaler=scaler,
use_amp=use_amp,
batch_size_multiplier=batch_size_multiplier,
)
model_and_loss.train()
end = time.time()
optimizer.zero_grad()
data_iter = enumerate(train_loader)
if logger is not None:
data_iter = logger.iteration_generator_wrapper(data_iter, mode='train')
for i, (input, target) in data_iter:
bs = input.size(0)
lr_scheduler(optimizer, i, epoch)
lr = lr_scheduler(i)
data_time = time.time() - end
optimizer_step = ((i + 1) % batch_size_multiplier) == 0
loss = step(input, target, optimizer_step=optimizer_step)
if ema is not None:
ema(model_and_loss, epoch*steps_per_epoch+i)
loss = train_step(input, target, step=step + i)
it_time = time.time() - end
if logger is not None:
logger.log_metric("train.loss", loss.item(), bs)
logger.log_metric("train.compute_ips", utils.calc_ips(bs, it_time - data_time))
logger.log_metric("train.total_ips", utils.calc_ips(bs, it_time))
logger.log_metric("train.data_time", data_time)
logger.log_metric("train.compute_time", it_time - data_time)
with torch.no_grad():
if torch.distributed.is_initialized():
reduced_loss = utils.reduce_tensor(loss.detach())
else:
reduced_loss = loss.detach()
log_fn(
compute_ips=utils.calc_ips(bs, it_time - data_time),
total_ips=utils.calc_ips(bs, it_time),
data_time=data_time,
compute_time=it_time - data_time,
lr=lr,
loss=reduced_loss.item(),
)
end = time.time()
if prof > 0 and (i + 1 >= prof):
time.sleep(5)
break
if ((i+1) % 20 == 0) and timeout_handler.interrupted:
if ((i + 1) % 20 == 0) and timeout_handler.interrupted:
time.sleep(5)
interrupted = True
break
@ -307,134 +240,58 @@ def train(
return interrupted
def get_val_step(model_and_loss, use_amp=False):
def _step(input, target):
input_var = Variable(input)
target_var = Variable(target)
with torch.no_grad(), autocast(enabled=use_amp):
loss, output = model_and_loss(input_var, target_var)
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
if torch.distributed.is_initialized():
reduced_loss = utils.reduce_tensor(loss.data)
prec1 = utils.reduce_tensor(prec1)
prec5 = utils.reduce_tensor(prec5)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
return reduced_loss, prec1, prec5
return _step
def validate(
val_loader,
model_and_loss,
logger,
epoch,
use_amp=False,
prof=-1,
register_metrics=True,
prefix="val",
):
if register_metrics and logger is not None:
logger.register_metric(
f"{prefix}.top1",
log.ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=ACC_METADATA,
)
logger.register_metric(
f"{prefix}.top5",
log.ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=ACC_METADATA,
)
logger.register_metric(
f"{prefix}.loss",
log.LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=LOSS_METADATA,
)
logger.register_metric(
f"{prefix}.compute_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=IPS_METADATA,
)
logger.register_metric(
f"{prefix}.total_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=IPS_METADATA,
)
logger.register_metric(
f"{prefix}.data_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at100",
log.LAT_100(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at99",
log.LAT_99(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at95",
log.LAT_95(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
step = get_val_step(model_and_loss, use_amp=use_amp)
def validate(infer_fn, val_loader, log_fn, prof=-1, with_loss=True):
top1 = log.AverageMeter()
# switch to evaluate mode
model_and_loss.eval()
end = time.time()
data_iter = enumerate(val_loader)
if not logger is None:
data_iter = logger.iteration_generator_wrapper(data_iter, mode='val')
for i, (input, target) in data_iter:
bs = input.size(0)
data_time = time.time() - end
loss, prec1, prec5 = step(input, target)
if with_loss:
loss, output = infer_fn(input, target)
else:
output = infer_fn(input)
with torch.no_grad():
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
if torch.distributed.is_initialized():
if with_loss:
reduced_loss = utils.reduce_tensor(loss.detach())
prec1 = utils.reduce_tensor(prec1)
prec5 = utils.reduce_tensor(prec5)
else:
if with_loss:
reduced_loss = loss.detach()
prec1 = prec1.item()
prec5 = prec5.item()
infer_result = {
"top1": (prec1, bs),
"top5": (prec5, bs),
}
if with_loss:
infer_result["loss"] = (reduced_loss.item(), bs)
torch.cuda.synchronize()
it_time = time.time() - end
top1.record(prec1.item(), bs)
if logger is not None:
logger.log_metric(f"{prefix}.top1", prec1.item(), bs)
logger.log_metric(f"{prefix}.top5", prec5.item(), bs)
logger.log_metric(f"{prefix}.loss", loss.item(), bs)
logger.log_metric(f"{prefix}.compute_ips", utils.calc_ips(bs, it_time - data_time))
logger.log_metric(f"{prefix}.total_ips", utils.calc_ips(bs, it_time))
logger.log_metric(f"{prefix}.data_time", data_time)
logger.log_metric(f"{prefix}.compute_latency", it_time - data_time)
logger.log_metric(f"{prefix}.compute_latency_at95", it_time - data_time)
logger.log_metric(f"{prefix}.compute_latency_at99", it_time - data_time)
logger.log_metric(f"{prefix}.compute_latency_at100", it_time - data_time)
top1.record(prec1, bs)
log_fn(
compute_ips=utils.calc_ips(bs, it_time - data_time),
total_ips=utils.calc_ips(bs, it_time),
data_time=data_time,
compute_time=it_time - data_time,
**infer_result,
)
end = time.time()
if (prof > 0) and (i + 1 >= prof):
@ -445,22 +302,14 @@ def validate(
# Train loop {{{
def train_loop(
model_and_loss,
optimizer,
scaler,
trainer: Trainer,
lr_scheduler,
train_loader,
train_loader_len,
val_loader,
logger,
should_backup_checkpoint,
steps_per_epoch,
ema=None,
model_ema=None,
use_amp=False,
batch_size_multiplier=1,
best_prec1=0,
start_epoch=0,
end_epoch=0,
@ -472,14 +321,22 @@ def train_loop(
checkpoint_dir="./",
checkpoint_filename="checkpoint.pth.tar",
):
train_metrics = TrainingMetrics(logger)
val_metrics = {
k: ValidationMetrics(logger, k) for k in trainer.validation_steps().keys()
}
training_step = trainer.train_step
prec1 = -1
use_ema = (model_ema is not None) and (ema is not None)
if early_stopping_patience > 0:
epochs_since_improvement = 0
backup_prefix = checkpoint_filename[:-len("checkpoint.pth.tar")] if \
checkpoint_filename.endswith("checkpoint.pth.tar") else ""
backup_prefix = (
checkpoint_filename[: -len("checkpoint.pth.tar")]
if checkpoint_filename.endswith("checkpoint.pth.tar")
else ""
)
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
with utils.TimeoutHandler() as timeout_handler:
interrupted = False
@ -487,73 +344,71 @@ def train_loop(
if logger is not None:
logger.start_epoch()
if not skip_training:
if logger is not None:
data_iter = logger.iteration_generator_wrapper(
train_loader, mode="train"
)
else:
data_iter = train_loader
trainer.train()
interrupted = train(
train_loader,
model_and_loss,
optimizer,
scaler,
lr_scheduler,
logger,
epoch,
steps_per_epoch,
training_step,
data_iter,
lambda i: lr_scheduler(trainer.optimizer, i, epoch),
train_metrics.log,
timeout_handler,
ema=ema,
use_amp=use_amp,
prof=prof,
register_metrics=epoch == start_epoch,
batch_size_multiplier=batch_size_multiplier,
step=epoch * train_loader_len,
)
if not skip_validation:
prec1, nimg = validate(
val_loader,
model_and_loss,
logger,
epoch,
use_amp=use_amp,
prof=prof,
register_metrics=epoch == start_epoch,
)
if use_ema:
model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()})
prec1, nimg = validate(
val_loader,
model_ema,
logger,
epoch,
trainer.eval()
for k, infer_fn in trainer.validation_steps().items():
if logger is not None:
data_iter = logger.iteration_generator_wrapper(
val_loader, mode="val"
)
else:
data_iter = val_loader
step_prec1, _ = validate(
infer_fn,
data_iter,
val_metrics[k].log,
prof=prof,
register_metrics=epoch == start_epoch,
prefix='val_ema'
)
if k == "val":
prec1 = step_prec1
if prec1 > best_prec1:
is_best = True
best_prec1 = prec1
else:
is_best = False
else:
is_best = True
is_best = False
best_prec1 = 0
if logger is not None:
logger.end_epoch()
if save_checkpoints and (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
not torch.distributed.is_initialized()
or torch.distributed.get_rank() == 0
):
if should_backup_checkpoint(epoch):
backup_filename = "{}checkpoint-{}.pth.tar".format(backup_prefix, epoch + 1)
backup_filename = "{}checkpoint-{}.pth.tar".format(
backup_prefix, epoch + 1
)
else:
backup_filename = None
checkpoint_state = {
"epoch": epoch + 1,
"state_dict": model_and_loss.model.state_dict(),
"best_prec1": best_prec1,
"optimizer": optimizer.state_dict(),
**trainer.state_dict(),
}
if use_ema:
checkpoint_state["state_dict_ema"] = ema.state_dict()
utils.save_checkpoint(
checkpoint_state,
is_best,
@ -561,6 +416,7 @@ def train_loop(
backup_filename=backup_filename,
filename=checkpoint_filename,
)
if early_stopping_patience > 0:
if not is_best:
epochs_since_improvement += 1
@ -570,4 +426,6 @@ def train_loop(
break
if interrupted:
break
# }}}

View File

@ -97,7 +97,7 @@ def accuracy(output, target, topk=(1,)):
def reduce_tensor(tensor):
rt = tensor.clone()
rt = tensor.clone().detach()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= (
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
@ -114,6 +114,7 @@ class TimeoutHandler:
def __init__(self, sig=signal.SIGTERM):
self.sig = sig
self.device = torch.device("cuda")
@property
def interrupted(self):
if not dist.is_initialized():

View File

@ -58,6 +58,12 @@ from image_classification.models import (
efficientnet_widese_b0,
efficientnet_widese_b4,
)
from image_classification.optimizers import (
get_optimizer,
lr_cosine_policy,
lr_linear_policy,
lr_step_policy,
)
import dllogger
@ -102,7 +108,9 @@ def add_parser_arguments(parser, skip_arch=False):
metavar="ARCH",
default="resnet50",
choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)",
help="model architecture: "
+ " | ".join(model_names)
+ " (default: resnet50)",
)
parser.add_argument(
@ -290,6 +298,13 @@ def add_parser_arguments(parser, skip_arch=False):
dest="save_checkpoints",
help="do not store any checkpoints, useful for benchmarking",
)
parser.add_argument(
"--jit",
type=str,
default = "no",
choices=["no", "script"],
help="no -> do not use torch.jit; script -> use torch.jit.script"
)
parser.add_argument("--checkpoint-filename", default="checkpoint.pth.tar", type=str)
@ -320,7 +335,7 @@ def add_parser_arguments(parser, skip_arch=False):
type=int,
default=None,
required=False,
help="number of classes"
help="number of classes",
)
@ -432,13 +447,25 @@ def prepare_for_training(args, model_args, model_arch):
if args.image_size is not None
else model.arch.default_image_size
)
model_and_loss = ModelAndLoss(model, loss, cuda=True, memory_format=memory_format)
if args.use_ema is not None:
model_ema = deepcopy(model_and_loss)
ema = EMA(args.use_ema)
else:
model_ema = None
ema = None
scaler = torch.cuda.amp.GradScaler(
init_scale=args.static_loss_scale,
growth_factor=2,
backoff_factor=0.5,
growth_interval=100 if args.dynamic_loss_scale else 1000000000,
enabled=args.amp,
)
executor = Executor(
model,
loss(),
cuda=True,
memory_format=memory_format,
amp=args.amp,
scaler=scaler,
divide_loss=batch_size_multiplier,
ts_script = args.jit == "script",
)
# Create data loaders and optimizers as needed
if args.data_backend == "pytorch":
@ -463,7 +490,7 @@ def prepare_for_training(args, model_args, model_arch):
args.batch_size,
model_args.num_classes,
args.mixup > 0.0,
interpolation = args.interpolation,
interpolation=args.interpolation,
augmentation=args.augmentation,
start_epoch=start_epoch,
workers=args.workers,
@ -478,7 +505,7 @@ def prepare_for_training(args, model_args, model_arch):
args.batch_size,
model_args.num_classes,
False,
interpolation = args.interpolation,
interpolation=args.interpolation,
workers=args.workers,
memory_format=memory_format,
)
@ -508,41 +535,38 @@ def prepare_for_training(args, model_args, model_arch):
)
optimizer = get_optimizer(
list(model_and_loss.model.named_parameters()),
list(executor.model.named_parameters()),
args.lr,
args=args,
state=optimizer_state,
)
if args.lr_schedule == "step":
lr_policy = lr_step_policy(
args.lr, [30, 60, 80], 0.1, args.warmup, logger=logger
)
lr_policy = lr_step_policy(args.lr, [30, 60, 80], 0.1, args.warmup)
elif args.lr_schedule == "cosine":
lr_policy = lr_cosine_policy(
args.lr, args.warmup, args.epochs, end_lr=args.end_lr, logger=logger
args.lr, args.warmup, args.epochs, end_lr=args.end_lr
)
elif args.lr_schedule == "linear":
lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs, logger=logger)
scaler = torch.cuda.amp.GradScaler(
init_scale=args.static_loss_scale,
growth_factor=2,
backoff_factor=0.5,
growth_interval=100 if args.dynamic_loss_scale else 1000000000,
enabled=args.amp,
)
lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs)
if args.distributed:
model_and_loss.distributed(args.gpu)
executor.distributed(args.gpu)
model_and_loss.load_model_state(model_state)
if (ema is not None) and (model_state_ema is not None):
print("load ema")
ema.load_state_dict(model_state_ema)
if model_state is not None:
executor.model.load_state_dict(model_state)
return (model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema,
train_loader_len, batch_size_multiplier, start_epoch)
trainer = Trainer(
executor,
optimizer,
grad_acc_steps=batch_size_multiplier,
ema=args.use_ema,
)
if (args.use_ema is not None) and (model_state_ema is not None):
trainer.ema_executor.model.load_state_dict(model_state_ema)
return (trainer, lr_policy, train_loader, train_loader_len, val_loader, logger, start_epoch)
def main(args, model_args, model_arch):
@ -550,23 +574,24 @@ def main(args, model_args, model_arch):
global best_prec1
best_prec1 = 0
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch)
train_loop(
model_and_loss,
optimizer,
scaler,
(
trainer,
lr_policy,
train_loader,
train_loader_len,
val_loader,
logger,
start_epoch,
) = prepare_for_training(args, model_args, model_arch)
train_loop(
trainer,
lr_policy,
train_loader,
train_loader_len,
val_loader,
logger,
should_backup_checkpoint(args),
ema=ema,
model_ema=model_ema,
steps_per_epoch=train_loader_len,
use_amp=args.amp,
batch_size_multiplier=batch_size_multiplier,
start_epoch=start_epoch,
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
if args.run_epochs != -1
@ -587,7 +612,6 @@ def main(args, model_args, model_arch):
if __name__ == "__main__":
epilog = [
"Based on the architecture picked by --arch flag, you may use the following options:\n"
]
@ -603,7 +627,7 @@ if __name__ == "__main__":
add_parser_arguments(parser)
args, rest = parser.parse_known_args()
model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest)
print(model_args)

View File

@ -68,9 +68,11 @@ def parse_quantization(parser):
metavar="ARCH",
default="efficientnet-quant-b0",
choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: efficientnet-quant-b0)",
help="model architecture: "
+ " | ".join(model_names)
+ " (default: efficientnet-quant-b0)",
)
parser.add_argument(
"--skip-calibration",
action="store_true",
@ -80,6 +82,7 @@ def parse_quantization(parser):
def parse_training_args(parser):
from main import add_parser_arguments
return add_parser_arguments(parser)
@ -92,28 +95,29 @@ def main(args, model_args, model_arch):
select_default_calib_method()
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch)
print(f"RUNNING QUANTIZATION")
if not skip_calibration:
calibrate(model_and_loss.model, train_loader, logger, calib_iter=10)
train_loop(
model_and_loss,
optimizer,
scaler,
(
trainer,
lr_policy,
train_loader,
train_loader_len,
val_loader,
logger,
start_epoch,
) = prepare_for_training(args, model_args, model_arch)
print(f"RUNNING QUANTIZATION")
if not skip_calibration:
calibrate(trainer.model_and_loss.model, train_loader, logger, calib_iter=10)
train_loop(
trainer,
lr_policy,
train_loader,
train_loader_len,
val_loader,
logger,
should_backup_checkpoint(args),
ema=ema,
model_ema=model_ema,
steps_per_epoch=train_loader_len,
use_amp=args.amp,
batch_size_multiplier=batch_size_multiplier,
start_epoch=start_epoch,
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
if args.run_epochs != -1
@ -124,7 +128,7 @@ def main(args, model_args, model_arch):
skip_validation=args.training_only,
save_checkpoints=args.save_checkpoints,
checkpoint_dir=args.workspace,
checkpoint_filename='quantized_' + args.checkpoint_filename,
checkpoint_filename="quantized_" + args.checkpoint_filename,
)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
@ -149,7 +153,7 @@ if __name__ == "__main__":
parse_training(parser, skip_arch=True)
args, rest = parser.parse_known_args()
model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest)
print(model_args)
@ -159,4 +163,3 @@ if __name__ == "__main__":
cudnn.benchmark = True
main(args, model_args, model_arch)