Merge: [Tacotron2/PyT] Using native amp
This commit is contained in:
commit
0b7816ea4d
|
@ -39,7 +39,7 @@ from torch.nn.parameter import Parameter
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from apex.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
import models
|
import models
|
||||||
import loss_functions
|
import loss_functions
|
||||||
|
@ -51,10 +51,6 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
||||||
|
|
||||||
from scipy.io.wavfile import write as write_wav
|
from scipy.io.wavfile import write as write_wav
|
||||||
|
|
||||||
from apex import amp
|
|
||||||
amp.lists.functional_overrides.FP32_FUNCS.remove('softmax')
|
|
||||||
amp.lists.functional_overrides.FP16_FUNCS.append('softmax')
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args(parser):
|
def parse_args(parser):
|
||||||
"""
|
"""
|
||||||
|
@ -389,16 +385,13 @@ def main():
|
||||||
cpu_run=False,
|
cpu_run=False,
|
||||||
uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight)
|
uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight)
|
||||||
|
|
||||||
if not args.amp and distributed_run:
|
if distributed_run:
|
||||||
model = DDP(model)
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
if args.amp:
|
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
|
|
||||||
if distributed_run:
|
|
||||||
model = DDP(model)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sigma = args.sigma
|
sigma = args.sigma
|
||||||
|
@ -480,8 +473,10 @@ def main():
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
x, y, num_items = batch_to_gpu(batch)
|
x, y, num_items = batch_to_gpu(batch)
|
||||||
|
|
||||||
y_pred = model(x)
|
#AMP upstream autocast
|
||||||
loss = criterion(y_pred, y)
|
with torch.cuda.amp.autocast(enabled=args.amp):
|
||||||
|
y_pred = model(x)
|
||||||
|
loss = criterion(y_pred, y)
|
||||||
|
|
||||||
if distributed_run:
|
if distributed_run:
|
||||||
reduced_loss = reduce_tensor(loss.data, world_size).item()
|
reduced_loss = reduce_tensor(loss.data, world_size).item()
|
||||||
|
@ -500,16 +495,19 @@ def main():
|
||||||
reduced_num_items_epoch += reduced_num_items
|
reduced_num_items_epoch += reduced_num_items
|
||||||
|
|
||||||
if args.amp:
|
if args.amp:
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
scaler.scale(loss).backward()
|
||||||
scaled_loss.backward()
|
scaler.unscale_(optimizer)
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
amp.master_params(optimizer), args.grad_clip_thresh)
|
model.parameters(), args.grad_clip_thresh)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
model.parameters(), args.grad_clip_thresh)
|
model.parameters(), args.grad_clip_thresh)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
optimizer.step()
|
model.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
iter_stop_time = time.perf_counter()
|
iter_stop_time = time.perf_counter()
|
||||||
|
|
Loading…
Reference in a new issue