60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
# Copyright 2017 Google Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
#
|
|
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utility functions specifically for NMT."""
|
|
from __future__ import print_function
|
|
|
|
import codecs
|
|
import time
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from utils import misc_utils as utils
|
|
|
|
__all__ = ["get_translation"]
|
|
|
|
|
|
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
|
|
"""Given batch decoding outputs, select a sentence and turn to text."""
|
|
if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
|
|
# Select a sentence
|
|
output = nmt_outputs[sent_id, :].tolist()
|
|
|
|
# If there is an eos symbol in outputs, cut them at that point.
|
|
if tgt_eos and tgt_eos in output:
|
|
output = output[:output.index(tgt_eos)]
|
|
|
|
if subword_option == "bpe": # BPE
|
|
translation = utils.format_bpe_text(output)
|
|
elif subword_option == "spm": # SPM
|
|
translation = utils.format_spm_text(output)
|
|
else:
|
|
translation = utils.format_text(output)
|
|
|
|
return translation, len(output)
|