NMT timing and tokenizer stats utils (#3004)

* 1. Fixed logging of log_var_q_z_given_x.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated timing of eval_step.
2. Added caching of encoder values when sampling.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed using cached values in eval_step.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated timing names to include "_timing"

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added support in sliding mean.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added the tokenizer stats script.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Moved duplicated code into a method.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added logging of detailed NMT timing for non-bottleneck and bottleneck models.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added logging of both src and tgt length when measuring timing.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added plotting script.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added usage.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added missing copyright.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added post-processing for timing data to be collcted from list of dict to a dict of lists.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added warmup batch when measuring timing.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated plot script.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Debugging.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed a typo.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added control over plotting.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added control over font size.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Added control over tick font size.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Fixed style.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Moved scripts.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>
This commit is contained in:
Micha Livne 2021-10-22 13:51:04 -04:00 committed by GitHub
parent 1f36f32ee6
commit 4452162b93
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 412 additions and 89 deletions

View file

@ -22,6 +22,7 @@ USAGE Example:
"""
import json
from argparse import ArgumentParser
import torch
@ -35,6 +36,62 @@ from nemo.collections.nlp.modules.common.transformer import (
from nemo.utils import logging
def translate_text(
models, args, src_text, tgt_text, tgt_text_all, src_texts, all_scores, all_timing, ensemble_generator
):
if len(models) > 1:
src_ids, src_mask = models[0].prepare_inference_batch(src_text)
best_translations = ensemble_generator(src_ids, src_mask, return_beam_scores=args.write_scores)
if args.write_scores:
all_results, scores, best_translations = (
best_translations[0],
best_translations[1],
best_translations[2],
)
scores = scores.view(-1).data.cpu().numpy().tolist()
all_scores += scores
src_texts += [item for item in src_text for i in range(args.beam_size)]
all_results = models[0].ids_to_postprocessed_text(
all_results, models[0].decoder_tokenizer, models[0].target_processor
)
tgt_text_all += all_results
best_translations = models[0].ids_to_postprocessed_text(
best_translations, models[0].decoder_tokenizer, models[0].target_processor
)
tgt_text += best_translations
else:
model = models[0]
best_translations = model.translate(
text=src_text,
source_lang=args.source_lang,
target_lang=args.target_lang,
return_beam_scores=args.write_scores,
log_timing=args.write_timing,
)
if args.write_timing:
*best_translations, timing_dict = best_translations
all_timing.append(timing_dict)
else:
best_translations = (best_translations,)
if args.write_scores:
all_results, scores, best_translations = (
best_translations[0],
best_translations[1],
best_translations[2],
)
all_scores += scores
src_texts += [item for item in src_text for i in range(args.beam_size)]
tgt_text_all += all_results
else:
best_translations = best_translations[0]
tgt_text += best_translations
print(f"Translated {len(tgt_text)} sentences")
def main():
parser = ArgumentParser()
parser.add_argument(
@ -71,6 +128,11 @@ def main():
action="store_true",
help="Whether to write a separate file with scores not including length penalties corresponding to each beam hypothesis (.score suffix)",
)
parser.add_argument(
"--write_timing",
action="store_true",
help="Whether to write a separate file with detailed timing info (.timing.json suffix)",
)
# shallow fusion specific parameters
parser.add_argument(
"--lm_model",
@ -92,11 +154,15 @@ def main():
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path).eval()
models.append(model)
if (len(models) > 1) and (args.write_timing):
raise RuntimeError("Cannot measure timing when more than 1 model is used")
src_text = []
tgt_text = []
tgt_text_all = []
src_texts = []
all_scores = []
all_timing = []
if torch.cuda.is_available():
models = [model.cuda() for model in models]
@ -124,6 +190,7 @@ def main():
)
else:
model = models[0]
ensemble_generator = None
if lm_model is not None:
model.beam_search = BeamSearchSequenceGeneratorWithLanguageModel(
embedding=model.decoder.embedding,
@ -155,91 +222,49 @@ def main():
logging.info(f"Translating: {args.srctext}")
count = 0
with open(args.srctext, 'r') as src_f:
for line in src_f:
src_text.append(line.strip())
if len(src_text) == args.batch_size:
if len(models) > 1:
src_ids, src_mask = models[0].prepare_inference_batch(src_text)
best_translations = ensemble_generator(src_ids, src_mask, return_beam_scores=args.write_scores)
if args.write_scores:
all_results, scores, best_translations = (
best_translations[0],
best_translations[1],
best_translations[2],
)
scores = scores.view(-1).data.cpu().numpy().tolist()
all_scores += scores
src_texts += [item for item in src_text for i in range(args.beam_size)]
all_results = models[0].ids_to_postprocessed_text(
all_results, models[0].decoder_tokenizer, models[0].target_processor
)
tgt_text_all += all_results
best_translations = models[0].ids_to_postprocessed_text(
best_translations, models[0].decoder_tokenizer, models[0].target_processor
# warmup when measuring timing
if not all_timing:
print("running a warmup batch")
translate_text(
models=models,
args=args,
src_text=src_text,
tgt_text=[],
tgt_text_all=[],
src_texts=[],
all_scores=[],
all_timing=[],
ensemble_generator=ensemble_generator,
)
tgt_text += best_translations
else:
best_translations = model.translate(
text=src_text,
source_lang=args.source_lang,
target_lang=args.target_lang,
return_beam_scores=args.write_scores,
)
if args.write_scores:
all_results, scores, best_translations = (
best_translations[0],
best_translations[1],
best_translations[2],
)
all_scores += scores
src_texts += [item for item in src_text for i in range(args.beam_size)]
tgt_text_all += all_results
tgt_text += best_translations
translate_text(
models=models,
args=args,
src_text=src_text,
tgt_text=tgt_text,
tgt_text_all=tgt_text_all,
src_texts=src_texts,
all_scores=all_scores,
all_timing=all_timing,
ensemble_generator=ensemble_generator,
)
src_text = []
print(f"Translated {count + 1} sentences")
count += 1
if len(src_text) > 0:
if len(models) > 1:
src_ids, src_mask = models[0].prepare_inference_batch(src_text)
best_translations = ensemble_generator(src_ids, src_mask, return_beam_scores=args.write_scores)
if args.write_scores:
all_results, scores, best_translations = (
best_translations[0],
best_translations[1],
best_translations[2],
)
scores = scores.view(-1).data.cpu().numpy().tolist()
all_scores += scores
src_texts += [item for item in src_text for i in range(args.beam_size)]
all_results = models[0].ids_to_postprocessed_text(
all_results, models[0].decoder_tokenizer, models[0].target_processor
)
tgt_text_all += all_results
best_translations = models[0].ids_to_postprocessed_text(
best_translations, models[0].decoder_tokenizer, models[0].target_processor
)
tgt_text += best_translations
else:
best_translations = model.translate(
text=src_text,
source_lang=args.source_lang,
target_lang=args.target_lang,
return_beam_scores=args.write_scores,
)
if args.write_scores:
all_results, scores, best_translations = (
best_translations[0],
best_translations[1],
best_translations[2],
)
all_scores += scores
src_texts += [item for item in src_text for i in range(args.beam_size)]
tgt_text_all += all_results
tgt_text += best_translations
src_text = []
print(f"Translated {count} sentences")
translate_text(
models=models,
args=args,
src_text=src_text,
tgt_text=tgt_text,
tgt_text_all=tgt_text_all,
src_texts=src_texts,
all_scores=all_scores,
all_timing=all_timing,
ensemble_generator=ensemble_generator,
)
with open(args.tgtout, 'w') as tgt_f:
for line in tgt_text:
@ -250,6 +275,16 @@ def main():
for line, score, inp in zip(tgt_text_all, all_scores, src_texts):
tgt_f_scores.write(inp + "\t" + line + "\t" + str(score) + "\n")
if args.write_timing:
# collect list of dicts to a dict of lists
timing_dict = {}
if len(all_timing):
for k in all_timing[0].keys():
timing_dict[k] = [t[k] for t in all_timing]
with open(args.tgtout + '.timing.json', 'w') as timing_fh:
json.dump(timing_dict, timing_fh)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter

View file

@ -317,13 +317,18 @@ class MTBottleneckModel(MTEncDecModel):
inputs: a list of string containing detokenized inputs
"""
mode = self.training
timer = cache.get("timer", None)
try:
self.eval()
# build posterior distribution q(x|z)
if ("z" not in cache) or ("z_mean" not in cache) or ("z_mask" not in cache):
if timer is not None:
timer.start("encoder")
enc_hiddens, enc_mask = self.encoder(input_ids=src, encoder_mask=src_mask, return_mask=True)
z, z_mean, _ = self.encode_latent(hidden=enc_hiddens)
if timer is not None:
timer.stop("encoder")
else:
enc_mask = cache["z_mask"]
z = cache["z"]
@ -332,8 +337,8 @@ class MTBottleneckModel(MTEncDecModel):
if getattr(self, "deterministic_translate", True):
z = z_mean
if cache.get("timer", None) is not None:
cache["timer"].start("sampler")
if timer is not None:
timer.start("sampler")
# decoding cross attention context
context_hiddens = self.latent2hidden(z)
@ -342,8 +347,8 @@ class MTBottleneckModel(MTEncDecModel):
encoder_input_mask=enc_mask,
return_beam_scores=return_beam_scores,
)
if cache.get("timer", None) is not None:
cache["timer"].stop("sampler")
if timer is not None:
timer.stop("sampler")
if return_beam_scores:
all_translations, scores, best_translations = best_translations

View file

@ -46,7 +46,7 @@ from nemo.collections.nlp.modules.common.lm_utils import get_transformer
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.nlp.modules.common.transformer import BeamSearchSequenceGenerator, TopKSequenceGenerator
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.utils import logging, model_utils
from nemo.utils import logging, model_utils, timers
__all__ = ['MTEncDecModel']
@ -765,7 +765,9 @@ class MTEncDecModel(EncDecNLPModel):
return translations
@torch.no_grad()
def batch_translate(self, src: torch.LongTensor, src_mask: torch.LongTensor, return_beam_scores: bool = False):
def batch_translate(
self, src: torch.LongTensor, src_mask: torch.LongTensor, return_beam_scores: bool = False, cache={}
):
"""
Translates a minibatch of inputs from source language to target language.
Args:
@ -776,12 +778,20 @@ class MTEncDecModel(EncDecNLPModel):
inputs: a list of string containing detokenized inputs
"""
mode = self.training
timer = cache.get("timer", None)
try:
self.eval()
if timer is not None:
timer.start("encoder")
src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask)
if timer is not None:
timer.stop("encoder")
timer.start("sampler")
best_translations = self.beam_search(
encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask, return_beam_scores=return_beam_scores
)
if timer is not None:
timer.stop("sampler")
if return_beam_scores:
all_translations, scores, best_translations = best_translations
scores = scores.view(-1)
@ -827,7 +837,12 @@ class MTEncDecModel(EncDecNLPModel):
# TODO: We should drop source/target_lang arguments in favor of using self.src/tgt_language
@torch.no_grad()
def translate(
self, text: List[str], source_lang: str = None, target_lang: str = None, return_beam_scores: bool = False
self,
text: List[str],
source_lang: str = None,
target_lang: str = None,
return_beam_scores: bool = False,
log_timing: bool = False,
) -> List[str]:
"""
Translates list of sentences from source language to target language.
@ -855,19 +870,41 @@ class MTEncDecModel(EncDecNLPModel):
elif tgt_symbol in self.multilingual_ids:
prepend_ids = [tgt_symbol]
if log_timing:
timer = timers.NamedTimer()
else:
timer = None
cache = {
"timer": timer,
}
try:
self.eval()
src, src_mask = self.prepare_inference_batch(text, prepend_ids)
if return_beam_scores:
_, all_translations, scores, best_translations = self.batch_translate(
src, src_mask, return_beam_scores=True
src, src_mask, return_beam_scores=True, cache=cache,
)
return all_translations, scores, best_translations
return_val = all_translations, scores, best_translations
else:
_, translations = self.batch_translate(src, src_mask, return_beam_scores=False)
_, best_translations = self.batch_translate(src, src_mask, return_beam_scores=False, cache=cache)
return_val = best_translations
finally:
self.train(mode=mode)
return translations
if log_timing:
timing = timer.export()
timing["mean_src_length"] = src_mask.sum().cpu().item() / src_mask.shape[0]
tgt, tgt_mask = self.prepare_inference_batch(best_translations, prepend_ids)
timing["mean_tgt_length"] = tgt_mask.sum().cpu().item() / tgt_mask.shape[0]
if type(return_val) is tuple:
return_val = return_val + (timing,)
else:
return_val = (return_val, timing)
return return_val
@classmethod
def list_available_models(cls) -> Optional[Dict[str, str]]:

View file

@ -0,0 +1,147 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import multiprocessing as mp
import os
from functools import partial
import numpy as np
from matplotlib import pyplot as plt
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
# =============================================================================#
# Auxiliary methods
# =============================================================================#
def read_batch(fh, batch_size):
"""
Reads a batch (or smaller) chunk of lines.
"""
lines = []
for i in range(batch_size):
l = fh.readline()
if l is None:
break
else:
lines.append(l.strip())
return lines
def tokenize_line(line, tokenizer):
"""
Returns a tokenized line
"""
tokens = tokenizer.text_to_ids(line)
return tokens
def line_len(line, tokenizer):
"""
Returns a tokenized length of a text line
"""
tokens = tokenize_line(line, tokenizer)
return len(tokens)
# =============================================================================#
# Main script
# =============================================================================#
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Collects statistics over tokenized dataset')
parser.add_argument('input_files', metavar='N', type=str, nargs='+', help='Input files to parse')
parser.add_argument(
'--tokenizer_library', type=str, required=True, help='Path to pre-trained nemo-supported tokenizer model'
)
parser.add_argument(
'--tokenizer_model', type=str, required=True, help='Path to pre-trained nemo-supported tokenizer model'
)
parser.add_argument(
'--num_workers', type=int, default=mp.cpu_count(), help='Number of workers (default to number of CPUs)'
)
parser.add_argument('--max_lines', type=int, default=-1, help='Max number of lines to parse')
parser.add_argument('--batch_size', type=int, default=10000000, help='Batch size to parse in parallel')
parser.add_argument('--out_dir', type=str, default="", help='Path to store data and plots')
args = parser.parse_args()
tokenizer = get_nmt_tokenizer(library=args.tokenizer_library, tokenizer_model=args.tokenizer_model,)
all_len = []
for fn in args.input_files:
print(f"Parsing fn = {fn}")
# read file
fh = open(fn)
# read all batches
while True:
lines = read_batch(fh, args.batch_size)
# move to next file when no lines are read
if not lines:
break
# tokenize lines
with mp.Pool(args.num_workers) as p:
all_len.extend(p.map(partial(line_len, tokenizer=tokenizer), lines))
print(f"{fn}: Parsed {len(all_len)} lines")
# early stop, if required
if (args.max_lines > 0) and (len(all_len) >= args.max_lines):
lines = lines[: args.max_lines]
break
# early stop, if required
if (args.max_lines > 0) and (len(all_len) >= args.max_lines):
lines = lines[: args.max_lines]
break
# compute stats
# save all results
if args.out_dir:
if not os.path.exists(args.out_dir):
os.mkdir(args.out_dir)
stats = {
"samples": int(len(all_len)),
"mean": float(np.mean(all_len)),
"stdev": float(np.std(all_len)),
"min": float(np.min(all_len)),
"max": float(np.max(all_len)),
"median": float(np.median(all_len)),
}
print(f"stats = \n{stats}")
# save all results
if args.out_dir:
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir, exist_ok=True)
fh = open(os.path.join(args.out_dir, "lengths.txt"), "w")
fh.writelines(["{l}\n".format(l=l) for l in all_len])
json.dump(stats, open(os.path.join(args.out_dir, "stats.json"), "w"))
fig = plt.hist(all_len)
plt.savefig(os.path.join(args.out_dir, "lengths_hist.pdf"))

View file

@ -0,0 +1,99 @@
#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as an input XXXX.json files
(i.e., the output of nmt_transformer_infer.py --write_timing)
and creates plots XXX.PLOT_NAME.png at the same path.
"""
import json
import os
import sys
from matplotlib import pyplot as plt
# =============================================================================#
# Control Variables
# =============================================================================#
PLOTS_EXT = "pdf"
PLOT_TITLE = False
PLOT_XLABEL = True
PLOT_YLABEL = True
PLOT_LABEL_FONT_SIZE = 16
PLOT_GRID = True
# =============================================================================#
# Helper functions
# =============================================================================#
def plot_timing(lengths, timings, lengths_name, timings_name, fig=None):
if fig is None:
fig = plt.figure()
plt.scatter(lengths, timings, label=timings_name)
if PLOT_XLABEL:
plt.xlabel(f"{lengths_name} [tokens]", fontsize=PLOT_LABEL_FONT_SIZE)
if PLOT_YLABEL:
plt.ylabel(f"{timings_name} [sec]", fontsize=PLOT_LABEL_FONT_SIZE)
if PLOT_GRID:
plt.grid(True)
if PLOT_TITLE:
plt.title(f"{timings_name} vs. {lengths_name}")
plt.xticks(fontsize=PLOT_LABEL_FONT_SIZE)
plt.yticks(fontsize=PLOT_LABEL_FONT_SIZE)
plt.tight_layout()
return fig
# =============================================================================#
# Main script
# =============================================================================#
if __name__ == "__main__":
print("Usage: plot_detailed_timing.py <JSON FILE> <SJON FILE> ...")
for timing_fn in sys.argv[1:]:
# load data
print(f"Parsing file = {timing_fn}")
data = json.load(open(timing_fn))
# plot data
gifs_dict = {}
gifs_dict["encoder-src_len"] = plot_timing(
lengths=data["mean_src_length"],
timings=data["encoder"],
lengths_name="src length",
timings_name="encoder",
)
gifs_dict["sampler-src_len"] = plot_timing(
lengths=data["mean_src_length"],
timings=data["sampler"],
lengths_name="src length",
timings_name="sampler",
)
gifs_dict["sampler-tgt_len"] = plot_timing(
lengths=data["mean_tgt_length"],
timings=data["sampler"],
lengths_name="tgt length",
timings_name="sampler",
)
# save data
base_fn = os.path.splitext(timing_fn)[0]
for name, fig in gifs_dict.items():
plot_fn = f"{base_fn}.{name}.{PLOTS_EXT}"
print(f"Saving pot = {plot_fn}")
fig.savefig(plot_fn)