[ConvNets/PyT] TorchScriptable ConvNets

This commit is contained in:
Andrzej Sulecki 2021-11-09 13:42:18 -08:00 committed by Krzysztof Kudrynski
parent 3d3250a3ae
commit 4f2c6922bd
12 changed files with 915 additions and 590 deletions

View file

@ -0,0 +1,2 @@
*.pth.tar
*.log

View file

@ -22,6 +22,7 @@ def add_parser_arguments(parser):
parser.add_argument( parser.add_argument(
"--weight-path", metavar="<path>", help="name of file in which to store weights" "--weight-path", metavar="<path>", help="name of file in which to store weights"
) )
parser.add_argument("--ema", action="store_true", default=False)
if __name__ == "__main__": if __name__ == "__main__":
@ -30,12 +31,13 @@ if __name__ == "__main__":
add_parser_arguments(parser) add_parser_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path, map_location=torch.device('cpu')) checkpoint = torch.load(args.checkpoint_path, map_location=torch.device("cpu"))
key = "state_dict" if not args.ema else "ema_state_dict"
model_state_dict = { model_state_dict = {
k[len("module.") :] if "module." in k else k: v k[len("module.") :] if "module." in k else k: v
for k, v in checkpoint["state_dict"].items() for k, v in checkpoint["state_dict"].items()
} }
print(f"Loaded model, acc : {checkpoint['best_prec1']}") print(f"Loaded model, acc : {checkpoint['best_prec1']}")
torch.save(model_state_dict, args.weight_path) torch.save(model_state_dict, args.weight_path)

View file

@ -234,32 +234,29 @@ class Logger(object):
def log_metric(self, metric_name, val, n=1): def log_metric(self, metric_name, val, n=1):
self.metrics[metric_name]["meter"].record(val, n=n) self.metrics[metric_name]["meter"].record(val, n=n)
def start_iteration(self, mode='train'): def start_iteration(self, mode="train"):
if mode == 'val': if mode == "val":
self.val_iteration += 1 self.val_iteration += 1
elif mode == 'train': elif mode == "train":
self.iteration += 1 self.iteration += 1
elif mode == 'calib': elif mode == "calib":
self.calib_iteration += 1 self.calib_iteration += 1
def end_iteration(self, mode='train'): def end_iteration(self, mode="train"):
if mode == 'val': if mode == "val":
it = self.val_iteration it = self.val_iteration
elif mode == 'train': elif mode == "train":
it = self.iteration it = self.iteration
elif mode == 'calib': elif mode == "calib":
it = self.calib_iteration it = self.calib_iteration
if it % self.print_interval == 0 or mode == "calib":
if it % self.print_interval == 0 or mode == 'calib': metrics = {n: m for n, m in self.metrics.items() if n.startswith(mode)}
metrics = { if mode == "train":
n: m for n, m in self.metrics.items() if n.startswith(mode)
}
if mode == 'train':
step = (self.epoch, self.iteration) step = (self.epoch, self.iteration)
elif mode == 'val': elif mode == "val":
step = (self.epoch, self.iteration, self.val_iteration) step = (self.epoch, self.iteration, self.val_iteration)
elif mode == 'calib': elif mode == "calib":
step = ('Calibration', self.calib_iteration) step = ("Calibration", self.calib_iteration)
verbositys = {m["level"] for _, m in metrics.items()} verbositys = {m["level"] for _, m in metrics.items()}
for ll in verbositys: for ll in verbositys:
@ -282,12 +279,12 @@ class Logger(object):
self.val_iteration = 0 self.val_iteration = 0
for n, m in self.metrics.items(): for n, m in self.metrics.items():
if not n.startswith('calib'): if not n.startswith("calib"):
m["meter"].reset_epoch() m["meter"].reset_epoch()
def end_epoch(self): def end_epoch(self):
for n, m in self.metrics.items(): for n, m in self.metrics.items():
if not n.startswith('calib'): if not n.startswith("calib"):
m["meter"].reset_iteration() m["meter"].reset_iteration()
verbositys = {m["level"] for _, m in self.metrics.items()} verbositys = {m["level"] for _, m in self.metrics.items()}
@ -302,12 +299,12 @@ class Logger(object):
self.calib_iteration = 0 self.calib_iteration = 0
for n, m in self.metrics.items(): for n, m in self.metrics.items():
if n.startswith('calib'): if n.startswith("calib"):
m["meter"].reset_epoch() m["meter"].reset_epoch()
def end_calibration(self): def end_calibration(self):
for n, m in self.metrics.items(): for n, m in self.metrics.items():
if n.startswith('calib'): if n.startswith("calib"):
m["meter"].reset_iteration() m["meter"].reset_iteration()
def end(self): def end(self):
@ -326,7 +323,7 @@ class Logger(object):
dllogger.flush() dllogger.flush()
def iteration_generator_wrapper(self, gen, mode='train'): def iteration_generator_wrapper(self, gen, mode="train"):
for g in gen: for g in gen:
self.start_iteration(mode=mode) self.start_iteration(mode=mode)
yield g yield g
@ -337,3 +334,155 @@ class Logger(object):
self.start_epoch() self.start_epoch()
yield g yield g
self.end_epoch() self.end_epoch()
class Metrics:
ACC_METADATA = {"unit": "%", "format": ":.2f"}
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
TIME_METADATA = {"unit": "s", "format": ":.5f"}
LOSS_METADATA = {"format": ":.5f"}
LR_METADATA = {"format": ":.5f"}
def __init__(self, logger):
self.logger = logger
self.map = {}
def log(self, **kwargs):
if self.logger is None:
return
for k, v in kwargs.items():
tks = self.map.get(k, [k])
for tk in tks:
if isinstance(v, tuple):
self.logger.log_metric(tk, v[0], v[1])
else:
self.logger.log_metric(tk, v)
class TrainingMetrics(Metrics):
def __init__(self, logger):
super().__init__(logger)
if self.logger is not None:
self.map = {
"loss": ["train.loss"],
"compute_ips": ["train.compute_ips"],
"total_ips": ["train.total_ips"],
"data_time": ["train.data_time"],
"compute_time": ["train.compute_time"],
"lr": ["train.lr"],
}
logger.register_metric(
"train.loss",
LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.LOSS_METADATA,
)
logger.register_metric(
"train.compute_ips",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.IPS_METADATA,
)
logger.register_metric(
"train.total_ips",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.IPS_METADATA,
)
logger.register_metric(
"train.data_time",
PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
"train.compute_time",
PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
"train.lr",
LR_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
)
class ValidationMetrics(Metrics):
def __init__(self, logger, prefix):
super().__init__(logger)
if self.logger is not None:
self.map = {
"loss": [f"{prefix}.loss"],
"top1": [f"{prefix}.top1"],
"top5": [f"{prefix}.top5"],
"compute_ips": [f"{prefix}.compute_ips"],
"total_ips": [f"{prefix}.total_ips"],
"data_time": [f"{prefix}.data_time"],
"compute_time": [
f"{prefix}.compute_latency",
f"{prefix}.compute_latency_at100",
f"{prefix}.compute_latency_at99",
f"{prefix}.compute_latency_at95",
],
}
logger.register_metric(
f"{prefix}.top1",
ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.ACC_METADATA,
)
logger.register_metric(
f"{prefix}.top5",
ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.ACC_METADATA,
)
logger.register_metric(
f"{prefix}.loss",
LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.LOSS_METADATA,
)
logger.register_metric(
f"{prefix}.compute_ips",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.IPS_METADATA,
)
logger.register_metric(
f"{prefix}.total_ips",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.IPS_METADATA,
)
logger.register_metric(
f"{prefix}.data_time",
PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency",
PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at100",
LAT_100(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at99",
LAT_99(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at95",
LAT_95(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=Metrics.TIME_METADATA,
)

View file

@ -5,6 +5,7 @@ from typing import Optional
import torch import torch
import warnings import warnings
from torch import nn from torch import nn
import torch.nn.functional as F
try: try:
from pytorch_quantization import nn as quant_nn from pytorch_quantization import nn as quant_nn
@ -143,30 +144,44 @@ class LambdaLayer(nn.Module):
# SqueezeAndExcitation {{{ # SqueezeAndExcitation {{{
class SqueezeAndExcitation(nn.Module): class SqueezeAndExcitation(nn.Module):
def __init__(self, in_channels, squeeze, activation, use_conv=False): def __init__(self, in_channels, squeeze, activation):
super(SqueezeAndExcitation, self).__init__() super(SqueezeAndExcitation, self).__init__()
if use_conv: self.squeeze = nn.Linear(in_channels, squeeze)
self.pooling = nn.AdaptiveAvgPool2d(1) self.expand = nn.Linear(squeeze, in_channels)
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
self.expand = nn.Conv2d(squeeze, in_channels, 1)
else:
self.squeeze = nn.Linear(in_channels, squeeze)
self.expand = nn.Linear(squeeze, in_channels)
self.activation = activation self.activation = activation
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.use_conv = use_conv
def forward(self, x): def forward(self, x):
if self.use_conv: return self._attention(x)
out = self.pooling(x)
else: def _attention(self, x):
out = torch.mean(x, [2, 3]) out = torch.mean(x, [2, 3])
out = self.squeeze(out)
out = self.activation(out)
out = self.expand(out)
out = self.sigmoid(out)
out = out.unsqueeze(2).unsqueeze(3)
return out
class SqueezeAndExcitationTRT(nn.Module):
def __init__(self, in_channels, squeeze, activation):
super(SqueezeAndExcitationTRT, self).__init__()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
self.expand = nn.Conv2d(squeeze, in_channels, 1)
self.activation = activation
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self._attention(x)
def _attention(self, x):
out = self.pooling(x)
out = self.squeeze(out) out = self.squeeze(out)
out = self.activation(out) out = self.activation(out)
out = self.expand(out) out = self.expand(out)
out = self.sigmoid(out) out = self.sigmoid(out)
if not self.use_conv:
out = out.unsqueeze(2).unsqueeze(3)
return out return out
@ -174,18 +189,9 @@ class SqueezeAndExcitation(nn.Module):
# EMA {{{ # EMA {{{
class EMA: class EMA:
def __init__(self, mu): def __init__(self, mu, module_ema):
self.mu = mu self.mu = mu
self.shadow = {} self.module_ema = module_ema
def state_dict(self):
return copy.deepcopy(self.shadow)
def load_state_dict(self, state_dict):
self.shadow = state_dict
def __len__(self):
return len(self.shadow)
def __call__(self, module, step=None): def __call__(self, module, step=None):
if step is None: if step is None:
@ -193,12 +199,17 @@ class EMA:
else: else:
mu = min(self.mu, (1.0 + step) / (10 + step)) mu = min(self.mu, (1.0 + step) / (10 + step))
for name, x in module.state_dict().items(): def strip_module(s: str) -> str:
if name in self.shadow: return s
new_average = (1.0 - mu) * x + mu * self.shadow[name]
self.shadow[name] = new_average.clone() mesd = self.module_ema.state_dict()
else: with torch.no_grad():
self.shadow[name] = x.clone() for name, x in module.state_dict().items():
if name.endswith("num_batches_tracked"):
continue
n = strip_module(name)
mesd[n].mul_(mu)
mesd[n].add_((1.0 - mu) * x)
# }}} # }}}
@ -218,10 +229,8 @@ class ONNXSiLU(nn.Module):
class SequentialSqueezeAndExcitation(SqueezeAndExcitation): class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
def __init__( def __init__(self, in_channels, squeeze, activation, quantized=False):
self, in_channels, squeeze, activation, quantized=False, use_conv=False super().__init__(in_channels, squeeze, activation)
):
super().__init__(in_channels, squeeze, activation, use_conv=use_conv)
self.quantized = quantized self.quantized = quantized
if quantized: if quantized:
assert quant_nn is not None, "pytorch_quantization is not available" assert quant_nn is not None, "pytorch_quantization is not available"
@ -231,10 +240,63 @@ class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
self.mul_b_quantizer = quant_nn.TensorQuantizer( self.mul_b_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input quant_nn.QuantConv2d.default_quant_desc_input
) )
else:
self.mul_a_quantizer = nn.Identity()
self.mul_b_quantizer = nn.Identity()
def forward(self, x): def forward(self, x):
out = self._attention(x)
if not self.quantized: if not self.quantized:
return super().forward(x) * x return out * x
else: else:
x_quant = self.mul_a_quantizer(super().forward(x)) x_quant = self.mul_a_quantizer(out)
return x_quant * self.mul_b_quantizer(x) return x_quant * self.mul_b_quantizer(x)
class SequentialSqueezeAndExcitationTRT(SqueezeAndExcitationTRT):
def __init__(self, in_channels, squeeze, activation, quantized=False):
super().__init__(in_channels, squeeze, activation)
self.quantized = quantized
if quantized:
assert quant_nn is not None, "pytorch_quantization is not available"
self.mul_a_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
self.mul_b_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input
)
else:
self.mul_a_quantizer = nn.Identity()
self.mul_b_quantizer = nn.Identity()
def forward(self, x):
out = self._attention(x)
if not self.quantized:
return out * x
else:
x_quant = self.mul_a_quantizer(out)
return x_quant * self.mul_b_quantizer(x)
class StochasticDepthResidual(nn.Module):
def __init__(self, survival_prob: float):
super().__init__()
self.survival_prob = survival_prob
self.register_buffer("mask", torch.ones(()), persistent=False)
def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
if not self.training:
return torch.add(residual, other=x)
else:
with torch.no_grad():
F.dropout(
self.mask,
p=1 - self.survival_prob,
training=self.training,
inplace=False,
)
return torch.addcmul(residual, self.mask, x)
class Flatten(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.squeeze(-1).squeeze(-1)

View file

@ -31,11 +31,11 @@ except ImportError as e:
from .common import ( from .common import (
SqueezeAndExcitation,
ONNXSiLU,
SequentialSqueezeAndExcitation, SequentialSqueezeAndExcitation,
SequentialSqueezeAndExcitationTRT,
LayerBuilder, LayerBuilder,
LambdaLayer, StochasticDepthResidual,
Flatten,
) )
from .model import ( from .model import (
@ -206,6 +206,7 @@ class EfficientNet(nn.Module):
out_channels = arch.stem_channels out_channels = arch.stem_channels
plc = 0 plc = 0
layers = []
for i, (k, s, r, e, c) in arch.enumerate(): for i, (k, s, r, e, c) in arch.enumerate():
layer, out_channels = self._make_layer( layer, out_channels = self._make_layer(
block=arch.block, block=arch.block,
@ -220,8 +221,8 @@ class EfficientNet(nn.Module):
trt=trt, trt=trt,
) )
plc = plc + r plc = plc + r
setattr(self, f"layer{i+1}", layer) layers.append(layer)
self.layers = nn.Sequential(*layers)
self.features = self._make_features(out_channels, arch.feature_channels) self.features = self._make_features(out_channels, arch.feature_channels)
self.classifier = self._make_classifier( self.classifier = self._make_classifier(
arch.feature_channels, num_classes, dropout arch.feature_channels, num_classes, dropout
@ -229,11 +230,7 @@ class EfficientNet(nn.Module):
def forward(self, x): def forward(self, x):
x = self.stem(x) x = self.stem(x)
x = self.layers(x)
for i in range(self.num_layers):
fn = getattr(self, f"layer{i+1}")
x = fn(x)
x = self.features(x) x = self.features(x)
x = self.classifier(x) x = self.classifier(x)
@ -241,27 +238,34 @@ class EfficientNet(nn.Module):
def extract_features(self, x, layers=None): def extract_features(self, x, layers=None):
if layers is None: if layers is None:
layers = [f"layer{i+1}" for i in range(self.num_layers)] layers = [f"layer{i+1}" for i in range(self.num_layers)] + [
"features",
"classifier",
]
run = [ run = [
f"layer{i+1}" i
for i in range(self.num_layers) for i in range(self.num_layers)
if "classifier" in layers if "classifier" in layers
or "features" in layers or "features" in layers
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)]) or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
] ]
if "features" in layers or "classifier" in layers:
run.append("features")
if "classifier" in layers:
run.append("classifier")
output = {} output = {}
x = self.stem(x) x = self.stem(x)
for l in run: for l in run:
fn = getattr(self, l) fn = self.layers[l]
x = fn(x) x = fn(x)
if l in layers: if f"layer{l+1}" in layers:
output[l] = x output[f"layer{l+1}"] = x
if "features" in layers or "classifier" in layers:
x = self.features(x)
if "features" in layers:
output["features"] = x
if "classifier" in layers:
output["classifier"] = self.classifier(x)
return output return output
@ -298,7 +302,7 @@ class EfficientNet(nn.Module):
OrderedDict( OrderedDict(
[ [
("pooling", nn.AdaptiveAvgPool2d(1)), ("pooling", nn.AdaptiveAvgPool2d(1)),
("squeeze", LambdaLayer(lambda x: x.squeeze(-1).squeeze(-1))), ("squeeze", Flatten()),
("dropout", nn.Dropout(dropout)), ("dropout", nn.Dropout(dropout)),
("fc", nn.Linear(num_features, num_classes)), ("fc", nn.Linear(num_features, num_classes)),
] ]
@ -353,11 +357,33 @@ class EfficientNet(nn.Module):
layers.append((f"block{idx}", blk)) layers.append((f"block{idx}", blk))
return nn.Sequential(OrderedDict(layers)), out_channels return nn.Sequential(OrderedDict(layers)), out_channels
def ngc_checkpoint_remap(self, url=None, version=None):
if version is None:
version = url.split("/")[8]
def to_sequential_remap(s):
splited = s.split(".")
if splited[0].startswith("layer"):
return ".".join(
["layers." + str(int(splited[0][len("layer") :]) - 1)] + splited[1:]
)
else:
return s
def no_remap(s):
return s
return {"20.12.0": to_sequential_remap, "21.03.0": to_sequential_remap}.get(
version, no_remap
)
# }}} # }}}
# MBConvBlock {{{ # MBConvBlock {{{
class MBConvBlock(nn.Module): class MBConvBlock(nn.Module):
__constants__ = ["quantized"]
def __init__( def __init__(
self, self,
builder: LayerBuilder, builder: LayerBuilder,
@ -366,7 +392,7 @@ class MBConvBlock(nn.Module):
out_channels: int, out_channels: int,
expand_ratio: int, expand_ratio: int,
stride: int, stride: int,
squeeze_excitation_ratio: int, squeeze_excitation_ratio: float,
squeeze_hidden=False, squeeze_hidden=False,
survival_prob: float = 1.0, survival_prob: float = 1.0,
quantized: bool = False, quantized: bool = False,
@ -387,25 +413,31 @@ class MBConvBlock(nn.Module):
self.depsep = builder.convDepSep( self.depsep = builder.convDepSep(
depsep_kernel_size, hidden_dim, hidden_dim, stride, bn=True, act=True depsep_kernel_size, hidden_dim, hidden_dim, stride, bn=True, act=True
) )
self.se = SequentialSqueezeAndExcitation( if trt or self.quantized:
hidden_dim, squeeze_dim, builder.activation(), self.quantized, use_conv=trt # Need TRT mode for quantized in order to automatically insert quantization before pooling
) self.se: nn.Module = SequentialSqueezeAndExcitationTRT(
hidden_dim, squeeze_dim, builder.activation(), self.quantized
)
else:
self.se: nn.Module = SequentialSqueezeAndExcitation(
hidden_dim, squeeze_dim, builder.activation(), self.quantized
)
self.proj = builder.conv1x1(hidden_dim, out_channels, bn=True) self.proj = builder.conv1x1(hidden_dim, out_channels, bn=True)
self.survival_prob = survival_prob if survival_prob == 1.0:
self.residual_add = torch.add
else:
self.residual_add = StochasticDepthResidual(survival_prob=survival_prob)
if self.quantized and self.residual: if self.quantized and self.residual:
assert quant_nn is not None, "pytorch_quantization is not available" assert quant_nn is not None, "pytorch_quantization is not available"
self.residual_quantizer = quant_nn.TensorQuantizer( self.residual_quantizer = quant_nn.TensorQuantizer(
quant_nn.QuantConv2d.default_quant_desc_input quant_nn.QuantConv2d.default_quant_desc_input
) # TODO QuantConv2d ?!? ) # TODO QuantConv2d ?!?
else:
self.residual_quantizer = nn.Identity()
def drop(self): def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.survival_prob == 1.0:
return False
return random.uniform(0.0, 1.0) > self.survival_prob
def forward(self, x):
if not self.residual: if not self.residual:
return self.proj( return self.proj(
self.se(self.depsep(x if self.expand is None else self.expand(x))) self.se(self.depsep(x if self.expand is None else self.expand(x)))
@ -414,16 +446,10 @@ class MBConvBlock(nn.Module):
b = self.proj( b = self.proj(
self.se(self.depsep(x if self.expand is None else self.expand(x))) self.se(self.depsep(x if self.expand is None else self.expand(x)))
) )
if self.training:
if self.drop():
multiplication_factor = 0.0
else:
multiplication_factor = 1.0 / self.survival_prob
else:
multiplication_factor = 1.0
if self.quantized: if self.quantized:
x = self.residual_quantizer(x) x = self.residual_quantizer(x)
return torch.add(x, alpha=multiplication_factor, other=b)
return self.residual_add(x, b)
def original_mbconv( def original_mbconv(
@ -436,7 +462,7 @@ def original_mbconv(
squeeze_excitation_ratio: int, squeeze_excitation_ratio: int,
survival_prob: float, survival_prob: float,
quantized: bool, quantized: bool,
trt: bool trt: bool,
): ):
return MBConvBlock( return MBConvBlock(
builder, builder,

View file

@ -1,8 +1,15 @@
from dataclasses import dataclass, asdict, replace from dataclasses import dataclass, asdict, replace
from .common import (
SequentialSqueezeAndExcitationTRT,
SequentialSqueezeAndExcitation,
SqueezeAndExcitation,
SqueezeAndExcitationTRT,
)
from typing import Optional, Callable from typing import Optional, Callable
import os import os
import torch import torch
import argparse import argparse
from functools import partial
@dataclass @dataclass
@ -37,7 +44,13 @@ class EntryPoint:
self.name = name self.name = name
self.model = model self.model = model
def __call__(self, pretrained=False, pretrained_from_file=None, **kwargs): def __call__(
self,
pretrained=False,
pretrained_from_file=None,
state_dict_key_map_fn=None,
**kwargs,
):
assert not (pretrained and (pretrained_from_file is not None)) assert not (pretrained and (pretrained_from_file is not None))
params = replace(self.model.params, **kwargs) params = replace(self.model.params, **kwargs)
@ -66,7 +79,7 @@ class EntryPoint:
pretrained_from_file pretrained_from_file
) )
) )
# Temporary fix to allow NGC checkpoint loading
if state_dict is not None: if state_dict is not None:
state_dict = { state_dict = {
k[len("module.") :] if k.startswith("module.") else k: v k[len("module.") :] if k.startswith("module.") else k: v
@ -85,12 +98,32 @@ class EntryPoint:
else: else:
return t return t
if state_dict_key_map_fn is not None:
state_dict = {
state_dict_key_map_fn(k): v for k, v in state_dict.items()
}
if hasattr(model, "ngc_checkpoint_remap"):
remap_fn = model.ngc_checkpoint_remap(url=self.model.checkpoint_url)
state_dict = {remap_fn(k): v for k, v in state_dict.items()}
def _se_layer_uses_conv(m):
return any(
map(
partial(isinstance, m),
[
SqueezeAndExcitationTRT,
SequentialSqueezeAndExcitationTRT,
],
)
)
state_dict = { state_dict = {
k: reshape( k: reshape(
v, v,
conv=dict(model.named_modules())[ conv=_se_layer_uses_conv(
".".join(k.split(".")[:-2]) dict(model.named_modules())[".".join(k.split(".")[:-2])]
].use_conv, ),
) )
if is_se_weight(k, v) if is_se_weight(k, v)
else v else v
@ -123,7 +156,8 @@ class EntryPoint:
def is_se_weight(key, value): def is_se_weight(key, value):
return (key.endswith("squeeze.weight") or key.endswith("expand.weight")) return key.endswith("squeeze.weight") or key.endswith("expand.weight")
def create_entrypoint(m: Model): def create_entrypoint(m: Model):
def _ep(**kwargs): def _ep(**kwargs):

View file

@ -36,14 +36,16 @@ from typing import List, Dict, Callable, Any, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from .common import SqueezeAndExcitation, LayerBuilder, LambdaLayer from .common import (
SqueezeAndExcitation,
LayerBuilder,
SqueezeAndExcitationTRT,
)
from .model import ( from .model import (
Model, Model,
ModelParams, ModelParams,
ModelArch, ModelArch,
OptimizerParams,
create_entrypoint,
EntryPoint, EntryPoint,
) )
@ -128,11 +130,18 @@ class Bottleneck(nn.Module):
self.stride = stride self.stride = stride
self.fused_se = fused_se self.fused_se = fused_se
self.squeeze = ( if se:
SqueezeAndExcitation(planes * expansion, se_squeeze, builder.activation(), use_conv=trt) self.squeeze = (
if se SqueezeAndExcitation(
else None planes * expansion, se_squeeze, builder.activation()
) )
if not trt
else SqueezeAndExcitationTRT(
planes * expansion, se_squeeze, builder.activation()
)
)
else:
self.squeeze = None
def forward(self, x): def forward(self, x):
residual = x residual = x
@ -215,6 +224,7 @@ class ResNet(nn.Module):
last_bn_0_init: bool = False last_bn_0_init: bool = False
conv_init: str = "fan_in" conv_init: str = "fan_in"
trt: bool = False trt: bool = False
fused_se: bool = True
def parser(self, name): def parser(self, name):
p = super().parser(name) p = super().parser(name)
@ -240,6 +250,10 @@ class ResNet(nn.Module):
help="initialization mode for convolutional layers, see https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_", help="initialization mode for convolutional layers, see https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_",
) )
p.add_argument("--trt", metavar="True|False", default=self.trt, type=bool) p.add_argument("--trt", metavar="True|False", default=self.trt, type=bool)
p.add_argument(
"--fused_se", metavar="True|False", default=self.fused_se, type=bool
)
return p return p
def __init__( def __init__(
@ -249,6 +263,7 @@ class ResNet(nn.Module):
last_bn_0_init: bool = False, last_bn_0_init: bool = False,
conv_init: str = "fan_in", conv_init: str = "fan_in",
trt: bool = False, trt: bool = False,
fused_se: bool = True,
): ):
super(ResNet, self).__init__() super(ResNet, self).__init__()
@ -265,6 +280,7 @@ class ResNet(nn.Module):
inplanes = arch.stem_width inplanes = arch.stem_width
assert len(arch.widths) == len(arch.layers) assert len(arch.widths) == len(arch.layers)
self.num_layers = len(arch.widths) self.num_layers = len(arch.widths)
layers = []
for i, (w, l) in enumerate(zip(arch.widths, arch.layers)): for i, (w, l) in enumerate(zip(arch.widths, arch.layers)):
layer, inplanes = self._make_layer( layer, inplanes = self._make_layer(
arch.block, arch.block,
@ -275,9 +291,11 @@ class ResNet(nn.Module):
cardinality=arch.cardinality, cardinality=arch.cardinality,
stride=1 if i == 0 else 2, stride=1 if i == 0 else 2,
trt=trt, trt=trt,
fused_se=fused_se,
) )
setattr(self, f"layer{i+1}", layer) layers.append(layer)
self.layers = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1) self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(arch.widths[-1] * arch.expansion, num_classes) self.fc = nn.Linear(arch.widths[-1] * arch.expansion, num_classes)
@ -297,13 +315,8 @@ class ResNet(nn.Module):
def forward(self, x): def forward(self, x):
x = self.stem(x) x = self.stem(x)
x = self.layers(x)
for i in range(self.num_layers):
fn = getattr(self, f"layer{i+1}")
x = fn(x)
x = self.classifier(x) x = self.classifier(x)
return x return x
def extract_features(self, x, layers=None): def extract_features(self, x, layers=None):
@ -311,7 +324,7 @@ class ResNet(nn.Module):
layers = [f"layer{i+1}" for i in range(self.num_layers)] + ["classifier"] layers = [f"layer{i+1}" for i in range(self.num_layers)] + ["classifier"]
run = [ run = [
f"layer{i+1}" i
for i in range(self.num_layers) for i in range(self.num_layers)
if "classifier" in layers if "classifier" in layers
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)]) or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
@ -320,10 +333,10 @@ class ResNet(nn.Module):
output = {} output = {}
x = self.stem(x) x = self.stem(x)
for l in run: for l in run:
fn = getattr(self, l) fn = self.layers[l]
x = fn(x) x = fn(x)
if l in layers: if f"layer{l+1}" in layers:
output[l] = x output[f"layer{l+1}"] = x
if "classifier" in layers: if "classifier" in layers:
output["classifier"] = self.classifier(x) output["classifier"] = self.classifier(x)
@ -332,7 +345,16 @@ class ResNet(nn.Module):
# helper functions {{{ # helper functions {{{
def _make_layer( def _make_layer(
self, block, expansion, inplanes, planes, blocks, stride=1, cardinality=1, trt=False, self,
block,
expansion,
inplanes,
planes,
blocks,
stride=1,
cardinality=1,
trt=False,
fused_se=True,
): ):
downsample = None downsample = None
if stride != 1 or inplanes != planes * expansion: if stride != 1 or inplanes != planes * expansion:
@ -354,15 +376,33 @@ class ResNet(nn.Module):
stride=stride if i == 0 else 1, stride=stride if i == 0 else 1,
cardinality=cardinality, cardinality=cardinality,
downsample=downsample if i == 0 else None, downsample=downsample if i == 0 else None,
fused_se=True, fused_se=fused_se,
last_bn_0_init=self.last_bn_0_init, last_bn_0_init=self.last_bn_0_init,
trt = trt, trt=trt,
) )
) )
inplanes = planes * expansion inplanes = planes * expansion
return nn.Sequential(*layers), inplanes return nn.Sequential(*layers), inplanes
def ngc_checkpoint_remap(self, url=None, version=None):
if version is None:
version = url.split("/")[8]
def to_sequential_remap(s):
splited = s.split(".")
if splited[0].startswith("layer"):
return ".".join(
["layers." + str(int(splited[0][len("layer") :]) - 1)] + splited[1:]
)
else:
return s
def no_remap(s):
return s
return {"20.06.0": to_sequential_remap}.get(version, no_remap)
# }}} # }}}

View file

@ -1,7 +1,39 @@
import math
import numpy as np
import torch import torch
from torch import optim from torch import optim
def get_sgd_optimizer(parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False):
def get_optimizer(parameters, lr, args, state=None):
if args.optimizer == "sgd":
optimizer = get_sgd_optimizer(
parameters,
lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=args.nesterov,
bn_weight_decay=args.bn_weight_decay,
)
elif args.optimizer == "rmsprop":
optimizer = get_rmsprop_optimizer(
parameters,
lr,
alpha=args.rmsprop_alpha,
momentum=args.momentum,
weight_decay=args.weight_decay,
eps=args.rmsprop_eps,
bn_weight_decay=args.bn_weight_decay,
)
if not state is None:
optimizer.load_state_dict(state)
return optimizer
def get_sgd_optimizer(
parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False
):
if bn_weight_decay: if bn_weight_decay:
print(" ! Weight decay applied to BN parameters ") print(" ! Weight decay applied to BN parameters ")
params = [v for n, v in parameters] params = [v for n, v in parameters]
@ -17,20 +49,112 @@ def get_sgd_optimizer(parameters, lr, momentum, weight_decay, nesterov=False, bn
{"params": rest_params, "weight_decay": weight_decay}, {"params": rest_params, "weight_decay": weight_decay},
] ]
optimizer = torch.optim.SGD(params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov) optimizer = torch.optim.SGD(
params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov
)
return optimizer return optimizer
def get_rmsprop_optimizer(parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False): def get_rmsprop_optimizer(
parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False
):
bn_params = [v for n, v in parameters if "bn" in n] bn_params = [v for n, v in parameters if "bn" in n]
rest_params = [v for n, v in parameters if not "bn" in n] rest_params = [v for n, v in parameters if not "bn" in n]
params = [ params = [
{"params": bn_params, "weight_decay": weight_decay if bn_weight_decay else 0}, {"params": bn_params, "weight_decay": weight_decay if bn_weight_decay else 0},
{"params": rest_params, "weight_decay": weight_decay}, {"params": rest_params, "weight_decay": weight_decay},
] ]
optimizer = torch.optim.RMSprop(params, lr=lr, alpha=alpha, weight_decay=weight_decay, momentum=momentum, eps=eps) optimizer = torch.optim.RMSprop(
params,
lr=lr,
alpha=alpha,
weight_decay=weight_decay,
momentum=momentum,
eps=eps,
)
return optimizer return optimizer
def lr_policy(lr_fn):
def _alr(optimizer, iteration, epoch):
lr = lr_fn(iteration, epoch)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return lr
return _alr
def lr_step_policy(base_lr, steps, decay_factor, warmup_length):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
lr = base_lr
for s in steps:
if epoch >= s:
lr *= decay_factor
return lr
return lr_policy(_lr_fn)
def lr_linear_policy(base_lr, warmup_length, epochs):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = base_lr * (1 - (e / es))
return lr
return lr_policy(_lr_fn)
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr=0):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
return lr
return lr_policy(_lr_fn)
def lr_exponential_policy(
base_lr,
warmup_length,
epochs,
final_multiplier=0.001,
decay_factor=None,
decay_step=1,
logger=None,
):
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
es = epochs - warmup_length
if decay_factor is not None:
epoch_decay = decay_factor
else:
epoch_decay = np.power(
2, np.log2(final_multiplier) / math.floor(es / decay_step)
)
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
lr = base_lr * (epoch_decay ** math.floor(e / decay_step))
return lr
return lr_policy(_lr_fn, logger=logger)

View file

@ -27,279 +27,212 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import math
import time import time
from copy import deepcopy
from functools import wraps
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable
from . import logger as log
from . import models
from . import utils
import dllogger
from .optimizers import get_sgd_optimizer, get_rmsprop_optimizer
from .models.common import EMA
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
ACC_METADATA = {"unit": "%", "format": ":.2f"} from . import logger as log
IPS_METADATA = {"unit": "img/s", "format": ":.2f"} from . import utils
TIME_METADATA = {"unit": "s", "format": ":.5f"} from .logger import TrainingMetrics, ValidationMetrics
LOSS_METADATA = {"format": ":.5f"} from .models.common import EMA
class ModelAndLoss(nn.Module): class Executor:
def __init__( def __init__(
self, self,
model, model: nn.Module,
loss, loss: Optional[nn.Module],
cuda=True, cuda: bool = True,
memory_format=torch.contiguous_format, memory_format: torch.memory_format = torch.contiguous_format,
amp: bool = False,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
divide_loss: int = 1,
ts_script: bool = False,
): ):
super(ModelAndLoss, self).__init__() assert not (amp and scaler is None), "Gradient Scaler is needed for AMP"
if cuda: def xform(m: nn.Module) -> nn.Module:
model = model.cuda().to(memory_format=memory_format) if cuda:
m = m.cuda()
m.to(memory_format=memory_format)
return m
# define loss function (criterion) and optimizer self.model = xform(model)
criterion = loss() if ts_script:
self.model = torch.jit.script(self.model)
if cuda: self.ts_script = ts_script
criterion = criterion.cuda() self.loss = xform(loss) if loss is not None else None
self.amp = amp
self.model = model self.scaler = scaler
self.loss = criterion self.is_distributed = False
self.divide_loss = divide_loss
def forward(self, data, target): self._fwd_bwd = None
output = self.model(data) self._forward = None
loss = self.loss(output, target)
return loss, output
def distributed(self, gpu_id): def distributed(self, gpu_id):
self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id) self.is_distributed = True
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id)
torch.cuda.current_stream().wait_stream(s)
def load_model_state(self, state): def _fwd_bwd_fn(
if not state is None: self,
self.model.load_state_dict(state) input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
with autocast(enabled=self.amp):
loss = self.loss(self.model(input), target)
loss /= self.divide_loss
self.scaler.scale(loss).backward()
return loss
def _forward_fn(
self, input: torch.Tensor, target: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad(), autocast(enabled=self.amp):
output = self.model(input)
loss = None if self.loss is None else self.loss(output, target)
return output if loss is None else loss, output
def optimize(self, fn):
return fn
@property
def forward_backward(self):
if self._fwd_bwd is None:
if self.loss is None:
raise NotImplementedError(
"Loss must not be None for forward+backward step"
)
self._fwd_bwd = self.optimize(self._fwd_bwd_fn)
return self._fwd_bwd
@property
def forward(self):
if self._forward is None:
self._forward = self.optimize(self._forward_fn)
return self._forward
def get_optimizer(parameters, lr, args, state=None): class Trainer:
if args.optimizer == 'sgd': def __init__(
optimizer = get_sgd_optimizer(parameters, lr, momentum=args.momentum, self,
weight_decay=args.weight_decay, nesterov=args.nesterov, executor: Executor,
bn_weight_decay=args.bn_weight_decay) optimizer: torch.optim.Optimizer,
elif args.optimizer == 'rmsprop': grad_acc_steps: int,
optimizer = get_rmsprop_optimizer(parameters, lr, alpha=args.rmsprop_alpha, momentum=args.momentum, ema: Optional[float] = None,
weight_decay=args.weight_decay, ):
eps=args.rmsprop_eps, self.executor = executor
bn_weight_decay=args.bn_weight_decay) self.optimizer = optimizer
if not state is None: self.grad_acc_steps = grad_acc_steps
optimizer.load_state_dict(state) self.use_ema = False
if ema is not None:
self.ema_executor = deepcopy(self.executor)
self.ema = EMA(ema, self.ema_executor.model)
self.use_ema = True
return optimizer self.optimizer.zero_grad(set_to_none=True)
self.steps_since_update = 0
def train(self):
self.executor.model.train()
def lr_policy(lr_fn, logger=None): def eval(self):
if logger is not None: self.executor.model.eval()
logger.register_metric( if self.use_ema:
"lr", log.LR_METER(), verbosity=dllogger.Verbosity.VERBOSE self.executor.model.eval()
)
def _alr(optimizer, iteration, epoch): def train_step(self, input, target, step=None):
lr = lr_fn(iteration, epoch) loss = self.executor.forward_backward(input, target)
if logger is not None: self.steps_since_update += 1
logger.log_metric("lr", lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return _alr if self.steps_since_update == self.grad_acc_steps:
if self.executor.scaler is not None:
self.executor.scaler.step(self.optimizer)
def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None): self.executor.scaler.update()
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
lr = base_lr
for s in steps:
if epoch >= s:
lr *= decay_factor
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_linear_policy(base_lr, warmup_length, epochs, logger=None):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = base_lr * (1 - (e / es))
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr = 0, logger=None):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_exponential_policy(
base_lr, warmup_length, epochs, final_multiplier=0.001, decay_factor=None, decay_step=1, logger=None
):
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
es = epochs - warmup_length
if decay_factor is not None:
epoch_decay = decay_factor
else:
epoch_decay = np.power(2, np.log2(final_multiplier) / math.floor(es/decay_step))
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
lr = base_lr * (epoch_decay ** math.floor(e/decay_step))
return lr
return lr_policy(_lr_fn, logger=logger)
def get_train_step(
model_and_loss, optimizer, scaler, use_amp=False, batch_size_multiplier=1
):
def _step(input, target, optimizer_step=True):
input_var = Variable(input)
target_var = Variable(target)
with autocast(enabled=use_amp):
loss, output = model_and_loss(input_var, target_var)
loss /= batch_size_multiplier
if torch.distributed.is_initialized():
reduced_loss = utils.reduce_tensor(loss.data)
else: else:
reduced_loss = loss.data self.optimizer.step()
self.optimizer.zero_grad()
scaler.scale(loss).backward() self.steps_since_update = 0
if optimizer_step:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
torch.cuda.synchronize() torch.cuda.synchronize()
return reduced_loss if self.use_ema:
self.ema(self.executor.model, step=step)
return _step return loss
def validation_steps(self) -> Dict[str, Callable]:
vsd: Dict[str, Callable] = {"val": self.executor.forward}
if self.use_ema:
vsd["val_ema"] = self.ema_executor.forward
return vsd
def state_dict(self) -> dict:
res = {
"state_dict": self.executor.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
if self.use_ema:
res["state_dict_ema"] = self.ema_executor.model.state_dict()
return res
def train( def train(
train_step,
train_loader, train_loader,
model_and_loss,
optimizer,
scaler,
lr_scheduler, lr_scheduler,
logger, log_fn,
epoch,
steps_per_epoch,
timeout_handler, timeout_handler,
ema=None,
use_amp=False,
prof=-1, prof=-1,
batch_size_multiplier=1, step=0,
register_metrics=True,
): ):
interrupted = False interrupted = False
if register_metrics and logger is not None:
logger.register_metric(
"train.loss",
log.LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=LOSS_METADATA,
)
logger.register_metric(
"train.compute_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=IPS_METADATA,
)
logger.register_metric(
"train.total_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=IPS_METADATA,
)
logger.register_metric(
"train.data_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
"train.compute_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
step = get_train_step(
model_and_loss,
optimizer,
scaler=scaler,
use_amp=use_amp,
batch_size_multiplier=batch_size_multiplier,
)
model_and_loss.train()
end = time.time() end = time.time()
optimizer.zero_grad()
data_iter = enumerate(train_loader) data_iter = enumerate(train_loader)
if logger is not None:
data_iter = logger.iteration_generator_wrapper(data_iter, mode='train')
for i, (input, target) in data_iter: for i, (input, target) in data_iter:
bs = input.size(0) bs = input.size(0)
lr_scheduler(optimizer, i, epoch) lr = lr_scheduler(i)
data_time = time.time() - end data_time = time.time() - end
optimizer_step = ((i + 1) % batch_size_multiplier) == 0 loss = train_step(input, target, step=step + i)
loss = step(input, target, optimizer_step=optimizer_step)
if ema is not None:
ema(model_and_loss, epoch*steps_per_epoch+i)
it_time = time.time() - end it_time = time.time() - end
if logger is not None: with torch.no_grad():
logger.log_metric("train.loss", loss.item(), bs) if torch.distributed.is_initialized():
logger.log_metric("train.compute_ips", utils.calc_ips(bs, it_time - data_time)) reduced_loss = utils.reduce_tensor(loss.detach())
logger.log_metric("train.total_ips", utils.calc_ips(bs, it_time)) else:
logger.log_metric("train.data_time", data_time) reduced_loss = loss.detach()
logger.log_metric("train.compute_time", it_time - data_time)
log_fn(
compute_ips=utils.calc_ips(bs, it_time - data_time),
total_ips=utils.calc_ips(bs, it_time),
data_time=data_time,
compute_time=it_time - data_time,
lr=lr,
loss=reduced_loss.item(),
)
end = time.time() end = time.time()
if prof > 0 and (i + 1 >= prof): if prof > 0 and (i + 1 >= prof):
time.sleep(5) time.sleep(5)
break break
if ((i+1) % 20 == 0) and timeout_handler.interrupted: if ((i + 1) % 20 == 0) and timeout_handler.interrupted:
time.sleep(5) time.sleep(5)
interrupted = True interrupted = True
break break
@ -307,134 +240,58 @@ def train(
return interrupted return interrupted
def get_val_step(model_and_loss, use_amp=False): def validate(infer_fn, val_loader, log_fn, prof=-1, with_loss=True):
def _step(input, target):
input_var = Variable(input)
target_var = Variable(target)
with torch.no_grad(), autocast(enabled=use_amp):
loss, output = model_and_loss(input_var, target_var)
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
if torch.distributed.is_initialized():
reduced_loss = utils.reduce_tensor(loss.data)
prec1 = utils.reduce_tensor(prec1)
prec5 = utils.reduce_tensor(prec5)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
return reduced_loss, prec1, prec5
return _step
def validate(
val_loader,
model_and_loss,
logger,
epoch,
use_amp=False,
prof=-1,
register_metrics=True,
prefix="val",
):
if register_metrics and logger is not None:
logger.register_metric(
f"{prefix}.top1",
log.ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=ACC_METADATA,
)
logger.register_metric(
f"{prefix}.top5",
log.ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=ACC_METADATA,
)
logger.register_metric(
f"{prefix}.loss",
log.LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=LOSS_METADATA,
)
logger.register_metric(
f"{prefix}.compute_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=IPS_METADATA,
)
logger.register_metric(
f"{prefix}.total_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=IPS_METADATA,
)
logger.register_metric(
f"{prefix}.data_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at100",
log.LAT_100(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at99",
log.LAT_99(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
f"{prefix}.compute_latency_at95",
log.LAT_95(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
step = get_val_step(model_and_loss, use_amp=use_amp)
top1 = log.AverageMeter() top1 = log.AverageMeter()
# switch to evaluate mode # switch to evaluate mode
model_and_loss.eval()
end = time.time() end = time.time()
data_iter = enumerate(val_loader) data_iter = enumerate(val_loader)
if not logger is None:
data_iter = logger.iteration_generator_wrapper(data_iter, mode='val')
for i, (input, target) in data_iter: for i, (input, target) in data_iter:
bs = input.size(0) bs = input.size(0)
data_time = time.time() - end data_time = time.time() - end
loss, prec1, prec5 = step(input, target) if with_loss:
loss, output = infer_fn(input, target)
else:
output = infer_fn(input)
with torch.no_grad():
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
if torch.distributed.is_initialized():
if with_loss:
reduced_loss = utils.reduce_tensor(loss.detach())
prec1 = utils.reduce_tensor(prec1)
prec5 = utils.reduce_tensor(prec5)
else:
if with_loss:
reduced_loss = loss.detach()
prec1 = prec1.item()
prec5 = prec5.item()
infer_result = {
"top1": (prec1, bs),
"top5": (prec5, bs),
}
if with_loss:
infer_result["loss"] = (reduced_loss.item(), bs)
torch.cuda.synchronize()
it_time = time.time() - end it_time = time.time() - end
top1.record(prec1.item(), bs) top1.record(prec1, bs)
if logger is not None:
logger.log_metric(f"{prefix}.top1", prec1.item(), bs) log_fn(
logger.log_metric(f"{prefix}.top5", prec5.item(), bs) compute_ips=utils.calc_ips(bs, it_time - data_time),
logger.log_metric(f"{prefix}.loss", loss.item(), bs) total_ips=utils.calc_ips(bs, it_time),
logger.log_metric(f"{prefix}.compute_ips", utils.calc_ips(bs, it_time - data_time)) data_time=data_time,
logger.log_metric(f"{prefix}.total_ips", utils.calc_ips(bs, it_time)) compute_time=it_time - data_time,
logger.log_metric(f"{prefix}.data_time", data_time) **infer_result,
logger.log_metric(f"{prefix}.compute_latency", it_time - data_time) )
logger.log_metric(f"{prefix}.compute_latency_at95", it_time - data_time)
logger.log_metric(f"{prefix}.compute_latency_at99", it_time - data_time)
logger.log_metric(f"{prefix}.compute_latency_at100", it_time - data_time)
end = time.time() end = time.time()
if (prof > 0) and (i + 1 >= prof): if (prof > 0) and (i + 1 >= prof):
@ -445,22 +302,14 @@ def validate(
# Train loop {{{ # Train loop {{{
def train_loop( def train_loop(
model_and_loss, trainer: Trainer,
optimizer,
scaler,
lr_scheduler, lr_scheduler,
train_loader, train_loader,
train_loader_len,
val_loader, val_loader,
logger, logger,
should_backup_checkpoint, should_backup_checkpoint,
steps_per_epoch,
ema=None,
model_ema=None,
use_amp=False,
batch_size_multiplier=1,
best_prec1=0, best_prec1=0,
start_epoch=0, start_epoch=0,
end_epoch=0, end_epoch=0,
@ -472,14 +321,22 @@ def train_loop(
checkpoint_dir="./", checkpoint_dir="./",
checkpoint_filename="checkpoint.pth.tar", checkpoint_filename="checkpoint.pth.tar",
): ):
train_metrics = TrainingMetrics(logger)
val_metrics = {
k: ValidationMetrics(logger, k) for k in trainer.validation_steps().keys()
}
training_step = trainer.train_step
prec1 = -1 prec1 = -1
use_ema = (model_ema is not None) and (ema is not None)
if early_stopping_patience > 0: if early_stopping_patience > 0:
epochs_since_improvement = 0 epochs_since_improvement = 0
backup_prefix = checkpoint_filename[:-len("checkpoint.pth.tar")] if \ backup_prefix = (
checkpoint_filename.endswith("checkpoint.pth.tar") else "" checkpoint_filename[: -len("checkpoint.pth.tar")]
if checkpoint_filename.endswith("checkpoint.pth.tar")
else ""
)
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}") print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
with utils.TimeoutHandler() as timeout_handler: with utils.TimeoutHandler() as timeout_handler:
interrupted = False interrupted = False
@ -487,73 +344,71 @@ def train_loop(
if logger is not None: if logger is not None:
logger.start_epoch() logger.start_epoch()
if not skip_training: if not skip_training:
if logger is not None:
data_iter = logger.iteration_generator_wrapper(
train_loader, mode="train"
)
else:
data_iter = train_loader
trainer.train()
interrupted = train( interrupted = train(
train_loader, training_step,
model_and_loss, data_iter,
optimizer, lambda i: lr_scheduler(trainer.optimizer, i, epoch),
scaler, train_metrics.log,
lr_scheduler,
logger,
epoch,
steps_per_epoch,
timeout_handler, timeout_handler,
ema=ema,
use_amp=use_amp,
prof=prof, prof=prof,
register_metrics=epoch == start_epoch, step=epoch * train_loader_len,
batch_size_multiplier=batch_size_multiplier,
) )
if not skip_validation: if not skip_validation:
prec1, nimg = validate( trainer.eval()
val_loader, for k, infer_fn in trainer.validation_steps().items():
model_and_loss, if logger is not None:
logger, data_iter = logger.iteration_generator_wrapper(
epoch, val_loader, mode="val"
use_amp=use_amp, )
prof=prof, else:
register_metrics=epoch == start_epoch, data_iter = val_loader
)
if use_ema: step_prec1, _ = validate(
model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()}) infer_fn,
prec1, nimg = validate( data_iter,
val_loader, val_metrics[k].log,
model_ema,
logger,
epoch,
prof=prof, prof=prof,
register_metrics=epoch == start_epoch,
prefix='val_ema'
) )
if k == "val":
prec1 = step_prec1
if prec1 > best_prec1: if prec1 > best_prec1:
is_best = True is_best = True
best_prec1 = prec1 best_prec1 = prec1
else: else:
is_best = False is_best = False
else: else:
is_best = True is_best = False
best_prec1 = 0 best_prec1 = 0
if logger is not None: if logger is not None:
logger.end_epoch() logger.end_epoch()
if save_checkpoints and ( if save_checkpoints and (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 not torch.distributed.is_initialized()
or torch.distributed.get_rank() == 0
): ):
if should_backup_checkpoint(epoch): if should_backup_checkpoint(epoch):
backup_filename = "{}checkpoint-{}.pth.tar".format(backup_prefix, epoch + 1) backup_filename = "{}checkpoint-{}.pth.tar".format(
backup_prefix, epoch + 1
)
else: else:
backup_filename = None backup_filename = None
checkpoint_state = { checkpoint_state = {
"epoch": epoch + 1, "epoch": epoch + 1,
"state_dict": model_and_loss.model.state_dict(),
"best_prec1": best_prec1, "best_prec1": best_prec1,
"optimizer": optimizer.state_dict(), **trainer.state_dict(),
} }
if use_ema:
checkpoint_state["state_dict_ema"] = ema.state_dict()
utils.save_checkpoint( utils.save_checkpoint(
checkpoint_state, checkpoint_state,
is_best, is_best,
@ -561,6 +416,7 @@ def train_loop(
backup_filename=backup_filename, backup_filename=backup_filename,
filename=checkpoint_filename, filename=checkpoint_filename,
) )
if early_stopping_patience > 0: if early_stopping_patience > 0:
if not is_best: if not is_best:
epochs_since_improvement += 1 epochs_since_improvement += 1
@ -570,4 +426,6 @@ def train_loop(
break break
if interrupted: if interrupted:
break break
# }}} # }}}

View file

@ -97,7 +97,7 @@ def accuracy(output, target, topk=(1,)):
def reduce_tensor(tensor): def reduce_tensor(tensor):
rt = tensor.clone() rt = tensor.clone().detach()
dist.all_reduce(rt, op=dist.ReduceOp.SUM) dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= ( rt /= (
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
@ -114,6 +114,7 @@ class TimeoutHandler:
def __init__(self, sig=signal.SIGTERM): def __init__(self, sig=signal.SIGTERM):
self.sig = sig self.sig = sig
self.device = torch.device("cuda") self.device = torch.device("cuda")
@property @property
def interrupted(self): def interrupted(self):
if not dist.is_initialized(): if not dist.is_initialized():

View file

@ -58,6 +58,12 @@ from image_classification.models import (
efficientnet_widese_b0, efficientnet_widese_b0,
efficientnet_widese_b4, efficientnet_widese_b4,
) )
from image_classification.optimizers import (
get_optimizer,
lr_cosine_policy,
lr_linear_policy,
lr_step_policy,
)
import dllogger import dllogger
@ -102,7 +108,9 @@ def add_parser_arguments(parser, skip_arch=False):
metavar="ARCH", metavar="ARCH",
default="resnet50", default="resnet50",
choices=model_names, choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", help="model architecture: "
+ " | ".join(model_names)
+ " (default: resnet50)",
) )
parser.add_argument( parser.add_argument(
@ -290,6 +298,13 @@ def add_parser_arguments(parser, skip_arch=False):
dest="save_checkpoints", dest="save_checkpoints",
help="do not store any checkpoints, useful for benchmarking", help="do not store any checkpoints, useful for benchmarking",
) )
parser.add_argument(
"--jit",
type=str,
default = "no",
choices=["no", "script"],
help="no -> do not use torch.jit; script -> use torch.jit.script"
)
parser.add_argument("--checkpoint-filename", default="checkpoint.pth.tar", type=str) parser.add_argument("--checkpoint-filename", default="checkpoint.pth.tar", type=str)
@ -320,7 +335,7 @@ def add_parser_arguments(parser, skip_arch=False):
type=int, type=int,
default=None, default=None,
required=False, required=False,
help="number of classes" help="number of classes",
) )
@ -432,13 +447,25 @@ def prepare_for_training(args, model_args, model_arch):
if args.image_size is not None if args.image_size is not None
else model.arch.default_image_size else model.arch.default_image_size
) )
model_and_loss = ModelAndLoss(model, loss, cuda=True, memory_format=memory_format)
if args.use_ema is not None: scaler = torch.cuda.amp.GradScaler(
model_ema = deepcopy(model_and_loss) init_scale=args.static_loss_scale,
ema = EMA(args.use_ema) growth_factor=2,
else: backoff_factor=0.5,
model_ema = None growth_interval=100 if args.dynamic_loss_scale else 1000000000,
ema = None enabled=args.amp,
)
executor = Executor(
model,
loss(),
cuda=True,
memory_format=memory_format,
amp=args.amp,
scaler=scaler,
divide_loss=batch_size_multiplier,
ts_script = args.jit == "script",
)
# Create data loaders and optimizers as needed # Create data loaders and optimizers as needed
if args.data_backend == "pytorch": if args.data_backend == "pytorch":
@ -463,7 +490,7 @@ def prepare_for_training(args, model_args, model_arch):
args.batch_size, args.batch_size,
model_args.num_classes, model_args.num_classes,
args.mixup > 0.0, args.mixup > 0.0,
interpolation = args.interpolation, interpolation=args.interpolation,
augmentation=args.augmentation, augmentation=args.augmentation,
start_epoch=start_epoch, start_epoch=start_epoch,
workers=args.workers, workers=args.workers,
@ -478,7 +505,7 @@ def prepare_for_training(args, model_args, model_arch):
args.batch_size, args.batch_size,
model_args.num_classes, model_args.num_classes,
False, False,
interpolation = args.interpolation, interpolation=args.interpolation,
workers=args.workers, workers=args.workers,
memory_format=memory_format, memory_format=memory_format,
) )
@ -508,41 +535,38 @@ def prepare_for_training(args, model_args, model_arch):
) )
optimizer = get_optimizer( optimizer = get_optimizer(
list(model_and_loss.model.named_parameters()), list(executor.model.named_parameters()),
args.lr, args.lr,
args=args, args=args,
state=optimizer_state, state=optimizer_state,
) )
if args.lr_schedule == "step": if args.lr_schedule == "step":
lr_policy = lr_step_policy( lr_policy = lr_step_policy(args.lr, [30, 60, 80], 0.1, args.warmup)
args.lr, [30, 60, 80], 0.1, args.warmup, logger=logger
)
elif args.lr_schedule == "cosine": elif args.lr_schedule == "cosine":
lr_policy = lr_cosine_policy( lr_policy = lr_cosine_policy(
args.lr, args.warmup, args.epochs, end_lr=args.end_lr, logger=logger args.lr, args.warmup, args.epochs, end_lr=args.end_lr
) )
elif args.lr_schedule == "linear": elif args.lr_schedule == "linear":
lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs, logger=logger) lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs)
scaler = torch.cuda.amp.GradScaler(
init_scale=args.static_loss_scale,
growth_factor=2,
backoff_factor=0.5,
growth_interval=100 if args.dynamic_loss_scale else 1000000000,
enabled=args.amp,
)
if args.distributed: if args.distributed:
model_and_loss.distributed(args.gpu) executor.distributed(args.gpu)
model_and_loss.load_model_state(model_state) if model_state is not None:
if (ema is not None) and (model_state_ema is not None): executor.model.load_state_dict(model_state)
print("load ema")
ema.load_state_dict(model_state_ema)
return (model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, trainer = Trainer(
train_loader_len, batch_size_multiplier, start_epoch) executor,
optimizer,
grad_acc_steps=batch_size_multiplier,
ema=args.use_ema,
)
if (args.use_ema is not None) and (model_state_ema is not None):
trainer.ema_executor.model.load_state_dict(model_state_ema)
return (trainer, lr_policy, train_loader, train_loader_len, val_loader, logger, start_epoch)
def main(args, model_args, model_arch): def main(args, model_args, model_arch):
@ -550,23 +574,24 @@ def main(args, model_args, model_arch):
global best_prec1 global best_prec1
best_prec1 = 0 best_prec1 = 0
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \ (
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch) trainer,
train_loop(
model_and_loss,
optimizer,
scaler,
lr_policy, lr_policy,
train_loader, train_loader,
train_loader_len,
val_loader,
logger,
start_epoch,
) = prepare_for_training(args, model_args, model_arch)
train_loop(
trainer,
lr_policy,
train_loader,
train_loader_len,
val_loader, val_loader,
logger, logger,
should_backup_checkpoint(args), should_backup_checkpoint(args),
ema=ema,
model_ema=model_ema,
steps_per_epoch=train_loader_len,
use_amp=args.amp,
batch_size_multiplier=batch_size_multiplier,
start_epoch=start_epoch, start_epoch=start_epoch,
end_epoch=min((start_epoch + args.run_epochs), args.epochs) end_epoch=min((start_epoch + args.run_epochs), args.epochs)
if args.run_epochs != -1 if args.run_epochs != -1
@ -587,7 +612,6 @@ def main(args, model_args, model_arch):
if __name__ == "__main__": if __name__ == "__main__":
epilog = [ epilog = [
"Based on the architecture picked by --arch flag, you may use the following options:\n" "Based on the architecture picked by --arch flag, you may use the following options:\n"
] ]
@ -603,7 +627,7 @@ if __name__ == "__main__":
add_parser_arguments(parser) add_parser_arguments(parser)
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
model_arch = available_models()[args.arch] model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest) model_args, rest = model_arch.parser().parse_known_args(rest)
print(model_args) print(model_args)

View file

@ -68,9 +68,11 @@ def parse_quantization(parser):
metavar="ARCH", metavar="ARCH",
default="efficientnet-quant-b0", default="efficientnet-quant-b0",
choices=model_names, choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: efficientnet-quant-b0)", help="model architecture: "
+ " | ".join(model_names)
+ " (default: efficientnet-quant-b0)",
) )
parser.add_argument( parser.add_argument(
"--skip-calibration", "--skip-calibration",
action="store_true", action="store_true",
@ -80,6 +82,7 @@ def parse_quantization(parser):
def parse_training_args(parser): def parse_training_args(parser):
from main import add_parser_arguments from main import add_parser_arguments
return add_parser_arguments(parser) return add_parser_arguments(parser)
@ -92,28 +95,29 @@ def main(args, model_args, model_arch):
select_default_calib_method() select_default_calib_method()
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \ (
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch) trainer,
print(f"RUNNING QUANTIZATION")
if not skip_calibration:
calibrate(model_and_loss.model, train_loader, logger, calib_iter=10)
train_loop(
model_and_loss,
optimizer,
scaler,
lr_policy, lr_policy,
train_loader, train_loader,
train_loader_len,
val_loader,
logger,
start_epoch,
) = prepare_for_training(args, model_args, model_arch)
print(f"RUNNING QUANTIZATION")
if not skip_calibration:
calibrate(trainer.model_and_loss.model, train_loader, logger, calib_iter=10)
train_loop(
trainer,
lr_policy,
train_loader,
train_loader_len,
val_loader, val_loader,
logger, logger,
should_backup_checkpoint(args), should_backup_checkpoint(args),
ema=ema,
model_ema=model_ema,
steps_per_epoch=train_loader_len,
use_amp=args.amp,
batch_size_multiplier=batch_size_multiplier,
start_epoch=start_epoch, start_epoch=start_epoch,
end_epoch=min((start_epoch + args.run_epochs), args.epochs) end_epoch=min((start_epoch + args.run_epochs), args.epochs)
if args.run_epochs != -1 if args.run_epochs != -1
@ -124,7 +128,7 @@ def main(args, model_args, model_arch):
skip_validation=args.training_only, skip_validation=args.training_only,
save_checkpoints=args.save_checkpoints, save_checkpoints=args.save_checkpoints,
checkpoint_dir=args.workspace, checkpoint_dir=args.workspace,
checkpoint_filename='quantized_' + args.checkpoint_filename, checkpoint_filename="quantized_" + args.checkpoint_filename,
) )
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
@ -149,7 +153,7 @@ if __name__ == "__main__":
parse_training(parser, skip_arch=True) parse_training(parser, skip_arch=True)
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
model_arch = available_models()[args.arch] model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest) model_args, rest = model_arch.parser().parse_known_args(rest)
print(model_args) print(model_args)
@ -159,4 +163,3 @@ if __name__ == "__main__":
cudnn.benchmark = True cudnn.benchmark = True
main(args, model_args, model_arch) main(args, model_args, model_arch)