feat: Add a websocket party (#152)

This adds functionality to connect to remote server thru the party
This commit is contained in:
Wanjohi
2025-01-05 23:45:41 +03:00
committed by GitHub
parent c15657a0d1
commit 56b877fa27
10 changed files with 384 additions and 252 deletions

View File

@@ -3,6 +3,8 @@ package party
import (
"fmt"
"nestrilabs/cli/internal/machine"
"nestrilabs/cli/internal/resource"
"net/http"
"net/url"
"time"
@@ -48,6 +50,9 @@ func (p *Party) Connect() {
wsURL := baseURL + "?" + params.Encode()
retryDelay := initialRetryDelay
header := http.Header{}
bearer := fmt.Sprintf("Bearer %s", resource.Resource.AuthFingerprintKey.Value)
header.Add("Authorization", bearer)
for {
select {
@@ -55,7 +60,7 @@ func (p *Party) Connect() {
log.Info("Shutting down connection")
return
default:
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, header)
if err != nil {
log.Error("Failed to connect to party server", "err", err)
time.Sleep(retryDelay)
@@ -66,6 +71,7 @@ func (p *Party) Connect() {
}
continue
}
log.Info("Connection to server", "url", wsURL)
// Reset retry delay on successful connection
retryDelay = initialRetryDelay
@@ -77,10 +83,10 @@ func (p *Party) Connect() {
defer conn.Close()
// Send initial message
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello there")); err != nil {
log.Error("Failed to send initial message", "err", err)
return
}
// if err := conn.WriteMessage(websocket.TextMessage, []byte("hello there")); err != nil {
// log.Error("Failed to send initial message", "err", err)
// return
// }
// Read messages loop
for {

View File

@@ -1,10 +1,7 @@
package main
import (
"context"
"nestrilabs/cli/internal/session"
"github.com/charmbracelet/log"
"nestrilabs/cli/internal/party"
)
func main() {
@@ -13,46 +10,49 @@ func main() {
// log.Error("Error running the cmd command", "err", err)
// }
ctx := context.Background()
// ctx := context.Background()
config := &session.SessionConfig{
Room: "victortest",
Resolution: "1920x1080",
Framerate: "60",
RelayURL: "https://relay.dathorse.com",
Params: "--verbose=true --video-codec=h264 --video-bitrate=4000 --video-bitrate-max=6000 --gpu-card-path=/dev/dri/card1",
GamePath: "/path/to/your/game",
}
// config := &session.SessionConfig{
// Room: "victortest",
// Resolution: "1920x1080",
// Framerate: "60",
// RelayURL: "https://relay.dathorse.com",
// Params: "--verbose=true --video-codec=h264 --video-bitrate=4000 --video-bitrate-max=6000 --gpu-card-path=/dev/dri/card1",
// GamePath: "/path/to/your/game",
// }
sess, err := session.NewSession(config)
if err != nil {
log.Error("Failed to create session", "err", err)
}
// sess, err := session.NewSession(config)
// if err != nil {
// log.Error("Failed to create session", "err", err)
// }
// Start the session
if err := sess.Start(ctx); err != nil {
log.Error("Failed to start session", "err", err)
}
// // Start the session
// if err := sess.Start(ctx); err != nil {
// log.Error("Failed to start session", "err", err)
// }
// Check if it's running
if sess.IsRunning() {
log.Info("Session is running with container ID", "containerId", sess.GetContainerID())
}
// // Check if it's running
// if sess.IsRunning() {
// log.Info("Session is running with container ID", "containerId", sess.GetContainerID())
// }
env, err := sess.GetEnvironment(ctx)
if err != nil {
log.Printf("Failed to get environment: %v", err)
} else {
for key, value := range env {
log.Info("Found this environment variables", key, value)
}
}
// env, err := sess.GetEnvironment(ctx)
// if err != nil {
// log.Printf("Failed to get environment: %v", err)
// } else {
// for key, value := range env {
// log.Info("Found this environment variables", key, value)
// }
// }
// Let it run for a while
// time.Sleep(time.Second * 50)
// // Let it run for a while
// // time.Sleep(time.Second * 50)
// Stop the session
if err := sess.Stop(ctx); err != nil {
log.Error("Failed to stop session", "err", err)
}
// // Stop the session
// if err := sess.Stop(ctx); err != nil {
// log.Error("Failed to stop session", "err", err)
// }
party := party.NewParty()
party.Connect()
}

View File

@@ -133,7 +133,7 @@ app.get(
title: "Nestri API",
description:
"The Nestri API gives you the power to run your own customized cloud gaming platform.",
version: "0.0.3",
version: "0.3.0",
},
components: {
securitySchemes: {

View File

@@ -101,7 +101,7 @@ export default {
const hostname = url.hostname;
if (hostname.endsWith("nestri.io")) return true;
if (hostname === "localhost") return true;
return true;
return false;
},
success: async (ctx, value) => {
if (value.provider === "device") {

View File

@@ -1,82 +0,0 @@
import { z } from "zod";
import { Hono } from "hono";
import { Result } from "../common"
import { describeRoute } from "hono-openapi";
import type * as Party from "partykit/server";
import { validator, resolver } from "hono-openapi/zod";
const paramsObj = z.object({
code: z.string(),
state: z.string()
})
export module AuthApi {
export const route = new Hono()
.get("/:connection",
describeRoute({
tags: ["Auth"],
summary: "Authenticate the remote device",
description: "This is a callback function to authenticate the remote device.",
responses: {
200: {
content: {
"application/json": {
schema: Result(z.literal("Device authenticated successfully"))
},
},
description: "Authentication successful.",
},
404: {
content: {
"application/json": {
schema: resolver(z.object({ error: z.string() })),
},
},
description: "This device does not exist.",
},
},
}),
validator(
"param",
z.object({
connection: z.string().openapi({
description: "The hostname of the device to login to.",
example: "desktopeuo8vsf",
}),
}),
),
async (c) => {
const param = c.req.valid("param");
const env = c.env as any
const room = env.room as Party.Room
// const connection = room.getConnection(param.connection)
// if (!connection) {
// return c.json({ error: "This device does not exist." }, 404);
// }
// const authParams = getUrlParams(new URL(c.req.url))
// const res = paramsObj.safeParse(authParams)
// if (res.error) {
// return c.json({ error: "Expected url params are missing" })
// }
// connection.send(JSON.stringify({ ...authParams, type: "auth" }))
// FIXME:We just assume the authentication was successful, might wanna do some questioning in the future
return c.text("Device authenticated successfully")
}
)
}
function getUrlParams(url: URL) {
const urlString = url.toString()
const hash = urlString.substring(urlString.indexOf('?') + 1); // Extract the part after the #
const params = new URLSearchParams(hash);
const paramsObj = {} as any;
for (const [key, value] of params.entries()) {
paramsObj[key] = decodeURIComponent(value);
}
return paramsObj;
}

View File

@@ -1,116 +1,65 @@
import "zod-openapi/extend";
import type * as Party from "partykit/server";
// import { Resource } from "sst";
import { ZodError } from "zod";
import { Hono } from "hono";
import { logger } from "hono/logger";
// import { subjects } from "../subjects";
import { VisibleError } from "../error";
// import { ActorContext } from '@nestri/core/actor';
import { Hono, type MiddlewareHandler } from "hono";
import { HTTPException } from "hono/http-exception";
import { AuthApi } from "./auth";
import type { HonoBindings } from "./types";
import { ApiSession } from "./session";
import { openAPISpecs } from "hono-openapi";
const app = new Hono().basePath('/parties/main/:id');
// const auth: MiddlewareHandler = async (c, next) => {
// const client = createClient({
// clientID: "api",
// issuer: "http://auth.nestri.io" //Resource.Urls.auth
// });
// const authHeader =
// c.req.query("authorization") ?? c.req.header("authorization");
// if (authHeader) {
// const match = authHeader.match(/^Bearer (.+)$/);
// if (!match || !match[1]) {
// throw new VisibleError(
// "input",
// "auth.token",
// "Bearer token not found or improperly formatted",
// );
// }
// const bearerToken = match[1];
// const result = await client.verify(subjects, bearerToken!);
// if (result.err)
// throw new VisibleError("input", "auth.invalid", "Invalid bearer token");
// if (result.subject.type === "user") {
// // return ActorContext.with(
// // {
// // type: "user",
// // properties: {
// // accessToken: result.subject.properties.accessToken,
// // userID: result.subject.properties.userID,
// // auth: {
// // type: "oauth",
// // clientID: result.aud,
// // },
// // },
// // },
// // next,
// // );
// }
// }
// }
const app = new Hono<{ Bindings: HonoBindings }>().basePath('/parties/main/:room');
app
.use(logger(), async (c, next) => {
c.header("Cache-Control", "no-store");
return next();
})
// .use(auth)
app
.route("/auth", AuthApi.route)
// .get("/parties/main/:id", (c) => {
// const id = c.req.param();
// const env = c.env as any
// const party = env.room as Party.Room
// party.broadcast("hello from hono")
// return c.text(`Hello there, ${id.id} 👋🏾`)
// })
.onError((error, c) => {
console.error(error);
if (error instanceof VisibleError) {
try {
await next();
} catch (e: any) {
return c.json(
{
code: error.code,
message: error.message,
},
error.kind === "auth" ? 401 : 400,
);
}
if (error instanceof ZodError) {
const e = error.errors[0];
if (e) {
return c.json(
{
code: e?.code,
message: e?.message,
error: {
message: e.message || "Internal Server Error",
status: e.status || 500,
},
400,
);
}
}
if (error instanceof HTTPException) {
return c.json(
{
code: "request",
message: "Invalid request",
},
400,
e.status || 500
);
}
return c.json(
{
code: "internal",
message: "Internal server error",
})
const routes = app
.get("/health", (c) => {
return c.json({
status: "healthy",
timestamp: new Date().toISOString(),
});
})
.route("/session", ApiSession.route)
app.get(
"/doc",
openAPISpecs(routes, {
documentation: {
info: {
title: "Nestri Realtime API",
description:
"The Nestri realtime API gives you the power to connect to your remote machine and relays from a single station",
version: "0.3.0",
},
500,
);
});
components: {
securitySchemes: {
Bearer: {
type: "http",
scheme: "bearer",
bearerFormat: "JWT",
},
},
},
security: [{ Bearer: [] }],
servers: [
{ description: "Production", url: "https://api.nestri.io" },
],
},
}),
);
export type Routes = typeof routes;
export default app

View File

@@ -1,37 +1,47 @@
import type * as Party from "partykit/server";
import app from "./hono"
import type * as Party from "partykit/server";
import { tryAuthentication } from "./utils";
export default class Server implements Party.Server {
constructor(readonly room: Party.Room) { }
onRequest(request: Party.Request): Response | Promise<Response> {
static async onBeforeRequest(req: Party.Request, lobby: Party.Lobby) {
const docs = new URL(req.url).toString().endsWith("/doc")
if (docs) {
return req
}
return app.fetch(request as any, { room: this.room })
try {
return await tryAuthentication(req, lobby)
} catch (e: any) {
// authentication failed!
return new Response(e, { status: 401 });
}
}
getConnectionTags(
conn: Party.Connection,
ctx: Party.ConnectionContext
) {
console.log("Tagging", conn.id)
// const country = (ctx.request.cf?.country as string) ?? "unknown";
// return [country];
return [conn.id]
// return ["AF"]
static async onBeforeConnect(request: Party.Request, lobby: Party.Lobby) {
try {
return await tryAuthentication(request, lobby)
} catch (e: any) {
// authentication failed!
return new Response(e, { status: 401 });
}
}
onConnect(conn: Party.Connection, ctx: Party.ConnectionContext) {
// A websocket just connected!
onRequest(req: Party.Request): Response | Promise<Response> {
return app.fetch(req as any, { room: this.room })
}
getConnectionTags(conn: Party.Connection, ctx: Party.ConnectionContext) {
return [conn.id, ctx.request.cf?.country as any]
}
onConnect(conn: Party.Connection, ctx: Party.ConnectionContext): void | Promise<void> {
console.log(`Connected:, id:${conn.id}, room: ${this.room.id}, url: ${new URL(ctx.request.url).pathname}`);
this.getConnectionTags(conn, ctx)
console.log(
`Connected:
id: ${conn.id}
room: ${this.room.id}
url: ${new URL(ctx.request.url).pathname}`
);
// let's send a message to the connection
// conn.send("hello from server");
}
onMessage(message: string, sender: Party.Connection) {

View File

@@ -0,0 +1,217 @@
import { z } from "zod";
import { Hono } from "hono";
import { Result } from "../common"
import { describeRoute } from "hono-openapi";
import type { HonoBindings, WSMessage } from "./types";
import { validator, resolver } from "hono-openapi/zod";
export module ApiSession {
export const route = new Hono<{ Bindings: HonoBindings }>()
.post("/:sessionID/start",
describeRoute({
tags: ["Session"],
summary: "Start a session",
description: "Start a session on this machine",
responses: {
200: {
content: {
"application/json": {
schema: Result(z.object({
success: z.boolean(),
message: z.string(),
sessionID: z.string()
}))
},
},
description: "Session started successfully",
},
500: {
content: {
"application/json": {
schema: resolver(z.object({ error: z.string(), details: z.string() })),
},
},
description: "There was a problem trying to start your session",
},
},
}),
validator(
"param",
z.object({
sessionID: z.string().openapi({
description: "The session ID to start",
example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7",
}),
}),
),
async (c) => {
const param = c.req.valid("param");
const room = c.env.room
const message: WSMessage = {
type: "START_GAME",
sessionID: param.sessionID,
};
try {
room.broadcast(JSON.stringify(message));
return c.json({
success: true,
message: "Game start signal sent",
"sessionID": param.sessionID,
});
} catch (error: any) {
return c.json(
{
error: {
message: "Failed to start game session",
details: error.message,
},
},
500
);
}
}
)
.post("/:sessionID/end",
describeRoute({
tags: ["Session"],
summary: "End a session",
description: "End a session on this machine",
responses: {
200: {
content: {
"application/json": {
schema: Result(z.object({
success: z.boolean(),
message: z.string(),
sessionID: z.string()
}))
},
},
description: "Session successfully ended",
},
500: {
content: {
"application/json": {
schema: resolver(z.object({ error: z.string(), details: z.string() })),
},
},
description: "There was a problem trying to end your session",
},
},
}),
validator(
"param",
z.object({
sessionID: z.string().openapi({
description: "The session ID to end",
example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7",
}),
}),
),
async (c) => {
const param = c.req.valid("param");
const room = c.env.room
const message: WSMessage = {
type: "END_GAME",
sessionID: param.sessionID,
};
try {
room.broadcast(JSON.stringify(message));
return c.json({
success: true,
message: "Game end signal sent",
"sessionID": param.sessionID,
});
} catch (error: any) {
return c.json(
{
error: {
message: "Failed to end game session",
details: error.message,
},
},
500
);
}
}
)
.post("/:sessionID/status",
describeRoute({
tags: ["Session"],
summary: "Get the status of a session",
description: "Get the status of a session on this machine",
responses: {
200: {
content: {
"application/json": {
schema: Result(z.object({
success: z.boolean(),
message: z.string(),
sessionID: z.string()
}))
},
},
description: "Session status query was successful"
},
500: {
content: {
"application/json": {
schema: resolver(z.object({ error: z.string(), details: z.string() })),
},
},
description: "There was a problem trying to querying the status of your session",
},
},
}),
validator(
"param",
z.object({
sessionID: z.string().openapi({
description: "The session ID to query",
example: "18d8b4b5-29ba-4a62-8cf9-7059449907a7",
}),
}),
),
async (c) => {
const param = c.req.valid("param");
const room = c.env.room
const message: WSMessage = {
type: "END_GAME",
sessionID: param.sessionID,
};
try {
room.broadcast(JSON.stringify(message));
return c.json({
success: true,
message: "Game end signal sent",
"sessionID": param.sessionID,
});
} catch (error: any) {
return c.json(
{
error: {
message: "Failed to end game session",
details: error.message,
},
},
500
);
}
}
)
}

View File

@@ -0,0 +1,11 @@
import type * as Party from "partykit/server";
export interface HonoBindings {
room: Party.Room;
}
export type WSMessage = {
type: "START_GAME" | "END_GAME" | "GAME_STATUS";
sessionID: string;
payload?: any;
};

View File

@@ -0,0 +1,21 @@
import type * as Party from "partykit/server";
export async function tryAuthentication(req: Party.Request, lobby: Party.Lobby) {
const authHeader = req.headers.get("authorization") ?? new URL(req.url).searchParams.get("authorization")
if (authHeader) {
const match = authHeader.match(/^Bearer (.+)$/);
if (!match || !match[1]) {
throw new Error("Bearer token not found or improperly formatted");
}
const bearerToken = match[1];
if (bearerToken !== lobby.env.AUTH_FINGERPRINT) {
throw new Error("Invalid authorization token");
}
return req// app.fetch(req as any, { room: this.room })
}
throw new Error("You are not authorized to be here")
}