Merge pull request #67 from GrzegorzKarchNV/master

Update models.py - fix fp16 inference
This commit is contained in:
nvpstr 2019-06-05 14:45:27 +02:00 committed by GitHub
commit 2619f172c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -74,8 +74,7 @@ def get_model(model_name, model_config, to_fp16, to_cuda, training=True):
raise NotImplementedError(model_name)
if to_fp16:
model = batchnorm_to_float(model.half())
if training:
model = lstmcell_to_float(model)
model = lstmcell_to_float(model)
if model_name == "WaveGlow":
for k in model.convinv:
k.float()