mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 16:55:37 +02:00
⭐ feat(runner): DMA-BUF support (for NVIDIA) (#181)
Also includes other improvements and hopefully reducing LOC with some cleanup. --------- Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
060718d8b0
commit
7de6e243ed
@@ -1,6 +1,5 @@
|
||||
use regex::Regex;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
use std::str;
|
||||
|
||||
@@ -49,9 +48,9 @@ impl GPUInfo {
|
||||
|
||||
fn get_gpu_vendor(vendor_id: &str) -> GPUVendor {
|
||||
match vendor_id {
|
||||
"8086" => GPUVendor::INTEL, // Intel
|
||||
"10de" => GPUVendor::NVIDIA, // NVIDIA
|
||||
"1002" => GPUVendor::AMD, // AMD/ATI
|
||||
"8086" => GPUVendor::INTEL,
|
||||
"10de" => GPUVendor::NVIDIA,
|
||||
"1002" => GPUVendor::AMD,
|
||||
_ => GPUVendor::UNKNOWN,
|
||||
}
|
||||
}
|
||||
@@ -60,174 +59,105 @@ fn get_gpu_vendor(vendor_id: &str) -> GPUVendor {
|
||||
/// # Returns
|
||||
/// * `Vec<GPUInfo>` - A vector containing information about each GPU.
|
||||
pub fn get_gpus() -> Vec<GPUInfo> {
|
||||
let mut gpus = Vec::new();
|
||||
|
||||
// Run lspci to get PCI devices related to GPUs
|
||||
let lspci_output = Command::new("lspci")
|
||||
.arg("-mmnn") // Get machine-readable output with IDs
|
||||
let output = Command::new("lspci")
|
||||
.args(["-mm", "-nn"])
|
||||
.output()
|
||||
.expect("Failed to execute lspci");
|
||||
|
||||
let output = str::from_utf8(&lspci_output.stdout).unwrap();
|
||||
|
||||
// Filter lines that mention VGA or 3D controller
|
||||
for line in output.lines() {
|
||||
if line.to_lowercase().contains("vga compatible controller")
|
||||
|| line.to_lowercase().contains("3d controller")
|
||||
|| line.to_lowercase().contains("display controller")
|
||||
{
|
||||
if let Some((pci_addr, vendor_id, device_name)) = parse_pci_device(line) {
|
||||
// Run udevadm to get the device path
|
||||
if let Some((card_path, render_path)) = get_dri_device_path(&pci_addr) {
|
||||
let vendor = get_gpu_vendor(&vendor_id);
|
||||
gpus.push(GPUInfo {
|
||||
vendor,
|
||||
card_path,
|
||||
render_path,
|
||||
device_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
gpus
|
||||
}
|
||||
|
||||
/// Parses a line from the lspci output to extract PCI device information.
|
||||
/// # Arguments
|
||||
/// * `line` - A string slice that holds a line from the 'lspci -mmnn' output.
|
||||
/// # Returns
|
||||
/// * `Option<(String, String, String)>` - A tuple containing the PCI address, vendor ID, and device name if parsing is successful.
|
||||
fn parse_pci_device(line: &str) -> Option<(String, String, String)> {
|
||||
// Define a regex pattern to match PCI device lines
|
||||
let re = Regex::new(r#"(?P<pci_addr>[0-9a-fA-F]{1,2}:[0-9a-fA-F]{2}\.[0-9]) "[^"]+" "(?P<vendor_name>[^"]+)" "(?P<device_name>[^"]+)"#).unwrap();
|
||||
|
||||
// Collect all matched groups
|
||||
let parts: Vec<(String, String, String)> = re
|
||||
.captures_iter(line)
|
||||
.filter_map(|cap| {
|
||||
Some((
|
||||
cap["pci_addr"].to_string(),
|
||||
cap["vendor_name"].to_string(),
|
||||
cap["device_name"].to_string(),
|
||||
))
|
||||
str::from_utf8(&output.stdout)
|
||||
.unwrap()
|
||||
.lines()
|
||||
.filter_map(|line| parse_pci_device(line))
|
||||
.filter(|(class_id, _, _, _)| matches!(class_id.as_str(), "0300" | "0302" | "0380"))
|
||||
.filter_map(|(_, vendor_id, device_name, pci_addr)| {
|
||||
get_dri_device_path(&pci_addr)
|
||||
.map(|(card, render)| (vendor_id, card, render, device_name))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some((pci_addr, vendor_name, device_name)) = parts.first() {
|
||||
// Create a mutable copy of the device name to modify
|
||||
let mut device_name = device_name.clone();
|
||||
|
||||
// If more than 1 square-bracketed item is found, remove last one, otherwise remove all
|
||||
if let Some(start) = device_name.rfind('[') {
|
||||
if let Some(_) = device_name.rfind(']') {
|
||||
device_name = device_name[..start].trim().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Extract vendor ID from vendor name (e.g., "Intel Corporation [8086]" or "Advanced Micro Devices, Inc. [AMD/ATI] [1002]")
|
||||
let vendor_id = vendor_name
|
||||
.split_whitespace()
|
||||
.last()
|
||||
.unwrap_or_default()
|
||||
.trim_matches(|c: char| !c.is_ascii_hexdigit())
|
||||
.to_string();
|
||||
|
||||
return Some((pci_addr.clone(), vendor_id, device_name));
|
||||
}
|
||||
|
||||
None
|
||||
.map(|(vid, card_path, render_path, device_name)| GPUInfo {
|
||||
vendor: get_gpu_vendor(&vid),
|
||||
card_path,
|
||||
render_path,
|
||||
device_name,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn parse_pci_device(line: &str) -> Option<(String, String, String, String)> {
|
||||
let re = Regex::new(
|
||||
r#"^(?P<pci_addr>\S+)\s+"[^\[]*\[(?P<class_id>[0-9a-f]{4})\].*?"\s+"[^\[]*\[(?P<vendor_id>[0-9a-f]{4})\].*?"\s+"(?P<device_name>[^"]+?)""#,
|
||||
).unwrap();
|
||||
|
||||
let caps = re.captures(line)?;
|
||||
|
||||
// Clean device name by removing only the trailing device ID
|
||||
let device_name = caps.name("device_name")?.as_str().trim();
|
||||
let clean_re = Regex::new(r"\s+\[[0-9a-f]{4}\]$").unwrap();
|
||||
let cleaned_name = clean_re.replace(device_name, "").trim().to_string();
|
||||
|
||||
Some((
|
||||
caps.name("class_id")?.as_str().to_lowercase(),
|
||||
caps.name("vendor_id")?.as_str().to_lowercase(),
|
||||
cleaned_name,
|
||||
caps.name("pci_addr")?.as_str().to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Retrieves the DRI device paths for a given PCI address.
|
||||
/// Doubles as a way to verify the device is a GPU.
|
||||
/// # Arguments
|
||||
/// * `pci_addr` - A string slice that holds the PCI address.
|
||||
/// # Returns
|
||||
/// * `Option<(String, String)>` - A tuple containing the card path and render path if found.
|
||||
fn get_dri_device_path(pci_addr: &str) -> Option<(String, String)> {
|
||||
// Construct the base directory in /sys/bus/pci/devices to start the search
|
||||
let base_dir = Path::new("/sys/bus/pci/devices");
|
||||
let target_dir = format!("0000:{}", pci_addr);
|
||||
let entries = fs::read_dir("/sys/bus/pci/devices").ok()?;
|
||||
|
||||
// Define the target PCI address with "0000:" prefix
|
||||
let target_addr = format!("0000:{}", pci_addr);
|
||||
for entry in entries.flatten() {
|
||||
if !entry.path().to_string_lossy().contains(&target_dir) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Search for a matching directory that contains the target PCI address
|
||||
for entry in fs::read_dir(base_dir).ok()?.flatten() {
|
||||
let path = entry.path();
|
||||
let mut card = String::new();
|
||||
let mut render = String::new();
|
||||
let drm_path = entry.path().join("drm");
|
||||
|
||||
// Check if the path matches the target PCI address
|
||||
if path.to_string_lossy().contains(&target_addr) {
|
||||
// Look for any files under the 'drm' subdirectory, like 'card0' or 'renderD128'
|
||||
let drm_path = path.join("drm");
|
||||
if drm_path.exists() {
|
||||
let mut card_path = String::new();
|
||||
let mut render_path = String::new();
|
||||
for drm_entry in fs::read_dir(drm_path).ok()?.flatten() {
|
||||
let file_name = drm_entry.file_name();
|
||||
if let Some(name) = file_name.to_str() {
|
||||
if name.starts_with("card") {
|
||||
card_path = format!("/dev/dri/{}", name);
|
||||
}
|
||||
if name.starts_with("renderD") {
|
||||
render_path = format!("/dev/dri/{}", name);
|
||||
}
|
||||
// If both paths are found, break the loop
|
||||
if !card_path.is_empty() && !render_path.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Some((card_path, render_path));
|
||||
for drm_entry in fs::read_dir(drm_path).ok()?.flatten() {
|
||||
let name = drm_entry.file_name().to_string_lossy().into_owned();
|
||||
|
||||
if name.starts_with("card") {
|
||||
card = format!("/dev/dri/{}", name);
|
||||
} else if name.starts_with("renderD") {
|
||||
render = format!("/dev/dri/{}", name);
|
||||
}
|
||||
|
||||
if !card.is_empty() && !render.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Helper method to get GPUs from vector by vendor name (case-insensitive).
|
||||
/// # Arguments
|
||||
/// * `gpus` - A vector containing information about each GPU.
|
||||
/// * `vendor` - A string slice that holds the vendor name.
|
||||
/// # Returns
|
||||
/// * `Vec<GPUInfo>` - A vector containing GPUInfo structs if found.
|
||||
pub fn get_gpus_by_vendor(gpus: &Vec<GPUInfo>, vendor: &str) -> Vec<GPUInfo> {
|
||||
let vendor = vendor.to_lowercase();
|
||||
gpus.iter()
|
||||
.filter(|gpu| gpu.vendor_string().to_lowercase().contains(&vendor))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Helper method to get GPUs from vector by device name substring (case-insensitive).
|
||||
/// # Arguments
|
||||
/// * `gpus` - A vector containing information about each GPU.
|
||||
/// * `device_name` - A string slice that holds the device name substring.
|
||||
/// # Returns
|
||||
/// * `Vec<GPUInfo>` - A vector containing GPUInfo structs if found.
|
||||
pub fn get_gpus_by_device_name(gpus: &Vec<GPUInfo>, device_name: &str) -> Vec<GPUInfo> {
|
||||
let device_name = device_name.to_lowercase();
|
||||
gpus.iter()
|
||||
.filter(|gpu| gpu.device_name.to_lowercase().contains(&device_name))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Helper method to get a GPU from vector of GPUInfo by card path (either /dev/dri/cardX or /dev/dri/renderDX, case-insensitive).
|
||||
/// # Arguments
|
||||
/// * `gpus` - A vector containing information about each GPU.
|
||||
/// * `card_path` - A string slice that holds the card path.
|
||||
/// # Returns
|
||||
/// * `Option<GPUInfo>` - A reference to a GPUInfo struct if found.
|
||||
pub fn get_gpu_by_card_path(gpus: &Vec<GPUInfo>, card_path: &str) -> Option<GPUInfo> {
|
||||
for gpu in gpus {
|
||||
if gpu.card_path().to_lowercase() == card_path.to_lowercase()
|
||||
|| gpu.render_path().to_lowercase() == card_path.to_lowercase()
|
||||
{
|
||||
return Some(gpu.clone());
|
||||
if !card.is_empty() {
|
||||
return Some((card, render));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// Helper functions remain similar with improved readability:
|
||||
pub fn get_gpus_by_vendor(gpus: &[GPUInfo], vendor: &str) -> Vec<GPUInfo> {
|
||||
let target = vendor.to_lowercase();
|
||||
gpus.iter()
|
||||
.filter(|gpu| gpu.vendor_string().to_lowercase() == target)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_gpus_by_device_name(gpus: &[GPUInfo], substring: &str) -> Vec<GPUInfo> {
|
||||
let target = substring.to_lowercase();
|
||||
gpus.iter()
|
||||
.filter(|gpu| gpu.device_name.to_lowercase().contains(&target))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_gpu_by_card_path(gpus: &[GPUInfo], path: &str) -> Option<GPUInfo> {
|
||||
gpus.iter()
|
||||
.find(|gpu| {
|
||||
gpu.card_path.eq_ignore_ascii_case(path) || gpu.render_path.eq_ignore_ascii_case(path)
|
||||
})
|
||||
.cloned()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user