Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip injection of firmware by default #695

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
16 changes: 8 additions & 8 deletions cmd/nvidia-container-runtime-hook/container_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func getDevicesFromMounts(mounts []Mount) *string {
return &ret
}

func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *string {
func (hookConfig *hookConfig) getDevices(image image.CUDA, mounts []Mount, privileged bool) *string {
// If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := getDevicesFromMounts(mounts)
Expand Down Expand Up @@ -284,10 +284,10 @@ func getImexChannels(image image.CUDA) *string {
return &chans
}

func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
// We use the default driver capabilities by default. This is filtered to only include the
// supported capabilities
supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities)
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)

capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)

Expand All @@ -311,11 +311,11 @@ func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage boo
return capabilities
}

func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *nvidiaConfig {
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, mounts []Mount, privileged bool) *nvidiaConfig {
legacyImage := image.IsLegacy()

var devices string
if d := getDevices(hookConfig, image, mounts, privileged); d != nil {
if d := hookConfig.getDevices(image, mounts, privileged); d != nil {
devices = *d
} else {
// 'nil' devices means this is not a GPU container.
Expand Down Expand Up @@ -360,7 +360,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p
}
}

func getContainerConfig(hook HookConfig) (config containerConfig) {
func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
var h HookState
d := json.NewDecoder(os.Stdin)
if err := d.Decode(&h); err != nil {
Expand All @@ -376,7 +376,7 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {

image, err := image.New(
image.WithEnv(s.Process.Env),
image.WithDisableRequire(hook.DisableRequire),
image.WithDisableRequire(hookConfig.DisableRequire),
)
if err != nil {
log.Panicln(err)
Expand All @@ -387,6 +387,6 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {
Pid: h.Pid,
Rootfs: s.Root.Path,
Image: image,
Nvidia: getNvidiaConfig(&hook, image, s.Mounts, privileged),
Nvidia: hookConfig.getNvidiaConfig(image, s.Mounts, privileged),
}
}
78 changes: 46 additions & 32 deletions cmd/nvidia-container-runtime-hook/container_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
)

Expand All @@ -15,7 +16,7 @@ func TestGetNvidiaConfig(t *testing.T) {
description string
env map[string]string
privileged bool
hookConfig *HookConfig
hookConfig *hookConfig
expectedConfig *nvidiaConfig
expectedPanic bool
}{
Expand Down Expand Up @@ -394,8 +395,10 @@ func TestGetNvidiaConfig(t *testing.T) {
envNVDriverCapabilities: "all",
},
privileged: true,
hookConfig: &HookConfig{
SupportedDriverCapabilities: "video,display",
hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display",
},
},
expectedConfig: &nvidiaConfig{
Devices: "all",
Expand All @@ -409,8 +412,10 @@ func TestGetNvidiaConfig(t *testing.T) {
envNVDriverCapabilities: "video,display",
},
privileged: true,
hookConfig: &HookConfig{
SupportedDriverCapabilities: "video,display,compute,utility",
hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display,compute,utility",
},
},
expectedConfig: &nvidiaConfig{
Devices: "all",
Expand All @@ -423,8 +428,10 @@ func TestGetNvidiaConfig(t *testing.T) {
envNVVisibleDevices: "all",
},
privileged: true,
hookConfig: &HookConfig{
SupportedDriverCapabilities: "video,display,utility,compute",
hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display,utility,compute",
},
},
expectedConfig: &nvidiaConfig{
Devices: "all",
Expand All @@ -438,9 +445,11 @@ func TestGetNvidiaConfig(t *testing.T) {
"DOCKER_SWARM_RESOURCE": "GPU1,GPU2",
},
privileged: true,
hookConfig: &HookConfig{
SwarmResource: "DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
hookConfig: &hookConfig{
Config: &config.Config{
SwarmResource: "DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
},
},
expectedConfig: &nvidiaConfig{
Devices: "GPU1,GPU2",
Expand All @@ -454,9 +463,11 @@ func TestGetNvidiaConfig(t *testing.T) {
"DOCKER_SWARM_RESOURCE": "GPU1,GPU2",
},
privileged: true,
hookConfig: &HookConfig{
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
hookConfig: &hookConfig{
Config: &config.Config{
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
},
},
expectedConfig: &nvidiaConfig{
Devices: "GPU1,GPU2",
Expand All @@ -470,14 +481,14 @@ func TestGetNvidiaConfig(t *testing.T) {
image.WithEnvMap(tc.env),
)
// Wrap the call to getNvidiaConfig() in a closure.
var config *nvidiaConfig
var cfg *nvidiaConfig
getConfig := func() {
hookConfig := tc.hookConfig
if hookConfig == nil {
defaultConfig, _ := getDefaultHookConfig()
hookConfig = &defaultConfig
hookCfg := tc.hookConfig
if hookCfg == nil {
defaultConfig, _ := config.GetDefault()
hookCfg = &hookConfig{defaultConfig}
}
config = getNvidiaConfig(hookConfig, image, nil, tc.privileged)
cfg = hookCfg.getNvidiaConfig(image, nil, tc.privileged)
}

// For any tests that are expected to panic, make sure they do.
Expand All @@ -491,18 +502,18 @@ func TestGetNvidiaConfig(t *testing.T) {

// And start comparing the test results to the expected results.
if tc.expectedConfig == nil {
require.Nil(t, config, tc.description)
require.Nil(t, cfg, tc.description)
return
}

require.NotNil(t, config, tc.description)
require.NotNil(t, cfg, tc.description)

require.Equal(t, tc.expectedConfig.Devices, config.Devices)
require.Equal(t, tc.expectedConfig.MigConfigDevices, config.MigConfigDevices)
require.Equal(t, tc.expectedConfig.MigMonitorDevices, config.MigMonitorDevices)
require.Equal(t, tc.expectedConfig.DriverCapabilities, config.DriverCapabilities)
require.Equal(t, tc.expectedConfig.Devices, cfg.Devices)
require.Equal(t, tc.expectedConfig.MigConfigDevices, cfg.MigConfigDevices)
require.Equal(t, tc.expectedConfig.MigMonitorDevices, cfg.MigMonitorDevices)
require.Equal(t, tc.expectedConfig.DriverCapabilities, cfg.DriverCapabilities)

require.ElementsMatch(t, tc.expectedConfig.Requirements, config.Requirements)
require.ElementsMatch(t, tc.expectedConfig.Requirements, cfg.Requirements)
})
}
}
Expand Down Expand Up @@ -689,10 +700,11 @@ func TestDeviceListSourcePriority(t *testing.T) {
},
),
)
hookConfig, _ := getDefaultHookConfig()
hookConfig.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
hookConfig.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
devices = getDevices(&hookConfig, image, tc.mountDevices, tc.privileged)
defaultConfig, _ := config.GetDefault()
cfg := &hookConfig{defaultConfig}
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
devices = cfg.getDevices(image, tc.mountDevices, tc.privileged)
}

// For all other tests, just grab the devices and check the results
Expand Down Expand Up @@ -1028,8 +1040,10 @@ func TestGetDriverCapabilities(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
var capabilities string

c := HookConfig{
SupportedDriverCapabilities: tc.supportedCapabilities,
c := hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: tc.supportedCapabilities,
},
}

image, _ := image.New(
Expand Down
22 changes: 8 additions & 14 deletions cmd/nvidia-container-runtime-hook/hook_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,10 @@ const (
driverPath = "/run/nvidia/driver"
)

// HookConfig : options for the nvidia-container-runtime-hook.
type HookConfig config.Config

func getDefaultHookConfig() (HookConfig, error) {
defaultCfg, err := config.GetDefault()
if err != nil {
return HookConfig{}, err
}

return *(*HookConfig)(defaultCfg), nil
// hookConfig wraps the toolkit config.
// This allows for functions to be defined on the local type.
type hookConfig struct {
*config.Config
}

// loadConfig loads the required paths for the hook config.
Expand Down Expand Up @@ -56,12 +50,12 @@ func loadConfig() (*config.Config, error) {
return config.GetDefault()
}

func getHookConfig() (*HookConfig, error) {
func getHookConfig() (*hookConfig, error) {
cfg, err := loadConfig()
if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err)
}
config := (*HookConfig)(cfg)
config := &hookConfig{cfg}

allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" {
Expand All @@ -79,7 +73,7 @@ func getHookConfig() (*HookConfig, error) {

// getConfigOption returns the toml config option associated with the
// specified struct field.
func (c HookConfig) getConfigOption(fieldName string) string {
func (c hookConfig) getConfigOption(fieldName string) string {
t := reflect.TypeOf(c)
f, ok := t.FieldByName(fieldName)
if !ok {
Expand All @@ -93,7 +87,7 @@ func (c HookConfig) getConfigOption(fieldName string) string {
}

// getSwarmResourceEnvvars returns the swarm resource envvars for the config.
func (c *HookConfig) getSwarmResourceEnvvars() []string {
func (c *hookConfig) getSwarmResourceEnvvars() []string {
if c.SwarmResource == "" {
return nil
}
Expand Down
13 changes: 8 additions & 5 deletions cmd/nvidia-container-runtime-hook/hook_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
)

Expand Down Expand Up @@ -89,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
}
}

var config HookConfig
var cfg hookConfig
getHookConfig := func() {
c, _ := getHookConfig()
config = *c
cfg = *c
}

if tc.expectedPanic {
Expand All @@ -102,7 +103,7 @@ func TestGetHookConfig(t *testing.T) {

getHookConfig()

require.EqualValues(t, tc.expectedDriverCapabilities, config.SupportedDriverCapabilities)
require.EqualValues(t, tc.expectedDriverCapabilities, cfg.SupportedDriverCapabilities)
})
}
}
Expand Down Expand Up @@ -144,8 +145,10 @@ func TestGetSwarmResourceEnvvars(t *testing.T) {

for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
c := &HookConfig{
SwarmResource: tc.value,
c := &hookConfig{
Config: &config.Config{
SwarmResource: tc.value,
},
}

envvars := c.getSwarmResourceEnvvars()
Expand Down
12 changes: 11 additions & 1 deletion cmd/nvidia-container-runtime-hook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func doPrestart() {
}
cli := hook.NVIDIAContainerCLIConfig

container := getContainerConfig(*hook)
container := hook.getContainerConfig()
nvidia := container.Nvidia
if nvidia == nil {
// Not a GPU container, nothing to do.
Expand All @@ -89,6 +89,16 @@ func doPrestart() {
rootfs := getRootfsPath(container)

args := []string{getCLIPath(cli)}

// Only include GSP firmware if explicitly renabled.
if !hook.Features.IncludeGSPFirmware.IsEnabled() {
args = append(args, "--no-gsp-firmware")
}
// Only include the nvidia-persistenced socket if it is explicitly enabled.
if !hook.Features.IncludePersistencedSocket.IsEnabled() {
args = append(args, "--no-persistenced")
}

if cli.Root != "" {
args = append(args, fmt.Sprintf("--root=%s", cli.Root))
}
Expand Down
15 changes: 15 additions & 0 deletions cmd/nvidia-ctk/cdi/generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type options struct {
files cli.StringSlice
ignorePatterns cli.StringSlice
}

includeGSPFirmware bool
includePersistencedSocket bool
}

// NewCommand constructs a generate-cdi command with the specified logger
Expand Down Expand Up @@ -169,6 +172,16 @@ func (m command) build() *cli.Command {
Usage: "Specify a pattern the CSV mount specifications.",
Destination: &opts.csv.ignorePatterns,
},
&cli.BoolFlag{
Name: "include-gsp-firmware",
Usage: "Include the GSP firmware in the generated CDI specification.",
Destination: &opts.includeGSPFirmware,
},
&cli.BoolFlag{
Name: "include-persistenced-socket",
Usage: "Include the nvidia-persistenced socket in the generated CDI specification.",
Destination: &opts.includePersistencedSocket,
},
}

return &c
Expand Down Expand Up @@ -273,6 +286,8 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
nvcdi.WithLibrarySearchPaths(opts.librarySearchPaths.Value()),
nvcdi.WithCSVFiles(opts.csv.files.Value()),
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()),
nvcdi.WithOptInFeature("include-gsp-firmware", opts.includeGSPFirmware),
nvcdi.WithOptInFeature("include-persistenced-socket", opts.includePersistencedSocket),
)
if err != nil {
return nil, fmt.Errorf("failed to create CDI library: %v", err)
Expand Down
Loading