Skip to content

Commit

Permalink
Add custom filters for processing accelerators
Browse files Browse the repository at this point in the history
  • Loading branch information
mlorenzofr committed Oct 28, 2024
1 parent c1bfc6e commit 106b5a7
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 13 deletions.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,16 @@ Each `ghw.AcceleratorDevice` struct contains the following fields:
describing the processing accelerator card. This may be `nil` if no PCI device
information could be determined for the card.

#### filters
The `ghw.Accelerator()` function accepts a slice of filters, of type string, as parameter
in format `[<vendor>]:[<device>][:<class>]`, (same is the _lspci_ command).

Some filter examples:
* `::0302`. Select 3D controller cards.
* `10de::0302`. Select Nvidia (`10de`) 3D controller cards (`0302`).
* `1da3:1060:1200`. Select Habana Labs (`1da3`) Gaudi3 (`1060`) processing accelerator cards (`1200`).
* `1002::`. Select AMD ATI hardware.

```go
package main

Expand All @@ -976,7 +986,11 @@ import (
)

func main() {
accel, err := ghw.Accelerator()
filter := make([]string, 0)
// example of a filter to detect 3D controllers
// filter = append(filter, "::0302")

accel, err := ghw.Accelerator(filter)
if err != nil {
fmt.Printf("Error getting processing accelerator info: %v", err)
}
Expand Down
4 changes: 3 additions & 1 deletion cmd/ghwc/commands/accelerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ var acceleratorCmd = &cobra.Command{

// showAccelerator show processing accelerators information for the host system.
func showAccelerator(cmd *cobra.Command, args []string) error {
accel, err := ghw.Accelerator()
filter := make([]string, 0)

accel, err := ghw.Accelerator(filter)
if err != nil {
return errors.Wrap(err, "error getting Accelerator info")
}
Expand Down
2 changes: 1 addition & 1 deletion host.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func Host(opts ...*WithOption) (*HostInfo, error) {
if err != nil {
return nil, err
}
acceleratorInfo, err := accelerator.New(opts...)
acceleratorInfo, err := accelerator.New([]string{}, opts...)
if err != nil {
return nil, err
}
Expand Down
12 changes: 8 additions & 4 deletions pkg/accelerator/accelerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,19 @@ func (dev *AcceleratorDevice) String() string {
}

type Info struct {
ctx *context.Context
Devices []*AcceleratorDevice `json:"devices"`
ctx *context.Context
Devices []*AcceleratorDevice `json:"devices"`
DiscoveryFilters []string
}

// New returns a pointer to an Info struct that contains information about the
// accelerator devices on the host system
func New(opts ...*option.Option) (*Info, error) {
func New(filter []string, opts ...*option.Option) (*Info, error) {
ctx := context.New(opts...)
info := &Info{ctx: ctx}
info := &Info{
ctx: ctx,
DiscoveryFilters: filter,
}

if err := ctx.Do(info.load); err != nil {
return nil, err
Expand Down
60 changes: 55 additions & 5 deletions pkg/accelerator/accelerator_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package accelerator

import (
"github.com/samber/lo"
"fmt"
"strings"

"github.com/jaypipes/ghw/pkg/context"
"github.com/jaypipes/ghw/pkg/pci"
"github.com/samber/lo"
)

// PCI IDs list available at https://admin.pci-ids.ucw.cz/read/PD
Expand Down Expand Up @@ -60,13 +62,61 @@ func (i *Info) load() error {
if !isAccelerator(device) {
continue
}
accelDev := &AcceleratorDevice{
Address: device.Address,
PCIDevice: device,
for _, filter := range i.DiscoveryFilters {
if validate(filter, device) {
accelDev := &AcceleratorDevice{
Address: device.Address,
PCIDevice: device,
}
accelDevices = append(accelDevices, accelDev)
break
}
}
accelDevices = append(accelDevices, accelDev)
}

i.Devices = accelDevices
return nil
}

// validate checks if a given PCI device matches the provided filter string.
//
// The filter string is expected to be in the format "VendorID:ProductID:Class+Subclass".
// Each part of the filter (VendorID, ProductID, Class+Subclass) is optional and can be
// left empty, in which case the corresponding attribute is ignored during validation.
//
// Parameters:
// - filter: A string in the form "VendorID:ProductID:Class+Subclass", where
// any part of the string may be empty to represent a wildcard match.
// - device: A pointer to a `pci.Device` structure.
//
// Returns:
// - true: If the device matches the filter criteria (wildcards are supported).
// - false: If the device does not match the filter criteria.
//
// Matching criteria:
// - VendorID must match `device.Vendor.ID` if provided.
// - ProductID must match `device.Product.ID` if provided.
// - Class and Subclass must match the concatenated result of `device.Class.ID` and `device.Subclass.ID` if provided.
//
// Example:
//
// filter := "8086:1234:1200"
// device := pci.Device{Vendor: Vendor{ID: "8086"}, Product: Product{ID: "1234"}, Class: Class{ID: "12"}, Subclass: Subclass{ID: "00"}}
// isValid := validate(filter, &device) // returns true
//
// filter := "8086::1200" // Wildcard for ProductID
// isValid := validate(filter, &device) // returns true
//
// filter := "::1200" // Wildcard for ProductID and VendorID
// isValid := validate(filter, &device) // returns true
func validate(filter string, device *pci.Device) bool {
ids := strings.Split(filter, ":")

if (ids[0] == "" || ids[0] == device.Vendor.ID) &&
(len(ids) < 2 || (ids[1] == "" || ids[1] == device.Product.ID)) &&
(len(ids) < 3 || (ids[2] == "" || ids[2] == fmt.Sprintf("%s%s", device.Class.ID, device.Subclass.ID))) {
return true
}

return false
}
2 changes: 1 addition & 1 deletion pkg/accelerator/accelerator_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func testScenario(t *testing.T, filename string, expectedDevs int) {
_ = snapshot.Cleanup(tmpRoot)
}()

info, err := accelerator.New(option.WithChroot(tmpRoot))
info, err := accelerator.New([]string{}, option.WithChroot(tmpRoot))
if err != nil {
t.Fatalf("Expected nil err, but got %v", err)
}
Expand Down

0 comments on commit 106b5a7

Please sign in to comment.