Merge pull request #868 from hXl3s/rn50/fix_bs_1
Fixed wrong shapes with BS=1
This commit is contained in:
commit
f5bda9aa7a
|
@ -187,12 +187,9 @@ class ResnetModel(object):
|
|||
reuse=False,
|
||||
use_final_conv=params['use_final_conv']
|
||||
)
|
||||
|
||||
if mode!=tf.estimator.ModeKeys.PREDICT:
|
||||
logits = tf.squeeze(logits)
|
||||
|
||||
if mode!=tf.estimator.ModeKeys.PREDICT:
|
||||
logits = tf.squeeze(logits)
|
||||
if params['use_final_conv']:
|
||||
logits = tf.squeeze(logits, axis=[-2, -1])
|
||||
|
||||
y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)
|
||||
|
||||
|
|
Loading…
Reference in a new issue