TN updates (#2983)
* moses added Signed-off-by: ekmb <ebakhturina@nvidia.com> * updates to make eval with moses work Signed-off-by: ekmb <ebakhturina@nvidia.com> * clean up Signed-off-by: ekmb <ebakhturina@nvidia.com> * clean up Signed-off-by: ekmb <ebakhturina@nvidia.com> * fix init Signed-off-by: ekmb <ebakhturina@nvidia.com> * review Signed-off-by: ekmb <ebakhturina@nvidia.com> Co-authored-by: Yang Zhang <yzhang123@users.noreply.github.com>
This commit is contained in:
parent
5b603fb80c
commit
a0e8018fd7
|
@ -120,7 +120,6 @@ data:
|
|||
data_path: train.tsv # provide the full path to the file
|
||||
batch_size: 64
|
||||
shuffle: true
|
||||
do_basic_tokenize: false
|
||||
max_insts: -1 # Maximum number of instances (-1 means no limit)
|
||||
# Refer to the text_normalization doc for more information about data augmentation
|
||||
tagger_data_augmentation: false
|
||||
|
@ -137,7 +136,6 @@ data:
|
|||
data_path: dev.tsv # provide the full path to the file. Provide multiple paths to run evaluation on multiple datasets
|
||||
batch_size: 64
|
||||
shuffle: false
|
||||
do_basic_tokenize: false
|
||||
max_insts: -1 # Maximum number of instances (-1 means no limit)
|
||||
use_cache: false # uses a cache to store the processed dataset, you may use it for large datasets for speed up (especially when using multi GPUs)
|
||||
num_workers: 3
|
||||
|
@ -148,7 +146,6 @@ data:
|
|||
data_path: test.tsv # provide the full path to the file
|
||||
batch_size: 64
|
||||
shuffle: false
|
||||
do_basic_tokenize: false
|
||||
use_cache: false # uses a cache to store the processed dataset, you may use it for large datasets for speed up (especially when using multi GPUs)
|
||||
num_workers: 3
|
||||
pin_memory: false
|
||||
|
|
|
@ -126,7 +126,6 @@ def _write_batches_to_tarfiles(
|
|||
max_len=max_seq_len,
|
||||
decoder_data_augmentation=decoder_data_augmentation,
|
||||
lang=lang,
|
||||
do_basic_tokenize=False,
|
||||
use_cache=False,
|
||||
max_insts=-1,
|
||||
do_tokenize=False,
|
||||
|
@ -187,12 +186,12 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--num_batches_per_tarfile',
|
||||
type=int,
|
||||
default=2,
|
||||
default=1000,
|
||||
help='Number batches, i.e., pickle files, included in a single .tar file.',
|
||||
)
|
||||
parser.add_argument('--n_jobs', type=int, default=-2, help='The maximum number of concurrently running jobs.')
|
||||
parser.add_argument(
|
||||
'--batch_size', type=int, default=16, help='Batch size, i.e., number of examples in a single pickle file'
|
||||
'--batch_size', type=int, default=32, help='Batch size, i.e., number of examples in a single pickle file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--factor', default=8, type=int, help='The final number of tar files will be divisible by the "factor" value'
|
||||
|
|
|
@ -60,7 +60,7 @@ from nemo.utils import logging
|
|||
|
||||
@hydra_runner(config_path="conf", config_name="duplex_tn_config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
|
||||
logging.debug(f'Config Params: {OmegaConf.to_yaml(cfg)}')
|
||||
lang = cfg.lang
|
||||
|
||||
if cfg.decoder_pretrained_model is None or cfg.tagger_pretrained_model is None:
|
||||
|
@ -88,9 +88,7 @@ def main(cfg: DictConfig) -> None:
|
|||
for i, line in enumerate(lines):
|
||||
batch.append(line.strip())
|
||||
if len(batch) == batch_size or i == len(lines) - 1:
|
||||
outputs = tn_model._infer(
|
||||
batch, [constants.DIRECTIONS_TO_MODE[mode]] * len(batch), do_basic_tokenization=True,
|
||||
)
|
||||
outputs = tn_model._infer(batch, [constants.DIRECTIONS_TO_MODE[mode]] * len(batch),)
|
||||
all_preds.extend([x for x in outputs[-1]])
|
||||
batch = []
|
||||
assert len(all_preds) == len(lines)
|
||||
|
@ -124,7 +122,7 @@ def main(cfg: DictConfig) -> None:
|
|||
if cfg.mode in ['tn', 'joint']:
|
||||
directions.append(constants.DIRECTIONS_TO_MODE[constants.TN_MODE])
|
||||
inputs.append(test_input)
|
||||
outputs = tn_model._infer(inputs, directions, do_basic_tokenization=True)[-1]
|
||||
outputs = tn_model._infer(inputs, directions)[-1]
|
||||
if cfg.mode in ['joint', 'itn']:
|
||||
print(f'Prediction (ITN): {outputs[0]}')
|
||||
if cfg.mode in ['joint', 'tn']:
|
||||
|
|
|
@ -27,8 +27,9 @@ from torch.utils.data import IterableDataset
|
|||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
|
||||
from nemo.collections.nlp.data.text_normalization import constants
|
||||
from nemo.collections.nlp.data.text_normalization.utils import basic_tokenize, read_data_file
|
||||
from nemo.collections.nlp.data.text_normalization.utils import read_data_file
|
||||
from nemo.core.classes import Dataset
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.decorators.experimental import experimental
|
||||
|
@ -58,7 +59,6 @@ class TextNormalizationDecoderDataset(Dataset):
|
|||
instances that may help the decoder become more robust against the tagger's errors.
|
||||
Refer to the doc for more info.
|
||||
lang: language of the dataset
|
||||
do_basic_tokenize: a flag indicates whether to do some basic tokenization for the inputs
|
||||
use_cache: Enables caching to use pickle format to store and read data from
|
||||
max_insts: Maximum number of instances (-1 means no limit)
|
||||
do_tokenize: Tokenize each instance (set to False for Tarred dataset)
|
||||
|
@ -75,7 +75,6 @@ class TextNormalizationDecoderDataset(Dataset):
|
|||
max_len: int = 512,
|
||||
decoder_data_augmentation: bool = False,
|
||||
lang: str = "en",
|
||||
do_basic_tokenize: bool = False,
|
||||
use_cache: bool = False,
|
||||
max_insts: int = -1,
|
||||
do_tokenize: bool = True,
|
||||
|
@ -95,7 +94,7 @@ class TextNormalizationDecoderDataset(Dataset):
|
|||
data_dir, filename = os.path.split(input_file)
|
||||
tokenizer_name_normalized = tokenizer_name.replace('/', '_')
|
||||
cached_data_file = os.path.join(
|
||||
data_dir, f'cached_decoder_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}_{mode}_{max_len}.pkl'
|
||||
data_dir, f'cached_decoder_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}_{mode}_{max_len}.pkl',
|
||||
)
|
||||
|
||||
if use_cache and os.path.exists(cached_data_file):
|
||||
|
@ -117,7 +116,7 @@ class TextNormalizationDecoderDataset(Dataset):
|
|||
|
||||
logging.debug(f"Converting raw instances to DecoderDataInstance for {input_file}...")
|
||||
self.insts, all_semiotic_classes = self.__process_raw_entries(
|
||||
raw_instances, decoder_data_augmentation=decoder_data_augmentation, do_basic_tokenize=do_basic_tokenize
|
||||
raw_instances, decoder_data_augmentation=decoder_data_augmentation
|
||||
)
|
||||
logging.debug(
|
||||
f"Extracted {len(self.insts)} DecoderDateInstances out of {len(raw_instances)} raw instances."
|
||||
|
@ -134,7 +133,7 @@ class TextNormalizationDecoderDataset(Dataset):
|
|||
logging.debug(f'Processing samples, total number: {len(self.insts)}')
|
||||
self.__tokenize_samples(use_cache=use_cache, cached_data_file=cached_data_file)
|
||||
|
||||
def __process_raw_entries(self, raw_instances: List[Tuple[str]], decoder_data_augmentation, do_basic_tokenize):
|
||||
def __process_raw_entries(self, raw_instances: List[Tuple[str]], decoder_data_augmentation):
|
||||
"""
|
||||
Converts raw instances to DecoderDataInstance
|
||||
|
||||
|
@ -142,7 +141,6 @@ class TextNormalizationDecoderDataset(Dataset):
|
|||
decoder_data_augmentation (bool): a flag indicates whether to augment the dataset with additional data
|
||||
instances that may help the decoder become more robust against the tagger's errors.
|
||||
Refer to the doc for more info.
|
||||
do_basic_tokenize: a flag indicates whether to do some basic tokenization for the inputs
|
||||
|
||||
Returns:
|
||||
converted instances and all semiotic classes present in the data
|
||||
|
@ -161,14 +159,7 @@ class TextNormalizationDecoderDataset(Dataset):
|
|||
continue
|
||||
# Create a DecoderDataInstance
|
||||
inst = DecoderDataInstance(
|
||||
w_words,
|
||||
s_words,
|
||||
inst_dir,
|
||||
start_idx=ix,
|
||||
end_idx=ix + 1,
|
||||
lang=self.lang,
|
||||
semiotic_class=_class,
|
||||
do_basic_tokenize=do_basic_tokenize,
|
||||
w_words, s_words, inst_dir, start_idx=ix, end_idx=ix + 1, lang=self.lang, semiotic_class=_class
|
||||
)
|
||||
insts.append(inst)
|
||||
|
||||
|
@ -183,7 +174,6 @@ class TextNormalizationDecoderDataset(Dataset):
|
|||
end_idx=ix + 1 + noise_right,
|
||||
semiotic_class=_class,
|
||||
lang=self.lang,
|
||||
do_basic_tokenize=do_basic_tokenize,
|
||||
)
|
||||
insts.append(inst)
|
||||
|
||||
|
@ -358,7 +348,6 @@ class DecoderDataInstance:
|
|||
end_idx: The ending index of the input span (exclusively)
|
||||
lang: Language of the instance
|
||||
semiotic_class: The semiotic class of the input span (can be set to None if not available)
|
||||
do_basic_tokenize: a flag indicates whether to do some basic tokenization for the inputs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -370,8 +359,8 @@ class DecoderDataInstance:
|
|||
end_idx: int,
|
||||
lang: str,
|
||||
semiotic_class: str = None,
|
||||
do_basic_tokenize: bool = False,
|
||||
):
|
||||
processor = MosesProcessor(lang_id=lang)
|
||||
start_idx = max(start_idx, 0)
|
||||
end_idx = min(end_idx, len(w_words))
|
||||
ctx_size = constants.DECODE_CTX_SIZE
|
||||
|
@ -409,11 +398,16 @@ class DecoderDataInstance:
|
|||
c_s_words[jx] = c_w_words[jx]
|
||||
|
||||
# Extract input_words and output_words
|
||||
if do_basic_tokenize:
|
||||
c_w_words = basic_tokenize(' '.join(c_w_words), lang)
|
||||
c_s_words = basic_tokenize(' '.join(c_s_words), lang)
|
||||
c_w_words = processor.tokenize(' '.join(c_w_words)).split()
|
||||
c_s_words = processor.tokenize(' '.join(c_s_words)).split()
|
||||
|
||||
# for cases when nearby words are actually multiple tokens, e.g. '1974,'
|
||||
w_left = processor.tokenize(' '.join(w_left)).split()[-constants.DECODE_CTX_SIZE :]
|
||||
w_right = processor.tokenize(' '.join(w_right)).split()[: constants.DECODE_CTX_SIZE]
|
||||
|
||||
w_input = w_left + [extra_id_0] + c_w_words + [extra_id_1] + w_right
|
||||
s_input = s_left + [extra_id_0] + c_s_words + [extra_id_1] + s_right
|
||||
|
||||
if inst_dir == constants.INST_BACKWARD:
|
||||
input_center_words = c_s_words
|
||||
input_words = [constants.ITN_PREFIX] + s_input
|
||||
|
|
|
@ -18,8 +18,9 @@ import pickle
|
|||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
|
||||
from nemo.collections.nlp.data.text_normalization import constants
|
||||
from nemo.collections.nlp.data.text_normalization.utils import basic_tokenize, read_data_file
|
||||
from nemo.collections.nlp.data.text_normalization.utils import read_data_file
|
||||
from nemo.core.classes import Dataset
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.decorators.experimental import experimental
|
||||
|
@ -41,7 +42,6 @@ class TextNormalizationTaggerDataset(Dataset):
|
|||
tokenizer: tokenizer of the model that will be trained on the dataset
|
||||
tokenizer_name: name of the tokenizer,
|
||||
mode: should be one of the values ['tn', 'itn', 'joint']. `tn` mode is for TN only. `itn` mode is for ITN only. `joint` is for training a system that can do both TN and ITN at the same time.
|
||||
do_basic_tokenize: a flag indicates whether to do some basic tokenization before using the tokenizer of the model
|
||||
tagger_data_augmentation (bool): a flag indicates whether to augment the dataset with additional data instances
|
||||
lang: language of the dataset
|
||||
use_cache: Enables caching to use pickle format to store and read data from,
|
||||
|
@ -54,7 +54,6 @@ class TextNormalizationTaggerDataset(Dataset):
|
|||
tokenizer: PreTrainedTokenizerBase,
|
||||
tokenizer_name: str,
|
||||
mode: str,
|
||||
do_basic_tokenize: bool,
|
||||
tagger_data_augmentation: bool,
|
||||
lang: str,
|
||||
max_seq_length: int,
|
||||
|
@ -72,7 +71,7 @@ class TextNormalizationTaggerDataset(Dataset):
|
|||
data_dir, filename = os.path.split(input_file)
|
||||
tokenizer_name_normalized = tokenizer_name.replace('/', '_')
|
||||
cached_data_file = os.path.join(
|
||||
data_dir, f'cached_tagger_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}_{max_seq_length}.pkl'
|
||||
data_dir, f'cached_tagger_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}_{max_seq_length}.pkl',
|
||||
)
|
||||
|
||||
if use_cache and os.path.exists(cached_data_file):
|
||||
|
@ -110,7 +109,7 @@ class TextNormalizationTaggerDataset(Dataset):
|
|||
continue
|
||||
|
||||
# Create a new TaggerDataInstance
|
||||
inst = TaggerDataInstance(w_words, s_words, inst_dir, do_basic_tokenize)
|
||||
inst = TaggerDataInstance(w_words, s_words, inst_dir, lang=self.lang)
|
||||
insts.append(inst)
|
||||
# Data Augmentation (if enabled)
|
||||
if tagger_data_augmentation:
|
||||
|
@ -120,7 +119,7 @@ class TextNormalizationTaggerDataset(Dataset):
|
|||
filtered_w_words.append(w)
|
||||
filtered_s_words.append(s)
|
||||
if len(filtered_s_words) > 1:
|
||||
inst = TaggerDataInstance(filtered_w_words, filtered_s_words, inst_dir)
|
||||
inst = TaggerDataInstance(filtered_w_words, filtered_s_words, inst_dir, lang)
|
||||
insts.append(inst)
|
||||
|
||||
self.insts = insts
|
||||
|
@ -189,10 +188,13 @@ class TaggerDataInstance:
|
|||
w_words: List of words in a sentence in the written form
|
||||
s_words: List of words in a sentence in the spoken form
|
||||
direction: Indicates the direction of the instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).
|
||||
do_basic_tokenize: a flag indicates whether to do some basic tokenization before using the tokenizer of the model
|
||||
lang: Language
|
||||
"""
|
||||
|
||||
def __init__(self, w_words, s_words, direction, do_basic_tokenize=False):
|
||||
def __init__(self, w_words, s_words, direction, lang):
|
||||
# moses tokenization before LM tokenization
|
||||
# e.g., don't -> don 't, 12/3 -> 12 / 3
|
||||
processor = MosesProcessor(lang_id=lang)
|
||||
# Build input_words and labels
|
||||
input_words, labels = [], []
|
||||
# Task Prefix
|
||||
|
@ -203,11 +205,9 @@ class TaggerDataInstance:
|
|||
labels.append(constants.TASK_TAG)
|
||||
# Main Content
|
||||
for w_word, s_word in zip(w_words, s_words):
|
||||
# Basic tokenization (if enabled)
|
||||
if do_basic_tokenize:
|
||||
w_word = ' '.join(basic_tokenize(w_word, self.lang))
|
||||
if not s_word in constants.SPECIAL_WORDS:
|
||||
s_word = ' '.join(basic_tokenize(s_word, self.lang))
|
||||
w_word = processor.tokenize(w_word)
|
||||
if not s_word in constants.SPECIAL_WORDS:
|
||||
s_word = processor.tokenize(s_word)
|
||||
# Update input_words and labels
|
||||
if s_word == constants.SIL_WORD and direction == constants.INST_BACKWARD:
|
||||
continue
|
||||
|
|
|
@ -13,15 +13,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, List
|
||||
from typing import List
|
||||
|
||||
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
|
||||
from nemo.collections.nlp.data.text_normalization import constants
|
||||
from nemo.collections.nlp.data.text_normalization.utils import (
|
||||
basic_tokenize,
|
||||
normalize_str,
|
||||
read_data_file,
|
||||
remove_puncts,
|
||||
)
|
||||
from nemo.collections.nlp.data.text_normalization.utils import normalize_str, read_data_file, remove_puncts
|
||||
from nemo.utils.decorators.experimental import experimental
|
||||
|
||||
__all__ = ['TextNormalizationTestDataset']
|
||||
|
@ -41,7 +37,7 @@ class TextNormalizationTestDataset:
|
|||
def __init__(self, input_file: str, mode: str, lang: str):
|
||||
self.lang = lang
|
||||
insts = read_data_file(input_file, lang=lang)
|
||||
|
||||
processor = MosesProcessor(lang_id=lang)
|
||||
# Build inputs and targets
|
||||
self.directions, self.inputs, self.targets, self.classes, self.nb_spans, self.span_starts, self.span_ends = (
|
||||
[],
|
||||
|
@ -58,7 +54,6 @@ class TextNormalizationTestDataset:
|
|||
if direction == constants.INST_BACKWARD:
|
||||
if mode == constants.TN_MODE:
|
||||
continue
|
||||
|
||||
# ITN mode
|
||||
(
|
||||
processed_w_words,
|
||||
|
@ -77,10 +72,13 @@ class TextNormalizationTestDataset:
|
|||
else:
|
||||
processed_s_words.append(s_word)
|
||||
|
||||
s_word_last = processor.tokenize(processed_s_words.pop()).split()
|
||||
processed_s_words.append(" ".join(s_word_last))
|
||||
num_tokens = len(s_word_last)
|
||||
processed_nb_spans += 1
|
||||
processed_classes.append(cls)
|
||||
processed_s_span_starts.append(s_word_idx)
|
||||
s_word_idx += len(basic_tokenize(processed_s_words[-1], lang=self.lang))
|
||||
s_word_idx += num_tokens
|
||||
processed_s_span_ends.append(s_word_idx)
|
||||
processed_w_words.append(w_word)
|
||||
|
||||
|
@ -88,15 +86,13 @@ class TextNormalizationTestDataset:
|
|||
self.span_ends.append(processed_s_span_ends)
|
||||
self.classes.append(processed_classes)
|
||||
self.nb_spans.append(processed_nb_spans)
|
||||
# Basic tokenization
|
||||
input_words = basic_tokenize(' '.join(processed_s_words), lang)
|
||||
input_words = ' '.join(processed_s_words)
|
||||
# Update self.directions, self.inputs, self.targets
|
||||
self.directions.append(direction)
|
||||
self.inputs.append(' '.join(input_words))
|
||||
self.inputs.append(input_words)
|
||||
self.targets.append(
|
||||
processed_w_words
|
||||
) # is list of lists where inner list contains target tokens (not words)
|
||||
|
||||
# TN mode
|
||||
elif direction == constants.INST_FORWARD:
|
||||
if mode == constants.ITN_MODE:
|
||||
|
@ -111,29 +107,29 @@ class TextNormalizationTestDataset:
|
|||
) = ([], [], [], 0, [], [])
|
||||
w_word_idx = 0
|
||||
for cls, w_word, s_word in zip(classes, w_words, s_words):
|
||||
|
||||
# TN forward mode
|
||||
# this is done for cases like `do n't`, this w_word will be treated as 2 tokens
|
||||
w_word = processor.tokenize(w_word).split()
|
||||
num_tokens = len(w_word)
|
||||
if s_word in constants.SPECIAL_WORDS:
|
||||
processed_s_words.append(w_word)
|
||||
processed_s_words.append(" ".join(w_word))
|
||||
else:
|
||||
processed_s_words.append(s_word)
|
||||
|
||||
w_span_starts.append(w_word_idx)
|
||||
w_word_idx += len(basic_tokenize(w_word, lang=self.lang))
|
||||
w_word_idx += num_tokens
|
||||
w_span_ends.append(w_word_idx)
|
||||
processed_nb_spans += 1
|
||||
processed_classes.append(cls)
|
||||
processed_w_words.append(w_word)
|
||||
processed_w_words.extend(w_word)
|
||||
|
||||
self.span_starts.append(w_span_starts)
|
||||
self.span_ends.append(w_span_ends)
|
||||
self.classes.append(processed_classes)
|
||||
self.nb_spans.append(processed_nb_spans)
|
||||
# Basic tokenization
|
||||
input_words = basic_tokenize(' '.join(processed_w_words), lang)
|
||||
input_words = ' '.join(processed_w_words)
|
||||
# Update self.directions, self.inputs, self.targets
|
||||
self.directions.append(direction)
|
||||
self.inputs.append(' '.join(input_words))
|
||||
self.inputs.append(input_words)
|
||||
self.targets.append(
|
||||
processed_s_words
|
||||
) # is list of lists where inner list contains target tokens (not words)
|
||||
|
@ -157,7 +153,7 @@ class TextNormalizationTestDataset:
|
|||
return len(self.inputs)
|
||||
|
||||
@staticmethod
|
||||
def is_same(pred: str, target: str, inst_dir: str, lang: str):
|
||||
def is_same(pred: str, target: str, inst_dir: str):
|
||||
"""
|
||||
Function for checking whether the predicted string can be considered
|
||||
the same as the target string
|
||||
|
@ -166,18 +162,17 @@ class TextNormalizationTestDataset:
|
|||
pred: Predicted string
|
||||
target: Target string
|
||||
inst_dir: Direction of the instance (i.e., INST_BACKWARD or INST_FORWARD).
|
||||
lang: Language
|
||||
Return: an int value (0/1) indicating whether pred and target are the same.
|
||||
"""
|
||||
if inst_dir == constants.INST_BACKWARD:
|
||||
pred = remove_puncts(pred)
|
||||
target = remove_puncts(target)
|
||||
pred = normalize_str(pred, lang)
|
||||
target = normalize_str(target, lang)
|
||||
pred = normalize_str(pred)
|
||||
target = normalize_str(target)
|
||||
return int(pred == target)
|
||||
|
||||
@staticmethod
|
||||
def compute_sent_accuracy(preds: List[str], targets: List[str], inst_directions: List[str], lang: str):
|
||||
def compute_sent_accuracy(preds: List[str], targets: List[str], inst_directions: List[str]):
|
||||
"""
|
||||
Compute the sentence accuracy metric.
|
||||
|
||||
|
@ -185,7 +180,6 @@ class TextNormalizationTestDataset:
|
|||
preds: List of predicted strings.
|
||||
targets: List of target strings.
|
||||
inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD or INST_FORWARD).
|
||||
lang: Language
|
||||
Return: the sentence accuracy score
|
||||
"""
|
||||
assert len(preds) == len(targets)
|
||||
|
@ -194,7 +188,7 @@ class TextNormalizationTestDataset:
|
|||
# Sentence Accuracy
|
||||
correct_count = 0
|
||||
for inst_dir, pred, target in zip(inst_directions, preds, targets):
|
||||
correct_count += TextNormalizationTestDataset.is_same(pred, target, inst_dir, lang)
|
||||
correct_count += TextNormalizationTestDataset.is_same(pred, target, inst_dir)
|
||||
sent_accuracy = correct_count / len(targets)
|
||||
|
||||
return sent_accuracy
|
||||
|
@ -208,9 +202,7 @@ class TextNormalizationTestDataset:
|
|||
output_spans: List[List[str]],
|
||||
classes: List[List[str]],
|
||||
nb_spans: List[int],
|
||||
span_starts: List[List[int]],
|
||||
span_ends: List[List[int]],
|
||||
lang: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Compute the class based accuracy metric. This uses model's predicted tags.
|
||||
|
@ -223,9 +215,7 @@ class TextNormalizationTestDataset:
|
|||
output_spans: A list of lists where each inner list contains the decoded spans for the corresponding input sentence
|
||||
classes: A list of lists where inner list contains the class for each semiotic token in input sentence
|
||||
nb_spans: A list that contains the number of tokens in the input
|
||||
span_starts: A list of lists where inner list contains the start word index of the current token
|
||||
span_ends: A list of lists where inner list contains the end word index of the current token
|
||||
lang: Language
|
||||
Return: the class accuracy scores as dict
|
||||
"""
|
||||
|
||||
|
@ -233,7 +223,7 @@ class TextNormalizationTestDataset:
|
|||
return 'NA'
|
||||
class2stats, class2correct = defaultdict(int), defaultdict(int)
|
||||
for ix, (sent, tags) in enumerate(zip(inputs, tag_preds)):
|
||||
assert len(inputs) == len(tag_preds)
|
||||
assert len(sent) == len(tags)
|
||||
cur_words = [[] for _ in range(nb_spans[ix])]
|
||||
jx, span_idx = 0, 0
|
||||
cur_spans = output_spans[ix]
|
||||
|
@ -261,9 +251,10 @@ class TextNormalizationTestDataset:
|
|||
jx += 1
|
||||
|
||||
target_token_idx = 0
|
||||
# assert len(cur_words) == len(targets[ix])
|
||||
for class_idx in range(nb_spans[ix]):
|
||||
correct = TextNormalizationTestDataset.is_same(
|
||||
" ".join(cur_words[class_idx]), targets[ix][target_token_idx], inst_directions[ix], lang
|
||||
" ".join(cur_words[class_idx]), targets[ix][target_token_idx], inst_directions[ix]
|
||||
)
|
||||
class2correct[classes[ix][class_idx]] += correct
|
||||
target_token_idx += 1
|
||||
|
|
|
@ -37,7 +37,7 @@ def input_preprocessing(sent: str, lang: str):
|
|||
such as Δ or λ (if any).
|
||||
|
||||
Args:
|
||||
sents: input text.
|
||||
sent: input text.
|
||||
lang: language
|
||||
|
||||
Returns: preprocessed input text.
|
||||
|
@ -138,12 +138,9 @@ def process_url(tokens: List[str], outputs: List[str], lang: str):
|
|||
return outputs
|
||||
|
||||
|
||||
def normalize_str(input_str, lang):
|
||||
def normalize_str(input_str):
|
||||
""" Normalize an input string """
|
||||
input_str_tokens = basic_tokenize(input_str.strip().lower(), lang)
|
||||
input_str = ' '.join(input_str_tokens)
|
||||
input_str = input_str.replace(' ', ' ')
|
||||
return input_str
|
||||
return input_str.strip().lower().replace(" ", " ")
|
||||
|
||||
|
||||
def remove_puncts(input_str):
|
||||
|
|
|
@ -26,6 +26,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSe
|
|||
|
||||
import nemo.collections.nlp.data.text_normalization.constants as constants
|
||||
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
|
||||
from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset
|
||||
from nemo.collections.nlp.data.text_normalization.decoder_dataset import (
|
||||
TarredTextNormalizationDecoderDataset,
|
||||
TextNormalizationDecoderDataset,
|
||||
|
@ -153,6 +154,7 @@ class DuplexDecoderModel(NLPModel):
|
|||
)
|
||||
|
||||
input_centers = self._tokenizer.batch_decode(batch['input_center'], skip_special_tokens=True)
|
||||
|
||||
direction = [x[0].item() for x in batch['direction']]
|
||||
direction_str = [constants.DIRECTIONS_ID_TO_NAME[x] for x in direction]
|
||||
# apply post_processing
|
||||
|
@ -161,9 +163,12 @@ class DuplexDecoderModel(NLPModel):
|
|||
for idx, class_id in enumerate(batch['semiotic_class_id']):
|
||||
direction = constants.TASK_ID_TO_MODE[batch['direction'][idx][0].item()]
|
||||
class_name = self._val_id_to_class[dataloader_idx][class_id[0].item()]
|
||||
results[f"correct_{class_name}_{direction}"] += torch.tensor(
|
||||
labels_str[idx] == generated_texts[idx], dtype=torch.int
|
||||
).to(self.device)
|
||||
|
||||
pred_result = TextNormalizationTestDataset.is_same(
|
||||
generated_texts[idx], labels_str[idx], constants.DIRECTIONS_TO_MODE[direction]
|
||||
)
|
||||
|
||||
results[f"correct_{class_name}_{direction}"] += torch.tensor(pred_result, dtype=torch.int).to(self.device)
|
||||
results[f"total_{class_name}_{direction}"] += torch.tensor(1).to(self.device)
|
||||
|
||||
results[f"{split}_loss"] = val_loss
|
||||
|
@ -309,10 +314,6 @@ class DuplexDecoderModel(NLPModel):
|
|||
if sum(nb_spans) == 0:
|
||||
return [[]] * len(sents)
|
||||
model, tokenizer = self.model, self._tokenizer
|
||||
try:
|
||||
model_max_len = model.config.n_positions
|
||||
except AttributeError:
|
||||
model_max_len = 512
|
||||
ctx_size = constants.DECODE_CTX_SIZE
|
||||
extra_id_0 = constants.EXTRA_ID_0
|
||||
extra_id_1 = constants.EXTRA_ID_1
|
||||
|
@ -320,7 +321,7 @@ class DuplexDecoderModel(NLPModel):
|
|||
"""
|
||||
Build all_inputs - extracted spans to be transformed by the decoder model
|
||||
Inputs for TN direction have "0" prefix, while the backward, ITN direction, has prefix "1"
|
||||
"input_centers" - List[str] - ground-truth labels for the span #TODO: rename
|
||||
"input_centers" - List[str] - ground-truth labels for the span
|
||||
"""
|
||||
input_centers, input_dirs, all_inputs = [], [], []
|
||||
for ix, sent in enumerate(sents):
|
||||
|
@ -351,11 +352,12 @@ class DuplexDecoderModel(NLPModel):
|
|||
input_ids = batch['input_ids'].to(self.device)
|
||||
|
||||
generated_texts, generated_ids, sequence_toks_scores = self._generate_predictions(
|
||||
input_ids=input_ids, model_max_len=model_max_len
|
||||
input_ids=input_ids, model_max_len=self.max_sequence_len
|
||||
)
|
||||
|
||||
# Use covering grammars (if enabled)
|
||||
if self.use_cg:
|
||||
|
||||
# Compute sequence probabilities
|
||||
sequence_probs = torch.ones(len(all_inputs)).to(self.device)
|
||||
for ix, cur_toks_scores in enumerate(sequence_toks_scores):
|
||||
|
@ -528,7 +530,6 @@ class DuplexDecoderModel(NLPModel):
|
|||
if data_split == "train"
|
||||
else False,
|
||||
lang=self.lang,
|
||||
do_basic_tokenize=cfg.do_basic_tokenize,
|
||||
use_cache=cfg.get('use_cache', False),
|
||||
max_insts=cfg.get('max_insts', -1),
|
||||
do_tokenize=True,
|
||||
|
|
|
@ -131,7 +131,7 @@ class DuplexTaggerModel(NLPModel):
|
|||
|
||||
# Functions for inference
|
||||
@torch.no_grad()
|
||||
def _infer(self, sents: List[List[str]], inst_directions: List[str], do_basic_tokenization=True):
|
||||
def _infer(self, sents: List[List[str]], inst_directions: List[str]):
|
||||
""" Main function for Inference
|
||||
|
||||
Args:
|
||||
|
@ -144,10 +144,8 @@ class DuplexTaggerModel(NLPModel):
|
|||
nb_spans: A list of ints where each int indicates the number of semiotic spans in input words.
|
||||
span_starts: A list of lists where each list contains the starting locations of semiotic spans in input words.
|
||||
span_ends: A list of lists where each list contains the ending locations of semiotic spans in input words.
|
||||
do_basic_tokenization: whether to do a pre-processing to separate punctuation marks, recommended to set to True
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# Append prefix
|
||||
texts = []
|
||||
for ix, sent in enumerate(sents):
|
||||
|
@ -155,19 +153,11 @@ class DuplexTaggerModel(NLPModel):
|
|||
prefix = constants.ITN_PREFIX
|
||||
elif inst_directions[ix] == constants.INST_FORWARD:
|
||||
prefix = constants.TN_PREFIX
|
||||
if do_basic_tokenization:
|
||||
texts.append([prefix] + sent)
|
||||
else:
|
||||
texts.append(prefix + " " + sent)
|
||||
texts.append([prefix] + sent)
|
||||
|
||||
# Apply the model
|
||||
if do_basic_tokenization:
|
||||
is_split_into_words = True
|
||||
else:
|
||||
is_split_into_words = False
|
||||
|
||||
encodings = self._tokenizer(
|
||||
texts, is_split_into_words=is_split_into_words, padding=True, truncation=True, return_tensors='pt'
|
||||
texts, is_split_into_words=True, padding=True, truncation=True, return_tensors='pt'
|
||||
)
|
||||
|
||||
inputs = encodings
|
||||
|
@ -175,10 +165,7 @@ class DuplexTaggerModel(NLPModel):
|
|||
|
||||
# check that the length of the 'input_ids' equals as least the length of the original input
|
||||
# if an input symbol is missing in the tokenizer's vocabulary (such as emoji or a Chinese character), it could be skipped
|
||||
if do_basic_tokenization:
|
||||
len_texts = [len(x) for x in texts]
|
||||
else:
|
||||
len_texts = [len(x.split()) for x in texts]
|
||||
len_texts = [len(x) for x in texts]
|
||||
len_ids = [
|
||||
len(self._tokenizer.convert_ids_to_tokens(x, skip_special_tokens=True)) for x in encodings['input_ids']
|
||||
]
|
||||
|
@ -343,7 +330,6 @@ class DuplexTaggerModel(NLPModel):
|
|||
tokenizer=self._tokenizer,
|
||||
tokenizer_name=self.transformer_name,
|
||||
mode=self.mode,
|
||||
do_basic_tokenize=cfg.do_basic_tokenize,
|
||||
tagger_data_augmentation=tagger_data_augmentation,
|
||||
lang=self.lang,
|
||||
max_seq_length=self.max_sequence_len,
|
||||
|
|
|
@ -92,7 +92,11 @@ class DuplexTextNormalizationModel(nn.Module):
|
|||
) = zip(*batch_insts)
|
||||
# Inference and Running Time Measurement
|
||||
batch_start_time = perf_counter()
|
||||
batch_tag_preds, batch_output_spans, batch_final_preds = self._infer(batch_inputs, batch_dirs)
|
||||
|
||||
batch_tag_preds, batch_output_spans, batch_final_preds = self._infer(
|
||||
batch_inputs, batch_dirs, processed=True
|
||||
)
|
||||
|
||||
batch_run_time = (perf_counter() - batch_start_time) * 1000 # milliseconds
|
||||
all_run_times.append(batch_run_time)
|
||||
# Update all_dirs, all_inputs, all_tag_preds, all_final_preds and all_targets
|
||||
|
@ -149,20 +153,18 @@ class DuplexTextNormalizationModel(nn.Module):
|
|||
cur_targets_sent = [" ".join(x) for x in cur_targets]
|
||||
|
||||
sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy(
|
||||
cur_final_preds, cur_targets_sent, cur_dirs, self.lang
|
||||
cur_final_preds, cur_targets_sent, cur_dirs
|
||||
)
|
||||
|
||||
class_accuracy = TextNormalizationTestDataset.compute_class_accuracy(
|
||||
[basic_tokenize(x, lang=self.lang) for x in cur_inputs],
|
||||
[x.split() for x in cur_inputs],
|
||||
cur_targets,
|
||||
cur_tag_preds,
|
||||
cur_dirs,
|
||||
cur_output_spans,
|
||||
cur_classes,
|
||||
cur_nb_spans,
|
||||
cur_span_starts,
|
||||
cur_span_ends,
|
||||
self.lang,
|
||||
)
|
||||
if verbose:
|
||||
logging.info(f'\n============ Direction {direction} ============')
|
||||
|
@ -185,13 +187,14 @@ class DuplexTextNormalizationModel(nn.Module):
|
|||
for _input, tag_pred, final_pred, target, classes in zip(
|
||||
cur_inputs, cur_tag_preds, cur_final_preds, cur_targets_sent, cur_classes
|
||||
):
|
||||
if not TextNormalizationTestDataset.is_same(final_pred, target, direction, self.lang):
|
||||
if not TextNormalizationTestDataset.is_same(final_pred, target, direction):
|
||||
if direction == constants.INST_BACKWARD:
|
||||
error_f.write('Backward Problem (ITN)\n')
|
||||
itn_error_ctx += 1
|
||||
elif direction == constants.INST_FORWARD:
|
||||
error_f.write('Forward Problem (TN)\n')
|
||||
tn_error_ctx += 1
|
||||
|
||||
formatted_input_str = get_formatted_string(basic_tokenize(_input, lang=self.lang))
|
||||
formatted_tag_pred_str = get_formatted_string(tag_pred)
|
||||
class_str = " ".join(classes)
|
||||
|
@ -217,7 +220,7 @@ class DuplexTextNormalizationModel(nn.Module):
|
|||
return results
|
||||
|
||||
# Functions for inference
|
||||
def _infer(self, sents: List[str], inst_directions: List[str], do_basic_tokenization=True):
|
||||
def _infer(self, sents: List[str], inst_directions: List[str], processed=False):
|
||||
"""
|
||||
Main function for Inference
|
||||
|
||||
|
@ -228,8 +231,8 @@ class DuplexTextNormalizationModel(nn.Module):
|
|||
sents: A list of input texts.
|
||||
inst_directions: A list of str where each str indicates the direction of the corresponding instance \
|
||||
(i.e., constants.INST_BACKWARD for ITN or constants.INST_FORWARD for TN).
|
||||
do_basic_tokenization: whether to do a pre-processing to separate punctuation marks,
|
||||
recommended to set to True
|
||||
processed: Set to True when used with TextNormalizationTestDataset, the data is already tokenized with moses,
|
||||
repetitive moses tokenization could lead to the number of tokens and class span mismatch
|
||||
|
||||
Returns:
|
||||
tag_preds: A list of lists where the inner list contains the tag predictions from the tagger for each word in the input text.
|
||||
|
@ -238,19 +241,16 @@ class DuplexTextNormalizationModel(nn.Module):
|
|||
"""
|
||||
original_sents = [s for s in sents]
|
||||
# Separate into words
|
||||
if do_basic_tokenization:
|
||||
if not processed:
|
||||
sents = [self.decoder.processor.tokenize(x).split() for x in sents]
|
||||
else:
|
||||
sents = [x.split() for x in sents]
|
||||
|
||||
# Tagging
|
||||
# span_ends included, returns index wrt to words in input without auxiliary words
|
||||
tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(
|
||||
sents, inst_directions, do_basic_tokenization=do_basic_tokenization
|
||||
)
|
||||
tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(sents, inst_directions)
|
||||
output_spans = self.decoder._infer(sents, nb_spans, span_starts, span_ends, inst_directions)
|
||||
|
||||
if not do_basic_tokenization:
|
||||
sents = [x.split() for x in sents]
|
||||
|
||||
# Prepare final outputs
|
||||
final_outputs = []
|
||||
for ix, (sent, tags) in enumerate(zip(sents, tag_preds)):
|
||||
|
@ -268,8 +268,15 @@ class DuplexTextNormalizationModel(nn.Module):
|
|||
span_idx += 1
|
||||
while jx < len(sent) and tags[jx] == constants.I_PREFIX + constants.TRANSFORM_TAG:
|
||||
jx += 1
|
||||
cur_output_str = self.decoder.processor.detokenize(cur_words)
|
||||
cur_output_str = post_process_punct(input=original_sents[ix], nn_output=cur_output_str)
|
||||
|
||||
if processed:
|
||||
# for Class-based evaluation, don't apply Moses detokenization
|
||||
cur_output_str = " ".join(cur_words)
|
||||
else:
|
||||
# detokenize the output with Moses and fix punctuation marks to match the input
|
||||
# for interactive inference or inference from a file
|
||||
cur_output_str = self.decoder.processor.detokenize(cur_words)
|
||||
cur_output_str = post_process_punct(input=original_sents[ix], nn_output=cur_output_str)
|
||||
final_outputs.append(cur_output_str)
|
||||
except IndexError:
|
||||
logging.warning(f"Input sent is too long and will be skipped - {original_sents[ix]}")
|
||||
|
|
Loading…
Reference in a new issue