From 56b877fa27feb7e4c2f3d1968f0b22572dadebcd Mon Sep 17 00:00:00 2001 From: Wanjohi <71614375+wanjohiryan@users.noreply.github.com> Date: Sun, 5 Jan 2025 23:45:41 +0300 Subject: [PATCH] =?UTF-8?q?=E2=AD=90feat:=20Add=20a=20websocket=20party=20?= =?UTF-8?q?(#152)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds functionality to connect to remote server thru the party --- packages/cli/internal/party/client.go | 16 +- packages/cli/main.go | 78 ++++----- packages/functions/src/api/index.ts | 2 +- packages/functions/src/auth.ts | 2 +- packages/functions/src/party/auth.ts | 82 --------- packages/functions/src/party/hono.ts | 149 ++++++---------- packages/functions/src/party/index.ts | 58 ++++--- packages/functions/src/party/session.ts | 217 ++++++++++++++++++++++++ packages/functions/src/party/types.ts | 11 ++ packages/functions/src/party/utils.ts | 21 +++ 10 files changed, 384 insertions(+), 252 deletions(-) delete mode 100644 packages/functions/src/party/auth.ts create mode 100644 packages/functions/src/party/session.ts create mode 100644 packages/functions/src/party/types.ts create mode 100644 packages/functions/src/party/utils.ts diff --git a/packages/cli/internal/party/client.go b/packages/cli/internal/party/client.go index 24e6fefa..4abffde8 100644 --- a/packages/cli/internal/party/client.go +++ b/packages/cli/internal/party/client.go @@ -3,6 +3,8 @@ package party import ( "fmt" "nestrilabs/cli/internal/machine" + "nestrilabs/cli/internal/resource" + "net/http" "net/url" "time" @@ -48,6 +50,9 @@ func (p *Party) Connect() { wsURL := baseURL + "?" + params.Encode() retryDelay := initialRetryDelay + header := http.Header{} + bearer := fmt.Sprintf("Bearer %s", resource.Resource.AuthFingerprintKey.Value) + header.Add("Authorization", bearer) for { select { @@ -55,7 +60,7 @@ func (p *Party) Connect() { log.Info("Shutting down connection") return default: - conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, header) if err != nil { log.Error("Failed to connect to party server", "err", err) time.Sleep(retryDelay) @@ -66,6 +71,7 @@ func (p *Party) Connect() { } continue } + log.Info("Connection to server", "url", wsURL) // Reset retry delay on successful connection retryDelay = initialRetryDelay @@ -77,10 +83,10 @@ func (p *Party) Connect() { defer conn.Close() // Send initial message - if err := conn.WriteMessage(websocket.TextMessage, []byte("hello there")); err != nil { - log.Error("Failed to send initial message", "err", err) - return - } + // if err := conn.WriteMessage(websocket.TextMessage, []byte("hello there")); err != nil { + // log.Error("Failed to send initial message", "err", err) + // return + // } // Read messages loop for { diff --git a/packages/cli/main.go b/packages/cli/main.go index 98534501..957992ec 100644 --- a/packages/cli/main.go +++ b/packages/cli/main.go @@ -1,10 +1,7 @@ package main import ( - "context" - "nestrilabs/cli/internal/session" - - "github.com/charmbracelet/log" + "nestrilabs/cli/internal/party" ) func main() { @@ -13,46 +10,49 @@ func main() { // log.Error("Error running the cmd command", "err", err) // } - ctx := context.Background() + // ctx := context.Background() - config := &session.SessionConfig{ - Room: "victortest", - Resolution: "1920x1080", - Framerate: "60", - RelayURL: "https://relay.dathorse.com", - Params: "--verbose=true --video-codec=h264 --video-bitrate=4000 --video-bitrate-max=6000 --gpu-card-path=/dev/dri/card1", - GamePath: "/path/to/your/game", - } + // config := &session.SessionConfig{ + // Room: "victortest", + // Resolution: "1920x1080", + // Framerate: "60", + // RelayURL: "https://relay.dathorse.com", + // Params: "--verbose=true --video-codec=h264 --video-bitrate=4000 --video-bitrate-max=6000 --gpu-card-path=/dev/dri/card1", + // GamePath: "/path/to/your/game", + // } - sess, err := session.NewSession(config) - if err != nil { - log.Error("Failed to create session", "err", err) - } + // sess, err := session.NewSession(config) + // if err != nil { + // log.Error("Failed to create session", "err", err) + // } - // Start the session - if err := sess.Start(ctx); err != nil { - log.Error("Failed to start session", "err", err) - } + // // Start the session + // if err := sess.Start(ctx); err != nil { + // log.Error("Failed to start session", "err", err) + // } - // Check if it's running - if sess.IsRunning() { - log.Info("Session is running with container ID", "containerId", sess.GetContainerID()) - } + // // Check if it's running + // if sess.IsRunning() { + // log.Info("Session is running with container ID", "containerId", sess.GetContainerID()) + // } - env, err := sess.GetEnvironment(ctx) - if err != nil { - log.Printf("Failed to get environment: %v", err) - } else { - for key, value := range env { - log.Info("Found this environment variables", key, value) - } - } + // env, err := sess.GetEnvironment(ctx) + // if err != nil { + // log.Printf("Failed to get environment: %v", err) + // } else { + // for key, value := range env { + // log.Info("Found this environment variables", key, value) + // } + // } - // Let it run for a while - // time.Sleep(time.Second * 50) + // // Let it run for a while + // // time.Sleep(time.Second * 50) - // Stop the session - if err := sess.Stop(ctx); err != nil { - log.Error("Failed to stop session", "err", err) - } + // // Stop the session + // if err := sess.Stop(ctx); err != nil { + // log.Error("Failed to stop session", "err", err) + // } + + party := party.NewParty() + party.Connect() } diff --git a/packages/functions/src/api/index.ts b/packages/functions/src/api/index.ts index b1fd453d..f48e20ea 100644 --- a/packages/functions/src/api/index.ts +++ b/packages/functions/src/api/index.ts @@ -133,7 +133,7 @@ app.get( title: "Nestri API", description: "The Nestri API gives you the power to run your own customized cloud gaming platform.", - version: "0.0.3", + version: "0.3.0", }, components: { securitySchemes: { diff --git a/packages/functions/src/auth.ts b/packages/functions/src/auth.ts index 670ba0ce..c560607c 100644 --- a/packages/functions/src/auth.ts +++ b/packages/functions/src/auth.ts @@ -101,7 +101,7 @@ export default { const hostname = url.hostname; if (hostname.endsWith("nestri.io")) return true; if (hostname === "localhost") return true; - return true; + return false; }, success: async (ctx, value) => { if (value.provider === "device") { diff --git a/packages/functions/src/party/auth.ts b/packages/functions/src/party/auth.ts deleted file mode 100644 index 5ee0912c..00000000 --- a/packages/functions/src/party/auth.ts +++ /dev/null @@ -1,82 +0,0 @@ -import { z } from "zod"; -import { Hono } from "hono"; -import { Result } from "../common" -import { describeRoute } from "hono-openapi"; -import type * as Party from "partykit/server"; -import { validator, resolver } from "hono-openapi/zod"; - -const paramsObj = z.object({ - code: z.string(), - state: z.string() -}) - -export module AuthApi { - export const route = new Hono() - .get("/:connection", - describeRoute({ - tags: ["Auth"], - summary: "Authenticate the remote device", - description: "This is a callback function to authenticate the remote device.", - responses: { - 200: { - content: { - "application/json": { - schema: Result(z.literal("Device authenticated successfully")) - }, - }, - description: "Authentication successful.", - }, - 404: { - content: { - "application/json": { - schema: resolver(z.object({ error: z.string() })), - }, - }, - description: "This device does not exist.", - }, - }, - }), - validator( - "param", - z.object({ - connection: z.string().openapi({ - description: "The hostname of the device to login to.", - example: "desktopeuo8vsf", - }), - }), - ), - async (c) => { - const param = c.req.valid("param"); - const env = c.env as any - const room = env.room as Party.Room - - - // const connection = room.getConnection(param.connection) - // if (!connection) { - // return c.json({ error: "This device does not exist." }, 404); - // } - - // const authParams = getUrlParams(new URL(c.req.url)) - // const res = paramsObj.safeParse(authParams) - // if (res.error) { - // return c.json({ error: "Expected url params are missing" }) - // } - - // connection.send(JSON.stringify({ ...authParams, type: "auth" })) - - // FIXME:We just assume the authentication was successful, might wanna do some questioning in the future - return c.text("Device authenticated successfully") - } - ) -} - -function getUrlParams(url: URL) { - const urlString = url.toString() - const hash = urlString.substring(urlString.indexOf('?') + 1); // Extract the part after the # - const params = new URLSearchParams(hash); - const paramsObj = {} as any; - for (const [key, value] of params.entries()) { - paramsObj[key] = decodeURIComponent(value); - } - return paramsObj; -} \ No newline at end of file diff --git a/packages/functions/src/party/hono.ts b/packages/functions/src/party/hono.ts index da31f286..3f05f60b 100644 --- a/packages/functions/src/party/hono.ts +++ b/packages/functions/src/party/hono.ts @@ -1,116 +1,65 @@ import "zod-openapi/extend"; -import type * as Party from "partykit/server"; -// import { Resource } from "sst"; -import { ZodError } from "zod"; +import { Hono } from "hono"; import { logger } from "hono/logger"; -// import { subjects } from "../subjects"; -import { VisibleError } from "../error"; -// import { ActorContext } from '@nestri/core/actor'; -import { Hono, type MiddlewareHandler } from "hono"; -import { HTTPException } from "hono/http-exception"; -import { AuthApi } from "./auth"; +import type { HonoBindings } from "./types"; +import { ApiSession } from "./session"; +import { openAPISpecs } from "hono-openapi"; - -const app = new Hono().basePath('/parties/main/:id'); -// const auth: MiddlewareHandler = async (c, next) => { -// const client = createClient({ -// clientID: "api", -// issuer: "http://auth.nestri.io" //Resource.Urls.auth -// }); - -// const authHeader = -// c.req.query("authorization") ?? c.req.header("authorization"); -// if (authHeader) { -// const match = authHeader.match(/^Bearer (.+)$/); -// if (!match || !match[1]) { -// throw new VisibleError( -// "input", -// "auth.token", -// "Bearer token not found or improperly formatted", -// ); -// } -// const bearerToken = match[1]; - -// const result = await client.verify(subjects, bearerToken!); -// if (result.err) -// throw new VisibleError("input", "auth.invalid", "Invalid bearer token"); -// if (result.subject.type === "user") { -// // return ActorContext.with( -// // { -// // type: "user", -// // properties: { -// // accessToken: result.subject.properties.accessToken, -// // userID: result.subject.properties.userID, -// // auth: { -// // type: "oauth", -// // clientID: result.aud, -// // }, -// // }, -// // }, -// // next, -// // ); -// } -// } -// } +const app = new Hono<{ Bindings: HonoBindings }>().basePath('/parties/main/:room'); app .use(logger(), async (c, next) => { c.header("Cache-Control", "no-store"); - return next(); - }) -// .use(auth) - - -app - .route("/auth", AuthApi.route) - // .get("/parties/main/:id", (c) => { - // const id = c.req.param(); - // const env = c.env as any - // const party = env.room as Party.Room - // party.broadcast("hello from hono") - - // return c.text(`Hello there, ${id.id} 👋🏾`) - // }) - .onError((error, c) => { - console.error(error); - if (error instanceof VisibleError) { + try { + await next(); + } catch (e: any) { return c.json( { - code: error.code, - message: error.message, - }, - error.kind === "auth" ? 401 : 400, - ); - } - if (error instanceof ZodError) { - const e = error.errors[0]; - if (e) { - return c.json( - { - code: e?.code, - message: e?.message, + error: { + message: e.message || "Internal Server Error", + status: e.status || 500, }, - 400, - ); - } - } - if (error instanceof HTTPException) { - return c.json( - { - code: "request", - message: "Invalid request", }, - 400, + e.status || 500 ); } - return c.json( - { - code: "internal", - message: "Internal server error", + }) + +const routes = app + .get("/health", (c) => { + return c.json({ + status: "healthy", + timestamp: new Date().toISOString(), + }); + }) + .route("/session", ApiSession.route) + +app.get( + "/doc", + openAPISpecs(routes, { + documentation: { + info: { + title: "Nestri Realtime API", + description: + "The Nestri realtime API gives you the power to connect to your remote machine and relays from a single station", + version: "0.3.0", }, - 500, - ); - }); - + components: { + securitySchemes: { + Bearer: { + type: "http", + scheme: "bearer", + bearerFormat: "JWT", + }, + }, + }, + security: [{ Bearer: [] }], + servers: [ + { description: "Production", url: "https://api.nestri.io" }, + ], + }, + }), +); +export type Routes = typeof routes; export default app \ No newline at end of file diff --git a/packages/functions/src/party/index.ts b/packages/functions/src/party/index.ts index 16392c4c..f23057c1 100644 --- a/packages/functions/src/party/index.ts +++ b/packages/functions/src/party/index.ts @@ -1,37 +1,47 @@ -import type * as Party from "partykit/server"; import app from "./hono" +import type * as Party from "partykit/server"; +import { tryAuthentication } from "./utils"; + export default class Server implements Party.Server { constructor(readonly room: Party.Room) { } - onRequest(request: Party.Request): Response | Promise { + static async onBeforeRequest(req: Party.Request, lobby: Party.Lobby) { + const docs = new URL(req.url).toString().endsWith("/doc") + if (docs) { + return req + } - return app.fetch(request as any, { room: this.room }) + try { + return await tryAuthentication(req, lobby) + } catch (e: any) { + // authentication failed! + return new Response(e, { status: 401 }); + } } - getConnectionTags( - conn: Party.Connection, - ctx: Party.ConnectionContext - ) { - console.log("Tagging", conn.id) - // const country = (ctx.request.cf?.country as string) ?? "unknown"; - // return [country]; - return [conn.id] - // return ["AF"] + static async onBeforeConnect(request: Party.Request, lobby: Party.Lobby) { + try { + return await tryAuthentication(request, lobby) + } catch (e: any) { + // authentication failed! + return new Response(e, { status: 401 }); + } } - onConnect(conn: Party.Connection, ctx: Party.ConnectionContext) { - // A websocket just connected! + onRequest(req: Party.Request): Response | Promise { + + return app.fetch(req as any, { room: this.room }) + } + + getConnectionTags(conn: Party.Connection, ctx: Party.ConnectionContext) { + + return [conn.id, ctx.request.cf?.country as any] + } + + onConnect(conn: Party.Connection, ctx: Party.ConnectionContext): void | Promise { + console.log(`Connected:, id:${conn.id}, room: ${this.room.id}, url: ${new URL(ctx.request.url).pathname}`); + this.getConnectionTags(conn, ctx) - - console.log( - `Connected: - id: ${conn.id} - room: ${this.room.id} - url: ${new URL(ctx.request.url).pathname}` - ); - - // let's send a message to the connection - // conn.send("hello from server"); } onMessage(message: string, sender: Party.Connection) { diff --git a/packages/functions/src/party/session.ts b/packages/functions/src/party/session.ts new file mode 100644 index 00000000..49c6ce27 --- /dev/null +++ b/packages/functions/src/party/session.ts @@ -0,0 +1,217 @@ +import { z } from "zod"; +import { Hono } from "hono"; +import { Result } from "../common" +import { describeRoute } from "hono-openapi"; +import type { HonoBindings, WSMessage } from "./types"; +import { validator, resolver } from "hono-openapi/zod"; + +export module ApiSession { + export const route = new Hono<{ Bindings: HonoBindings }>() + .post("/:sessionID/start", + describeRoute({ + tags: ["Session"], + summary: "Start a session", + description: "Start a session on this machine", + responses: { + 200: { + content: { + "application/json": { + schema: Result(z.object({ + success: z.boolean(), + message: z.string(), + sessionID: z.string() + })) + }, + }, + description: "Session started successfully", + }, + 500: { + content: { + "application/json": { + schema: resolver(z.object({ error: z.string(), details: z.string() })), + }, + }, + description: "There was a problem trying to start your session", + }, + }, + }), + validator( + "param", + z.object({ + sessionID: z.string().openapi({ + description: "The session ID to start", + example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7", + }), + }), + ), + async (c) => { + const param = c.req.valid("param"); + const room = c.env.room + + const message: WSMessage = { + type: "START_GAME", + sessionID: param.sessionID, + }; + + try { + + room.broadcast(JSON.stringify(message)); + + return c.json({ + success: true, + message: "Game start signal sent", + "sessionID": param.sessionID, + }); + + } catch (error: any) { + return c.json( + { + error: { + message: "Failed to start game session", + details: error.message, + }, + }, + 500 + ); + } + } + ) + .post("/:sessionID/end", + describeRoute({ + tags: ["Session"], + summary: "End a session", + description: "End a session on this machine", + responses: { + 200: { + content: { + "application/json": { + schema: Result(z.object({ + success: z.boolean(), + message: z.string(), + sessionID: z.string() + })) + }, + }, + description: "Session successfully ended", + }, + 500: { + content: { + "application/json": { + schema: resolver(z.object({ error: z.string(), details: z.string() })), + }, + }, + description: "There was a problem trying to end your session", + }, + }, + }), + validator( + "param", + z.object({ + sessionID: z.string().openapi({ + description: "The session ID to end", + example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7", + }), + }), + ), + async (c) => { + const param = c.req.valid("param"); + const room = c.env.room + + const message: WSMessage = { + type: "END_GAME", + sessionID: param.sessionID, + }; + + try { + + room.broadcast(JSON.stringify(message)); + + return c.json({ + success: true, + message: "Game end signal sent", + "sessionID": param.sessionID, + }); + + } catch (error: any) { + return c.json( + { + error: { + message: "Failed to end game session", + details: error.message, + }, + }, + 500 + ); + } + } + ) + .post("/:sessionID/status", + describeRoute({ + tags: ["Session"], + summary: "Get the status of a session", + description: "Get the status of a session on this machine", + responses: { + 200: { + content: { + "application/json": { + schema: Result(z.object({ + success: z.boolean(), + message: z.string(), + sessionID: z.string() + })) + }, + }, + description: "Session status query was successful" + }, + 500: { + content: { + "application/json": { + schema: resolver(z.object({ error: z.string(), details: z.string() })), + }, + }, + description: "There was a problem trying to querying the status of your session", + }, + }, + }), + validator( + "param", + z.object({ + sessionID: z.string().openapi({ + description: "The session ID to query", + example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7", + }), + }), + ), + async (c) => { + const param = c.req.valid("param"); + const room = c.env.room + + const message: WSMessage = { + type: "END_GAME", + sessionID: param.sessionID, + }; + + try { + + room.broadcast(JSON.stringify(message)); + + return c.json({ + success: true, + message: "Game end signal sent", + "sessionID": param.sessionID, + }); + + } catch (error: any) { + return c.json( + { + error: { + message: "Failed to end game session", + details: error.message, + }, + }, + 500 + ); + } + } + ) +} \ No newline at end of file diff --git a/packages/functions/src/party/types.ts b/packages/functions/src/party/types.ts new file mode 100644 index 00000000..75cef253 --- /dev/null +++ b/packages/functions/src/party/types.ts @@ -0,0 +1,11 @@ +import type * as Party from "partykit/server"; + +export interface HonoBindings { + room: Party.Room; +} + +export type WSMessage = { + type: "START_GAME" | "END_GAME" | "GAME_STATUS"; + sessionID: string; + payload?: any; + }; \ No newline at end of file diff --git a/packages/functions/src/party/utils.ts b/packages/functions/src/party/utils.ts new file mode 100644 index 00000000..87a5c62d --- /dev/null +++ b/packages/functions/src/party/utils.ts @@ -0,0 +1,21 @@ +import type * as Party from "partykit/server"; + +export async function tryAuthentication(req: Party.Request, lobby: Party.Lobby) { + const authHeader = req.headers.get("authorization") ?? new URL(req.url).searchParams.get("authorization") + if (authHeader) { + const match = authHeader.match(/^Bearer (.+)$/); + + if (!match || !match[1]) { + throw new Error("Bearer token not found or improperly formatted"); + } + + const bearerToken = match[1]; + + if (bearerToken !== lobby.env.AUTH_FINGERPRINT) { + throw new Error("Invalid authorization token"); + } + + return req// app.fetch(req as any, { room: this.room }) + } + throw new Error("You are not authorized to be here") +} \ No newline at end of file