delete nltk download for TN (#3028)

* delete nltk

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* style fix

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* remove unused import

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>
This commit is contained in:
Yang Zhang 2021-10-20 14:51:23 -07:00 committed by GitHub
parent db47f5dcd6
commit 620e8a8986
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 2 additions and 9 deletions

View file

@ -17,7 +17,6 @@ import os
from collections import defaultdict
from typing import Dict, List, Optional, Union
import nltk
import torch
import wordninja
from omegaconf import DictConfig
@ -44,8 +43,6 @@ try:
except (ModuleNotFoundError, ImportError):
PYNINI_AVAILABLE = False
nltk.download('punkt')
__all__ = ['DuplexDecoderModel']

View file

@ -15,7 +15,6 @@
from time import perf_counter
from typing import List, Optional
import nltk
import torch
from omegaconf import DictConfig
from pytorch_lightning import Trainer
@ -31,9 +30,6 @@ from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging
from nemo.utils.decorators.experimental import experimental
nltk.download('punkt')
__all__ = ['DuplexTaggerModel']

View file

@ -21,7 +21,7 @@ import torch.nn as nn
from tqdm import tqdm
from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset, constants
from nemo.collections.nlp.data.text_normalization.utils import basic_tokenize, post_process_punct
from nemo.collections.nlp.data.text_normalization.utils import post_process_punct
from nemo.collections.nlp.models.duplex_text_normalization.utils import get_formatted_string
from nemo.utils import logging
from nemo.utils.decorators.experimental import experimental
@ -195,7 +195,7 @@ class DuplexTextNormalizationModel(nn.Module):
error_f.write('Forward Problem (TN)\n')
tn_error_ctx += 1
formatted_input_str = get_formatted_string(basic_tokenize(_input, lang=self.lang))
formatted_input_str = get_formatted_string(self.decoder.processor.tokenize(_input).split())
formatted_tag_pred_str = get_formatted_string(tag_pred)
class_str = " ".join(classes)
error_f.write(f'Original Input : {_input}\n')