mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 08:45:38 +02:00
✨ feat: Add auth flow (#146)
This adds a simple way to incorporate a centralized authentication flow. The idea is to have the user, API and SSH (for machine authentication) all in one place using `openauthjs` + `SST` We also have a database now :) > We are using InstantDB as it allows us to authenticate a use with just the email. Plus it is super simple simple to use _of course after the initial fumbles trying to design the db and relationships_
This commit is contained in:
26
packages/cli/internal/api/api.go
Normal file
26
packages/cli/internal/api/api.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"nestrilabs/cli/internal/resource"
|
||||
|
||||
"github.com/nestrilabs/nestri-go-sdk"
|
||||
"github.com/nestrilabs/nestri-go-sdk/option"
|
||||
)
|
||||
|
||||
func RegisterMachine(token string) {
|
||||
client := nestri.NewClient(
|
||||
option.WithBearerToken(token),
|
||||
option.WithBaseURL(resource.Resource.Api.Url),
|
||||
)
|
||||
|
||||
machine, err := client.Machines.New(
|
||||
context.TODO(),
|
||||
nestri.MachineNewParams{})
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
fmt.Printf("%+v\n", machine.Data)
|
||||
}
|
||||
44
packages/cli/internal/auth/auth.go
Normal file
44
packages/cli/internal/auth/auth.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"nestrilabs/cli/internal/machine"
|
||||
"nestrilabs/cli/internal/resource"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type UserCredentials struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
func FetchUserCredentials() (*UserCredentials, error) {
|
||||
m := machine.NewMachine()
|
||||
fingerprint := m.GetMachineID()
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "client_credentials")
|
||||
data.Set("client_id", "device")
|
||||
data.Set("client_secret", resource.Resource.AuthFingerprintKey.Value)
|
||||
data.Set("hostname", m.Hostname)
|
||||
data.Set("fingerprint", fingerprint)
|
||||
data.Set("provider", "device")
|
||||
resp, err := http.PostForm(resource.Resource.Auth.Url+"/token", data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
fmt.Println(string(body))
|
||||
return nil, fmt.Errorf("failed to auth: " + string(body))
|
||||
}
|
||||
credentials := UserCredentials{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &credentials, nil
|
||||
}
|
||||
202
packages/cli/internal/machine/machine.go
Normal file
202
packages/cli/internal/machine/machine.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package machine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
type Machine struct {
|
||||
OperatingSystem string
|
||||
Arch string
|
||||
Kernel string
|
||||
Virtualization string
|
||||
Hostname string
|
||||
}
|
||||
|
||||
func NewMachine() *Machine {
|
||||
var OS string
|
||||
var architecture string
|
||||
var kernel string
|
||||
var virtualisation string
|
||||
var hostname string
|
||||
|
||||
output, _ := exec.Command("hostnamectl", "status").Output()
|
||||
os := regexp.MustCompile(`Operating System:\s+(.*)`)
|
||||
matchingOS := os.FindStringSubmatch(string(output))
|
||||
if len(matchingOS) > 1 {
|
||||
OS = matchingOS[1]
|
||||
}
|
||||
|
||||
arch := regexp.MustCompile(`Architecture:\s+(\w+)`)
|
||||
matchingArch := arch.FindStringSubmatch(string(output))
|
||||
if len(matchingArch) > 1 {
|
||||
architecture = matchingArch[1]
|
||||
}
|
||||
|
||||
kern := regexp.MustCompile(`Kernel:\s+(.*)`)
|
||||
matchingKernel := kern.FindStringSubmatch(string(output))
|
||||
if len(matchingKernel) > 1 {
|
||||
kernel = matchingKernel[1]
|
||||
}
|
||||
|
||||
virt := regexp.MustCompile(`Virtualization:\s+(\w+)`)
|
||||
matchingVirt := virt.FindStringSubmatch(string(output))
|
||||
if len(matchingVirt) > 1 {
|
||||
virtualisation = matchingVirt[1]
|
||||
}
|
||||
|
||||
host := regexp.MustCompile(`Static hostname:\s+(.*)`)
|
||||
matchingHost := host.FindStringSubmatch(string(output))
|
||||
if len(matchingHost) > 1 {
|
||||
hostname = matchingHost[1]
|
||||
}
|
||||
|
||||
return &Machine{
|
||||
OperatingSystem: OS,
|
||||
Arch: architecture,
|
||||
Kernel: kernel,
|
||||
Virtualization: virtualisation,
|
||||
Hostname: hostname,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Machine) GetOS() string {
|
||||
if m.OperatingSystem != "" {
|
||||
return m.OperatingSystem
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (m *Machine) GetArchitecture() string {
|
||||
|
||||
if m.Arch != "" {
|
||||
return m.Arch
|
||||
}
|
||||
return "unknown"
|
||||
|
||||
}
|
||||
|
||||
func (m *Machine) GetKernel() string {
|
||||
if m.Kernel != "" {
|
||||
return m.Kernel
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (m *Machine) GetVirtualization() string {
|
||||
if m.Virtualization != "" {
|
||||
return m.Virtualization
|
||||
}
|
||||
return "none"
|
||||
}
|
||||
|
||||
func (m *Machine) GetHostname() string {
|
||||
if m.Hostname != "" {
|
||||
return m.Hostname
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (m *Machine) GetMachineID() string {
|
||||
id, err := os.ReadFile("/etc/machine-id")
|
||||
if err != nil {
|
||||
log.Error("Error getting your machine's ID", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return strings.TrimSpace(string(id))
|
||||
}
|
||||
|
||||
func (m *Machine) GPUInfo() (string, string, error) {
|
||||
// The command for GPU information varies depending on the system and drivers.
|
||||
// lshw is a good general-purpose tool, but might need adjustments for specific hardware.
|
||||
output, err := exec.Command("lshw", "-C", "display").Output()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get GPU information: %w", err)
|
||||
}
|
||||
|
||||
gpuType := ""
|
||||
gpuSize := ""
|
||||
|
||||
// Regular expressions for extracting product and size information. These might need to be
|
||||
// adapted based on the output of lshw on your specific system.
|
||||
typeRegex := regexp.MustCompile(`product:\s+(.*)`)
|
||||
sizeRegex := regexp.MustCompile(`size:\s+(\d+MiB)`) // Example: extracts size in MiB
|
||||
|
||||
typeMatch := typeRegex.FindStringSubmatch(string(output))
|
||||
if len(typeMatch) > 1 {
|
||||
gpuType = typeMatch[1]
|
||||
}
|
||||
|
||||
sizeMatch := sizeRegex.FindStringSubmatch(string(output))
|
||||
if len(sizeMatch) > 1 {
|
||||
gpuSize = sizeMatch[1]
|
||||
}
|
||||
|
||||
if gpuType == "" && gpuSize == "" {
|
||||
return "", "", fmt.Errorf("could not parse GPU information using lshw")
|
||||
}
|
||||
|
||||
return gpuType, gpuSize, nil
|
||||
}
|
||||
|
||||
func (m *Machine) GetCPUInfo() (string, string, error) {
|
||||
output, err := exec.Command("lscpu").Output()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get CPU information: %w", err)
|
||||
}
|
||||
|
||||
cpuType := ""
|
||||
cpuSize := "" // This will store the number of cores
|
||||
|
||||
typeRegex := regexp.MustCompile(`Model name:\s+(.*)`)
|
||||
coresRegex := regexp.MustCompile(`CPU\(s\):\s+(\d+)`)
|
||||
|
||||
typeMatch := typeRegex.FindStringSubmatch(string(output))
|
||||
if len(typeMatch) > 1 {
|
||||
cpuType = typeMatch[1]
|
||||
}
|
||||
|
||||
coresMatch := coresRegex.FindStringSubmatch(string(output))
|
||||
if len(coresMatch) > 1 {
|
||||
cpuSize = coresMatch[1]
|
||||
}
|
||||
|
||||
if cpuType == "" && cpuSize == "" {
|
||||
return "", "", fmt.Errorf("could not parse CPU information using lscpu")
|
||||
}
|
||||
|
||||
return cpuType, cpuSize, nil
|
||||
|
||||
}
|
||||
|
||||
func (m *Machine) GetRAMSize() (string, error) {
|
||||
output, err := exec.Command("free", "-h", "--si").Output() // Using -h for human-readable and --si for base-10 units
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get RAM information: %w", err)
|
||||
}
|
||||
|
||||
ramSize := ""
|
||||
|
||||
ramRegex := regexp.MustCompile(`Mem:\s+(\S+)`) // Matches the total memory size
|
||||
|
||||
ramMatch := ramRegex.FindStringSubmatch(string(output))
|
||||
if len(ramMatch) > 1 {
|
||||
ramSize = ramMatch[1]
|
||||
} else {
|
||||
return "", fmt.Errorf("could not parse RAM information from free command")
|
||||
}
|
||||
|
||||
return ramSize, nil
|
||||
}
|
||||
|
||||
// func cleanString(s string) string {
|
||||
// s = strings.ToLower(s)
|
||||
|
||||
// reg := regexp.MustCompile("[^a-z0-9]+") // Matches one or more non-alphanumeric characters
|
||||
// return reg.ReplaceAllString(s, "")
|
||||
// }
|
||||
112
packages/cli/internal/party/client.go
Normal file
112
packages/cli/internal/party/client.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package party
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"nestrilabs/cli/internal/machine"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
// Initial retry delay
|
||||
initialRetryDelay = 1 * time.Second
|
||||
// Maximum retry delay
|
||||
maxRetryDelay = 30 * time.Second
|
||||
// Factor to increase delay by after each attempt
|
||||
backoffFactor = 2
|
||||
)
|
||||
|
||||
type Party struct {
|
||||
// Channel to signal shutdown
|
||||
done chan struct{}
|
||||
fingerprint string
|
||||
hostname string
|
||||
}
|
||||
|
||||
func NewParty() *Party {
|
||||
m := machine.NewMachine()
|
||||
fingerpint := m.GetMachineID()
|
||||
return &Party{
|
||||
done: make(chan struct{}),
|
||||
fingerprint: fingerpint,
|
||||
hostname: m.Hostname,
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully closes the connection
|
||||
func (p *Party) Shutdown() {
|
||||
close(p.done)
|
||||
}
|
||||
|
||||
func (p *Party) Connect() {
|
||||
baseURL := fmt.Sprintf("ws://localhost:1999/parties/main/%s", p.fingerprint)
|
||||
params := url.Values{}
|
||||
params.Add("_pk", p.hostname)
|
||||
wsURL := baseURL + "?" + params.Encode()
|
||||
|
||||
retryDelay := initialRetryDelay
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
log.Info("Shutting down connection")
|
||||
return
|
||||
default:
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to connect to party server", "err", err)
|
||||
time.Sleep(retryDelay)
|
||||
// Increase retry delay exponentially, but cap it
|
||||
retryDelay = time.Duration(float64(retryDelay) * backoffFactor)
|
||||
if retryDelay > maxRetryDelay {
|
||||
retryDelay = maxRetryDelay
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Reset retry delay on successful connection
|
||||
retryDelay = initialRetryDelay
|
||||
|
||||
// Handle connection in a separate goroutine
|
||||
connectionClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(connectionClosed)
|
||||
defer conn.Close()
|
||||
|
||||
// Send initial message
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello there")); err != nil {
|
||||
log.Error("Failed to send initial message", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Read messages loop
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
default:
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
log.Error("Error reading message", "err", err)
|
||||
return
|
||||
}
|
||||
log.Info("Received message from party server", "message", string(message))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for either connection to close or shutdown signal
|
||||
select {
|
||||
case <-connectionClosed:
|
||||
log.Warn("Connection closed, attempting to reconnect...")
|
||||
time.Sleep(retryDelay)
|
||||
case <-p.done:
|
||||
log.Info("Shutting down connection")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
125
packages/cli/internal/party/retry.go
Normal file
125
packages/cli/internal/party/retry.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package party
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"nestrilabs/cli/internal/machine"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// RetryConfig holds configuration for retry behavior
|
||||
type RetryConfig struct {
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
BackoffFactor float64
|
||||
MaxAttempts int // use 0 for infinite retries
|
||||
}
|
||||
|
||||
// DefaultRetryConfig provides sensible default values
|
||||
var DefaultRetryConfig = RetryConfig{
|
||||
InitialDelay: time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
MaxAttempts: 0, // infinite retries
|
||||
}
|
||||
|
||||
// RetryFunc is a function that will be retried
|
||||
type RetryFunc[T any] func() (T, error)
|
||||
|
||||
// Retry executes the given function with retries based on the config
|
||||
func Retry[T any](config RetryConfig, operation RetryFunc[T]) (T, error) {
|
||||
var result T
|
||||
currentDelay := config.InitialDelay
|
||||
attempts := 0
|
||||
|
||||
for {
|
||||
if config.MaxAttempts > 0 && attempts >= config.MaxAttempts {
|
||||
return result, fmt.Errorf("max retry attempts (%d) exceeded", config.MaxAttempts)
|
||||
}
|
||||
|
||||
result, err := operation()
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
log.Warn("Operation failed, retrying...",
|
||||
"attempt", attempts+1,
|
||||
"delay", currentDelay,
|
||||
"error", err)
|
||||
|
||||
time.Sleep(currentDelay)
|
||||
|
||||
// Increase delay for next attempt
|
||||
currentDelay = time.Duration(float64(currentDelay) * config.BackoffFactor)
|
||||
if currentDelay > config.MaxDelay {
|
||||
currentDelay = config.MaxDelay
|
||||
}
|
||||
|
||||
attempts++
|
||||
}
|
||||
}
|
||||
|
||||
// MessageHandler processes a message and returns true if it's the expected type
|
||||
type MessageHandler[T any] func(msg T) bool
|
||||
|
||||
type TypeListener[T any] struct {
|
||||
retryConfig RetryConfig
|
||||
handler MessageHandler[T]
|
||||
fingerprint string
|
||||
hostname string
|
||||
}
|
||||
|
||||
func NewTypeListener[T any](handler MessageHandler[T]) *TypeListener[T] {
|
||||
m := machine.NewMachine()
|
||||
fingerprint := m.GetMachineID()
|
||||
|
||||
return &TypeListener[T]{
|
||||
retryConfig: DefaultRetryConfig,
|
||||
handler: handler,
|
||||
fingerprint: fingerprint,
|
||||
hostname: m.Hostname,
|
||||
}
|
||||
}
|
||||
|
||||
// SetRetryConfig allows customizing the retry behavior
|
||||
func (t *TypeListener[T]) SetRetryConfig(config RetryConfig) {
|
||||
t.retryConfig = config
|
||||
}
|
||||
|
||||
func (t *TypeListener[T]) ConnectUntilMessage() (T, error) {
|
||||
baseURL := fmt.Sprintf("ws://localhost:1999/parties/main/%s", t.fingerprint)
|
||||
params := url.Values{}
|
||||
params.Add("_pk", t.hostname)
|
||||
wsURL := baseURL + "?" + params.Encode()
|
||||
|
||||
return Retry(t.retryConfig, func() (T, error) {
|
||||
var result T
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("connection failed: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read messages until we get the one we want
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("read error: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(message, &result); err != nil {
|
||||
// log.Error("Failed to unmarshal message", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if t.handler(result) {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
38
packages/cli/internal/resource/resource.go
Normal file
38
packages/cli/internal/resource/resource.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package resource
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type resource struct {
|
||||
Api struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
Auth struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
AuthFingerprintKey struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
}
|
||||
|
||||
var Resource resource
|
||||
|
||||
func init() {
|
||||
val := reflect.ValueOf(&Resource).Elem()
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Field(i)
|
||||
typeField := val.Type().Field(i)
|
||||
envVarName := fmt.Sprintf("SST_RESOURCE_%s", typeField.Name)
|
||||
envValue, exists := os.LookupEnv(envVarName)
|
||||
if !exists {
|
||||
panic(fmt.Sprintf("Environment variable %s is required", envVarName))
|
||||
}
|
||||
if err := json.Unmarshal([]byte(envValue), field.Addr().Interface()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
286
packages/cli/internal/session/start.go
Normal file
286
packages/cli/internal/session/start.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/image"
|
||||
"github.com/docker/docker/client"
|
||||
)
|
||||
|
||||
// GPUType represents the type of GPU available
|
||||
type GPUType int
|
||||
|
||||
const (
|
||||
GPUNone GPUType = iota
|
||||
GPUNvidia
|
||||
GPUIntelAMD
|
||||
)
|
||||
|
||||
// Session represents a Docker container session
|
||||
type Session struct {
|
||||
client *client.Client
|
||||
containerID string
|
||||
imageName string
|
||||
config *SessionConfig
|
||||
mu sync.RWMutex
|
||||
isRunning bool
|
||||
}
|
||||
|
||||
// SessionConfig holds the configuration for the session
|
||||
type SessionConfig struct {
|
||||
Room string
|
||||
Resolution string
|
||||
Framerate string
|
||||
RelayURL string
|
||||
Params string
|
||||
GamePath string
|
||||
}
|
||||
|
||||
// NewSession creates a new Docker session
|
||||
func NewSession(config *SessionConfig) (*Session, error) {
|
||||
cli, err := client.NewClientWithOpts(client.FromEnv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Docker client: %v", err)
|
||||
}
|
||||
|
||||
return &Session{
|
||||
client: cli,
|
||||
imageName: "archlinux", //"ghcr.io/datcaptainhorse/nestri-cachyos:latest-noavx2",
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start initiates the Docker container session
|
||||
func (s *Session) Start(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.isRunning {
|
||||
return fmt.Errorf("session is already running")
|
||||
}
|
||||
|
||||
// Detect GPU type
|
||||
gpuType := detectGPU()
|
||||
if gpuType == GPUNone {
|
||||
return fmt.Errorf("no supported GPU detected")
|
||||
}
|
||||
|
||||
// Get GPU-specific configurations
|
||||
deviceRequests, err := getGPUDeviceRequests(gpuType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
devices := getGPUDevices(gpuType)
|
||||
|
||||
// Check if image exists locally
|
||||
_, _, err = s.client.ImageInspectWithRaw(ctx, s.imageName)
|
||||
if err != nil {
|
||||
// Pull the image if it doesn't exist
|
||||
reader, err := s.client.ImagePull(ctx, s.imageName, image.PullOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pull image: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Copy pull output to stdout
|
||||
io.Copy(os.Stdout, reader)
|
||||
}
|
||||
|
||||
// Create container
|
||||
resp, err := s.client.ContainerCreate(ctx, &container.Config{
|
||||
Image: s.imageName,
|
||||
Env: []string{
|
||||
fmt.Sprintf("NESTRI_ROOM=%s", s.config.Room),
|
||||
fmt.Sprintf("RESOLUTION=%s", s.config.Resolution),
|
||||
fmt.Sprintf("NESTRI_PARAMS=%s", s.config.Params),
|
||||
fmt.Sprintf("FRAMERATE=%s", s.config.Framerate),
|
||||
fmt.Sprintf("RELAY_URL=%s", s.config.RelayURL),
|
||||
},
|
||||
}, &container.HostConfig{
|
||||
Binds: []string{
|
||||
fmt.Sprintf("%s:/home/nestri/.steam/", s.config.GamePath),
|
||||
},
|
||||
Resources: container.Resources{
|
||||
DeviceRequests: deviceRequests,
|
||||
Devices: devices,
|
||||
},
|
||||
SecurityOpt: []string{"label=disable"},
|
||||
ShmSize: 5368709120, // 5GB
|
||||
// ShmSize: 1073741824, // 1GB
|
||||
}, nil, nil, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create container: %v", err)
|
||||
}
|
||||
|
||||
// Start container
|
||||
if err := s.client.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
|
||||
return fmt.Errorf("failed to start container: %v", err)
|
||||
}
|
||||
|
||||
// Store container ID and update state
|
||||
s.containerID = resp.ID
|
||||
s.isRunning = true
|
||||
|
||||
// Start logging in a goroutine
|
||||
go s.streamLogs(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the Docker container session
|
||||
func (s *Session) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.isRunning {
|
||||
return fmt.Errorf("session is not running")
|
||||
}
|
||||
|
||||
timeout := 30 // seconds
|
||||
if err := s.client.ContainerStop(ctx, s.containerID, container.StopOptions{Timeout: &timeout}); err != nil {
|
||||
return fmt.Errorf("failed to stop container: %v", err)
|
||||
}
|
||||
|
||||
if err := s.client.ContainerRemove(ctx, s.containerID, container.RemoveOptions{}); err != nil {
|
||||
return fmt.Errorf("failed to remove container: %v", err)
|
||||
}
|
||||
|
||||
s.isRunning = false
|
||||
s.containerID = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRunning returns the current state of the session
|
||||
func (s *Session) IsRunning() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.isRunning
|
||||
}
|
||||
|
||||
// GetContainerID returns the current container ID
|
||||
func (s *Session) GetContainerID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.containerID
|
||||
}
|
||||
|
||||
// streamLogs streams container logs to stdout
|
||||
func (s *Session) streamLogs(ctx context.Context) {
|
||||
opts := container.LogsOptions{
|
||||
ShowStdout: true,
|
||||
ShowStderr: true,
|
||||
Follow: true,
|
||||
}
|
||||
|
||||
logs, err := s.client.ContainerLogs(ctx, s.containerID, opts)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error getting container logs: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer logs.Close()
|
||||
|
||||
_, err = io.Copy(os.Stdout, logs)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error streaming logs: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyEnvironment checks if all expected environment variables are set correctly in the container
|
||||
func (s *Session) VerifyEnvironment(ctx context.Context) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if !s.isRunning {
|
||||
return fmt.Errorf("session is not running")
|
||||
}
|
||||
|
||||
// Get container info to verify it's actually running
|
||||
inspect, err := s.client.ContainerInspect(ctx, s.containerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to inspect container: %v", err)
|
||||
}
|
||||
|
||||
if !inspect.State.Running {
|
||||
return fmt.Errorf("container is not in running state")
|
||||
}
|
||||
|
||||
// Expected environment variables
|
||||
expectedEnv := map[string]string{
|
||||
"NESTRI_ROOM": s.config.Room,
|
||||
"RESOLUTION": s.config.Resolution,
|
||||
"FRAMERATE": s.config.Framerate,
|
||||
"RELAY_URL": s.config.RelayURL,
|
||||
"NESTRI_PARAMS": s.config.Params,
|
||||
}
|
||||
|
||||
// Get actual environment variables from container
|
||||
containerEnv := make(map[string]string)
|
||||
for _, env := range inspect.Config.Env {
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
containerEnv[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// Check each expected variable
|
||||
var missingVars []string
|
||||
var mismatchedVars []string
|
||||
|
||||
for key, expectedValue := range expectedEnv {
|
||||
actualValue, exists := containerEnv[key]
|
||||
if !exists {
|
||||
missingVars = append(missingVars, key)
|
||||
} else if actualValue != expectedValue {
|
||||
mismatchedVars = append(mismatchedVars, fmt.Sprintf("%s (expected: %s, got: %s)",
|
||||
key, expectedValue, actualValue))
|
||||
}
|
||||
}
|
||||
|
||||
// Build error message if there are any issues
|
||||
if len(missingVars) > 0 || len(mismatchedVars) > 0 {
|
||||
var errorMsg strings.Builder
|
||||
if len(missingVars) > 0 {
|
||||
errorMsg.WriteString(fmt.Sprintf("Missing environment variables: %s\n",
|
||||
strings.Join(missingVars, ", ")))
|
||||
}
|
||||
if len(mismatchedVars) > 0 {
|
||||
errorMsg.WriteString(fmt.Sprintf("Mismatched environment variables: %s",
|
||||
strings.Join(mismatchedVars, ", ")))
|
||||
}
|
||||
return fmt.Errorf(errorMsg.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEnvironment returns all environment variables in the container
|
||||
func (s *Session) GetEnvironment(ctx context.Context) (map[string]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if !s.isRunning {
|
||||
return nil, fmt.Errorf("session is not running")
|
||||
}
|
||||
|
||||
inspect, err := s.client.ContainerInspect(ctx, s.containerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to inspect container: %v", err)
|
||||
}
|
||||
|
||||
env := make(map[string]string)
|
||||
for _, e := range inspect.Config.Env {
|
||||
parts := strings.SplitN(e, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
env[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
return env, nil
|
||||
}
|
||||
76
packages/cli/internal/session/steam.go
Normal file
76
packages/cli/internal/session/steam.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
)
|
||||
|
||||
// ExecResult holds the output from a container command
|
||||
type ExecResult struct {
|
||||
ExitCode int
|
||||
Stdout string
|
||||
Stderr string
|
||||
}
|
||||
|
||||
func (s *Session) execInContainer(ctx context.Context, cmd []string) (*ExecResult, error) {
|
||||
execConfig := container.ExecOptions{
|
||||
Cmd: cmd,
|
||||
AttachStdout: true,
|
||||
AttachStderr: true,
|
||||
}
|
||||
|
||||
execID, err := s.client.ContainerExecCreate(ctx, s.containerID, execConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.client.ContainerExecAttach(ctx, execID.ID, container.ExecAttachOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
_, err = io.Copy(&outBuf, resp.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inspect, err := s.client.ContainerExecInspect(ctx, execID.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ExecResult{
|
||||
ExitCode: inspect.ExitCode,
|
||||
Stdout: outBuf.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckSteamGames returns the list of installed games in the container
|
||||
func (s *Session) CheckInstalledSteamGames(ctx context.Context) ([]uint64, error) {
|
||||
result, err := s.execInContainer(ctx, []string{
|
||||
"sh", "-c",
|
||||
"find /home/nestri/.steam/steam/steamapps -name '*.acf' -exec grep -H '\"appid\"' {} \\;",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check steam games: %v", err)
|
||||
}
|
||||
|
||||
var gameIDs []uint64
|
||||
for _, line := range strings.Split(result.Stdout, "\n") {
|
||||
if strings.Contains(line, "appid") {
|
||||
var id uint64
|
||||
if _, err := fmt.Sscanf(line, `"appid" "%d"`, &id); err == nil {
|
||||
gameIDs = append(gameIDs, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return gameIDs, nil
|
||||
}
|
||||
72
packages/cli/internal/session/utils.go
Normal file
72
packages/cli/internal/session/utils.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
)
|
||||
|
||||
// detectGPU checks for available GPU type
|
||||
func detectGPU() GPUType {
|
||||
// First check for NVIDIA
|
||||
cmd := exec.Command("nvidia-smi")
|
||||
if err := cmd.Run(); err == nil {
|
||||
return GPUNvidia
|
||||
}
|
||||
|
||||
// Check for Intel/AMD GPU by looking for DRI devices
|
||||
if _, err := os.Stat("/dev/dri"); err == nil {
|
||||
return GPUIntelAMD
|
||||
}
|
||||
|
||||
return GPUNone
|
||||
}
|
||||
|
||||
// getGPUDeviceRequests returns appropriate device configuration based on GPU type
|
||||
func getGPUDeviceRequests(gpuType GPUType) ([]container.DeviceRequest, error) {
|
||||
switch gpuType {
|
||||
case GPUNvidia:
|
||||
return []container.DeviceRequest{
|
||||
{
|
||||
Driver: "nvidia",
|
||||
Count: 1,
|
||||
DeviceIDs: []string{"0"},
|
||||
Capabilities: [][]string{{"gpu"}},
|
||||
},
|
||||
}, nil
|
||||
case GPUIntelAMD:
|
||||
return []container.DeviceRequest{}, nil // Empty as we'll handle this in Devices
|
||||
default:
|
||||
return nil, fmt.Errorf("no supported GPU detected")
|
||||
}
|
||||
}
|
||||
|
||||
// getGPUDevices returns appropriate device mappings based on GPU type
|
||||
func getGPUDevices(gpuType GPUType) []container.DeviceMapping {
|
||||
if gpuType == GPUIntelAMD {
|
||||
devices := []container.DeviceMapping{}
|
||||
// Only look for card and renderD nodes
|
||||
for _, pattern := range []string{"card[0-9]*", "renderD[0-9]*"} {
|
||||
matches, err := filepath.Glob(fmt.Sprintf("/dev/dri/%s", pattern))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
// Verify it's a device file
|
||||
if info, err := os.Stat(match); err == nil && (info.Mode()&os.ModeDevice) != 0 {
|
||||
devices = append(devices, container.DeviceMapping{
|
||||
PathOnHost: match,
|
||||
PathInContainer: match,
|
||||
CgroupPermissions: "rwm",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return devices
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user