# coding=utf-8 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf import numpy as np def float32_variable_storage_getter(getter, name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=True, *args, **kwargs): """Custom variable getter that forces trainable variables to be stored in float32 precision and then casts them to the training precision. """ storage_dtype = tf.float32 if trainable else dtype variable = getter(name, shape, dtype=storage_dtype, initializer=initializer, regularizer=regularizer, trainable=trainable, *args, **kwargs) if trainable and dtype != tf.float32: variable = tf.cast(variable, dtype) return variable def get_custom_getter(compute_type): return float32_variable_storage_getter if compute_type == tf.float16 else None