feat(runner): DMA-BUF support (for NVIDIA) (#181)

Also includes other improvements and hopefully reducing LOC with some
cleanup.

---------

Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
Kristian Ollikainen
2025-02-11 12:03:03 +02:00
committed by GitHub
parent 060718d8b0
commit 7de6e243ed
7 changed files with 429 additions and 534 deletions

View File

@@ -1,6 +1,5 @@
use regex::Regex;
use std::fs;
use std::path::Path;
use std::process::Command;
use std::str;
@@ -49,9 +48,9 @@ impl GPUInfo {
fn get_gpu_vendor(vendor_id: &str) -> GPUVendor {
match vendor_id {
"8086" => GPUVendor::INTEL, // Intel
"10de" => GPUVendor::NVIDIA, // NVIDIA
"1002" => GPUVendor::AMD, // AMD/ATI
"8086" => GPUVendor::INTEL,
"10de" => GPUVendor::NVIDIA,
"1002" => GPUVendor::AMD,
_ => GPUVendor::UNKNOWN,
}
}
@@ -60,174 +59,105 @@ fn get_gpu_vendor(vendor_id: &str) -> GPUVendor {
/// # Returns
/// * `Vec<GPUInfo>` - A vector containing information about each GPU.
pub fn get_gpus() -> Vec<GPUInfo> {
let mut gpus = Vec::new();
// Run lspci to get PCI devices related to GPUs
let lspci_output = Command::new("lspci")
.arg("-mmnn") // Get machine-readable output with IDs
let output = Command::new("lspci")
.args(["-mm", "-nn"])
.output()
.expect("Failed to execute lspci");
let output = str::from_utf8(&lspci_output.stdout).unwrap();
// Filter lines that mention VGA or 3D controller
for line in output.lines() {
if line.to_lowercase().contains("vga compatible controller")
|| line.to_lowercase().contains("3d controller")
|| line.to_lowercase().contains("display controller")
{
if let Some((pci_addr, vendor_id, device_name)) = parse_pci_device(line) {
// Run udevadm to get the device path
if let Some((card_path, render_path)) = get_dri_device_path(&pci_addr) {
let vendor = get_gpu_vendor(&vendor_id);
gpus.push(GPUInfo {
vendor,
card_path,
render_path,
device_name,
});
}
}
}
}
gpus
}
/// Parses a line from the lspci output to extract PCI device information.
/// # Arguments
/// * `line` - A string slice that holds a line from the 'lspci -mmnn' output.
/// # Returns
/// * `Option<(String, String, String)>` - A tuple containing the PCI address, vendor ID, and device name if parsing is successful.
fn parse_pci_device(line: &str) -> Option<(String, String, String)> {
// Define a regex pattern to match PCI device lines
let re = Regex::new(r#"(?P<pci_addr>[0-9a-fA-F]{1,2}:[0-9a-fA-F]{2}\.[0-9]) "[^"]+" "(?P<vendor_name>[^"]+)" "(?P<device_name>[^"]+)"#).unwrap();
// Collect all matched groups
let parts: Vec<(String, String, String)> = re
.captures_iter(line)
.filter_map(|cap| {
Some((
cap["pci_addr"].to_string(),
cap["vendor_name"].to_string(),
cap["device_name"].to_string(),
))
str::from_utf8(&output.stdout)
.unwrap()
.lines()
.filter_map(|line| parse_pci_device(line))
.filter(|(class_id, _, _, _)| matches!(class_id.as_str(), "0300" | "0302" | "0380"))
.filter_map(|(_, vendor_id, device_name, pci_addr)| {
get_dri_device_path(&pci_addr)
.map(|(card, render)| (vendor_id, card, render, device_name))
})
.collect();
if let Some((pci_addr, vendor_name, device_name)) = parts.first() {
// Create a mutable copy of the device name to modify
let mut device_name = device_name.clone();
// If more than 1 square-bracketed item is found, remove last one, otherwise remove all
if let Some(start) = device_name.rfind('[') {
if let Some(_) = device_name.rfind(']') {
device_name = device_name[..start].trim().to_string();
}
}
// Extract vendor ID from vendor name (e.g., "Intel Corporation [8086]" or "Advanced Micro Devices, Inc. [AMD/ATI] [1002]")
let vendor_id = vendor_name
.split_whitespace()
.last()
.unwrap_or_default()
.trim_matches(|c: char| !c.is_ascii_hexdigit())
.to_string();
return Some((pci_addr.clone(), vendor_id, device_name));
}
None
.map(|(vid, card_path, render_path, device_name)| GPUInfo {
vendor: get_gpu_vendor(&vid),
card_path,
render_path,
device_name,
})
.collect()
}
fn parse_pci_device(line: &str) -> Option<(String, String, String, String)> {
let re = Regex::new(
r#"^(?P<pci_addr>\S+)\s+"[^\[]*\[(?P<class_id>[0-9a-f]{4})\].*?"\s+"[^\[]*\[(?P<vendor_id>[0-9a-f]{4})\].*?"\s+"(?P<device_name>[^"]+?)""#,
).unwrap();
let caps = re.captures(line)?;
// Clean device name by removing only the trailing device ID
let device_name = caps.name("device_name")?.as_str().trim();
let clean_re = Regex::new(r"\s+\[[0-9a-f]{4}\]$").unwrap();
let cleaned_name = clean_re.replace(device_name, "").trim().to_string();
Some((
caps.name("class_id")?.as_str().to_lowercase(),
caps.name("vendor_id")?.as_str().to_lowercase(),
cleaned_name,
caps.name("pci_addr")?.as_str().to_string(),
))
}
/// Retrieves the DRI device paths for a given PCI address.
/// Doubles as a way to verify the device is a GPU.
/// # Arguments
/// * `pci_addr` - A string slice that holds the PCI address.
/// # Returns
/// * `Option<(String, String)>` - A tuple containing the card path and render path if found.
fn get_dri_device_path(pci_addr: &str) -> Option<(String, String)> {
// Construct the base directory in /sys/bus/pci/devices to start the search
let base_dir = Path::new("/sys/bus/pci/devices");
let target_dir = format!("0000:{}", pci_addr);
let entries = fs::read_dir("/sys/bus/pci/devices").ok()?;
// Define the target PCI address with "0000:" prefix
let target_addr = format!("0000:{}", pci_addr);
for entry in entries.flatten() {
if !entry.path().to_string_lossy().contains(&target_dir) {
continue;
}
// Search for a matching directory that contains the target PCI address
for entry in fs::read_dir(base_dir).ok()?.flatten() {
let path = entry.path();
let mut card = String::new();
let mut render = String::new();
let drm_path = entry.path().join("drm");
// Check if the path matches the target PCI address
if path.to_string_lossy().contains(&target_addr) {
// Look for any files under the 'drm' subdirectory, like 'card0' or 'renderD128'
let drm_path = path.join("drm");
if drm_path.exists() {
let mut card_path = String::new();
let mut render_path = String::new();
for drm_entry in fs::read_dir(drm_path).ok()?.flatten() {
let file_name = drm_entry.file_name();
if let Some(name) = file_name.to_str() {
if name.starts_with("card") {
card_path = format!("/dev/dri/{}", name);
}
if name.starts_with("renderD") {
render_path = format!("/dev/dri/{}", name);
}
// If both paths are found, break the loop
if !card_path.is_empty() && !render_path.is_empty() {
break;
}
}
}
return Some((card_path, render_path));
for drm_entry in fs::read_dir(drm_path).ok()?.flatten() {
let name = drm_entry.file_name().to_string_lossy().into_owned();
if name.starts_with("card") {
card = format!("/dev/dri/{}", name);
} else if name.starts_with("renderD") {
render = format!("/dev/dri/{}", name);
}
if !card.is_empty() && !render.is_empty() {
break;
}
}
}
None
}
/// Helper method to get GPUs from vector by vendor name (case-insensitive).
/// # Arguments
/// * `gpus` - A vector containing information about each GPU.
/// * `vendor` - A string slice that holds the vendor name.
/// # Returns
/// * `Vec<GPUInfo>` - A vector containing GPUInfo structs if found.
pub fn get_gpus_by_vendor(gpus: &Vec<GPUInfo>, vendor: &str) -> Vec<GPUInfo> {
let vendor = vendor.to_lowercase();
gpus.iter()
.filter(|gpu| gpu.vendor_string().to_lowercase().contains(&vendor))
.cloned()
.collect()
}
/// Helper method to get GPUs from vector by device name substring (case-insensitive).
/// # Arguments
/// * `gpus` - A vector containing information about each GPU.
/// * `device_name` - A string slice that holds the device name substring.
/// # Returns
/// * `Vec<GPUInfo>` - A vector containing GPUInfo structs if found.
pub fn get_gpus_by_device_name(gpus: &Vec<GPUInfo>, device_name: &str) -> Vec<GPUInfo> {
let device_name = device_name.to_lowercase();
gpus.iter()
.filter(|gpu| gpu.device_name.to_lowercase().contains(&device_name))
.cloned()
.collect()
}
/// Helper method to get a GPU from vector of GPUInfo by card path (either /dev/dri/cardX or /dev/dri/renderDX, case-insensitive).
/// # Arguments
/// * `gpus` - A vector containing information about each GPU.
/// * `card_path` - A string slice that holds the card path.
/// # Returns
/// * `Option<GPUInfo>` - A reference to a GPUInfo struct if found.
pub fn get_gpu_by_card_path(gpus: &Vec<GPUInfo>, card_path: &str) -> Option<GPUInfo> {
for gpu in gpus {
if gpu.card_path().to_lowercase() == card_path.to_lowercase()
|| gpu.render_path().to_lowercase() == card_path.to_lowercase()
{
return Some(gpu.clone());
if !card.is_empty() {
return Some((card, render));
}
}
None
}
// Helper functions remain similar with improved readability:
pub fn get_gpus_by_vendor(gpus: &[GPUInfo], vendor: &str) -> Vec<GPUInfo> {
let target = vendor.to_lowercase();
gpus.iter()
.filter(|gpu| gpu.vendor_string().to_lowercase() == target)
.cloned()
.collect()
}
pub fn get_gpus_by_device_name(gpus: &[GPUInfo], substring: &str) -> Vec<GPUInfo> {
let target = substring.to_lowercase();
gpus.iter()
.filter(|gpu| gpu.device_name.to_lowercase().contains(&target))
.cloned()
.collect()
}
pub fn get_gpu_by_card_path(gpus: &[GPUInfo], path: &str) -> Option<GPUInfo> {
gpus.iter()
.find(|gpu| {
gpu.card_path.eq_ignore_ascii_case(path) || gpu.render_path.eq_ignore_ascii_case(path)
})
.cloned()
}