Split package "plugin" into "nvidia" and "nvidia-docker-plugin"

This commit is contained in:
Jonathan Calmels 2015-12-07 22:08:22 -08:00
parent d8c9afc50a
commit 341af23818
10 changed files with 49 additions and 25 deletions

View file

@ -31,4 +31,4 @@ ARG UID
RUN useradd --uid $UID build
USER build
CMD go get -v -ldflags="-s" plugin
CMD go get -v -ldflags="-s" nvidia-docker-plugin

View file

@ -7,7 +7,7 @@ USER_ID := $(shell id -u)
IMAGE := nvdocker-build
PREFIX := /usr/local/nvidia
TARGET := nvidia-docker-plugin
PLUGIN := $(BIN_DIR)/plugin
PLUGIN := $(BIN_DIR)/nvidia-docker-plugin
.PHONY: all install clean

View file

@ -6,9 +6,8 @@ import (
"flag"
"log"
"os"
"os/exec"
"nvml"
"nvidia"
)
var (
@ -16,12 +15,12 @@ var (
VolumePrefix string
SocketPath string
Devices []Device
Volumes VolumeMap
Devices []nvidia.Device
Volumes nvidia.VolumeMap
)
func init() {
log.SetPrefix("nvidia-docker-plugin | ")
log.SetPrefix(os.Args[0] + " | ")
flag.StringVar(&ListenAddr, "l", "localhost:3476", "Server listen address")
flag.StringVar(&VolumePrefix, "p", "", "Volumes prefix path (default is to use mktemp)")
@ -42,10 +41,6 @@ func exit() {
os.Exit(code)
}
func modprobeUVM() error {
return exec.Command("nvidia-modprobe", "-u", "-c=0").Run()
}
func main() {
var err error
@ -53,14 +48,14 @@ func main() {
defer exit()
log.Println("Loading NVIDIA management library")
assert(nvml.Init())
defer func() { assert(nvml.Shutdown()) }()
assert(nvidia.Init())
defer func() { assert(nvidia.Shutdown()) }()
log.Println("Loading NVIDIA unified memory module")
assert(modprobeUVM())
log.Println("Loading NVIDIA unified memory")
assert(nvidia.LoadUVM())
log.Println("Discovering GPU devices")
Devices, err = GetDevices()
Devices, err = nvidia.GetDevices()
assert(err)
if VolumePrefix == "" {
@ -68,7 +63,7 @@ func main() {
} else {
log.Println("Creating volumes at", VolumePrefix)
}
Volumes, err = GetVolumes(VolumePrefix)
Volumes, err = nvidia.GetVolumes(VolumePrefix)
assert(err)
plugin := NewPluginAPI(SocketPath)

View file

@ -12,8 +12,7 @@ import (
"text/tabwriter"
"text/template"
"cuda"
"nvml"
"nvidia"
)
type remoteV10 struct{}
@ -22,8 +21,8 @@ func (r *remoteV10) version() string { return "v1.0" }
func (r *remoteV10) gpuInfo(resp http.ResponseWriter, req *http.Request) {
const tpl = `
Driver version: {{nvmlDriverVersion}}
Supported CUDA version: {{cudaDriverVersion}}
Driver version: {{driverVersion}}
Supported CUDA version: {{cudaVersion}}
{{range $i, $e := .}}
Device #{{$i}}
Name: {{.Name}}
@ -54,8 +53,8 @@ func (r *remoteV10) gpuInfo(resp http.ResponseWriter, req *http.Request) {
{{end}}
`
m := template.FuncMap{
"nvmlDriverVersion": nvml.GetDriverVersion,
"cudaDriverVersion": cuda.GetDriverVersion,
"driverVersion": nvidia.GetDriverVersion,
"cudaVersion": nvidia.GetCUDAVersion,
}
t := template.Must(template.New("").Funcs(m).Parse(tpl))
w := tabwriter.NewWriter(resp, 0, 4, 0, ' ', 0)

View file

@ -1,6 +1,6 @@
// Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
package main
package nvidia
import (
"sort"

View file

@ -0,0 +1,30 @@
// Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
package nvidia
import (
"os/exec"
"cuda"
"nvml"
)
func Init() error {
return nvml.Init()
}
func Shutdown() error {
return nvml.Shutdown()
}
func LoadUVM() error {
return exec.Command("nvidia-modprobe", "-u", "-c=0").Run()
}
func GetDriverVersion() (string, error) {
return nvml.GetDriverVersion()
}
func GetCUDAVersion() (string, error) {
return cuda.GetDriverVersion()
}

View file

@ -1,6 +1,6 @@
// Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
package main
package nvidia
import (
"bufio"