[nnUNet/PyT] update dlprof
This commit is contained in:
parent
6a642837c4
commit
ee3f3db4c0
|
@ -8,9 +8,15 @@ RUN pip install --upgrade pip
|
|||
RUN pip install --disable-pip-version-check -r requirements.txt
|
||||
RUN pip install --disable-pip-version-check -r triton/requirements.txt
|
||||
RUN pip install pytorch-lightning==1.0.0 --no-dependencies
|
||||
RUN pip install torchtext==0.6.0 --no-dependencies
|
||||
RUN pip install monai==0.4.0 --no-dependencies
|
||||
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==0.30.0
|
||||
RUN pip install torch_optimizer==0.0.1a15 --no-dependencies
|
||||
RUN pip install numpy==1.20.3
|
||||
RUN pip install nvidia-pyindex==1.0.9
|
||||
RUN pip install nvidia-dlprof==1.2.0
|
||||
RUN pip install nvidia_dlprof_pytorch_nvtx==1.2.0
|
||||
|
||||
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
||||
RUN unzip -qq awscliv2.zip
|
||||
RUN ./aws/install
|
||||
|
|
|
@ -28,8 +28,9 @@ if __name__ == "__main__":
|
|||
args = get_main_args()
|
||||
|
||||
if args.profile:
|
||||
import pyprof
|
||||
pyprof.init(enable_function_stack=True)
|
||||
import nvidia_dlprof_pytorch_nvtx
|
||||
|
||||
nvidia_dlprof_pytorch_nvtx.init()
|
||||
print("Profiling enabled")
|
||||
|
||||
if args.affinity != "disabled":
|
||||
|
|
|
@ -65,7 +65,7 @@ class LoggingCallback(Callback):
|
|||
|
||||
return stats
|
||||
|
||||
def log(self):
|
||||
def _log(self):
|
||||
if is_main_process():
|
||||
diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1]))
|
||||
deltas = np.array(diffs)
|
||||
|
@ -76,8 +76,8 @@ class LoggingCallback(Callback):
|
|||
def on_train_end(self, trainer, pl_module):
|
||||
if self.profile:
|
||||
profiler.stop()
|
||||
self.log()
|
||||
self._log()
|
||||
|
||||
def on_test_end(self, trainer, pl_module):
|
||||
if trainer.current_epoch == 1:
|
||||
self.log()
|
||||
self._log()
|
||||
|
|
Loading…
Reference in a new issue