Provide NMT gRPC models via a directory instead of individual files (#2773)
* Provide models via a directory instead of individual files Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Style fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
parent
66102b885c
commit
3a419ac1f8
|
@ -28,10 +28,7 @@ from nemo.utils import logging
|
|||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
action="append",
|
||||
help="List of .nemo files specified by using --model xyz.nemo multiple times.",
|
||||
"--model_dir", required=True, type=str, help="Path to a folder containing .nemo translation model files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--punctuation_model",
|
||||
|
@ -40,12 +37,6 @@ def get_args():
|
|||
help="Optionally provide a path a .nemo file for punctation and capitalization (recommend if working with Riva speech recognition outputs)",
|
||||
)
|
||||
parser.add_argument("--port", default=50052, type=int, required=False)
|
||||
parser.add_argument(
|
||||
"--lang_directions",
|
||||
required=False,
|
||||
action="append",
|
||||
help="Use this arg if any of your models don't have the src_language or tgt_language attributes.",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=256, help="Maximum number of batches to process")
|
||||
parser.add_argument("--beam_size", type=int, default=1, help="Beam Size")
|
||||
parser.add_argument("--len_pen", type=float, default=0.6, help="Length Penalty")
|
||||
|
@ -65,14 +56,7 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
|
|||
"""Provides methods that implement functionality of route guide server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_paths,
|
||||
punctuation_model_path,
|
||||
beam_size=1,
|
||||
len_pen=0.6,
|
||||
max_delta_length=5,
|
||||
batch_size=256,
|
||||
lang_directions=[],
|
||||
self, model_dir, punctuation_model_path, beam_size=1, len_pen=0.6, max_delta_length=5, batch_size=256,
|
||||
):
|
||||
self._models = {}
|
||||
self._beam_size = beam_size
|
||||
|
@ -80,31 +64,14 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
|
|||
self._max_delta_length = max_delta_length
|
||||
self._batch_size = batch_size
|
||||
self._punctuation_model_path = punctuation_model_path
|
||||
self._model_paths = model_paths
|
||||
self._lang_directions = lang_directions
|
||||
self._model_dir = model_dir
|
||||
|
||||
# __TODO__ make this easier to use in the future.
|
||||
# If lang direction is provided for one model, it must be provided for all of them.
|
||||
model_paths = [os.path.join(model_dir, fname) for fname in os.listdir(model_dir) if fname.endswith('.nemo')]
|
||||
|
||||
if len(self._lang_directions) != 0:
|
||||
if len(self._lang_directions) != len(self._model_paths):
|
||||
raise ValueError(
|
||||
f"Found a different number of models ({len(self._model_paths)}) and language directions ({len(self._lang_directions)})"
|
||||
)
|
||||
|
||||
for idx, model_path in enumerate(self._model_paths):
|
||||
for idx, model_path in enumerate(model_paths):
|
||||
assert os.path.exists(model_path)
|
||||
logging.info(f"Loading model {model_path}")
|
||||
if len(self._lang_directions) > 0:
|
||||
if len(self._lang_directions[idx].split('-')) != 2:
|
||||
raise ValueError(
|
||||
f"Language direction must of the form lang1-lang2, got {self._lang_directions[idx]}"
|
||||
)
|
||||
src_language = self._lang_directions[idx].split('-')[0]
|
||||
tgt_language = self._lang_directions[idx].split('-')[1]
|
||||
else:
|
||||
src_language, tgt_language = None, None
|
||||
self._load_model(model_path, src_language, tgt_language)
|
||||
self._load_model(model_path)
|
||||
|
||||
if self._punctuation_model_path != "":
|
||||
assert os.path.exists(punctuation_model_path)
|
||||
|
@ -125,7 +92,7 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
|
|||
if torch.cuda.is_available():
|
||||
self.punctuation_model = self.punctuation_model.cuda()
|
||||
|
||||
def _load_model(self, model_path, src_language=None, tgt_language=None):
|
||||
def _load_model(self, model_path):
|
||||
if model_path.endswith(".nemo"):
|
||||
logging.info("Attempting to initialize from .nemo file")
|
||||
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path)
|
||||
|
@ -138,19 +105,13 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
|
|||
else:
|
||||
raise NotImplemented(f"Only support .nemo files, but got: {model_path}")
|
||||
|
||||
if (not hasattr(model, "src_language") or not hasattr(model, "tgt_language")) and (
|
||||
src_language is None or tgt_language is None
|
||||
):
|
||||
if not hasattr(model, "src_language") or not hasattr(model, "tgt_language"):
|
||||
raise ValueError(
|
||||
f"Could not find src_language and tgt_language in model attributes nor in --lang_directions. Please specify --lang_directions for all models. Ex: --lang_directions en-es --lang_directions en-de etc."
|
||||
f"Could not find src_language and tgt_language in model attributes. If using NeMo rc1 checkpoints, please edit the config files to add model.src_language and model.tgt_language"
|
||||
)
|
||||
|
||||
if src_language is not None and tgt_language is not None:
|
||||
model.src_language = src_language
|
||||
model.tgt_language = tgt_language
|
||||
else:
|
||||
src_language = model.src_language
|
||||
tgt_language = model.tgt_language
|
||||
src_language = model.src_language
|
||||
tgt_language = model.tgt_language
|
||||
|
||||
if src_language not in self._models:
|
||||
self._models[src_language] = {}
|
||||
|
@ -196,13 +157,12 @@ def serve():
|
|||
args = get_args()
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
servicer = RivaTranslateServicer(
|
||||
model_paths=args.model,
|
||||
model_dir=args.model_dir,
|
||||
punctuation_model_path=args.punctuation_model,
|
||||
beam_size=args.beam_size,
|
||||
len_pen=args.len_pen,
|
||||
batch_size=args.batch_size,
|
||||
max_delta_length=args.max_delta_length,
|
||||
lang_directions=args.lang_directions,
|
||||
)
|
||||
nmtsrv.add_RivaTranslateServicer_to_server(servicer, server)
|
||||
server.add_insecure_port('[::]:' + str(args.port))
|
||||
|
|
Loading…
Reference in a new issue