Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EBPF-601] gpu: add function to retrieve visible devices for a process #30510

Merged
merged 14 commits into from
Nov 4, 2024
126 changes: 126 additions & 0 deletions pkg/gpu/cuda/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2024-present Datadog, Inc.

//go:build linux_bpf

package cuda

import (
"fmt"
"strconv"
"strings"

"github.com/NVIDIA/go-nvml/pkg/nvml"

"github.com/DataDog/datadog-agent/pkg/util/kernel"
)

const cudaVisibleDevicesEnvVar = "CUDA_VISIBLE_DEVICES"
val06 marked this conversation as resolved.
Show resolved Hide resolved

// GetVisibleDevicesForProcess modifies the list of GPU devices according to the
val06 marked this conversation as resolved.
Show resolved Hide resolved
// value of the CUDA_VISIBLE_DEVICES environment variable for the specified
// process. Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars.
//
// As a summary, the CUDA_VISIBLE_DEVICES environment variable should be a comma
// separated list of GPU identifiers. These can be either the index of the GPU
// (0, 1, 2) or the UUID of the GPU (GPU-<UUID>, or
// MIG-GPU-<UUID>/<instance-index>/<compute-index for multi-instance GPUs). UUID
// identifiers do not need to be the full UUID, it is enough with specifying the
// prefix that uniquely identifies the GPU.
//
// Invalid device indexes are ignored, and anything that comes after that is
// invisible, following the spec: "If one of the indices is invalid, only the
// devices whose index precedes the invalid index are visible to CUDA
// applications." If an invalid index is found, an error is returned together
// with the list of valid devices found up until that point.
func GetVisibleDevicesForProcess(systemDevices []nvml.Device, pid int, procfs string) ([]nvml.Device, error) {
cudaVisibleDevices, err := kernel.GetProcessEnvVariable(pid, procfs, cudaVisibleDevicesEnvVar)
if err != nil {
return nil, fmt.Errorf("cannot get env var %s for process %d: %w", cudaVisibleDevicesEnvVar, pid, err)
}

return getVisibleDevices(systemDevices, cudaVisibleDevices)
}

// getVisibleDevices processes the list of GPU devices according to the value of
// the CUDA_VISIBLE_DEVICES environment variable
func getVisibleDevices(systemDevices []nvml.Device, cudaVisibleDevices string) ([]nvml.Device, error) {
if cudaVisibleDevices == "" {
return systemDevices, nil
}

var filteredDevices []nvml.Device
visibleDevicesList := strings.Split(cudaVisibleDevices, ",")

for _, visibleDevice := range visibleDevicesList {
val06 marked this conversation as resolved.
Show resolved Hide resolved
var matchingDevice nvml.Device
var err error
switch {
case strings.HasPrefix(visibleDevice, "GPU-"):
matchingDevice, err = getDeviceWithMatchingUUIDPrefix(systemDevices, visibleDevice)
if err != nil {
return filteredDevices, err
}
case strings.HasPrefix(visibleDevice, "MIG-GPU"):
// MIG (Multi Instance GPUs) devices require extra parsing and data
// about the MIG instance assignment, which is not supported yet.
return filteredDevices, fmt.Errorf("MIG devices are not supported")
val06 marked this conversation as resolved.
Show resolved Hide resolved
default:
matchingDevice, err = getDeviceWithIndex(systemDevices, visibleDevice)
if err != nil {
return filteredDevices, err
}
}

filteredDevices = append(filteredDevices, matchingDevice)
}

return filteredDevices, nil
}

// getDeviceWithMatchingUUIDPrefix returns the first device with a UUID that
// matches the given prefix. If there are multiple devices with the same prefix
// or the device is not found, an error is returned.
func getDeviceWithMatchingUUIDPrefix(systemDevices []nvml.Device, uuidPrefix string) (nvml.Device, error) {
var matchingDevice nvml.Device
var matchingDeviceUUID string

for _, device := range systemDevices {
uuid, ret := device.GetUUID()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("cannot get UUID for device: %s", nvml.ErrorString(ret))
}

if strings.HasPrefix(uuid, uuidPrefix) {
val06 marked this conversation as resolved.
Show resolved Hide resolved
if matchingDevice != nil {
return nil, fmt.Errorf("non-unique UUID prefix %s, found UUIDs %s and %s", uuidPrefix, matchingDeviceUUID, uuid)
}
matchingDevice = device
matchingDeviceUUID = uuid
}
}

if matchingDevice == nil {
return nil, fmt.Errorf("device with UUID prefix %s not found", uuidPrefix)
}

return matchingDevice, nil
}

// getDeviceWithIndex returns the device with the given index. If the index is
// out of range or the index is not a number, an error is returned.
func getDeviceWithIndex(systemDevices []nvml.Device, visibleDevice string) (nvml.Device, error) {
idx, err := strconv.Atoi(visibleDevice)
if err != nil {
return nil, fmt.Errorf("invalid device index %s: %w", visibleDevice, err)
}

if idx < 0 || idx >= len(systemDevices) {
return nil, fmt.Errorf("device index %d is out of range [0, %d]", idx, len(systemDevices)-1)
}

return systemDevices[idx], nil
}
115 changes: 115 additions & 0 deletions pkg/gpu/cuda/env_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2024-present Datadog, Inc.

//go:build linux_bpf

package cuda

import (
"testing"

"github.com/NVIDIA/go-nvml/pkg/nvml"
nvmlmock "github.com/NVIDIA/go-nvml/pkg/nvml/mock"
"github.com/stretchr/testify/require"
)

func TestGetVisibleDevices(t *testing.T) {
commonPrefix := "GPU-89"
uuid1 := commonPrefix + "32f937-d72c-4106-c12f-20bd9faed9f6"
uuid2 := commonPrefix + "02f078-a8da-4036-a78f-4032bbddeaf2"

dev1 := &nvmlmock.Device{
GetUUIDFunc: func() (string, nvml.Return) {
return uuid1, nvml.SUCCESS
},
}

dev2 := &nvmlmock.Device{
GetUUIDFunc: func() (string, nvml.Return) {
return uuid2, nvml.SUCCESS
},
}

devList := []nvml.Device{dev1, dev2}
cases := []struct {
name string
visibleDevices string
expectedDevices []nvml.Device
expectsError bool
}{
{
name: "no visible devices",
visibleDevices: "",
expectedDevices: devList,
expectsError: false,
},
{
name: "UUIDs",
visibleDevices: uuid1,
expectedDevices: []nvml.Device{devList[0]},
expectsError: false,
},
{
name: "Index",
visibleDevices: "1",
expectedDevices: []nvml.Device{devList[1]},
expectsError: false,
},
{
name: "IndexOutOfRange",
visibleDevices: "2",
expectedDevices: nil,
expectsError: true,
},
{
name: "InvalidIndex",
visibleDevices: "a",
expectedDevices: nil,
expectsError: true,
},
{
name: "MIGDevices",
visibleDevices: "MIG-GPU-1",
expectedDevices: nil,
expectsError: true,
},
{name: "UnorderedIndexes",
visibleDevices: "1,0",
expectedDevices: []nvml.Device{devList[1], devList[0]},
expectsError: false,
},
{
name: "MixedIndexesAndUUIDs",
val06 marked this conversation as resolved.
Show resolved Hide resolved
visibleDevices: "0," + uuid2,
expectedDevices: []nvml.Device{devList[0], devList[1]},
expectsError: false,
},
{
name: "InvalidIndexInMiddle",
visibleDevices: "0,235,1",
expectedDevices: []nvml.Device{devList[0]},
expectsError: true,
},
{
name: "SharedPrefix",
visibleDevices: commonPrefix,
expectedDevices: nil,
expectsError: true,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
devices, err := getVisibleDevices(devList, tc.visibleDevices)
if tc.expectsError {
require.Error(t, err)
} else {
require.NoError(t, err)
}

require.Equal(t, tc.expectedDevices, devices)
})
}
}
58 changes: 58 additions & 0 deletions pkg/util/kernel/proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
package kernel

import (
"bufio"
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"strings"
)

// AllPidsProcs will return all pids under procRoot
Expand Down Expand Up @@ -49,3 +55,55 @@ func WithAllProcs(procRoot string, fn func(int) error) error {
}
return nil
}

// scanNullString is a SplitFunc for a Scanner that returns each null-terminated
// string as a token. Receives the data from the scanner that's yet to be
// processed into tokens, and whether the scanner has reached EOF.
//
// Returns the number of bytes to advance the scanner, the token that was
// detected and an error in case of failure
func scanNullStrings(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\x00'); i >= 0 {
// We have a full null-terminated line.
return i + 1, data[0:i], nil
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}

func getEnvVariableFromBuffer(reader io.Reader, envVar string) string {
scanner := bufio.NewScanner(reader)
scanner.Split(scanNullStrings)
for scanner.Scan() {
parts := strings.SplitN(scanner.Text(), "=", 2)
if len(parts) != 2 {
continue
}

if parts[0] == envVar {
return parts[1]
}
}

return ""
}

// GetProcessEnvVariable retrieves the given environment variable for the specified process ID, without
// loading the entire environment into memory. Will return an empty string if the variable is not found.
func GetProcessEnvVariable(pid int, procRoot string, envVar string) (string, error) {
val06 marked this conversation as resolved.
Show resolved Hide resolved
envPath := filepath.Join(procRoot, strconv.Itoa(pid), "environ")
envFile, err := os.Open(envPath)
if err != nil {
return "", fmt.Errorf("cannot open %s: %w", envPath, err)
}
defer envFile.Close()

return getEnvVariableFromBuffer(envFile, envVar), nil
}
44 changes: 44 additions & 0 deletions pkg/util/kernel/proc_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package kernel

import (
"bytes"
"os"
"runtime"
"strconv"
Expand Down Expand Up @@ -77,3 +78,46 @@ func BenchmarkAllPidsProcs(b *testing.B) {
}
runtime.KeepAlive(pids)
}

func TestGetEnvVariableFromBuffer(t *testing.T) {
cases := []struct {
name string
contents string
envVar string
expected string
}{
{
name: "NonExistent",
contents: "PATH=/usr/bin\x00HOME=/home/user\x00",
envVar: "NONEXISTENT",
expected: "",
},
{
name: "Exists",
contents: "PATH=/usr/bin\x00MY_VAR=myvar\x00HOME=/home/user\x00",
envVar: "MY_VAR",
expected: "myvar",
},
{
name: "Empty",
contents: "PATH=/usr/bin\x00MY_VAR=\x00HOME=/home/user\x00",
envVar: "MY_VAR",
expected: "",
},
{
name: "PrefixVarNotSelected",
contents: "PATH=/usr/bin\x00MY_VAR_BUT_NOT_THIS=nope\x00MY_VAR=myvar\x00HOME=/home/user\x00",
envVar: "MY_VAR",
expected: "myvar",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
actual := getEnvVariableFromBuffer(bytes.NewBufferString(tc.contents), tc.envVar)
if actual != tc.expected {
t.Fatalf("Expected %s, got %s", tc.expected, actual)
}
})
}
}
Loading