add support for --cpu-run

This commit is contained in:
maggiezha 2020-05-06 20:50:59 +10:00 committed by GitHub
parent bee3ddfa0e
commit 3a11f10bfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 86 additions and 47 deletions

View file

@ -1,4 +1,4 @@
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.01-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.03-py3
FROM ${FROM_IMAGE_NAME}
ADD . /workspace/tacotron2

View file

@ -68,7 +68,8 @@ def parse_args(parser):
help='Include warmup')
parser.add_argument('--stft-hop-length', type=int, default=256,
help='STFT hop length for estimating audio length from mel size')
parser.add_argument('--cpu-run', action='store_true',
help='Run inference on CPU')
return parser
@ -102,16 +103,18 @@ def unwrap_distributed(state_dict):
return new_state_dict
def load_and_setup_model(model_name, parser, checkpoint, amp_run, forward_is_infer=False):
def load_and_setup_model(model_name, parser, checkpoint, amp_run, cpu_run, forward_is_infer=False):
model_parser = models.parse_model_args(model_name, parser, add_help=False)
model_args, _ = model_parser.parse_known_args()
model_config = models.get_model_config(model_name, model_args)
model = models.get_model(model_name, model_config, to_cuda=True,
forward_is_infer=forward_is_infer)
model = models.get_model(model_name, model_config, cpu_run, forward_is_infer=forward_is_infer)
if checkpoint is not None:
state_dict = torch.load(checkpoint)['state_dict']
if cpu_run:
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict']
else:
state_dict = torch.load(checkpoint)['state_dict']
if checkpoint_from_distributed(state_dict):
state_dict = unwrap_distributed(state_dict)
@ -164,23 +167,26 @@ def prepare_input_sequence(texts):
class MeasureTime():
def __init__(self, measurements, key):
def __init__(self, measurements, key, cpu_run):
self.measurements = measurements
self.key = key
self.cpu_run = cpu_run
def __enter__(self):
torch.cuda.synchronize()
if self.cpu_run == False:
torch.cuda.synchronize()
self.t0 = time.perf_counter()
def __exit__(self, exc_type, exc_value, exc_traceback):
torch.cuda.synchronize()
if self.cpu_run == False:
torch.cuda.synchronize()
self.measurements[self.key] = time.perf_counter() - self.t0
def main():
"""
Launches text to speech (inference).
Inference is executed on a single GPU.
Inference is executed on a single GPU or CPU.
"""
parser = argparse.ArgumentParser(
description='PyTorch Tacotron 2 Inference')
@ -195,10 +201,14 @@ def main():
DLLogger.log(step="PARAMETER", data={'model_name':'Tacotron2_PyT'})
tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
args.amp_run, forward_is_infer=True)
args.amp_run, args.cpu_run, forward_is_infer=True)
waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow,
args.amp_run, forward_is_infer=True)
denoiser = Denoiser(waveglow).cuda()
args.amp_run, args.cpu_run, forward_is_infer=True)
if args.cpu_run:
denoiser = Denoiser(waveglow, args.cpu_run)
else:
denoiser = Denoiser(waveglow, args.cpu_run).cuda()
jitted_tacotron2 = torch.jit.script(tacotron2)
@ -211,9 +221,14 @@ def main():
sys.exit(1)
if args.include_warmup:
sequence = torch.randint(low=0, high=148, size=(1,50),
if args.cpu_run:
sequence = torch.randint(low=0, high=148, size=(1,50),
dtype=torch.long)
input_lengths = torch.IntTensor([sequence.size(1)]).long()
else:
sequence = torch.randint(low=0, high=148, size=(1,50),
dtype=torch.long).cuda()
input_lengths = torch.IntTensor([sequence.size(1)]).cuda().long()
input_lengths = torch.IntTensor([sequence.size(1)]).cuda().long()
for i in range(3):
with torch.no_grad():
mel, mel_lengths, _ = jitted_tacotron2(sequence, input_lengths)
@ -223,16 +238,17 @@ def main():
sequences_padded, input_lengths = prepare_input_sequence(texts)
with torch.no_grad(), MeasureTime(measurements, "tacotron2_time"):
with torch.no_grad(), MeasureTime(measurements, "tacotron2_time", args.cpu_run):
mel, mel_lengths, alignments = jitted_tacotron2(sequences_padded, input_lengths)
with torch.no_grad(), MeasureTime(measurements, "waveglow_time"):
with torch.no_grad(), MeasureTime(measurements, "waveglow_time", args.cpu_run):
audios = waveglow(mel, sigma=args.sigma_infer)
audios = audios.float()
audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)
print("Stopping after",mel.size(2),"decoder steps")
tacotron2_infer_perf = mel.size(0)*mel.size(2)/measurements['tacotron2_time']
tacotron2_infer_perf = mel.size(0)*mel.size(2)/measurements['tacotron2_time']
waveglow_infer_perf = audios.size(0)*audios.size(1)/measurements['waveglow_time']
DLLogger.log(step=0, data={"tacotron2_items_per_sec": tacotron2_infer_perf})

View file

@ -62,7 +62,7 @@ def init_bn(module):
init_bn(child)
def get_model(model_name, model_config, to_cuda,
def get_model(model_name, model_config, cpu_run,
uniform_initialize_bn_weight=False, forward_is_infer=False):
""" Code chooses a model based on name"""
model = None
@ -88,7 +88,7 @@ def get_model(model_name, model_config, to_cuda,
if uniform_initialize_bn_weight:
init_bn(model)
if to_cuda:
if cpu_run==False:
model = model.cuda()
return model

View file

@ -1,4 +1,4 @@
bash test_infer.sh -bs 1 -il 128 -p amp --num-iters 1003 --tacotron2 ./checkpoints/checkpoint_Tacotron2_amp --waveglow ./checkpoints/checkpoint_WaveGlow_amp
bash test_infer.sh -bs 4 -il 128 -p amp --num-iters 1003 --tacotron2 ./checkpoints/checkpoint_Tacotron2_amp --waveglow ./checkpoints/checkpoint_WaveGlow_amp
bash test_infer.sh -bs 1 -il 128 -p fp32 --num-iters 1003 --tacotron2 ./checkpoints/checkpoint_Tacotron2_fp32 --waveglow ./checkpoints/checkpoint_WaveGlow_fp32
bash test_infer.sh -bs 4 -il 128 -p fp32 --num-iters 1003 --tacotron2 ./checkpoints/checkpoint_Tacotron2_fp32 --waveglow ./checkpoints/checkpoint_WaveGlow_fp32
bash test_infer.sh -bs 1 -il 128 -p amp --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256
bash test_infer.sh -bs 4 -il 128 -p amp --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256
bash test_infer.sh -bs 1 -il 128 -p fp32 --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256
bash test_infer.sh -bs 4 -il 128 -p fp32 --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256

View file

@ -0,0 +1,4 @@
export CUDA_VISIBLE_DEVICES=
bash test_infer.sh -bs 1 -il 128 -p fp32 --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256 --cpu-run
bash test_infer.sh -bs 4 -il 128 -p fp32 --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256 --cpu-run

View file

@ -42,6 +42,8 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
from apex import amp
from waveglow.denoiser import Denoiser
def parse_args(parser):
"""
Parse commandline arguments.
@ -51,6 +53,7 @@ def parse_args(parser):
parser.add_argument('--waveglow', type=str,
help='full path to the WaveGlow model checkpoint file')
parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
parser.add_argument('-d', '--denoising-strength', default=0.01, type=float)
parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
help='Sampling rate')
parser.add_argument('--amp-run', action='store_true',
@ -65,23 +68,24 @@ def parse_args(parser):
help='Input length')
parser.add_argument('-bs', '--batch-size', type=int, default=1,
help='Batch size')
parser.add_argument('--cpu-run', action='store_true',
help='Run inference on CPU')
return parser
def load_and_setup_model(model_name, parser, checkpoint, amp_run, to_cuda=True):
def load_and_setup_model(model_name, parser, checkpoint, amp_run, cpu_run, forward_is_infer=False):
model_parser = models.parse_model_args(model_name, parser, add_help=False)
model_args, _ = model_parser.parse_known_args()
model_config = models.get_model_config(model_name, model_args)
model = models.get_model(model_name, model_config, to_cuda=to_cuda)
model = models.get_model(model_name, model_config, cpu_run, forward_is_infer=forward_is_infer)
if checkpoint is not None:
if to_cuda:
state_dict = torch.load(checkpoint)['state_dict']
if cpu_run:
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict']
else:
state_dict = torch.load(checkpoint,map_location='cpu')['state_dict']
state_dict = torch.load(checkpoint)['state_dict']
if checkpoint_from_distributed(state_dict):
state_dict = unwrap_distributed(state_dict)
@ -141,7 +145,7 @@ def print_stats(measurements_all):
def main():
"""
Launches text to speech (inference).
Inference is executed on a single GPU.
Inference is executed on a single GPU or CPU.
"""
parser = argparse.ArgumentParser(
description='PyTorch Tacotron 2 Inference')
@ -168,8 +172,15 @@ def main():
print("args:", args, unknown_args)
tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.amp_run)
waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run)
tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.amp_run, args.cpu_run, forward_is_infer=True)
waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run, args.cpu_run)
if args.cpu_run:
denoiser = Denoiser(waveglow, args.cpu_run)
else:
denoiser = Denoiser(waveglow, args.cpu_run).cuda()
jitted_tacotron2 = torch.jit.script(tacotron2)
texts = ["The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."]
texts = [texts[0][:args.input_length]]
@ -181,27 +192,29 @@ def main():
measurements = {}
with MeasureTime(measurements, "pre_processing"):
with MeasureTime(measurements, "pre_processing", args.cpu_run):
sequences_padded, input_lengths = prepare_input_sequence(texts)
with torch.no_grad():
with MeasureTime(measurements, "latency"):
with MeasureTime(measurements, "tacotron2_latency"):
mel, mel_lengths, _ = tacotron2.infer(sequences_padded, input_lengths)
with MeasureTime(measurements, "latency", args.cpu_run):
with MeasureTime(measurements, "tacotron2_latency", args.cpu_run):
mel, mel_lengths, _ = jitted_tacotron2(sequences_padded, input_lengths)
with MeasureTime(measurements, "waveglow_latency"):
with MeasureTime(measurements, "waveglow_latency", args.cpu_run):
audios = waveglow.infer(mel, sigma=args.sigma_infer)
audios = audios.float()
audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)
num_mels = mel.size(0)*mel.size(2)
num_samples = audios.size(0)*audios.size(1)
with MeasureTime(measurements, "type_conversion"):
with MeasureTime(measurements, "type_conversion", args.cpu_run):
audios = audios.float()
with MeasureTime(measurements, "data_transfer"):
with MeasureTime(measurements, "data_transfer", args.cpu_run):
audios = audios.cpu()
with MeasureTime(measurements, "storage"):
with MeasureTime(measurements, "storage", args.cpu_run):
audios = audios.numpy()
for i, audio in enumerate(audios):
audio_path = "audio_"+str(i)+".wav"

View file

@ -4,11 +4,12 @@ BATCH_SIZE=1
INPUT_LENGTH=128
PRECISION="fp32"
NUM_ITERS=1003 # extra 3 iterations for warmup
TACOTRON2_CKPT="checkpoint_Tacotron2_1500_fp32"
WAVEGLOW_CKPT="checkpoint_WaveGlow_1000_fp32"
TACOTRON2_CKPT="tacotron2_1032590_6000_amp"
WAVEGLOW_CKPT="waveglow_1076430_14000_amp"
AMP_RUN=""
TEST_PROGRAM="test_infer.py"
WN_CHANNELS=512
WN_CHANNELS=256
CPU_RUN=""
while [ -n "$1" ]
do
@ -57,6 +58,10 @@ do
WN_CHANNELS="$2"
shift
;;
--cpu-run)
CPU_RUN="--cpu-run"
shift
;;
*)
echo "Option $1 not recognized"
esac
@ -93,6 +98,7 @@ python $TEST_PROGRAM \
--log-file $NVLOG_FILE \
--num-iters $NUM_ITERS \
--wn-channels $WN_CHANNELS \
$CPU_RUN \
|& tee $TMP_LOGFILE
set +x