[MaskRCNN/TF] Update extract_RN50_weights.py (#597)

* Update extract_RN50_weights.py

removing contrib - not compatible with TF2.0

* Update extract_RN50_weights.py

adding compat with tf1 and tf2
This commit is contained in:
Amr Ragab 2020-07-22 20:01:27 -04:00 committed by GitHub
parent 0ef5568a9d
commit 9ba22d1f29
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -20,6 +20,7 @@ import sys
import getopt
import logging
import tensorflow as tf
from distutils.version import LooseVersion
"""
python weights/extract_RN50_weights.py \
@ -44,7 +45,12 @@ def rename(checkpoint_dir, save_to, dry_run, verbose):
total_vars_loaded = 0
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
file_list = tf.contrib.framework.list_variables(checkpoint_dir)
else:
file_list = tf.train.list_variables(checkpoint_dir)
for var_name, _ in file_list:
if "resnet50" in var_name:
# Load the variable