Merge pull request #868 from hXl3s/rn50/fix_bs_1

Fixed wrong shapes with BS=1
This commit is contained in:
nv-kkudrynski 2021-03-11 10:38:12 +01:00 committed by GitHub
commit f5bda9aa7a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -187,12 +187,9 @@ class ResnetModel(object):
reuse=False,
use_final_conv=params['use_final_conv']
)
if mode!=tf.estimator.ModeKeys.PREDICT:
logits = tf.squeeze(logits)
if mode!=tf.estimator.ModeKeys.PREDICT:
logits = tf.squeeze(logits)
if params['use_final_conv']:
logits = tf.squeeze(logits, axis=[-2, -1])
y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)