[ConvNets/PyT] TorchScriptable ConvNets
This commit is contained in:
parent
3d3250a3ae
commit
4f2c6922bd
2
PyTorch/Classification/ConvNets/.dockerignore
Normal file
2
PyTorch/Classification/ConvNets/.dockerignore
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
*.pth.tar
|
||||||
|
*.log
|
|
@ -22,6 +22,7 @@ def add_parser_arguments(parser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--weight-path", metavar="<path>", help="name of file in which to store weights"
|
"--weight-path", metavar="<path>", help="name of file in which to store weights"
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--ema", action="store_true", default=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -30,12 +31,13 @@ if __name__ == "__main__":
|
||||||
add_parser_arguments(parser)
|
add_parser_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint_path, map_location=torch.device('cpu'))
|
checkpoint = torch.load(args.checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
|
||||||
|
key = "state_dict" if not args.ema else "ema_state_dict"
|
||||||
model_state_dict = {
|
model_state_dict = {
|
||||||
k[len("module.") :] if "module." in k else k: v
|
k[len("module.") :] if "module." in k else k: v
|
||||||
for k, v in checkpoint["state_dict"].items()
|
for k, v in checkpoint["state_dict"].items()
|
||||||
}
|
}
|
||||||
print(f"Loaded model, acc : {checkpoint['best_prec1']}")
|
print(f"Loaded model, acc : {checkpoint['best_prec1']}")
|
||||||
|
|
||||||
torch.save(model_state_dict, args.weight_path)
|
torch.save(model_state_dict, args.weight_path)
|
||||||
|
|
|
@ -234,32 +234,29 @@ class Logger(object):
|
||||||
def log_metric(self, metric_name, val, n=1):
|
def log_metric(self, metric_name, val, n=1):
|
||||||
self.metrics[metric_name]["meter"].record(val, n=n)
|
self.metrics[metric_name]["meter"].record(val, n=n)
|
||||||
|
|
||||||
def start_iteration(self, mode='train'):
|
def start_iteration(self, mode="train"):
|
||||||
if mode == 'val':
|
if mode == "val":
|
||||||
self.val_iteration += 1
|
self.val_iteration += 1
|
||||||
elif mode == 'train':
|
elif mode == "train":
|
||||||
self.iteration += 1
|
self.iteration += 1
|
||||||
elif mode == 'calib':
|
elif mode == "calib":
|
||||||
self.calib_iteration += 1
|
self.calib_iteration += 1
|
||||||
|
|
||||||
def end_iteration(self, mode='train'):
|
def end_iteration(self, mode="train"):
|
||||||
if mode == 'val':
|
if mode == "val":
|
||||||
it = self.val_iteration
|
it = self.val_iteration
|
||||||
elif mode == 'train':
|
elif mode == "train":
|
||||||
it = self.iteration
|
it = self.iteration
|
||||||
elif mode == 'calib':
|
elif mode == "calib":
|
||||||
it = self.calib_iteration
|
it = self.calib_iteration
|
||||||
|
if it % self.print_interval == 0 or mode == "calib":
|
||||||
if it % self.print_interval == 0 or mode == 'calib':
|
metrics = {n: m for n, m in self.metrics.items() if n.startswith(mode)}
|
||||||
metrics = {
|
if mode == "train":
|
||||||
n: m for n, m in self.metrics.items() if n.startswith(mode)
|
|
||||||
}
|
|
||||||
if mode == 'train':
|
|
||||||
step = (self.epoch, self.iteration)
|
step = (self.epoch, self.iteration)
|
||||||
elif mode == 'val':
|
elif mode == "val":
|
||||||
step = (self.epoch, self.iteration, self.val_iteration)
|
step = (self.epoch, self.iteration, self.val_iteration)
|
||||||
elif mode == 'calib':
|
elif mode == "calib":
|
||||||
step = ('Calibration', self.calib_iteration)
|
step = ("Calibration", self.calib_iteration)
|
||||||
|
|
||||||
verbositys = {m["level"] for _, m in metrics.items()}
|
verbositys = {m["level"] for _, m in metrics.items()}
|
||||||
for ll in verbositys:
|
for ll in verbositys:
|
||||||
|
@ -282,12 +279,12 @@ class Logger(object):
|
||||||
self.val_iteration = 0
|
self.val_iteration = 0
|
||||||
|
|
||||||
for n, m in self.metrics.items():
|
for n, m in self.metrics.items():
|
||||||
if not n.startswith('calib'):
|
if not n.startswith("calib"):
|
||||||
m["meter"].reset_epoch()
|
m["meter"].reset_epoch()
|
||||||
|
|
||||||
def end_epoch(self):
|
def end_epoch(self):
|
||||||
for n, m in self.metrics.items():
|
for n, m in self.metrics.items():
|
||||||
if not n.startswith('calib'):
|
if not n.startswith("calib"):
|
||||||
m["meter"].reset_iteration()
|
m["meter"].reset_iteration()
|
||||||
|
|
||||||
verbositys = {m["level"] for _, m in self.metrics.items()}
|
verbositys = {m["level"] for _, m in self.metrics.items()}
|
||||||
|
@ -302,12 +299,12 @@ class Logger(object):
|
||||||
self.calib_iteration = 0
|
self.calib_iteration = 0
|
||||||
|
|
||||||
for n, m in self.metrics.items():
|
for n, m in self.metrics.items():
|
||||||
if n.startswith('calib'):
|
if n.startswith("calib"):
|
||||||
m["meter"].reset_epoch()
|
m["meter"].reset_epoch()
|
||||||
|
|
||||||
def end_calibration(self):
|
def end_calibration(self):
|
||||||
for n, m in self.metrics.items():
|
for n, m in self.metrics.items():
|
||||||
if n.startswith('calib'):
|
if n.startswith("calib"):
|
||||||
m["meter"].reset_iteration()
|
m["meter"].reset_iteration()
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
|
@ -326,7 +323,7 @@ class Logger(object):
|
||||||
|
|
||||||
dllogger.flush()
|
dllogger.flush()
|
||||||
|
|
||||||
def iteration_generator_wrapper(self, gen, mode='train'):
|
def iteration_generator_wrapper(self, gen, mode="train"):
|
||||||
for g in gen:
|
for g in gen:
|
||||||
self.start_iteration(mode=mode)
|
self.start_iteration(mode=mode)
|
||||||
yield g
|
yield g
|
||||||
|
@ -337,3 +334,155 @@ class Logger(object):
|
||||||
self.start_epoch()
|
self.start_epoch()
|
||||||
yield g
|
yield g
|
||||||
self.end_epoch()
|
self.end_epoch()
|
||||||
|
|
||||||
|
|
||||||
|
class Metrics:
|
||||||
|
ACC_METADATA = {"unit": "%", "format": ":.2f"}
|
||||||
|
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
|
||||||
|
TIME_METADATA = {"unit": "s", "format": ":.5f"}
|
||||||
|
LOSS_METADATA = {"format": ":.5f"}
|
||||||
|
LR_METADATA = {"format": ":.5f"}
|
||||||
|
|
||||||
|
def __init__(self, logger):
|
||||||
|
self.logger = logger
|
||||||
|
self.map = {}
|
||||||
|
|
||||||
|
def log(self, **kwargs):
|
||||||
|
if self.logger is None:
|
||||||
|
return
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
tks = self.map.get(k, [k])
|
||||||
|
for tk in tks:
|
||||||
|
if isinstance(v, tuple):
|
||||||
|
self.logger.log_metric(tk, v[0], v[1])
|
||||||
|
else:
|
||||||
|
self.logger.log_metric(tk, v)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingMetrics(Metrics):
|
||||||
|
def __init__(self, logger):
|
||||||
|
super().__init__(logger)
|
||||||
|
if self.logger is not None:
|
||||||
|
self.map = {
|
||||||
|
"loss": ["train.loss"],
|
||||||
|
"compute_ips": ["train.compute_ips"],
|
||||||
|
"total_ips": ["train.total_ips"],
|
||||||
|
"data_time": ["train.data_time"],
|
||||||
|
"compute_time": ["train.compute_time"],
|
||||||
|
"lr": ["train.lr"],
|
||||||
|
}
|
||||||
|
logger.register_metric(
|
||||||
|
"train.loss",
|
||||||
|
LOSS_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.LOSS_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
"train.compute_ips",
|
||||||
|
PERF_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.IPS_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
"train.total_ips",
|
||||||
|
PERF_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.IPS_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
"train.data_time",
|
||||||
|
PERF_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.VERBOSE,
|
||||||
|
metadata=Metrics.TIME_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
"train.compute_time",
|
||||||
|
PERF_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.VERBOSE,
|
||||||
|
metadata=Metrics.TIME_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
"train.lr",
|
||||||
|
LR_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationMetrics(Metrics):
|
||||||
|
def __init__(self, logger, prefix):
|
||||||
|
super().__init__(logger)
|
||||||
|
if self.logger is not None:
|
||||||
|
self.map = {
|
||||||
|
"loss": [f"{prefix}.loss"],
|
||||||
|
"top1": [f"{prefix}.top1"],
|
||||||
|
"top5": [f"{prefix}.top5"],
|
||||||
|
"compute_ips": [f"{prefix}.compute_ips"],
|
||||||
|
"total_ips": [f"{prefix}.total_ips"],
|
||||||
|
"data_time": [f"{prefix}.data_time"],
|
||||||
|
"compute_time": [
|
||||||
|
f"{prefix}.compute_latency",
|
||||||
|
f"{prefix}.compute_latency_at100",
|
||||||
|
f"{prefix}.compute_latency_at99",
|
||||||
|
f"{prefix}.compute_latency_at95",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.top1",
|
||||||
|
ACC_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.ACC_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.top5",
|
||||||
|
ACC_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.ACC_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.loss",
|
||||||
|
LOSS_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.LOSS_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.compute_ips",
|
||||||
|
PERF_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.IPS_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.total_ips",
|
||||||
|
PERF_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.IPS_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.data_time",
|
||||||
|
PERF_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.VERBOSE,
|
||||||
|
metadata=Metrics.TIME_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.compute_latency",
|
||||||
|
PERF_METER(),
|
||||||
|
verbosity=dllogger.Verbosity.DEFAULT,
|
||||||
|
metadata=Metrics.TIME_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.compute_latency_at100",
|
||||||
|
LAT_100(),
|
||||||
|
verbosity=dllogger.Verbosity.VERBOSE,
|
||||||
|
metadata=Metrics.TIME_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.compute_latency_at99",
|
||||||
|
LAT_99(),
|
||||||
|
verbosity=dllogger.Verbosity.VERBOSE,
|
||||||
|
metadata=Metrics.TIME_METADATA,
|
||||||
|
)
|
||||||
|
logger.register_metric(
|
||||||
|
f"{prefix}.compute_latency_at95",
|
||||||
|
LAT_95(),
|
||||||
|
verbosity=dllogger.Verbosity.VERBOSE,
|
||||||
|
metadata=Metrics.TIME_METADATA,
|
||||||
|
)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pytorch_quantization import nn as quant_nn
|
from pytorch_quantization import nn as quant_nn
|
||||||
|
@ -143,30 +144,44 @@ class LambdaLayer(nn.Module):
|
||||||
|
|
||||||
# SqueezeAndExcitation {{{
|
# SqueezeAndExcitation {{{
|
||||||
class SqueezeAndExcitation(nn.Module):
|
class SqueezeAndExcitation(nn.Module):
|
||||||
def __init__(self, in_channels, squeeze, activation, use_conv=False):
|
def __init__(self, in_channels, squeeze, activation):
|
||||||
super(SqueezeAndExcitation, self).__init__()
|
super(SqueezeAndExcitation, self).__init__()
|
||||||
if use_conv:
|
self.squeeze = nn.Linear(in_channels, squeeze)
|
||||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
self.expand = nn.Linear(squeeze, in_channels)
|
||||||
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
|
|
||||||
self.expand = nn.Conv2d(squeeze, in_channels, 1)
|
|
||||||
else:
|
|
||||||
self.squeeze = nn.Linear(in_channels, squeeze)
|
|
||||||
self.expand = nn.Linear(squeeze, in_channels)
|
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
self.use_conv = use_conv
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.use_conv:
|
return self._attention(x)
|
||||||
out = self.pooling(x)
|
|
||||||
else:
|
def _attention(self, x):
|
||||||
out = torch.mean(x, [2, 3])
|
out = torch.mean(x, [2, 3])
|
||||||
|
out = self.squeeze(out)
|
||||||
|
out = self.activation(out)
|
||||||
|
out = self.expand(out)
|
||||||
|
out = self.sigmoid(out)
|
||||||
|
out = out.unsqueeze(2).unsqueeze(3)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeAndExcitationTRT(nn.Module):
|
||||||
|
def __init__(self, in_channels, squeeze, activation):
|
||||||
|
super(SqueezeAndExcitationTRT, self).__init__()
|
||||||
|
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
|
||||||
|
self.expand = nn.Conv2d(squeeze, in_channels, 1)
|
||||||
|
self.activation = activation
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._attention(x)
|
||||||
|
|
||||||
|
def _attention(self, x):
|
||||||
|
out = self.pooling(x)
|
||||||
out = self.squeeze(out)
|
out = self.squeeze(out)
|
||||||
out = self.activation(out)
|
out = self.activation(out)
|
||||||
out = self.expand(out)
|
out = self.expand(out)
|
||||||
out = self.sigmoid(out)
|
out = self.sigmoid(out)
|
||||||
if not self.use_conv:
|
|
||||||
out = out.unsqueeze(2).unsqueeze(3)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,18 +189,9 @@ class SqueezeAndExcitation(nn.Module):
|
||||||
|
|
||||||
# EMA {{{
|
# EMA {{{
|
||||||
class EMA:
|
class EMA:
|
||||||
def __init__(self, mu):
|
def __init__(self, mu, module_ema):
|
||||||
self.mu = mu
|
self.mu = mu
|
||||||
self.shadow = {}
|
self.module_ema = module_ema
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
return copy.deepcopy(self.shadow)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
self.shadow = state_dict
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.shadow)
|
|
||||||
|
|
||||||
def __call__(self, module, step=None):
|
def __call__(self, module, step=None):
|
||||||
if step is None:
|
if step is None:
|
||||||
|
@ -193,12 +199,17 @@ class EMA:
|
||||||
else:
|
else:
|
||||||
mu = min(self.mu, (1.0 + step) / (10 + step))
|
mu = min(self.mu, (1.0 + step) / (10 + step))
|
||||||
|
|
||||||
for name, x in module.state_dict().items():
|
def strip_module(s: str) -> str:
|
||||||
if name in self.shadow:
|
return s
|
||||||
new_average = (1.0 - mu) * x + mu * self.shadow[name]
|
|
||||||
self.shadow[name] = new_average.clone()
|
mesd = self.module_ema.state_dict()
|
||||||
else:
|
with torch.no_grad():
|
||||||
self.shadow[name] = x.clone()
|
for name, x in module.state_dict().items():
|
||||||
|
if name.endswith("num_batches_tracked"):
|
||||||
|
continue
|
||||||
|
n = strip_module(name)
|
||||||
|
mesd[n].mul_(mu)
|
||||||
|
mesd[n].add_((1.0 - mu) * x)
|
||||||
|
|
||||||
|
|
||||||
# }}}
|
# }}}
|
||||||
|
@ -218,10 +229,8 @@ class ONNXSiLU(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
|
class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
|
||||||
def __init__(
|
def __init__(self, in_channels, squeeze, activation, quantized=False):
|
||||||
self, in_channels, squeeze, activation, quantized=False, use_conv=False
|
super().__init__(in_channels, squeeze, activation)
|
||||||
):
|
|
||||||
super().__init__(in_channels, squeeze, activation, use_conv=use_conv)
|
|
||||||
self.quantized = quantized
|
self.quantized = quantized
|
||||||
if quantized:
|
if quantized:
|
||||||
assert quant_nn is not None, "pytorch_quantization is not available"
|
assert quant_nn is not None, "pytorch_quantization is not available"
|
||||||
|
@ -231,10 +240,63 @@ class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
|
||||||
self.mul_b_quantizer = quant_nn.TensorQuantizer(
|
self.mul_b_quantizer = quant_nn.TensorQuantizer(
|
||||||
quant_nn.QuantConv2d.default_quant_desc_input
|
quant_nn.QuantConv2d.default_quant_desc_input
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.mul_a_quantizer = nn.Identity()
|
||||||
|
self.mul_b_quantizer = nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
out = self._attention(x)
|
||||||
if not self.quantized:
|
if not self.quantized:
|
||||||
return super().forward(x) * x
|
return out * x
|
||||||
else:
|
else:
|
||||||
x_quant = self.mul_a_quantizer(super().forward(x))
|
x_quant = self.mul_a_quantizer(out)
|
||||||
return x_quant * self.mul_b_quantizer(x)
|
return x_quant * self.mul_b_quantizer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialSqueezeAndExcitationTRT(SqueezeAndExcitationTRT):
|
||||||
|
def __init__(self, in_channels, squeeze, activation, quantized=False):
|
||||||
|
super().__init__(in_channels, squeeze, activation)
|
||||||
|
self.quantized = quantized
|
||||||
|
if quantized:
|
||||||
|
assert quant_nn is not None, "pytorch_quantization is not available"
|
||||||
|
self.mul_a_quantizer = quant_nn.TensorQuantizer(
|
||||||
|
quant_nn.QuantConv2d.default_quant_desc_input
|
||||||
|
)
|
||||||
|
self.mul_b_quantizer = quant_nn.TensorQuantizer(
|
||||||
|
quant_nn.QuantConv2d.default_quant_desc_input
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mul_a_quantizer = nn.Identity()
|
||||||
|
self.mul_b_quantizer = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self._attention(x)
|
||||||
|
if not self.quantized:
|
||||||
|
return out * x
|
||||||
|
else:
|
||||||
|
x_quant = self.mul_a_quantizer(out)
|
||||||
|
return x_quant * self.mul_b_quantizer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class StochasticDepthResidual(nn.Module):
|
||||||
|
def __init__(self, survival_prob: float):
|
||||||
|
super().__init__()
|
||||||
|
self.survival_prob = survival_prob
|
||||||
|
self.register_buffer("mask", torch.ones(()), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.training:
|
||||||
|
return torch.add(residual, other=x)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
F.dropout(
|
||||||
|
self.mask,
|
||||||
|
p=1 - self.survival_prob,
|
||||||
|
training=self.training,
|
||||||
|
inplace=False,
|
||||||
|
)
|
||||||
|
return torch.addcmul(residual, self.mask, x)
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x.squeeze(-1).squeeze(-1)
|
||||||
|
|
|
@ -31,11 +31,11 @@ except ImportError as e:
|
||||||
|
|
||||||
|
|
||||||
from .common import (
|
from .common import (
|
||||||
SqueezeAndExcitation,
|
|
||||||
ONNXSiLU,
|
|
||||||
SequentialSqueezeAndExcitation,
|
SequentialSqueezeAndExcitation,
|
||||||
|
SequentialSqueezeAndExcitationTRT,
|
||||||
LayerBuilder,
|
LayerBuilder,
|
||||||
LambdaLayer,
|
StochasticDepthResidual,
|
||||||
|
Flatten,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .model import (
|
from .model import (
|
||||||
|
@ -206,6 +206,7 @@ class EfficientNet(nn.Module):
|
||||||
out_channels = arch.stem_channels
|
out_channels = arch.stem_channels
|
||||||
|
|
||||||
plc = 0
|
plc = 0
|
||||||
|
layers = []
|
||||||
for i, (k, s, r, e, c) in arch.enumerate():
|
for i, (k, s, r, e, c) in arch.enumerate():
|
||||||
layer, out_channels = self._make_layer(
|
layer, out_channels = self._make_layer(
|
||||||
block=arch.block,
|
block=arch.block,
|
||||||
|
@ -220,8 +221,8 @@ class EfficientNet(nn.Module):
|
||||||
trt=trt,
|
trt=trt,
|
||||||
)
|
)
|
||||||
plc = plc + r
|
plc = plc + r
|
||||||
setattr(self, f"layer{i+1}", layer)
|
layers.append(layer)
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
self.features = self._make_features(out_channels, arch.feature_channels)
|
self.features = self._make_features(out_channels, arch.feature_channels)
|
||||||
self.classifier = self._make_classifier(
|
self.classifier = self._make_classifier(
|
||||||
arch.feature_channels, num_classes, dropout
|
arch.feature_channels, num_classes, dropout
|
||||||
|
@ -229,11 +230,7 @@ class EfficientNet(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
|
x = self.layers(x)
|
||||||
for i in range(self.num_layers):
|
|
||||||
fn = getattr(self, f"layer{i+1}")
|
|
||||||
x = fn(x)
|
|
||||||
|
|
||||||
x = self.features(x)
|
x = self.features(x)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
|
|
||||||
|
@ -241,27 +238,34 @@ class EfficientNet(nn.Module):
|
||||||
|
|
||||||
def extract_features(self, x, layers=None):
|
def extract_features(self, x, layers=None):
|
||||||
if layers is None:
|
if layers is None:
|
||||||
layers = [f"layer{i+1}" for i in range(self.num_layers)]
|
layers = [f"layer{i+1}" for i in range(self.num_layers)] + [
|
||||||
|
"features",
|
||||||
|
"classifier",
|
||||||
|
]
|
||||||
|
|
||||||
run = [
|
run = [
|
||||||
f"layer{i+1}"
|
i
|
||||||
for i in range(self.num_layers)
|
for i in range(self.num_layers)
|
||||||
if "classifier" in layers
|
if "classifier" in layers
|
||||||
or "features" in layers
|
or "features" in layers
|
||||||
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
|
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
|
||||||
]
|
]
|
||||||
if "features" in layers or "classifier" in layers:
|
|
||||||
run.append("features")
|
|
||||||
if "classifier" in layers:
|
|
||||||
run.append("classifier")
|
|
||||||
|
|
||||||
output = {}
|
output = {}
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
for l in run:
|
for l in run:
|
||||||
fn = getattr(self, l)
|
fn = self.layers[l]
|
||||||
x = fn(x)
|
x = fn(x)
|
||||||
if l in layers:
|
if f"layer{l+1}" in layers:
|
||||||
output[l] = x
|
output[f"layer{l+1}"] = x
|
||||||
|
|
||||||
|
if "features" in layers or "classifier" in layers:
|
||||||
|
x = self.features(x)
|
||||||
|
if "features" in layers:
|
||||||
|
output["features"] = x
|
||||||
|
|
||||||
|
if "classifier" in layers:
|
||||||
|
output["classifier"] = self.classifier(x)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -298,7 +302,7 @@ class EfficientNet(nn.Module):
|
||||||
OrderedDict(
|
OrderedDict(
|
||||||
[
|
[
|
||||||
("pooling", nn.AdaptiveAvgPool2d(1)),
|
("pooling", nn.AdaptiveAvgPool2d(1)),
|
||||||
("squeeze", LambdaLayer(lambda x: x.squeeze(-1).squeeze(-1))),
|
("squeeze", Flatten()),
|
||||||
("dropout", nn.Dropout(dropout)),
|
("dropout", nn.Dropout(dropout)),
|
||||||
("fc", nn.Linear(num_features, num_classes)),
|
("fc", nn.Linear(num_features, num_classes)),
|
||||||
]
|
]
|
||||||
|
@ -353,11 +357,33 @@ class EfficientNet(nn.Module):
|
||||||
layers.append((f"block{idx}", blk))
|
layers.append((f"block{idx}", blk))
|
||||||
return nn.Sequential(OrderedDict(layers)), out_channels
|
return nn.Sequential(OrderedDict(layers)), out_channels
|
||||||
|
|
||||||
|
def ngc_checkpoint_remap(self, url=None, version=None):
|
||||||
|
if version is None:
|
||||||
|
version = url.split("/")[8]
|
||||||
|
|
||||||
|
def to_sequential_remap(s):
|
||||||
|
splited = s.split(".")
|
||||||
|
if splited[0].startswith("layer"):
|
||||||
|
return ".".join(
|
||||||
|
["layers." + str(int(splited[0][len("layer") :]) - 1)] + splited[1:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return s
|
||||||
|
|
||||||
|
def no_remap(s):
|
||||||
|
return s
|
||||||
|
|
||||||
|
return {"20.12.0": to_sequential_remap, "21.03.0": to_sequential_remap}.get(
|
||||||
|
version, no_remap
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# }}}
|
# }}}
|
||||||
|
|
||||||
# MBConvBlock {{{
|
# MBConvBlock {{{
|
||||||
class MBConvBlock(nn.Module):
|
class MBConvBlock(nn.Module):
|
||||||
|
__constants__ = ["quantized"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
builder: LayerBuilder,
|
builder: LayerBuilder,
|
||||||
|
@ -366,7 +392,7 @@ class MBConvBlock(nn.Module):
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
expand_ratio: int,
|
expand_ratio: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
squeeze_excitation_ratio: int,
|
squeeze_excitation_ratio: float,
|
||||||
squeeze_hidden=False,
|
squeeze_hidden=False,
|
||||||
survival_prob: float = 1.0,
|
survival_prob: float = 1.0,
|
||||||
quantized: bool = False,
|
quantized: bool = False,
|
||||||
|
@ -387,25 +413,31 @@ class MBConvBlock(nn.Module):
|
||||||
self.depsep = builder.convDepSep(
|
self.depsep = builder.convDepSep(
|
||||||
depsep_kernel_size, hidden_dim, hidden_dim, stride, bn=True, act=True
|
depsep_kernel_size, hidden_dim, hidden_dim, stride, bn=True, act=True
|
||||||
)
|
)
|
||||||
self.se = SequentialSqueezeAndExcitation(
|
if trt or self.quantized:
|
||||||
hidden_dim, squeeze_dim, builder.activation(), self.quantized, use_conv=trt
|
# Need TRT mode for quantized in order to automatically insert quantization before pooling
|
||||||
)
|
self.se: nn.Module = SequentialSqueezeAndExcitationTRT(
|
||||||
|
hidden_dim, squeeze_dim, builder.activation(), self.quantized
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.se: nn.Module = SequentialSqueezeAndExcitation(
|
||||||
|
hidden_dim, squeeze_dim, builder.activation(), self.quantized
|
||||||
|
)
|
||||||
|
|
||||||
self.proj = builder.conv1x1(hidden_dim, out_channels, bn=True)
|
self.proj = builder.conv1x1(hidden_dim, out_channels, bn=True)
|
||||||
|
|
||||||
self.survival_prob = survival_prob
|
if survival_prob == 1.0:
|
||||||
|
self.residual_add = torch.add
|
||||||
|
else:
|
||||||
|
self.residual_add = StochasticDepthResidual(survival_prob=survival_prob)
|
||||||
if self.quantized and self.residual:
|
if self.quantized and self.residual:
|
||||||
assert quant_nn is not None, "pytorch_quantization is not available"
|
assert quant_nn is not None, "pytorch_quantization is not available"
|
||||||
self.residual_quantizer = quant_nn.TensorQuantizer(
|
self.residual_quantizer = quant_nn.TensorQuantizer(
|
||||||
quant_nn.QuantConv2d.default_quant_desc_input
|
quant_nn.QuantConv2d.default_quant_desc_input
|
||||||
) # TODO QuantConv2d ?!?
|
) # TODO QuantConv2d ?!?
|
||||||
|
else:
|
||||||
|
self.residual_quantizer = nn.Identity()
|
||||||
|
|
||||||
def drop(self):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.survival_prob == 1.0:
|
|
||||||
return False
|
|
||||||
return random.uniform(0.0, 1.0) > self.survival_prob
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if not self.residual:
|
if not self.residual:
|
||||||
return self.proj(
|
return self.proj(
|
||||||
self.se(self.depsep(x if self.expand is None else self.expand(x)))
|
self.se(self.depsep(x if self.expand is None else self.expand(x)))
|
||||||
|
@ -414,16 +446,10 @@ class MBConvBlock(nn.Module):
|
||||||
b = self.proj(
|
b = self.proj(
|
||||||
self.se(self.depsep(x if self.expand is None else self.expand(x)))
|
self.se(self.depsep(x if self.expand is None else self.expand(x)))
|
||||||
)
|
)
|
||||||
if self.training:
|
|
||||||
if self.drop():
|
|
||||||
multiplication_factor = 0.0
|
|
||||||
else:
|
|
||||||
multiplication_factor = 1.0 / self.survival_prob
|
|
||||||
else:
|
|
||||||
multiplication_factor = 1.0
|
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
x = self.residual_quantizer(x)
|
x = self.residual_quantizer(x)
|
||||||
return torch.add(x, alpha=multiplication_factor, other=b)
|
|
||||||
|
return self.residual_add(x, b)
|
||||||
|
|
||||||
|
|
||||||
def original_mbconv(
|
def original_mbconv(
|
||||||
|
@ -436,7 +462,7 @@ def original_mbconv(
|
||||||
squeeze_excitation_ratio: int,
|
squeeze_excitation_ratio: int,
|
||||||
survival_prob: float,
|
survival_prob: float,
|
||||||
quantized: bool,
|
quantized: bool,
|
||||||
trt: bool
|
trt: bool,
|
||||||
):
|
):
|
||||||
return MBConvBlock(
|
return MBConvBlock(
|
||||||
builder,
|
builder,
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
from dataclasses import dataclass, asdict, replace
|
from dataclasses import dataclass, asdict, replace
|
||||||
|
from .common import (
|
||||||
|
SequentialSqueezeAndExcitationTRT,
|
||||||
|
SequentialSqueezeAndExcitation,
|
||||||
|
SqueezeAndExcitation,
|
||||||
|
SqueezeAndExcitationTRT,
|
||||||
|
)
|
||||||
from typing import Optional, Callable
|
from typing import Optional, Callable
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -37,7 +44,13 @@ class EntryPoint:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def __call__(self, pretrained=False, pretrained_from_file=None, **kwargs):
|
def __call__(
|
||||||
|
self,
|
||||||
|
pretrained=False,
|
||||||
|
pretrained_from_file=None,
|
||||||
|
state_dict_key_map_fn=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
assert not (pretrained and (pretrained_from_file is not None))
|
assert not (pretrained and (pretrained_from_file is not None))
|
||||||
params = replace(self.model.params, **kwargs)
|
params = replace(self.model.params, **kwargs)
|
||||||
|
|
||||||
|
@ -66,7 +79,7 @@ class EntryPoint:
|
||||||
pretrained_from_file
|
pretrained_from_file
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Temporary fix to allow NGC checkpoint loading
|
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k[len("module.") :] if k.startswith("module.") else k: v
|
k[len("module.") :] if k.startswith("module.") else k: v
|
||||||
|
@ -85,12 +98,32 @@ class EntryPoint:
|
||||||
else:
|
else:
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
if state_dict_key_map_fn is not None:
|
||||||
|
state_dict = {
|
||||||
|
state_dict_key_map_fn(k): v for k, v in state_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasattr(model, "ngc_checkpoint_remap"):
|
||||||
|
remap_fn = model.ngc_checkpoint_remap(url=self.model.checkpoint_url)
|
||||||
|
state_dict = {remap_fn(k): v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
def _se_layer_uses_conv(m):
|
||||||
|
return any(
|
||||||
|
map(
|
||||||
|
partial(isinstance, m),
|
||||||
|
[
|
||||||
|
SqueezeAndExcitationTRT,
|
||||||
|
SequentialSqueezeAndExcitationTRT,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k: reshape(
|
k: reshape(
|
||||||
v,
|
v,
|
||||||
conv=dict(model.named_modules())[
|
conv=_se_layer_uses_conv(
|
||||||
".".join(k.split(".")[:-2])
|
dict(model.named_modules())[".".join(k.split(".")[:-2])]
|
||||||
].use_conv,
|
),
|
||||||
)
|
)
|
||||||
if is_se_weight(k, v)
|
if is_se_weight(k, v)
|
||||||
else v
|
else v
|
||||||
|
@ -123,7 +156,8 @@ class EntryPoint:
|
||||||
|
|
||||||
|
|
||||||
def is_se_weight(key, value):
|
def is_se_weight(key, value):
|
||||||
return (key.endswith("squeeze.weight") or key.endswith("expand.weight"))
|
return key.endswith("squeeze.weight") or key.endswith("expand.weight")
|
||||||
|
|
||||||
|
|
||||||
def create_entrypoint(m: Model):
|
def create_entrypoint(m: Model):
|
||||||
def _ep(**kwargs):
|
def _ep(**kwargs):
|
||||||
|
|
|
@ -36,14 +36,16 @@ from typing import List, Dict, Callable, Any, Type
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .common import SqueezeAndExcitation, LayerBuilder, LambdaLayer
|
from .common import (
|
||||||
|
SqueezeAndExcitation,
|
||||||
|
LayerBuilder,
|
||||||
|
SqueezeAndExcitationTRT,
|
||||||
|
)
|
||||||
|
|
||||||
from .model import (
|
from .model import (
|
||||||
Model,
|
Model,
|
||||||
ModelParams,
|
ModelParams,
|
||||||
ModelArch,
|
ModelArch,
|
||||||
OptimizerParams,
|
|
||||||
create_entrypoint,
|
|
||||||
EntryPoint,
|
EntryPoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -128,11 +130,18 @@ class Bottleneck(nn.Module):
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
self.fused_se = fused_se
|
self.fused_se = fused_se
|
||||||
self.squeeze = (
|
if se:
|
||||||
SqueezeAndExcitation(planes * expansion, se_squeeze, builder.activation(), use_conv=trt)
|
self.squeeze = (
|
||||||
if se
|
SqueezeAndExcitation(
|
||||||
else None
|
planes * expansion, se_squeeze, builder.activation()
|
||||||
)
|
)
|
||||||
|
if not trt
|
||||||
|
else SqueezeAndExcitationTRT(
|
||||||
|
planes * expansion, se_squeeze, builder.activation()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.squeeze = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
@ -215,6 +224,7 @@ class ResNet(nn.Module):
|
||||||
last_bn_0_init: bool = False
|
last_bn_0_init: bool = False
|
||||||
conv_init: str = "fan_in"
|
conv_init: str = "fan_in"
|
||||||
trt: bool = False
|
trt: bool = False
|
||||||
|
fused_se: bool = True
|
||||||
|
|
||||||
def parser(self, name):
|
def parser(self, name):
|
||||||
p = super().parser(name)
|
p = super().parser(name)
|
||||||
|
@ -240,6 +250,10 @@ class ResNet(nn.Module):
|
||||||
help="initialization mode for convolutional layers, see https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_",
|
help="initialization mode for convolutional layers, see https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_",
|
||||||
)
|
)
|
||||||
p.add_argument("--trt", metavar="True|False", default=self.trt, type=bool)
|
p.add_argument("--trt", metavar="True|False", default=self.trt, type=bool)
|
||||||
|
p.add_argument(
|
||||||
|
"--fused_se", metavar="True|False", default=self.fused_se, type=bool
|
||||||
|
)
|
||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -249,6 +263,7 @@ class ResNet(nn.Module):
|
||||||
last_bn_0_init: bool = False,
|
last_bn_0_init: bool = False,
|
||||||
conv_init: str = "fan_in",
|
conv_init: str = "fan_in",
|
||||||
trt: bool = False,
|
trt: bool = False,
|
||||||
|
fused_se: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
|
@ -265,6 +280,7 @@ class ResNet(nn.Module):
|
||||||
inplanes = arch.stem_width
|
inplanes = arch.stem_width
|
||||||
assert len(arch.widths) == len(arch.layers)
|
assert len(arch.widths) == len(arch.layers)
|
||||||
self.num_layers = len(arch.widths)
|
self.num_layers = len(arch.widths)
|
||||||
|
layers = []
|
||||||
for i, (w, l) in enumerate(zip(arch.widths, arch.layers)):
|
for i, (w, l) in enumerate(zip(arch.widths, arch.layers)):
|
||||||
layer, inplanes = self._make_layer(
|
layer, inplanes = self._make_layer(
|
||||||
arch.block,
|
arch.block,
|
||||||
|
@ -275,9 +291,11 @@ class ResNet(nn.Module):
|
||||||
cardinality=arch.cardinality,
|
cardinality=arch.cardinality,
|
||||||
stride=1 if i == 0 else 2,
|
stride=1 if i == 0 else 2,
|
||||||
trt=trt,
|
trt=trt,
|
||||||
|
fused_se=fused_se,
|
||||||
)
|
)
|
||||||
setattr(self, f"layer{i+1}", layer)
|
layers.append(layer)
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
self.fc = nn.Linear(arch.widths[-1] * arch.expansion, num_classes)
|
self.fc = nn.Linear(arch.widths[-1] * arch.expansion, num_classes)
|
||||||
|
|
||||||
|
@ -297,13 +315,8 @@ class ResNet(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
|
x = self.layers(x)
|
||||||
for i in range(self.num_layers):
|
|
||||||
fn = getattr(self, f"layer{i+1}")
|
|
||||||
x = fn(x)
|
|
||||||
|
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def extract_features(self, x, layers=None):
|
def extract_features(self, x, layers=None):
|
||||||
|
@ -311,7 +324,7 @@ class ResNet(nn.Module):
|
||||||
layers = [f"layer{i+1}" for i in range(self.num_layers)] + ["classifier"]
|
layers = [f"layer{i+1}" for i in range(self.num_layers)] + ["classifier"]
|
||||||
|
|
||||||
run = [
|
run = [
|
||||||
f"layer{i+1}"
|
i
|
||||||
for i in range(self.num_layers)
|
for i in range(self.num_layers)
|
||||||
if "classifier" in layers
|
if "classifier" in layers
|
||||||
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
|
or any([f"layer{j+1}" in layers for j in range(i, self.num_layers)])
|
||||||
|
@ -320,10 +333,10 @@ class ResNet(nn.Module):
|
||||||
output = {}
|
output = {}
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
for l in run:
|
for l in run:
|
||||||
fn = getattr(self, l)
|
fn = self.layers[l]
|
||||||
x = fn(x)
|
x = fn(x)
|
||||||
if l in layers:
|
if f"layer{l+1}" in layers:
|
||||||
output[l] = x
|
output[f"layer{l+1}"] = x
|
||||||
|
|
||||||
if "classifier" in layers:
|
if "classifier" in layers:
|
||||||
output["classifier"] = self.classifier(x)
|
output["classifier"] = self.classifier(x)
|
||||||
|
@ -332,7 +345,16 @@ class ResNet(nn.Module):
|
||||||
|
|
||||||
# helper functions {{{
|
# helper functions {{{
|
||||||
def _make_layer(
|
def _make_layer(
|
||||||
self, block, expansion, inplanes, planes, blocks, stride=1, cardinality=1, trt=False,
|
self,
|
||||||
|
block,
|
||||||
|
expansion,
|
||||||
|
inplanes,
|
||||||
|
planes,
|
||||||
|
blocks,
|
||||||
|
stride=1,
|
||||||
|
cardinality=1,
|
||||||
|
trt=False,
|
||||||
|
fused_se=True,
|
||||||
):
|
):
|
||||||
downsample = None
|
downsample = None
|
||||||
if stride != 1 or inplanes != planes * expansion:
|
if stride != 1 or inplanes != planes * expansion:
|
||||||
|
@ -354,15 +376,33 @@ class ResNet(nn.Module):
|
||||||
stride=stride if i == 0 else 1,
|
stride=stride if i == 0 else 1,
|
||||||
cardinality=cardinality,
|
cardinality=cardinality,
|
||||||
downsample=downsample if i == 0 else None,
|
downsample=downsample if i == 0 else None,
|
||||||
fused_se=True,
|
fused_se=fused_se,
|
||||||
last_bn_0_init=self.last_bn_0_init,
|
last_bn_0_init=self.last_bn_0_init,
|
||||||
trt = trt,
|
trt=trt,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
inplanes = planes * expansion
|
inplanes = planes * expansion
|
||||||
|
|
||||||
return nn.Sequential(*layers), inplanes
|
return nn.Sequential(*layers), inplanes
|
||||||
|
|
||||||
|
def ngc_checkpoint_remap(self, url=None, version=None):
|
||||||
|
if version is None:
|
||||||
|
version = url.split("/")[8]
|
||||||
|
|
||||||
|
def to_sequential_remap(s):
|
||||||
|
splited = s.split(".")
|
||||||
|
if splited[0].startswith("layer"):
|
||||||
|
return ".".join(
|
||||||
|
["layers." + str(int(splited[0][len("layer") :]) - 1)] + splited[1:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return s
|
||||||
|
|
||||||
|
def no_remap(s):
|
||||||
|
return s
|
||||||
|
|
||||||
|
return {"20.06.0": to_sequential_remap}.get(version, no_remap)
|
||||||
|
|
||||||
# }}}
|
# }}}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,39 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
|
||||||
def get_sgd_optimizer(parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False):
|
|
||||||
|
def get_optimizer(parameters, lr, args, state=None):
|
||||||
|
if args.optimizer == "sgd":
|
||||||
|
optimizer = get_sgd_optimizer(
|
||||||
|
parameters,
|
||||||
|
lr,
|
||||||
|
momentum=args.momentum,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
nesterov=args.nesterov,
|
||||||
|
bn_weight_decay=args.bn_weight_decay,
|
||||||
|
)
|
||||||
|
elif args.optimizer == "rmsprop":
|
||||||
|
optimizer = get_rmsprop_optimizer(
|
||||||
|
parameters,
|
||||||
|
lr,
|
||||||
|
alpha=args.rmsprop_alpha,
|
||||||
|
momentum=args.momentum,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
eps=args.rmsprop_eps,
|
||||||
|
bn_weight_decay=args.bn_weight_decay,
|
||||||
|
)
|
||||||
|
if not state is None:
|
||||||
|
optimizer.load_state_dict(state)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_sgd_optimizer(
|
||||||
|
parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False
|
||||||
|
):
|
||||||
if bn_weight_decay:
|
if bn_weight_decay:
|
||||||
print(" ! Weight decay applied to BN parameters ")
|
print(" ! Weight decay applied to BN parameters ")
|
||||||
params = [v for n, v in parameters]
|
params = [v for n, v in parameters]
|
||||||
|
@ -17,20 +49,112 @@ def get_sgd_optimizer(parameters, lr, momentum, weight_decay, nesterov=False, bn
|
||||||
{"params": rest_params, "weight_decay": weight_decay},
|
{"params": rest_params, "weight_decay": weight_decay},
|
||||||
]
|
]
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
|
optimizer = torch.optim.SGD(
|
||||||
|
params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov
|
||||||
|
)
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def get_rmsprop_optimizer(parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False):
|
def get_rmsprop_optimizer(
|
||||||
|
parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False
|
||||||
|
):
|
||||||
bn_params = [v for n, v in parameters if "bn" in n]
|
bn_params = [v for n, v in parameters if "bn" in n]
|
||||||
rest_params = [v for n, v in parameters if not "bn" in n]
|
rest_params = [v for n, v in parameters if not "bn" in n]
|
||||||
|
|
||||||
params = [
|
params = [
|
||||||
{"params": bn_params, "weight_decay": weight_decay if bn_weight_decay else 0},
|
{"params": bn_params, "weight_decay": weight_decay if bn_weight_decay else 0},
|
||||||
{"params": rest_params, "weight_decay": weight_decay},
|
{"params": rest_params, "weight_decay": weight_decay},
|
||||||
]
|
]
|
||||||
|
|
||||||
optimizer = torch.optim.RMSprop(params, lr=lr, alpha=alpha, weight_decay=weight_decay, momentum=momentum, eps=eps)
|
optimizer = torch.optim.RMSprop(
|
||||||
|
params,
|
||||||
|
lr=lr,
|
||||||
|
alpha=alpha,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
momentum=momentum,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def lr_policy(lr_fn):
|
||||||
|
def _alr(optimizer, iteration, epoch):
|
||||||
|
lr = lr_fn(iteration, epoch)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group["lr"] = lr
|
||||||
|
|
||||||
|
return lr
|
||||||
|
|
||||||
|
return _alr
|
||||||
|
|
||||||
|
|
||||||
|
def lr_step_policy(base_lr, steps, decay_factor, warmup_length):
|
||||||
|
def _lr_fn(iteration, epoch):
|
||||||
|
if epoch < warmup_length:
|
||||||
|
lr = base_lr * (epoch + 1) / warmup_length
|
||||||
|
else:
|
||||||
|
lr = base_lr
|
||||||
|
for s in steps:
|
||||||
|
if epoch >= s:
|
||||||
|
lr *= decay_factor
|
||||||
|
return lr
|
||||||
|
|
||||||
|
return lr_policy(_lr_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def lr_linear_policy(base_lr, warmup_length, epochs):
|
||||||
|
def _lr_fn(iteration, epoch):
|
||||||
|
if epoch < warmup_length:
|
||||||
|
lr = base_lr * (epoch + 1) / warmup_length
|
||||||
|
else:
|
||||||
|
e = epoch - warmup_length
|
||||||
|
es = epochs - warmup_length
|
||||||
|
lr = base_lr * (1 - (e / es))
|
||||||
|
return lr
|
||||||
|
|
||||||
|
return lr_policy(_lr_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr=0):
|
||||||
|
def _lr_fn(iteration, epoch):
|
||||||
|
if epoch < warmup_length:
|
||||||
|
lr = base_lr * (epoch + 1) / warmup_length
|
||||||
|
else:
|
||||||
|
e = epoch - warmup_length
|
||||||
|
es = epochs - warmup_length
|
||||||
|
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
|
||||||
|
return lr
|
||||||
|
|
||||||
|
return lr_policy(_lr_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def lr_exponential_policy(
|
||||||
|
base_lr,
|
||||||
|
warmup_length,
|
||||||
|
epochs,
|
||||||
|
final_multiplier=0.001,
|
||||||
|
decay_factor=None,
|
||||||
|
decay_step=1,
|
||||||
|
logger=None,
|
||||||
|
):
|
||||||
|
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
|
||||||
|
es = epochs - warmup_length
|
||||||
|
|
||||||
|
if decay_factor is not None:
|
||||||
|
epoch_decay = decay_factor
|
||||||
|
else:
|
||||||
|
epoch_decay = np.power(
|
||||||
|
2, np.log2(final_multiplier) / math.floor(es / decay_step)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _lr_fn(iteration, epoch):
|
||||||
|
if epoch < warmup_length:
|
||||||
|
lr = base_lr * (epoch + 1) / warmup_length
|
||||||
|
else:
|
||||||
|
e = epoch - warmup_length
|
||||||
|
lr = base_lr * (epoch_decay ** math.floor(e / decay_step))
|
||||||
|
return lr
|
||||||
|
|
||||||
|
return lr_policy(_lr_fn, logger=logger)
|
||||||
|
|
|
@ -27,279 +27,212 @@
|
||||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
import math
|
|
||||||
import time
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.autograd import Variable
|
|
||||||
from . import logger as log
|
|
||||||
from . import models
|
|
||||||
from . import utils
|
|
||||||
import dllogger
|
|
||||||
|
|
||||||
from .optimizers import get_sgd_optimizer, get_rmsprop_optimizer
|
|
||||||
from .models.common import EMA
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
ACC_METADATA = {"unit": "%", "format": ":.2f"}
|
from . import logger as log
|
||||||
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
|
from . import utils
|
||||||
TIME_METADATA = {"unit": "s", "format": ":.5f"}
|
from .logger import TrainingMetrics, ValidationMetrics
|
||||||
LOSS_METADATA = {"format": ":.5f"}
|
from .models.common import EMA
|
||||||
|
|
||||||
|
|
||||||
class ModelAndLoss(nn.Module):
|
class Executor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model: nn.Module,
|
||||||
loss,
|
loss: Optional[nn.Module],
|
||||||
cuda=True,
|
cuda: bool = True,
|
||||||
memory_format=torch.contiguous_format,
|
memory_format: torch.memory_format = torch.contiguous_format,
|
||||||
|
amp: bool = False,
|
||||||
|
scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
||||||
|
divide_loss: int = 1,
|
||||||
|
ts_script: bool = False,
|
||||||
):
|
):
|
||||||
super(ModelAndLoss, self).__init__()
|
assert not (amp and scaler is None), "Gradient Scaler is needed for AMP"
|
||||||
|
|
||||||
if cuda:
|
def xform(m: nn.Module) -> nn.Module:
|
||||||
model = model.cuda().to(memory_format=memory_format)
|
if cuda:
|
||||||
|
m = m.cuda()
|
||||||
|
m.to(memory_format=memory_format)
|
||||||
|
return m
|
||||||
|
|
||||||
# define loss function (criterion) and optimizer
|
self.model = xform(model)
|
||||||
criterion = loss()
|
if ts_script:
|
||||||
|
self.model = torch.jit.script(self.model)
|
||||||
if cuda:
|
self.ts_script = ts_script
|
||||||
criterion = criterion.cuda()
|
self.loss = xform(loss) if loss is not None else None
|
||||||
|
self.amp = amp
|
||||||
self.model = model
|
self.scaler = scaler
|
||||||
self.loss = criterion
|
self.is_distributed = False
|
||||||
|
self.divide_loss = divide_loss
|
||||||
def forward(self, data, target):
|
self._fwd_bwd = None
|
||||||
output = self.model(data)
|
self._forward = None
|
||||||
loss = self.loss(output, target)
|
|
||||||
|
|
||||||
return loss, output
|
|
||||||
|
|
||||||
def distributed(self, gpu_id):
|
def distributed(self, gpu_id):
|
||||||
self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id)
|
self.is_distributed = True
|
||||||
|
s = torch.cuda.Stream()
|
||||||
|
s.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id)
|
||||||
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
|
|
||||||
def load_model_state(self, state):
|
def _fwd_bwd_fn(
|
||||||
if not state is None:
|
self,
|
||||||
self.model.load_state_dict(state)
|
input: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
with autocast(enabled=self.amp):
|
||||||
|
loss = self.loss(self.model(input), target)
|
||||||
|
loss /= self.divide_loss
|
||||||
|
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _forward_fn(
|
||||||
|
self, input: torch.Tensor, target: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
with torch.no_grad(), autocast(enabled=self.amp):
|
||||||
|
output = self.model(input)
|
||||||
|
loss = None if self.loss is None else self.loss(output, target)
|
||||||
|
|
||||||
|
return output if loss is None else loss, output
|
||||||
|
|
||||||
|
def optimize(self, fn):
|
||||||
|
return fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def forward_backward(self):
|
||||||
|
if self._fwd_bwd is None:
|
||||||
|
if self.loss is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Loss must not be None for forward+backward step"
|
||||||
|
)
|
||||||
|
self._fwd_bwd = self.optimize(self._fwd_bwd_fn)
|
||||||
|
return self._fwd_bwd
|
||||||
|
|
||||||
|
@property
|
||||||
|
def forward(self):
|
||||||
|
if self._forward is None:
|
||||||
|
self._forward = self.optimize(self._forward_fn)
|
||||||
|
return self._forward
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(parameters, lr, args, state=None):
|
class Trainer:
|
||||||
if args.optimizer == 'sgd':
|
def __init__(
|
||||||
optimizer = get_sgd_optimizer(parameters, lr, momentum=args.momentum,
|
self,
|
||||||
weight_decay=args.weight_decay, nesterov=args.nesterov,
|
executor: Executor,
|
||||||
bn_weight_decay=args.bn_weight_decay)
|
optimizer: torch.optim.Optimizer,
|
||||||
elif args.optimizer == 'rmsprop':
|
grad_acc_steps: int,
|
||||||
optimizer = get_rmsprop_optimizer(parameters, lr, alpha=args.rmsprop_alpha, momentum=args.momentum,
|
ema: Optional[float] = None,
|
||||||
weight_decay=args.weight_decay,
|
):
|
||||||
eps=args.rmsprop_eps,
|
self.executor = executor
|
||||||
bn_weight_decay=args.bn_weight_decay)
|
self.optimizer = optimizer
|
||||||
if not state is None:
|
self.grad_acc_steps = grad_acc_steps
|
||||||
optimizer.load_state_dict(state)
|
self.use_ema = False
|
||||||
|
if ema is not None:
|
||||||
|
self.ema_executor = deepcopy(self.executor)
|
||||||
|
self.ema = EMA(ema, self.ema_executor.model)
|
||||||
|
self.use_ema = True
|
||||||
|
|
||||||
return optimizer
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
|
self.steps_since_update = 0
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.executor.model.train()
|
||||||
|
|
||||||
def lr_policy(lr_fn, logger=None):
|
def eval(self):
|
||||||
if logger is not None:
|
self.executor.model.eval()
|
||||||
logger.register_metric(
|
if self.use_ema:
|
||||||
"lr", log.LR_METER(), verbosity=dllogger.Verbosity.VERBOSE
|
self.executor.model.eval()
|
||||||
)
|
|
||||||
|
|
||||||
def _alr(optimizer, iteration, epoch):
|
def train_step(self, input, target, step=None):
|
||||||
lr = lr_fn(iteration, epoch)
|
loss = self.executor.forward_backward(input, target)
|
||||||
|
|
||||||
if logger is not None:
|
self.steps_since_update += 1
|
||||||
logger.log_metric("lr", lr)
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group["lr"] = lr
|
|
||||||
|
|
||||||
return _alr
|
if self.steps_since_update == self.grad_acc_steps:
|
||||||
|
if self.executor.scaler is not None:
|
||||||
|
self.executor.scaler.step(self.optimizer)
|
||||||
def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None):
|
self.executor.scaler.update()
|
||||||
def _lr_fn(iteration, epoch):
|
|
||||||
if epoch < warmup_length:
|
|
||||||
lr = base_lr * (epoch + 1) / warmup_length
|
|
||||||
else:
|
|
||||||
lr = base_lr
|
|
||||||
for s in steps:
|
|
||||||
if epoch >= s:
|
|
||||||
lr *= decay_factor
|
|
||||||
return lr
|
|
||||||
|
|
||||||
return lr_policy(_lr_fn, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
def lr_linear_policy(base_lr, warmup_length, epochs, logger=None):
|
|
||||||
def _lr_fn(iteration, epoch):
|
|
||||||
if epoch < warmup_length:
|
|
||||||
lr = base_lr * (epoch + 1) / warmup_length
|
|
||||||
else:
|
|
||||||
e = epoch - warmup_length
|
|
||||||
es = epochs - warmup_length
|
|
||||||
lr = base_lr * (1 - (e / es))
|
|
||||||
return lr
|
|
||||||
|
|
||||||
return lr_policy(_lr_fn, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr = 0, logger=None):
|
|
||||||
def _lr_fn(iteration, epoch):
|
|
||||||
if epoch < warmup_length:
|
|
||||||
lr = base_lr * (epoch + 1) / warmup_length
|
|
||||||
else:
|
|
||||||
e = epoch - warmup_length
|
|
||||||
es = epochs - warmup_length
|
|
||||||
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
|
|
||||||
return lr
|
|
||||||
|
|
||||||
return lr_policy(_lr_fn, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
def lr_exponential_policy(
|
|
||||||
base_lr, warmup_length, epochs, final_multiplier=0.001, decay_factor=None, decay_step=1, logger=None
|
|
||||||
):
|
|
||||||
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
|
|
||||||
es = epochs - warmup_length
|
|
||||||
|
|
||||||
if decay_factor is not None:
|
|
||||||
epoch_decay = decay_factor
|
|
||||||
else:
|
|
||||||
epoch_decay = np.power(2, np.log2(final_multiplier) / math.floor(es/decay_step))
|
|
||||||
|
|
||||||
def _lr_fn(iteration, epoch):
|
|
||||||
if epoch < warmup_length:
|
|
||||||
lr = base_lr * (epoch + 1) / warmup_length
|
|
||||||
else:
|
|
||||||
e = epoch - warmup_length
|
|
||||||
lr = base_lr * (epoch_decay ** math.floor(e/decay_step))
|
|
||||||
return lr
|
|
||||||
|
|
||||||
return lr_policy(_lr_fn, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_step(
|
|
||||||
model_and_loss, optimizer, scaler, use_amp=False, batch_size_multiplier=1
|
|
||||||
):
|
|
||||||
def _step(input, target, optimizer_step=True):
|
|
||||||
input_var = Variable(input)
|
|
||||||
target_var = Variable(target)
|
|
||||||
|
|
||||||
with autocast(enabled=use_amp):
|
|
||||||
loss, output = model_and_loss(input_var, target_var)
|
|
||||||
loss /= batch_size_multiplier
|
|
||||||
if torch.distributed.is_initialized():
|
|
||||||
reduced_loss = utils.reduce_tensor(loss.data)
|
|
||||||
else:
|
else:
|
||||||
reduced_loss = loss.data
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
scaler.scale(loss).backward()
|
self.steps_since_update = 0
|
||||||
|
|
||||||
if optimizer_step:
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return reduced_loss
|
if self.use_ema:
|
||||||
|
self.ema(self.executor.model, step=step)
|
||||||
|
|
||||||
return _step
|
return loss
|
||||||
|
|
||||||
|
def validation_steps(self) -> Dict[str, Callable]:
|
||||||
|
vsd: Dict[str, Callable] = {"val": self.executor.forward}
|
||||||
|
if self.use_ema:
|
||||||
|
vsd["val_ema"] = self.ema_executor.forward
|
||||||
|
return vsd
|
||||||
|
|
||||||
|
def state_dict(self) -> dict:
|
||||||
|
res = {
|
||||||
|
"state_dict": self.executor.model.state_dict(),
|
||||||
|
"optimizer": self.optimizer.state_dict(),
|
||||||
|
}
|
||||||
|
if self.use_ema:
|
||||||
|
res["state_dict_ema"] = self.ema_executor.model.state_dict()
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
train_step,
|
||||||
train_loader,
|
train_loader,
|
||||||
model_and_loss,
|
|
||||||
optimizer,
|
|
||||||
scaler,
|
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
logger,
|
log_fn,
|
||||||
epoch,
|
|
||||||
steps_per_epoch,
|
|
||||||
timeout_handler,
|
timeout_handler,
|
||||||
ema=None,
|
|
||||||
use_amp=False,
|
|
||||||
prof=-1,
|
prof=-1,
|
||||||
batch_size_multiplier=1,
|
step=0,
|
||||||
register_metrics=True,
|
|
||||||
):
|
):
|
||||||
interrupted = False
|
interrupted = False
|
||||||
if register_metrics and logger is not None:
|
|
||||||
logger.register_metric(
|
|
||||||
"train.loss",
|
|
||||||
log.LOSS_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.DEFAULT,
|
|
||||||
metadata=LOSS_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
"train.compute_ips",
|
|
||||||
log.PERF_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=IPS_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
"train.total_ips",
|
|
||||||
log.PERF_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.DEFAULT,
|
|
||||||
metadata=IPS_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
"train.data_time",
|
|
||||||
log.PERF_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=TIME_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
"train.compute_time",
|
|
||||||
log.PERF_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=TIME_METADATA,
|
|
||||||
)
|
|
||||||
|
|
||||||
step = get_train_step(
|
|
||||||
model_and_loss,
|
|
||||||
optimizer,
|
|
||||||
scaler=scaler,
|
|
||||||
use_amp=use_amp,
|
|
||||||
batch_size_multiplier=batch_size_multiplier,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_and_loss.train()
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
data_iter = enumerate(train_loader)
|
data_iter = enumerate(train_loader)
|
||||||
if logger is not None:
|
|
||||||
data_iter = logger.iteration_generator_wrapper(data_iter, mode='train')
|
|
||||||
|
|
||||||
for i, (input, target) in data_iter:
|
for i, (input, target) in data_iter:
|
||||||
bs = input.size(0)
|
bs = input.size(0)
|
||||||
lr_scheduler(optimizer, i, epoch)
|
lr = lr_scheduler(i)
|
||||||
data_time = time.time() - end
|
data_time = time.time() - end
|
||||||
|
|
||||||
optimizer_step = ((i + 1) % batch_size_multiplier) == 0
|
loss = train_step(input, target, step=step + i)
|
||||||
loss = step(input, target, optimizer_step=optimizer_step)
|
|
||||||
if ema is not None:
|
|
||||||
ema(model_and_loss, epoch*steps_per_epoch+i)
|
|
||||||
|
|
||||||
it_time = time.time() - end
|
it_time = time.time() - end
|
||||||
|
|
||||||
if logger is not None:
|
with torch.no_grad():
|
||||||
logger.log_metric("train.loss", loss.item(), bs)
|
if torch.distributed.is_initialized():
|
||||||
logger.log_metric("train.compute_ips", utils.calc_ips(bs, it_time - data_time))
|
reduced_loss = utils.reduce_tensor(loss.detach())
|
||||||
logger.log_metric("train.total_ips", utils.calc_ips(bs, it_time))
|
else:
|
||||||
logger.log_metric("train.data_time", data_time)
|
reduced_loss = loss.detach()
|
||||||
logger.log_metric("train.compute_time", it_time - data_time)
|
|
||||||
|
log_fn(
|
||||||
|
compute_ips=utils.calc_ips(bs, it_time - data_time),
|
||||||
|
total_ips=utils.calc_ips(bs, it_time),
|
||||||
|
data_time=data_time,
|
||||||
|
compute_time=it_time - data_time,
|
||||||
|
lr=lr,
|
||||||
|
loss=reduced_loss.item(),
|
||||||
|
)
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if prof > 0 and (i + 1 >= prof):
|
if prof > 0 and (i + 1 >= prof):
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
break
|
break
|
||||||
if ((i+1) % 20 == 0) and timeout_handler.interrupted:
|
if ((i + 1) % 20 == 0) and timeout_handler.interrupted:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
interrupted = True
|
interrupted = True
|
||||||
break
|
break
|
||||||
|
@ -307,134 +240,58 @@ def train(
|
||||||
return interrupted
|
return interrupted
|
||||||
|
|
||||||
|
|
||||||
def get_val_step(model_and_loss, use_amp=False):
|
def validate(infer_fn, val_loader, log_fn, prof=-1, with_loss=True):
|
||||||
def _step(input, target):
|
|
||||||
input_var = Variable(input)
|
|
||||||
target_var = Variable(target)
|
|
||||||
|
|
||||||
with torch.no_grad(), autocast(enabled=use_amp):
|
|
||||||
loss, output = model_and_loss(input_var, target_var)
|
|
||||||
|
|
||||||
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
|
|
||||||
|
|
||||||
if torch.distributed.is_initialized():
|
|
||||||
reduced_loss = utils.reduce_tensor(loss.data)
|
|
||||||
prec1 = utils.reduce_tensor(prec1)
|
|
||||||
prec5 = utils.reduce_tensor(prec5)
|
|
||||||
else:
|
|
||||||
reduced_loss = loss.data
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return reduced_loss, prec1, prec5
|
|
||||||
|
|
||||||
return _step
|
|
||||||
|
|
||||||
|
|
||||||
def validate(
|
|
||||||
val_loader,
|
|
||||||
model_and_loss,
|
|
||||||
logger,
|
|
||||||
epoch,
|
|
||||||
use_amp=False,
|
|
||||||
prof=-1,
|
|
||||||
register_metrics=True,
|
|
||||||
prefix="val",
|
|
||||||
):
|
|
||||||
if register_metrics and logger is not None:
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.top1",
|
|
||||||
log.ACC_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.DEFAULT,
|
|
||||||
metadata=ACC_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.top5",
|
|
||||||
log.ACC_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.DEFAULT,
|
|
||||||
metadata=ACC_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.loss",
|
|
||||||
log.LOSS_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.DEFAULT,
|
|
||||||
metadata=LOSS_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.compute_ips",
|
|
||||||
log.PERF_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=IPS_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.total_ips",
|
|
||||||
log.PERF_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.DEFAULT,
|
|
||||||
metadata=IPS_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.data_time",
|
|
||||||
log.PERF_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=TIME_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.compute_latency",
|
|
||||||
log.PERF_METER(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=TIME_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.compute_latency_at100",
|
|
||||||
log.LAT_100(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=TIME_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.compute_latency_at99",
|
|
||||||
log.LAT_99(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=TIME_METADATA,
|
|
||||||
)
|
|
||||||
logger.register_metric(
|
|
||||||
f"{prefix}.compute_latency_at95",
|
|
||||||
log.LAT_95(),
|
|
||||||
verbosity=dllogger.Verbosity.VERBOSE,
|
|
||||||
metadata=TIME_METADATA,
|
|
||||||
)
|
|
||||||
|
|
||||||
step = get_val_step(model_and_loss, use_amp=use_amp)
|
|
||||||
|
|
||||||
top1 = log.AverageMeter()
|
top1 = log.AverageMeter()
|
||||||
# switch to evaluate mode
|
# switch to evaluate mode
|
||||||
model_and_loss.eval()
|
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
data_iter = enumerate(val_loader)
|
data_iter = enumerate(val_loader)
|
||||||
if not logger is None:
|
|
||||||
data_iter = logger.iteration_generator_wrapper(data_iter, mode='val')
|
|
||||||
|
|
||||||
for i, (input, target) in data_iter:
|
for i, (input, target) in data_iter:
|
||||||
bs = input.size(0)
|
bs = input.size(0)
|
||||||
data_time = time.time() - end
|
data_time = time.time() - end
|
||||||
|
|
||||||
loss, prec1, prec5 = step(input, target)
|
if with_loss:
|
||||||
|
loss, output = infer_fn(input, target)
|
||||||
|
else:
|
||||||
|
output = infer_fn(input)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
|
||||||
|
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
if with_loss:
|
||||||
|
reduced_loss = utils.reduce_tensor(loss.detach())
|
||||||
|
prec1 = utils.reduce_tensor(prec1)
|
||||||
|
prec5 = utils.reduce_tensor(prec5)
|
||||||
|
else:
|
||||||
|
if with_loss:
|
||||||
|
reduced_loss = loss.detach()
|
||||||
|
|
||||||
|
prec1 = prec1.item()
|
||||||
|
prec5 = prec5.item()
|
||||||
|
infer_result = {
|
||||||
|
"top1": (prec1, bs),
|
||||||
|
"top5": (prec5, bs),
|
||||||
|
}
|
||||||
|
|
||||||
|
if with_loss:
|
||||||
|
infer_result["loss"] = (reduced_loss.item(), bs)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
it_time = time.time() - end
|
it_time = time.time() - end
|
||||||
|
|
||||||
top1.record(prec1.item(), bs)
|
top1.record(prec1, bs)
|
||||||
if logger is not None:
|
|
||||||
logger.log_metric(f"{prefix}.top1", prec1.item(), bs)
|
log_fn(
|
||||||
logger.log_metric(f"{prefix}.top5", prec5.item(), bs)
|
compute_ips=utils.calc_ips(bs, it_time - data_time),
|
||||||
logger.log_metric(f"{prefix}.loss", loss.item(), bs)
|
total_ips=utils.calc_ips(bs, it_time),
|
||||||
logger.log_metric(f"{prefix}.compute_ips", utils.calc_ips(bs, it_time - data_time))
|
data_time=data_time,
|
||||||
logger.log_metric(f"{prefix}.total_ips", utils.calc_ips(bs, it_time))
|
compute_time=it_time - data_time,
|
||||||
logger.log_metric(f"{prefix}.data_time", data_time)
|
**infer_result,
|
||||||
logger.log_metric(f"{prefix}.compute_latency", it_time - data_time)
|
)
|
||||||
logger.log_metric(f"{prefix}.compute_latency_at95", it_time - data_time)
|
|
||||||
logger.log_metric(f"{prefix}.compute_latency_at99", it_time - data_time)
|
|
||||||
logger.log_metric(f"{prefix}.compute_latency_at100", it_time - data_time)
|
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if (prof > 0) and (i + 1 >= prof):
|
if (prof > 0) and (i + 1 >= prof):
|
||||||
|
@ -445,22 +302,14 @@ def validate(
|
||||||
|
|
||||||
|
|
||||||
# Train loop {{{
|
# Train loop {{{
|
||||||
|
|
||||||
|
|
||||||
def train_loop(
|
def train_loop(
|
||||||
model_and_loss,
|
trainer: Trainer,
|
||||||
optimizer,
|
|
||||||
scaler,
|
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
train_loader,
|
train_loader,
|
||||||
|
train_loader_len,
|
||||||
val_loader,
|
val_loader,
|
||||||
logger,
|
logger,
|
||||||
should_backup_checkpoint,
|
should_backup_checkpoint,
|
||||||
steps_per_epoch,
|
|
||||||
ema=None,
|
|
||||||
model_ema=None,
|
|
||||||
use_amp=False,
|
|
||||||
batch_size_multiplier=1,
|
|
||||||
best_prec1=0,
|
best_prec1=0,
|
||||||
start_epoch=0,
|
start_epoch=0,
|
||||||
end_epoch=0,
|
end_epoch=0,
|
||||||
|
@ -472,14 +321,22 @@ def train_loop(
|
||||||
checkpoint_dir="./",
|
checkpoint_dir="./",
|
||||||
checkpoint_filename="checkpoint.pth.tar",
|
checkpoint_filename="checkpoint.pth.tar",
|
||||||
):
|
):
|
||||||
|
train_metrics = TrainingMetrics(logger)
|
||||||
|
val_metrics = {
|
||||||
|
k: ValidationMetrics(logger, k) for k in trainer.validation_steps().keys()
|
||||||
|
}
|
||||||
|
training_step = trainer.train_step
|
||||||
|
|
||||||
prec1 = -1
|
prec1 = -1
|
||||||
use_ema = (model_ema is not None) and (ema is not None)
|
|
||||||
|
|
||||||
if early_stopping_patience > 0:
|
if early_stopping_patience > 0:
|
||||||
epochs_since_improvement = 0
|
epochs_since_improvement = 0
|
||||||
backup_prefix = checkpoint_filename[:-len("checkpoint.pth.tar")] if \
|
backup_prefix = (
|
||||||
checkpoint_filename.endswith("checkpoint.pth.tar") else ""
|
checkpoint_filename[: -len("checkpoint.pth.tar")]
|
||||||
|
if checkpoint_filename.endswith("checkpoint.pth.tar")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
|
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
|
||||||
with utils.TimeoutHandler() as timeout_handler:
|
with utils.TimeoutHandler() as timeout_handler:
|
||||||
interrupted = False
|
interrupted = False
|
||||||
|
@ -487,73 +344,71 @@ def train_loop(
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.start_epoch()
|
logger.start_epoch()
|
||||||
if not skip_training:
|
if not skip_training:
|
||||||
|
if logger is not None:
|
||||||
|
data_iter = logger.iteration_generator_wrapper(
|
||||||
|
train_loader, mode="train"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data_iter = train_loader
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
interrupted = train(
|
interrupted = train(
|
||||||
train_loader,
|
training_step,
|
||||||
model_and_loss,
|
data_iter,
|
||||||
optimizer,
|
lambda i: lr_scheduler(trainer.optimizer, i, epoch),
|
||||||
scaler,
|
train_metrics.log,
|
||||||
lr_scheduler,
|
|
||||||
logger,
|
|
||||||
epoch,
|
|
||||||
steps_per_epoch,
|
|
||||||
timeout_handler,
|
timeout_handler,
|
||||||
ema=ema,
|
|
||||||
use_amp=use_amp,
|
|
||||||
prof=prof,
|
prof=prof,
|
||||||
register_metrics=epoch == start_epoch,
|
step=epoch * train_loader_len,
|
||||||
batch_size_multiplier=batch_size_multiplier,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not skip_validation:
|
if not skip_validation:
|
||||||
prec1, nimg = validate(
|
trainer.eval()
|
||||||
val_loader,
|
for k, infer_fn in trainer.validation_steps().items():
|
||||||
model_and_loss,
|
if logger is not None:
|
||||||
logger,
|
data_iter = logger.iteration_generator_wrapper(
|
||||||
epoch,
|
val_loader, mode="val"
|
||||||
use_amp=use_amp,
|
)
|
||||||
prof=prof,
|
else:
|
||||||
register_metrics=epoch == start_epoch,
|
data_iter = val_loader
|
||||||
)
|
|
||||||
if use_ema:
|
step_prec1, _ = validate(
|
||||||
model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()})
|
infer_fn,
|
||||||
prec1, nimg = validate(
|
data_iter,
|
||||||
val_loader,
|
val_metrics[k].log,
|
||||||
model_ema,
|
|
||||||
logger,
|
|
||||||
epoch,
|
|
||||||
prof=prof,
|
prof=prof,
|
||||||
register_metrics=epoch == start_epoch,
|
|
||||||
prefix='val_ema'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if k == "val":
|
||||||
|
prec1 = step_prec1
|
||||||
|
|
||||||
if prec1 > best_prec1:
|
if prec1 > best_prec1:
|
||||||
is_best = True
|
is_best = True
|
||||||
best_prec1 = prec1
|
best_prec1 = prec1
|
||||||
else:
|
else:
|
||||||
is_best = False
|
is_best = False
|
||||||
else:
|
else:
|
||||||
is_best = True
|
is_best = False
|
||||||
best_prec1 = 0
|
best_prec1 = 0
|
||||||
|
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.end_epoch()
|
logger.end_epoch()
|
||||||
|
|
||||||
if save_checkpoints and (
|
if save_checkpoints and (
|
||||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
not torch.distributed.is_initialized()
|
||||||
|
or torch.distributed.get_rank() == 0
|
||||||
):
|
):
|
||||||
if should_backup_checkpoint(epoch):
|
if should_backup_checkpoint(epoch):
|
||||||
backup_filename = "{}checkpoint-{}.pth.tar".format(backup_prefix, epoch + 1)
|
backup_filename = "{}checkpoint-{}.pth.tar".format(
|
||||||
|
backup_prefix, epoch + 1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
backup_filename = None
|
backup_filename = None
|
||||||
checkpoint_state = {
|
checkpoint_state = {
|
||||||
"epoch": epoch + 1,
|
"epoch": epoch + 1,
|
||||||
"state_dict": model_and_loss.model.state_dict(),
|
|
||||||
"best_prec1": best_prec1,
|
"best_prec1": best_prec1,
|
||||||
"optimizer": optimizer.state_dict(),
|
**trainer.state_dict(),
|
||||||
}
|
}
|
||||||
if use_ema:
|
|
||||||
checkpoint_state["state_dict_ema"] = ema.state_dict()
|
|
||||||
|
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
checkpoint_state,
|
checkpoint_state,
|
||||||
is_best,
|
is_best,
|
||||||
|
@ -561,6 +416,7 @@ def train_loop(
|
||||||
backup_filename=backup_filename,
|
backup_filename=backup_filename,
|
||||||
filename=checkpoint_filename,
|
filename=checkpoint_filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
if early_stopping_patience > 0:
|
if early_stopping_patience > 0:
|
||||||
if not is_best:
|
if not is_best:
|
||||||
epochs_since_improvement += 1
|
epochs_since_improvement += 1
|
||||||
|
@ -570,4 +426,6 @@ def train_loop(
|
||||||
break
|
break
|
||||||
if interrupted:
|
if interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
# }}}
|
# }}}
|
||||||
|
|
|
@ -97,7 +97,7 @@ def accuracy(output, target, topk=(1,)):
|
||||||
|
|
||||||
|
|
||||||
def reduce_tensor(tensor):
|
def reduce_tensor(tensor):
|
||||||
rt = tensor.clone()
|
rt = tensor.clone().detach()
|
||||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||||
rt /= (
|
rt /= (
|
||||||
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||||
|
@ -114,6 +114,7 @@ class TimeoutHandler:
|
||||||
def __init__(self, sig=signal.SIGTERM):
|
def __init__(self, sig=signal.SIGTERM):
|
||||||
self.sig = sig
|
self.sig = sig
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def interrupted(self):
|
def interrupted(self):
|
||||||
if not dist.is_initialized():
|
if not dist.is_initialized():
|
||||||
|
|
|
@ -58,6 +58,12 @@ from image_classification.models import (
|
||||||
efficientnet_widese_b0,
|
efficientnet_widese_b0,
|
||||||
efficientnet_widese_b4,
|
efficientnet_widese_b4,
|
||||||
)
|
)
|
||||||
|
from image_classification.optimizers import (
|
||||||
|
get_optimizer,
|
||||||
|
lr_cosine_policy,
|
||||||
|
lr_linear_policy,
|
||||||
|
lr_step_policy,
|
||||||
|
)
|
||||||
import dllogger
|
import dllogger
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,7 +108,9 @@ def add_parser_arguments(parser, skip_arch=False):
|
||||||
metavar="ARCH",
|
metavar="ARCH",
|
||||||
default="resnet50",
|
default="resnet50",
|
||||||
choices=model_names,
|
choices=model_names,
|
||||||
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)",
|
help="model architecture: "
|
||||||
|
+ " | ".join(model_names)
|
||||||
|
+ " (default: resnet50)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -290,6 +298,13 @@ def add_parser_arguments(parser, skip_arch=False):
|
||||||
dest="save_checkpoints",
|
dest="save_checkpoints",
|
||||||
help="do not store any checkpoints, useful for benchmarking",
|
help="do not store any checkpoints, useful for benchmarking",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str,
|
||||||
|
default = "no",
|
||||||
|
choices=["no", "script"],
|
||||||
|
help="no -> do not use torch.jit; script -> use torch.jit.script"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--checkpoint-filename", default="checkpoint.pth.tar", type=str)
|
parser.add_argument("--checkpoint-filename", default="checkpoint.pth.tar", type=str)
|
||||||
|
|
||||||
|
@ -320,7 +335,7 @@ def add_parser_arguments(parser, skip_arch=False):
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help="number of classes"
|
help="number of classes",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -432,13 +447,25 @@ def prepare_for_training(args, model_args, model_arch):
|
||||||
if args.image_size is not None
|
if args.image_size is not None
|
||||||
else model.arch.default_image_size
|
else model.arch.default_image_size
|
||||||
)
|
)
|
||||||
model_and_loss = ModelAndLoss(model, loss, cuda=True, memory_format=memory_format)
|
|
||||||
if args.use_ema is not None:
|
scaler = torch.cuda.amp.GradScaler(
|
||||||
model_ema = deepcopy(model_and_loss)
|
init_scale=args.static_loss_scale,
|
||||||
ema = EMA(args.use_ema)
|
growth_factor=2,
|
||||||
else:
|
backoff_factor=0.5,
|
||||||
model_ema = None
|
growth_interval=100 if args.dynamic_loss_scale else 1000000000,
|
||||||
ema = None
|
enabled=args.amp,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = Executor(
|
||||||
|
model,
|
||||||
|
loss(),
|
||||||
|
cuda=True,
|
||||||
|
memory_format=memory_format,
|
||||||
|
amp=args.amp,
|
||||||
|
scaler=scaler,
|
||||||
|
divide_loss=batch_size_multiplier,
|
||||||
|
ts_script = args.jit == "script",
|
||||||
|
)
|
||||||
|
|
||||||
# Create data loaders and optimizers as needed
|
# Create data loaders and optimizers as needed
|
||||||
if args.data_backend == "pytorch":
|
if args.data_backend == "pytorch":
|
||||||
|
@ -463,7 +490,7 @@ def prepare_for_training(args, model_args, model_arch):
|
||||||
args.batch_size,
|
args.batch_size,
|
||||||
model_args.num_classes,
|
model_args.num_classes,
|
||||||
args.mixup > 0.0,
|
args.mixup > 0.0,
|
||||||
interpolation = args.interpolation,
|
interpolation=args.interpolation,
|
||||||
augmentation=args.augmentation,
|
augmentation=args.augmentation,
|
||||||
start_epoch=start_epoch,
|
start_epoch=start_epoch,
|
||||||
workers=args.workers,
|
workers=args.workers,
|
||||||
|
@ -478,7 +505,7 @@ def prepare_for_training(args, model_args, model_arch):
|
||||||
args.batch_size,
|
args.batch_size,
|
||||||
model_args.num_classes,
|
model_args.num_classes,
|
||||||
False,
|
False,
|
||||||
interpolation = args.interpolation,
|
interpolation=args.interpolation,
|
||||||
workers=args.workers,
|
workers=args.workers,
|
||||||
memory_format=memory_format,
|
memory_format=memory_format,
|
||||||
)
|
)
|
||||||
|
@ -508,41 +535,38 @@ def prepare_for_training(args, model_args, model_arch):
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = get_optimizer(
|
optimizer = get_optimizer(
|
||||||
list(model_and_loss.model.named_parameters()),
|
list(executor.model.named_parameters()),
|
||||||
args.lr,
|
args.lr,
|
||||||
args=args,
|
args=args,
|
||||||
state=optimizer_state,
|
state=optimizer_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lr_schedule == "step":
|
if args.lr_schedule == "step":
|
||||||
lr_policy = lr_step_policy(
|
lr_policy = lr_step_policy(args.lr, [30, 60, 80], 0.1, args.warmup)
|
||||||
args.lr, [30, 60, 80], 0.1, args.warmup, logger=logger
|
|
||||||
)
|
|
||||||
elif args.lr_schedule == "cosine":
|
elif args.lr_schedule == "cosine":
|
||||||
lr_policy = lr_cosine_policy(
|
lr_policy = lr_cosine_policy(
|
||||||
args.lr, args.warmup, args.epochs, end_lr=args.end_lr, logger=logger
|
args.lr, args.warmup, args.epochs, end_lr=args.end_lr
|
||||||
)
|
)
|
||||||
elif args.lr_schedule == "linear":
|
elif args.lr_schedule == "linear":
|
||||||
lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs, logger=logger)
|
lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs)
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler(
|
|
||||||
init_scale=args.static_loss_scale,
|
|
||||||
growth_factor=2,
|
|
||||||
backoff_factor=0.5,
|
|
||||||
growth_interval=100 if args.dynamic_loss_scale else 1000000000,
|
|
||||||
enabled=args.amp,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
model_and_loss.distributed(args.gpu)
|
executor.distributed(args.gpu)
|
||||||
|
|
||||||
model_and_loss.load_model_state(model_state)
|
if model_state is not None:
|
||||||
if (ema is not None) and (model_state_ema is not None):
|
executor.model.load_state_dict(model_state)
|
||||||
print("load ema")
|
|
||||||
ema.load_state_dict(model_state_ema)
|
|
||||||
|
|
||||||
return (model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema,
|
trainer = Trainer(
|
||||||
train_loader_len, batch_size_multiplier, start_epoch)
|
executor,
|
||||||
|
optimizer,
|
||||||
|
grad_acc_steps=batch_size_multiplier,
|
||||||
|
ema=args.use_ema,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (args.use_ema is not None) and (model_state_ema is not None):
|
||||||
|
trainer.ema_executor.model.load_state_dict(model_state_ema)
|
||||||
|
|
||||||
|
return (trainer, lr_policy, train_loader, train_loader_len, val_loader, logger, start_epoch)
|
||||||
|
|
||||||
|
|
||||||
def main(args, model_args, model_arch):
|
def main(args, model_args, model_arch):
|
||||||
|
@ -550,23 +574,24 @@ def main(args, model_args, model_arch):
|
||||||
global best_prec1
|
global best_prec1
|
||||||
best_prec1 = 0
|
best_prec1 = 0
|
||||||
|
|
||||||
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \
|
(
|
||||||
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch)
|
trainer,
|
||||||
|
|
||||||
train_loop(
|
|
||||||
model_and_loss,
|
|
||||||
optimizer,
|
|
||||||
scaler,
|
|
||||||
lr_policy,
|
lr_policy,
|
||||||
train_loader,
|
train_loader,
|
||||||
|
train_loader_len,
|
||||||
|
val_loader,
|
||||||
|
logger,
|
||||||
|
start_epoch,
|
||||||
|
) = prepare_for_training(args, model_args, model_arch)
|
||||||
|
|
||||||
|
train_loop(
|
||||||
|
trainer,
|
||||||
|
lr_policy,
|
||||||
|
train_loader,
|
||||||
|
train_loader_len,
|
||||||
val_loader,
|
val_loader,
|
||||||
logger,
|
logger,
|
||||||
should_backup_checkpoint(args),
|
should_backup_checkpoint(args),
|
||||||
ema=ema,
|
|
||||||
model_ema=model_ema,
|
|
||||||
steps_per_epoch=train_loader_len,
|
|
||||||
use_amp=args.amp,
|
|
||||||
batch_size_multiplier=batch_size_multiplier,
|
|
||||||
start_epoch=start_epoch,
|
start_epoch=start_epoch,
|
||||||
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
|
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
|
||||||
if args.run_epochs != -1
|
if args.run_epochs != -1
|
||||||
|
@ -587,7 +612,6 @@ def main(args, model_args, model_arch):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
epilog = [
|
epilog = [
|
||||||
"Based on the architecture picked by --arch flag, you may use the following options:\n"
|
"Based on the architecture picked by --arch flag, you may use the following options:\n"
|
||||||
]
|
]
|
||||||
|
@ -603,7 +627,7 @@ if __name__ == "__main__":
|
||||||
add_parser_arguments(parser)
|
add_parser_arguments(parser)
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
|
|
||||||
model_arch = available_models()[args.arch]
|
model_arch = available_models()[args.arch]
|
||||||
model_args, rest = model_arch.parser().parse_known_args(rest)
|
model_args, rest = model_arch.parser().parse_known_args(rest)
|
||||||
print(model_args)
|
print(model_args)
|
||||||
|
|
|
@ -68,9 +68,11 @@ def parse_quantization(parser):
|
||||||
metavar="ARCH",
|
metavar="ARCH",
|
||||||
default="efficientnet-quant-b0",
|
default="efficientnet-quant-b0",
|
||||||
choices=model_names,
|
choices=model_names,
|
||||||
help="model architecture: " + " | ".join(model_names) + " (default: efficientnet-quant-b0)",
|
help="model architecture: "
|
||||||
|
+ " | ".join(model_names)
|
||||||
|
+ " (default: efficientnet-quant-b0)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-calibration",
|
"--skip-calibration",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
@ -80,6 +82,7 @@ def parse_quantization(parser):
|
||||||
|
|
||||||
def parse_training_args(parser):
|
def parse_training_args(parser):
|
||||||
from main import add_parser_arguments
|
from main import add_parser_arguments
|
||||||
|
|
||||||
return add_parser_arguments(parser)
|
return add_parser_arguments(parser)
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,28 +95,29 @@ def main(args, model_args, model_arch):
|
||||||
|
|
||||||
select_default_calib_method()
|
select_default_calib_method()
|
||||||
|
|
||||||
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \
|
(
|
||||||
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch)
|
trainer,
|
||||||
|
|
||||||
print(f"RUNNING QUANTIZATION")
|
|
||||||
|
|
||||||
if not skip_calibration:
|
|
||||||
calibrate(model_and_loss.model, train_loader, logger, calib_iter=10)
|
|
||||||
|
|
||||||
train_loop(
|
|
||||||
model_and_loss,
|
|
||||||
optimizer,
|
|
||||||
scaler,
|
|
||||||
lr_policy,
|
lr_policy,
|
||||||
train_loader,
|
train_loader,
|
||||||
|
train_loader_len,
|
||||||
|
val_loader,
|
||||||
|
logger,
|
||||||
|
start_epoch,
|
||||||
|
) = prepare_for_training(args, model_args, model_arch)
|
||||||
|
|
||||||
|
print(f"RUNNING QUANTIZATION")
|
||||||
|
|
||||||
|
if not skip_calibration:
|
||||||
|
calibrate(trainer.model_and_loss.model, train_loader, logger, calib_iter=10)
|
||||||
|
|
||||||
|
train_loop(
|
||||||
|
trainer,
|
||||||
|
lr_policy,
|
||||||
|
train_loader,
|
||||||
|
train_loader_len,
|
||||||
val_loader,
|
val_loader,
|
||||||
logger,
|
logger,
|
||||||
should_backup_checkpoint(args),
|
should_backup_checkpoint(args),
|
||||||
ema=ema,
|
|
||||||
model_ema=model_ema,
|
|
||||||
steps_per_epoch=train_loader_len,
|
|
||||||
use_amp=args.amp,
|
|
||||||
batch_size_multiplier=batch_size_multiplier,
|
|
||||||
start_epoch=start_epoch,
|
start_epoch=start_epoch,
|
||||||
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
|
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
|
||||||
if args.run_epochs != -1
|
if args.run_epochs != -1
|
||||||
|
@ -124,7 +128,7 @@ def main(args, model_args, model_arch):
|
||||||
skip_validation=args.training_only,
|
skip_validation=args.training_only,
|
||||||
save_checkpoints=args.save_checkpoints,
|
save_checkpoints=args.save_checkpoints,
|
||||||
checkpoint_dir=args.workspace,
|
checkpoint_dir=args.workspace,
|
||||||
checkpoint_filename='quantized_' + args.checkpoint_filename,
|
checkpoint_filename="quantized_" + args.checkpoint_filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||||
|
@ -149,7 +153,7 @@ if __name__ == "__main__":
|
||||||
parse_training(parser, skip_arch=True)
|
parse_training(parser, skip_arch=True)
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
|
|
||||||
model_arch = available_models()[args.arch]
|
model_arch = available_models()[args.arch]
|
||||||
model_args, rest = model_arch.parser().parse_known_args(rest)
|
model_args, rest = model_arch.parser().parse_known_args(rest)
|
||||||
print(model_args)
|
print(model_args)
|
||||||
|
@ -159,4 +163,3 @@ if __name__ == "__main__":
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
main(args, model_args, model_arch)
|
main(args, model_args, model_arch)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue