Provide NMT gRPC models via a directory instead of individual files (#2773)

* Provide models via a directory instead of individual files

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
Sandeep Subramanian 2021-09-09 22:58:45 -07:00 committed by GitHub
parent 66102b885c
commit 3a419ac1f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -28,10 +28,7 @@ from nemo.utils import logging
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
required=True,
action="append",
help="List of .nemo files specified by using --model xyz.nemo multiple times.",
"--model_dir", required=True, type=str, help="Path to a folder containing .nemo translation model files.",
)
parser.add_argument(
"--punctuation_model",
@ -40,12 +37,6 @@ def get_args():
help="Optionally provide a path a .nemo file for punctation and capitalization (recommend if working with Riva speech recognition outputs)",
)
parser.add_argument("--port", default=50052, type=int, required=False)
parser.add_argument(
"--lang_directions",
required=False,
action="append",
help="Use this arg if any of your models don't have the src_language or tgt_language attributes.",
)
parser.add_argument("--batch_size", type=int, default=256, help="Maximum number of batches to process")
parser.add_argument("--beam_size", type=int, default=1, help="Beam Size")
parser.add_argument("--len_pen", type=float, default=0.6, help="Length Penalty")
@ -65,14 +56,7 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
"""Provides methods that implement functionality of route guide server."""
def __init__(
self,
model_paths,
punctuation_model_path,
beam_size=1,
len_pen=0.6,
max_delta_length=5,
batch_size=256,
lang_directions=[],
self, model_dir, punctuation_model_path, beam_size=1, len_pen=0.6, max_delta_length=5, batch_size=256,
):
self._models = {}
self._beam_size = beam_size
@ -80,31 +64,14 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
self._max_delta_length = max_delta_length
self._batch_size = batch_size
self._punctuation_model_path = punctuation_model_path
self._model_paths = model_paths
self._lang_directions = lang_directions
self._model_dir = model_dir
# __TODO__ make this easier to use in the future.
# If lang direction is provided for one model, it must be provided for all of them.
model_paths = [os.path.join(model_dir, fname) for fname in os.listdir(model_dir) if fname.endswith('.nemo')]
if len(self._lang_directions) != 0:
if len(self._lang_directions) != len(self._model_paths):
raise ValueError(
f"Found a different number of models ({len(self._model_paths)}) and language directions ({len(self._lang_directions)})"
)
for idx, model_path in enumerate(self._model_paths):
for idx, model_path in enumerate(model_paths):
assert os.path.exists(model_path)
logging.info(f"Loading model {model_path}")
if len(self._lang_directions) > 0:
if len(self._lang_directions[idx].split('-')) != 2:
raise ValueError(
f"Language direction must of the form lang1-lang2, got {self._lang_directions[idx]}"
)
src_language = self._lang_directions[idx].split('-')[0]
tgt_language = self._lang_directions[idx].split('-')[1]
else:
src_language, tgt_language = None, None
self._load_model(model_path, src_language, tgt_language)
self._load_model(model_path)
if self._punctuation_model_path != "":
assert os.path.exists(punctuation_model_path)
@ -125,7 +92,7 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
if torch.cuda.is_available():
self.punctuation_model = self.punctuation_model.cuda()
def _load_model(self, model_path, src_language=None, tgt_language=None):
def _load_model(self, model_path):
if model_path.endswith(".nemo"):
logging.info("Attempting to initialize from .nemo file")
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path)
@ -138,19 +105,13 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
else:
raise NotImplemented(f"Only support .nemo files, but got: {model_path}")
if (not hasattr(model, "src_language") or not hasattr(model, "tgt_language")) and (
src_language is None or tgt_language is None
):
if not hasattr(model, "src_language") or not hasattr(model, "tgt_language"):
raise ValueError(
f"Could not find src_language and tgt_language in model attributes nor in --lang_directions. Please specify --lang_directions for all models. Ex: --lang_directions en-es --lang_directions en-de etc."
f"Could not find src_language and tgt_language in model attributes. If using NeMo rc1 checkpoints, please edit the config files to add model.src_language and model.tgt_language"
)
if src_language is not None and tgt_language is not None:
model.src_language = src_language
model.tgt_language = tgt_language
else:
src_language = model.src_language
tgt_language = model.tgt_language
src_language = model.src_language
tgt_language = model.tgt_language
if src_language not in self._models:
self._models[src_language] = {}
@ -196,13 +157,12 @@ def serve():
args = get_args()
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
servicer = RivaTranslateServicer(
model_paths=args.model,
model_dir=args.model_dir,
punctuation_model_path=args.punctuation_model,
beam_size=args.beam_size,
len_pen=args.len_pen,
batch_size=args.batch_size,
max_delta_length=args.max_delta_length,
lang_directions=args.lang_directions,
)
nmtsrv.add_RivaTranslateServicer_to_server(servicer, server)
server.add_insecure_port('[::]:' + str(args.port))