mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 08:45:38 +02:00
⭐ feat(maitred): Update maitred - hookup to the API (#198)
## Description We are attempting to hookup maitred to the API Maitred duties will be: - [ ] Hookup to the API - [ ] Wait for signal (from the API) to start Steam - [ ] Stop signal to stop the gaming session, clean up Steam... and maybe do the backup ## Summary by CodeRabbit - **New Features** - Introduced Docker-based deployment configurations for both the main and relay applications. - Added new API endpoints enabling real-time machine messaging and enhanced IoT operations. - Expanded database schema and actor types to support improved machine tracking. - **Improvements** - Enhanced real-time communication and relay management with streamlined room handling. - Upgraded dependencies, logging, and error handling for greater stability and performance. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com> Co-authored-by: Kristian Ollikainen <14197772+DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
45
packages/maitred/internal/auth/auth.go
Normal file
45
packages/maitred/internal/auth/auth.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"nestri/maitred/internal/resource"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type UserCredentials struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
func FetchUserToken(machineID string, resource *resource.Resource) (*UserCredentials, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "client_credentials")
|
||||
data.Set("client_id", "maitred")
|
||||
data.Set("client_secret", resource.AuthFingerprintKey.Value)
|
||||
data.Set("fingerprint", machineID)
|
||||
data.Set("provider", "machine")
|
||||
resp, err := http.PostForm(resource.Auth.Url+"/token", data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err = Body.Close()
|
||||
if err != nil {
|
||||
slog.Error("Error closing body", "err", err)
|
||||
}
|
||||
}(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to auth: " + string(body))
|
||||
}
|
||||
credentials := UserCredentials{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &credentials, nil
|
||||
}
|
||||
38
packages/maitred/internal/containers/containers.go
Normal file
38
packages/maitred/internal/containers/containers.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package containers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Container represents a container instance
|
||||
type Container struct {
|
||||
ID string
|
||||
Name string
|
||||
State string
|
||||
Image string
|
||||
}
|
||||
|
||||
// ContainerEngine defines the common interface for differing container engines
|
||||
type ContainerEngine interface {
|
||||
Close() error
|
||||
ListContainers(ctx context.Context) ([]Container, error)
|
||||
ListContainersByImage(ctx context.Context, img string) ([]Container, error)
|
||||
NewContainer(ctx context.Context, img string, envs []string) (string, error)
|
||||
StartContainer(ctx context.Context, id string) error
|
||||
StopContainer(ctx context.Context, id string) error
|
||||
RemoveContainer(ctx context.Context, id string) error
|
||||
InspectContainer(ctx context.Context, id string) (*Container, error)
|
||||
PullImage(ctx context.Context, img string) error
|
||||
Info(ctx context.Context) (string, error)
|
||||
LogsContainer(ctx context.Context, id string) (string, error)
|
||||
}
|
||||
|
||||
func NewContainerEngine() (ContainerEngine, error) {
|
||||
dockerEngine, err := NewDockerEngine()
|
||||
if err == nil {
|
||||
return dockerEngine, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to create container engine: %w", err)
|
||||
}
|
||||
299
packages/maitred/internal/containers/docker.go
Normal file
299
packages/maitred/internal/containers/docker.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package containers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/image"
|
||||
"github.com/docker/docker/client"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DockerEngine implements the ContainerEngine interface for Docker / Docker compatible engines
|
||||
type DockerEngine struct {
|
||||
cli *client.Client
|
||||
}
|
||||
|
||||
func NewDockerEngine() (*DockerEngine, error) {
|
||||
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Docker client: %w", err)
|
||||
}
|
||||
return &DockerEngine{cli: cli}, nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) Close() error {
|
||||
return d.cli.Close()
|
||||
}
|
||||
|
||||
func (d *DockerEngine) ListContainers(ctx context.Context) ([]Container, error) {
|
||||
containerList, err := d.cli.ContainerList(ctx, container.ListOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list containers: %w", err)
|
||||
}
|
||||
|
||||
var result []Container
|
||||
for _, c := range containerList {
|
||||
result = append(result, Container{
|
||||
ID: c.ID,
|
||||
Name: strings.TrimPrefix(strings.Join(c.Names, ","), "/"),
|
||||
State: c.State,
|
||||
Image: c.Image,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) ListContainersByImage(ctx context.Context, img string) ([]Container, error) {
|
||||
if len(img) <= 0 {
|
||||
return nil, fmt.Errorf("image name cannot be empty")
|
||||
}
|
||||
|
||||
containerList, err := d.cli.ContainerList(ctx, container.ListOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list containers: %w", err)
|
||||
}
|
||||
|
||||
var result []Container
|
||||
for _, c := range containerList {
|
||||
if c.Image == img {
|
||||
result = append(result, Container{
|
||||
ID: c.ID,
|
||||
Name: strings.TrimPrefix(strings.Join(c.Names, ","), "/"),
|
||||
State: c.State,
|
||||
Image: c.Image,
|
||||
})
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) NewContainer(ctx context.Context, img string, envs []string) (string, error) {
|
||||
// Create a new container with the given image and environment variables
|
||||
resp, err := d.cli.ContainerCreate(ctx, &container.Config{
|
||||
Image: img,
|
||||
Env: envs,
|
||||
}, &container.HostConfig{
|
||||
NetworkMode: "host",
|
||||
}, nil, nil, "")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create container: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.ID) <= 0 {
|
||||
return "", fmt.Errorf("failed to create container, no ID returned")
|
||||
}
|
||||
|
||||
return resp.ID, nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) StartContainer(ctx context.Context, id string) error {
|
||||
err := d.cli.ContainerStart(ctx, id, container.StartOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start container: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the container to start
|
||||
if err = d.waitForContainer(ctx, id, "running"); err != nil {
|
||||
return fmt.Errorf("container failed to reach running state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) StopContainer(ctx context.Context, id string) error {
|
||||
// Waiter for the container to stop
|
||||
respChan, errChan := d.cli.ContainerWait(ctx, id, container.WaitConditionNotRunning)
|
||||
|
||||
// Stop the container
|
||||
err := d.cli.ContainerStop(ctx, id, container.StopOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop container: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-respChan:
|
||||
// Container stopped successfully
|
||||
break
|
||||
case err = <-errChan:
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to wait for container to stop: %w", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context canceled while waiting for container to stop")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) RemoveContainer(ctx context.Context, id string) error {
|
||||
// Waiter for the container to be removed
|
||||
respChan, errChan := d.cli.ContainerWait(ctx, id, container.WaitConditionRemoved)
|
||||
|
||||
err := d.cli.ContainerRemove(ctx, id, container.RemoveOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove container: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-respChan:
|
||||
// Container removed successfully
|
||||
break
|
||||
case err = <-errChan:
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to wait for container to be removed: %w", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context canceled while waiting for container to stop")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) InspectContainer(ctx context.Context, id string) (*Container, error) {
|
||||
info, err := d.cli.ContainerInspect(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to inspect container: %w", err)
|
||||
}
|
||||
|
||||
return &Container{
|
||||
ID: info.ID,
|
||||
Name: info.Name,
|
||||
State: info.State.Status,
|
||||
Image: info.Config.Image,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) PullImage(ctx context.Context, img string) error {
|
||||
if len(img) <= 0 {
|
||||
return fmt.Errorf("image name cannot be empty")
|
||||
}
|
||||
|
||||
slog.Info("Starting image pull", "image", img)
|
||||
|
||||
reader, err := d.cli.ImagePull(ctx, img, image.PullOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start image pull for %s: %w", img, err)
|
||||
}
|
||||
defer func(reader io.ReadCloser) {
|
||||
err = reader.Close()
|
||||
if err != nil {
|
||||
slog.Warn("Failed to close reader", "err", err)
|
||||
}
|
||||
}(reader)
|
||||
|
||||
// Parse the JSON stream for progress
|
||||
decoder := json.NewDecoder(reader)
|
||||
lastDownloadPercent := 0
|
||||
downloadTotals := make(map[string]int64)
|
||||
downloadCurrents := make(map[string]int64)
|
||||
|
||||
var msg struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
ProgressDetail struct {
|
||||
Current int64 `json:"current"`
|
||||
Total int64 `json:"total"`
|
||||
} `json:"progressDetail"`
|
||||
}
|
||||
|
||||
for {
|
||||
err = decoder.Decode(&msg)
|
||||
if err == io.EOF {
|
||||
break // Pull completed
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding pull response for %s: %w", img, err)
|
||||
}
|
||||
|
||||
// Skip if no progress details or ID
|
||||
if msg.ID == "" || msg.ProgressDetail.Total == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(msg.Status), "downloading") {
|
||||
downloadTotals[msg.ID] = msg.ProgressDetail.Total
|
||||
downloadCurrents[msg.ID] = msg.ProgressDetail.Current
|
||||
var total, current int64
|
||||
for _, t := range downloadTotals {
|
||||
total += t
|
||||
}
|
||||
for _, c := range downloadCurrents {
|
||||
current += c
|
||||
}
|
||||
percent := int((float64(current) / float64(total)) * 100)
|
||||
if percent >= lastDownloadPercent+10 && percent <= 100 {
|
||||
slog.Info("Download progress", "image", img, "percent", percent)
|
||||
lastDownloadPercent = percent - (percent % 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("Pulled image", "image", img)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) Info(ctx context.Context) (string, error) {
|
||||
info, err := d.cli.Info(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get Docker info: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Docker Engine Version: %s", info.ServerVersion), nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) LogsContainer(ctx context.Context, id string) (string, error) {
|
||||
reader, err := d.cli.ContainerLogs(ctx, id, container.LogsOptions{ShowStdout: true, ShowStderr: true})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get container logs: %w", err)
|
||||
}
|
||||
defer func(reader io.ReadCloser) {
|
||||
err = reader.Close()
|
||||
if err != nil {
|
||||
slog.Warn("Failed to close reader", "err", err)
|
||||
}
|
||||
}(reader)
|
||||
|
||||
logs, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read container logs: %w", err)
|
||||
}
|
||||
|
||||
return string(logs), nil
|
||||
}
|
||||
|
||||
func (d *DockerEngine) waitForContainer(ctx context.Context, id, desiredState string) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
// Inspect the container to get its current state
|
||||
inspection, err := d.cli.ContainerInspect(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to inspect container: %w", err)
|
||||
}
|
||||
|
||||
// Check the container's state
|
||||
currentState := strings.ToLower(inspection.State.Status)
|
||||
switch currentState {
|
||||
case desiredState:
|
||||
// Container is in the desired state (e.g., "running")
|
||||
return nil
|
||||
case "exited", "dead", "removing":
|
||||
// Container failed or stopped unexpectedly, get logs and return error
|
||||
logs, _ := d.LogsContainer(ctx, id)
|
||||
return fmt.Errorf("container failed to reach %s state, logs: %s", desiredState, logs)
|
||||
}
|
||||
|
||||
// Wait before polling again
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timed out after 10s waiting for container to reach %s state", desiredState)
|
||||
case <-time.After(1 * time.Second):
|
||||
// Continue polling
|
||||
}
|
||||
}
|
||||
}
|
||||
70
packages/maitred/internal/flags.go
Normal file
70
packages/maitred/internal/flags.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var globalFlags *Flags
|
||||
|
||||
type Flags struct {
|
||||
Verbose bool // Log everything to console
|
||||
Debug bool // Enable debug mode, implies Verbose - disables SST and MQTT connections
|
||||
NoMonitor bool // Disable system monitoring
|
||||
}
|
||||
|
||||
func (flags *Flags) DebugLog() {
|
||||
slog.Info("Maitred flags",
|
||||
"verbose", flags.Verbose,
|
||||
"debug", flags.Debug,
|
||||
"no-monitor", flags.NoMonitor,
|
||||
)
|
||||
}
|
||||
|
||||
func getEnvAsInt(name string, defaultVal int) int {
|
||||
valueStr := os.Getenv(name)
|
||||
if value, err := strconv.Atoi(valueStr); err != nil {
|
||||
return defaultVal
|
||||
} else {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvAsBool(name string, defaultVal bool) bool {
|
||||
valueStr := os.Getenv(name)
|
||||
val, err := strconv.ParseBool(valueStr)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func getEnvAsString(name string, defaultVal string) string {
|
||||
valueStr := os.Getenv(name)
|
||||
if len(valueStr) == 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return valueStr
|
||||
}
|
||||
|
||||
func InitFlags() {
|
||||
// Create Flags struct
|
||||
globalFlags = &Flags{}
|
||||
// Get flags
|
||||
flag.BoolVar(&globalFlags.Verbose, "verbose", getEnvAsBool("VERBOSE", false), "Verbose mode")
|
||||
flag.BoolVar(&globalFlags.Debug, "debug", getEnvAsBool("DEBUG", false), "Debug mode")
|
||||
flag.BoolVar(&globalFlags.NoMonitor, "no-monitor", getEnvAsBool("NO_MONITOR", false), "Disable system monitoring")
|
||||
// Parse flags
|
||||
flag.Parse()
|
||||
|
||||
// If debug is enabled, verbose is also enabled
|
||||
if globalFlags.Debug {
|
||||
globalFlags.Verbose = true
|
||||
}
|
||||
}
|
||||
|
||||
func GetFlags() *Flags {
|
||||
return globalFlags
|
||||
}
|
||||
48
packages/maitred/internal/handler.go
Normal file
48
packages/maitred/internal/handler.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CustomHandler struct {
|
||||
Handler slog.Handler
|
||||
}
|
||||
|
||||
func (h *CustomHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
return h.Handler.Enabled(nil, level)
|
||||
}
|
||||
|
||||
func (h *CustomHandler) Handle(_ context.Context, r slog.Record) error {
|
||||
// Format the timestamp as "2006/01/02 15:04:05"
|
||||
timestamp := r.Time.Format("2006/01/02 15:04:05")
|
||||
// Convert level to uppercase string (e.g., "INFO")
|
||||
level := strings.ToUpper(r.Level.String())
|
||||
// Build the message
|
||||
msg := fmt.Sprintf("%s %s %s", timestamp, level, r.Message)
|
||||
|
||||
// Handle additional attributes if they exist
|
||||
var attrs []string
|
||||
r.Attrs(func(a slog.Attr) bool {
|
||||
attrs = append(attrs, fmt.Sprintf("%s=%v", a.Key, a.Value))
|
||||
return true
|
||||
})
|
||||
if len(attrs) > 0 {
|
||||
msg += " " + strings.Join(attrs, " ")
|
||||
}
|
||||
|
||||
// Write the formatted message to stdout
|
||||
_, err := fmt.Fprintln(os.Stdout, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *CustomHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return &CustomHandler{Handler: h.Handler.WithAttrs(attrs)}
|
||||
}
|
||||
|
||||
func (h *CustomHandler) WithGroup(name string) slog.Handler {
|
||||
return &CustomHandler{Handler: h.Handler.WithGroup(name)}
|
||||
}
|
||||
366
packages/maitred/internal/realtime/managed.go
Normal file
366
packages/maitred/internal/realtime/managed.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package realtime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"nestri/maitred/internal"
|
||||
"nestri/maitred/internal/containers"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
nestriRunnerImage = "ghcr.io/nestrilabs/nestri/runner:nightly"
|
||||
nestriRelayImage = "ghcr.io/nestrilabs/nestri/relay:nightly"
|
||||
)
|
||||
|
||||
type ManagedContainerType int
|
||||
|
||||
const (
|
||||
// Runner is the nestri runner container
|
||||
Runner ManagedContainerType = iota
|
||||
// Relay is the nestri relay container
|
||||
Relay
|
||||
)
|
||||
|
||||
// ManagedContainer type with extra information fields
|
||||
type ManagedContainer struct {
|
||||
containers.Container
|
||||
Type ManagedContainerType
|
||||
}
|
||||
|
||||
// managedContainers is a map of containers that are managed by us (maitred)
|
||||
var (
|
||||
managedContainers = make(map[string]ManagedContainer)
|
||||
managedContainersMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// InitializeManager handles the initialization of the managed containers and pulls their latest images
|
||||
func InitializeManager(ctx context.Context, ctrEngine containers.ContainerEngine) error {
|
||||
// If debug, override the images
|
||||
if internal.GetFlags().Debug {
|
||||
nestriRunnerImage = "ghcr.io/datcaptainhorse/nestri-cachyos:latest-v3"
|
||||
nestriRelayImage = "ghcr.io/datcaptainhorse/nestri-relay:latest"
|
||||
}
|
||||
|
||||
// Look for existing stopped runner containers and remove them
|
||||
slog.Info("Checking and removing old runner containers")
|
||||
oldRunners, err := ctrEngine.ListContainersByImage(ctx, nestriRunnerImage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range oldRunners {
|
||||
// If running, stop first
|
||||
if strings.Contains(strings.ToLower(c.State), "running") {
|
||||
slog.Info("Stopping old runner container", "id", c.ID)
|
||||
if err = ctrEngine.StopContainer(ctx, c.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
slog.Info("Removing old runner container", "id", c.ID)
|
||||
if err = ctrEngine.RemoveContainer(ctx, c.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Pull the runner image if not in debug mode
|
||||
if !internal.GetFlags().Debug {
|
||||
slog.Info("Pulling runner image", "image", nestriRunnerImage)
|
||||
if err := ctrEngine.PullImage(ctx, nestriRunnerImage); err != nil {
|
||||
return fmt.Errorf("failed to pull runner image: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Look for existing stopped relay containers and remove them
|
||||
slog.Info("Checking and removing old relay containers")
|
||||
oldRelays, err := ctrEngine.ListContainersByImage(ctx, nestriRelayImage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range oldRelays {
|
||||
// If running, stop first
|
||||
if strings.Contains(strings.ToLower(c.State), "running") {
|
||||
slog.Info("Stopping old relay container", "id", c.ID)
|
||||
if err = ctrEngine.StopContainer(ctx, c.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
slog.Info("Removing old relay container", "id", c.ID)
|
||||
if err = ctrEngine.RemoveContainer(ctx, c.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Pull the relay image if not in debug mode
|
||||
if !internal.GetFlags().Debug {
|
||||
slog.Info("Pulling relay image", "image", nestriRelayImage)
|
||||
if err := ctrEngine.PullImage(ctx, nestriRelayImage); err != nil {
|
||||
return fmt.Errorf("failed to pull relay image: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateRunner creates a new runner image container
|
||||
func CreateRunner(ctx context.Context, ctrEngine containers.ContainerEngine) (string, error) {
|
||||
// For safety, limit to 4 runners
|
||||
if CountRunners() >= 4 {
|
||||
return "", fmt.Errorf("maximum number of runners reached")
|
||||
}
|
||||
|
||||
// Create the container
|
||||
containerID, err := ctrEngine.NewContainer(ctx, nestriRunnerImage, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Add the container to the managed list
|
||||
managedContainersMutex.Lock()
|
||||
defer managedContainersMutex.Unlock()
|
||||
managedContainers[containerID] = ManagedContainer{
|
||||
Container: containers.Container{
|
||||
ID: containerID,
|
||||
},
|
||||
Type: Runner,
|
||||
}
|
||||
|
||||
return containerID, nil
|
||||
}
|
||||
|
||||
// StartRunner starts a runner container, keeping track of it's state
|
||||
func StartRunner(ctx context.Context, ctrEngine containers.ContainerEngine, id string) error {
|
||||
// Verify the container is part of the managed list
|
||||
managedContainersMutex.RLock()
|
||||
if _, ok := managedContainers[id]; !ok {
|
||||
managedContainersMutex.RUnlock()
|
||||
return fmt.Errorf("container %s is not managed", id)
|
||||
}
|
||||
managedContainersMutex.RUnlock()
|
||||
|
||||
// Start the container
|
||||
if err := ctrEngine.StartContainer(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check container status in background at 10 second intervals, if it exits print it's logs
|
||||
go func() {
|
||||
err := monitorContainer(ctx, ctrEngine, id)
|
||||
if err != nil {
|
||||
slog.Error("failure while monitoring runner container", "id", id, "err", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRunner removes a runner container
|
||||
func RemoveRunner(ctx context.Context, ctrEngine containers.ContainerEngine, id string) error {
|
||||
// Stop the container if it's running
|
||||
if strings.Contains(strings.ToLower(managedContainers[id].State), "running") {
|
||||
if err := ctrEngine.StopContainer(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the container
|
||||
if err := ctrEngine.RemoveContainer(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the container from the managed list
|
||||
managedContainersMutex.Lock()
|
||||
defer managedContainersMutex.Unlock()
|
||||
delete(managedContainers, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRunners returns a list of all runner containers
|
||||
func ListRunners() []ManagedContainer {
|
||||
managedContainersMutex.Lock()
|
||||
defer managedContainersMutex.Unlock()
|
||||
var runners []ManagedContainer
|
||||
for _, v := range managedContainers {
|
||||
if v.Type == Runner {
|
||||
runners = append(runners, v)
|
||||
}
|
||||
}
|
||||
return runners
|
||||
}
|
||||
|
||||
// CountRunners returns the number of runner containers
|
||||
func CountRunners() int {
|
||||
return len(ListRunners())
|
||||
}
|
||||
|
||||
// CreateRelay creates a new relay image container
|
||||
func CreateRelay(ctx context.Context, ctrEngine containers.ContainerEngine) (string, error) {
|
||||
// Limit to 1 relay
|
||||
if CountRelays() >= 1 {
|
||||
return "", fmt.Errorf("maximum number of relays reached")
|
||||
}
|
||||
|
||||
// TODO: Placeholder for control secret, should be generated at runtime
|
||||
secretEnv := fmt.Sprintf("CONTROL_SECRET=%s", "1234")
|
||||
|
||||
// Create the container
|
||||
containerID, err := ctrEngine.NewContainer(ctx, nestriRelayImage, []string{secretEnv})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Add the container to the managed list
|
||||
managedContainersMutex.Lock()
|
||||
defer managedContainersMutex.Unlock()
|
||||
managedContainers[containerID] = ManagedContainer{
|
||||
Container: containers.Container{
|
||||
ID: containerID,
|
||||
},
|
||||
Type: Relay,
|
||||
}
|
||||
|
||||
return containerID, nil
|
||||
}
|
||||
|
||||
// StartRelay starts a relay container, keeping track of it's state
|
||||
func StartRelay(ctx context.Context, ctrEngine containers.ContainerEngine, id string) error {
|
||||
// Verify the container is part of the managed list
|
||||
managedContainersMutex.RLock()
|
||||
if _, ok := managedContainers[id]; !ok {
|
||||
managedContainersMutex.RUnlock()
|
||||
return fmt.Errorf("container %s is not managed", id)
|
||||
}
|
||||
managedContainersMutex.RUnlock()
|
||||
|
||||
// Start the container
|
||||
if err := ctrEngine.StartContainer(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check container status in background at 10 second intervals, if it exits print it's logs
|
||||
go func() {
|
||||
err := monitorContainer(ctx, ctrEngine, id)
|
||||
if err != nil {
|
||||
slog.Error("failure while monitoring relay container", "id", id, "err", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRelay removes a relay container
|
||||
func RemoveRelay(ctx context.Context, ctrEngine containers.ContainerEngine, id string) error {
|
||||
// Stop the container if it's running
|
||||
if strings.Contains(strings.ToLower(managedContainers[id].State), "running") {
|
||||
if err := ctrEngine.StopContainer(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the container
|
||||
if err := ctrEngine.RemoveContainer(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the container from the managed list
|
||||
managedContainersMutex.Lock()
|
||||
defer managedContainersMutex.Unlock()
|
||||
delete(managedContainers, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRelays returns a list of all relay containers
|
||||
func ListRelays() []ManagedContainer {
|
||||
managedContainersMutex.Lock()
|
||||
defer managedContainersMutex.Unlock()
|
||||
var relays []ManagedContainer
|
||||
for _, v := range managedContainers {
|
||||
if v.Type == Relay {
|
||||
relays = append(relays, v)
|
||||
}
|
||||
}
|
||||
return relays
|
||||
}
|
||||
|
||||
// CountRelays returns the number of relay containers
|
||||
func CountRelays() int {
|
||||
return len(ListRelays())
|
||||
}
|
||||
|
||||
// CleanupManaged stops and removes all managed containers
|
||||
func CleanupManaged(ctx context.Context, ctrEngine containers.ContainerEngine) error {
|
||||
if len(managedContainers) <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Info("Cleaning up managed containers")
|
||||
managedContainersMutex.Lock()
|
||||
defer managedContainersMutex.Unlock()
|
||||
for id := range managedContainers {
|
||||
// If running, stop first
|
||||
if strings.Contains(strings.ToLower(managedContainers[id].State), "running") {
|
||||
slog.Info("Stopping managed container", "id", id)
|
||||
if err := ctrEngine.StopContainer(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the container
|
||||
slog.Info("Removing managed container", "id", id)
|
||||
if err := ctrEngine.RemoveContainer(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
// Remove from the managed list
|
||||
delete(managedContainers, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func monitorContainer(ctx context.Context, ctrEngine containers.ContainerEngine, id string) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
// Check the container status
|
||||
ctr, err := ctrEngine.InspectContainer(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to inspect container: %w", err)
|
||||
}
|
||||
|
||||
// Update the container state in the managed list
|
||||
managedContainersMutex.Lock()
|
||||
managedContainers[id] = ManagedContainer{
|
||||
Container: containers.Container{
|
||||
ID: ctr.ID,
|
||||
Name: ctr.Name,
|
||||
State: ctr.State,
|
||||
Image: ctr.Image,
|
||||
},
|
||||
Type: Relay,
|
||||
}
|
||||
managedContainersMutex.Unlock()
|
||||
|
||||
if !strings.Contains(strings.ToLower(ctr.State), "running") {
|
||||
// Container is not running, print logs
|
||||
logs, err := ctrEngine.LogsContainer(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get container logs: %w", err)
|
||||
}
|
||||
return fmt.Errorf("container %s stopped running: %s", id, logs)
|
||||
}
|
||||
}
|
||||
// Sleep for 10 seconds
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(10 * time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
52
packages/maitred/internal/realtime/messages.go
Normal file
52
packages/maitred/internal/realtime/messages.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package realtime
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// BaseMessage is the generic top-level message structure
|
||||
type BaseMessage struct {
|
||||
Type string `json:"type"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
type CreatePayload struct{}
|
||||
|
||||
type StartPayload struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
}
|
||||
|
||||
type StopPayload struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
}
|
||||
|
||||
// ParseMessage parses a BaseMessage and returns the specific payload
|
||||
func ParseMessage(data []byte) (BaseMessage, interface{}, error) {
|
||||
var base BaseMessage
|
||||
if err := json.Unmarshal(data, &base); err != nil {
|
||||
return base, nil, err
|
||||
}
|
||||
|
||||
switch base.Type {
|
||||
case "create":
|
||||
var payload CreatePayload
|
||||
if err := json.Unmarshal(base.Payload, &payload); err != nil {
|
||||
return base, nil, err
|
||||
}
|
||||
return base, payload, nil
|
||||
case "start":
|
||||
var payload StartPayload
|
||||
if err := json.Unmarshal(base.Payload, &payload); err != nil {
|
||||
return base, nil, err
|
||||
}
|
||||
return base, payload, nil
|
||||
case "stop":
|
||||
var payload StopPayload
|
||||
if err := json.Unmarshal(base.Payload, &payload); err != nil {
|
||||
return base, nil, err
|
||||
}
|
||||
return base, payload, nil
|
||||
default:
|
||||
return base, base.Payload, nil
|
||||
}
|
||||
}
|
||||
182
packages/maitred/internal/realtime/realtime.go
Normal file
182
packages/maitred/internal/realtime/realtime.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package realtime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/eclipse/paho.golang/autopaho"
|
||||
"github.com/eclipse/paho.golang/paho"
|
||||
"log/slog"
|
||||
"nestri/maitred/internal/auth"
|
||||
"nestri/maitred/internal/containers"
|
||||
"nestri/maitred/internal/resource"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, machineID string, containerEngine containers.ContainerEngine, resource *resource.Resource) error {
|
||||
var clientID = generateClientID()
|
||||
var topic = fmt.Sprintf("%s/%s/%s", resource.App.Name, resource.App.Stage, machineID)
|
||||
var serverURL = fmt.Sprintf("wss://%s/mqtt?x-amz-customauthorizer-name=%s", resource.Realtime.Endpoint, resource.Realtime.Authorizer)
|
||||
|
||||
slog.Info("Realtime", "topic", topic)
|
||||
|
||||
userTokens, err := auth.FetchUserToken(machineID, resource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("Realtime", "token", userTokens.AccessToken)
|
||||
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
router := paho.NewStandardRouter()
|
||||
router.DefaultHandler(func(p *paho.Publish) {
|
||||
slog.Debug("DefaultHandler", "topic", p.Topic, "message", fmt.Sprintf("default handler received message: %s - with topic: %s", p.Payload, p.Topic))
|
||||
})
|
||||
|
||||
createTopic := fmt.Sprintf("%s/create", topic)
|
||||
slog.Debug("Registering handler", "topic", createTopic)
|
||||
router.RegisterHandler(createTopic, func(p *paho.Publish) {
|
||||
slog.Debug("Router", "message", "received create message with payload", fmt.Sprintf("%s", p.Payload))
|
||||
|
||||
base, _, err := ParseMessage(p.Payload)
|
||||
if err != nil {
|
||||
slog.Error("Router", "err", fmt.Sprintf("failed to parse message: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
if base.Type != "create" {
|
||||
slog.Error("Router", "err", "unexpected message type")
|
||||
return
|
||||
}
|
||||
|
||||
// Create runner container
|
||||
containerID, err := CreateRunner(ctx, containerEngine)
|
||||
if err != nil {
|
||||
slog.Error("Router", "err", fmt.Sprintf("failed to create runner container: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Router", "info", fmt.Sprintf("created runner container: %s", containerID))
|
||||
})
|
||||
|
||||
startTopic := fmt.Sprintf("%s/start", topic)
|
||||
slog.Debug("Registering handler", "topic", startTopic)
|
||||
router.RegisterHandler(startTopic, func(p *paho.Publish) {
|
||||
slog.Debug("Router", "message", "received start message with payload", fmt.Sprintf("%s", p.Payload))
|
||||
|
||||
base, payload, err := ParseMessage(p.Payload)
|
||||
if err != nil {
|
||||
slog.Error("Router", "err", fmt.Sprintf("failed to parse message: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
if base.Type != "start" {
|
||||
slog.Error("Router", "err", "unexpected message type")
|
||||
return
|
||||
}
|
||||
|
||||
// Get container ID
|
||||
startPayload, ok := payload.(StartPayload)
|
||||
if !ok {
|
||||
slog.Error("Router", "err", "failed to get payload")
|
||||
return
|
||||
}
|
||||
|
||||
// Start runner container
|
||||
if err = containerEngine.StartContainer(ctx, startPayload.ContainerID); err != nil {
|
||||
slog.Error("Router", "err", fmt.Sprintf("failed to start runner container: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Router", "info", fmt.Sprintf("started runner container: %s", startPayload.ContainerID))
|
||||
})
|
||||
|
||||
stopTopic := fmt.Sprintf("%s/stop", topic)
|
||||
slog.Debug("Registering handler", "topic", stopTopic)
|
||||
router.RegisterHandler(stopTopic, func(p *paho.Publish) {
|
||||
slog.Debug("Router", "message", "received stop message with payload", fmt.Sprintf("%s", p.Payload))
|
||||
|
||||
base, payload, err := ParseMessage(p.Payload)
|
||||
if err != nil {
|
||||
slog.Error("Router", "err", fmt.Sprintf("failed to parse message: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
if base.Type != "stop" {
|
||||
slog.Error("Router", "err", "unexpected message type")
|
||||
return
|
||||
}
|
||||
|
||||
// Get container ID
|
||||
stopPayload, ok := payload.(StopPayload)
|
||||
if !ok {
|
||||
slog.Error("Router", "err", "failed to get payload")
|
||||
return
|
||||
}
|
||||
|
||||
// Stop runner container
|
||||
if err = containerEngine.StopContainer(ctx, stopPayload.ContainerID); err != nil {
|
||||
slog.Error("Router", "err", fmt.Sprintf("failed to stop runner container: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Router", "info", fmt.Sprintf("stopped runner container: %s", stopPayload.ContainerID))
|
||||
})
|
||||
|
||||
legacyLogger := slog.NewLogLogger(slog.NewTextHandler(os.Stdout, nil), slog.LevelError)
|
||||
cliCfg := autopaho.ClientConfig{
|
||||
ServerUrls: []*url.URL{u},
|
||||
ConnectUsername: "",
|
||||
ConnectPassword: []byte(userTokens.AccessToken),
|
||||
KeepAlive: 20,
|
||||
CleanStartOnInitialConnection: true,
|
||||
SessionExpiryInterval: 60,
|
||||
ReconnectBackoff: autopaho.NewConstantBackoff(time.Second),
|
||||
OnConnectionUp: func(cm *autopaho.ConnectionManager, connAck *paho.Connack) {
|
||||
slog.Info("Router", "info", "MQTT connection is up and running")
|
||||
if _, err = cm.Subscribe(context.Background(), &paho.Subscribe{
|
||||
Subscriptions: []paho.SubscribeOptions{
|
||||
{Topic: fmt.Sprintf("%s/#", topic), QoS: 1},
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Error("Router", "err", fmt.Sprint("failed to subscribe, likely no messages will be received: ", err))
|
||||
}
|
||||
},
|
||||
Errors: legacyLogger,
|
||||
OnConnectError: func(err error) {
|
||||
slog.Error("Router", "err", fmt.Sprintf("error whilst attempting connection: %s", err))
|
||||
},
|
||||
ClientConfig: paho.ClientConfig{
|
||||
ClientID: clientID,
|
||||
OnPublishReceived: []func(paho.PublishReceived) (bool, error){
|
||||
func(pr paho.PublishReceived) (bool, error) {
|
||||
router.Route(pr.Packet.Packet())
|
||||
return true, nil
|
||||
}},
|
||||
OnClientError: func(err error) { slog.Error("Router", "err", fmt.Sprintf("client error: %s", err)) },
|
||||
OnServerDisconnect: func(d *paho.Disconnect) {
|
||||
if d.Properties != nil {
|
||||
slog.Info("Router", "info", fmt.Sprintf("server requested disconnect: %s", d.Properties.ReasonString))
|
||||
} else {
|
||||
slog.Info("Router", "info", fmt.Sprintf("server requested disconnect; reason code: %d", d.ReasonCode))
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c, err := autopaho.NewConnection(ctx, cliCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = c.AwaitConnection(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
17
packages/maitred/internal/realtime/utils.go
Normal file
17
packages/maitred/internal/realtime/utils.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package realtime
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
func generateClientID() string {
|
||||
// Create a source of entropy (cryptographically secure)
|
||||
entropy := ulid.Monotonic(rand.Reader, 0)
|
||||
// Generate a new ULID
|
||||
id := ulid.MustNew(ulid.Timestamp(time.Now()), entropy)
|
||||
// Create the client ID string
|
||||
return fmt.Sprintf("mch_%s", id.String())
|
||||
}
|
||||
46
packages/maitred/internal/resource/resource.go
Normal file
46
packages/maitred/internal/resource/resource.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package resource
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Resource struct {
|
||||
Api struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
Auth struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
AuthFingerprintKey struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
Realtime struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Authorizer string `json:"authorizer"`
|
||||
}
|
||||
App struct {
|
||||
Name string `json:"name"`
|
||||
Stage string `json:"stage"`
|
||||
}
|
||||
}
|
||||
|
||||
func NewResource() (*Resource, error) {
|
||||
resource := Resource{}
|
||||
val := reflect.ValueOf(&resource).Elem()
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Field(i)
|
||||
typeField := val.Type().Field(i)
|
||||
envVarName := fmt.Sprintf("SST_RESOURCE_%s", typeField.Name)
|
||||
envValue, exists := os.LookupEnv(envVarName)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("missing environment variable %s", envVarName)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(envValue), field.Addr().Interface()); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling %s: %w", envVarName, err)
|
||||
}
|
||||
}
|
||||
return &resource, nil
|
||||
}
|
||||
184
packages/maitred/internal/system/gpu.go
Normal file
184
packages/maitred/internal/system/gpu.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
pciClassVGA = 0x0300 // VGA compatible controller
|
||||
pciClass3D = 0x0302 // 3D controller
|
||||
pciClassDisplay = 0x0380 // Display controller
|
||||
pciClassCoProcessor = 0x0b40 // Co-processor (e.g., NVIDIA Tesla)
|
||||
)
|
||||
|
||||
type infoPair struct {
|
||||
Name string
|
||||
ID int
|
||||
}
|
||||
|
||||
type PCIInfo struct {
|
||||
Slot string
|
||||
Class infoPair
|
||||
Vendor infoPair
|
||||
Device infoPair
|
||||
SVendor infoPair
|
||||
SDevice infoPair
|
||||
Rev string
|
||||
ProgIf string
|
||||
Driver string
|
||||
Modules []string
|
||||
IOMMUGroup string
|
||||
}
|
||||
|
||||
const (
|
||||
VendorIntel = 0x8086
|
||||
VendorNVIDIA = 0x10de
|
||||
VendorAMD = 0x1002
|
||||
)
|
||||
|
||||
func GetAllGPUInfo() ([]PCIInfo, error) {
|
||||
var gpus []PCIInfo
|
||||
|
||||
cmd := exec.Command("lspci", "-mmvvvnnkD")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sections := bytes.Split(output, []byte("\n\n"))
|
||||
for _, section := range sections {
|
||||
var info PCIInfo
|
||||
|
||||
lines := bytes.Split(section, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
parts := bytes.SplitN(line, []byte(":"), 2)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(string(parts[0]))
|
||||
value := strings.TrimSpace(string(parts[1]))
|
||||
|
||||
switch key {
|
||||
case "Slot":
|
||||
info.Slot = value
|
||||
case "Class":
|
||||
info.Class, err = parseInfoPair(value)
|
||||
case "Vendor":
|
||||
info.Vendor, err = parseInfoPair(value)
|
||||
case "Device":
|
||||
info.Device, err = parseInfoPair(value)
|
||||
case "SVendor":
|
||||
info.SVendor, err = parseInfoPair(value)
|
||||
case "SDevice":
|
||||
info.SDevice, err = parseInfoPair(value)
|
||||
case "Rev":
|
||||
info.Rev = value
|
||||
case "ProgIf":
|
||||
info.ProgIf = value
|
||||
case "Driver":
|
||||
info.Driver = value
|
||||
case "Module":
|
||||
info.Modules = append(info.Modules, value)
|
||||
case "IOMMUGroup":
|
||||
info.IOMMUGroup = value
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a GPU device
|
||||
if isGPUClass(info.Class.ID) {
|
||||
gpus = append(gpus, info)
|
||||
}
|
||||
}
|
||||
|
||||
return gpus, nil
|
||||
}
|
||||
|
||||
// gets infoPair from "SomeName [SomeID]"
|
||||
// example: "DG2 [Arc A770] [56a0]" -> Name: "DG2 [Arc A770]", ID: "56a0"
|
||||
func parseInfoPair(pair string) (infoPair, error) {
|
||||
parts := strings.Split(pair, "[")
|
||||
if len(parts) < 2 {
|
||||
return infoPair{}, errors.New("invalid info pair")
|
||||
}
|
||||
|
||||
id := strings.TrimSuffix(parts[len(parts)-1], "]")
|
||||
name := strings.TrimSuffix(pair, "["+id)
|
||||
name = strings.TrimSpace(name)
|
||||
id = strings.TrimSpace(id)
|
||||
|
||||
// Remove ID including square brackets from name
|
||||
name = strings.ReplaceAll(name, "["+id+"]", "")
|
||||
name = strings.TrimSpace(name)
|
||||
|
||||
idHex, err := parseHexID(id)
|
||||
if err != nil {
|
||||
return infoPair{}, err
|
||||
}
|
||||
|
||||
return infoPair{
|
||||
Name: name,
|
||||
ID: idHex,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseHexID(id string) (int, error) {
|
||||
if strings.HasPrefix(id, "0x") {
|
||||
id = id[2:]
|
||||
}
|
||||
parsed, err := strconv.ParseInt(id, 16, 32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(parsed), nil
|
||||
}
|
||||
|
||||
func isGPUClass(class int) bool {
|
||||
return class == pciClassVGA || class == pciClass3D || class == pciClassDisplay || class == pciClassCoProcessor
|
||||
}
|
||||
|
||||
// GetCardDevices returns the /dev/dri/cardX and /dev/dri/renderDXXX device
|
||||
func (info PCIInfo) GetCardDevices() (cardPath, renderPath string, err error) {
|
||||
busID := strings.ToLower(info.Slot)
|
||||
if !strings.HasPrefix(busID, "0000:") || len(busID) != 12 || busID[4] != ':' || busID[7] != ':' || busID[10] != '.' {
|
||||
return "", "", fmt.Errorf("invalid PCI Bus ID format: %s (expected 0000:XX:YY.Z)", busID)
|
||||
}
|
||||
|
||||
byPathDir := "/dev/dri/by-path/"
|
||||
entries, err := os.ReadDir(byPathDir)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to read %s: %v", byPathDir, err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasPrefix(name, "pci-"+busID+"-card") {
|
||||
cardPath, err = filepath.EvalSymlinks(filepath.Join(byPathDir, name))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to resolve card symlink %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(name, "pci-"+busID+"-render") {
|
||||
renderPath, err = filepath.EvalSymlinks(filepath.Join(byPathDir, name))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to resolve render symlink %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cardPath == "" && renderPath == "" {
|
||||
return "", "", fmt.Errorf("no DRM devices found for PCI Bus ID: %s", busID)
|
||||
}
|
||||
return cardPath, renderPath, nil
|
||||
}
|
||||
290
packages/maitred/internal/system/gpu_intel.go
Normal file
290
packages/maitred/internal/system/gpu_intel.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// FDInfo holds parsed fdinfo data
|
||||
type FDInfo struct {
|
||||
ClientID string
|
||||
EngineTime uint64 // i915: "drm-engine-render" in ns
|
||||
Cycles uint64 // Xe: "drm-cycles-rcs"
|
||||
TotalCycles uint64 // Xe: "drm-total-cycles-rcs"
|
||||
MemoryVRAM uint64 // i915: "drm-memory-vram", Xe: "drm-total-vram0" in bytes
|
||||
}
|
||||
|
||||
// findCardX maps PCI slot to /dev/dri/cardX
|
||||
func findCardX(pciSlot string) (string, error) {
|
||||
driPath := "/sys/class/drm"
|
||||
entries, err := os.ReadDir(driPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read /sys/class/drm: %v", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if strings.HasPrefix(entry.Name(), "card") {
|
||||
deviceLink := filepath.Join(driPath, entry.Name(), "device")
|
||||
target, err := os.Readlink(deviceLink)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(target, pciSlot) {
|
||||
return entry.Name(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no cardX found for PCI slot %s", pciSlot)
|
||||
}
|
||||
|
||||
// getDriver retrieves the driver name
|
||||
func getDriver(cardX string) (string, error) {
|
||||
driverLink := filepath.Join("/sys/class/drm", cardX, "device", "driver")
|
||||
target, err := os.Readlink(driverLink)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read driver link for %s: %v", cardX, err)
|
||||
}
|
||||
return filepath.Base(target), nil
|
||||
}
|
||||
|
||||
// collectFDInfo gathers fdinfo data
|
||||
func collectFDInfo(cardX string) ([]FDInfo, error) {
|
||||
var fdInfos []FDInfo
|
||||
clientIDs := make(map[string]struct{})
|
||||
|
||||
procDirs, err := os.ReadDir("/proc")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read /proc: %v", err)
|
||||
}
|
||||
|
||||
for _, procDir := range procDirs {
|
||||
if !procDir.IsDir() {
|
||||
continue
|
||||
}
|
||||
pid := procDir.Name()
|
||||
if _, err := strconv.Atoi(pid); err != nil {
|
||||
continue
|
||||
}
|
||||
fdDir := filepath.Join("/proc", pid, "fd")
|
||||
fdEntries, err := os.ReadDir(fdDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, fdEntry := range fdEntries {
|
||||
fdPath := filepath.Join(fdDir, fdEntry.Name())
|
||||
target, err := os.Readlink(fdPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if target == "/dev/dri/"+cardX {
|
||||
fdinfoPath := filepath.Join("/proc", pid, "fdinfo", fdEntry.Name())
|
||||
file, err := os.Open(fdinfoPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
var clientID, engineTime, cycles, totalCycles, memoryVRAM string
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
switch key {
|
||||
case "drm-client-id":
|
||||
clientID = value
|
||||
case "drm-engine-render":
|
||||
engineTime = value
|
||||
case "drm-cycles-rcs":
|
||||
cycles = value
|
||||
case "drm-total-cycles-rcs":
|
||||
totalCycles = value
|
||||
case "drm-memory-vram", "drm-total-vram0": // i915 and Xe keys
|
||||
memoryVRAM = value
|
||||
}
|
||||
}
|
||||
if clientID == "" || clientID == "0" {
|
||||
continue
|
||||
}
|
||||
if _, exists := clientIDs[clientID]; exists {
|
||||
continue
|
||||
}
|
||||
clientIDs[clientID] = struct{}{}
|
||||
|
||||
fdInfo := FDInfo{ClientID: clientID}
|
||||
if engineTime != "" {
|
||||
fdInfo.EngineTime, _ = strconv.ParseUint(engineTime, 10, 64)
|
||||
}
|
||||
if cycles != "" {
|
||||
fdInfo.Cycles, _ = strconv.ParseUint(cycles, 10, 64)
|
||||
}
|
||||
if totalCycles != "" {
|
||||
fdInfo.TotalCycles, _ = strconv.ParseUint(totalCycles, 10, 64)
|
||||
}
|
||||
if memoryVRAM != "" {
|
||||
if strings.HasSuffix(memoryVRAM, " kB") || strings.HasSuffix(memoryVRAM, " KiB") {
|
||||
memKB := strings.TrimSuffix(strings.TrimSuffix(memoryVRAM, " kB"), " KiB")
|
||||
if mem, err := strconv.ParseUint(memKB, 10, 64); err == nil {
|
||||
fdInfo.MemoryVRAM = mem * 1024 // Convert kB to bytes
|
||||
}
|
||||
} else {
|
||||
fdInfo.MemoryVRAM, _ = strconv.ParseUint(memoryVRAM, 10, 64) // Assume bytes if no unit
|
||||
}
|
||||
}
|
||||
fdInfos = append(fdInfos, fdInfo)
|
||||
_ = file.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
return fdInfos, nil
|
||||
}
|
||||
|
||||
// drmIoctl wraps the syscall.Syscall for ioctl
|
||||
func drmIoctl(fd int, request uintptr, data unsafe.Pointer) error {
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), request, uintptr(data))
|
||||
if errno != 0 {
|
||||
return fmt.Errorf("ioctl failed: %v", errno)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func monitorIntelGPU(device PCIInfo) GPUUsage {
|
||||
// Map PCI slot to cardX
|
||||
cardX, err := findCardX(device.Slot)
|
||||
if err != nil {
|
||||
slog.Warn("failed to find cardX for Intel GPU", "slot", device.Slot, "error", err)
|
||||
return GPUUsage{}
|
||||
}
|
||||
|
||||
// Determine driver
|
||||
driver, err := getDriver(cardX)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get driver", "card", cardX, "error", err)
|
||||
return GPUUsage{}
|
||||
}
|
||||
if driver != "i915" && driver != "xe" {
|
||||
slog.Warn("unsupported Intel driver", "driver", driver, "card", cardX)
|
||||
return GPUUsage{}
|
||||
}
|
||||
|
||||
// PCIInfo also has the driver, let's warn if they don't match
|
||||
if device.Driver != driver {
|
||||
slog.Warn("driver mismatch", "card", cardX, "lspci driver", device.Driver, "sysfs driver", driver)
|
||||
}
|
||||
|
||||
// Open DRM device
|
||||
cardPath := "/dev/dri/" + cardX
|
||||
fd, err := syscall.Open(cardPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
slog.Error("failed to open DRM device", "path", cardPath, "error", err)
|
||||
return GPUUsage{}
|
||||
}
|
||||
defer func(fd int) {
|
||||
_ = syscall.Close(fd)
|
||||
}(fd)
|
||||
|
||||
// Get total and used VRAM via ioctl
|
||||
var totalVRAM, usedVRAMFromIOCTL uint64
|
||||
if driver == "i915" {
|
||||
totalVRAM, usedVRAMFromIOCTL, err = getMemoryRegionsI915(fd)
|
||||
} else { // xe
|
||||
totalVRAM, usedVRAMFromIOCTL, err = queryMemoryRegionsXE(fd)
|
||||
}
|
||||
if err != nil {
|
||||
//slog.Debug("failed to get memory regions", "card", cardX, "error", err)
|
||||
// Proceed with totalVRAM = 0 if ioctl fails
|
||||
}
|
||||
|
||||
// Collect samples for usage percentage
|
||||
firstFDInfos, err := collectFDInfo(cardX)
|
||||
if err != nil {
|
||||
slog.Warn("failed to collect first FDInfo", "card", cardX, "error", err)
|
||||
return GPUUsage{}
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
secondFDInfos, err := collectFDInfo(cardX)
|
||||
if err != nil {
|
||||
slog.Warn("failed to collect second FDInfo", "card", cardX, "error", err)
|
||||
return GPUUsage{}
|
||||
}
|
||||
|
||||
// Calculate usage percentage
|
||||
var usagePercent float64
|
||||
if driver == "i915" {
|
||||
var totalDeltaTime uint64
|
||||
for _, second := range secondFDInfos {
|
||||
for _, first := range firstFDInfos {
|
||||
if second.ClientID == first.ClientID {
|
||||
totalDeltaTime += second.EngineTime - first.EngineTime
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if totalDeltaTime > 0 {
|
||||
usagePercent = float64(totalDeltaTime) / 1e9 * 100 // ns to percent
|
||||
}
|
||||
} else { // xe
|
||||
var totalDeltaCycles, deltaTotalCycles uint64
|
||||
for i, second := range secondFDInfos {
|
||||
for _, first := range firstFDInfos {
|
||||
if second.ClientID == first.ClientID {
|
||||
deltaCycles := second.Cycles - first.Cycles
|
||||
totalDeltaCycles += deltaCycles
|
||||
if i == 0 {
|
||||
deltaTotalCycles = second.TotalCycles - first.TotalCycles
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if deltaTotalCycles > 0 {
|
||||
usagePercent = float64(totalDeltaCycles) / float64(deltaTotalCycles) * 100
|
||||
}
|
||||
}
|
||||
if usagePercent > 100 {
|
||||
usagePercent = 100
|
||||
}
|
||||
|
||||
// Sum per-process VRAM usage as fallback
|
||||
var usedVRAM uint64
|
||||
for _, fdInfo := range secondFDInfos {
|
||||
usedVRAM += fdInfo.MemoryVRAM
|
||||
}
|
||||
|
||||
// Prefer ioctl used VRAM if available and non-zero
|
||||
if usedVRAMFromIOCTL != 0 {
|
||||
usedVRAM = usedVRAMFromIOCTL
|
||||
}
|
||||
|
||||
// Compute VRAM metrics
|
||||
var freeVRAM uint64
|
||||
var usedPercent float64
|
||||
if totalVRAM > 0 {
|
||||
if usedVRAM > totalVRAM {
|
||||
usedVRAM = totalVRAM
|
||||
}
|
||||
freeVRAM = totalVRAM - usedVRAM
|
||||
usedPercent = float64(usedVRAM) / float64(totalVRAM) * 100
|
||||
}
|
||||
|
||||
return GPUUsage{
|
||||
Info: device,
|
||||
UsagePercent: usagePercent,
|
||||
VRAM: VRAMUsage{
|
||||
Total: totalVRAM,
|
||||
Used: usedVRAM,
|
||||
Free: freeVRAM,
|
||||
UsedPercent: usedPercent,
|
||||
},
|
||||
}
|
||||
}
|
||||
86
packages/maitred/internal/system/gpu_intel_i915.go
Normal file
86
packages/maitred/internal/system/gpu_intel_i915.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Constants for i915
|
||||
const (
|
||||
DRM_COMMAND_BASE = 0x40
|
||||
DRM_I915_QUERY = 0x39
|
||||
DRM_IOCTL_I915_QUERY = 0x80106479 // _IOWR('d', 0x79, 16)
|
||||
DRM_I915_QUERY_MEMORY_REGIONS = 4
|
||||
I915_MEMORY_CLASS_DEVICE = 1
|
||||
)
|
||||
|
||||
// drmI915QueryItem mirrors struct drm_i915_query_item
|
||||
type drmI915QueryItem struct {
|
||||
QueryID uintptr
|
||||
Length int32
|
||||
Flags uint32
|
||||
DataPtr uintptr
|
||||
}
|
||||
|
||||
// drmI915Query mirrors struct drm_i915_query
|
||||
type drmI915Query struct {
|
||||
NumItems uint32
|
||||
Flags uint32
|
||||
ItemsPtr uintptr
|
||||
}
|
||||
|
||||
// drmI915MemoryRegionInfo mirrors struct drm_i915_memory_region_info
|
||||
type drmI915MemoryRegionInfo struct {
|
||||
Region struct {
|
||||
MemoryClass uint16
|
||||
MemoryInstance uint16
|
||||
}
|
||||
Rsvd0 uint32
|
||||
ProbedSize uint64
|
||||
UnallocatedSize uint64
|
||||
Rsvd1 [8]uint64
|
||||
}
|
||||
|
||||
func getMemoryRegionsI915(fd int) (totalVRAM, usedVRAM uint64, err error) {
|
||||
// Step 1: Get the required buffer size
|
||||
item := drmI915QueryItem{
|
||||
QueryID: DRM_I915_QUERY_MEMORY_REGIONS,
|
||||
Length: 0,
|
||||
}
|
||||
query := drmI915Query{
|
||||
NumItems: 1,
|
||||
ItemsPtr: uintptr(unsafe.Pointer(&item)),
|
||||
}
|
||||
if err = drmIoctl(fd, DRM_IOCTL_I915_QUERY, unsafe.Pointer(&query)); err != nil {
|
||||
return 0, 0, fmt.Errorf("initial i915 query failed: %v", err)
|
||||
}
|
||||
if item.Length <= 0 {
|
||||
return 0, 0, fmt.Errorf("i915 query returned invalid length: %d", item.Length)
|
||||
}
|
||||
|
||||
// Step 2: Allocate buffer and perform the query
|
||||
data := make([]byte, item.Length)
|
||||
item.DataPtr = uintptr(unsafe.Pointer(&data[0]))
|
||||
if err = drmIoctl(fd, DRM_IOCTL_I915_QUERY, unsafe.Pointer(&query)); err != nil {
|
||||
return 0, 0, fmt.Errorf("second i915 query failed: %v", err)
|
||||
}
|
||||
|
||||
// Step 3: Parse the memory regions
|
||||
numRegions := *(*uint32)(unsafe.Pointer(&data[0]))
|
||||
headerSize := uint32(16) // num_regions (4) + rsvd[3] (12) = 16 bytes
|
||||
regionSize := uint32(88) // Size of drm_i915_memory_region_info (calculated: 4+4+8+8+64)
|
||||
|
||||
for i := uint32(0); i < numRegions; i++ {
|
||||
offset := headerSize + i*regionSize
|
||||
if offset+regionSize > uint32(len(data)) {
|
||||
return 0, 0, fmt.Errorf("data buffer too small for i915 region %d", i)
|
||||
}
|
||||
mr := (*drmI915MemoryRegionInfo)(unsafe.Pointer(&data[offset]))
|
||||
if mr.Region.MemoryClass == I915_MEMORY_CLASS_DEVICE {
|
||||
totalVRAM += mr.ProbedSize
|
||||
usedVRAM += mr.ProbedSize - mr.UnallocatedSize
|
||||
}
|
||||
}
|
||||
|
||||
return totalVRAM, usedVRAM, nil
|
||||
}
|
||||
84
packages/maitred/internal/system/gpu_intel_xe.go
Normal file
84
packages/maitred/internal/system/gpu_intel_xe.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Constants from xe_drm.h
|
||||
const (
|
||||
DRM_XE_DEVICE_QUERY_MEM_REGIONS = 1
|
||||
DRM_XE_MEM_REGION_CLASS_VRAM = 1
|
||||
DRM_XE_DEVICE_QUERY = 0x00
|
||||
DRM_IOCTL_XE_DEVICE_QUERY uintptr = 0xC0286440 // Precomputed as above
|
||||
)
|
||||
|
||||
// drmXEDeviceQuery mirrors struct drm_xe_device_query
|
||||
type drmXEDeviceQuery struct {
|
||||
Extensions uint64
|
||||
Query uint32
|
||||
Size uint32
|
||||
Data uint64
|
||||
Reserved [2]uint64
|
||||
}
|
||||
|
||||
// drmXEQueryMemRegions mirrors struct drm_xe_query_mem_regions header
|
||||
type drmXEQueryMemRegions struct {
|
||||
NumMemRegions uint32
|
||||
Pad uint32
|
||||
// mem_regions[] follows
|
||||
}
|
||||
|
||||
// drmXEMemRegion mirrors struct drm_xe_mem_region
|
||||
type drmXEMemRegion struct {
|
||||
MemClass uint16
|
||||
Instance uint16
|
||||
MinPageSize uint32
|
||||
TotalSize uint64
|
||||
Used uint64
|
||||
CPUVisibleSize uint64
|
||||
CPUVisibleUsed uint64
|
||||
Reserved [6]uint64
|
||||
}
|
||||
|
||||
func queryMemoryRegionsXE(fd int) (totalVRAM, usedVRAM uint64, err error) {
|
||||
// Step 1: Get the required size
|
||||
query := drmXEDeviceQuery{
|
||||
Query: DRM_XE_DEVICE_QUERY_MEM_REGIONS,
|
||||
Size: 0,
|
||||
}
|
||||
if err = drmIoctl(fd, DRM_IOCTL_XE_DEVICE_QUERY, unsafe.Pointer(&query)); err != nil {
|
||||
return 0, 0, fmt.Errorf("initial xe query failed: %v", err)
|
||||
}
|
||||
if query.Size == 0 {
|
||||
return 0, 0, fmt.Errorf("xe query returned zero size")
|
||||
}
|
||||
|
||||
// Step 2: Allocate buffer and perform the query
|
||||
data := make([]byte, query.Size)
|
||||
query.Data = uint64(uintptr(unsafe.Pointer(&data[0])))
|
||||
query.Size = uint32(len(data))
|
||||
if err = drmIoctl(fd, DRM_IOCTL_XE_DEVICE_QUERY, unsafe.Pointer(&query)); err != nil {
|
||||
return 0, 0, fmt.Errorf("second xe query failed: %v", err)
|
||||
}
|
||||
|
||||
// Step 3: Parse the memory regions
|
||||
header := (*drmXEQueryMemRegions)(unsafe.Pointer(&data[0]))
|
||||
numRegions := header.NumMemRegions
|
||||
headerSize := unsafe.Sizeof(drmXEQueryMemRegions{})
|
||||
regionSize := unsafe.Sizeof(drmXEMemRegion{})
|
||||
|
||||
for i := uint32(0); i < numRegions; i++ {
|
||||
offset := headerSize + uintptr(i)*regionSize
|
||||
if offset+regionSize > uintptr(len(data)) {
|
||||
return 0, 0, fmt.Errorf("data buffer too small for xe region %d", i)
|
||||
}
|
||||
mr := (*drmXEMemRegion)(unsafe.Pointer(&data[offset]))
|
||||
if mr.MemClass == DRM_XE_MEM_REGION_CLASS_VRAM {
|
||||
totalVRAM += mr.TotalSize
|
||||
usedVRAM += mr.Used
|
||||
}
|
||||
}
|
||||
|
||||
return totalVRAM, usedVRAM, nil
|
||||
}
|
||||
57
packages/maitred/internal/system/gpu_nvidia.go
Normal file
57
packages/maitred/internal/system/gpu_nvidia.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// monitorNVIDIAGPU monitors an NVIDIA GPU using nvidia-smi
|
||||
func monitorNVIDIAGPU(device PCIInfo) GPUUsage {
|
||||
// Query nvidia-smi for GPU metrics
|
||||
cmd := exec.Command("nvidia-smi", "--query-gpu=pci.bus_id,utilization.gpu,memory.total,memory.used,memory.free", "--format=csv,noheader,nounits")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
slog.Warn("failed to run nvidia-smi", "error", err)
|
||||
return GPUUsage{}
|
||||
}
|
||||
|
||||
// Parse output and find matching GPU
|
||||
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
|
||||
for _, line := range lines {
|
||||
fields := strings.Split(line, ", ")
|
||||
if len(fields) != 5 {
|
||||
continue
|
||||
}
|
||||
busID := fields[0] // e.g., "0000:01:00.0"
|
||||
if strings.Contains(busID, device.Slot) || strings.Contains(device.Slot, busID) {
|
||||
usagePercent, _ := strconv.ParseFloat(fields[1], 64)
|
||||
totalMiB, _ := strconv.ParseUint(fields[2], 10, 64)
|
||||
usedMiB, _ := strconv.ParseUint(fields[3], 10, 64)
|
||||
freeMiB, _ := strconv.ParseUint(fields[4], 10, 64)
|
||||
|
||||
// Convert MiB to bytes
|
||||
total := totalMiB * 1024 * 1024
|
||||
used := usedMiB * 1024 * 1024
|
||||
free := freeMiB * 1024 * 1024
|
||||
usedPercent := float64(0)
|
||||
if total > 0 {
|
||||
usedPercent = float64(used) / float64(total) * 100
|
||||
}
|
||||
|
||||
return GPUUsage{
|
||||
Info: device,
|
||||
UsagePercent: usagePercent,
|
||||
VRAM: VRAMUsage{
|
||||
Total: total,
|
||||
Used: used,
|
||||
Free: free,
|
||||
UsedPercent: usedPercent,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Warn("No NVIDIA GPU found matching PCI slot", "slot", device.Slot)
|
||||
return GPUUsage{}
|
||||
}
|
||||
24
packages/maitred/internal/system/id.go
Normal file
24
packages/maitred/internal/system/id.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
dbusPath = "/var/lib/dbus/machine-id"
|
||||
dbusPathEtc = "/etc/machine-id"
|
||||
)
|
||||
|
||||
// GetID returns the machine ID specified at `/var/lib/dbus/machine-id` or `/etc/machine-id`.
|
||||
// If there is an error reading the files an empty string is returned.
|
||||
func GetID() (string, error) {
|
||||
id, err := os.ReadFile(dbusPath)
|
||||
if err != nil {
|
||||
id, err = os.ReadFile(dbusPathEtc)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.Trim(string(id), " \n"), nil
|
||||
}
|
||||
405
packages/maitred/internal/system/resources.go
Normal file
405
packages/maitred/internal/system/resources.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CPUInfo contains CPU model information
|
||||
type CPUInfo struct {
|
||||
Vendor string `json:"vendor"` // CPU vendor (e.g., "AMD", "Intel")
|
||||
Model string `json:"model"` // CPU model name
|
||||
}
|
||||
|
||||
// CPUUsage contains CPU usage metrics
|
||||
type CPUUsage struct {
|
||||
Info CPUInfo `json:"info"` // CPU vendor and model information
|
||||
Total float64 `json:"total"` // Total CPU usage in percentage (0-100)
|
||||
PerCore []float64 `json:"per_core"` // CPU usage per core in percentage (0-100)
|
||||
}
|
||||
|
||||
// MemoryUsage contains memory usage metrics
|
||||
type MemoryUsage struct {
|
||||
Total uint64 `json:"total"` // Total memory in bytes
|
||||
Used uint64 `json:"used"` // Used memory in bytes
|
||||
Available uint64 `json:"available"` // Available memory in bytes
|
||||
Free uint64 `json:"free"` // Free memory in bytes
|
||||
UsedPercent float64 `json:"used_percent"` // Used memory in percentage (0-100)
|
||||
}
|
||||
|
||||
// FilesystemUsage contains usage metrics for a filesystem path
|
||||
type FilesystemUsage struct {
|
||||
Path string `json:"path"` // Filesystem path
|
||||
Total uint64 `json:"total"` // Total disk space in bytes
|
||||
Used uint64 `json:"used"` // Used disk space in bytes
|
||||
Free uint64 `json:"free"` // Free disk space in bytes
|
||||
UsedPercent float64 `json:"used_percent"` // Used disk space in percentage (0-100)
|
||||
}
|
||||
|
||||
// GPUUsage contains GPU usage metrics
|
||||
type GPUUsage struct {
|
||||
Info PCIInfo `json:"pci_info"` // GPU PCI information
|
||||
UsagePercent float64 `json:"usage_percent"` // GPU usage in percentage (0-100)
|
||||
VRAM VRAMUsage `json:"vram"` // GPU memory usage metrics
|
||||
}
|
||||
|
||||
// VRAMUsage contains GPU memory usage metrics
|
||||
type VRAMUsage struct {
|
||||
Total uint64 `json:"total"` // Total VRAM in bytes
|
||||
Used uint64 `json:"used"` // Used VRAM in bytes
|
||||
Free uint64 `json:"free"` // Free VRAM in bytes
|
||||
UsedPercent float64 `json:"used_percent"` // Used VRAM in percentage (0-100)
|
||||
}
|
||||
|
||||
// ResourceUsage contains resource usage metrics
|
||||
type ResourceUsage struct {
|
||||
CPU CPUUsage `json:"cpu"` // CPU usage metrics
|
||||
Memory MemoryUsage `json:"memory"` // Memory usage metrics
|
||||
Disk FilesystemUsage `json:"disk"` // Disk usage metrics
|
||||
GPUs []GPUUsage `json:"gpus"` // Per-GPU usage metrics
|
||||
}
|
||||
|
||||
var (
|
||||
lastUsage ResourceUsage
|
||||
lastUsageMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// GetSystemUsage returns last known system resource usage metrics
|
||||
func GetSystemUsage() ResourceUsage {
|
||||
lastUsageMutex.RLock()
|
||||
defer lastUsageMutex.RUnlock()
|
||||
return lastUsage
|
||||
}
|
||||
|
||||
// StartMonitoring begins periodic system usage monitoring with the given interval
|
||||
func StartMonitoring(ctx context.Context, interval time.Duration) {
|
||||
slog.Info("Starting system monitoring")
|
||||
go func() {
|
||||
// Initial sample immediately
|
||||
updateUsage()
|
||||
|
||||
// Ticker for periodic updates
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("Stopping system monitoring")
|
||||
return
|
||||
case <-ticker.C:
|
||||
updateUsage()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// updateUsage collects and updates the lastUsage variable
|
||||
func updateUsage() {
|
||||
// Collect CPU usage
|
||||
cpu := GetCPUUsage()
|
||||
|
||||
// Collect memory usage
|
||||
memory := GetMemoryUsage()
|
||||
|
||||
// Collect root filesystem usage
|
||||
rootfs, err := GetFilesystemUsage("/")
|
||||
if err != nil {
|
||||
slog.Warn("Failed to get root filesystem usage", "error", err)
|
||||
}
|
||||
|
||||
// Collect GPU usage
|
||||
gpus := GetGPUUsage()
|
||||
|
||||
// Update shared variable safely
|
||||
lastUsageMutex.Lock()
|
||||
lastUsage = ResourceUsage{
|
||||
CPU: cpu,
|
||||
Memory: memory,
|
||||
Disk: rootfs,
|
||||
GPUs: gpus,
|
||||
}
|
||||
lastUsageMutex.Unlock()
|
||||
}
|
||||
|
||||
// PrettyString returns resource usage metrics in a human-readable format string
|
||||
func (r ResourceUsage) PrettyString() string {
|
||||
res := "Resource Usage:\n"
|
||||
res += fmt.Sprintf(" CPU:\n")
|
||||
res += fmt.Sprintf(" Vendor: %s\n", r.CPU.Info.Vendor)
|
||||
res += fmt.Sprintf(" Model: %s\n", r.CPU.Info.Model)
|
||||
res += fmt.Sprintf(" Total Usage: %.2f%%\n", r.CPU.Total)
|
||||
res += fmt.Sprintf(" Per-Core Usage:\n")
|
||||
res += fmt.Sprintf(" [")
|
||||
for i, coreUsage := range r.CPU.PerCore {
|
||||
res += fmt.Sprintf("%.2f%%", coreUsage)
|
||||
if i < len(r.CPU.PerCore)-1 {
|
||||
res += ", "
|
||||
}
|
||||
}
|
||||
res += "]\n"
|
||||
|
||||
res += fmt.Sprintf(" Memory:\n")
|
||||
res += fmt.Sprintf(" Total: %d bytes\n", r.Memory.Total)
|
||||
res += fmt.Sprintf(" Used: %d bytes\n", r.Memory.Used)
|
||||
res += fmt.Sprintf(" Available: %d bytes\n", r.Memory.Available)
|
||||
res += fmt.Sprintf(" Free: %d bytes\n", r.Memory.Free)
|
||||
res += fmt.Sprintf(" Used Percent: %.2f%%\n", r.Memory.UsedPercent)
|
||||
|
||||
res += fmt.Sprintf(" Filesystem:\n")
|
||||
res += fmt.Sprintf(" Path: %s\n", r.Disk.Path)
|
||||
res += fmt.Sprintf(" Total: %d bytes\n", r.Disk.Total)
|
||||
res += fmt.Sprintf(" Used: %d bytes\n", r.Disk.Used)
|
||||
res += fmt.Sprintf(" Free: %d bytes\n", r.Disk.Free)
|
||||
res += fmt.Sprintf(" Used Percent: %.2f%%\n", r.Disk.UsedPercent)
|
||||
|
||||
res += fmt.Sprintf(" GPUs:\n")
|
||||
for i, gpu := range r.GPUs {
|
||||
cardDev, renderDev, err := gpu.Info.GetCardDevices()
|
||||
if err != nil {
|
||||
slog.Warn("Failed to get card and render devices", "error", err)
|
||||
}
|
||||
|
||||
res += fmt.Sprintf(" GPU %d:\n", i)
|
||||
res += fmt.Sprintf(" Vendor: %s\n", gpu.Info.Vendor.Name)
|
||||
res += fmt.Sprintf(" Model: %s\n", gpu.Info.Device.Name)
|
||||
res += fmt.Sprintf(" Driver: %s\n", gpu.Info.Driver)
|
||||
res += fmt.Sprintf(" Card Device: %s\n", cardDev)
|
||||
res += fmt.Sprintf(" Render Device: %s\n", renderDev)
|
||||
res += fmt.Sprintf(" Usage Percent: %.2f%%\n", gpu.UsagePercent)
|
||||
res += fmt.Sprintf(" VRAM:\n")
|
||||
res += fmt.Sprintf(" Total: %d bytes\n", gpu.VRAM.Total)
|
||||
res += fmt.Sprintf(" Used: %d bytes\n", gpu.VRAM.Used)
|
||||
res += fmt.Sprintf(" Free: %d bytes\n", gpu.VRAM.Free)
|
||||
res += fmt.Sprintf(" Used Percent: %.2f%%\n", gpu.VRAM.UsedPercent)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// GetCPUUsage gathers CPU usage
|
||||
func GetCPUUsage() CPUUsage {
|
||||
// Helper to read /proc/stat
|
||||
readStat := func() (uint64, uint64, []uint64, []uint64) {
|
||||
statBytes, err := os.ReadFile("/proc/stat")
|
||||
if err != nil {
|
||||
slog.Warn("Failed to read /proc/stat", "error", err)
|
||||
return 0, 0, nil, nil
|
||||
}
|
||||
statScanner := bufio.NewScanner(bytes.NewReader(statBytes))
|
||||
statScanner.Scan() // Total CPU line
|
||||
fields := strings.Fields(statScanner.Text())[1:]
|
||||
var total, idle uint64
|
||||
for i, field := range fields {
|
||||
val, _ := strconv.ParseUint(field, 10, 64)
|
||||
total += val
|
||||
if i == 3 { // Idle time
|
||||
idle = val
|
||||
}
|
||||
}
|
||||
|
||||
var perCoreTotals, perCoreIdles []uint64
|
||||
for statScanner.Scan() {
|
||||
line := statScanner.Text()
|
||||
if !strings.HasPrefix(line, "cpu") {
|
||||
break
|
||||
}
|
||||
coreFields := strings.Fields(line)[1:]
|
||||
var coreTotal, coreIdle uint64
|
||||
for i, field := range coreFields {
|
||||
val, _ := strconv.ParseUint(field, 10, 64)
|
||||
coreTotal += val
|
||||
if i == 3 { // Idle time
|
||||
coreIdle = val
|
||||
}
|
||||
}
|
||||
perCoreTotals = append(perCoreTotals, coreTotal)
|
||||
perCoreIdles = append(perCoreIdles, coreIdle)
|
||||
}
|
||||
return total, idle, perCoreTotals, perCoreIdles
|
||||
}
|
||||
|
||||
// First sample
|
||||
prevTotal, prevIdle, prevPerCoreTotals, prevPerCoreIdles := readStat()
|
||||
time.Sleep(1 * time.Second) // Delay for accurate delta
|
||||
// Second sample
|
||||
currTotal, currIdle, currPerCoreTotals, currPerCoreIdles := readStat()
|
||||
|
||||
// Calculate total CPU usage
|
||||
totalDiff := float64(currTotal - prevTotal)
|
||||
idleDiff := float64(currIdle - prevIdle)
|
||||
var totalUsage float64
|
||||
if totalDiff > 0 {
|
||||
totalUsage = ((totalDiff - idleDiff) / totalDiff) * 100
|
||||
}
|
||||
|
||||
// Calculate per-core usage
|
||||
var perCore []float64
|
||||
for i := range currPerCoreTotals {
|
||||
coreTotalDiff := float64(currPerCoreTotals[i] - prevPerCoreTotals[i])
|
||||
coreIdleDiff := float64(currPerCoreIdles[i] - prevPerCoreIdles[i])
|
||||
if coreTotalDiff > 0 {
|
||||
perCoreUsage := ((coreTotalDiff - coreIdleDiff) / coreTotalDiff) * 100
|
||||
perCore = append(perCore, perCoreUsage)
|
||||
} else {
|
||||
perCore = append(perCore, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// Get CPU info
|
||||
cpuInfoBytes, err := os.ReadFile("/proc/cpuinfo")
|
||||
if err != nil {
|
||||
slog.Warn("Failed to read /proc/cpuinfo", "error", err)
|
||||
return CPUUsage{}
|
||||
}
|
||||
cpuInfo := string(cpuInfoBytes)
|
||||
scanner := bufio.NewScanner(strings.NewReader(cpuInfo))
|
||||
var vendor, model string
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "vendor_id") {
|
||||
vendor = strings.TrimSpace(strings.Split(line, ":")[1])
|
||||
} else if strings.HasPrefix(line, "model name") {
|
||||
model = strings.TrimSpace(strings.Split(line, ":")[1])
|
||||
}
|
||||
if vendor != "" && model != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return CPUUsage{
|
||||
Info: CPUInfo{
|
||||
Vendor: vendor,
|
||||
Model: model,
|
||||
},
|
||||
Total: totalUsage,
|
||||
PerCore: perCore,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMemoryUsage gathers memory usage from /proc/meminfo
|
||||
func GetMemoryUsage() MemoryUsage {
|
||||
data, err := os.ReadFile("/proc/meminfo")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||
var total, free, available uint64
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "MemTotal:") {
|
||||
total = parseMemInfoLine(line)
|
||||
} else if strings.HasPrefix(line, "MemFree:") {
|
||||
free = parseMemInfoLine(line)
|
||||
} else if strings.HasPrefix(line, "MemAvailable:") {
|
||||
available = parseMemInfoLine(line)
|
||||
}
|
||||
}
|
||||
|
||||
used := total - available
|
||||
usedPercent := (float64(used) / float64(total)) * 100
|
||||
|
||||
return MemoryUsage{
|
||||
Total: total * 1024, // Convert from KB to bytes
|
||||
Used: used * 1024,
|
||||
Available: available * 1024,
|
||||
Free: free * 1024,
|
||||
UsedPercent: usedPercent,
|
||||
}
|
||||
}
|
||||
|
||||
// parseMemInfoLine parses a line from /proc/meminfo
|
||||
func parseMemInfoLine(line string) uint64 {
|
||||
fields := strings.Fields(line)
|
||||
val, _ := strconv.ParseUint(fields[1], 10, 64)
|
||||
return val
|
||||
}
|
||||
|
||||
// GetFilesystemUsage gathers usage statistics for the specified path
|
||||
func GetFilesystemUsage(path string) (FilesystemUsage, error) {
|
||||
cmd := exec.Command("df", path)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return FilesystemUsage{}, err
|
||||
}
|
||||
|
||||
lines := strings.Split(string(output), "\n")
|
||||
if len(lines) < 2 {
|
||||
return FilesystemUsage{}, fmt.Errorf("unexpected `df` output format for path: %s", path)
|
||||
}
|
||||
|
||||
fields := strings.Fields(lines[1])
|
||||
if len(fields) < 5 {
|
||||
return FilesystemUsage{}, fmt.Errorf("insufficient fields in `df` output for path: %s", path)
|
||||
}
|
||||
|
||||
total, err := strconv.ParseUint(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
return FilesystemUsage{}, fmt.Errorf("failed to parse total space: %v", err)
|
||||
}
|
||||
|
||||
used, err := strconv.ParseUint(fields[2], 10, 64)
|
||||
if err != nil {
|
||||
return FilesystemUsage{}, fmt.Errorf("failed to parse used space: %v", err)
|
||||
}
|
||||
|
||||
free, err := strconv.ParseUint(fields[3], 10, 64)
|
||||
if err != nil {
|
||||
return FilesystemUsage{}, fmt.Errorf("failed to parse free space: %v", err)
|
||||
}
|
||||
|
||||
usedPercent, err := strconv.ParseFloat(strings.TrimSuffix(fields[4], "%"), 64)
|
||||
if err != nil {
|
||||
return FilesystemUsage{}, fmt.Errorf("failed to parse used percentage: %v", err)
|
||||
}
|
||||
|
||||
return FilesystemUsage{
|
||||
Path: path,
|
||||
Total: total * 1024,
|
||||
Used: used * 1024,
|
||||
Free: free * 1024,
|
||||
UsedPercent: usedPercent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetGPUUsage gathers GPU usage for all detected GPUs
|
||||
func GetGPUUsage() []GPUUsage {
|
||||
var gpus []GPUUsage
|
||||
|
||||
// Detect all GPUs
|
||||
pciInfos, err := GetAllGPUInfo()
|
||||
if err != nil {
|
||||
slog.Warn("Failed to get GPU info", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Monitor each GPU
|
||||
for _, gpu := range pciInfos {
|
||||
var gpuUsage GPUUsage
|
||||
switch gpu.Vendor.ID {
|
||||
case VendorIntel:
|
||||
gpuUsage = monitorIntelGPU(gpu)
|
||||
case VendorNVIDIA:
|
||||
gpuUsage = monitorNVIDIAGPU(gpu)
|
||||
case VendorAMD:
|
||||
// TODO: Implement if needed
|
||||
continue
|
||||
default:
|
||||
continue
|
||||
}
|
||||
gpus = append(gpus, gpuUsage)
|
||||
}
|
||||
|
||||
return gpus
|
||||
}
|
||||
Reference in New Issue
Block a user