nvidia-docker/src/nvidia/nvidia.go

62 lines
1.2 KiB
Go
Raw Normal View History

2016-01-06 00:45:23 +01:00
// Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
package nvidia
import (
2016-03-19 01:00:03 +01:00
"errors"
2015-12-12 11:26:16 +01:00
"os"
"os/exec"
"github.com/NVIDIA/nvidia-docker/src/cuda"
"github.com/NVIDIA/nvidia-docker/src/nvml"
)
2016-03-15 23:59:14 +01:00
const (
2016-06-17 01:29:45 +02:00
DockerPlugin = "nvidia-docker"
DeviceCtl = "/dev/nvidiactl"
DeviceUVM = "/dev/nvidia-uvm"
DeviceUVMTools = "/dev/nvidia-uvm-tools"
2016-03-15 23:59:14 +01:00
)
func Init() error {
2016-06-16 23:55:12 +02:00
if err := os.Setenv("CUDA_DISABLE_UNIFIED_MEMORY", "1"); err != nil {
return err
}
2016-04-13 02:33:54 +02:00
if err := os.Setenv("CUDA_CACHE_DISABLE", "1"); err != nil {
return err
}
if err := os.Unsetenv("CUDA_VISIBLE_DEVICES"); err != nil {
return err
}
return nvml.Init()
}
func Shutdown() error {
return nvml.Shutdown()
}
func LoadUVM() error {
2016-03-19 01:00:03 +01:00
if exec.Command("nvidia-modprobe", "-u", "-c=0").Run() != nil {
return errors.New("Could not load UVM kernel module. Is nvidia-modprobe installed?")
2016-03-19 01:00:03 +01:00
}
return nil
}
func GetDriverVersion() (string, error) {
return nvml.GetDriverVersion()
}
func GetCUDAVersion() (string, error) {
return cuda.GetDriverVersion()
}
2016-06-17 01:29:45 +02:00
func GetControlDevicePaths() ([]string, error) {
devs := []string{DeviceCtl, DeviceUVM}
_, err := os.Stat(DeviceUVMTools)
if os.IsNotExist(err) {
return devs, nil
}
return append(devs, DeviceUVMTools), err
}