Merge: [Tacotron2/PyT] Using native amp

This commit is contained in:
Krzysztof Kudrynski 2021-11-08 14:23:51 -08:00
commit 0b7816ea4d

View file

@ -39,7 +39,7 @@ from torch.nn.parameter import Parameter
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from apex.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import models import models
import loss_functions import loss_functions
@ -51,10 +51,6 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
from scipy.io.wavfile import write as write_wav from scipy.io.wavfile import write as write_wav
from apex import amp
amp.lists.functional_overrides.FP32_FUNCS.remove('softmax')
amp.lists.functional_overrides.FP16_FUNCS.append('softmax')
def parse_args(parser): def parse_args(parser):
""" """
@ -389,16 +385,13 @@ def main():
cpu_run=False, cpu_run=False,
uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight) uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight)
if not args.amp and distributed_run: if distributed_run:
model = DDP(model) model = DDP(model, device_ids=[local_rank], output_device=local_rank)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
if args.amp: scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
if distributed_run:
model = DDP(model)
try: try:
sigma = args.sigma sigma = args.sigma
@ -480,8 +473,10 @@ def main():
model.zero_grad() model.zero_grad()
x, y, num_items = batch_to_gpu(batch) x, y, num_items = batch_to_gpu(batch)
y_pred = model(x) #AMP upstream autocast
loss = criterion(y_pred, y) with torch.cuda.amp.autocast(enabled=args.amp):
y_pred = model(x)
loss = criterion(y_pred, y)
if distributed_run: if distributed_run:
reduced_loss = reduce_tensor(loss.data, world_size).item() reduced_loss = reduce_tensor(loss.data, world_size).item()
@ -500,16 +495,19 @@ def main():
reduced_num_items_epoch += reduced_num_items reduced_num_items_epoch += reduced_num_items
if args.amp: if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss: scaler.scale(loss).backward()
scaled_loss.backward() scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), args.grad_clip_thresh) model.parameters(), args.grad_clip_thresh)
scaler.step(optimizer)
scaler.update()
else: else:
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.grad_clip_thresh) model.parameters(), args.grad_clip_thresh)
optimizer.step()
optimizer.step() model.zero_grad(set_to_none=True)
torch.cuda.synchronize() torch.cuda.synchronize()
iter_stop_time = time.perf_counter() iter_stop_time = time.perf_counter()