[Tacotron2/PyT] Using native amp

This commit is contained in:
Grzegorz Karch 2021-11-08 14:23:50 -08:00 committed by Krzysztof Kudrynski
parent dcd3bbac09
commit 4fe3b4e4ae

View file

@ -39,7 +39,7 @@ from torch.nn.parameter import Parameter
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from apex.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import DistributedDataParallel as DDP
import models
import loss_functions
@ -51,10 +51,6 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
from scipy.io.wavfile import write as write_wav
from apex import amp
amp.lists.functional_overrides.FP32_FUNCS.remove('softmax')
amp.lists.functional_overrides.FP16_FUNCS.append('softmax')
def parse_args(parser):
"""
@ -389,16 +385,13 @@ def main():
cpu_run=False,
uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight)
if not args.amp and distributed_run:
model = DDP(model)
if distributed_run:
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
weight_decay=args.weight_decay)
if args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
if distributed_run:
model = DDP(model)
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
try:
sigma = args.sigma
@ -480,8 +473,10 @@ def main():
model.zero_grad()
x, y, num_items = batch_to_gpu(batch)
y_pred = model(x)
loss = criterion(y_pred, y)
#AMP upstream autocast
with torch.cuda.amp.autocast(enabled=args.amp):
y_pred = model(x)
loss = criterion(y_pred, y)
if distributed_run:
reduced_loss = reduce_tensor(loss.data, world_size).item()
@ -500,16 +495,19 @@ def main():
reduced_num_items_epoch += reduced_num_items
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), args.grad_clip_thresh)
model.parameters(), args.grad_clip_thresh)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.grad_clip_thresh)
optimizer.step()
optimizer.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
iter_stop_time = time.perf_counter()