Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add imex mode to CDI spec generation #807

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions cmd/nvidia-ctk/cdi/generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ func (m command) build() *cli.Command {
Destination: &opts.format,
},
&cli.StringFlag{
Name: "mode",
Aliases: []string{"discovery-mode"},
Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. If mode is set to 'auto' the mode will be determined based on the system configuration.",
Value: nvcdi.ModeAuto,
Name: "mode",
Aliases: []string{"discovery-mode"},
Usage: "The mode to use when discovering the available entities. " +
"One of [" + strings.Join(nvcdi.AllModes[string](), " | ") + "]. " +
"If mode is set to 'auto' the mode will be determined based on the system configuration.",
Value: string(nvcdi.ModeAuto),
Destination: &opts.mode,
},
&cli.StringFlag{
Expand Down Expand Up @@ -184,13 +186,7 @@ func (m command) validateFlags(c *cli.Context, opts *options) error {
}

opts.mode = strings.ToLower(opts.mode)
cdesiniotis marked this conversation as resolved.
Show resolved Hide resolved
switch opts.mode {
case nvcdi.ModeAuto:
case nvcdi.ModeCSV:
case nvcdi.ModeNvml:
case nvcdi.ModeWsl:
case nvcdi.ModeManagement:
default:
if !nvcdi.IsValidMode(opts.mode) {
return fmt.Errorf("invalid discovery mode: %v", opts.mode)
}

Expand Down
18 changes: 0 additions & 18 deletions pkg/nvcdi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
)

const (
// ModeAuto configures the CDI spec generator to automatically detect the system configuration
ModeAuto = "auto"
// ModeNvml configures the CDI spec generator to use the NVML library.
ModeNvml = "nvml"
// ModeWsl configures the CDI spec generator to generate a WSL spec.
ModeWsl = "wsl"
// ModeManagement configures the CDI spec generator to generate a management spec.
ModeManagement = "management"
// ModeGds configures the CDI spec generator to generate a GDS spec.
ModeGds = "gds"
// ModeMofed configures the CDI spec generator to generate a MOFED spec.
ModeMofed = "mofed"
// ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV
// mountspec files.
ModeCSV = "csv"
)

// Interface defines the API for the nvcdi package
type Interface interface {
GetSpec() (spec.Interface, error)
Expand Down
118 changes: 118 additions & 0 deletions pkg/nvcdi/lib-imex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvcdi

import (
"fmt"
"path/filepath"
"strconv"
"strings"

"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"

"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
)

type imexlib nvcdilib

var _ Interface = (*imexlib)(nil)

const (
classImexChannel = "imex-channel"
)

// GetSpec should not be called for imexlib.
func (l *imexlib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("unexpected call to imexlib.GetSpec()")
}

// GetAllDeviceSpecs returns the device specs for all available devices.
func (l *imexlib) GetAllDeviceSpecs() ([]specs.Device, error) {
channelsDiscoverer := discover.NewCharDeviceDiscoverer(
l.logger,
l.devRoot,
[]string{"/dev/nvidia-caps-imex-channels/channel*"},
)

channels, err := channelsDiscoverer.Devices()
if err != nil {
return nil, err
}

var channelIDs []string
for _, channel := range channels {
channelIDs = append(channelIDs, filepath.Base(channel.Path))
}

return l.GetDeviceSpecsByID(channelIDs...)
}

// GetCommonEdits returns an empty set of edits for IMEX devices.
func (l *imexlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
return edits.FromDiscoverer(discover.None{})
}

// GetDeviceSpecsByID returns the CDI device specs for the IMEX channels specified.
func (l *imexlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
var deviceSpecs []specs.Device
for _, id := range ids {
trimmed := strings.TrimPrefix(id, "channel")
_, err := strconv.ParseUint(trimmed, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid channel ID %v: %w", id, err)
}
path := "/dev/nvidia-caps-imex-channels/channel" + trimmed
deviceSpec := specs.Device{
Name: trimmed,
ContainerEdits: specs.ContainerEdits{
DeviceNodes: []*specs.DeviceNode{
{
Path: path,
HostPath: filepath.Join(l.devRoot, path),
},
},
},
}
deviceSpecs = append(deviceSpecs, deviceSpec)
}
return deviceSpecs, nil
}

// GetGPUDeviceEdits is unsupported for the imexlib specs
func (l *imexlib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported")
}

// GetGPUDeviceSpecs is unsupported for the imexlib specs
func (l *imexlib) GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error) {
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported")
}

// GetMIGDeviceEdits is unsupported for the imexlib specs
func (l *imexlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported")
}

// GetMIGDeviceSpecs is unsupported for the imexlib specs
func (l *imexlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported")
}
80 changes: 80 additions & 0 deletions pkg/nvcdi/lib-imex_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvcdi

import (
"bytes"
"path/filepath"
"strings"
"testing"

testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
)

func TestImexMode(t *testing.T) {
t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true")

logger, _ := testlog.NewNullLogger()

moduleRoot, err := test.GetModuleRoot()
require.NoError(t, err)
hostRoot := filepath.Join(moduleRoot, "testdata", "lookup", "rootfs-1")

expectedSpec := `---
cdiVersion: 0.5.0
containerEdits:
env:
- NVIDIA_VISIBLE_DEVICES=void
cdesiniotis marked this conversation as resolved.
Show resolved Hide resolved
devices:
- containerEdits:
deviceNodes:
- hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel0
path: /dev/nvidia-caps-imex-channels/channel0
name: "0"
- containerEdits:
deviceNodes:
- hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel1
path: /dev/nvidia-caps-imex-channels/channel1
name: "1"
- containerEdits:
deviceNodes:
- hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel2047
path: /dev/nvidia-caps-imex-channels/channel2047
name: "2047"
kind: nvidia.com/imex-channel
`
expectedSpec = strings.ReplaceAll(expectedSpec, "{{ .hostRoot }}", hostRoot)

lib, err := New(
WithLogger(logger),
WithMode(ModeImex),
WithDriverRoot(hostRoot),
)
require.NoError(t, err)

spec, err := lib.GetSpec()
require.NoError(t, err)

var b bytes.Buffer

_, err = spec.WriteTo(&b)
require.NoError(t, err)
require.Equal(t, expectedSpec, b.String())
}
2 changes: 1 addition & 1 deletion pkg/nvcdi/lib-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ Interface = (*nvmllib)(nil)

// GetSpec should not be called for nvmllib
func (l *nvmllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()")
return nil, fmt.Errorf("unexpected call to nvmllib.GetSpec()")
}

// GetAllDeviceSpecs returns the device specs for all available devices.
Expand Down
29 changes: 6 additions & 23 deletions pkg/nvcdi/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type nvcdilib struct {
logger logger.Interface
nvmllib nvml.Interface
nvsandboxutilslib nvsandboxutils.Interface
mode string
mode Mode
devicelib device.Interface
deviceNamers DeviceNamers
driverRoot string
Expand Down Expand Up @@ -161,6 +161,11 @@ func New(opts ...Option) (Interface, error) {
l.class = "mofed"
}
lib = (*mofedlib)(l)
case ModeImex:
if l.class == "" {
l.class = classImexChannel
}
lib = (*imexlib)(l)
default:
return nil, fmt.Errorf("unknown mode %q", l.mode)
}
Expand Down Expand Up @@ -206,28 +211,6 @@ func (m *wrapper) GetCommonEdits() (*cdi.ContainerEdits, error) {
return edits, nil
}

// resolveMode resolves the mode for CDI spec generation based on the current system.
func (l *nvcdilib) resolveMode() (rmode string) {
if l.mode != ModeAuto {
return l.mode
}
defer func() {
l.logger.Infof("Auto-detected mode as '%v'", rmode)
}()

platform := l.infolib.ResolvePlatform()
switch platform {
case info.PlatformNVML:
return ModeNvml
case info.PlatformTegra:
return ModeCSV
case info.PlatformWSL:
return ModeWsl
}
l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml)
return ModeNvml
}

// getCudaVersion returns the CUDA version of the current system.
func (l *nvcdilib) getCudaVersion() (string, error) {
version, err := l.getCudaVersionNvsandboxutils()
Expand Down
Loading