Remove deadcode in TensorRT builder

Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
This commit is contained in:
Rajeev Rao 2020-03-30 12:00:09 -07:00
parent 899c9988f7
commit af69862cfb

View file

@ -251,28 +251,20 @@ def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imas
# FC1 + GELU
B_mid = init_dict[prefix + B_MID]
W_mid = init_dict[prefix + W_MID]
if False:
mid_dense = network.add_fully_connected(attention_ln, config.intermediate_size, W_mid, B_mid)
mid_dense_out = mid_dense.get_output(0)
gelu_layer = make_gelu_layer(prefix, config, network, mid_dense_out)
intermediate_act = gelu_layer.get_output(0)
set_tensor_name(intermediate_act, prefix, "gelu")
else:
W_midT = init_dict[prefix + W_MID + '_notrans']
mid_dense = my_fc(config, network, attention_ln, config.intermediate_size, W_midT)
mid_dense_out = mid_dense.get_output(0)
W_midT = init_dict[prefix + W_MID + '_notrans']
mid_dense = my_fc(config, network, attention_ln, config.intermediate_size, W_midT)
mid_dense_out = mid_dense.get_output(0)
pf_type = trt.PluginField("type_id", np.array([1 if config.use_fp16 else 0], np.int32), trt.PluginFieldType.INT32)
pf_bias = trt.PluginField("bias", B_mid.numpy(), trt.PluginFieldType.FLOAT32)
pfc = trt.PluginFieldCollection([pf_type, pf_bias])
pf_type = trt.PluginField("type_id", np.array([1 if config.use_fp16 else 0], np.int32), trt.PluginFieldType.INT32)
pf_bias = trt.PluginField("bias", B_mid.numpy(), trt.PluginFieldType.FLOAT32)
pfc = trt.PluginFieldCollection([pf_type, pf_bias])
plug = gelu_plg_creator.create_plugin("gelu", pfc)
plug = gelu_plg_creator.create_plugin("gelu", pfc)
gelu_layer = network.add_plugin_v2([mid_dense_out], plug)
gelu_layer = network.add_plugin_v2([mid_dense_out], plug)
intermediate_act = gelu_layer.get_output(0)
set_tensor_name(intermediate_act, prefix, "gelu")
intermediate_act = gelu_layer.get_output(0)
set_tensor_name(intermediate_act, prefix, "gelu")
if config.use_int8 and config.use_strict:
intermediate_act.set_dynamic_range(-10, 10)