[BERT/PyT] Triton Inference Server support

This commit is contained in:
Przemek Strzelczyk 2020-04-02 14:39:24 +02:00
parent c007e8e6dc
commit 26c2676104
33 changed files with 2723 additions and 1036 deletions

View file

@ -15,6 +15,10 @@ checkpoints/
results
results/*
#Editor
.idea
.idea/*
# Distribution / packaging
.Python
build/

View file

@ -10,7 +10,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:19.12-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.03-py3
FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk as trt
FROM ${FROM_IMAGE_NAME}
RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract
@ -20,10 +22,17 @@ WORKDIR /workspace
RUN git clone https://github.com/attardi/wikiextractor.git
RUN git clone https://github.com/soskek/bookcorpus.git
# Copy the perf_client over
COPY --from=trt /workspace/install/ /workspace/install/
ENV LD_LIBRARY_PATH /workspace/install/lib:${LD_LIBRARY_PATH}
# Install trt python api
RUN pip install /workspace/install/python/tensorrtserver-1.*-py3-none-linux_x86_64.whl
WORKDIR /workspace/bert
RUN pip install --upgrade --no-cache-dir pip \
&& pip install --no-cache-dir \
tqdm boto3 requests six ipdb h5py html2text nltk progressbar onnxruntime\
tqdm boto3 requests six ipdb h5py html2text nltk progressbar onnxruntime \
git+https://github.com/NVIDIA/dllogger
RUN apt-get install -y iputils-ping

File diff suppressed because it is too large Load diff

View file

@ -249,6 +249,10 @@ def create_instances_from_document(
if random_document_index != document_index:
break
#If picked random document is the same as the current document
if random_document_index == document_index:
is_random_next = False
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):

View file

@ -35,6 +35,7 @@ from types import SimpleNamespace
from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from tokenization import (BasicTokenizer, BertTokenizer, whitespace_tokenize)
from run_squad import _get_best_indices, _compute_softmax, get_valid_prelim_predictions, get_answer_text
if sys.version_info[0] == 2:
import cPickle as pickle
@ -133,290 +134,86 @@ def preprocess_tokenized_text(doc_tokens, query_tokens, tokenizer,
RawResult = collections.namedtuple("RawResult", ["start_logits", "end_logits"])
def get_predictions(doc_tokens, tokens_for_postprocessing,
start_logits, end_logits, n_best_size,
max_answer_length, do_lower_case,
can_give_negative_answer, null_score_diff_threshold):
""" Write final predictions to the json file and log-odds of null if needed. """
def get_answer(doc_tokens, tokens_for_postprocessing,
start_logits, end_logits, args):
result = RawResult(start_logits=start_logits, end_logits=end_logits)
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction",
["start_index", "end_index", "start_logit", "end_logit"])
predictions = []
Prediction = collections.namedtuple('Prediction', ['text', 'start_logit', 'end_logit'])
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
if args.version_2_with_negative:
null_val = (float("inf"), 0, 0)
start_indices = _get_indices_of_largest_logits(result.start_logits)
end_indices = _get_indices_of_largest_logits(result.end_logits)
# if we could have irrelevant answers, get the min score of irrelevant
if can_give_negative_answer:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indices:
for end_index in end_indices:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(tokens_for_postprocessing.tokens):
continue
if end_index >= len(tokens_for_postprocessing.tokens):
continue
if start_index not in tokens_for_postprocessing.token_to_orig_map:
continue
if end_index not in tokens_for_postprocessing.token_to_orig_map:
continue
if not tokens_for_postprocessing.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]
)
)
if can_give_negative_answer:
prelim_predictions.append(
_PrelimPrediction(
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit
)
)
start_indices = _get_best_indices(result.start_logits, args.n_best_size)
end_indices = _get_best_indices(result.end_logits, args.n_best_size)
prelim_predictions = get_valid_prelim_predictions(start_indices, end_indices,
tokens_for_postprocessing, result, args)
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True
)
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True
)
if args.version_2_with_negative:
score = result.start_logits[0] + result.end_logits[0]
if score < null_val[0]:
null_val = (score, result.start_logits[0], result.end_logits[0])
_NbestPrediction = collections.namedtuple("NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {}
nbest = []
doc_tokens_obj = {
'doc_tokens': doc_tokens,
}
doc_tokens_obj = SimpleNamespace(**doc_tokens_obj)
curr_predictions = []
seen_predictions = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
if len(curr_predictions) == args.n_best_size:
break
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = tokens_for_postprocessing.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = tokens_for_postprocessing.token_to_orig_map[pred.start_index]
orig_doc_end = tokens_for_postprocessing.token_to_orig_map[pred.end_index]
orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# de-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
# get final text
final_text = get_final_text(tok_text, orig_text, do_lower_case)
if pred.end_index > 0: # this is a non-null prediction
final_text = get_answer_text(doc_tokens_obj, tokens_for_postprocessing, pred, args)
if final_text in seen_predictions:
continue
# mark it
seen_predictions[final_text] = True
else: # this is a null prediction
else:
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit
)
)
# if we didn't include the empty option in the n-best, include it
if can_give_negative_answer:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit
)
)
# In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure.
if len(nbest) == 1:
nbest.insert(0, _NbestPrediction(text="", start_logit=0.0, end_logit=0.0))
seen_predictions.append(final_text)
curr_predictions.append(Prediction(final_text, pred.start_logit, pred.end_logit))
predictions += curr_predictions
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(_NbestPrediction(text="", start_logit=0.0, end_logit=0.0))
# add empty prediction
if args.version_2_with_negative:
predictions.append(Prediction('', null_val[1], null_val[2]))
assert len(nbest) >= 1
nbest_answers = []
answer = None
nbest = sorted(predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)[:args.n_best_size]
# scoring
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
# get probabilities
if not best_non_null_entry and entry.text:
best_non_null_entry = entry
probs = _compute_softmax(total_scores)
# nbest predictions into json format
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
nbest_json.append(output)
nbest_answers.append(output)
if args.version_2_with_negative:
score_diff = null_val[0] - best_non_null_entry.start_logit - best_non_null_entry.end_logit
if score_diff > args.null_score_diff_threshold:
answer = ""
else:
answer = best_non_null_entry.text
else:
answer = nbest_answers[0]['text']
assert len(nbest_json) >= 1
if can_give_negative_answer:
# predict "unknown" iff ((score_null - score_of_best_non-null_entry) > threshold)
score = best_non_null_entry.start_logit + best_non_null_entry.end_logit
score_diff = score_null - score
if score_diff > null_score_diff_threshold:
nbest_json[0]['text'] = "unknown"
# best_non_null_entry.text = "unknown"
#
return nbest_json
def get_final_text(pred_text, orig_text, do_lower_case):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heruistic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in tok_ns_to_s_map.items():
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
def _get_indices_of_largest_logits(logits):
""" sort logits and return the indices of the sorted array """
indices_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
indices = map(lambda x: x[0], indices_and_score)
indices = list(indices)
return indices
return answer, nbest_answers
def main():
@ -434,6 +231,8 @@ def main():
help="The checkpoint file from pretraining")
## Other parameters
parser.add_argument("--verbose_logging", action='store_true',
help="If true, all of the warnings related to data processing will be printed. ")
parser.add_argument("--seed", default=1, type=int)
parser.add_argument("--question", default="Most antibiotics target bacteria and don't affect what class of organisms? ",
type=str, help="question")
@ -456,7 +255,7 @@ def main():
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument('--can_give_negative_answer',
parser.add_argument('--version_2_with_negative',
action='store_true',
help='If true, then the model can reply with "unknown". ')
parser.add_argument('--null_score_diff_threshold',
@ -535,14 +334,14 @@ def main():
# post-processing
start_logits = start_logits[0].detach().cpu().tolist()
end_logits = end_logits[0].detach().cpu().tolist()
answer = get_predictions(doc_tokens, tokens_for_postprocessing,
start_logits, end_logits, args.n_best_size,
args.max_answer_length, args.do_lower_case,
args.can_give_negative_answer,
args.null_score_diff_threshold)
answer, answers = get_answer(doc_tokens, tokens_for_postprocessing,
start_logits, end_logits, args)
# print result
print(json.dumps(answer, indent=4))
print()
print(answer)
print()
print(json.dumps(answers, indent=4))
if __name__ == "__main__":

View file

@ -33,6 +33,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils import checkpoint
sys.path.append('/workspace/bert/')
from file_utils import cached_path
from torch.nn import Module
@ -115,34 +116,21 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
pointer.data = torch.from_numpy(array)
return model
@torch.jit.script
def f_gelu(x):
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
return torch.nn.functional.gelu(x)# x * 0.5 * (1.0 + torch.erf(x / 1.41421))
@torch.jit.script
def bias_tanh(bias, y):
x = bias + y
return torch.tanh(x)
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return f_gelu(x)
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
ACT2FN = {"gelu": torch.nn.functional.gelu, "bias_gelu": bias_gelu, "bias_tanh": bias_tanh, "relu": torch.nn.functional.relu, "swish": swish}
class LinearActivation(Module):
r"""Fused Linear and activation Module.
@ -153,13 +141,14 @@ class LinearActivation(Module):
super(LinearActivation, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.fused_gelu = False
self.fused_tanh = False
if isinstance(act, str) or (sys.version_info[0] == 2 and isinstance(act, unicode)):
if bias and act == 'gelu':
self.fused_gelu = True
elif bias and act == 'tanh':
self.fused_tanh = True
self.act_fn = nn.Identity() #
self.biased_act_fn = None #
self.bias = None #
if isinstance(act, str) or (sys.version_info[0] == 2 and isinstance(act, unicode)): # For TorchScript
if bias and not 'bias' in act: # compatibility
act = 'bias_' + act #
self.biased_act_fn = ACT2FN[act] #
else:
self.act_fn = ACT2FN[act]
else:
@ -179,10 +168,8 @@ class LinearActivation(Module):
init.uniform_(self.bias, -bound, bound)
def forward(self, input):
if self.fused_gelu:
return bias_gelu(self.bias, F.linear(input, self.weight, None))
elif self.fused_tanh:
return bias_tanh(self.bias, F.linear(input, self.weight, None))
if not self.bias is None:
return self.biased_act_fn(self.bias, F.linear(input, self.weight, None))
else:
return self.act_fn(F.linear(input, self.weight, self.bias))
@ -206,7 +193,8 @@ class BertConfig(object):
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
initializer_range=0.02,
output_all_encoded_layers=False):
"""Constructs BertConfig.
Args:
@ -249,6 +237,7 @@ class BertConfig(object):
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.output_all_encoded_layers = output_all_encoded_layers
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@ -280,28 +269,57 @@ class BertConfig(object):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class BertNonFusedLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertNonFusedLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
try:
import apex
#apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm')
import apex.normalization
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
#apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward')
BertLayerNorm = apex.normalization.FusedLayerNorm
#BertLayerNorm = apex.normalization.FusedLayerNorm
APEX_IS_AVAILABLE = True
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
#BertLayerNorm = BertNonFusedLayerNorm
APEX_IS_AVAILABLE = False
class BertLayerNorm(Module):
def __init__(self, hidden_size, eps=1e-12):
super(BertLayerNorm, self).__init__()
self.shape = torch.Size((hidden_size,))
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.apex_enabled = APEX_IS_AVAILABLE
def forward(self, x):
@torch.jit.unused
def fused_layer_norm(self, x):
return FusedLayerNormAffineFunction.apply(
x, self.weight, self.bias, self.shape, self.eps)
def forward(self, x):
if self.apex_enabled and not torch.jit.is_scripting():
x = self.fused_layer_norm(x)
else:
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight * x + self.bias
return x
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
@ -317,12 +335,10 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
def forward(self, input_ids, token_type_ids):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
@ -350,16 +366,15 @@ class BertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = torch.reshape(x, new_x_shape)
return x.permute(0, 2, 1, 3)
def transpose_key_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = torch.reshape(x, new_x_shape)
return x.permute(0, 2, 3, 1)
def forward(self, hidden_states, attention_mask):
@ -378,7 +393,7 @@ class BertSelfAttention(nn.Module):
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = self.softmax(attention_scores)
attention_probs = F.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@ -387,7 +402,7 @@ class BertSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = torch.reshape(context_layer, new_context_layer_shape)
return context_layer
@ -457,20 +472,12 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.output_all_encoded_layers = config.output_all_encoded_layers
self._checkpoint_activations = False
# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
# all_encoder_layers = []
# for layer_module in self.layer:
# hidden_states = layer_module(hidden_states, attention_mask)
# if output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# if not output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# return all_encoder_layers
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False):
all_encoder_layers = []
@torch.jit.unused
def checkpointed_forward(self, hidden_states, attention_mask):
def custom(start, end):
def custom_forward(*inputs):
layers = self.layer[start:end]
@ -480,42 +487,31 @@ class BertEncoder(nn.Module):
return x_
return custom_forward
if checkpoint_activations:
l = 0
num_layers = len(self.layer)
chunk_length = math.ceil(math.sqrt(num_layers))
while l < num_layers:
hidden_states = checkpoint.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
l += chunk_length
# decoder layers
l = 0
num_layers = len(self.layer)
chunk_length = math.ceil(math.sqrt(num_layers))
while l < num_layers:
hidden_states = checkpoint.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
l += chunk_length
return hidden_states
def forward(self, hidden_states, attention_mask):
all_encoder_layers = []
if self._checkpoint_activations:
hidden_states = self.checkpointed_forward(hidden_states, attention_mask)
else:
for i,layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
if self.output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers or checkpoint_activations:
if not self.output_all_encoded_layers or self._checkpoint_activations:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
#class BertEncoder(nn.Module):
# def __init__(self, config):
# super(BertEncoder, self).__init__()
# layer = BertLayer(config)
# self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
#
# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
# all_encoder_layers = []
# for layer_module in self.layer:
# hidden_states = layer_module(hidden_states, attention_mask)
# if output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# if not output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
@ -556,9 +552,7 @@ class BertLMPredictionHead(nn.Module):
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
torch.cuda.nvtx.range_push("decoder input.size() = {}, weight.size() = {}".format(hidden_states.size(), self.decoder.weight.size()))
hidden_states = self.decoder(hidden_states) + self.bias
torch.cuda.nvtx.range_pop()
return hidden_states
@ -622,6 +616,17 @@ class BertPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def checkpoint_activations(self, val):
def _apply_flag(module):
if hasattr(module, "_checkpoint_activations"):
module._checkpoint_activations=val
self.apply(_apply_flag)
def enable_apex(self, val):
def _apply_flag(module):
if hasattr(module, "apex_enabled"):
module.apex_enabled=val
self.apply(_apply_flag)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
from_tf=False, *inputs, **kwargs):
@ -763,7 +768,6 @@ class BertModel(BertPreTrainedModel):
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
@ -796,13 +800,9 @@ class BertModel(BertPreTrainedModel):
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
self.output_all_encoded_layers = config.output_all_encoded_layers
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, checkpoint_activations=False):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
def forward(self, input_ids, token_type_ids, attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
@ -815,17 +815,15 @@ class BertModel(BertPreTrainedModel):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=self.embeddings.word_embeddings.weight.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers, checkpoint_activations=checkpoint_activations)
encoded_layers = self.encoder(embedding_output, extended_attention_mask)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
if not self.output_all_encoded_layers:
encoded_layers = encoded_layers[-1:]
return encoded_layers, pooled_output
@ -885,20 +883,12 @@ class BertForPreTraining(BertPreTrainedModel):
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
def forward(self, input_ids, token_type_ids, attention_mask):
encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = encoded_layers[-1]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
#print("loss is {} {}".format(masked_lm_loss, next_sentence_loss))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
else:
return prediction_scores, seq_relationship_score
return prediction_scores, seq_relationship_score
class BertForMaskedLM(BertPreTrainedModel):
@ -949,9 +939,9 @@ class BertForMaskedLM(BertPreTrainedModel):
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = encoded_layers[-1]
prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None:
@ -1011,9 +1001,8 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, checkpoint_activations=False):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
seq_relationship_score = self.cls( pooled_output)
if next_sentence_label is not None:
@ -1077,8 +1066,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
@ -1142,11 +1131,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
@ -1212,8 +1201,9 @@ class BertForTokenClassification(BertPreTrainedModel):
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = encoded_layers[-1]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
@ -1251,19 +1241,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
@ -1287,29 +1268,11 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
def forward(self, input_ids, token_type_ids, attention_mask):
encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = encoded_layers[-1]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
else:
return start_logits, end_logits
return start_logits, end_logits

View file

View file

@ -41,6 +41,9 @@ from apex import amp
from sklearn.metrics import matthews_corrcoef, f1_score
from utils import is_main_process
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)

View file

@ -24,7 +24,6 @@ from __future__ import print_function
import csv
import os
import time
import logging
import argparse
import random
import h5py
@ -39,7 +38,7 @@ from apex import amp
import multiprocessing
from tokenization import BertTokenizer
from modeling import BertForPreTraining, BertConfig
import modeling
from apex.optimizers import FusedLAMB
from schedulers import PolyWarmUpScheduler
@ -55,14 +54,25 @@ from apex.amp import _amp_state
import dllogger
from concurrent.futures import ProcessPoolExecutor
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
skipped_steps = 0
def create_pretraining_dataset(input_file, max_pred_length, shared_list, args):
#Workaround because python functions are not picklable
class WorkerInitObj(object):
def __init__(self, seed):
self.seed = seed
def __call__(self, id):
np.random.seed(seed=self.seed + id)
random.seed(self.seed + id)
def create_pretraining_dataset(input_file, max_pred_length, shared_list, args, worker_init):
train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler,
batch_size=args.train_batch_size * args.n_gpu, num_workers=4,
batch_size=args.train_batch_size * args.n_gpu,
num_workers=4, worker_init_fn=worker_init,
pin_memory=True)
return train_dataloader, input_file
@ -97,6 +107,17 @@ class pretraining_dataset(Dataset):
return [input_ids, segment_ids, input_mask,
masked_lm_labels, next_sentence_labels]
class BertPretrainingCriterion(torch.nn.Module):
def __init__(self, vocab_size):
super(BertPretrainingCriterion, self).__init__()
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
self.vocab_size = vocab_size
def forward(self, prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels):
masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
def parse_arguments():
@ -224,12 +245,16 @@ def parse_arguments():
default=False,
action='store_true',
help="Whether to run training.")
parser.add_argument('--json-summary', type=str, default="dllogger.json",
parser.add_argument('--json-summary', type=str, default="results/dllogger.json",
help='If provided, the json summary will be written to'
'the specified file.')
parser.add_argument("--use_env",
action='store_true',
help="Whether to read local rank from ENVVAR")
parser.add_argument('--disable_progress_bar',
default=False,
action='store_true',
help='Disable tqdm progress bar')
args = parser.parse_args()
return args
@ -274,7 +299,7 @@ def setup_training(args):
os.listdir(args.output_dir) and any([i.startswith('ckpt') for i in os.listdir(args.output_dir)])):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if not args.resume_from_checkpoint or not os.path.exists(args.output_dir):
if (not args.resume_from_checkpoint or not os.path.exists(args.output_dir)) and is_main_process():
os.makedirs(args.output_dir, exist_ok=True)
return device, args
@ -282,12 +307,14 @@ def setup_training(args):
def prepare_model_and_optimizer(args, device):
# Prepare model
config = BertConfig.from_json_file(args.config_file)
config = modeling.BertConfig.from_json_file(args.config_file)
# Padding for divisibility by 8
if config.vocab_size % 8 != 0:
config.vocab_size += 8 - (config.vocab_size % 8)
model = BertForPreTraining(config)
modeling.ACT2FN["bias_gelu"] = torch.jit.script(modeling.ACT2FN["bias_gelu"])
model = modeling.BertForPreTraining(config)
checkpoint = None
if not args.resume_from_checkpoint:
@ -327,11 +354,13 @@ def prepare_model_and_optimizer(args, device):
if args.fp16:
if args.loss_scale == 0:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic")
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16)
else:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16)
amp._amp_state.loss_scalers[0]._loss_scale = 2**20
model.checkpoint_activations(args.checkpoint_activations)
if args.resume_from_checkpoint:
if args.phase2 or args.init_checkpoint:
keys = list(checkpoint['optimizer']['state'].keys())
@ -361,7 +390,10 @@ def prepare_model_and_optimizer(args, device):
elif args.n_gpu > 1:
model = torch.nn.DataParallel(model)
return model, optimizer, lr_scheduler, checkpoint, global_step
criterion = BertPretrainingCriterion(config.vocab_size)
return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
global skipped_steps
@ -429,15 +461,17 @@ def main():
if args.use_env and 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.environ['LOCAL_RANK'])
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed + args.local_rank)
np.random.seed(args.seed + args.local_rank)
torch.manual_seed(args.seed + args.local_rank)
torch.cuda.manual_seed(args.seed + args.local_rank)
worker_init = WorkerInitObj(args.seed + args.local_rank)
device, args = setup_training(args)
dllogger.log(step="PARAMETER", data={"Config": [str(args)]})
# Prepare optimizer
model, optimizer, lr_scheduler, checkpoint, global_step = prepare_model_and_optimizer(args, device)
model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(args, device)
if is_main_process():
dllogger.log(step="PARAMETER", data={"SEED": args.seed})
@ -487,7 +521,8 @@ def main():
train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler,
batch_size=args.train_batch_size * args.n_gpu, num_workers=4,
batch_size=args.train_batch_size * args.n_gpu,
num_workers=4, worker_init_fn=worker_init,
pin_memory=True)
# shared_file_list["0"] = (train_dataloader, data_file)
@ -507,17 +542,16 @@ def main():
previous_file = data_file
dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args)
dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init)
train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process() else train_dataloader
train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader
for step, batch in enumerate(train_iter):
training_steps += 1
batch = [t.to(device) for t in batch]
input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,
masked_lm_labels=masked_lm_labels, next_sentence_label=next_sentence_labels,
checkpoint_activations=args.checkpoint_activations)
prediction_scores, seq_relationship_score = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
loss = criterion(prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
@ -549,7 +583,7 @@ def main():
torch.distributed.all_reduce(average_loss)
final_loss = average_loss.item()
if is_main_process():
dllogger.log(step=(epoch, training_steps / args.gradient_accumulation_steps, ), data={"final_loss": final_loss})
dllogger.log(step=(epoch, global_step, ), data={"final_loss": final_loss})
elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
if is_main_process():
dllogger.log(step=(epoch, global_step, ), data={"average_loss": average_loss / (args.log_freq * divisor),

View file

@ -43,6 +43,9 @@ from tokenization import (BasicTokenizer, BertTokenizer, whitespace_tokenize)
from utils import is_main_process, format_step
import dllogger, time
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
if sys.version_info[0] == 2:
import cPickle as pickle
else:
@ -321,30 +324,6 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
if is_training and example.is_impossible:
start_position = 0
end_position = 0
if example_index < 20:
logger.info("*** Example ***")
logger.info("unique_id: %s" % (unique_id))
logger.info("example_index: %s" % (example_index))
logger.info("doc_span_index: %s" % (doc_span_index))
logger.info("tokens: %s" % " ".join(tokens))
logger.info("token_to_orig_map: %s" % " ".join([
"%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
logger.info("token_is_max_context: %s" % " ".join([
"%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info(
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
if is_training and example.is_impossible:
logger.info("impossible example")
if is_training and not example.is_impossible:
answer_text = " ".join(tokens[start_position:(end_position + 1)])
logger.info("start_position: %d" % (start_position))
logger.info("end_position: %d" % (end_position))
logger.info(
"answer: %s" % (answer_text))
features.append(
InputFeatures(
@ -443,197 +422,143 @@ RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, verbose_logging,
version_2_with_negative, null_score_diff_threshold):
"""Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file))
def get_answers(examples, features, results, args):
predictions = collections.defaultdict(list) #it is possible that one example corresponds to multiple features
Prediction = collections.namedtuple('Prediction', ['text', 'start_logit', 'end_logit'])
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in all_results:
unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction",
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index]
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
if args.version_2_with_negative:
null_vals = collections.defaultdict(lambda: (float("inf"),0,0))
for ex, feat, result in match_results(examples, features, results):
start_indices = _get_best_indices(result.start_logits, args.n_best_size)
end_indices = _get_best_indices(result.end_logits, args.n_best_size)
prelim_predictions = get_valid_prelim_predictions(start_indices, end_indices, feat, result, args)
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)
if args.version_2_with_negative:
score = result.start_logits[0] + result.end_logits[0]
if score < null_vals[ex.qas_id][0]:
null_vals[ex.qas_id] = (score, result.start_logits[0], result.end_logits[0])
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {}
nbest = []
curr_predictions = []
seen_predictions = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
if len(curr_predictions) == args.n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
if pred.start_index > 0: # this is a non-null prediction TODO: this probably is irrelevant
final_text = get_answer_text(ex, feat, pred, args)
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
else:
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't include the empty option in the n-best, include it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
seen_predictions.append(final_text)
curr_predictions.append(Prediction(final_text, pred.start_logit, pred.end_logit))
predictions[ex.qas_id] += curr_predictions
# In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure.
if len(nbest) == 1:
nbest.insert(0,
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
#Add empty prediction
if args.version_2_with_negative:
for qas_id in predictions.keys():
predictions[qas_id].append(Prediction('',
null_vals[ex.qas_id][1],
null_vals[ex.qas_id][2]))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1
nbest_answers = collections.defaultdict(list)
answers = {}
for qas_id, preds in predictions.items():
nbest = sorted(
preds,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)[:args.n_best_size]
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
if not best_non_null_entry and entry.text:
best_non_null_entry = entry
probs = _compute_softmax(total_scores)
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
nbest_json.append(output)
assert len(nbest_json) >= 1
if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
nbest_answers[qas_id].append(output)
if args.version_2_with_negative:
score_diff = null_vals[qas_id][0] - best_non_null_entry.start_logit - best_non_null_entry.end_logit
if score_diff > args.null_score_diff_threshold:
answers[qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
answers[qas_id] = best_non_null_entry.text
else:
answers[qas_id] = nbest_answers[qas_id][0]['text']
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
return answers, nbest_answers
with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
def get_answer_text(example, feature, pred, args):
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
if version_2_with_negative:
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, args.do_lower_case, args.verbose_logging)
return final_text
def get_valid_prelim_predictions(start_indices, end_indices, feature, result, args):
_PrelimPrediction = collections.namedtuple(
"PrelimPrediction",
["start_index", "end_index", "start_logit", "end_logit"])
prelim_predictions = []
for start_index in start_indices:
for end_index in end_indices:
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > args.max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
return prelim_predictions
def match_results(examples, features, results):
unique_f_ids = set([f.unique_id for f in features])
unique_r_ids = set([r.unique_id for r in results])
matching_ids = unique_f_ids & unique_r_ids
features = [f for f in features if f.unique_id in matching_ids]
results = [r for r in results if r.unique_id in matching_ids]
features.sort(key=lambda x: x.unique_id)
results.sort(key=lambda x: x.unique_id)
for f, r in zip(features, results): #original code assumes strict ordering of examples. TODO: rewrite this
yield examples[f.example_index], f, r
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
"""Project the tokenized prediction back to the original text."""
@ -732,16 +657,16 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
return output_text
def _get_best_indexes(logits, n_best_size):
def _get_best_indices(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
best_indices = []
for i in range(len(index_and_score)):
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
best_indices.append(index_and_score[i][0])
return best_indices
def _compute_softmax(scores):
@ -885,7 +810,7 @@ def main():
parser.add_argument('--log_freq',
type=int, default=50,
help='frequency of logging loss.')
parser.add_argument('--json-summary', type=str, default="dllogger.json",
parser.add_argument('--json-summary', type=str, default="results/dllogger.json",
help='If provided, the json summary will be written to'
'the specified file.')
parser.add_argument("--eval_script",
@ -902,6 +827,19 @@ def main():
default=False,
action='store_true',
help="Whether to save checkpoints")
parser.add_argument('--disable-progress-bar',
default=False,
action='store_true',
help='Disable tqdm progress bar')
parser.add_argument("--skip_cache",
default=False,
action='store_true',
help="Whether to cache train features")
parser.add_argument("--cache_dir",
default=None,
type=str,
help="Location to cache train feaures. Will default to the dataset directory")
args = parser.parse_args()
if args.use_env and 'LOCAL_RANK' in os.environ:
@ -957,7 +895,7 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and os.listdir(args.output_dir)!=['logfile.txt']:
print("WARNING: Output directory {} already exists and is not empty.".format(args.output_dir), os.listdir(args.output_dir))
if not os.path.exists(args.output_dir):
if not os.path.exists(args.output_dir) and is_main_process():
os.makedirs(args.output_dir)
tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large
@ -986,6 +924,8 @@ def main():
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')["model"], strict=False)
dllogger.log(step="PARAMETER", data={"loaded_checkpoint": True})
model.to(device)
num_weights = sum([p.numel() for p in model.parameters() if p.requires_grad])
dllogger.log(step="PARAMETER", data={"model_weights_num":num_weights})
# Prepare optimizer
param_optimizer = list(model.named_parameters())
@ -1037,9 +977,16 @@ def main():
global_step = 0
if args.do_train:
cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride),
str(args.max_query_length))
if args.cache_dir is None:
cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride),
str(args.max_query_length))
else:
cached_train_features_file = args.cache_dir.strip('/') + '/' + args.train_file.split('/')[-1] + '_{0}_{1}_{2}_{3}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride),
str(args.max_query_length))
train_features = None
try:
with open(cached_train_features_file, "rb") as reader:
@ -1052,10 +999,12 @@ def main():
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
if args.local_rank == -1 or is_main_process():
if not args.skip_cache and is_main_process():
dllogger.log(step="PARAMETER", data={"Cached_train features_file": cached_train_features_file})
with open(cached_train_features_file, "wb") as writer:
pickle.dump(train_features, writer)
dllogger.log(step="PARAMETER", data={"train_start": True})
dllogger.log(step="PARAMETER", data={"training_samples": len(train_examples)})
dllogger.log(step="PARAMETER", data={"training_features": len(train_features)})
@ -1079,7 +1028,7 @@ def main():
final_loss = None
train_start = time.time()
for epoch in range(int(args.num_train_epochs)):
train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process() else train_dataloader
train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader
for step, batch in enumerate(train_iter):
# Terminate early for benchmarking
@ -1089,7 +1038,21 @@ def main():
if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
loss = (start_loss + end_loss) / 2
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
@ -1141,6 +1104,7 @@ def main():
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=False)
dllogger.log(step="PARAMETER", data={"infer_start": True})
dllogger.log(step="PARAMETER", data={"eval_samples": len(eval_examples)})
dllogger.log(step="PARAMETER", data={"eval_features": len(eval_features)})
@ -1159,7 +1123,7 @@ def main():
model.eval()
all_results = []
dllogger.log(step="PARAMETER", data={"eval_start": True})
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.disable_progress_bar):
if len(all_results) % 1000 == 0:
dllogger.log(step="PARAMETER", data={"sample_number": len(all_results)})
input_ids = input_ids.to(device)
@ -1179,12 +1143,19 @@ def main():
time_to_infer = time.time() - infer_start
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
write_predictions(eval_examples, eval_features, all_results,
args.n_best_size, args.max_answer_length,
args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
answers, nbest_answers = get_answers(eval_examples, eval_features, all_results, args)
with open(output_prediction_file, "w") as f:
f.write(json.dumps(answers, indent=4) + "\n")
with open(output_nbest_file, "w") as f:
f.write(json.dumps(nbest_answers, indent=4) + "\n")
# output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
# write_predictions(eval_examples, eval_features, all_results,
# args.n_best_size, args.max_answer_length,
# args.do_lower_case, output_prediction_file,
# output_nbest_file, output_null_log_odds_file, args.verbose_logging,
# args.version_2_with_negative, args.null_score_diff_threshold)
if args.do_eval and is_main_process():
import sys

View file

@ -35,6 +35,9 @@ from modeling import BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAM
from optimization import BertAdam, warmup_linear
from tokenization import BertTokenizer
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)

View file

@ -1,2 +1,2 @@
#!/bin/bash
docker build . --rm -t bert
docker build --network=host . --rm --pull --no-cache -t bert

View file

@ -1,9 +1,16 @@
#!/bin/bash
CMD=${1:-/bin/bash}
NV_VISIBLE_DEVICES=${2:-"all"}
DOCKER_BRIDGE=${3:-"host"}
docker run -it --rm \
--runtime=nvidia \
--gpus device=$NV_VISIBLE_DEVICES \
--net=$DOCKER_BRIDGE \
--shm-size=1g \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
-v ${PWD}:/workspace/bert \
bert bash
-e LD_LIBRARY_PATH='/workspace/install/lib/' \
-v $PWD:/workspace/bert \
-v $PWD/results:/results \
bert $CMD

View file

@ -25,7 +25,7 @@ resume_training=${8:-"false"}
create_logfile=${9:-"true"}
accumulate_gradients=${10:-"true"}
gradient_accumulation_steps=${11:-128}
seed=${12:-$RANDOM}
seed=${12:-42}
job_name=${13:-"bert_lamb_pretraining"}
allreduce_post_accumulation=${14:-"true"}
allreduce_post_accumulation_fp16=${15:-"true"}
@ -34,7 +34,7 @@ learning_rate_phase2=${18:-"4e-3"}
warmup_proportion_phase2=${19:-"0.128"}
train_steps_phase2=${20:-1563}
gradient_accumulation_steps_phase2=${21:-512}
DATASET=hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus # change this for other datasets
DATASET=hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5_shard_1472_test_split_10/books_wiki_en_corpus/training # change this for other datasets
DATA_DIR_PHASE1=${22:-$BERT_PREP_WORKING_DIR/${DATASET}/}
BERT_CONFIG=bert_config.json
CODEDIR=${24:-"/workspace/bert"}
@ -119,6 +119,7 @@ CMD+=" $ALL_REDUCE_POST_ACCUMULATION"
CMD+=" $ALL_REDUCE_POST_ACCUMULATION_FP16"
CMD+=" $INIT_CHECKPOINT"
CMD+=" --do_train"
CMD+=" --json-summary ${RESULTS_DIR}/dllogger.json "
CMD="python3 -m torch.distributed.launch --nproc_per_node=$num_gpus $CMD"
@ -146,7 +147,7 @@ echo "finished pretraining"
#Start Phase2
DATASET=hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus # change this for other datasets
DATASET=hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5_shard_1472_test_split_10/books_wiki_en_corpus/training # change this for other datasets
DATA_DIR_PHASE2=${23:-$BERT_PREP_WORKING_DIR/${DATASET}/}
PREC=""
@ -195,6 +196,7 @@ CMD+=" $CHECKPOINT"
CMD+=" $ALL_REDUCE_POST_ACCUMULATION"
CMD+=" $ALL_REDUCE_POST_ACCUMULATION_FP16"
CMD+=" --do_train --phase2 --resume_from_checkpoint --phase1_end_step=$train_steps"
CMD+=" --json-summary ${RESULTS_DIR}/dllogger.json "
CMD="python3 -m torch.distributed.launch --nproc_per_node=$num_gpus $CMD"
@ -217,4 +219,4 @@ fi
set +x
echo "finished phase2"
echo "finished phase2"

View file

@ -13,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#OUT_DIR=/results/SQuAD
echo "Container nvidia build = " $NVIDIA_BUILD_ID
init_checkpoint=${1:-"/workspace/bert/checkpoints/bert_uncased.pt"}
epochs=${2:-"2.0"}
batch_size=${3:-"3"}
batch_size=${3:-"4"}
learning_rate=${4:-"3e-5"}
precision=${5:-"fp16"}
num_gpu=${6:-"8"}
@ -94,4 +92,4 @@ CMD+=" $use_fp16"
LOGFILE=$OUT_DIR/logfile.txt
echo "$CMD |& tee $LOGFILE"
time $CMD |& tee $LOGFILE
time $CMD |& tee $LOGFILE

View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,115 @@
# Deploying the BERT model using Triton Inference Server
## Solution overview
The [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) provides a datacenter and cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or gRPC endpoint, allowing remote clients to request inferencing for any number of GPU or CPU models being managed by the server.
This folder contains detailed performance analysis as well as scripts to run SQuAD fine-tuning on BERT model using Triton Inference Server.
## Setup
The first step is to train BERT for question answering. The process is the same as in the main readme.
1. Download the squad dataset with `cd [bert folder]/data/squad/ && bash ./squad_download.sh`.
2. Build the Docker container with `bash ./scripts/docker/build.sh`.
3. [train](https://gitlab-master.nvidia.com/dl/JoC/bert_pyt#training-process) your own checkpoint and fine-tune it, or [download](https://ngc.nvidia.com/catalog/models/nvidia:bert_large_pyt_amp_ckpt_squad_qa1_1/files) the already trained and fine-tuned checkpoint from the [NGC](https://ngc.nvidia.com/catalog/models/nvidia:bert_large_pyt_amp_ckpt_squad_qa1_1/files) model repository.
The checkpoint should be placed in `[bert folder]/checkpoints/<checkpoint>`. By default, the scripts assume `<checkpoint>` is `bert_qa.pt`, therefore, you might have to rename the trained or downloaded models as necessary.
Note: The following instructions are run from outside the container and call `docker run` commands as required. \
Unless stated otherwise, all the commands below have to be executed from `[bert folder]`.
## Quick Start Guide
### Deploying the model
The following command exports the checkpoint to `torchscript`, and deploys the Triton model repository.
`bash ./triton/export_model.sh`
The deployed Triton model repository will be in `[bert folder]/results/triton_models`.
Edit `[bert folder]/triton/export_model.sh` to deploy BERT in ONNX format.
Change the value of `EXPORT_FORMAT` from `ts-script` to `onnx`. Additionally, change the value of `triton_model_name` from `bertQA-ts` to `bertQA-onnx`, respectively.
Moreover, you may set `precision` to either `fp32` or `fp16`.
### Running the Triton server
To launch the Triton server, execute the following command.
`docker run --rm --gpus device=0 --ipc=host --network=host -p 8000:8000 -p 8001:8001 -p 8002:8002 -v $PWD/results/triton_models:/models nvcr.io/nvidia/tritonserver:20.03-py3 trtserver --model-store=/models --log-verbose=1`
Here `device=0,1,2,3` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs.
By default, the server expects the model repository to be in `[bert folder]/results/triton_models`.
### Running the custom Triton client
The custom Triton client is found in `[bert folder]/triton/client.py`.
It may be used once BERT is deployed and the Triton server is running. To try it, do the following steps.
1. Start the BERT docker container with the following command: \
`docker run -it --rm --ipc=host --network=host -v $PWD/vocab:/workspace/bert/vocab bert:latest` \
Notice, that for the client, no GPU support is necessary.
2. Move to the triton folder with the following command: \
`cd /workspace/bert/triton/`
3. Run the client with the following command: \
`python client.py --do_lower_case --version_2_with_negative --vocab_file=../vocab/vocab --triton-model-name=bertQA-ts-script`
This will send a request to the already running Triton server, which will process it, and return the result to the client. The response will be printed on the screen.
You may send your own question-context pair for processing, using the `--question` and `--context` flags of client.py, respectively.
You may want to use the `--triton-model-name` flag to select the model in onnx format.
### Evaluating the deployed model on Squad1.1
To deploy and evaluate your model, run the following command.
`bash ./triton/evaluate.sh`
By default, this will deploy BERT in torchscript format, and evaluate it on Squad1.1.
You may change the format of deployment by editing `[bert folder]/triton/evaluate.sh`.
Change the value of `EXPORT_FORMAT` from `ts-script` to `onnx`. Moreover, you may set `precision` to either `fp32` or `fp16`.
### Generating performance data
To collect performance data, run the following command.
`bash ./triton/generate_figures.sh`
By default, this will deploy BERT in `torchscript` format, launch the server, run the perf client, collect statistics and place them in `[bert folder]/results/triton_models/perf_client`.
You may change the format of deployment by editing `./triton/generate_figures.sh`. Change the value of `EXPORT_FORMAT` from `ts-script` to `onnx`, respectively.
Moreover, you may set `precision` to either `fp32` or `fp16`.
## Advanced
### Other scripts
To launch the Triton server in a detached state, run the following command.
`bash ./triton/launch_triton_server.sh`
By default, the Triton server is expecting the model repository in `[bert folder]/results/triton_models`.
To make the machine wait until the server is initialized, and the model is ready for inference, run the following command.
`bash ./triton/wait_for_triton_server.sh`
## Performance
The numbers below are averages, measured on Triton, with [static batching](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/model_configuration.html#scheduling-and-batching).
| Format | GPUs | Batch size | Sequence length | Throughput - FP32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (mixed precision/FP32) |
|--------|------|------------|-----------------|----------------------------------|---------------------------------------------|--------------------------------------------|
|pytorch | 1 | 1 | 384 | 30.1 | 28.0 | 0.93x |
|pytorch | 1 | 8 | 384 | 36.0 | 116.8 | 3.24x |
|torchscript | 1 | 1 | 384 | 32.20 | 38.40 | 1.19x |
|torchscript | 1 | 8 | 384 | 40.00 | 134.40 | 3.36x |
|onnx | 1 | 1 | 384 | 33.30 | 92.00 | 2.76x |
|onnx | 1 | 8 | 384 | 42.60 | 165.30 | 3.88x |

View file

@ -0,0 +1,174 @@
#!/usr/bin/python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import json
import argparse
import numpy as np
from builtins import range
from tensorrtserver.api import *
#
import sys
sys.path.append('../')
from inference import preprocess_tokenized_text, get_answer
from tokenization import BertTokenizer
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=1,
help='batch size for inference. default: 1')
parser.add_argument("--triton-model-name", type=str, default="model_name",
help="the name of the model used for inference")
parser.add_argument("--triton-model-version", type=int, default=-1,
help="the version of the model used for inference")
parser.add_argument("--triton-server-url", type=str, default="localhost:8000",
help="Inference server URL. Default is localhost:8000.")
parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False,
help='Enable verbose output')
parser.add_argument('--protocol', type=str, required=False, default='http',
help='Protocol ("http"/"grpc") used to ' +
'communicate with inference service. Default is "http".')
parser.add_argument('-H', dest='http_headers', metavar="HTTP_HEADER",
required=False, action='append',
help='HTTP headers to add to inference server requests. ' +
'Format is -H"Header:Value".')
## pre- and postprocessing parameters
parser.add_argument("--verbose_logging", action='store_true',
help="If true, all of the warnings related to data processing will be printed. ")
parser.add_argument("--max_seq_length", default=384, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--n_best_size", default=1, type=int,
help="The total number of n-best predictions to generate. ")
parser.add_argument("--max_answer_length", default=30, type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument('--version_2_with_negative',
action='store_true',
help='If true, then the model can reply with "unknown". ')
parser.add_argument('--null_score_diff_threshold',
type=float, default=-11.0,
help="If null_score - best_non_null is greater than the threshold predict 'unknown'. ")
parser.add_argument('--vocab_file',
type=str, default=None, required=True,
help="Vocabulary mapping/file BERT was pretrainined on")
# input texts
parser.add_argument("--question", default="Most antibiotics target bacteria and don't affect what class of organisms? ",
type=str, help="question")
parser.add_argument("--context", default="Within the genitourinary and gastrointestinal tracts, commensal flora serve as biological barriers by competing with pathogenic bacteria for food and space and, in some cases, by changing the conditions in their environment, such as pH or available iron. This reduces the probability that pathogens will reach sufficient numbers to cause illness. However, since most antibiotics non-specifically target bacteria and do not affect fungi, oral antibiotics can lead to an overgrowth of fungi and cause conditions such as a vaginal candidiasis (a yeast infection). There is good evidence that re-introduction of probiotic flora, such as pure cultures of the lactobacilli normally found in unpasteurized yogurt, helps restore a healthy balance of microbial populations in intestinal infections in children and encouraging preliminary data in studies on bacterial gastroenteritis, inflammatory bowel diseases, urinary tract infection and post-surgical infections. ",
type=str, help="context")
args = parser.parse_args()
args.protocol = ProtocolType.from_str(args.protocol)
# Create a health context, get the ready and live state of server.
health_ctx = ServerHealthContext(args.triton_server_url, args.protocol,
http_headers=args.http_headers, verbose=args.verbose)
print("Health for model {}".format(args.triton_model_name))
print("Live: {}".format(health_ctx.is_live()))
print("Ready: {}".format(health_ctx.is_ready()))
# Create a status context and get server status
status_ctx = ServerStatusContext(args.triton_server_url, args.protocol, args.triton_model_name,
http_headers=args.http_headers, verbose=args.verbose)
print("Status for model {}".format(args.triton_model_name))
print(status_ctx.get_server_status())
# Create the inference context for the model.
infer_ctx = InferContext(args.triton_server_url, args.protocol, args.triton_model_name, args.triton_model_version,
http_headers=args.http_headers, verbose=args.verbose)
print("question: ", args.question)
print("context: ", args.context)
print()
# pre-processing
tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large
doc_tokens = args.context.split()
query_tokens = tokenizer.tokenize(args.question)
feature = preprocess_tokenized_text(doc_tokens,
query_tokens,
tokenizer,
max_seq_length=args.max_seq_length,
max_query_length=args.max_query_length)
tensors_for_inference, tokens_for_postprocessing = feature
dtype = np.int64
input_ids = np.array(tensors_for_inference.input_ids, dtype=dtype)[None,...] # make bs=1
segment_ids = np.array(tensors_for_inference.segment_ids, dtype=dtype)[None,...] # make bs=1
input_mask = np.array(tensors_for_inference.input_mask, dtype=dtype)[None,...] # make bs=1
assert args.batch_size == input_ids.shape[0]
assert args.batch_size == segment_ids.shape[0]
assert args.batch_size == input_mask.shape[0]
# prepare inputs
input_dict = {
"input__0" : tuple(input_ids[i] for i in range(args.batch_size)),
"input__1" : tuple(segment_ids[i] for i in range(args.batch_size)),
"input__2" : tuple(input_mask[i] for i in range(args.batch_size))
}
# prepare outputs
output_keys = [
"output__0",
"output__1"
]
output_dict = {}
for k in output_keys:
output_dict[k] = InferContext.ResultFormat.RAW
# Send inference request to the inference server.
result = infer_ctx.run(input_dict, output_dict, args.batch_size)
# get the result
start_logits = result["output__0"][0].tolist()
end_logits = result["output__1"][0].tolist()
# post-processing
answer, answers = get_answer(doc_tokens, tokens_for_postprocessing,
start_logits, end_logits, args)
# print result
print()
print(answer)
print()
print(json.dumps(answers, indent=4))

View file

@ -0,0 +1,151 @@
#!/usr/bin/python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import pickle
import argparse
import deployer_lib
#
import sys
sys.path.append('../')
sys.path.append('.')
from modeling import BertForQuestionAnswering, BertConfig
from tokenization import BertTokenizer
from run_squad import convert_examples_to_features, read_squad_examples
def get_model_args(model_args):
''' the arguments initialize_model will receive '''
parser = argparse.ArgumentParser()
## Required parameters by the model.
parser.add_argument("--checkpoint",
default=None,
type=str,
required=True,
help="The checkpoint of the model. ")
parser.add_argument('--batch_size',
default=8,
type=int,
help='Batch size for inference')
parser.add_argument("--bert_model", default="bert-large-uncased", type=str,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument('--vocab_file',
type=str, default=None, required=True,
help="Vocabulary mapping/file BERT was pretrainined on")
parser.add_argument("--predict_file", default=None, type=str,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument('--version_2_with_negative',
action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument("--max_seq_length", default=384, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--doc_stride", default=128, type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.")
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--config_file",
default=None,
type=str,
required=True,
help="The BERT model config")
parser.add_argument('--fp16',
action='store_true',
help="use mixed-precision")
parser.add_argument('--nbatches',
default=2,
type=int,
help='Number of batches in the inference dataloader. Default: 10. ')
return parser.parse_args(model_args)
def initialize_model(args):
''' return model, ready to trace '''
config = BertConfig.from_json_file(args.config_file)
if config.vocab_size % 8 != 0:
config.vocab_size += 8 - (config.vocab_size % 8)
model = BertForQuestionAnswering(config)
model.enable_apex(False)
state_dict = torch.load(args.checkpoint, map_location='cpu')["model"]
model.load_state_dict(state_dict)
if args.fp16:
model.half()
return model
def get_dataloader(args):
''' return dataloader for inference '''
# Preprocess input data
tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large
cached_features_file = args.predict_file + '_{}_{}.bin'.format(args.max_seq_length, args.doc_stride)
try:
with open(cached_features_file, "rb") as reader:
eval_features = pickle.load(reader)
except:
eval_examples = read_squad_examples(
input_file=args.predict_file,
is_training=False,
version_2_with_negative=args.version_2_with_negative)
eval_features = convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=False)
with open(cached_features_file, "wb") as writer:
pickle.dump(eval_features, writer)
data = []
for feature in eval_features:
input_ids = torch.tensor(feature.input_ids, dtype=torch.int64)
input_mask = torch.tensor(feature.input_mask, dtype=torch.int64)
segment_ids = torch.tensor(feature.segment_ids, dtype=torch.int64)
inp = (input_ids, segment_ids, input_mask)
data.append(inp)
if args.nbatches > 0:
data = data[:args.nbatches*args.batch_size]
test_loader = torch.utils.data.DataLoader(
data,
batch_size=args.batch_size,
shuffle=False,
num_workers=1,
pin_memory=True)
return test_loader
if __name__=='__main__':
# don't touch this!
deployer, model_argv = deployer_lib.create_deployer(sys.argv[1:]) # deployer and returns removed deployer arguments
model_args = get_model_args(model_argv)
model = initialize_model(model_args)
dataloader = get_dataloader(model_args)
deployer.deploy(dataloader, model)

View file

@ -0,0 +1,563 @@
#!/usr/bin/python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import json
import torch
import argparse
import statistics
from collections import Counter
torch_type_to_triton_type = {
torch.bool: 'TYPE_BOOL',
torch.int8: 'TYPE_INT8',
torch.int16: 'TYPE_INT16',
torch.int32: 'TYPE_INT32',
torch.int64: 'TYPE_INT64',
torch.uint8: 'TYPE_UINT8',
torch.float16: 'TYPE_FP16',
torch.float32: 'TYPE_FP32',
torch.float64: 'TYPE_FP64'
}
CONFIG_TEMPLATE = r"""
name: "{model_name}"
platform: "{platform}"
max_batch_size: {max_batch_size}
input [
{spec_inputs}
]
output [
{spec_outputs}
]
{dynamic_batching}
{model_optimizations}
instance_group [
{{
count: {engine_count}
kind: KIND_GPU
gpus: [ {gpu_list} ]
}}
]"""
INPUT_TEMPLATE = r"""
{{
name: "input__{num}"
data_type: {type}
dims: {dims}
{reshape}
}},"""
OUTPUT_TEMPLATE = r"""
{{
name: "output__{num}"
data_type: {type}
dims: {dims}
{reshape}
}},"""
MODEL_OPTIMIZATION_TEMPLATE = r"""
optimization {{
cuda {{
graphs: {capture_cuda_graph}
}}
}}"""
def remove_empty_lines(text):
''' removes empty lines from text, returns the result '''
ret = "".join([s for s in text.strip().splitlines(True) if s.strip()])
return ret
def create_deployer(argv):
''' takes a list of arguments, returns a deployer object and the list of unused arguments '''
parser = argparse.ArgumentParser()
# required args
method = parser.add_mutually_exclusive_group(required=True)
method.add_argument('--ts-script',
action='store_true',
help='convert to torchscript using torch.jit.script')
method.add_argument('--ts-trace',
action='store_true',
help='convert to torchscript using torch.jit.trace')
method.add_argument('--onnx',
action='store_true',
help='convert to onnx using torch.onnx.export')
# triton related args
arguments = parser.add_argument_group('triton related flags')
arguments.add_argument('--triton-no-cuda',
action='store_true',
help='Use the CPU for tracing.')
arguments.add_argument('--triton-model-name',
type=str,
default="model",
help="exports to appropriate directory structure for Triton")
arguments.add_argument("--triton-model-version",
type=int,
default=1,
help="exports to appropriate directory structure for Triton")
arguments.add_argument("--triton-server-url",
type=str,
default="localhost:8001",
help="exports to appropriate directory structure for Triton")
arguments.add_argument("--triton-max-batch-size",
type=int,
default=8,
help="Specifies the 'max_batch_size' in the Triton model config.\
See the Triton documentation for more info.")
arguments.add_argument("--triton-dyn-batching-delay",
type=float,
default=0,
help="Determines the dynamic_batching queue delay in milliseconds(ms) for\
the Triton model config. Use '0' or '-1' to specify static batching.\
See the Triton documentation for more info.")
arguments.add_argument("--triton-engine-count",
type=int,
default=1,
help="Specifies the 'instance_group' count value in the Triton model config.\
See the Triton documentation for more info.")
arguments.add_argument('--save-dir', type=str, default='./triton_models', help='Saved model directory')
# optimization args
arguments = parser.add_argument_group('optimization flags')
arguments.add_argument("--capture-cuda-graph",
type=int,
default=0,
help="capture cuda graph for obtaining speedup. possible values: 0, 1. default: 0 (automatic). ")
# remainder args
arguments.add_argument('model_arguments', nargs=argparse.REMAINDER, help='arguments that will be ignored by deployer lib and will be forwarded to your deployer script')
#
args = parser.parse_args(argv)
deployer = Deployer(args)
#
return deployer, args.model_arguments[1:]
class DeployerLibrary:
def __init__(self, args):
self.args = args
self.platform = None
def set_platform(self, platform):
''' sets the platform
:: platform :: "pytorch_libtorch" or "onnxruntime_onnx" or "tensorrt_plan"
'''
self.platform = platform
def prepare_inputs(self, dataloader, device):
''' load sample inputs to device '''
inputs = []
for batch in dataloader:
if type(batch) is torch.Tensor:
batch_d = batch.to(device)
batch_d = (batch_d,)
inputs.append(batch_d)
else:
batch_d = []
for x in batch:
assert type(x) is torch.Tensor, "input is not a tensor"
batch_d.append(x.to(device))
batch_d = tuple(batch_d)
inputs.append(batch_d)
return inputs
def get_list_of_shapes(self, l, fun):
''' returns the list of min/max shapes, depending on fun
:: l :: list of tuples of tensors
:: fun :: min or max
'''
tensor_tuple = l[0]
shapes = [list(x.shape) for x in tensor_tuple]
for tensor_tuple in l:
assert len(tensor_tuple) == len(shapes), "tensors with varying shape lengths are not supported"
for i,x in enumerate(tensor_tuple):
for j in range(len(x.shape)):
shapes[i][j] = fun(shapes[i][j], x.shape[j])
return shapes # a list of shapes
def get_tuple_of_min_shapes(self, l):
''' returns the tuple of min shapes
:: l :: list of tuples of tensors '''
shapes = self.get_list_of_shapes(l, min)
min_batch = 1
shapes = [[min_batch,*shape[1:]] for shape in shapes]
shapes = tuple(shapes)
return shapes # tuple of min shapes
def get_tuple_of_max_shapes(self, l):
''' returns the tuple of max shapes
:: l :: list of tuples of tensors '''
shapes = self.get_list_of_shapes(l, max)
max_batch = max(2,shapes[0][0])
shapes = [[max_batch,*shape[1:]] for shape in shapes]
shapes = tuple(shapes)
return shapes # tuple of max shapes
def get_tuple_of_opt_shapes(self, l):
''' returns the tuple of opt shapes
:: l :: list of tuples of tensors '''
counter = Counter()
for tensor_tuple in l:
shapes = [tuple(x.shape) for x in tensor_tuple]
shapes = tuple(shapes)
counter[shapes] += 1
shapes = counter.most_common(1)[0][0]
return shapes # tuple of most common occuring shapes
def get_tuple_of_dynamic_shapes(self, l):
''' returns a tuple of dynamic shapes: variable tensor dimensions
(for ex. batch size) occur as -1 in the tuple
:: l :: list of tuples of tensors '''
tensor_tuple = l[0]
shapes = [list(x.shape) for x in tensor_tuple]
for tensor_tuple in l:
err_msg = "tensors with varying shape lengths are not supported"
assert len(tensor_tuple) == len(shapes), err_msg
for i,x in enumerate(tensor_tuple):
for j in range(len(x.shape)):
if shapes[i][j] != x.shape[j] or j == 0:
shapes[i][j] = -1
shapes = tuple(shapes)
return shapes # tuple of dynamic shapes
def run_models(self, models, inputs):
''' run the models on inputs, return the outputs and execution times '''
ret = []
for model in models:
torch.cuda.synchronize()
time_start = time.time()
outputs = []
for input in inputs:
with torch.no_grad():
output = model(*input)
if type(output) is torch.Tensor:
output = [output]
outputs.append(output)
torch.cuda.synchronize()
time_end = time.time()
t = time_end - time_start
ret.append(outputs)
ret.append(t)
return ret
def compute_errors(self, outputs_A, outputs_B):
''' returns the list of L_inf errors computed over every single output tensor '''
Linf_errors = []
for output_A,output_B in zip(outputs_A,outputs_B):
for x,y in zip(output_A, output_B):
error = (x - y).norm(float('inf')).item()
Linf_errors.append(error)
return Linf_errors
def print_errors(self, Linf_errors):
''' print various statistcs of Linf errors '''
print()
print("conversion correctness test results")
print("-----------------------------------")
print("maximal absolute error over dataset (L_inf): ", max(Linf_errors))
print()
print("average L_inf error over output tensors: ", statistics.mean(Linf_errors))
print("variance of L_inf error over output tensors: ", statistics.variance(Linf_errors))
print("stddev of L_inf error over output tensors: ", statistics.stdev(Linf_errors))
print()
def write_config(self, config_filename,
input_shapes, input_types,
output_shapes, output_types):
''' writes Triton config file
:: config_filename :: the file to write the config file into
:: input_shapes :: tuple of dynamic shapes of the input tensors
:: input_types :: tuple of torch types of the input tensors
:: output_shapes :: tuple of dynamic shapes of the output tensors
:: output_types :: tuple of torch types of the output tensors
'''
assert self.platform is not None, "error - platform is not set"
config_template = CONFIG_TEMPLATE
input_template = INPUT_TEMPLATE
optimization_template = MODEL_OPTIMIZATION_TEMPLATE
spec_inputs = r""""""
for i,(shape,typ) in enumerate(zip(input_shapes,input_types)):
d = {
'num' : str(i),
'type': torch_type_to_triton_type[typ],
'dims': str([1]) if len(shape) == 1 else str(list(shape)[1:]) # first dimension is the batch size
}
d['reshape'] = 'reshape: { shape: [ ] }' if len(shape) == 1 else ''
spec_inputs += input_template.format_map(d)
spec_inputs = spec_inputs[:-1]
output_template = OUTPUT_TEMPLATE
spec_outputs = r""""""
for i,(shape,typ) in enumerate(zip(output_shapes,output_types)):
d = {
'num' : str(i),
'type': torch_type_to_triton_type[typ],
'dims': str([1]) if len(shape) == 1 else str(list(shape)[1:]) # first dimension is the batch size
}
d['reshape'] = 'reshape: { shape: [ ] }' if len(shape) == 1 else ''
spec_outputs += output_template.format_map(d)
spec_outputs = spec_outputs[:-1]
batching_str = ""
max_batch_size = self.args.triton_max_batch_size
if (self.args.triton_dyn_batching_delay > 0):
# Use only full and half full batches
pref_batch_size = [int(max_batch_size / 2.0), max_batch_size]
batching_str = r"""
dynamic_batching {{
preferred_batch_size: [{0}]
max_queue_delay_microseconds: {1}
}}""".format(", ".join([str(x) for x in pref_batch_size]),
int(self.args.triton_dyn_batching_delay * 1000.0))
d = {
"capture_cuda_graph": str(self.args.capture_cuda_graph)
}
optimization_str = optimization_template.format_map(d)
config_values = {
"model_name": self.args.triton_model_name,
"platform": self.platform,
"max_batch_size": max_batch_size,
"spec_inputs": spec_inputs,
"spec_outputs": spec_outputs,
"dynamic_batching": batching_str,
"model_optimizations" : optimization_str,
"gpu_list": ", ".join([str(x) for x in range(torch.cuda.device_count())]),
"engine_count": self.args.triton_engine_count
}
# write config
with open(config_filename, "w") as file:
final_config_str = config_template.format_map(config_values)
final_config_str = remove_empty_lines(final_config_str)
file.write(final_config_str)
class Deployer:
def __init__(self, args):
self.args = args
self.lib = DeployerLibrary(args)
def deploy(self, dataloader, model):
''' deploy the model and test for correctness with dataloader '''
if self.args.ts_script or self.args.ts_trace:
self.lib.set_platform("pytorch_libtorch")
print("deploying model " + self.args.triton_model_name + " in format " + self.lib.platform)
self.to_triton_torchscript(dataloader, model)
elif self.args.onnx:
self.lib.set_platform("onnxruntime_onnx")
print("deploying model " + self.args.triton_model_name + " in format " + self.lib.platform)
self.to_triton_onnx(dataloader, model)
else:
assert False, "error"
print("done")
def to_triton_onnx(self, dataloader, model):
''' export the model to onnx and test correctness on dataloader '''
import onnx
import onnxruntime
# setup device
if self.args.triton_no_cuda:
device = torch.device('cpu')
else:
device = torch.device('cuda')
# prepare model
model.to(device)
model.eval()
assert not model.training, "internal error - model should be in eval() mode! "
# prepare inputs
inputs = self.lib.prepare_inputs(dataloader, device)
# generate outputs
outputs = []
for input in inputs:
with torch.no_grad():
output = model(*input)
if type(output) is torch.Tensor:
output = [output]
outputs.append(output)
# generate input shapes - dynamic tensor shape support
input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
# generate output shapes - dynamic tensor shape support
output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
# generate input types
input_types = [x.dtype for x in inputs[0]]
# generate output types
output_types = [x.dtype for x in outputs[0]]
# get input names
rng = range(len(input_types))
input_names = ["input__" + str(num) for num in rng]
# get output names
rng = range(len(output_types))
output_names = ["output__" + str(num) for num in rng]
# prepare save path
model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
if not os.path.exists(version_folder):
os.makedirs(version_folder)
final_model_path = os.path.join(version_folder, 'model.onnx')
# get indices of dynamic input and output shapes
dynamic_axes = {}
for input_name,input_shape in zip(input_names,input_shapes):
dynamic_axes[input_name] = [i for i,x in enumerate(input_shape) if x == -1]
for output_name,output_shape in zip(output_names,output_shapes):
dynamic_axes[output_name] = [i for i,x in enumerate(output_shape) if x == -1]
# export the model
assert not model.training, "internal error - model should be in eval() mode! "
with torch.no_grad():
torch.onnx.export(model, inputs[0], final_model_path, verbose=False,
input_names=input_names, output_names=output_names,
dynamic_axes=dynamic_axes, opset_version=11)
# syntactic error check
converted_model = onnx.load(final_model_path)
# check that the IR is well formed
onnx.checker.check_model(converted_model)
# load the model
session = onnxruntime.InferenceSession(final_model_path, None)
class ONNX_model:
def __init__(self, session, input_names, device):
self.session = session
self.input_names = input_names
def to_numpy(self, tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def __call__(self, *inputs):
inp = [(input_name, inputs[i]) for i,input_name in enumerate(self.input_names)]
inp = {input_name : self.to_numpy(x) for input_name,x in inp}
outputs = self.session.run(None, inp)
outputs = [torch.from_numpy(output) for output in outputs]
outputs = [output.to(device) for output in outputs]
if len(outputs) == 1:
outputs = outputs[0]
return outputs
# switch to eval mode
model_onnx = ONNX_model(session, input_names, device)
# run both models on inputs
assert not model.training, "internal error - model should be in eval() mode! "
models = (model, model_onnx)
outputs, time_model, outputs_onnx, time_model_onnx = self.lib.run_models(models, inputs)
# check for errors
Linf_errors = self.lib.compute_errors(outputs, outputs_onnx)
self.lib.print_errors(Linf_errors)
print('time of error check of native model: ', time_model, 'seconds')
print('time of error check of onnx model: ', time_model_onnx, 'seconds')
print()
# write Triton config
config_filename = os.path.join(model_folder, "config.pbtxt")
self.lib.write_config(config_filename,
input_shapes, input_types,
output_shapes, output_types)
def to_triton_torchscript(self, dataloader, model):
''' export the model to torchscript and test correctness on dataloader '''
# setup device
if self.args.triton_no_cuda:
device = torch.device('cpu')
else:
device = torch.device('cuda')
# prepare model
model.to(device)
model.eval()
assert not model.training, "internal error - model should be in eval() mode! "
# prepare inputs
inputs = self.lib.prepare_inputs(dataloader, device)
# generate input shapes - dynamic tensor shape support
input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
# generate input types
input_types = [x.dtype for x in inputs[0]]
# prepare save path
model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
if not os.path.exists(version_folder):
os.makedirs(version_folder)
final_model_path = os.path.join(version_folder, 'model.pt')
# convert the model
with torch.no_grad():
if self.args.ts_trace: # trace it
model_ts = torch.jit.trace(model, inputs[0])
if self.args.ts_script: # script it
model_ts = torch.jit.script(model)
# save the model
torch.jit.save(model_ts, final_model_path)
# load the model
model_ts = torch.jit.load(final_model_path)
model_ts.eval() # WAR for bug : by default, model_ts gets loaded in training mode
# run both models on inputs
assert not model.training, "internal error - model should be in eval() mode! "
assert not model_ts.training, "internal error - converted model should be in eval() mode! "
models = (model, model_ts)
outputs, time_model, outputs_ts, time_model_ts = self.lib.run_models(models, inputs)
# check for errors
Linf_errors = self.lib.compute_errors(outputs, outputs_ts)
self.lib.print_errors(Linf_errors)
print('time of error check of native model: ', time_model, 'seconds')
print('time of error check of ts model: ', time_model_ts, 'seconds')
print()
# generate output shapes - dynamic tensor shape support
output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
# generate output types
output_types = [x.dtype for x in outputs[0]]
# now we build the config for Triton
config_filename = os.path.join(model_folder, "config.pbtxt")
self.lib.write_config(config_filename,
input_shapes, input_types,
output_shapes, output_types)

View file

@ -0,0 +1,64 @@
#!/bin/bash
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export TRITON_MODEL_OVERWRITE=True
NV_VISIBLE_DEVICES=0
bert_model=${1:-"large"}
precision=${2:-"fp32"}
init_checkpoint=${3:-"/workspace/bert/checkpoints/bert_qa.pt"}
EXPORT_FORMAT=${4:-"ts-script"}
MODEL_NAME="bert_${bert_model}_${precision}"
BERT_DIR="/workspace/bert"
VOCAB_FILE="/workspace/bert/vocab/vocab"
PREDICT_FILE="/workspace/bert/data/squad/v1.1/dev-v1.1.json"
SQUAD_DIR="/workspace/bert/data/squad/v1.1"
OUT_DIR="/results"
BATCH_SIZE="8"
# Create common bridge for client and server
BRIDGE_NAME="tritonnet"
docker network create ${BRIDGE_NAME}
EXPORT_MODEL_ARGS="${BATCH_SIZE} ${BERT_DIR} ${EXPORT_FORMAT} ${precision} 1 ${MODEL_NAME} 0 1"
# Clean up
cleanup() {
docker kill trt_server_cont
docker network rm ${BRIDGE_NAME}
}
trap cleanup EXIT
trap cleanup SIGTERM
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${EXPORT_MODEL_ARGS} ${TRITON_MODEL_OVERWRITE}
# Start Server
echo Starting server...
SERVER_ID=$( ./triton/launch_triton_server.sh ${BRIDGE_NAME} --NV_VISIBLE_DEVICES=$NV_VISIBLE_DEVICES )
SERVER_IP=$( docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' ${SERVER_ID} )
./triton/wait_for_triton_server.sh
CMD="python triton/run_squad_client.py \
--model_name ${MODEL_NAME} \
--do_lower_case \
--vocab_file ${VOCAB_FILE} \
--output_dir ${OUT_DIR} \
--predict_file ${PREDICT_FILE} \
--batch_size ${BATCH_SIZE}"
bash scripts/docker/launch.sh "${CMD}"
bash scripts/docker/launch.sh "python ${SQUAD_DIR}/evaluate-v1.1.py ${PREDICT_FILE} ${OUT_DIR}/predictions.json"

View file

@ -0,0 +1,53 @@
#!/bin/bash
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
NV_VISIBLE_DEVICES=${1:-"0"}
DOCKER_BRIDGE=${2:-"host"}
checkpoint=${3:-"/workspace/bert/checkpoints/bert_qa.pt"}
batch_size=${4:-"8"}
BERT_DIR=${5:-"/workspace/bert"}
EXPORT_FORMAT=${6:-"ts-script"}
precision=${7:-"fp16"}
triton_model_version=${8:-1}
triton_model_name=${9:-"bertQA-ts-script"}
triton_dyn_batching_delay=${10:-0}
triton_engine_count=${11:-1}
triton_model_overwrite=${12:-"False"}
PREDICT_FILE="/workspace/bert/data/squad/v1.1/dev-v1.1.json"
DEPLOYER="deployer.py"
CMD="python triton/${DEPLOYER} \
--${EXPORT_FORMAT} \
--save-dir /results/triton_models \
--triton-model-name ${triton_model_name} \
--triton-model-version ${triton_model_version} \
--triton-max-batch-size ${batch_size} \
--triton-dyn-batching-delay ${triton_dyn_batching_delay} \
--triton-engine-count ${triton_engine_count} "
CMD+="-- --checkpoint ${checkpoint} \
--config_file ${BERT_DIR}/bert_config.json \
--vocab_file /workspace/bert/vocab/vocab \
--predict_file ${PREDICT_FILE} \
--do_lower_case \
--batch_size=${batch_size} "
if [[ $precision == "fp16" ]]; then
CMD+="--fp16 "
fi
bash scripts/docker/launch.sh "${CMD}" ${NV_VISIBLE_DEVICES} ${DOCKER_BRIDGE}

View file

@ -0,0 +1,157 @@
#!/bin/bash
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export TRITON_MODEL_OVERWRITE=True
NV_VISIBLE_DEVICES=0
bert_model=${1:-"large"}
precision=${2:-"fp32"}
init_checkpoint=${3:-"/workspace/bert/checkpoints/bert_qa.pt"}
EXPORT_FORMAT=${4:-"ts-script"}
PROFILING_DATA="triton/profiling_data_int64"
MODEL_NAME="bert_${bert_model}_${precision}_${EXPORT_FORMAT}"
BERT_DIR="/workspace/bert"
# Create common bridge for client and server
BRIDGE_NAME="tritonnet"
docker network create ${BRIDGE_NAME}
# Start Server
echo Starting server...
#bash triton/launch_triton_server.sh
SERVER_ID=$( ./triton/launch_triton_server.sh ${BRIDGE_NAME} --NV_VISIBLE_DEVICES=$NV_VISIBLE_DEVICES )
SERVER_IP=$( docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' ${SERVER_ID} )
EXPORT_MODEL_ARGS="${BERT_DIR} ${EXPORT_FORMAT} ${precision} 1 ${MODEL_NAME}"
PERF_CLIENT_ARGS="50000 10 20"
# Restart Server
restart_server() {
docker kill trt_server_cont
#bash triton/launch_triton_server.sh
SERVER_ID=$( ./triton/launch_triton_server.sh ${BRIDGE_NAME} --NV_VISIBLE_DEVICES=$NV_VISIBLE_DEVICES )
SERVER_IP=$( docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' ${SERVER_ID} )
}
# Clean up
cleanup() {
docker kill trt_server_cont
docker network rm ${BRIDGE_NAME}
}
trap cleanup EXIT
############## Dynamic Batching Comparison ##############
SERVER_BATCH_SIZE=8
CLIENT_BATCH_SIZE=1
TRITON_ENGINE_COUNT=1
TEST_NAME="DYN_BATCH_"
# Dynamic batching 10 ms
TRITON_DYN_BATCHING_DELAY=10
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} ${PERF_CLIENT_ARGS} ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${TRITON_DYN_BATCHING_DELAY} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
# Dynamic batching 5 ms
TRITON_DYN_BATCHING_DELAY=5
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} ${PERF_CLIENT_ARGS} ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${TRITON_DYN_BATCHING_DELAY} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
# Dynamic batching 2 ms
TRITON_DYN_BATCHING_DELAY=2
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} ${PERF_CLIENT_ARGS} ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${TRITON_DYN_BATCHING_DELAY} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
# Static Batching (i.e. Dynamic batching 0 ms)
TRITON_DYN_BATCHING_DELAY=0
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} ${PERF_CLIENT_ARGS} ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${TRITON_DYN_BATCHING_DELAY} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
############## Engine Count Comparison ##############
SERVER_BATCH_SIZE=1
CLIENT_BATCH_SIZE=1
TRITON_DYN_BATCHING_DELAY=0
TEST_NAME="ENGINE_C_"
# Engine Count = 4
TRITON_ENGINE_COUNT=4
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} ${PERF_CLIENT_ARGS} ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${TRITON_ENGINE_COUNT} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
# Engine Count = 2
TRITON_ENGINE_COUNT=2
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} ${PERF_CLIENT_ARGS} ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${TRITON_ENGINE_COUNT} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
# Engine Count = 1
TRITON_ENGINE_COUNT=1
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} ${PERF_CLIENT_ARGS} ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${TRITON_ENGINE_COUNT} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
############## Batch Size Comparison ##############
# BATCH=1 Generate model and perf
SERVER_BATCH_SIZE=1
CLIENT_BATCH_SIZE=1
TRITON_ENGINE_COUNT=1
TRITON_DYN_BATCHING_DELAY=0
TEST_NAME="BATCH_SIZE_"
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} 1000 10 64 ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${SERVER_BATCH_SIZE} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
# BATCH=2 Generate model and perf
SERVER_BATCH_SIZE=2
CLIENT_BATCH_SIZE=2
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} 1000 10 32 ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${SERVER_BATCH_SIZE} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
# BATCH=4 Generate model and perf
SERVER_BATCH_SIZE=4
CLIENT_BATCH_SIZE=4
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} 1000 10 16 ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${SERVER_BATCH_SIZE} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}
# BATCH=8 Generate model and perf
SERVER_BATCH_SIZE=8
CLIENT_BATCH_SIZE=8
./triton/export_model.sh ${NV_VISIBLE_DEVICES} ${BRIDGE_NAME} ${init_checkpoint} ${SERVER_BATCH_SIZE} ${EXPORT_MODEL_ARGS} ${TRITON_DYN_BATCHING_DELAY} ${TRITON_ENGINE_COUNT} ${TRITON_MODEL_OVERWRITE}
restart_server
sleep 15
./triton/run_perf_client.sh ${MODEL_NAME} 1 ${precision} ${CLIENT_BATCH_SIZE} 1000 10 8 ${SERVER_IP} ${BRIDGE_NAME} ${TEST_NAME}${SERVER_BATCH_SIZE} ${PROFILING_DATA} ${NV_VISIBLE_DEVICES}

View file

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
DOCKER_BRIDGE=${1:-"bridge"}
NV_VISIBLE_DEVICES=${NV_VISIBLE_DEVICES:-"0"}
# Start TRITON server in detached state
docker run -d --rm \
--gpus device=${NV_VISIBLE_DEVICES} \
--shm-size=1g \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
--network=${DOCKER_BRIDGE} \
-p 8000:8000 \
-p 8001:8001 \
-p 8002:8002 \
--name trt_server_cont \
-v $PWD/results/triton_models:/models \
nvcr.io/nvidia/tritonserver:20.03-py3 trtserver --model-store=/models --log-verbose=1

View file

@ -0,0 +1,82 @@
#!/bin/bash
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
MODEL_NAME=${1:-"bert"}
MODEL_VERSION=${2:-1}
precision=${3:-"fp32"}
BATCH_SIZE=${4:-1}
MAX_LATENCY=${5:-500}
MAX_CLIENT_THREADS=${6:-10}
MAX_CONCURRENCY=${7:-50}
SERVER_HOSTNAME=${8:-"localhost"}
DOCKER_BRIDGE=${9:-"host"}
RESULTS_ID=${10:-""}
PROFILING_DATA=${11:-"triton/profiling_data_int64"}
NV_VISIBLE_DEVICES=${12:-"0"}
if [[ $SERVER_HOSTNAME == *":"* ]]; then
echo "ERROR! Do not include the port when passing the Server Hostname. These scripts require that the TRITON HTTP endpoint is on Port 8000 and the gRPC endpoint is on Port 8001. Exiting..."
exit 1
fi
if [ "$SERVER_HOSTNAME" = "localhost" ]
then
if [ ! "$(docker inspect -f "{{.State.Running}}" trt_server_cont)" = "true" ] ; then
echo "Launching TRITON server"
bash triton/launch_triton_server.sh ${DOCKER_BRIDGE} --NV_VISIBLE_DEVICES=$NV_VISIBLE_DEVICES
SERVER_LAUNCHED=true
function cleanup_server {
docker kill trt_server_cont
}
# Ensure we cleanup the server on exit
# trap "exit" INT TERM
trap cleanup_server EXIT
fi
fi
# Wait until server is up. curl on the health of the server and sleep until its ready
bash triton/wait_for_triton_server.sh $SERVER_HOSTNAME
TIMESTAMP=$(date "+%y%m%d_%H%M")
# Create model directory on host (directory /results is mounted)
bash scripts/docker/launch.sh "mkdir -p /results/perf_client/${MODEL_NAME}"
if [ ! -z "${RESULTS_ID}" ];
then
RESULTS_ID="_${RESULTS_ID}"
fi
OUTPUT_FILE_CSV="/results/perf_client/${MODEL_NAME}/results${RESULTS_ID}_${TIMESTAMP}.csv"
ARGS="\
--max-threads ${MAX_CLIENT_THREADS} \
-m ${MODEL_NAME} \
-x ${MODEL_VERSION} \
-p 3000 \
-d \
-v \
-i gRPC \
-u ${SERVER_HOSTNAME}:8001 \
-b ${BATCH_SIZE} \
-l ${MAX_LATENCY} \
-c ${MAX_CONCURRENCY} \
-f ${OUTPUT_FILE_CSV} \
--input-data ${PROFILING_DATA}"
echo "Using args: $(echo "$ARGS" | sed -e 's/ -/\n-/g')"
bash scripts/docker/launch.sh "/workspace/install/bin/perf_client $ARGS" all $DOCKER_BRIDGE

View file

@ -0,0 +1,280 @@
#!/usr/bin/python
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import numpy as np
import os
import sys
from builtins import range
import collections
from tqdm import tqdm
import time
from tensorrtserver.api import *
sys.path.append('.')
from run_squad import get_answers, convert_examples_to_features, read_squad_examples
from tokenization import BertTokenizer
import json
import pickle
args = None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False,
help='Enable verbose output')
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8000',
help='Inference server URL. Default is localhost:8000.')
parser.add_argument('-i', '--protocol', type=str, required=False, default='http',
help='Protocol ("http"/"grpc") used to ' +
'communicate with inference service. Default is "http".')
parser.add_argument('-H', dest='http_headers', metavar="HTTP_HEADER",
required=False, action='append',
help='HTTP headers to add to inference server requests. ' +
'Format is -H"Header:Value".')
parser.add_argument('--synchronous', action='store_true', help="Wait for previous request to finish before sending next request.")
parser.add_argument("--model_name",
type=str,
default='bert',
help="Specify model to run")
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. \
True for uncased models, False for cased models.")
parser.add_argument('--vocab_file',
type=str, default=None, required=True,
help="Vocabulary mapping/file BERT was pretrainined on")
parser.add_argument("--predict_file", default=None, type=str,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument('--version_2_with_negative',
action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument("--max_seq_length", default=384, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--doc_stride", default=128, type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.")
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--n_best_size", default=20, type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json "
"output file.")
parser.add_argument("--verbose_logging", action='store_true',
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument('--null_score_diff_threshold',
type=float, default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.")
parser.add_argument('--batch_size', default=1, type=int,
help='Maximal number of examples in a batch')
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.")
parser.add_argument("--max_answer_length", default=30, type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.")
args = parser.parse_args()
# TRITON client setup
protocol = ProtocolType.from_str(args.protocol)
model_version = -1
infer_ctx = InferContext(args.url, protocol, args.model_name, model_version,
http_headers=args.http_headers, verbose=args.verbose)
# Preprocess input data
tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large
cached_features_file = args.predict_file + '_{}_{}.bin'.format(args.max_seq_length, args.doc_stride)
eval_examples = read_squad_examples(
input_file=args.predict_file,
is_training=False,
version_2_with_negative=args.version_2_with_negative)
try:
with open(cached_features_file, "rb") as reader:
eval_features = pickle.load(reader)
except:
eval_features = convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=False)
with open(cached_features_file, "wb") as writer:
pickle.dump(eval_features, writer)
dtype = np.int64
def batch(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
unique_ids = ()
example_indices = ()
input_ids_data = ()
input_mask_data = ()
segment_ids_data = ()
for i in range(0, min(n, l-ndx)):
unique_ids = unique_ids + (iterable[ndx + i].unique_id,)
example_indices = example_indices + (ndx + i,)
input_ids_data = input_ids_data + (np.array(iterable[ndx + i].input_ids, dtype=dtype),)
input_mask_data = input_mask_data + (np.array(iterable[ndx + i].input_mask, dtype=dtype),)
segment_ids_data = segment_ids_data + (np.array(iterable[ndx + i].segment_ids, dtype=dtype),)
inputs_dict = {'input__0': input_ids_data,
'input__1': segment_ids_data,
'input__2': input_mask_data}
yield inputs_dict, example_indices, unique_ids
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
ExampleInfo = collections.namedtuple("ExampleInfo",
["start_time", "batch_size", "example_ids", "unique_ids"])
all_results = []
time_list = []
outstanding = 0
sent_prog = tqdm(desc="Sending Requests", total=len(eval_features), file=sys.stdout, unit='sentences')
recv_prog = tqdm(desc="Processed Requests", total=len(eval_features), file=sys.stdout, unit='sentences')
if args.synchronous:
raw_results = []
def process_result_cb(example_info, ctx, request_id):
global outstanding
result = infer_ctx.get_async_run_results(request_id)
stop = time.time()
outstanding -= 1
time_list.append(stop - example_info.start_time)
batch_count = example_info.batch_size
for i in range(batch_count):
unique_id = int(example_info.unique_ids[i])
start_logits = [float(x) for x in result["output__0"][i].flat]
end_logits = [float(x) for x in result["output__1"][i].flat]
all_results.append(
RawResult(
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
recv_prog.update(n=batch_count)
all_results_start = time.time()
for input_dict, example_indices, unique_ids in batch(eval_features, args.batch_size):
current_bs = len(input_dict['input__0'])
outputs_dict = {'output__0': InferContext.ResultFormat.RAW,
'output__1': InferContext.ResultFormat.RAW}
start = time.time()
example_info = ExampleInfo(start_time=start,
batch_size=current_bs,
example_ids=example_indices,
unique_ids=unique_ids
)
if not args.synchronous:
outstanding += 1
result_id = infer_ctx.async_run(partial(process_result_cb, example_info),
input_dict,
outputs_dict,
batch_size=current_bs)
else:
result = infer_ctx.run(input_dict, outputs_dict, batch_size=current_bs)
raw_results.append((example_info, result))
sent_prog.update(n=current_bs)
# Make sure that all sent requests have been processed
while outstanding > 0:
pass
all_results_end = time.time()
all_results_total = (all_results_end - all_results_start) * 1000.0
num_batches = (len(eval_features) + args.batch_size - 1) // args.batch_size
if args.synchronous:
for result in raw_results:
example_info, batch = result
for i in range(example_info.batch_size):
unique_id = int(example_info.unique_ids[i])
start_logits = [float(x) for x in batch["output__0"][i].flat]
end_logits = [float(x) for x in batch["output__1"][i].flat]
all_results.append(
RawResult(
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
recv_prog.update(n=example_info.batch_size)
print("-----------------------------")
print("Individual Time Runs")
print("Total Time: {} ms".format(all_results_total))
print("-----------------------------")
print("-----------------------------")
print("Total Inference Time = %0.2f for"
"Sentences processed = %d" % (sum(time_list), len(eval_features)))
print("Throughput Average (sentences/sec) = %0.2f" % (len(eval_features) / all_results_total * 1000.0))
print("Throughput Average (batches/sec) = %0.2f" % (num_batches / all_results_total * 1000.0))
print("-----------------------------")
if not args.synchronous:
time_list.sort()
avg = np.mean(time_list)
cf_95 = max(time_list[:int(len(time_list) * 0.95)])
cf_99 = max(time_list[:int(len(time_list) * 0.99)])
cf_100 = max(time_list[:int(len(time_list) * 1)])
print("-----------------------------")
print("Summary Statistics")
print("Batch size =", args.batch_size)
print("Sequence Length =", args.max_seq_length)
print("Latency Confidence Level 95 (ms) =", cf_95 * 1000)
print("Latency Confidence Level 99 (ms) =", cf_99 * 1000)
print("Latency Confidence Level 100 (ms) =", cf_100 * 1000)
print("Latency Average (ms) =", avg * 1000)
print("-----------------------------")
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
answers, nbest_answers = get_answers(eval_examples, eval_features, all_results, args)
with open(output_prediction_file, "w") as f:
f.write(json.dumps(answers, indent=4) + "\n")
with open(output_nbest_file, "w") as f:
f.write(json.dumps(nbest_answers, indent=4) + "\n")

View file

@ -0,0 +1,34 @@
#!/bin/bash
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
SERVER_URI=${1:-"localhost"}
echo "Waiting for TRITON Server to be ready at http://$SERVER_URI:8000..."
live_command="curl -i -m 1 -L -s -o /dev/null -w %{http_code} http://$SERVER_URI:8000/api/health/live"
ready_command="curl -i -m 1 -L -s -o /dev/null -w %{http_code} http://$SERVER_URI:8000/api/health/ready"
current_status=$($live_command)
echo $current_status
# First check the current status. If that passes, check the json. If either fail, loop
while [[ ${current_status} != "200" ]] || [[ $($ready_command) != "200" ]]; do
printf "."
sleep 1
current_status=$($live_command)
done
echo "TRITON Server is ready!"

View file

@ -67,7 +67,7 @@ The examples are organized first by framework, such as TensorFlow, PyTorch, etc.
| [ResNeXt101-32x4d](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnext101-32x4d) |PyTorch | Yes | Yes | Yes | - | - | - | - | - |
| [SE-ResNeXt101-32x4d](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/se-resnext101-32x4d) |PyTorch | Yes | Yes | Yes | - | - | - | - | - |
| [SSD300 v1.1](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Detection/SSD) |PyTorch | Yes | Yes | Yes | - | - | - | - | - |
| [BERT](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT) |PyTorch | N/A | Yes | Yes | Yes | - | - | - | - |
| [BERT](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT) |PyTorch | N/A | Yes | Yes | Yes | - | - | Yes | - |
| [Transformer-XL](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL) |PyTorch | N/A | Yes | Yes | Yes | - | - | - | - |
| [Neural Collaborative Filtering](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Recommendation/NCF) |PyTorch | N/A | Yes | Yes | - | - |- | - | - |
| [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/MaskRCNN) |PyTorch | N/A | Yes | Yes | - | - | - | - | - |