Skip to content

Commit

Permalink
[EBPF-601] gpu: add function to retrieve visible devices for a process (
Browse files Browse the repository at this point in the history
  • Loading branch information
gjulianm authored Nov 4, 2024
1 parent 1b31902 commit 3580727
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 0 deletions.
126 changes: 126 additions & 0 deletions pkg/gpu/cuda/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2024-present Datadog, Inc.

//go:build linux_bpf

package cuda

import (
"fmt"
"strconv"
"strings"

"github.com/NVIDIA/go-nvml/pkg/nvml"

"github.com/DataDog/datadog-agent/pkg/util/kernel"
)

const cudaVisibleDevicesEnvVar = "CUDA_VISIBLE_DEVICES"

// GetVisibleDevicesForProcess modifies the list of GPU devices according to the
// value of the CUDA_VISIBLE_DEVICES environment variable for the specified
// process. Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars.
//
// As a summary, the CUDA_VISIBLE_DEVICES environment variable should be a comma
// separated list of GPU identifiers. These can be either the index of the GPU
// (0, 1, 2) or the UUID of the GPU (GPU-<UUID>, or
// MIG-GPU-<UUID>/<instance-index>/<compute-index for multi-instance GPUs). UUID
// identifiers do not need to be the full UUID, it is enough with specifying the
// prefix that uniquely identifies the GPU.
//
// Invalid device indexes are ignored, and anything that comes after that is
// invisible, following the spec: "If one of the indices is invalid, only the
// devices whose index precedes the invalid index are visible to CUDA
// applications." If an invalid index is found, an error is returned together
// with the list of valid devices found up until that point.
func GetVisibleDevicesForProcess(systemDevices []nvml.Device, pid int, procfs string) ([]nvml.Device, error) {
cudaVisibleDevices, err := kernel.GetProcessEnvVariable(pid, procfs, cudaVisibleDevicesEnvVar)
if err != nil {
return nil, fmt.Errorf("cannot get env var %s for process %d: %w", cudaVisibleDevicesEnvVar, pid, err)
}

return getVisibleDevices(systemDevices, cudaVisibleDevices)
}

// getVisibleDevices processes the list of GPU devices according to the value of
// the CUDA_VISIBLE_DEVICES environment variable
func getVisibleDevices(systemDevices []nvml.Device, cudaVisibleDevices string) ([]nvml.Device, error) {
if cudaVisibleDevices == "" {
return systemDevices, nil
}

var filteredDevices []nvml.Device
visibleDevicesList := strings.Split(cudaVisibleDevices, ",")

for _, visibleDevice := range visibleDevicesList {
var matchingDevice nvml.Device
var err error
switch {
case strings.HasPrefix(visibleDevice, "GPU-"):
matchingDevice, err = getDeviceWithMatchingUUIDPrefix(systemDevices, visibleDevice)
if err != nil {
return filteredDevices, err
}
case strings.HasPrefix(visibleDevice, "MIG-GPU"):
// MIG (Multi Instance GPUs) devices require extra parsing and data
// about the MIG instance assignment, which is not supported yet.
return filteredDevices, fmt.Errorf("MIG devices are not supported")
default:
matchingDevice, err = getDeviceWithIndex(systemDevices, visibleDevice)
if err != nil {
return filteredDevices, err
}
}

filteredDevices = append(filteredDevices, matchingDevice)
}

return filteredDevices, nil
}

// getDeviceWithMatchingUUIDPrefix returns the first device with a UUID that
// matches the given prefix. If there are multiple devices with the same prefix
// or the device is not found, an error is returned.
func getDeviceWithMatchingUUIDPrefix(systemDevices []nvml.Device, uuidPrefix string) (nvml.Device, error) {
var matchingDevice nvml.Device
var matchingDeviceUUID string

for _, device := range systemDevices {
uuid, ret := device.GetUUID()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("cannot get UUID for device: %s", nvml.ErrorString(ret))
}

if strings.HasPrefix(uuid, uuidPrefix) {
if matchingDevice != nil {
return nil, fmt.Errorf("non-unique UUID prefix %s, found UUIDs %s and %s", uuidPrefix, matchingDeviceUUID, uuid)
}
matchingDevice = device
matchingDeviceUUID = uuid
}
}

if matchingDevice == nil {
return nil, fmt.Errorf("device with UUID prefix %s not found", uuidPrefix)
}

return matchingDevice, nil
}

// getDeviceWithIndex returns the device with the given index. If the index is
// out of range or the index is not a number, an error is returned.
func getDeviceWithIndex(systemDevices []nvml.Device, visibleDevice string) (nvml.Device, error) {
idx, err := strconv.Atoi(visibleDevice)
if err != nil {
return nil, fmt.Errorf("invalid device index %s: %w", visibleDevice, err)
}

if idx < 0 || idx >= len(systemDevices) {
return nil, fmt.Errorf("device index %d is out of range [0, %d]", idx, len(systemDevices)-1)
}

return systemDevices[idx], nil
}
115 changes: 115 additions & 0 deletions pkg/gpu/cuda/env_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2024-present Datadog, Inc.

//go:build linux_bpf

package cuda

import (
"testing"

"github.com/NVIDIA/go-nvml/pkg/nvml"
nvmlmock "github.com/NVIDIA/go-nvml/pkg/nvml/mock"
"github.com/stretchr/testify/require"
)

func TestGetVisibleDevices(t *testing.T) {
commonPrefix := "GPU-89"
uuid1 := commonPrefix + "32f937-d72c-4106-c12f-20bd9faed9f6"
uuid2 := commonPrefix + "02f078-a8da-4036-a78f-4032bbddeaf2"

dev1 := &nvmlmock.Device{
GetUUIDFunc: func() (string, nvml.Return) {
return uuid1, nvml.SUCCESS
},
}

dev2 := &nvmlmock.Device{
GetUUIDFunc: func() (string, nvml.Return) {
return uuid2, nvml.SUCCESS
},
}

devList := []nvml.Device{dev1, dev2}
cases := []struct {
name string
visibleDevices string
expectedDevices []nvml.Device
expectsError bool
}{
{
name: "no visible devices",
visibleDevices: "",
expectedDevices: devList,
expectsError: false,
},
{
name: "UUIDs",
visibleDevices: uuid1,
expectedDevices: []nvml.Device{devList[0]},
expectsError: false,
},
{
name: "Index",
visibleDevices: "1",
expectedDevices: []nvml.Device{devList[1]},
expectsError: false,
},
{
name: "IndexOutOfRange",
visibleDevices: "2",
expectedDevices: nil,
expectsError: true,
},
{
name: "InvalidIndex",
visibleDevices: "a",
expectedDevices: nil,
expectsError: true,
},
{
name: "MIGDevices",
visibleDevices: "MIG-GPU-1",
expectedDevices: nil,
expectsError: true,
},
{name: "UnorderedIndexes",
visibleDevices: "1,0",
expectedDevices: []nvml.Device{devList[1], devList[0]},
expectsError: false,
},
{
name: "MixedIndexesAndUUIDs",
visibleDevices: "0," + uuid2,
expectedDevices: []nvml.Device{devList[0], devList[1]},
expectsError: false,
},
{
name: "InvalidIndexInMiddle",
visibleDevices: "0,235,1",
expectedDevices: []nvml.Device{devList[0]},
expectsError: true,
},
{
name: "SharedPrefix",
visibleDevices: commonPrefix,
expectedDevices: nil,
expectsError: true,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
devices, err := getVisibleDevices(devList, tc.visibleDevices)
if tc.expectsError {
require.Error(t, err)
} else {
require.NoError(t, err)
}

require.Equal(t, tc.expectedDevices, devices)
})
}
}
58 changes: 58 additions & 0 deletions pkg/util/kernel/proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
package kernel

import (
"bufio"
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"strings"
)

// AllPidsProcs will return all pids under procRoot
Expand Down Expand Up @@ -49,3 +55,55 @@ func WithAllProcs(procRoot string, fn func(int) error) error {
}
return nil
}

// scanNullString is a SplitFunc for a Scanner that returns each null-terminated
// string as a token. Receives the data from the scanner that's yet to be
// processed into tokens, and whether the scanner has reached EOF.
//
// Returns the number of bytes to advance the scanner, the token that was
// detected and an error in case of failure
func scanNullStrings(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\x00'); i >= 0 {
// We have a full null-terminated line.
return i + 1, data[0:i], nil
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}

func getEnvVariableFromBuffer(reader io.Reader, envVar string) string {
scanner := bufio.NewScanner(reader)
scanner.Split(scanNullStrings)
for scanner.Scan() {
parts := strings.SplitN(scanner.Text(), "=", 2)
if len(parts) != 2 {
continue
}

if parts[0] == envVar {
return parts[1]
}
}

return ""
}

// GetProcessEnvVariable retrieves the given environment variable for the specified process ID, without
// loading the entire environment into memory. Will return an empty string if the variable is not found.
func GetProcessEnvVariable(pid int, procRoot string, envVar string) (string, error) {
envPath := filepath.Join(procRoot, strconv.Itoa(pid), "environ")
envFile, err := os.Open(envPath)
if err != nil {
return "", fmt.Errorf("cannot open %s: %w", envPath, err)
}
defer envFile.Close()

return getEnvVariableFromBuffer(envFile, envVar), nil
}
44 changes: 44 additions & 0 deletions pkg/util/kernel/proc_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package kernel

import (
"bytes"
"os"
"runtime"
"strconv"
Expand Down Expand Up @@ -77,3 +78,46 @@ func BenchmarkAllPidsProcs(b *testing.B) {
}
runtime.KeepAlive(pids)
}

func TestGetEnvVariableFromBuffer(t *testing.T) {
cases := []struct {
name string
contents string
envVar string
expected string
}{
{
name: "NonExistent",
contents: "PATH=/usr/bin\x00HOME=/home/user\x00",
envVar: "NONEXISTENT",
expected: "",
},
{
name: "Exists",
contents: "PATH=/usr/bin\x00MY_VAR=myvar\x00HOME=/home/user\x00",
envVar: "MY_VAR",
expected: "myvar",
},
{
name: "Empty",
contents: "PATH=/usr/bin\x00MY_VAR=\x00HOME=/home/user\x00",
envVar: "MY_VAR",
expected: "",
},
{
name: "PrefixVarNotSelected",
contents: "PATH=/usr/bin\x00MY_VAR_BUT_NOT_THIS=nope\x00MY_VAR=myvar\x00HOME=/home/user\x00",
envVar: "MY_VAR",
expected: "myvar",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
actual := getEnvVariableFromBuffer(bytes.NewBufferString(tc.contents), tc.envVar)
if actual != tc.expected {
t.Fatalf("Expected %s, got %s", tc.expected, actual)
}
})
}
}

0 comments on commit 3580727

Please sign in to comment.