Merge: [Tacotron2/PyT] Using native amp
This commit is contained in:
commit
0b7816ea4d
|
@ -39,7 +39,7 @@ from torch.nn.parameter import Parameter
|
|||
import torch.distributed as dist
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import models
|
||||
import loss_functions
|
||||
|
@ -51,10 +51,6 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
|||
|
||||
from scipy.io.wavfile import write as write_wav
|
||||
|
||||
from apex import amp
|
||||
amp.lists.functional_overrides.FP32_FUNCS.remove('softmax')
|
||||
amp.lists.functional_overrides.FP16_FUNCS.append('softmax')
|
||||
|
||||
|
||||
def parse_args(parser):
|
||||
"""
|
||||
|
@ -389,16 +385,13 @@ def main():
|
|||
cpu_run=False,
|
||||
uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight)
|
||||
|
||||
if not args.amp and distributed_run:
|
||||
model = DDP(model)
|
||||
if distributed_run:
|
||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
if args.amp:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
|
||||
if distributed_run:
|
||||
model = DDP(model)
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||
|
||||
try:
|
||||
sigma = args.sigma
|
||||
|
@ -480,8 +473,10 @@ def main():
|
|||
model.zero_grad()
|
||||
x, y, num_items = batch_to_gpu(batch)
|
||||
|
||||
y_pred = model(x)
|
||||
loss = criterion(y_pred, y)
|
||||
#AMP upstream autocast
|
||||
with torch.cuda.amp.autocast(enabled=args.amp):
|
||||
y_pred = model(x)
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
if distributed_run:
|
||||
reduced_loss = reduce_tensor(loss.data, world_size).item()
|
||||
|
@ -500,16 +495,19 @@ def main():
|
|||
reduced_num_items_epoch += reduced_num_items
|
||||
|
||||
if args.amp:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(optimizer), args.grad_clip_thresh)
|
||||
model.parameters(), args.grad_clip_thresh)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), args.grad_clip_thresh)
|
||||
optimizer.step()
|
||||
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
iter_stop_time = time.perf_counter()
|
||||
|
|
Loading…
Reference in a new issue