delete nltk download for TN (#3028)
* delete nltk Signed-off-by: Yang Zhang <yangzhang@nvidia.com> * style fix Signed-off-by: Yang Zhang <yangzhang@nvidia.com> * remove unused import Signed-off-by: Yang Zhang <yangzhang@nvidia.com>
This commit is contained in:
parent
db47f5dcd6
commit
620e8a8986
|
@ -17,7 +17,6 @@ import os
|
|||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
import wordninja
|
||||
from omegaconf import DictConfig
|
||||
|
@ -44,8 +43,6 @@ try:
|
|||
except (ModuleNotFoundError, ImportError):
|
||||
PYNINI_AVAILABLE = False
|
||||
|
||||
nltk.download('punkt')
|
||||
|
||||
|
||||
__all__ = ['DuplexDecoderModel']
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
from time import perf_counter
|
||||
from typing import List, Optional
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -31,9 +30,6 @@ from nemo.core.classes.common import PretrainedModelInfo
|
|||
from nemo.utils import logging
|
||||
from nemo.utils.decorators.experimental import experimental
|
||||
|
||||
nltk.download('punkt')
|
||||
|
||||
|
||||
__all__ = ['DuplexTaggerModel']
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import torch.nn as nn
|
|||
from tqdm import tqdm
|
||||
|
||||
from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset, constants
|
||||
from nemo.collections.nlp.data.text_normalization.utils import basic_tokenize, post_process_punct
|
||||
from nemo.collections.nlp.data.text_normalization.utils import post_process_punct
|
||||
from nemo.collections.nlp.models.duplex_text_normalization.utils import get_formatted_string
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.decorators.experimental import experimental
|
||||
|
@ -195,7 +195,7 @@ class DuplexTextNormalizationModel(nn.Module):
|
|||
error_f.write('Forward Problem (TN)\n')
|
||||
tn_error_ctx += 1
|
||||
|
||||
formatted_input_str = get_formatted_string(basic_tokenize(_input, lang=self.lang))
|
||||
formatted_input_str = get_formatted_string(self.decoder.processor.tokenize(_input).split())
|
||||
formatted_tag_pred_str = get_formatted_string(tag_pred)
|
||||
class_str = " ".join(classes)
|
||||
error_f.write(f'Original Input : {_input}\n')
|
||||
|
|
Loading…
Reference in a new issue