TN updates (#2983)

* moses added

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* updates to make eval with moses work

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* clean up

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* clean up

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* fix init

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* review

Signed-off-by: ekmb <ebakhturina@nvidia.com>

Co-authored-by: Yang Zhang <yzhang123@users.noreply.github.com>
This commit is contained in:
Evelina 2021-10-11 15:56:25 -07:00 committed by GitHub
parent 5b603fb80c
commit a0e8018fd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 102 additions and 132 deletions

View file

@ -120,7 +120,6 @@ data:
data_path: train.tsv # provide the full path to the file
batch_size: 64
shuffle: true
do_basic_tokenize: false
max_insts: -1 # Maximum number of instances (-1 means no limit)
# Refer to the text_normalization doc for more information about data augmentation
tagger_data_augmentation: false
@ -137,7 +136,6 @@ data:
data_path: dev.tsv # provide the full path to the file. Provide multiple paths to run evaluation on multiple datasets
batch_size: 64
shuffle: false
do_basic_tokenize: false
max_insts: -1 # Maximum number of instances (-1 means no limit)
use_cache: false # uses a cache to store the processed dataset, you may use it for large datasets for speed up (especially when using multi GPUs)
num_workers: 3
@ -148,7 +146,6 @@ data:
data_path: test.tsv # provide the full path to the file
batch_size: 64
shuffle: false
do_basic_tokenize: false
use_cache: false # uses a cache to store the processed dataset, you may use it for large datasets for speed up (especially when using multi GPUs)
num_workers: 3
pin_memory: false

View file

@ -126,7 +126,6 @@ def _write_batches_to_tarfiles(
max_len=max_seq_len,
decoder_data_augmentation=decoder_data_augmentation,
lang=lang,
do_basic_tokenize=False,
use_cache=False,
max_insts=-1,
do_tokenize=False,
@ -187,12 +186,12 @@ if __name__ == '__main__':
parser.add_argument(
'--num_batches_per_tarfile',
type=int,
default=2,
default=1000,
help='Number batches, i.e., pickle files, included in a single .tar file.',
)
parser.add_argument('--n_jobs', type=int, default=-2, help='The maximum number of concurrently running jobs.')
parser.add_argument(
'--batch_size', type=int, default=16, help='Batch size, i.e., number of examples in a single pickle file'
'--batch_size', type=int, default=32, help='Batch size, i.e., number of examples in a single pickle file'
)
parser.add_argument(
'--factor', default=8, type=int, help='The final number of tar files will be divisible by the "factor" value'

View file

@ -60,7 +60,7 @@ from nemo.utils import logging
@hydra_runner(config_path="conf", config_name="duplex_tn_config")
def main(cfg: DictConfig) -> None:
logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
logging.debug(f'Config Params: {OmegaConf.to_yaml(cfg)}')
lang = cfg.lang
if cfg.decoder_pretrained_model is None or cfg.tagger_pretrained_model is None:
@ -88,9 +88,7 @@ def main(cfg: DictConfig) -> None:
for i, line in enumerate(lines):
batch.append(line.strip())
if len(batch) == batch_size or i == len(lines) - 1:
outputs = tn_model._infer(
batch, [constants.DIRECTIONS_TO_MODE[mode]] * len(batch), do_basic_tokenization=True,
)
outputs = tn_model._infer(batch, [constants.DIRECTIONS_TO_MODE[mode]] * len(batch),)
all_preds.extend([x for x in outputs[-1]])
batch = []
assert len(all_preds) == len(lines)
@ -124,7 +122,7 @@ def main(cfg: DictConfig) -> None:
if cfg.mode in ['tn', 'joint']:
directions.append(constants.DIRECTIONS_TO_MODE[constants.TN_MODE])
inputs.append(test_input)
outputs = tn_model._infer(inputs, directions, do_basic_tokenization=True)[-1]
outputs = tn_model._infer(inputs, directions)[-1]
if cfg.mode in ['joint', 'itn']:
print(f'Prediction (ITN): {outputs[0]}')
if cfg.mode in ['joint', 'tn']:

View file

@ -27,8 +27,9 @@ from torch.utils.data import IterableDataset
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
from nemo.collections.nlp.data.text_normalization import constants
from nemo.collections.nlp.data.text_normalization.utils import basic_tokenize, read_data_file
from nemo.collections.nlp.data.text_normalization.utils import read_data_file
from nemo.core.classes import Dataset
from nemo.utils import logging
from nemo.utils.decorators.experimental import experimental
@ -58,7 +59,6 @@ class TextNormalizationDecoderDataset(Dataset):
instances that may help the decoder become more robust against the tagger's errors.
Refer to the doc for more info.
lang: language of the dataset
do_basic_tokenize: a flag indicates whether to do some basic tokenization for the inputs
use_cache: Enables caching to use pickle format to store and read data from
max_insts: Maximum number of instances (-1 means no limit)
do_tokenize: Tokenize each instance (set to False for Tarred dataset)
@ -75,7 +75,6 @@ class TextNormalizationDecoderDataset(Dataset):
max_len: int = 512,
decoder_data_augmentation: bool = False,
lang: str = "en",
do_basic_tokenize: bool = False,
use_cache: bool = False,
max_insts: int = -1,
do_tokenize: bool = True,
@ -95,7 +94,7 @@ class TextNormalizationDecoderDataset(Dataset):
data_dir, filename = os.path.split(input_file)
tokenizer_name_normalized = tokenizer_name.replace('/', '_')
cached_data_file = os.path.join(
data_dir, f'cached_decoder_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}_{mode}_{max_len}.pkl'
data_dir, f'cached_decoder_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}_{mode}_{max_len}.pkl',
)
if use_cache and os.path.exists(cached_data_file):
@ -117,7 +116,7 @@ class TextNormalizationDecoderDataset(Dataset):
logging.debug(f"Converting raw instances to DecoderDataInstance for {input_file}...")
self.insts, all_semiotic_classes = self.__process_raw_entries(
raw_instances, decoder_data_augmentation=decoder_data_augmentation, do_basic_tokenize=do_basic_tokenize
raw_instances, decoder_data_augmentation=decoder_data_augmentation
)
logging.debug(
f"Extracted {len(self.insts)} DecoderDateInstances out of {len(raw_instances)} raw instances."
@ -134,7 +133,7 @@ class TextNormalizationDecoderDataset(Dataset):
logging.debug(f'Processing samples, total number: {len(self.insts)}')
self.__tokenize_samples(use_cache=use_cache, cached_data_file=cached_data_file)
def __process_raw_entries(self, raw_instances: List[Tuple[str]], decoder_data_augmentation, do_basic_tokenize):
def __process_raw_entries(self, raw_instances: List[Tuple[str]], decoder_data_augmentation):
"""
Converts raw instances to DecoderDataInstance
@ -142,7 +141,6 @@ class TextNormalizationDecoderDataset(Dataset):
decoder_data_augmentation (bool): a flag indicates whether to augment the dataset with additional data
instances that may help the decoder become more robust against the tagger's errors.
Refer to the doc for more info.
do_basic_tokenize: a flag indicates whether to do some basic tokenization for the inputs
Returns:
converted instances and all semiotic classes present in the data
@ -161,14 +159,7 @@ class TextNormalizationDecoderDataset(Dataset):
continue
# Create a DecoderDataInstance
inst = DecoderDataInstance(
w_words,
s_words,
inst_dir,
start_idx=ix,
end_idx=ix + 1,
lang=self.lang,
semiotic_class=_class,
do_basic_tokenize=do_basic_tokenize,
w_words, s_words, inst_dir, start_idx=ix, end_idx=ix + 1, lang=self.lang, semiotic_class=_class
)
insts.append(inst)
@ -183,7 +174,6 @@ class TextNormalizationDecoderDataset(Dataset):
end_idx=ix + 1 + noise_right,
semiotic_class=_class,
lang=self.lang,
do_basic_tokenize=do_basic_tokenize,
)
insts.append(inst)
@ -358,7 +348,6 @@ class DecoderDataInstance:
end_idx: The ending index of the input span (exclusively)
lang: Language of the instance
semiotic_class: The semiotic class of the input span (can be set to None if not available)
do_basic_tokenize: a flag indicates whether to do some basic tokenization for the inputs
"""
def __init__(
@ -370,8 +359,8 @@ class DecoderDataInstance:
end_idx: int,
lang: str,
semiotic_class: str = None,
do_basic_tokenize: bool = False,
):
processor = MosesProcessor(lang_id=lang)
start_idx = max(start_idx, 0)
end_idx = min(end_idx, len(w_words))
ctx_size = constants.DECODE_CTX_SIZE
@ -409,11 +398,16 @@ class DecoderDataInstance:
c_s_words[jx] = c_w_words[jx]
# Extract input_words and output_words
if do_basic_tokenize:
c_w_words = basic_tokenize(' '.join(c_w_words), lang)
c_s_words = basic_tokenize(' '.join(c_s_words), lang)
c_w_words = processor.tokenize(' '.join(c_w_words)).split()
c_s_words = processor.tokenize(' '.join(c_s_words)).split()
# for cases when nearby words are actually multiple tokens, e.g. '1974,'
w_left = processor.tokenize(' '.join(w_left)).split()[-constants.DECODE_CTX_SIZE :]
w_right = processor.tokenize(' '.join(w_right)).split()[: constants.DECODE_CTX_SIZE]
w_input = w_left + [extra_id_0] + c_w_words + [extra_id_1] + w_right
s_input = s_left + [extra_id_0] + c_s_words + [extra_id_1] + s_right
if inst_dir == constants.INST_BACKWARD:
input_center_words = c_s_words
input_words = [constants.ITN_PREFIX] + s_input

View file

@ -18,8 +18,9 @@ import pickle
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
from nemo.collections.nlp.data.text_normalization import constants
from nemo.collections.nlp.data.text_normalization.utils import basic_tokenize, read_data_file
from nemo.collections.nlp.data.text_normalization.utils import read_data_file
from nemo.core.classes import Dataset
from nemo.utils import logging
from nemo.utils.decorators.experimental import experimental
@ -41,7 +42,6 @@ class TextNormalizationTaggerDataset(Dataset):
tokenizer: tokenizer of the model that will be trained on the dataset
tokenizer_name: name of the tokenizer,
mode: should be one of the values ['tn', 'itn', 'joint']. `tn` mode is for TN only. `itn` mode is for ITN only. `joint` is for training a system that can do both TN and ITN at the same time.
do_basic_tokenize: a flag indicates whether to do some basic tokenization before using the tokenizer of the model
tagger_data_augmentation (bool): a flag indicates whether to augment the dataset with additional data instances
lang: language of the dataset
use_cache: Enables caching to use pickle format to store and read data from,
@ -54,7 +54,6 @@ class TextNormalizationTaggerDataset(Dataset):
tokenizer: PreTrainedTokenizerBase,
tokenizer_name: str,
mode: str,
do_basic_tokenize: bool,
tagger_data_augmentation: bool,
lang: str,
max_seq_length: int,
@ -72,7 +71,7 @@ class TextNormalizationTaggerDataset(Dataset):
data_dir, filename = os.path.split(input_file)
tokenizer_name_normalized = tokenizer_name.replace('/', '_')
cached_data_file = os.path.join(
data_dir, f'cached_tagger_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}_{max_seq_length}.pkl'
data_dir, f'cached_tagger_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}_{max_seq_length}.pkl',
)
if use_cache and os.path.exists(cached_data_file):
@ -110,7 +109,7 @@ class TextNormalizationTaggerDataset(Dataset):
continue
# Create a new TaggerDataInstance
inst = TaggerDataInstance(w_words, s_words, inst_dir, do_basic_tokenize)
inst = TaggerDataInstance(w_words, s_words, inst_dir, lang=self.lang)
insts.append(inst)
# Data Augmentation (if enabled)
if tagger_data_augmentation:
@ -120,7 +119,7 @@ class TextNormalizationTaggerDataset(Dataset):
filtered_w_words.append(w)
filtered_s_words.append(s)
if len(filtered_s_words) > 1:
inst = TaggerDataInstance(filtered_w_words, filtered_s_words, inst_dir)
inst = TaggerDataInstance(filtered_w_words, filtered_s_words, inst_dir, lang)
insts.append(inst)
self.insts = insts
@ -189,10 +188,13 @@ class TaggerDataInstance:
w_words: List of words in a sentence in the written form
s_words: List of words in a sentence in the spoken form
direction: Indicates the direction of the instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).
do_basic_tokenize: a flag indicates whether to do some basic tokenization before using the tokenizer of the model
lang: Language
"""
def __init__(self, w_words, s_words, direction, do_basic_tokenize=False):
def __init__(self, w_words, s_words, direction, lang):
# moses tokenization before LM tokenization
# e.g., don't -> don 't, 12/3 -> 12 / 3
processor = MosesProcessor(lang_id=lang)
# Build input_words and labels
input_words, labels = [], []
# Task Prefix
@ -203,11 +205,9 @@ class TaggerDataInstance:
labels.append(constants.TASK_TAG)
# Main Content
for w_word, s_word in zip(w_words, s_words):
# Basic tokenization (if enabled)
if do_basic_tokenize:
w_word = ' '.join(basic_tokenize(w_word, self.lang))
if not s_word in constants.SPECIAL_WORDS:
s_word = ' '.join(basic_tokenize(s_word, self.lang))
w_word = processor.tokenize(w_word)
if not s_word in constants.SPECIAL_WORDS:
s_word = processor.tokenize(s_word)
# Update input_words and labels
if s_word == constants.SIL_WORD and direction == constants.INST_BACKWARD:
continue

View file

@ -13,15 +13,11 @@
# limitations under the License.
from collections import defaultdict
from typing import DefaultDict, List
from typing import List
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
from nemo.collections.nlp.data.text_normalization import constants
from nemo.collections.nlp.data.text_normalization.utils import (
basic_tokenize,
normalize_str,
read_data_file,
remove_puncts,
)
from nemo.collections.nlp.data.text_normalization.utils import normalize_str, read_data_file, remove_puncts
from nemo.utils.decorators.experimental import experimental
__all__ = ['TextNormalizationTestDataset']
@ -41,7 +37,7 @@ class TextNormalizationTestDataset:
def __init__(self, input_file: str, mode: str, lang: str):
self.lang = lang
insts = read_data_file(input_file, lang=lang)
processor = MosesProcessor(lang_id=lang)
# Build inputs and targets
self.directions, self.inputs, self.targets, self.classes, self.nb_spans, self.span_starts, self.span_ends = (
[],
@ -58,7 +54,6 @@ class TextNormalizationTestDataset:
if direction == constants.INST_BACKWARD:
if mode == constants.TN_MODE:
continue
# ITN mode
(
processed_w_words,
@ -77,10 +72,13 @@ class TextNormalizationTestDataset:
else:
processed_s_words.append(s_word)
s_word_last = processor.tokenize(processed_s_words.pop()).split()
processed_s_words.append(" ".join(s_word_last))
num_tokens = len(s_word_last)
processed_nb_spans += 1
processed_classes.append(cls)
processed_s_span_starts.append(s_word_idx)
s_word_idx += len(basic_tokenize(processed_s_words[-1], lang=self.lang))
s_word_idx += num_tokens
processed_s_span_ends.append(s_word_idx)
processed_w_words.append(w_word)
@ -88,15 +86,13 @@ class TextNormalizationTestDataset:
self.span_ends.append(processed_s_span_ends)
self.classes.append(processed_classes)
self.nb_spans.append(processed_nb_spans)
# Basic tokenization
input_words = basic_tokenize(' '.join(processed_s_words), lang)
input_words = ' '.join(processed_s_words)
# Update self.directions, self.inputs, self.targets
self.directions.append(direction)
self.inputs.append(' '.join(input_words))
self.inputs.append(input_words)
self.targets.append(
processed_w_words
) # is list of lists where inner list contains target tokens (not words)
# TN mode
elif direction == constants.INST_FORWARD:
if mode == constants.ITN_MODE:
@ -111,29 +107,29 @@ class TextNormalizationTestDataset:
) = ([], [], [], 0, [], [])
w_word_idx = 0
for cls, w_word, s_word in zip(classes, w_words, s_words):
# TN forward mode
# this is done for cases like `do n't`, this w_word will be treated as 2 tokens
w_word = processor.tokenize(w_word).split()
num_tokens = len(w_word)
if s_word in constants.SPECIAL_WORDS:
processed_s_words.append(w_word)
processed_s_words.append(" ".join(w_word))
else:
processed_s_words.append(s_word)
w_span_starts.append(w_word_idx)
w_word_idx += len(basic_tokenize(w_word, lang=self.lang))
w_word_idx += num_tokens
w_span_ends.append(w_word_idx)
processed_nb_spans += 1
processed_classes.append(cls)
processed_w_words.append(w_word)
processed_w_words.extend(w_word)
self.span_starts.append(w_span_starts)
self.span_ends.append(w_span_ends)
self.classes.append(processed_classes)
self.nb_spans.append(processed_nb_spans)
# Basic tokenization
input_words = basic_tokenize(' '.join(processed_w_words), lang)
input_words = ' '.join(processed_w_words)
# Update self.directions, self.inputs, self.targets
self.directions.append(direction)
self.inputs.append(' '.join(input_words))
self.inputs.append(input_words)
self.targets.append(
processed_s_words
) # is list of lists where inner list contains target tokens (not words)
@ -157,7 +153,7 @@ class TextNormalizationTestDataset:
return len(self.inputs)
@staticmethod
def is_same(pred: str, target: str, inst_dir: str, lang: str):
def is_same(pred: str, target: str, inst_dir: str):
"""
Function for checking whether the predicted string can be considered
the same as the target string
@ -166,18 +162,17 @@ class TextNormalizationTestDataset:
pred: Predicted string
target: Target string
inst_dir: Direction of the instance (i.e., INST_BACKWARD or INST_FORWARD).
lang: Language
Return: an int value (0/1) indicating whether pred and target are the same.
"""
if inst_dir == constants.INST_BACKWARD:
pred = remove_puncts(pred)
target = remove_puncts(target)
pred = normalize_str(pred, lang)
target = normalize_str(target, lang)
pred = normalize_str(pred)
target = normalize_str(target)
return int(pred == target)
@staticmethod
def compute_sent_accuracy(preds: List[str], targets: List[str], inst_directions: List[str], lang: str):
def compute_sent_accuracy(preds: List[str], targets: List[str], inst_directions: List[str]):
"""
Compute the sentence accuracy metric.
@ -185,7 +180,6 @@ class TextNormalizationTestDataset:
preds: List of predicted strings.
targets: List of target strings.
inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD or INST_FORWARD).
lang: Language
Return: the sentence accuracy score
"""
assert len(preds) == len(targets)
@ -194,7 +188,7 @@ class TextNormalizationTestDataset:
# Sentence Accuracy
correct_count = 0
for inst_dir, pred, target in zip(inst_directions, preds, targets):
correct_count += TextNormalizationTestDataset.is_same(pred, target, inst_dir, lang)
correct_count += TextNormalizationTestDataset.is_same(pred, target, inst_dir)
sent_accuracy = correct_count / len(targets)
return sent_accuracy
@ -208,9 +202,7 @@ class TextNormalizationTestDataset:
output_spans: List[List[str]],
classes: List[List[str]],
nb_spans: List[int],
span_starts: List[List[int]],
span_ends: List[List[int]],
lang: str,
) -> dict:
"""
Compute the class based accuracy metric. This uses model's predicted tags.
@ -223,9 +215,7 @@ class TextNormalizationTestDataset:
output_spans: A list of lists where each inner list contains the decoded spans for the corresponding input sentence
classes: A list of lists where inner list contains the class for each semiotic token in input sentence
nb_spans: A list that contains the number of tokens in the input
span_starts: A list of lists where inner list contains the start word index of the current token
span_ends: A list of lists where inner list contains the end word index of the current token
lang: Language
Return: the class accuracy scores as dict
"""
@ -233,7 +223,7 @@ class TextNormalizationTestDataset:
return 'NA'
class2stats, class2correct = defaultdict(int), defaultdict(int)
for ix, (sent, tags) in enumerate(zip(inputs, tag_preds)):
assert len(inputs) == len(tag_preds)
assert len(sent) == len(tags)
cur_words = [[] for _ in range(nb_spans[ix])]
jx, span_idx = 0, 0
cur_spans = output_spans[ix]
@ -261,9 +251,10 @@ class TextNormalizationTestDataset:
jx += 1
target_token_idx = 0
# assert len(cur_words) == len(targets[ix])
for class_idx in range(nb_spans[ix]):
correct = TextNormalizationTestDataset.is_same(
" ".join(cur_words[class_idx]), targets[ix][target_token_idx], inst_directions[ix], lang
" ".join(cur_words[class_idx]), targets[ix][target_token_idx], inst_directions[ix]
)
class2correct[classes[ix][class_idx]] += correct
target_token_idx += 1

View file

@ -37,7 +37,7 @@ def input_preprocessing(sent: str, lang: str):
such as Δ or λ (if any).
Args:
sents: input text.
sent: input text.
lang: language
Returns: preprocessed input text.
@ -138,12 +138,9 @@ def process_url(tokens: List[str], outputs: List[str], lang: str):
return outputs
def normalize_str(input_str, lang):
def normalize_str(input_str):
""" Normalize an input string """
input_str_tokens = basic_tokenize(input_str.strip().lower(), lang)
input_str = ' '.join(input_str_tokens)
input_str = input_str.replace(' ', ' ')
return input_str
return input_str.strip().lower().replace(" ", " ")
def remove_puncts(input_str):

View file

@ -26,6 +26,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSe
import nemo.collections.nlp.data.text_normalization.constants as constants
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset
from nemo.collections.nlp.data.text_normalization.decoder_dataset import (
TarredTextNormalizationDecoderDataset,
TextNormalizationDecoderDataset,
@ -153,6 +154,7 @@ class DuplexDecoderModel(NLPModel):
)
input_centers = self._tokenizer.batch_decode(batch['input_center'], skip_special_tokens=True)
direction = [x[0].item() for x in batch['direction']]
direction_str = [constants.DIRECTIONS_ID_TO_NAME[x] for x in direction]
# apply post_processing
@ -161,9 +163,12 @@ class DuplexDecoderModel(NLPModel):
for idx, class_id in enumerate(batch['semiotic_class_id']):
direction = constants.TASK_ID_TO_MODE[batch['direction'][idx][0].item()]
class_name = self._val_id_to_class[dataloader_idx][class_id[0].item()]
results[f"correct_{class_name}_{direction}"] += torch.tensor(
labels_str[idx] == generated_texts[idx], dtype=torch.int
).to(self.device)
pred_result = TextNormalizationTestDataset.is_same(
generated_texts[idx], labels_str[idx], constants.DIRECTIONS_TO_MODE[direction]
)
results[f"correct_{class_name}_{direction}"] += torch.tensor(pred_result, dtype=torch.int).to(self.device)
results[f"total_{class_name}_{direction}"] += torch.tensor(1).to(self.device)
results[f"{split}_loss"] = val_loss
@ -309,10 +314,6 @@ class DuplexDecoderModel(NLPModel):
if sum(nb_spans) == 0:
return [[]] * len(sents)
model, tokenizer = self.model, self._tokenizer
try:
model_max_len = model.config.n_positions
except AttributeError:
model_max_len = 512
ctx_size = constants.DECODE_CTX_SIZE
extra_id_0 = constants.EXTRA_ID_0
extra_id_1 = constants.EXTRA_ID_1
@ -320,7 +321,7 @@ class DuplexDecoderModel(NLPModel):
"""
Build all_inputs - extracted spans to be transformed by the decoder model
Inputs for TN direction have "0" prefix, while the backward, ITN direction, has prefix "1"
"input_centers" - List[str] - ground-truth labels for the span #TODO: rename
"input_centers" - List[str] - ground-truth labels for the span
"""
input_centers, input_dirs, all_inputs = [], [], []
for ix, sent in enumerate(sents):
@ -351,11 +352,12 @@ class DuplexDecoderModel(NLPModel):
input_ids = batch['input_ids'].to(self.device)
generated_texts, generated_ids, sequence_toks_scores = self._generate_predictions(
input_ids=input_ids, model_max_len=model_max_len
input_ids=input_ids, model_max_len=self.max_sequence_len
)
# Use covering grammars (if enabled)
if self.use_cg:
# Compute sequence probabilities
sequence_probs = torch.ones(len(all_inputs)).to(self.device)
for ix, cur_toks_scores in enumerate(sequence_toks_scores):
@ -528,7 +530,6 @@ class DuplexDecoderModel(NLPModel):
if data_split == "train"
else False,
lang=self.lang,
do_basic_tokenize=cfg.do_basic_tokenize,
use_cache=cfg.get('use_cache', False),
max_insts=cfg.get('max_insts', -1),
do_tokenize=True,

View file

@ -131,7 +131,7 @@ class DuplexTaggerModel(NLPModel):
# Functions for inference
@torch.no_grad()
def _infer(self, sents: List[List[str]], inst_directions: List[str], do_basic_tokenization=True):
def _infer(self, sents: List[List[str]], inst_directions: List[str]):
""" Main function for Inference
Args:
@ -144,10 +144,8 @@ class DuplexTaggerModel(NLPModel):
nb_spans: A list of ints where each int indicates the number of semiotic spans in input words.
span_starts: A list of lists where each list contains the starting locations of semiotic spans in input words.
span_ends: A list of lists where each list contains the ending locations of semiotic spans in input words.
do_basic_tokenization: whether to do a pre-processing to separate punctuation marks, recommended to set to True
"""
self.eval()
# Append prefix
texts = []
for ix, sent in enumerate(sents):
@ -155,19 +153,11 @@ class DuplexTaggerModel(NLPModel):
prefix = constants.ITN_PREFIX
elif inst_directions[ix] == constants.INST_FORWARD:
prefix = constants.TN_PREFIX
if do_basic_tokenization:
texts.append([prefix] + sent)
else:
texts.append(prefix + " " + sent)
texts.append([prefix] + sent)
# Apply the model
if do_basic_tokenization:
is_split_into_words = True
else:
is_split_into_words = False
encodings = self._tokenizer(
texts, is_split_into_words=is_split_into_words, padding=True, truncation=True, return_tensors='pt'
texts, is_split_into_words=True, padding=True, truncation=True, return_tensors='pt'
)
inputs = encodings
@ -175,10 +165,7 @@ class DuplexTaggerModel(NLPModel):
# check that the length of the 'input_ids' equals as least the length of the original input
# if an input symbol is missing in the tokenizer's vocabulary (such as emoji or a Chinese character), it could be skipped
if do_basic_tokenization:
len_texts = [len(x) for x in texts]
else:
len_texts = [len(x.split()) for x in texts]
len_texts = [len(x) for x in texts]
len_ids = [
len(self._tokenizer.convert_ids_to_tokens(x, skip_special_tokens=True)) for x in encodings['input_ids']
]
@ -343,7 +330,6 @@ class DuplexTaggerModel(NLPModel):
tokenizer=self._tokenizer,
tokenizer_name=self.transformer_name,
mode=self.mode,
do_basic_tokenize=cfg.do_basic_tokenize,
tagger_data_augmentation=tagger_data_augmentation,
lang=self.lang,
max_seq_length=self.max_sequence_len,

View file

@ -92,7 +92,11 @@ class DuplexTextNormalizationModel(nn.Module):
) = zip(*batch_insts)
# Inference and Running Time Measurement
batch_start_time = perf_counter()
batch_tag_preds, batch_output_spans, batch_final_preds = self._infer(batch_inputs, batch_dirs)
batch_tag_preds, batch_output_spans, batch_final_preds = self._infer(
batch_inputs, batch_dirs, processed=True
)
batch_run_time = (perf_counter() - batch_start_time) * 1000 # milliseconds
all_run_times.append(batch_run_time)
# Update all_dirs, all_inputs, all_tag_preds, all_final_preds and all_targets
@ -149,20 +153,18 @@ class DuplexTextNormalizationModel(nn.Module):
cur_targets_sent = [" ".join(x) for x in cur_targets]
sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy(
cur_final_preds, cur_targets_sent, cur_dirs, self.lang
cur_final_preds, cur_targets_sent, cur_dirs
)
class_accuracy = TextNormalizationTestDataset.compute_class_accuracy(
[basic_tokenize(x, lang=self.lang) for x in cur_inputs],
[x.split() for x in cur_inputs],
cur_targets,
cur_tag_preds,
cur_dirs,
cur_output_spans,
cur_classes,
cur_nb_spans,
cur_span_starts,
cur_span_ends,
self.lang,
)
if verbose:
logging.info(f'\n============ Direction {direction} ============')
@ -185,13 +187,14 @@ class DuplexTextNormalizationModel(nn.Module):
for _input, tag_pred, final_pred, target, classes in zip(
cur_inputs, cur_tag_preds, cur_final_preds, cur_targets_sent, cur_classes
):
if not TextNormalizationTestDataset.is_same(final_pred, target, direction, self.lang):
if not TextNormalizationTestDataset.is_same(final_pred, target, direction):
if direction == constants.INST_BACKWARD:
error_f.write('Backward Problem (ITN)\n')
itn_error_ctx += 1
elif direction == constants.INST_FORWARD:
error_f.write('Forward Problem (TN)\n')
tn_error_ctx += 1
formatted_input_str = get_formatted_string(basic_tokenize(_input, lang=self.lang))
formatted_tag_pred_str = get_formatted_string(tag_pred)
class_str = " ".join(classes)
@ -217,7 +220,7 @@ class DuplexTextNormalizationModel(nn.Module):
return results
# Functions for inference
def _infer(self, sents: List[str], inst_directions: List[str], do_basic_tokenization=True):
def _infer(self, sents: List[str], inst_directions: List[str], processed=False):
"""
Main function for Inference
@ -228,8 +231,8 @@ class DuplexTextNormalizationModel(nn.Module):
sents: A list of input texts.
inst_directions: A list of str where each str indicates the direction of the corresponding instance \
(i.e., constants.INST_BACKWARD for ITN or constants.INST_FORWARD for TN).
do_basic_tokenization: whether to do a pre-processing to separate punctuation marks,
recommended to set to True
processed: Set to True when used with TextNormalizationTestDataset, the data is already tokenized with moses,
repetitive moses tokenization could lead to the number of tokens and class span mismatch
Returns:
tag_preds: A list of lists where the inner list contains the tag predictions from the tagger for each word in the input text.
@ -238,19 +241,16 @@ class DuplexTextNormalizationModel(nn.Module):
"""
original_sents = [s for s in sents]
# Separate into words
if do_basic_tokenization:
if not processed:
sents = [self.decoder.processor.tokenize(x).split() for x in sents]
else:
sents = [x.split() for x in sents]
# Tagging
# span_ends included, returns index wrt to words in input without auxiliary words
tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(
sents, inst_directions, do_basic_tokenization=do_basic_tokenization
)
tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(sents, inst_directions)
output_spans = self.decoder._infer(sents, nb_spans, span_starts, span_ends, inst_directions)
if not do_basic_tokenization:
sents = [x.split() for x in sents]
# Prepare final outputs
final_outputs = []
for ix, (sent, tags) in enumerate(zip(sents, tag_preds)):
@ -268,8 +268,15 @@ class DuplexTextNormalizationModel(nn.Module):
span_idx += 1
while jx < len(sent) and tags[jx] == constants.I_PREFIX + constants.TRANSFORM_TAG:
jx += 1
cur_output_str = self.decoder.processor.detokenize(cur_words)
cur_output_str = post_process_punct(input=original_sents[ix], nn_output=cur_output_str)
if processed:
# for Class-based evaluation, don't apply Moses detokenization
cur_output_str = " ".join(cur_words)
else:
# detokenize the output with Moses and fix punctuation marks to match the input
# for interactive inference or inference from a file
cur_output_str = self.decoder.processor.detokenize(cur_words)
cur_output_str = post_process_punct(input=original_sents[ix], nn_output=cur_output_str)
final_outputs.append(cur_output_str)
except IndexError:
logging.warning(f"Input sent is too long and will be skipped - {original_sents[ix]}")