[Convnets] 21.09 Fixes
This commit is contained in:
parent
88e5298421
commit
4fdd014ebf
|
@ -36,6 +36,7 @@ def add_dali_args(parser):
|
||||||
|
|
||||||
group.add_argument('--dali-nvjpeg-width-hint', type=int, default=5980, help="Width hint value for nvJPEG (in pixels)")
|
group.add_argument('--dali-nvjpeg-width-hint', type=int, default=5980, help="Width hint value for nvJPEG (in pixels)")
|
||||||
group.add_argument('--dali-nvjpeg-height-hint', type=int, default=6430, help="Height hint value for nvJPEG (in pixels)")
|
group.add_argument('--dali-nvjpeg-height-hint', type=int, default=6430, help="Height hint value for nvJPEG (in pixels)")
|
||||||
|
group.add_argument('--dali-dont-use-mmap', default=False, action='store_true', help="Use plain I/O instead of MMAP for datasets")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +48,8 @@ class HybridTrainPipe(Pipeline):
|
||||||
):
|
):
|
||||||
super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id, prefetch_queue_depth = prefetch_queue)
|
super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id, prefetch_queue_depth = prefetch_queue)
|
||||||
self.input = ops.MXNetReader(path=[rec_path], index_path=[idx_path],
|
self.input = ops.MXNetReader(path=[rec_path], index_path=[idx_path],
|
||||||
random_shuffle=True, shard_id=shard_id, num_shards=num_shards)
|
random_shuffle=True, shard_id=shard_id, num_shards=num_shards,
|
||||||
|
dont_use_mmap=args.dali_dont_use_mmap)
|
||||||
|
|
||||||
if dali_cpu:
|
if dali_cpu:
|
||||||
dali_device = "cpu"
|
dali_device = "cpu"
|
||||||
|
@ -101,7 +103,8 @@ class HybridValPipe(Pipeline):
|
||||||
nvjpeg_width_hint=5980, nvjpeg_height_hint=6430):
|
nvjpeg_width_hint=5980, nvjpeg_height_hint=6430):
|
||||||
super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id, prefetch_queue_depth=prefetch_queue)
|
super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id, prefetch_queue_depth=prefetch_queue)
|
||||||
self.input = ops.MXNetReader(path=[rec_path], index_path=[idx_path],
|
self.input = ops.MXNetReader(path=[rec_path], index_path=[idx_path],
|
||||||
random_shuffle=False, shard_id=shard_id, num_shards=num_shards)
|
random_shuffle=False, shard_id=shard_id, num_shards=num_shards,
|
||||||
|
dont_use_mmap=args.dali_dont_use_mmap)
|
||||||
|
|
||||||
if dali_cpu:
|
if dali_cpu:
|
||||||
dali_device = "cpu"
|
dali_device = "cpu"
|
||||||
|
|
Loading…
Reference in a new issue