[MaskRCNN/PyT] Update AMP API for inference (#810)
This commit is contained in:
parent
2badf6e8e4
commit
0cbadd7d49
|
@ -96,13 +96,14 @@ def main():
|
|||
model = build_detection_model(cfg)
|
||||
model.to(cfg.MODEL.DEVICE)
|
||||
|
||||
# Initialize mixed-precision if necessary
|
||||
# Initialize mixed-precision
|
||||
if args.fp16:
|
||||
use_mixed_precision = True
|
||||
else:
|
||||
use_mixed_precision = cfg.DTYPE == "float16"
|
||||
amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)
|
||||
|
||||
amp_opt_level = 'O1' if use_mixed_precision else 'O0'
|
||||
model = amp.initialize(model, opt_level=amp_opt_level)
|
||||
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
|
||||
_ = checkpointer.load(cfg.MODEL.WEIGHT)
|
||||
|
|
Loading…
Reference in a new issue