diff --git a/PyTorch/SpeechSynthesis/Tacotron2/train.py b/PyTorch/SpeechSynthesis/Tacotron2/train.py index 17fd4859..732c9b66 100644 --- a/PyTorch/SpeechSynthesis/Tacotron2/train.py +++ b/PyTorch/SpeechSynthesis/Tacotron2/train.py @@ -39,7 +39,7 @@ from torch.nn.parameter import Parameter import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler -from apex.parallel import DistributedDataParallel as DDP +from torch.nn.parallel import DistributedDataParallel as DDP import models import loss_functions @@ -51,10 +51,6 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity from scipy.io.wavfile import write as write_wav -from apex import amp -amp.lists.functional_overrides.FP32_FUNCS.remove('softmax') -amp.lists.functional_overrides.FP16_FUNCS.append('softmax') - def parse_args(parser): """ @@ -389,16 +385,13 @@ def main(): cpu_run=False, uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight) - if not args.amp and distributed_run: - model = DDP(model) + if distributed_run: + model = DDP(model, device_ids=[local_rank], output_device=local_rank) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) - if args.amp: - model, optimizer = amp.initialize(model, optimizer, opt_level="O1") - if distributed_run: - model = DDP(model) + scaler = torch.cuda.amp.GradScaler(enabled=args.amp) try: sigma = args.sigma @@ -480,8 +473,10 @@ def main(): model.zero_grad() x, y, num_items = batch_to_gpu(batch) - y_pred = model(x) - loss = criterion(y_pred, y) + #AMP upstream autocast + with torch.cuda.amp.autocast(enabled=args.amp): + y_pred = model(x) + loss = criterion(y_pred, y) if distributed_run: reduced_loss = reduce_tensor(loss.data, world_size).item() @@ -500,16 +495,19 @@ def main(): reduced_num_items_epoch += reduced_num_items if args.amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + scaler.scale(loss).backward() + scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( - amp.master_params(optimizer), args.grad_clip_thresh) + model.parameters(), args.grad_clip_thresh) + scaler.step(optimizer) + scaler.update() else: loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.grad_clip_thresh) + optimizer.step() - optimizer.step() + model.zero_grad(set_to_none=True) torch.cuda.synchronize() iter_stop_time = time.perf_counter()