feat: Add a websocket party (#152)

This adds functionality to connect to remote server thru the party
This commit is contained in:
Wanjohi
2025-01-05 23:45:41 +03:00
committed by GitHub
parent c15657a0d1
commit 56b877fa27
10 changed files with 384 additions and 252 deletions

View File

@@ -3,6 +3,8 @@ package party
import ( import (
"fmt" "fmt"
"nestrilabs/cli/internal/machine" "nestrilabs/cli/internal/machine"
"nestrilabs/cli/internal/resource"
"net/http"
"net/url" "net/url"
"time" "time"
@@ -48,6 +50,9 @@ func (p *Party) Connect() {
wsURL := baseURL + "?" + params.Encode() wsURL := baseURL + "?" + params.Encode()
retryDelay := initialRetryDelay retryDelay := initialRetryDelay
header := http.Header{}
bearer := fmt.Sprintf("Bearer %s", resource.Resource.AuthFingerprintKey.Value)
header.Add("Authorization", bearer)
for { for {
select { select {
@@ -55,7 +60,7 @@ func (p *Party) Connect() {
log.Info("Shutting down connection") log.Info("Shutting down connection")
return return
default: default:
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) conn, _, err := websocket.DefaultDialer.Dial(wsURL, header)
if err != nil { if err != nil {
log.Error("Failed to connect to party server", "err", err) log.Error("Failed to connect to party server", "err", err)
time.Sleep(retryDelay) time.Sleep(retryDelay)
@@ -66,6 +71,7 @@ func (p *Party) Connect() {
} }
continue continue
} }
log.Info("Connection to server", "url", wsURL)
// Reset retry delay on successful connection // Reset retry delay on successful connection
retryDelay = initialRetryDelay retryDelay = initialRetryDelay
@@ -77,10 +83,10 @@ func (p *Party) Connect() {
defer conn.Close() defer conn.Close()
// Send initial message // Send initial message
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello there")); err != nil { // if err := conn.WriteMessage(websocket.TextMessage, []byte("hello there")); err != nil {
log.Error("Failed to send initial message", "err", err) // log.Error("Failed to send initial message", "err", err)
return // return
} // }
// Read messages loop // Read messages loop
for { for {

View File

@@ -1,10 +1,7 @@
package main package main
import ( import (
"context" "nestrilabs/cli/internal/party"
"nestrilabs/cli/internal/session"
"github.com/charmbracelet/log"
) )
func main() { func main() {
@@ -13,46 +10,49 @@ func main() {
// log.Error("Error running the cmd command", "err", err) // log.Error("Error running the cmd command", "err", err)
// } // }
ctx := context.Background() // ctx := context.Background()
config := &session.SessionConfig{ // config := &session.SessionConfig{
Room: "victortest", // Room: "victortest",
Resolution: "1920x1080", // Resolution: "1920x1080",
Framerate: "60", // Framerate: "60",
RelayURL: "https://relay.dathorse.com", // RelayURL: "https://relay.dathorse.com",
Params: "--verbose=true --video-codec=h264 --video-bitrate=4000 --video-bitrate-max=6000 --gpu-card-path=/dev/dri/card1", // Params: "--verbose=true --video-codec=h264 --video-bitrate=4000 --video-bitrate-max=6000 --gpu-card-path=/dev/dri/card1",
GamePath: "/path/to/your/game", // GamePath: "/path/to/your/game",
} // }
sess, err := session.NewSession(config) // sess, err := session.NewSession(config)
if err != nil { // if err != nil {
log.Error("Failed to create session", "err", err) // log.Error("Failed to create session", "err", err)
} // }
// Start the session // // Start the session
if err := sess.Start(ctx); err != nil { // if err := sess.Start(ctx); err != nil {
log.Error("Failed to start session", "err", err) // log.Error("Failed to start session", "err", err)
} // }
// Check if it's running // // Check if it's running
if sess.IsRunning() { // if sess.IsRunning() {
log.Info("Session is running with container ID", "containerId", sess.GetContainerID()) // log.Info("Session is running with container ID", "containerId", sess.GetContainerID())
} // }
env, err := sess.GetEnvironment(ctx) // env, err := sess.GetEnvironment(ctx)
if err != nil { // if err != nil {
log.Printf("Failed to get environment: %v", err) // log.Printf("Failed to get environment: %v", err)
} else { // } else {
for key, value := range env { // for key, value := range env {
log.Info("Found this environment variables", key, value) // log.Info("Found this environment variables", key, value)
} // }
} // }
// Let it run for a while // // Let it run for a while
// time.Sleep(time.Second * 50) // // time.Sleep(time.Second * 50)
// Stop the session // // Stop the session
if err := sess.Stop(ctx); err != nil { // if err := sess.Stop(ctx); err != nil {
log.Error("Failed to stop session", "err", err) // log.Error("Failed to stop session", "err", err)
} // }
party := party.NewParty()
party.Connect()
} }

View File

@@ -133,7 +133,7 @@ app.get(
title: "Nestri API", title: "Nestri API",
description: description:
"The Nestri API gives you the power to run your own customized cloud gaming platform.", "The Nestri API gives you the power to run your own customized cloud gaming platform.",
version: "0.0.3", version: "0.3.0",
}, },
components: { components: {
securitySchemes: { securitySchemes: {

View File

@@ -101,7 +101,7 @@ export default {
const hostname = url.hostname; const hostname = url.hostname;
if (hostname.endsWith("nestri.io")) return true; if (hostname.endsWith("nestri.io")) return true;
if (hostname === "localhost") return true; if (hostname === "localhost") return true;
return true; return false;
}, },
success: async (ctx, value) => { success: async (ctx, value) => {
if (value.provider === "device") { if (value.provider === "device") {

View File

@@ -1,82 +0,0 @@
import { z } from "zod";
import { Hono } from "hono";
import { Result } from "../common"
import { describeRoute } from "hono-openapi";
import type * as Party from "partykit/server";
import { validator, resolver } from "hono-openapi/zod";
const paramsObj = z.object({
code: z.string(),
state: z.string()
})
export module AuthApi {
export const route = new Hono()
.get("/:connection",
describeRoute({
tags: ["Auth"],
summary: "Authenticate the remote device",
description: "This is a callback function to authenticate the remote device.",
responses: {
200: {
content: {
"application/json": {
schema: Result(z.literal("Device authenticated successfully"))
},
},
description: "Authentication successful.",
},
404: {
content: {
"application/json": {
schema: resolver(z.object({ error: z.string() })),
},
},
description: "This device does not exist.",
},
},
}),
validator(
"param",
z.object({
connection: z.string().openapi({
description: "The hostname of the device to login to.",
example: "desktopeuo8vsf",
}),
}),
),
async (c) => {
const param = c.req.valid("param");
const env = c.env as any
const room = env.room as Party.Room
// const connection = room.getConnection(param.connection)
// if (!connection) {
// return c.json({ error: "This device does not exist." }, 404);
// }
// const authParams = getUrlParams(new URL(c.req.url))
// const res = paramsObj.safeParse(authParams)
// if (res.error) {
// return c.json({ error: "Expected url params are missing" })
// }
// connection.send(JSON.stringify({ ...authParams, type: "auth" }))
// FIXME:We just assume the authentication was successful, might wanna do some questioning in the future
return c.text("Device authenticated successfully")
}
)
}
function getUrlParams(url: URL) {
const urlString = url.toString()
const hash = urlString.substring(urlString.indexOf('?') + 1); // Extract the part after the #
const params = new URLSearchParams(hash);
const paramsObj = {} as any;
for (const [key, value] of params.entries()) {
paramsObj[key] = decodeURIComponent(value);
}
return paramsObj;
}

View File

@@ -1,116 +1,65 @@
import "zod-openapi/extend"; import "zod-openapi/extend";
import type * as Party from "partykit/server"; import { Hono } from "hono";
// import { Resource } from "sst";
import { ZodError } from "zod";
import { logger } from "hono/logger"; import { logger } from "hono/logger";
// import { subjects } from "../subjects"; import type { HonoBindings } from "./types";
import { VisibleError } from "../error"; import { ApiSession } from "./session";
// import { ActorContext } from '@nestri/core/actor'; import { openAPISpecs } from "hono-openapi";
import { Hono, type MiddlewareHandler } from "hono";
import { HTTPException } from "hono/http-exception";
import { AuthApi } from "./auth";
const app = new Hono<{ Bindings: HonoBindings }>().basePath('/parties/main/:room');
const app = new Hono().basePath('/parties/main/:id');
// const auth: MiddlewareHandler = async (c, next) => {
// const client = createClient({
// clientID: "api",
// issuer: "http://auth.nestri.io" //Resource.Urls.auth
// });
// const authHeader =
// c.req.query("authorization") ?? c.req.header("authorization");
// if (authHeader) {
// const match = authHeader.match(/^Bearer (.+)$/);
// if (!match || !match[1]) {
// throw new VisibleError(
// "input",
// "auth.token",
// "Bearer token not found or improperly formatted",
// );
// }
// const bearerToken = match[1];
// const result = await client.verify(subjects, bearerToken!);
// if (result.err)
// throw new VisibleError("input", "auth.invalid", "Invalid bearer token");
// if (result.subject.type === "user") {
// // return ActorContext.with(
// // {
// // type: "user",
// // properties: {
// // accessToken: result.subject.properties.accessToken,
// // userID: result.subject.properties.userID,
// // auth: {
// // type: "oauth",
// // clientID: result.aud,
// // },
// // },
// // },
// // next,
// // );
// }
// }
// }
app app
.use(logger(), async (c, next) => { .use(logger(), async (c, next) => {
c.header("Cache-Control", "no-store"); c.header("Cache-Control", "no-store");
return next(); try {
}) await next();
// .use(auth) } catch (e: any) {
app
.route("/auth", AuthApi.route)
// .get("/parties/main/:id", (c) => {
// const id = c.req.param();
// const env = c.env as any
// const party = env.room as Party.Room
// party.broadcast("hello from hono")
// return c.text(`Hello there, ${id.id} 👋🏾`)
// })
.onError((error, c) => {
console.error(error);
if (error instanceof VisibleError) {
return c.json( return c.json(
{ {
code: error.code, error: {
message: error.message, message: e.message || "Internal Server Error",
}, status: e.status || 500,
error.kind === "auth" ? 401 : 400,
);
}
if (error instanceof ZodError) {
const e = error.errors[0];
if (e) {
return c.json(
{
code: e?.code,
message: e?.message,
}, },
400,
);
}
}
if (error instanceof HTTPException) {
return c.json(
{
code: "request",
message: "Invalid request",
}, },
400, e.status || 500
); );
} }
return c.json( })
{
code: "internal", const routes = app
message: "Internal server error", .get("/health", (c) => {
return c.json({
status: "healthy",
timestamp: new Date().toISOString(),
});
})
.route("/session", ApiSession.route)
app.get(
"/doc",
openAPISpecs(routes, {
documentation: {
info: {
title: "Nestri Realtime API",
description:
"The Nestri realtime API gives you the power to connect to your remote machine and relays from a single station",
version: "0.3.0",
}, },
500, components: {
); securitySchemes: {
}); Bearer: {
type: "http",
scheme: "bearer",
bearerFormat: "JWT",
},
},
},
security: [{ Bearer: [] }],
servers: [
{ description: "Production", url: "https://api.nestri.io" },
],
},
}),
);
export type Routes = typeof routes;
export default app export default app

View File

@@ -1,37 +1,47 @@
import type * as Party from "partykit/server";
import app from "./hono" import app from "./hono"
import type * as Party from "partykit/server";
import { tryAuthentication } from "./utils";
export default class Server implements Party.Server { export default class Server implements Party.Server {
constructor(readonly room: Party.Room) { } constructor(readonly room: Party.Room) { }
onRequest(request: Party.Request): Response | Promise<Response> { static async onBeforeRequest(req: Party.Request, lobby: Party.Lobby) {
const docs = new URL(req.url).toString().endsWith("/doc")
if (docs) {
return req
}
return app.fetch(request as any, { room: this.room }) try {
return await tryAuthentication(req, lobby)
} catch (e: any) {
// authentication failed!
return new Response(e, { status: 401 });
}
} }
getConnectionTags( static async onBeforeConnect(request: Party.Request, lobby: Party.Lobby) {
conn: Party.Connection, try {
ctx: Party.ConnectionContext return await tryAuthentication(request, lobby)
) { } catch (e: any) {
console.log("Tagging", conn.id) // authentication failed!
// const country = (ctx.request.cf?.country as string) ?? "unknown"; return new Response(e, { status: 401 });
// return [country]; }
return [conn.id]
// return ["AF"]
} }
onConnect(conn: Party.Connection, ctx: Party.ConnectionContext) { onRequest(req: Party.Request): Response | Promise<Response> {
// A websocket just connected!
return app.fetch(req as any, { room: this.room })
}
getConnectionTags(conn: Party.Connection, ctx: Party.ConnectionContext) {
return [conn.id, ctx.request.cf?.country as any]
}
onConnect(conn: Party.Connection, ctx: Party.ConnectionContext): void | Promise<void> {
console.log(`Connected:, id:${conn.id}, room: ${this.room.id}, url: ${new URL(ctx.request.url).pathname}`);
this.getConnectionTags(conn, ctx) this.getConnectionTags(conn, ctx)
console.log(
`Connected:
id: ${conn.id}
room: ${this.room.id}
url: ${new URL(ctx.request.url).pathname}`
);
// let's send a message to the connection
// conn.send("hello from server");
} }
onMessage(message: string, sender: Party.Connection) { onMessage(message: string, sender: Party.Connection) {

View File

@@ -0,0 +1,217 @@
import { z } from "zod";
import { Hono } from "hono";
import { Result } from "../common"
import { describeRoute } from "hono-openapi";
import type { HonoBindings, WSMessage } from "./types";
import { validator, resolver } from "hono-openapi/zod";
export module ApiSession {
export const route = new Hono<{ Bindings: HonoBindings }>()
.post("/:sessionID/start",
describeRoute({
tags: ["Session"],
summary: "Start a session",
description: "Start a session on this machine",
responses: {
200: {
content: {
"application/json": {
schema: Result(z.object({
success: z.boolean(),
message: z.string(),
sessionID: z.string()
}))
},
},
description: "Session started successfully",
},
500: {
content: {
"application/json": {
schema: resolver(z.object({ error: z.string(), details: z.string() })),
},
},
description: "There was a problem trying to start your session",
},
},
}),
validator(
"param",
z.object({
sessionID: z.string().openapi({
description: "The session ID to start",
example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7",
}),
}),
),
async (c) => {
const param = c.req.valid("param");
const room = c.env.room
const message: WSMessage = {
type: "START_GAME",
sessionID: param.sessionID,
};
try {
room.broadcast(JSON.stringify(message));
return c.json({
success: true,
message: "Game start signal sent",
"sessionID": param.sessionID,
});
} catch (error: any) {
return c.json(
{
error: {
message: "Failed to start game session",
details: error.message,
},
},
500
);
}
}
)
.post("/:sessionID/end",
describeRoute({
tags: ["Session"],
summary: "End a session",
description: "End a session on this machine",
responses: {
200: {
content: {
"application/json": {
schema: Result(z.object({
success: z.boolean(),
message: z.string(),
sessionID: z.string()
}))
},
},
description: "Session successfully ended",
},
500: {
content: {
"application/json": {
schema: resolver(z.object({ error: z.string(), details: z.string() })),
},
},
description: "There was a problem trying to end your session",
},
},
}),
validator(
"param",
z.object({
sessionID: z.string().openapi({
description: "The session ID to end",
example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7",
}),
}),
),
async (c) => {
const param = c.req.valid("param");
const room = c.env.room
const message: WSMessage = {
type: "END_GAME",
sessionID: param.sessionID,
};
try {
room.broadcast(JSON.stringify(message));
return c.json({
success: true,
message: "Game end signal sent",
"sessionID": param.sessionID,
});
} catch (error: any) {
return c.json(
{
error: {
message: "Failed to end game session",
details: error.message,
},
},
500
);
}
}
)
.post("/:sessionID/status",
describeRoute({
tags: ["Session"],
summary: "Get the status of a session",
description: "Get the status of a session on this machine",
responses: {
200: {
content: {
"application/json": {
schema: Result(z.object({
success: z.boolean(),
message: z.string(),
sessionID: z.string()
}))
},
},
description: "Session status query was successful"
},
500: {
content: {
"application/json": {
schema: resolver(z.object({ error: z.string(), details: z.string() })),
},
},
description: "There was a problem trying to querying the status of your session",
},
},
}),
validator(
"param",
z.object({
sessionID: z.string().openapi({
description: "The session ID to query",
example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7",
}),
}),
),
async (c) => {
const param = c.req.valid("param");
const room = c.env.room
const message: WSMessage = {
type: "END_GAME",
sessionID: param.sessionID,
};
try {
room.broadcast(JSON.stringify(message));
return c.json({
success: true,
message: "Game end signal sent",
"sessionID": param.sessionID,
});
} catch (error: any) {
return c.json(
{
error: {
message: "Failed to end game session",
details: error.message,
},
},
500
);
}
}
)
}

View File

@@ -0,0 +1,11 @@
import type * as Party from "partykit/server";
export interface HonoBindings {
room: Party.Room;
}
export type WSMessage = {
type: "START_GAME" | "END_GAME" | "GAME_STATUS";
sessionID: string;
payload?: any;
};

View File

@@ -0,0 +1,21 @@
import type * as Party from "partykit/server";
export async function tryAuthentication(req: Party.Request, lobby: Party.Lobby) {
const authHeader = req.headers.get("authorization") ?? new URL(req.url).searchParams.get("authorization")
if (authHeader) {
const match = authHeader.match(/^Bearer (.+)$/);
if (!match || !match[1]) {
throw new Error("Bearer token not found or improperly formatted");
}
const bearerToken = match[1];
if (bearerToken !== lobby.env.AUTH_FINGERPRINT) {
throw new Error("Invalid authorization token");
}
return req// app.fetch(req as any, { room: this.room })
}
throw new Error("You are not authorized to be here")
}