NMT timing and tokenizer stats utils (#3004)
* 1. Fixed logging of log_var_q_z_given_x. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed style. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated timing of eval_step. 2. Added caching of encoder values when sampling. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed using cached values in eval_step. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated timing names to include "_timing" Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added support in sliding mean. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed style. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added the tokenizer stats script. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed style. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Moved duplicated code into a method. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added logging of detailed NMT timing for non-bottleneck and bottleneck models. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed style Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added logging of both src and tgt length when measuring timing. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added plotting script. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added usage. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed style. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added missing copyright. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added post-processing for timing data to be collcted from list of dict to a dict of lists. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added warmup batch when measuring timing. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated plot script. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Debugging. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed a typo. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added control over plotting. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed style. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added control over font size. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Added control over tick font size. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Fixed style. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Moved scripts. Signed-off-by: Micha Livne <mlivne@nvidia.com> Co-authored-by: Micha Livne <mlivne@nvidia.com>
This commit is contained in:
parent
1f36f32ee6
commit
4452162b93
|
@ -22,6 +22,7 @@ USAGE Example:
|
|||
"""
|
||||
|
||||
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
|
@ -35,6 +36,62 @@ from nemo.collections.nlp.modules.common.transformer import (
|
|||
from nemo.utils import logging
|
||||
|
||||
|
||||
def translate_text(
|
||||
models, args, src_text, tgt_text, tgt_text_all, src_texts, all_scores, all_timing, ensemble_generator
|
||||
):
|
||||
if len(models) > 1:
|
||||
src_ids, src_mask = models[0].prepare_inference_batch(src_text)
|
||||
best_translations = ensemble_generator(src_ids, src_mask, return_beam_scores=args.write_scores)
|
||||
if args.write_scores:
|
||||
all_results, scores, best_translations = (
|
||||
best_translations[0],
|
||||
best_translations[1],
|
||||
best_translations[2],
|
||||
)
|
||||
scores = scores.view(-1).data.cpu().numpy().tolist()
|
||||
all_scores += scores
|
||||
src_texts += [item for item in src_text for i in range(args.beam_size)]
|
||||
all_results = models[0].ids_to_postprocessed_text(
|
||||
all_results, models[0].decoder_tokenizer, models[0].target_processor
|
||||
)
|
||||
tgt_text_all += all_results
|
||||
best_translations = models[0].ids_to_postprocessed_text(
|
||||
best_translations, models[0].decoder_tokenizer, models[0].target_processor
|
||||
)
|
||||
tgt_text += best_translations
|
||||
else:
|
||||
model = models[0]
|
||||
best_translations = model.translate(
|
||||
text=src_text,
|
||||
source_lang=args.source_lang,
|
||||
target_lang=args.target_lang,
|
||||
return_beam_scores=args.write_scores,
|
||||
log_timing=args.write_timing,
|
||||
)
|
||||
|
||||
if args.write_timing:
|
||||
*best_translations, timing_dict = best_translations
|
||||
all_timing.append(timing_dict)
|
||||
else:
|
||||
best_translations = (best_translations,)
|
||||
|
||||
if args.write_scores:
|
||||
all_results, scores, best_translations = (
|
||||
best_translations[0],
|
||||
best_translations[1],
|
||||
best_translations[2],
|
||||
)
|
||||
all_scores += scores
|
||||
src_texts += [item for item in src_text for i in range(args.beam_size)]
|
||||
tgt_text_all += all_results
|
||||
else:
|
||||
best_translations = best_translations[0]
|
||||
|
||||
tgt_text += best_translations
|
||||
|
||||
print(f"Translated {len(tgt_text)} sentences")
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
@ -71,6 +128,11 @@ def main():
|
|||
action="store_true",
|
||||
help="Whether to write a separate file with scores not including length penalties corresponding to each beam hypothesis (.score suffix)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write_timing",
|
||||
action="store_true",
|
||||
help="Whether to write a separate file with detailed timing info (.timing.json suffix)",
|
||||
)
|
||||
# shallow fusion specific parameters
|
||||
parser.add_argument(
|
||||
"--lm_model",
|
||||
|
@ -92,11 +154,15 @@ def main():
|
|||
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path).eval()
|
||||
models.append(model)
|
||||
|
||||
if (len(models) > 1) and (args.write_timing):
|
||||
raise RuntimeError("Cannot measure timing when more than 1 model is used")
|
||||
|
||||
src_text = []
|
||||
tgt_text = []
|
||||
tgt_text_all = []
|
||||
src_texts = []
|
||||
all_scores = []
|
||||
all_timing = []
|
||||
|
||||
if torch.cuda.is_available():
|
||||
models = [model.cuda() for model in models]
|
||||
|
@ -124,6 +190,7 @@ def main():
|
|||
)
|
||||
else:
|
||||
model = models[0]
|
||||
ensemble_generator = None
|
||||
if lm_model is not None:
|
||||
model.beam_search = BeamSearchSequenceGeneratorWithLanguageModel(
|
||||
embedding=model.decoder.embedding,
|
||||
|
@ -155,91 +222,49 @@ def main():
|
|||
|
||||
logging.info(f"Translating: {args.srctext}")
|
||||
|
||||
count = 0
|
||||
with open(args.srctext, 'r') as src_f:
|
||||
for line in src_f:
|
||||
src_text.append(line.strip())
|
||||
if len(src_text) == args.batch_size:
|
||||
if len(models) > 1:
|
||||
src_ids, src_mask = models[0].prepare_inference_batch(src_text)
|
||||
best_translations = ensemble_generator(src_ids, src_mask, return_beam_scores=args.write_scores)
|
||||
if args.write_scores:
|
||||
all_results, scores, best_translations = (
|
||||
best_translations[0],
|
||||
best_translations[1],
|
||||
best_translations[2],
|
||||
)
|
||||
scores = scores.view(-1).data.cpu().numpy().tolist()
|
||||
all_scores += scores
|
||||
src_texts += [item for item in src_text for i in range(args.beam_size)]
|
||||
all_results = models[0].ids_to_postprocessed_text(
|
||||
all_results, models[0].decoder_tokenizer, models[0].target_processor
|
||||
)
|
||||
tgt_text_all += all_results
|
||||
best_translations = models[0].ids_to_postprocessed_text(
|
||||
best_translations, models[0].decoder_tokenizer, models[0].target_processor
|
||||
# warmup when measuring timing
|
||||
if not all_timing:
|
||||
print("running a warmup batch")
|
||||
translate_text(
|
||||
models=models,
|
||||
args=args,
|
||||
src_text=src_text,
|
||||
tgt_text=[],
|
||||
tgt_text_all=[],
|
||||
src_texts=[],
|
||||
all_scores=[],
|
||||
all_timing=[],
|
||||
ensemble_generator=ensemble_generator,
|
||||
)
|
||||
tgt_text += best_translations
|
||||
else:
|
||||
best_translations = model.translate(
|
||||
text=src_text,
|
||||
source_lang=args.source_lang,
|
||||
target_lang=args.target_lang,
|
||||
return_beam_scores=args.write_scores,
|
||||
)
|
||||
if args.write_scores:
|
||||
all_results, scores, best_translations = (
|
||||
best_translations[0],
|
||||
best_translations[1],
|
||||
best_translations[2],
|
||||
)
|
||||
all_scores += scores
|
||||
src_texts += [item for item in src_text for i in range(args.beam_size)]
|
||||
tgt_text_all += all_results
|
||||
tgt_text += best_translations
|
||||
translate_text(
|
||||
models=models,
|
||||
args=args,
|
||||
src_text=src_text,
|
||||
tgt_text=tgt_text,
|
||||
tgt_text_all=tgt_text_all,
|
||||
src_texts=src_texts,
|
||||
all_scores=all_scores,
|
||||
all_timing=all_timing,
|
||||
ensemble_generator=ensemble_generator,
|
||||
)
|
||||
src_text = []
|
||||
print(f"Translated {count + 1} sentences")
|
||||
count += 1
|
||||
|
||||
if len(src_text) > 0:
|
||||
if len(models) > 1:
|
||||
src_ids, src_mask = models[0].prepare_inference_batch(src_text)
|
||||
best_translations = ensemble_generator(src_ids, src_mask, return_beam_scores=args.write_scores)
|
||||
if args.write_scores:
|
||||
all_results, scores, best_translations = (
|
||||
best_translations[0],
|
||||
best_translations[1],
|
||||
best_translations[2],
|
||||
)
|
||||
scores = scores.view(-1).data.cpu().numpy().tolist()
|
||||
all_scores += scores
|
||||
src_texts += [item for item in src_text for i in range(args.beam_size)]
|
||||
all_results = models[0].ids_to_postprocessed_text(
|
||||
all_results, models[0].decoder_tokenizer, models[0].target_processor
|
||||
)
|
||||
tgt_text_all += all_results
|
||||
best_translations = models[0].ids_to_postprocessed_text(
|
||||
best_translations, models[0].decoder_tokenizer, models[0].target_processor
|
||||
)
|
||||
tgt_text += best_translations
|
||||
else:
|
||||
best_translations = model.translate(
|
||||
text=src_text,
|
||||
source_lang=args.source_lang,
|
||||
target_lang=args.target_lang,
|
||||
return_beam_scores=args.write_scores,
|
||||
)
|
||||
if args.write_scores:
|
||||
all_results, scores, best_translations = (
|
||||
best_translations[0],
|
||||
best_translations[1],
|
||||
best_translations[2],
|
||||
)
|
||||
all_scores += scores
|
||||
src_texts += [item for item in src_text for i in range(args.beam_size)]
|
||||
tgt_text_all += all_results
|
||||
tgt_text += best_translations
|
||||
src_text = []
|
||||
print(f"Translated {count} sentences")
|
||||
translate_text(
|
||||
models=models,
|
||||
args=args,
|
||||
src_text=src_text,
|
||||
tgt_text=tgt_text,
|
||||
tgt_text_all=tgt_text_all,
|
||||
src_texts=src_texts,
|
||||
all_scores=all_scores,
|
||||
all_timing=all_timing,
|
||||
ensemble_generator=ensemble_generator,
|
||||
)
|
||||
|
||||
with open(args.tgtout, 'w') as tgt_f:
|
||||
for line in tgt_text:
|
||||
|
@ -250,6 +275,16 @@ def main():
|
|||
for line, score, inp in zip(tgt_text_all, all_scores, src_texts):
|
||||
tgt_f_scores.write(inp + "\t" + line + "\t" + str(score) + "\n")
|
||||
|
||||
if args.write_timing:
|
||||
# collect list of dicts to a dict of lists
|
||||
timing_dict = {}
|
||||
if len(all_timing):
|
||||
for k in all_timing[0].keys():
|
||||
timing_dict[k] = [t[k] for t in all_timing]
|
||||
|
||||
with open(args.tgtout + '.timing.json', 'w') as timing_fh:
|
||||
json.dump(timing_dict, timing_fh)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main() # noqa pylint: disable=no-value-for-parameter
|
||||
|
|
|
@ -317,13 +317,18 @@ class MTBottleneckModel(MTEncDecModel):
|
|||
inputs: a list of string containing detokenized inputs
|
||||
"""
|
||||
mode = self.training
|
||||
timer = cache.get("timer", None)
|
||||
try:
|
||||
self.eval()
|
||||
|
||||
# build posterior distribution q(x|z)
|
||||
if ("z" not in cache) or ("z_mean" not in cache) or ("z_mask" not in cache):
|
||||
if timer is not None:
|
||||
timer.start("encoder")
|
||||
enc_hiddens, enc_mask = self.encoder(input_ids=src, encoder_mask=src_mask, return_mask=True)
|
||||
z, z_mean, _ = self.encode_latent(hidden=enc_hiddens)
|
||||
if timer is not None:
|
||||
timer.stop("encoder")
|
||||
else:
|
||||
enc_mask = cache["z_mask"]
|
||||
z = cache["z"]
|
||||
|
@ -332,8 +337,8 @@ class MTBottleneckModel(MTEncDecModel):
|
|||
if getattr(self, "deterministic_translate", True):
|
||||
z = z_mean
|
||||
|
||||
if cache.get("timer", None) is not None:
|
||||
cache["timer"].start("sampler")
|
||||
if timer is not None:
|
||||
timer.start("sampler")
|
||||
# decoding cross attention context
|
||||
context_hiddens = self.latent2hidden(z)
|
||||
|
||||
|
@ -342,8 +347,8 @@ class MTBottleneckModel(MTEncDecModel):
|
|||
encoder_input_mask=enc_mask,
|
||||
return_beam_scores=return_beam_scores,
|
||||
)
|
||||
if cache.get("timer", None) is not None:
|
||||
cache["timer"].stop("sampler")
|
||||
if timer is not None:
|
||||
timer.stop("sampler")
|
||||
|
||||
if return_beam_scores:
|
||||
all_translations, scores, best_translations = best_translations
|
||||
|
|
|
@ -46,7 +46,7 @@ from nemo.collections.nlp.modules.common.lm_utils import get_transformer
|
|||
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
|
||||
from nemo.collections.nlp.modules.common.transformer import BeamSearchSequenceGenerator, TopKSequenceGenerator
|
||||
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
||||
from nemo.utils import logging, model_utils
|
||||
from nemo.utils import logging, model_utils, timers
|
||||
|
||||
__all__ = ['MTEncDecModel']
|
||||
|
||||
|
@ -765,7 +765,9 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
return translations
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_translate(self, src: torch.LongTensor, src_mask: torch.LongTensor, return_beam_scores: bool = False):
|
||||
def batch_translate(
|
||||
self, src: torch.LongTensor, src_mask: torch.LongTensor, return_beam_scores: bool = False, cache={}
|
||||
):
|
||||
"""
|
||||
Translates a minibatch of inputs from source language to target language.
|
||||
Args:
|
||||
|
@ -776,12 +778,20 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
inputs: a list of string containing detokenized inputs
|
||||
"""
|
||||
mode = self.training
|
||||
timer = cache.get("timer", None)
|
||||
try:
|
||||
self.eval()
|
||||
if timer is not None:
|
||||
timer.start("encoder")
|
||||
src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask)
|
||||
if timer is not None:
|
||||
timer.stop("encoder")
|
||||
timer.start("sampler")
|
||||
best_translations = self.beam_search(
|
||||
encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask, return_beam_scores=return_beam_scores
|
||||
)
|
||||
if timer is not None:
|
||||
timer.stop("sampler")
|
||||
if return_beam_scores:
|
||||
all_translations, scores, best_translations = best_translations
|
||||
scores = scores.view(-1)
|
||||
|
@ -827,7 +837,12 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
# TODO: We should drop source/target_lang arguments in favor of using self.src/tgt_language
|
||||
@torch.no_grad()
|
||||
def translate(
|
||||
self, text: List[str], source_lang: str = None, target_lang: str = None, return_beam_scores: bool = False
|
||||
self,
|
||||
text: List[str],
|
||||
source_lang: str = None,
|
||||
target_lang: str = None,
|
||||
return_beam_scores: bool = False,
|
||||
log_timing: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Translates list of sentences from source language to target language.
|
||||
|
@ -855,19 +870,41 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
elif tgt_symbol in self.multilingual_ids:
|
||||
prepend_ids = [tgt_symbol]
|
||||
|
||||
if log_timing:
|
||||
timer = timers.NamedTimer()
|
||||
else:
|
||||
timer = None
|
||||
|
||||
cache = {
|
||||
"timer": timer,
|
||||
}
|
||||
|
||||
try:
|
||||
self.eval()
|
||||
src, src_mask = self.prepare_inference_batch(text, prepend_ids)
|
||||
if return_beam_scores:
|
||||
_, all_translations, scores, best_translations = self.batch_translate(
|
||||
src, src_mask, return_beam_scores=True
|
||||
src, src_mask, return_beam_scores=True, cache=cache,
|
||||
)
|
||||
return all_translations, scores, best_translations
|
||||
return_val = all_translations, scores, best_translations
|
||||
else:
|
||||
_, translations = self.batch_translate(src, src_mask, return_beam_scores=False)
|
||||
_, best_translations = self.batch_translate(src, src_mask, return_beam_scores=False, cache=cache)
|
||||
return_val = best_translations
|
||||
finally:
|
||||
self.train(mode=mode)
|
||||
return translations
|
||||
|
||||
if log_timing:
|
||||
timing = timer.export()
|
||||
timing["mean_src_length"] = src_mask.sum().cpu().item() / src_mask.shape[0]
|
||||
tgt, tgt_mask = self.prepare_inference_batch(best_translations, prepend_ids)
|
||||
timing["mean_tgt_length"] = tgt_mask.sum().cpu().item() / tgt_mask.shape[0]
|
||||
|
||||
if type(return_val) is tuple:
|
||||
return_val = return_val + (timing,)
|
||||
else:
|
||||
return_val = (return_val, timing)
|
||||
|
||||
return return_val
|
||||
|
||||
@classmethod
|
||||
def list_available_models(cls) -> Optional[Dict[str, str]]:
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
|
||||
|
||||
# =============================================================================#
|
||||
# Auxiliary methods
|
||||
# =============================================================================#
|
||||
|
||||
|
||||
def read_batch(fh, batch_size):
|
||||
"""
|
||||
Reads a batch (or smaller) chunk of lines.
|
||||
"""
|
||||
lines = []
|
||||
for i in range(batch_size):
|
||||
l = fh.readline()
|
||||
if l is None:
|
||||
break
|
||||
else:
|
||||
lines.append(l.strip())
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def tokenize_line(line, tokenizer):
|
||||
"""
|
||||
Returns a tokenized line
|
||||
"""
|
||||
tokens = tokenizer.text_to_ids(line)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def line_len(line, tokenizer):
|
||||
"""
|
||||
Returns a tokenized length of a text line
|
||||
"""
|
||||
tokens = tokenize_line(line, tokenizer)
|
||||
|
||||
return len(tokens)
|
||||
|
||||
|
||||
# =============================================================================#
|
||||
# Main script
|
||||
# =============================================================================#
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Collects statistics over tokenized dataset')
|
||||
parser.add_argument('input_files', metavar='N', type=str, nargs='+', help='Input files to parse')
|
||||
parser.add_argument(
|
||||
'--tokenizer_library', type=str, required=True, help='Path to pre-trained nemo-supported tokenizer model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tokenizer_model', type=str, required=True, help='Path to pre-trained nemo-supported tokenizer model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_workers', type=int, default=mp.cpu_count(), help='Number of workers (default to number of CPUs)'
|
||||
)
|
||||
parser.add_argument('--max_lines', type=int, default=-1, help='Max number of lines to parse')
|
||||
parser.add_argument('--batch_size', type=int, default=10000000, help='Batch size to parse in parallel')
|
||||
parser.add_argument('--out_dir', type=str, default="", help='Path to store data and plots')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = get_nmt_tokenizer(library=args.tokenizer_library, tokenizer_model=args.tokenizer_model,)
|
||||
|
||||
all_len = []
|
||||
|
||||
for fn in args.input_files:
|
||||
print(f"Parsing fn = {fn}")
|
||||
# read file
|
||||
fh = open(fn)
|
||||
|
||||
# read all batches
|
||||
while True:
|
||||
lines = read_batch(fh, args.batch_size)
|
||||
|
||||
# move to next file when no lines are read
|
||||
if not lines:
|
||||
break
|
||||
|
||||
# tokenize lines
|
||||
with mp.Pool(args.num_workers) as p:
|
||||
all_len.extend(p.map(partial(line_len, tokenizer=tokenizer), lines))
|
||||
|
||||
print(f"{fn}: Parsed {len(all_len)} lines")
|
||||
|
||||
# early stop, if required
|
||||
if (args.max_lines > 0) and (len(all_len) >= args.max_lines):
|
||||
lines = lines[: args.max_lines]
|
||||
break
|
||||
|
||||
# early stop, if required
|
||||
if (args.max_lines > 0) and (len(all_len) >= args.max_lines):
|
||||
lines = lines[: args.max_lines]
|
||||
break
|
||||
|
||||
# compute stats
|
||||
|
||||
# save all results
|
||||
if args.out_dir:
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.mkdir(args.out_dir)
|
||||
|
||||
stats = {
|
||||
"samples": int(len(all_len)),
|
||||
"mean": float(np.mean(all_len)),
|
||||
"stdev": float(np.std(all_len)),
|
||||
"min": float(np.min(all_len)),
|
||||
"max": float(np.max(all_len)),
|
||||
"median": float(np.median(all_len)),
|
||||
}
|
||||
|
||||
print(f"stats = \n{stats}")
|
||||
|
||||
# save all results
|
||||
if args.out_dir:
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
fh = open(os.path.join(args.out_dir, "lengths.txt"), "w")
|
||||
fh.writelines(["{l}\n".format(l=l) for l in all_len])
|
||||
|
||||
json.dump(stats, open(os.path.join(args.out_dir, "stats.json"), "w"))
|
||||
|
||||
fig = plt.hist(all_len)
|
||||
plt.savefig(os.path.join(args.out_dir, "lengths_hist.pdf"))
|
99
scripts/neural_machine_translation/plot_detailed_timing.py
Executable file
99
scripts/neural_machine_translation/plot_detailed_timing.py
Executable file
|
@ -0,0 +1,99 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script takes as an input XXXX.json files
|
||||
(i.e., the output of nmt_transformer_infer.py --write_timing)
|
||||
and creates plots XXX.PLOT_NAME.png at the same path.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# =============================================================================#
|
||||
# Control Variables
|
||||
# =============================================================================#
|
||||
|
||||
PLOTS_EXT = "pdf"
|
||||
PLOT_TITLE = False
|
||||
PLOT_XLABEL = True
|
||||
PLOT_YLABEL = True
|
||||
PLOT_LABEL_FONT_SIZE = 16
|
||||
PLOT_GRID = True
|
||||
|
||||
# =============================================================================#
|
||||
# Helper functions
|
||||
# =============================================================================#
|
||||
|
||||
|
||||
def plot_timing(lengths, timings, lengths_name, timings_name, fig=None):
|
||||
if fig is None:
|
||||
fig = plt.figure()
|
||||
|
||||
plt.scatter(lengths, timings, label=timings_name)
|
||||
if PLOT_XLABEL:
|
||||
plt.xlabel(f"{lengths_name} [tokens]", fontsize=PLOT_LABEL_FONT_SIZE)
|
||||
if PLOT_YLABEL:
|
||||
plt.ylabel(f"{timings_name} [sec]", fontsize=PLOT_LABEL_FONT_SIZE)
|
||||
if PLOT_GRID:
|
||||
plt.grid(True)
|
||||
if PLOT_TITLE:
|
||||
plt.title(f"{timings_name} vs. {lengths_name}")
|
||||
|
||||
plt.xticks(fontsize=PLOT_LABEL_FONT_SIZE)
|
||||
plt.yticks(fontsize=PLOT_LABEL_FONT_SIZE)
|
||||
plt.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
# =============================================================================#
|
||||
# Main script
|
||||
# =============================================================================#
|
||||
if __name__ == "__main__":
|
||||
print("Usage: plot_detailed_timing.py <JSON FILE> <SJON FILE> ...")
|
||||
for timing_fn in sys.argv[1:]:
|
||||
# load data
|
||||
print(f"Parsing file = {timing_fn}")
|
||||
data = json.load(open(timing_fn))
|
||||
|
||||
# plot data
|
||||
gifs_dict = {}
|
||||
gifs_dict["encoder-src_len"] = plot_timing(
|
||||
lengths=data["mean_src_length"],
|
||||
timings=data["encoder"],
|
||||
lengths_name="src length",
|
||||
timings_name="encoder",
|
||||
)
|
||||
gifs_dict["sampler-src_len"] = plot_timing(
|
||||
lengths=data["mean_src_length"],
|
||||
timings=data["sampler"],
|
||||
lengths_name="src length",
|
||||
timings_name="sampler",
|
||||
)
|
||||
gifs_dict["sampler-tgt_len"] = plot_timing(
|
||||
lengths=data["mean_tgt_length"],
|
||||
timings=data["sampler"],
|
||||
lengths_name="tgt length",
|
||||
timings_name="sampler",
|
||||
)
|
||||
|
||||
# save data
|
||||
base_fn = os.path.splitext(timing_fn)[0]
|
||||
for name, fig in gifs_dict.items():
|
||||
plot_fn = f"{base_fn}.{name}.{PLOTS_EXT}"
|
||||
print(f"Saving pot = {plot_fn}")
|
||||
fig.savefig(plot_fn)
|
Loading…
Reference in a new issue