Merge pull request #67 from GrzegorzKarchNV/master
Update models.py - fix fp16 inference
This commit is contained in:
commit
2619f172c7
|
@ -74,8 +74,7 @@ def get_model(model_name, model_config, to_fp16, to_cuda, training=True):
|
|||
raise NotImplementedError(model_name)
|
||||
if to_fp16:
|
||||
model = batchnorm_to_float(model.half())
|
||||
if training:
|
||||
model = lstmcell_to_float(model)
|
||||
model = lstmcell_to_float(model)
|
||||
if model_name == "WaveGlow":
|
||||
for k in model.convinv:
|
||||
k.float()
|
||||
|
|
Loading…
Reference in a new issue