[MaskRCNN/TF] Update extract_RN50_weights.py (#597)
* Update extract_RN50_weights.py removing contrib - not compatible with TF2.0 * Update extract_RN50_weights.py adding compat with tf1 and tf2
This commit is contained in:
parent
0ef5568a9d
commit
9ba22d1f29
|
@ -20,6 +20,7 @@ import sys
|
|||
import getopt
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
"""
|
||||
python weights/extract_RN50_weights.py \
|
||||
|
@ -44,7 +45,12 @@ def rename(checkpoint_dir, save_to, dry_run, verbose):
|
|||
|
||||
total_vars_loaded = 0
|
||||
|
||||
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
|
||||
if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
|
||||
file_list = tf.contrib.framework.list_variables(checkpoint_dir)
|
||||
else:
|
||||
file_list = tf.train.list_variables(checkpoint_dir)
|
||||
|
||||
for var_name, _ in file_list:
|
||||
|
||||
if "resnet50" in var_name:
|
||||
# Load the variable
|
||||
|
|
Loading…
Reference in a new issue