[ConvNets/PyT] TorchScriptable ConvNets
This commit is contained in:
parent
3d3250a3ae
commit
4f2c6922bd
2
PyTorch/Classification/ConvNets/.dockerignore
Normal file
2
PyTorch/Classification/ConvNets/.dockerignore
Normal file
|
@ -0,0 +1,2 @@
|
|||
*.pth.tar
|
||||
*.log
|
|
@ -22,6 +22,7 @@ def add_parser_arguments(parser):
|
|||
parser.add_argument(
|
||||
"--weight-path", metavar="<path>", help="name of file in which to store weights"
|
||||
)
|
||||
parser.add_argument("--ema", action="store_true", default=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -30,8 +31,9 @@ if __name__ == "__main__":
|
|||
add_parser_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location=torch.device('cpu'))
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location=torch.device("cpu"))
|
||||
|
||||
key = "state_dict" if not args.ema else "ema_state_dict"
|
||||
model_state_dict = {
|
||||
k[len("module.") :] if "module." in k else k: v
|
||||
for k, v in checkpoint["state_dict"].items()
|
||||
|
|
|
@ -234,32 +234,29 @@ class Logger(object):
|
|||
def log_metric(self, metric_name, val, n=1):
|
||||
self.metrics[metric_name]["meter"].record(val, n=n)
|
||||
|
||||
def start_iteration(self, mode='train'):
|
||||
if mode == 'val':
|
||||
def start_iteration(self, mode="train"):
|
||||
if mode == "val":
|
||||
self.val_iteration += 1
|
||||
elif mode == 'train':
|
||||
elif mode == "train":
|
||||
self.iteration += 1
|
||||
elif mode == 'calib':
|
||||
elif mode == "calib":
|
||||
self.calib_iteration += 1
|
||||
|
||||
def end_iteration(self, mode='train'):
|
||||
if mode == 'val':
|
||||
def end_iteration(self, mode="train"):
|
||||
if mode == "val":
|
||||
it = self.val_iteration
|
||||
elif mode == 'train':
|
||||
elif mode == "train":
|
||||
it = self.iteration
|
||||
elif mode == 'calib':
|
||||
elif mode == "calib":
|
||||
it = self.calib_iteration
|
||||
|
||||
if it % self.print_interval == 0 or mode == 'calib':
|
||||
metrics = {
|
||||
n: m for n, m in self.metrics.items() if n.startswith(mode)
|
||||
}
|
||||
if mode == 'train':
|
||||
if it % self.print_interval == 0 or mode == "calib":
|
||||
metrics = {n: m for n, m in self.metrics.items() if n.startswith(mode)}
|
||||
if mode == "train":
|
||||
step = (self.epoch, self.iteration)
|
||||
elif mode == 'val':
|
||||
elif mode == "val":
|
||||
step = (self.epoch, self.iteration, self.val_iteration)
|
||||
elif mode == 'calib':
|
||||
step = ('Calibration', self.calib_iteration)
|
||||
elif mode == "calib":
|
||||
step = ("Calibration", self.calib_iteration)
|
||||
|
||||
verbositys = {m["level"] for _, m in metrics.items()}
|
||||
for ll in verbositys:
|
||||
|
@ -282,12 +279,12 @@ class Logger(object):
|
|||
self.val_iteration = 0
|
||||
|
||||
for n, m in self.metrics.items():
|
||||
if not n.startswith('calib'):
|
||||
if not n.startswith("calib"):
|
||||
m["meter"].reset_epoch()
|
||||
|
||||
def end_epoch(self):
|
||||
for n, m in self.metrics.items():
|
||||
if not n.startswith('calib'):
|
||||
if not n.startswith("calib"):
|
||||
m["meter"].reset_iteration()
|
||||
|
||||
verbositys = {m["level"] for _, m in self.metrics.items()}
|
||||
|
@ -302,12 +299,12 @@ class Logger(object):
|
|||
self.calib_iteration = 0
|
||||
|
||||
for n, m in self.metrics.items():
|
||||
if n.startswith('calib'):
|
||||
if n.startswith("calib"):
|
||||
m["meter"].reset_epoch()
|
||||
|
||||
def end_calibration(self):
|
||||
for n, m in self.metrics.items():
|
||||
if n.startswith('calib'):
|
||||
if n.startswith("calib"):
|
||||
m["meter"].reset_iteration()
|
||||
|
||||
def end(self):
|
||||
|
@ -326,7 +323,7 @@ class Logger(object):
|
|||
|
||||
dllogger.flush()
|
||||
|
||||
def iteration_generator_wrapper(self, gen, mode='train'):
|
||||
def iteration_generator_wrapper(self, gen, mode="train"):
|
||||
for g in gen:
|
||||
self.start_iteration(mode=mode)
|
||||
yield g
|
||||
|
@ -337,3 +334,155 @@ class Logger(object):
|
|||
self.start_epoch()
|
||||
yield g
|
||||
self.end_epoch()
|
||||
|
||||
|
||||
class Metrics:
|
||||
ACC_METADATA = {"unit": "%", "format": ":.2f"}
|
||||
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
|
||||
TIME_METADATA = {"unit": "s", "format": ":.5f"}
|
||||
LOSS_METADATA = {"format": ":.5f"}
|
||||
LR_METADATA = {"format": ":.5f"}
|
||||
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
self.map = {}
|
||||
|
||||
def log(self, **kwargs):
|
||||
if self.logger is None:
|
||||
return
|
||||
for k, v in kwargs.items():
|
||||
tks = self.map.get(k, [k])
|
||||
for tk in tks:
|
||||
if isinstance(v, tuple):
|
||||
self.logger.log_metric(tk, v[0], v[1])
|
||||
else:
|
||||
self.logger.log_metric(tk, v)
|
||||
|
||||
|
||||
class TrainingMetrics(Metrics):
|
||||
def __init__(self, logger):
|
||||
super().__init__(logger)
|
||||
if self.logger is not None:
|
||||
self.map = {
|
||||
"loss": ["train.loss"],
|
||||
"compute_ips": ["train.compute_ips"],
|
||||
"total_ips": ["train.total_ips"],
|
||||
"data_time": ["train.data_time"],
|
||||
"compute_time": ["train.compute_time"],
|
||||
"lr": ["train.lr"],
|
||||
}
|
||||
logger.register_metric(
|
||||
"train.loss",
|
||||
LOSS_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.LOSS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.compute_ips",
|
||||
PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.IPS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.total_ips",
|
||||
PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.IPS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.data_time",
|
||||
PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=Metrics.TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.compute_time",
|
||||
PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=Metrics.TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.lr",
|
||||
LR_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
)
|
||||
|
||||
|
||||
class ValidationMetrics(Metrics):
|
||||
def __init__(self, logger, prefix):
|
||||
super().__init__(logger)
|
||||
if self.logger is not None:
|
||||
self.map = {
|
||||
"loss": [f"{prefix}.loss"],
|
||||
"top1": [f"{prefix}.top1"],
|
||||
"top5": [f"{prefix}.top5"],
|
||||
"compute_ips": [f"{prefix}.compute_ips"],
|
||||
"total_ips": [f"{prefix}.total_ips"],
|
||||
"data_time": [f"{prefix}.data_time"],
|
||||
"compute_time": [
|
||||
f"{prefix}.compute_latency",
|
||||
f"{prefix}.compute_latency_at100",
|
||||
f"{prefix}.compute_latency_at99",
|
||||
f"{prefix}.compute_latency_at95",
|
||||
],
|
||||
}
|
||||
logger.register_metric(
|
||||
f"{prefix}.top1",
|
||||
ACC_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.ACC_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.top5",
|
||||
ACC_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.ACC_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.loss",
|
||||
LOSS_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.LOSS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_ips",
|
||||
PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.IPS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.total_ips",
|
||||
PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.IPS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.data_time",
|
||||
PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=Metrics.TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_latency",
|
||||
PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=Metrics.TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_latency_at100",
|
||||
LAT_100(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=Metrics.TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_latency_at99",
|
||||
LAT_99(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=Metrics.TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_latency_at95",
|
||||
LAT_95(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=Metrics.TIME_METADATA,
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
import torch
|
||||
import warnings
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from pytorch_quantization import nn as quant_nn
|
||||
|
@ -143,49 +144,54 @@ class LambdaLayer(nn.Module):
|
|||
|
||||
# SqueezeAndExcitation {{{
|
||||
class SqueezeAndExcitation(nn.Module):
|
||||
def __init__(self, in_channels, squeeze, activation, use_conv=False):
|
||||
def __init__(self, in_channels, squeeze, activation):
|
||||
super(SqueezeAndExcitation, self).__init__()
|
||||
if use_conv:
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
|
||||
self.expand = nn.Conv2d(squeeze, in_channels, 1)
|
||||
else:
|
||||
self.squeeze = nn.Linear(in_channels, squeeze)
|
||||
self.expand = nn.Linear(squeeze, in_channels)
|
||||
self.activation = activation
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.use_conv = use_conv
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
out = self.pooling(x)
|
||||
else:
|
||||
return self._attention(x)
|
||||
|
||||
def _attention(self, x):
|
||||
out = torch.mean(x, [2, 3])
|
||||
out = self.squeeze(out)
|
||||
out = self.activation(out)
|
||||
out = self.expand(out)
|
||||
out = self.sigmoid(out)
|
||||
if not self.use_conv:
|
||||
out = out.unsqueeze(2).unsqueeze(3)
|
||||
return out
|
||||
|
||||
|
||||
class SqueezeAndExcitationTRT(nn.Module):
|
||||
def __init__(self, in_channels, squeeze, activation):
|
||||
super(SqueezeAndExcitationTRT, self).__init__()
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
|
||||
self.expand = nn.Conv2d(squeeze, in_channels, 1)
|
||||
self.activation = activation
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return self._attention(x)
|
||||
|
||||
def _attention(self, x):
|
||||
out = self.pooling(x)
|
||||
out = self.squeeze(out)
|
||||
out = self.activation(out)
|
||||
out = self.expand(out)
|
||||
out = self.sigmoid(out)
|
||||
return out
|
||||
|
||||
|
||||
# }}}
|
||||
|
||||
# EMA {{{
|
||||
class EMA:
|
||||
def __init__(self, mu):
|
||||
def __init__(self, mu, module_ema):
|
||||
self.mu = mu
|
||||
self.shadow = {}
|
||||
|
||||
def state_dict(self):
|
||||
return copy.deepcopy(self.shadow)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.shadow = state_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.shadow)
|
||||
self.module_ema = module_ema
|
||||
|
||||
def __call__(self, module, step=None):
|
||||
if step is None:
|
||||
|
@ -193,12 +199,17 @@ class EMA:
|
|||
else:
|
||||
mu = min(self.mu, (1.0 + step) / (10 + step))
|
||||
|
||||
def strip_module(s: str) -> str:
|
||||
return s
|
||||
|
||||
mesd = self.module_ema.state_dict()
|
||||
with torch.no_grad():
|
||||
for name, x in module.state_dict().items():
|
||||
if name in self.shadow:
|
||||
new_average = (1.0 - mu) * x + mu * self.shadow[name]
|
||||
self.shadow[name] = new_average.clone()
|
||||
else:
|
||||
self.shadow[name] = x.clone()
|
||||
if name.endswith("num_batches_tracked"):
|
||||
continue
|
||||
n = strip_module(name)
|
||||
mesd[n].mul_(mu)
|
||||
mesd[n].add_((1.0 - mu) * x)
|
||||
|
||||
|
||||
# }}}
|
||||
|
@ -218,10 +229,8 @@ class ONNXSiLU(nn.Module):
|
|||
|
||||
|
||||
class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
|
||||
def __init__(
|
||||
self, in_channels, squeeze, activation, quantized=False, use_conv=False
|
||||
):
|
||||
super().__init__(in_channels, squeeze, activation, use_conv=use_conv)
|
||||
def __init__(self, in_channels, squeeze, activation, quantized=False):
|
||||
super().__init__(in_channels, squeeze, activation)
|
||||
self.quantized = quantized
|
||||
if quantized:
|
||||
assert quant_nn is not None, "pytorch_quantization is not available"
|
||||
|
@ -231,10 +240,63 @@ class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
|
|||
self.mul_b_quantizer = quant_nn.TensorQuantizer(
|
||||
quant_nn.QuantConv2d.default_quant_desc_input
|
||||
)
|
||||
else:
|
||||
self.mul_a_quantizer = nn.Identity()
|
||||
self.mul_b_quantizer = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
out = self._attention(x)
|
||||
if not self.quantized:
|
||||
return super().forward(x) * x
|
||||
return out * x
|
||||
else:
|
||||
x_quant = self.mul_a_quantizer(super().forward(x))
|
||||
x_quant = self.mul_a_quantizer(out)
|
||||
return x_quant * self.mul_b_quantizer(x)
|
||||
|
||||
|
||||
class SequentialSqueezeAndExcitationTRT(SqueezeAndExcitationTRT):
|
||||
def __init__(self, in_channels, squeeze, activation, quantized=False):
|
||||
super().__init__(in_channels, squeeze, activation)
|
||||
self.quantized = quantized
|
||||
if quantized:
|
||||
assert quant_nn is not None, "pytorch_quantization is not available"
|
||||
self.mul_a_quantizer = quant_nn.TensorQuantizer(
|
||||
quant_nn.QuantConv2d.default_quant_desc_input
|
||||
)
|
||||
self.mul_b_quantizer = quant_nn.TensorQuantizer(
|
||||
quant_nn.QuantConv2d.default_quant_desc_input
|
||||
)
|
||||
else:
|
||||
self.mul_a_quantizer = nn.Identity()
|
||||
self.mul_b_quantizer = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
out = self._attention(x)
|
||||
if not self.quantized:
|
||||
return out * x
|
||||
else:
|
||||
x_quant = self.mul_a_quantizer(out)
|
||||
return x_quant * self.mul_b_quantizer(x)
|
||||
|
||||
|
||||
class StochasticDepthResidual(nn.Module):
|
||||
def __init__(self, survival_prob: float):
|
||||
super().__init__()
|
||||
self.survival_prob = survival_prob
|
||||
self.register_buffer("mask", torch.ones(()), persistent=False)
|
||||
|
||||
def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
if not self.training:
|
||||
return torch.add(residual, other=x)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
F.dropout(
|
||||
self.mask,
|
||||
p=1 - self.survival_prob,
|
||||
training=self.training,
|
||||
inplace=False,
|
||||
)
|
||||
return torch.addcmul(residual, self.mask, x)
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.squeeze(-1).squeeze(-1)
|
||||
|
|
|
@ -31,11 +31,11 @@ except ImportError as e:
|
|||
|
||||
|
||||
from .common import (
|
||||
SqueezeAndExcitation,
|
||||
ONNXSiLU,
|
||||
SequentialSqueezeAndExcitation,
|
||||
SequentialSqueezeAndExcitationTRT,
|
||||
LayerBuilder,
|
||||
LambdaLayer,
|
||||
StochasticDepthResidual,
|
||||
Flatten,
|
||||
)
|
||||
|
||||
from .model import (
|
||||
|
@ -206,6 +206,7 @@ class EfficientNet(nn.Module):
|
|||
out_channels = arch.stem_channels
|
||||
|
||||
plc = 0
|
||||
layers = []
|
||||
for i, (k, s, r, e, c) in arch.enumerate():
|
||||
layer, out_channels = self._make_layer(
|
||||
block=arch.block,
|
||||
|
@ -220,8 +221,8 @@ class EfficientNet(nn.Module):
|
|||
trt=trt,
|
||||
)
|
||||
plc = plc + r
|
||||
setattr(self, f"layer{i+1}", layer)
|
||||
|
||||
layers.append(layer)
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.features = self._make_features(out_channels, arch.feature_channels)
|
||||
self.classifier = self._make_classifier(
|
||||
arch.feature_channels, num_classes, dropout
|
||||
|
@ -229,11 +230,7 @@ class EfficientNet(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
fn = getattr(self, f"layer{i+1}")
|
||||
x = fn(x)
|
||||
|
||||
x = self.layers(x)
|
||||
x = self.features(x)
|
||||
x = self.classifier(x)
|
||||
|
||||
|
@ -241,27 +238,34 @@ class EfficientNet(nn.Module):
|
|||
|
||||
def extract_features(self, x, layers=None):
|
||||
if layers is None:
|
||||
layers = [f"layer{i+1}" for i in range(self.num_layers)]
|
||||
layers = [f"layer{i+1}" for i in range(self.num_layers)] + [
|
||||
"features",
|
||||
"classifier",
|
||||
]
|
||||
|
||||
run = [
|
||||
f"layer{i+1}"
|
||||
i
|
||||
for i in range(self.num_layers)
|
||||
if "classifier" in layers
|
||||
or "features" in layers
|
||||
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
|
||||
]
|
||||
if "features" in layers or "classifier" in layers:
|
||||
run.append("features")
|
||||
if "classifier" in layers:
|
||||
run.append("classifier")
|
||||
|
||||
output = {}
|
||||
x = self.stem(x)
|
||||
for l in run:
|
||||
fn = getattr(self, l)
|
||||
fn = self.layers[l]
|
||||
x = fn(x)
|
||||
if l in layers:
|
||||
output[l] = x
|
||||
if f"layer{l+1}" in layers:
|
||||
output[f"layer{l+1}"] = x
|
||||
|
||||
if "features" in layers or "classifier" in layers:
|
||||
x = self.features(x)
|
||||
if "features" in layers:
|
||||
output["features"] = x
|
||||
|
||||
if "classifier" in layers:
|
||||
output["classifier"] = self.classifier(x)
|
||||
|
||||
return output
|
||||
|
||||
|
@ -298,7 +302,7 @@ class EfficientNet(nn.Module):
|
|||
OrderedDict(
|
||||
[
|
||||
("pooling", nn.AdaptiveAvgPool2d(1)),
|
||||
("squeeze", LambdaLayer(lambda x: x.squeeze(-1).squeeze(-1))),
|
||||
("squeeze", Flatten()),
|
||||
("dropout", nn.Dropout(dropout)),
|
||||
("fc", nn.Linear(num_features, num_classes)),
|
||||
]
|
||||
|
@ -353,11 +357,33 @@ class EfficientNet(nn.Module):
|
|||
layers.append((f"block{idx}", blk))
|
||||
return nn.Sequential(OrderedDict(layers)), out_channels
|
||||
|
||||
def ngc_checkpoint_remap(self, url=None, version=None):
|
||||
if version is None:
|
||||
version = url.split("/")[8]
|
||||
|
||||
def to_sequential_remap(s):
|
||||
splited = s.split(".")
|
||||
if splited[0].startswith("layer"):
|
||||
return ".".join(
|
||||
["layers." + str(int(splited[0][len("layer") :]) - 1)] + splited[1:]
|
||||
)
|
||||
else:
|
||||
return s
|
||||
|
||||
def no_remap(s):
|
||||
return s
|
||||
|
||||
return {"20.12.0": to_sequential_remap, "21.03.0": to_sequential_remap}.get(
|
||||
version, no_remap
|
||||
)
|
||||
|
||||
|
||||
# }}}
|
||||
|
||||
# MBConvBlock {{{
|
||||
class MBConvBlock(nn.Module):
|
||||
__constants__ = ["quantized"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
builder: LayerBuilder,
|
||||
|
@ -366,7 +392,7 @@ class MBConvBlock(nn.Module):
|
|||
out_channels: int,
|
||||
expand_ratio: int,
|
||||
stride: int,
|
||||
squeeze_excitation_ratio: int,
|
||||
squeeze_excitation_ratio: float,
|
||||
squeeze_hidden=False,
|
||||
survival_prob: float = 1.0,
|
||||
quantized: bool = False,
|
||||
|
@ -387,25 +413,31 @@ class MBConvBlock(nn.Module):
|
|||
self.depsep = builder.convDepSep(
|
||||
depsep_kernel_size, hidden_dim, hidden_dim, stride, bn=True, act=True
|
||||
)
|
||||
self.se = SequentialSqueezeAndExcitation(
|
||||
hidden_dim, squeeze_dim, builder.activation(), self.quantized, use_conv=trt
|
||||
if trt or self.quantized:
|
||||
# Need TRT mode for quantized in order to automatically insert quantization before pooling
|
||||
self.se: nn.Module = SequentialSqueezeAndExcitationTRT(
|
||||
hidden_dim, squeeze_dim, builder.activation(), self.quantized
|
||||
)
|
||||
else:
|
||||
self.se: nn.Module = SequentialSqueezeAndExcitation(
|
||||
hidden_dim, squeeze_dim, builder.activation(), self.quantized
|
||||
)
|
||||
|
||||
self.proj = builder.conv1x1(hidden_dim, out_channels, bn=True)
|
||||
|
||||
self.survival_prob = survival_prob
|
||||
|
||||
if survival_prob == 1.0:
|
||||
self.residual_add = torch.add
|
||||
else:
|
||||
self.residual_add = StochasticDepthResidual(survival_prob=survival_prob)
|
||||
if self.quantized and self.residual:
|
||||
assert quant_nn is not None, "pytorch_quantization is not available"
|
||||
self.residual_quantizer = quant_nn.TensorQuantizer(
|
||||
quant_nn.QuantConv2d.default_quant_desc_input
|
||||
) # TODO QuantConv2d ?!?
|
||||
else:
|
||||
self.residual_quantizer = nn.Identity()
|
||||
|
||||
def drop(self):
|
||||
if self.survival_prob == 1.0:
|
||||
return False
|
||||
return random.uniform(0.0, 1.0) > self.survival_prob
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if not self.residual:
|
||||
return self.proj(
|
||||
self.se(self.depsep(x if self.expand is None else self.expand(x)))
|
||||
|
@ -414,16 +446,10 @@ class MBConvBlock(nn.Module):
|
|||
b = self.proj(
|
||||
self.se(self.depsep(x if self.expand is None else self.expand(x)))
|
||||
)
|
||||
if self.training:
|
||||
if self.drop():
|
||||
multiplication_factor = 0.0
|
||||
else:
|
||||
multiplication_factor = 1.0 / self.survival_prob
|
||||
else:
|
||||
multiplication_factor = 1.0
|
||||
if self.quantized:
|
||||
x = self.residual_quantizer(x)
|
||||
return torch.add(x, alpha=multiplication_factor, other=b)
|
||||
|
||||
return self.residual_add(x, b)
|
||||
|
||||
|
||||
def original_mbconv(
|
||||
|
@ -436,7 +462,7 @@ def original_mbconv(
|
|||
squeeze_excitation_ratio: int,
|
||||
survival_prob: float,
|
||||
quantized: bool,
|
||||
trt: bool
|
||||
trt: bool,
|
||||
):
|
||||
return MBConvBlock(
|
||||
builder,
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
from dataclasses import dataclass, asdict, replace
|
||||
from .common import (
|
||||
SequentialSqueezeAndExcitationTRT,
|
||||
SequentialSqueezeAndExcitation,
|
||||
SqueezeAndExcitation,
|
||||
SqueezeAndExcitationTRT,
|
||||
)
|
||||
from typing import Optional, Callable
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -37,7 +44,13 @@ class EntryPoint:
|
|||
self.name = name
|
||||
self.model = model
|
||||
|
||||
def __call__(self, pretrained=False, pretrained_from_file=None, **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
pretrained=False,
|
||||
pretrained_from_file=None,
|
||||
state_dict_key_map_fn=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert not (pretrained and (pretrained_from_file is not None))
|
||||
params = replace(self.model.params, **kwargs)
|
||||
|
||||
|
@ -66,7 +79,7 @@ class EntryPoint:
|
|||
pretrained_from_file
|
||||
)
|
||||
)
|
||||
# Temporary fix to allow NGC checkpoint loading
|
||||
|
||||
if state_dict is not None:
|
||||
state_dict = {
|
||||
k[len("module.") :] if k.startswith("module.") else k: v
|
||||
|
@ -85,12 +98,32 @@ class EntryPoint:
|
|||
else:
|
||||
return t
|
||||
|
||||
if state_dict_key_map_fn is not None:
|
||||
state_dict = {
|
||||
state_dict_key_map_fn(k): v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
if hasattr(model, "ngc_checkpoint_remap"):
|
||||
remap_fn = model.ngc_checkpoint_remap(url=self.model.checkpoint_url)
|
||||
state_dict = {remap_fn(k): v for k, v in state_dict.items()}
|
||||
|
||||
def _se_layer_uses_conv(m):
|
||||
return any(
|
||||
map(
|
||||
partial(isinstance, m),
|
||||
[
|
||||
SqueezeAndExcitationTRT,
|
||||
SequentialSqueezeAndExcitationTRT,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
state_dict = {
|
||||
k: reshape(
|
||||
v,
|
||||
conv=dict(model.named_modules())[
|
||||
".".join(k.split(".")[:-2])
|
||||
].use_conv,
|
||||
conv=_se_layer_uses_conv(
|
||||
dict(model.named_modules())[".".join(k.split(".")[:-2])]
|
||||
),
|
||||
)
|
||||
if is_se_weight(k, v)
|
||||
else v
|
||||
|
@ -123,7 +156,8 @@ class EntryPoint:
|
|||
|
||||
|
||||
def is_se_weight(key, value):
|
||||
return (key.endswith("squeeze.weight") or key.endswith("expand.weight"))
|
||||
return key.endswith("squeeze.weight") or key.endswith("expand.weight")
|
||||
|
||||
|
||||
def create_entrypoint(m: Model):
|
||||
def _ep(**kwargs):
|
||||
|
|
|
@ -36,14 +36,16 @@ from typing import List, Dict, Callable, Any, Type
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .common import SqueezeAndExcitation, LayerBuilder, LambdaLayer
|
||||
from .common import (
|
||||
SqueezeAndExcitation,
|
||||
LayerBuilder,
|
||||
SqueezeAndExcitationTRT,
|
||||
)
|
||||
|
||||
from .model import (
|
||||
Model,
|
||||
ModelParams,
|
||||
ModelArch,
|
||||
OptimizerParams,
|
||||
create_entrypoint,
|
||||
EntryPoint,
|
||||
)
|
||||
|
||||
|
@ -128,11 +130,18 @@ class Bottleneck(nn.Module):
|
|||
self.stride = stride
|
||||
|
||||
self.fused_se = fused_se
|
||||
if se:
|
||||
self.squeeze = (
|
||||
SqueezeAndExcitation(planes * expansion, se_squeeze, builder.activation(), use_conv=trt)
|
||||
if se
|
||||
else None
|
||||
SqueezeAndExcitation(
|
||||
planes * expansion, se_squeeze, builder.activation()
|
||||
)
|
||||
if not trt
|
||||
else SqueezeAndExcitationTRT(
|
||||
planes * expansion, se_squeeze, builder.activation()
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.squeeze = None
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
@ -215,6 +224,7 @@ class ResNet(nn.Module):
|
|||
last_bn_0_init: bool = False
|
||||
conv_init: str = "fan_in"
|
||||
trt: bool = False
|
||||
fused_se: bool = True
|
||||
|
||||
def parser(self, name):
|
||||
p = super().parser(name)
|
||||
|
@ -240,6 +250,10 @@ class ResNet(nn.Module):
|
|||
help="initialization mode for convolutional layers, see https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_",
|
||||
)
|
||||
p.add_argument("--trt", metavar="True|False", default=self.trt, type=bool)
|
||||
p.add_argument(
|
||||
"--fused_se", metavar="True|False", default=self.fused_se, type=bool
|
||||
)
|
||||
|
||||
return p
|
||||
|
||||
def __init__(
|
||||
|
@ -249,6 +263,7 @@ class ResNet(nn.Module):
|
|||
last_bn_0_init: bool = False,
|
||||
conv_init: str = "fan_in",
|
||||
trt: bool = False,
|
||||
fused_se: bool = True,
|
||||
):
|
||||
|
||||
super(ResNet, self).__init__()
|
||||
|
@ -265,6 +280,7 @@ class ResNet(nn.Module):
|
|||
inplanes = arch.stem_width
|
||||
assert len(arch.widths) == len(arch.layers)
|
||||
self.num_layers = len(arch.widths)
|
||||
layers = []
|
||||
for i, (w, l) in enumerate(zip(arch.widths, arch.layers)):
|
||||
layer, inplanes = self._make_layer(
|
||||
arch.block,
|
||||
|
@ -275,9 +291,11 @@ class ResNet(nn.Module):
|
|||
cardinality=arch.cardinality,
|
||||
stride=1 if i == 0 else 2,
|
||||
trt=trt,
|
||||
fused_se=fused_se,
|
||||
)
|
||||
setattr(self, f"layer{i+1}", layer)
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Linear(arch.widths[-1] * arch.expansion, num_classes)
|
||||
|
||||
|
@ -297,13 +315,8 @@ class ResNet(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
fn = getattr(self, f"layer{i+1}")
|
||||
x = fn(x)
|
||||
|
||||
x = self.layers(x)
|
||||
x = self.classifier(x)
|
||||
|
||||
return x
|
||||
|
||||
def extract_features(self, x, layers=None):
|
||||
|
@ -311,7 +324,7 @@ class ResNet(nn.Module):
|
|||
layers = [f"layer{i+1}" for i in range(self.num_layers)] + ["classifier"]
|
||||
|
||||
run = [
|
||||
f"layer{i+1}"
|
||||
i
|
||||
for i in range(self.num_layers)
|
||||
if "classifier" in layers
|
||||
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
|
||||
|
@ -320,10 +333,10 @@ class ResNet(nn.Module):
|
|||
output = {}
|
||||
x = self.stem(x)
|
||||
for l in run:
|
||||
fn = getattr(self, l)
|
||||
fn = self.layers[l]
|
||||
x = fn(x)
|
||||
if l in layers:
|
||||
output[l] = x
|
||||
if f"layer{l+1}" in layers:
|
||||
output[f"layer{l+1}"] = x
|
||||
|
||||
if "classifier" in layers:
|
||||
output["classifier"] = self.classifier(x)
|
||||
|
@ -332,7 +345,16 @@ class ResNet(nn.Module):
|
|||
|
||||
# helper functions {{{
|
||||
def _make_layer(
|
||||
self, block, expansion, inplanes, planes, blocks, stride=1, cardinality=1, trt=False,
|
||||
self,
|
||||
block,
|
||||
expansion,
|
||||
inplanes,
|
||||
planes,
|
||||
blocks,
|
||||
stride=1,
|
||||
cardinality=1,
|
||||
trt=False,
|
||||
fused_se=True,
|
||||
):
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * expansion:
|
||||
|
@ -354,15 +376,33 @@ class ResNet(nn.Module):
|
|||
stride=stride if i == 0 else 1,
|
||||
cardinality=cardinality,
|
||||
downsample=downsample if i == 0 else None,
|
||||
fused_se=True,
|
||||
fused_se=fused_se,
|
||||
last_bn_0_init=self.last_bn_0_init,
|
||||
trt = trt,
|
||||
trt=trt,
|
||||
)
|
||||
)
|
||||
inplanes = planes * expansion
|
||||
|
||||
return nn.Sequential(*layers), inplanes
|
||||
|
||||
def ngc_checkpoint_remap(self, url=None, version=None):
|
||||
if version is None:
|
||||
version = url.split("/")[8]
|
||||
|
||||
def to_sequential_remap(s):
|
||||
splited = s.split(".")
|
||||
if splited[0].startswith("layer"):
|
||||
return ".".join(
|
||||
["layers." + str(int(splited[0][len("layer") :]) - 1)] + splited[1:]
|
||||
)
|
||||
else:
|
||||
return s
|
||||
|
||||
def no_remap(s):
|
||||
return s
|
||||
|
||||
return {"20.06.0": to_sequential_remap}.get(version, no_remap)
|
||||
|
||||
# }}}
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,39 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import optim
|
||||
|
||||
def get_sgd_optimizer(parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False):
|
||||
|
||||
def get_optimizer(parameters, lr, args, state=None):
|
||||
if args.optimizer == "sgd":
|
||||
optimizer = get_sgd_optimizer(
|
||||
parameters,
|
||||
lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
nesterov=args.nesterov,
|
||||
bn_weight_decay=args.bn_weight_decay,
|
||||
)
|
||||
elif args.optimizer == "rmsprop":
|
||||
optimizer = get_rmsprop_optimizer(
|
||||
parameters,
|
||||
lr,
|
||||
alpha=args.rmsprop_alpha,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
eps=args.rmsprop_eps,
|
||||
bn_weight_decay=args.bn_weight_decay,
|
||||
)
|
||||
if not state is None:
|
||||
optimizer.load_state_dict(state)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def get_sgd_optimizer(
|
||||
parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False
|
||||
):
|
||||
if bn_weight_decay:
|
||||
print(" ! Weight decay applied to BN parameters ")
|
||||
params = [v for n, v in parameters]
|
||||
|
@ -17,12 +49,16 @@ def get_sgd_optimizer(parameters, lr, momentum, weight_decay, nesterov=False, bn
|
|||
{"params": rest_params, "weight_decay": weight_decay},
|
||||
]
|
||||
|
||||
optimizer = torch.optim.SGD(params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
|
||||
optimizer = torch.optim.SGD(
|
||||
params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def get_rmsprop_optimizer(parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False):
|
||||
def get_rmsprop_optimizer(
|
||||
parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False
|
||||
):
|
||||
bn_params = [v for n, v in parameters if "bn" in n]
|
||||
rest_params = [v for n, v in parameters if not "bn" in n]
|
||||
|
||||
|
@ -31,6 +67,94 @@ def get_rmsprop_optimizer(parameters, lr, alpha, weight_decay, momentum, eps, bn
|
|||
{"params": rest_params, "weight_decay": weight_decay},
|
||||
]
|
||||
|
||||
optimizer = torch.optim.RMSprop(params, lr=lr, alpha=alpha, weight_decay=weight_decay, momentum=momentum, eps=eps)
|
||||
optimizer = torch.optim.RMSprop(
|
||||
params,
|
||||
lr=lr,
|
||||
alpha=alpha,
|
||||
weight_decay=weight_decay,
|
||||
momentum=momentum,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def lr_policy(lr_fn):
|
||||
def _alr(optimizer, iteration, epoch):
|
||||
lr = lr_fn(iteration, epoch)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
return lr
|
||||
|
||||
return _alr
|
||||
|
||||
|
||||
def lr_step_policy(base_lr, steps, decay_factor, warmup_length):
|
||||
def _lr_fn(iteration, epoch):
|
||||
if epoch < warmup_length:
|
||||
lr = base_lr * (epoch + 1) / warmup_length
|
||||
else:
|
||||
lr = base_lr
|
||||
for s in steps:
|
||||
if epoch >= s:
|
||||
lr *= decay_factor
|
||||
return lr
|
||||
|
||||
return lr_policy(_lr_fn)
|
||||
|
||||
|
||||
def lr_linear_policy(base_lr, warmup_length, epochs):
|
||||
def _lr_fn(iteration, epoch):
|
||||
if epoch < warmup_length:
|
||||
lr = base_lr * (epoch + 1) / warmup_length
|
||||
else:
|
||||
e = epoch - warmup_length
|
||||
es = epochs - warmup_length
|
||||
lr = base_lr * (1 - (e / es))
|
||||
return lr
|
||||
|
||||
return lr_policy(_lr_fn)
|
||||
|
||||
|
||||
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr=0):
|
||||
def _lr_fn(iteration, epoch):
|
||||
if epoch < warmup_length:
|
||||
lr = base_lr * (epoch + 1) / warmup_length
|
||||
else:
|
||||
e = epoch - warmup_length
|
||||
es = epochs - warmup_length
|
||||
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
|
||||
return lr
|
||||
|
||||
return lr_policy(_lr_fn)
|
||||
|
||||
|
||||
def lr_exponential_policy(
|
||||
base_lr,
|
||||
warmup_length,
|
||||
epochs,
|
||||
final_multiplier=0.001,
|
||||
decay_factor=None,
|
||||
decay_step=1,
|
||||
logger=None,
|
||||
):
|
||||
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
|
||||
es = epochs - warmup_length
|
||||
|
||||
if decay_factor is not None:
|
||||
epoch_decay = decay_factor
|
||||
else:
|
||||
epoch_decay = np.power(
|
||||
2, np.log2(final_multiplier) / math.floor(es / decay_step)
|
||||
)
|
||||
|
||||
def _lr_fn(iteration, epoch):
|
||||
if epoch < warmup_length:
|
||||
lr = base_lr * (epoch + 1) / warmup_length
|
||||
else:
|
||||
e = epoch - warmup_length
|
||||
lr = base_lr * (epoch_decay ** math.floor(e / decay_step))
|
||||
return lr
|
||||
|
||||
return lr_policy(_lr_fn, logger=logger)
|
||||
|
|
|
@ -27,279 +27,212 @@
|
|||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import math
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from . import logger as log
|
||||
from . import models
|
||||
from . import utils
|
||||
import dllogger
|
||||
|
||||
from .optimizers import get_sgd_optimizer, get_rmsprop_optimizer
|
||||
from .models.common import EMA
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
ACC_METADATA = {"unit": "%", "format": ":.2f"}
|
||||
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
|
||||
TIME_METADATA = {"unit": "s", "format": ":.5f"}
|
||||
LOSS_METADATA = {"format": ":.5f"}
|
||||
from . import logger as log
|
||||
from . import utils
|
||||
from .logger import TrainingMetrics, ValidationMetrics
|
||||
from .models.common import EMA
|
||||
|
||||
|
||||
class ModelAndLoss(nn.Module):
|
||||
class Executor:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
loss,
|
||||
cuda=True,
|
||||
memory_format=torch.contiguous_format,
|
||||
model: nn.Module,
|
||||
loss: Optional[nn.Module],
|
||||
cuda: bool = True,
|
||||
memory_format: torch.memory_format = torch.contiguous_format,
|
||||
amp: bool = False,
|
||||
scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
||||
divide_loss: int = 1,
|
||||
ts_script: bool = False,
|
||||
):
|
||||
super(ModelAndLoss, self).__init__()
|
||||
assert not (amp and scaler is None), "Gradient Scaler is needed for AMP"
|
||||
|
||||
def xform(m: nn.Module) -> nn.Module:
|
||||
if cuda:
|
||||
model = model.cuda().to(memory_format=memory_format)
|
||||
m = m.cuda()
|
||||
m.to(memory_format=memory_format)
|
||||
return m
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion = loss()
|
||||
|
||||
if cuda:
|
||||
criterion = criterion.cuda()
|
||||
|
||||
self.model = model
|
||||
self.loss = criterion
|
||||
|
||||
def forward(self, data, target):
|
||||
output = self.model(data)
|
||||
loss = self.loss(output, target)
|
||||
|
||||
return loss, output
|
||||
self.model = xform(model)
|
||||
if ts_script:
|
||||
self.model = torch.jit.script(self.model)
|
||||
self.ts_script = ts_script
|
||||
self.loss = xform(loss) if loss is not None else None
|
||||
self.amp = amp
|
||||
self.scaler = scaler
|
||||
self.is_distributed = False
|
||||
self.divide_loss = divide_loss
|
||||
self._fwd_bwd = None
|
||||
self._forward = None
|
||||
|
||||
def distributed(self, gpu_id):
|
||||
self.is_distributed = True
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
def load_model_state(self, state):
|
||||
if not state is None:
|
||||
self.model.load_state_dict(state)
|
||||
def _fwd_bwd_fn(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
with autocast(enabled=self.amp):
|
||||
loss = self.loss(self.model(input), target)
|
||||
loss /= self.divide_loss
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
return loss
|
||||
|
||||
def get_optimizer(parameters, lr, args, state=None):
|
||||
if args.optimizer == 'sgd':
|
||||
optimizer = get_sgd_optimizer(parameters, lr, momentum=args.momentum,
|
||||
weight_decay=args.weight_decay, nesterov=args.nesterov,
|
||||
bn_weight_decay=args.bn_weight_decay)
|
||||
elif args.optimizer == 'rmsprop':
|
||||
optimizer = get_rmsprop_optimizer(parameters, lr, alpha=args.rmsprop_alpha, momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
eps=args.rmsprop_eps,
|
||||
bn_weight_decay=args.bn_weight_decay)
|
||||
if not state is None:
|
||||
optimizer.load_state_dict(state)
|
||||
def _forward_fn(
|
||||
self, input: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with torch.no_grad(), autocast(enabled=self.amp):
|
||||
output = self.model(input)
|
||||
loss = None if self.loss is None else self.loss(output, target)
|
||||
|
||||
return optimizer
|
||||
return output if loss is None else loss, output
|
||||
|
||||
def optimize(self, fn):
|
||||
return fn
|
||||
|
||||
def lr_policy(lr_fn, logger=None):
|
||||
if logger is not None:
|
||||
logger.register_metric(
|
||||
"lr", log.LR_METER(), verbosity=dllogger.Verbosity.VERBOSE
|
||||
@property
|
||||
def forward_backward(self):
|
||||
if self._fwd_bwd is None:
|
||||
if self.loss is None:
|
||||
raise NotImplementedError(
|
||||
"Loss must not be None for forward+backward step"
|
||||
)
|
||||
self._fwd_bwd = self.optimize(self._fwd_bwd_fn)
|
||||
return self._fwd_bwd
|
||||
|
||||
def _alr(optimizer, iteration, epoch):
|
||||
lr = lr_fn(iteration, epoch)
|
||||
|
||||
if logger is not None:
|
||||
logger.log_metric("lr", lr)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
return _alr
|
||||
@property
|
||||
def forward(self):
|
||||
if self._forward is None:
|
||||
self._forward = self.optimize(self._forward_fn)
|
||||
return self._forward
|
||||
|
||||
|
||||
def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None):
|
||||
def _lr_fn(iteration, epoch):
|
||||
if epoch < warmup_length:
|
||||
lr = base_lr * (epoch + 1) / warmup_length
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
executor: Executor,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
grad_acc_steps: int,
|
||||
ema: Optional[float] = None,
|
||||
):
|
||||
self.executor = executor
|
||||
self.optimizer = optimizer
|
||||
self.grad_acc_steps = grad_acc_steps
|
||||
self.use_ema = False
|
||||
if ema is not None:
|
||||
self.ema_executor = deepcopy(self.executor)
|
||||
self.ema = EMA(ema, self.ema_executor.model)
|
||||
self.use_ema = True
|
||||
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
self.steps_since_update = 0
|
||||
|
||||
def train(self):
|
||||
self.executor.model.train()
|
||||
|
||||
def eval(self):
|
||||
self.executor.model.eval()
|
||||
if self.use_ema:
|
||||
self.executor.model.eval()
|
||||
|
||||
def train_step(self, input, target, step=None):
|
||||
loss = self.executor.forward_backward(input, target)
|
||||
|
||||
self.steps_since_update += 1
|
||||
|
||||
if self.steps_since_update == self.grad_acc_steps:
|
||||
if self.executor.scaler is not None:
|
||||
self.executor.scaler.step(self.optimizer)
|
||||
self.executor.scaler.update()
|
||||
else:
|
||||
lr = base_lr
|
||||
for s in steps:
|
||||
if epoch >= s:
|
||||
lr *= decay_factor
|
||||
return lr
|
||||
|
||||
return lr_policy(_lr_fn, logger=logger)
|
||||
|
||||
|
||||
def lr_linear_policy(base_lr, warmup_length, epochs, logger=None):
|
||||
def _lr_fn(iteration, epoch):
|
||||
if epoch < warmup_length:
|
||||
lr = base_lr * (epoch + 1) / warmup_length
|
||||
else:
|
||||
e = epoch - warmup_length
|
||||
es = epochs - warmup_length
|
||||
lr = base_lr * (1 - (e / es))
|
||||
return lr
|
||||
|
||||
return lr_policy(_lr_fn, logger=logger)
|
||||
|
||||
|
||||
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr = 0, logger=None):
|
||||
def _lr_fn(iteration, epoch):
|
||||
if epoch < warmup_length:
|
||||
lr = base_lr * (epoch + 1) / warmup_length
|
||||
else:
|
||||
e = epoch - warmup_length
|
||||
es = epochs - warmup_length
|
||||
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
|
||||
return lr
|
||||
|
||||
return lr_policy(_lr_fn, logger=logger)
|
||||
|
||||
|
||||
def lr_exponential_policy(
|
||||
base_lr, warmup_length, epochs, final_multiplier=0.001, decay_factor=None, decay_step=1, logger=None
|
||||
):
|
||||
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
|
||||
es = epochs - warmup_length
|
||||
|
||||
if decay_factor is not None:
|
||||
epoch_decay = decay_factor
|
||||
else:
|
||||
epoch_decay = np.power(2, np.log2(final_multiplier) / math.floor(es/decay_step))
|
||||
|
||||
def _lr_fn(iteration, epoch):
|
||||
if epoch < warmup_length:
|
||||
lr = base_lr * (epoch + 1) / warmup_length
|
||||
else:
|
||||
e = epoch - warmup_length
|
||||
lr = base_lr * (epoch_decay ** math.floor(e/decay_step))
|
||||
return lr
|
||||
|
||||
return lr_policy(_lr_fn, logger=logger)
|
||||
|
||||
|
||||
def get_train_step(
|
||||
model_and_loss, optimizer, scaler, use_amp=False, batch_size_multiplier=1
|
||||
):
|
||||
def _step(input, target, optimizer_step=True):
|
||||
input_var = Variable(input)
|
||||
target_var = Variable(target)
|
||||
|
||||
with autocast(enabled=use_amp):
|
||||
loss, output = model_and_loss(input_var, target_var)
|
||||
loss /= batch_size_multiplier
|
||||
if torch.distributed.is_initialized():
|
||||
reduced_loss = utils.reduce_tensor(loss.data)
|
||||
else:
|
||||
reduced_loss = loss.data
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if optimizer_step:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.steps_since_update = 0
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return reduced_loss
|
||||
if self.use_ema:
|
||||
self.ema(self.executor.model, step=step)
|
||||
|
||||
return _step
|
||||
return loss
|
||||
|
||||
def validation_steps(self) -> Dict[str, Callable]:
|
||||
vsd: Dict[str, Callable] = {"val": self.executor.forward}
|
||||
if self.use_ema:
|
||||
vsd["val_ema"] = self.ema_executor.forward
|
||||
return vsd
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
res = {
|
||||
"state_dict": self.executor.model.state_dict(),
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
}
|
||||
if self.use_ema:
|
||||
res["state_dict_ema"] = self.ema_executor.model.state_dict()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def train(
|
||||
train_step,
|
||||
train_loader,
|
||||
model_and_loss,
|
||||
optimizer,
|
||||
scaler,
|
||||
lr_scheduler,
|
||||
logger,
|
||||
epoch,
|
||||
steps_per_epoch,
|
||||
log_fn,
|
||||
timeout_handler,
|
||||
ema=None,
|
||||
use_amp=False,
|
||||
prof=-1,
|
||||
batch_size_multiplier=1,
|
||||
register_metrics=True,
|
||||
step=0,
|
||||
):
|
||||
interrupted = False
|
||||
if register_metrics and logger is not None:
|
||||
logger.register_metric(
|
||||
"train.loss",
|
||||
log.LOSS_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=LOSS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.compute_ips",
|
||||
log.PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=IPS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.total_ips",
|
||||
log.PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=IPS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.data_time",
|
||||
log.PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
"train.compute_time",
|
||||
log.PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=TIME_METADATA,
|
||||
)
|
||||
|
||||
step = get_train_step(
|
||||
model_and_loss,
|
||||
optimizer,
|
||||
scaler=scaler,
|
||||
use_amp=use_amp,
|
||||
batch_size_multiplier=batch_size_multiplier,
|
||||
)
|
||||
|
||||
model_and_loss.train()
|
||||
end = time.time()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
data_iter = enumerate(train_loader)
|
||||
if logger is not None:
|
||||
data_iter = logger.iteration_generator_wrapper(data_iter, mode='train')
|
||||
|
||||
for i, (input, target) in data_iter:
|
||||
bs = input.size(0)
|
||||
lr_scheduler(optimizer, i, epoch)
|
||||
lr = lr_scheduler(i)
|
||||
data_time = time.time() - end
|
||||
|
||||
optimizer_step = ((i + 1) % batch_size_multiplier) == 0
|
||||
loss = step(input, target, optimizer_step=optimizer_step)
|
||||
if ema is not None:
|
||||
ema(model_and_loss, epoch*steps_per_epoch+i)
|
||||
|
||||
loss = train_step(input, target, step=step + i)
|
||||
it_time = time.time() - end
|
||||
|
||||
if logger is not None:
|
||||
logger.log_metric("train.loss", loss.item(), bs)
|
||||
logger.log_metric("train.compute_ips", utils.calc_ips(bs, it_time - data_time))
|
||||
logger.log_metric("train.total_ips", utils.calc_ips(bs, it_time))
|
||||
logger.log_metric("train.data_time", data_time)
|
||||
logger.log_metric("train.compute_time", it_time - data_time)
|
||||
with torch.no_grad():
|
||||
if torch.distributed.is_initialized():
|
||||
reduced_loss = utils.reduce_tensor(loss.detach())
|
||||
else:
|
||||
reduced_loss = loss.detach()
|
||||
|
||||
log_fn(
|
||||
compute_ips=utils.calc_ips(bs, it_time - data_time),
|
||||
total_ips=utils.calc_ips(bs, it_time),
|
||||
data_time=data_time,
|
||||
compute_time=it_time - data_time,
|
||||
lr=lr,
|
||||
loss=reduced_loss.item(),
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
if prof > 0 and (i + 1 >= prof):
|
||||
time.sleep(5)
|
||||
break
|
||||
if ((i+1) % 20 == 0) and timeout_handler.interrupted:
|
||||
if ((i + 1) % 20 == 0) and timeout_handler.interrupted:
|
||||
time.sleep(5)
|
||||
interrupted = True
|
||||
break
|
||||
|
@ -307,134 +240,58 @@ def train(
|
|||
return interrupted
|
||||
|
||||
|
||||
def get_val_step(model_and_loss, use_amp=False):
|
||||
def _step(input, target):
|
||||
input_var = Variable(input)
|
||||
target_var = Variable(target)
|
||||
|
||||
with torch.no_grad(), autocast(enabled=use_amp):
|
||||
loss, output = model_and_loss(input_var, target_var)
|
||||
|
||||
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
reduced_loss = utils.reduce_tensor(loss.data)
|
||||
prec1 = utils.reduce_tensor(prec1)
|
||||
prec5 = utils.reduce_tensor(prec5)
|
||||
else:
|
||||
reduced_loss = loss.data
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return reduced_loss, prec1, prec5
|
||||
|
||||
return _step
|
||||
|
||||
|
||||
def validate(
|
||||
val_loader,
|
||||
model_and_loss,
|
||||
logger,
|
||||
epoch,
|
||||
use_amp=False,
|
||||
prof=-1,
|
||||
register_metrics=True,
|
||||
prefix="val",
|
||||
):
|
||||
if register_metrics and logger is not None:
|
||||
logger.register_metric(
|
||||
f"{prefix}.top1",
|
||||
log.ACC_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=ACC_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.top5",
|
||||
log.ACC_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=ACC_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.loss",
|
||||
log.LOSS_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=LOSS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_ips",
|
||||
log.PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=IPS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.total_ips",
|
||||
log.PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.DEFAULT,
|
||||
metadata=IPS_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.data_time",
|
||||
log.PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_latency",
|
||||
log.PERF_METER(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_latency_at100",
|
||||
log.LAT_100(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_latency_at99",
|
||||
log.LAT_99(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=TIME_METADATA,
|
||||
)
|
||||
logger.register_metric(
|
||||
f"{prefix}.compute_latency_at95",
|
||||
log.LAT_95(),
|
||||
verbosity=dllogger.Verbosity.VERBOSE,
|
||||
metadata=TIME_METADATA,
|
||||
)
|
||||
|
||||
step = get_val_step(model_and_loss, use_amp=use_amp)
|
||||
|
||||
def validate(infer_fn, val_loader, log_fn, prof=-1, with_loss=True):
|
||||
top1 = log.AverageMeter()
|
||||
# switch to evaluate mode
|
||||
model_and_loss.eval()
|
||||
|
||||
end = time.time()
|
||||
|
||||
data_iter = enumerate(val_loader)
|
||||
if not logger is None:
|
||||
data_iter = logger.iteration_generator_wrapper(data_iter, mode='val')
|
||||
|
||||
for i, (input, target) in data_iter:
|
||||
bs = input.size(0)
|
||||
data_time = time.time() - end
|
||||
|
||||
loss, prec1, prec5 = step(input, target)
|
||||
if with_loss:
|
||||
loss, output = infer_fn(input, target)
|
||||
else:
|
||||
output = infer_fn(input)
|
||||
|
||||
with torch.no_grad():
|
||||
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
if with_loss:
|
||||
reduced_loss = utils.reduce_tensor(loss.detach())
|
||||
prec1 = utils.reduce_tensor(prec1)
|
||||
prec5 = utils.reduce_tensor(prec5)
|
||||
else:
|
||||
if with_loss:
|
||||
reduced_loss = loss.detach()
|
||||
|
||||
prec1 = prec1.item()
|
||||
prec5 = prec5.item()
|
||||
infer_result = {
|
||||
"top1": (prec1, bs),
|
||||
"top5": (prec5, bs),
|
||||
}
|
||||
|
||||
if with_loss:
|
||||
infer_result["loss"] = (reduced_loss.item(), bs)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
it_time = time.time() - end
|
||||
|
||||
top1.record(prec1.item(), bs)
|
||||
if logger is not None:
|
||||
logger.log_metric(f"{prefix}.top1", prec1.item(), bs)
|
||||
logger.log_metric(f"{prefix}.top5", prec5.item(), bs)
|
||||
logger.log_metric(f"{prefix}.loss", loss.item(), bs)
|
||||
logger.log_metric(f"{prefix}.compute_ips", utils.calc_ips(bs, it_time - data_time))
|
||||
logger.log_metric(f"{prefix}.total_ips", utils.calc_ips(bs, it_time))
|
||||
logger.log_metric(f"{prefix}.data_time", data_time)
|
||||
logger.log_metric(f"{prefix}.compute_latency", it_time - data_time)
|
||||
logger.log_metric(f"{prefix}.compute_latency_at95", it_time - data_time)
|
||||
logger.log_metric(f"{prefix}.compute_latency_at99", it_time - data_time)
|
||||
logger.log_metric(f"{prefix}.compute_latency_at100", it_time - data_time)
|
||||
top1.record(prec1, bs)
|
||||
|
||||
log_fn(
|
||||
compute_ips=utils.calc_ips(bs, it_time - data_time),
|
||||
total_ips=utils.calc_ips(bs, it_time),
|
||||
data_time=data_time,
|
||||
compute_time=it_time - data_time,
|
||||
**infer_result,
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
if (prof > 0) and (i + 1 >= prof):
|
||||
|
@ -445,22 +302,14 @@ def validate(
|
|||
|
||||
|
||||
# Train loop {{{
|
||||
|
||||
|
||||
def train_loop(
|
||||
model_and_loss,
|
||||
optimizer,
|
||||
scaler,
|
||||
trainer: Trainer,
|
||||
lr_scheduler,
|
||||
train_loader,
|
||||
train_loader_len,
|
||||
val_loader,
|
||||
logger,
|
||||
should_backup_checkpoint,
|
||||
steps_per_epoch,
|
||||
ema=None,
|
||||
model_ema=None,
|
||||
use_amp=False,
|
||||
batch_size_multiplier=1,
|
||||
best_prec1=0,
|
||||
start_epoch=0,
|
||||
end_epoch=0,
|
||||
|
@ -472,13 +321,21 @@ def train_loop(
|
|||
checkpoint_dir="./",
|
||||
checkpoint_filename="checkpoint.pth.tar",
|
||||
):
|
||||
train_metrics = TrainingMetrics(logger)
|
||||
val_metrics = {
|
||||
k: ValidationMetrics(logger, k) for k in trainer.validation_steps().keys()
|
||||
}
|
||||
training_step = trainer.train_step
|
||||
|
||||
prec1 = -1
|
||||
use_ema = (model_ema is not None) and (ema is not None)
|
||||
|
||||
if early_stopping_patience > 0:
|
||||
epochs_since_improvement = 0
|
||||
backup_prefix = checkpoint_filename[:-len("checkpoint.pth.tar")] if \
|
||||
checkpoint_filename.endswith("checkpoint.pth.tar") else ""
|
||||
backup_prefix = (
|
||||
checkpoint_filename[: -len("checkpoint.pth.tar")]
|
||||
if checkpoint_filename.endswith("checkpoint.pth.tar")
|
||||
else ""
|
||||
)
|
||||
|
||||
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
|
||||
with utils.TimeoutHandler() as timeout_handler:
|
||||
|
@ -487,73 +344,71 @@ def train_loop(
|
|||
if logger is not None:
|
||||
logger.start_epoch()
|
||||
if not skip_training:
|
||||
if logger is not None:
|
||||
data_iter = logger.iteration_generator_wrapper(
|
||||
train_loader, mode="train"
|
||||
)
|
||||
else:
|
||||
data_iter = train_loader
|
||||
|
||||
trainer.train()
|
||||
interrupted = train(
|
||||
train_loader,
|
||||
model_and_loss,
|
||||
optimizer,
|
||||
scaler,
|
||||
lr_scheduler,
|
||||
logger,
|
||||
epoch,
|
||||
steps_per_epoch,
|
||||
training_step,
|
||||
data_iter,
|
||||
lambda i: lr_scheduler(trainer.optimizer, i, epoch),
|
||||
train_metrics.log,
|
||||
timeout_handler,
|
||||
ema=ema,
|
||||
use_amp=use_amp,
|
||||
prof=prof,
|
||||
register_metrics=epoch == start_epoch,
|
||||
batch_size_multiplier=batch_size_multiplier,
|
||||
step=epoch * train_loader_len,
|
||||
)
|
||||
|
||||
if not skip_validation:
|
||||
prec1, nimg = validate(
|
||||
val_loader,
|
||||
model_and_loss,
|
||||
logger,
|
||||
epoch,
|
||||
use_amp=use_amp,
|
||||
prof=prof,
|
||||
register_metrics=epoch == start_epoch,
|
||||
trainer.eval()
|
||||
for k, infer_fn in trainer.validation_steps().items():
|
||||
if logger is not None:
|
||||
data_iter = logger.iteration_generator_wrapper(
|
||||
val_loader, mode="val"
|
||||
)
|
||||
if use_ema:
|
||||
model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()})
|
||||
prec1, nimg = validate(
|
||||
val_loader,
|
||||
model_ema,
|
||||
logger,
|
||||
epoch,
|
||||
else:
|
||||
data_iter = val_loader
|
||||
|
||||
step_prec1, _ = validate(
|
||||
infer_fn,
|
||||
data_iter,
|
||||
val_metrics[k].log,
|
||||
prof=prof,
|
||||
register_metrics=epoch == start_epoch,
|
||||
prefix='val_ema'
|
||||
)
|
||||
|
||||
if k == "val":
|
||||
prec1 = step_prec1
|
||||
|
||||
if prec1 > best_prec1:
|
||||
is_best = True
|
||||
best_prec1 = prec1
|
||||
else:
|
||||
is_best = False
|
||||
else:
|
||||
is_best = True
|
||||
is_best = False
|
||||
best_prec1 = 0
|
||||
|
||||
if logger is not None:
|
||||
logger.end_epoch()
|
||||
|
||||
if save_checkpoints and (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
not torch.distributed.is_initialized()
|
||||
or torch.distributed.get_rank() == 0
|
||||
):
|
||||
if should_backup_checkpoint(epoch):
|
||||
backup_filename = "{}checkpoint-{}.pth.tar".format(backup_prefix, epoch + 1)
|
||||
backup_filename = "{}checkpoint-{}.pth.tar".format(
|
||||
backup_prefix, epoch + 1
|
||||
)
|
||||
else:
|
||||
backup_filename = None
|
||||
checkpoint_state = {
|
||||
"epoch": epoch + 1,
|
||||
"state_dict": model_and_loss.model.state_dict(),
|
||||
"best_prec1": best_prec1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
**trainer.state_dict(),
|
||||
}
|
||||
if use_ema:
|
||||
checkpoint_state["state_dict_ema"] = ema.state_dict()
|
||||
|
||||
utils.save_checkpoint(
|
||||
checkpoint_state,
|
||||
is_best,
|
||||
|
@ -561,6 +416,7 @@ def train_loop(
|
|||
backup_filename=backup_filename,
|
||||
filename=checkpoint_filename,
|
||||
)
|
||||
|
||||
if early_stopping_patience > 0:
|
||||
if not is_best:
|
||||
epochs_since_improvement += 1
|
||||
|
@ -570,4 +426,6 @@ def train_loop(
|
|||
break
|
||||
if interrupted:
|
||||
break
|
||||
|
||||
|
||||
# }}}
|
||||
|
|
|
@ -97,7 +97,7 @@ def accuracy(output, target, topk=(1,)):
|
|||
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
rt = tensor.clone().detach()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= (
|
||||
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||
|
@ -114,6 +114,7 @@ class TimeoutHandler:
|
|||
def __init__(self, sig=signal.SIGTERM):
|
||||
self.sig = sig
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
@property
|
||||
def interrupted(self):
|
||||
if not dist.is_initialized():
|
||||
|
|
|
@ -58,6 +58,12 @@ from image_classification.models import (
|
|||
efficientnet_widese_b0,
|
||||
efficientnet_widese_b4,
|
||||
)
|
||||
from image_classification.optimizers import (
|
||||
get_optimizer,
|
||||
lr_cosine_policy,
|
||||
lr_linear_policy,
|
||||
lr_step_policy,
|
||||
)
|
||||
import dllogger
|
||||
|
||||
|
||||
|
@ -102,7 +108,9 @@ def add_parser_arguments(parser, skip_arch=False):
|
|||
metavar="ARCH",
|
||||
default="resnet50",
|
||||
choices=model_names,
|
||||
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)",
|
||||
help="model architecture: "
|
||||
+ " | ".join(model_names)
|
||||
+ " (default: resnet50)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
@ -290,6 +298,13 @@ def add_parser_arguments(parser, skip_arch=False):
|
|||
dest="save_checkpoints",
|
||||
help="do not store any checkpoints, useful for benchmarking",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str,
|
||||
default = "no",
|
||||
choices=["no", "script"],
|
||||
help="no -> do not use torch.jit; script -> use torch.jit.script"
|
||||
)
|
||||
|
||||
parser.add_argument("--checkpoint-filename", default="checkpoint.pth.tar", type=str)
|
||||
|
||||
|
@ -320,7 +335,7 @@ def add_parser_arguments(parser, skip_arch=False):
|
|||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="number of classes"
|
||||
help="number of classes",
|
||||
)
|
||||
|
||||
|
||||
|
@ -432,13 +447,25 @@ def prepare_for_training(args, model_args, model_arch):
|
|||
if args.image_size is not None
|
||||
else model.arch.default_image_size
|
||||
)
|
||||
model_and_loss = ModelAndLoss(model, loss, cuda=True, memory_format=memory_format)
|
||||
if args.use_ema is not None:
|
||||
model_ema = deepcopy(model_and_loss)
|
||||
ema = EMA(args.use_ema)
|
||||
else:
|
||||
model_ema = None
|
||||
ema = None
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=args.static_loss_scale,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=100 if args.dynamic_loss_scale else 1000000000,
|
||||
enabled=args.amp,
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
model,
|
||||
loss(),
|
||||
cuda=True,
|
||||
memory_format=memory_format,
|
||||
amp=args.amp,
|
||||
scaler=scaler,
|
||||
divide_loss=batch_size_multiplier,
|
||||
ts_script = args.jit == "script",
|
||||
)
|
||||
|
||||
# Create data loaders and optimizers as needed
|
||||
if args.data_backend == "pytorch":
|
||||
|
@ -463,7 +490,7 @@ def prepare_for_training(args, model_args, model_arch):
|
|||
args.batch_size,
|
||||
model_args.num_classes,
|
||||
args.mixup > 0.0,
|
||||
interpolation = args.interpolation,
|
||||
interpolation=args.interpolation,
|
||||
augmentation=args.augmentation,
|
||||
start_epoch=start_epoch,
|
||||
workers=args.workers,
|
||||
|
@ -478,7 +505,7 @@ def prepare_for_training(args, model_args, model_arch):
|
|||
args.batch_size,
|
||||
model_args.num_classes,
|
||||
False,
|
||||
interpolation = args.interpolation,
|
||||
interpolation=args.interpolation,
|
||||
workers=args.workers,
|
||||
memory_format=memory_format,
|
||||
)
|
||||
|
@ -508,41 +535,38 @@ def prepare_for_training(args, model_args, model_arch):
|
|||
)
|
||||
|
||||
optimizer = get_optimizer(
|
||||
list(model_and_loss.model.named_parameters()),
|
||||
list(executor.model.named_parameters()),
|
||||
args.lr,
|
||||
args=args,
|
||||
state=optimizer_state,
|
||||
)
|
||||
|
||||
if args.lr_schedule == "step":
|
||||
lr_policy = lr_step_policy(
|
||||
args.lr, [30, 60, 80], 0.1, args.warmup, logger=logger
|
||||
)
|
||||
lr_policy = lr_step_policy(args.lr, [30, 60, 80], 0.1, args.warmup)
|
||||
elif args.lr_schedule == "cosine":
|
||||
lr_policy = lr_cosine_policy(
|
||||
args.lr, args.warmup, args.epochs, end_lr=args.end_lr, logger=logger
|
||||
args.lr, args.warmup, args.epochs, end_lr=args.end_lr
|
||||
)
|
||||
elif args.lr_schedule == "linear":
|
||||
lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs, logger=logger)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=args.static_loss_scale,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=100 if args.dynamic_loss_scale else 1000000000,
|
||||
enabled=args.amp,
|
||||
)
|
||||
lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs)
|
||||
|
||||
if args.distributed:
|
||||
model_and_loss.distributed(args.gpu)
|
||||
executor.distributed(args.gpu)
|
||||
|
||||
model_and_loss.load_model_state(model_state)
|
||||
if (ema is not None) and (model_state_ema is not None):
|
||||
print("load ema")
|
||||
ema.load_state_dict(model_state_ema)
|
||||
if model_state is not None:
|
||||
executor.model.load_state_dict(model_state)
|
||||
|
||||
return (model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema,
|
||||
train_loader_len, batch_size_multiplier, start_epoch)
|
||||
trainer = Trainer(
|
||||
executor,
|
||||
optimizer,
|
||||
grad_acc_steps=batch_size_multiplier,
|
||||
ema=args.use_ema,
|
||||
)
|
||||
|
||||
if (args.use_ema is not None) and (model_state_ema is not None):
|
||||
trainer.ema_executor.model.load_state_dict(model_state_ema)
|
||||
|
||||
return (trainer, lr_policy, train_loader, train_loader_len, val_loader, logger, start_epoch)
|
||||
|
||||
|
||||
def main(args, model_args, model_arch):
|
||||
|
@ -550,23 +574,24 @@ def main(args, model_args, model_arch):
|
|||
global best_prec1
|
||||
best_prec1 = 0
|
||||
|
||||
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \
|
||||
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch)
|
||||
|
||||
train_loop(
|
||||
model_and_loss,
|
||||
optimizer,
|
||||
scaler,
|
||||
(
|
||||
trainer,
|
||||
lr_policy,
|
||||
train_loader,
|
||||
train_loader_len,
|
||||
val_loader,
|
||||
logger,
|
||||
start_epoch,
|
||||
) = prepare_for_training(args, model_args, model_arch)
|
||||
|
||||
train_loop(
|
||||
trainer,
|
||||
lr_policy,
|
||||
train_loader,
|
||||
train_loader_len,
|
||||
val_loader,
|
||||
logger,
|
||||
should_backup_checkpoint(args),
|
||||
ema=ema,
|
||||
model_ema=model_ema,
|
||||
steps_per_epoch=train_loader_len,
|
||||
use_amp=args.amp,
|
||||
batch_size_multiplier=batch_size_multiplier,
|
||||
start_epoch=start_epoch,
|
||||
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
|
||||
if args.run_epochs != -1
|
||||
|
@ -587,7 +612,6 @@ def main(args, model_args, model_arch):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
epilog = [
|
||||
"Based on the architecture picked by --arch flag, you may use the following options:\n"
|
||||
]
|
||||
|
|
|
@ -68,7 +68,9 @@ def parse_quantization(parser):
|
|||
metavar="ARCH",
|
||||
default="efficientnet-quant-b0",
|
||||
choices=model_names,
|
||||
help="model architecture: " + " | ".join(model_names) + " (default: efficientnet-quant-b0)",
|
||||
help="model architecture: "
|
||||
+ " | ".join(model_names)
|
||||
+ " (default: efficientnet-quant-b0)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
@ -80,6 +82,7 @@ def parse_quantization(parser):
|
|||
|
||||
def parse_training_args(parser):
|
||||
from main import add_parser_arguments
|
||||
|
||||
return add_parser_arguments(parser)
|
||||
|
||||
|
||||
|
@ -92,28 +95,29 @@ def main(args, model_args, model_arch):
|
|||
|
||||
select_default_calib_method()
|
||||
|
||||
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \
|
||||
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch)
|
||||
(
|
||||
trainer,
|
||||
lr_policy,
|
||||
train_loader,
|
||||
train_loader_len,
|
||||
val_loader,
|
||||
logger,
|
||||
start_epoch,
|
||||
) = prepare_for_training(args, model_args, model_arch)
|
||||
|
||||
print(f"RUNNING QUANTIZATION")
|
||||
|
||||
if not skip_calibration:
|
||||
calibrate(model_and_loss.model, train_loader, logger, calib_iter=10)
|
||||
calibrate(trainer.model_and_loss.model, train_loader, logger, calib_iter=10)
|
||||
|
||||
train_loop(
|
||||
model_and_loss,
|
||||
optimizer,
|
||||
scaler,
|
||||
trainer,
|
||||
lr_policy,
|
||||
train_loader,
|
||||
train_loader_len,
|
||||
val_loader,
|
||||
logger,
|
||||
should_backup_checkpoint(args),
|
||||
ema=ema,
|
||||
model_ema=model_ema,
|
||||
steps_per_epoch=train_loader_len,
|
||||
use_amp=args.amp,
|
||||
batch_size_multiplier=batch_size_multiplier,
|
||||
start_epoch=start_epoch,
|
||||
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
|
||||
if args.run_epochs != -1
|
||||
|
@ -124,7 +128,7 @@ def main(args, model_args, model_arch):
|
|||
skip_validation=args.training_only,
|
||||
save_checkpoints=args.save_checkpoints,
|
||||
checkpoint_dir=args.workspace,
|
||||
checkpoint_filename='quantized_' + args.checkpoint_filename,
|
||||
checkpoint_filename="quantized_" + args.checkpoint_filename,
|
||||
)
|
||||
|
||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||
|
@ -159,4 +163,3 @@ if __name__ == "__main__":
|
|||
cudnn.benchmark = True
|
||||
|
||||
main(args, model_args, model_arch)
|
||||
|
||||
|
|
Loading…
Reference in a new issue