mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-11 00:05:36 +02:00
128 lines
2.9 KiB
Go
128 lines
2.9 KiB
Go
package common
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
gen "relay/internal/proto"
|
|
"sync"
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
)
|
|
|
|
// readUvarint reads an unsigned varint from the reader
|
|
func readUvarint(r io.ByteReader) (uint64, error) {
|
|
return binary.ReadUvarint(r)
|
|
}
|
|
|
|
// writeUvarint writes an unsigned varint to the writer
|
|
func writeUvarint(w io.Writer, x uint64) error {
|
|
buf := make([]byte, binary.MaxVarintLen64)
|
|
n := binary.PutUvarint(buf, x)
|
|
_, err := w.Write(buf[:n])
|
|
return err
|
|
}
|
|
|
|
// SafeBufioRW wraps a bufio.ReadWriter for sending and receiving JSON and protobufs safely
|
|
type SafeBufioRW struct {
|
|
brw *bufio.ReadWriter
|
|
mutex sync.RWMutex
|
|
}
|
|
|
|
func NewSafeBufioRW(brw *bufio.ReadWriter) *SafeBufioRW {
|
|
return &SafeBufioRW{brw: brw}
|
|
}
|
|
|
|
func (bu *SafeBufioRW) SendProto(msg proto.Message) error {
|
|
bu.mutex.Lock()
|
|
defer bu.mutex.Unlock()
|
|
|
|
protoData, err := proto.Marshal(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Write varint length prefix
|
|
if err := writeUvarint(bu.brw, uint64(len(protoData))); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Write the Protobuf data
|
|
if _, err := bu.brw.Write(protoData); err != nil {
|
|
return err
|
|
}
|
|
|
|
return bu.brw.Flush()
|
|
}
|
|
|
|
func (bu *SafeBufioRW) ReceiveProto(msg proto.Message) error {
|
|
bu.mutex.RLock()
|
|
defer bu.mutex.RUnlock()
|
|
|
|
// Read varint length prefix
|
|
length, err := readUvarint(bu.brw)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Read the Protobuf data
|
|
data := make([]byte, length)
|
|
if _, err := io.ReadFull(bu.brw, data); err != nil {
|
|
return err
|
|
}
|
|
|
|
return proto.Unmarshal(data, msg)
|
|
}
|
|
|
|
type CreateMessageOptions struct {
|
|
SequenceID string
|
|
Latency *gen.ProtoLatencyTracker
|
|
}
|
|
|
|
func CreateMessage(payload proto.Message, payloadType string, opts *CreateMessageOptions) (*gen.ProtoMessage, error) {
|
|
msg := &gen.ProtoMessage{
|
|
MessageBase: &gen.ProtoMessageBase{
|
|
PayloadType: payloadType,
|
|
},
|
|
}
|
|
|
|
if opts != nil {
|
|
if opts.Latency != nil {
|
|
msg.MessageBase.Latency = opts.Latency
|
|
} else if opts.SequenceID != "" {
|
|
msg.MessageBase.Latency = &gen.ProtoLatencyTracker{
|
|
SequenceId: opts.SequenceID,
|
|
Timestamps: []*gen.ProtoTimestampEntry{
|
|
{
|
|
Stage: "created",
|
|
Time: timestamppb.Now(),
|
|
},
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use reflection to set the oneof field automatically
|
|
msgReflect := msg.ProtoReflect()
|
|
payloadReflect := payload.ProtoReflect()
|
|
|
|
oneofDesc := msgReflect.Descriptor().Oneofs().ByName("payload")
|
|
if oneofDesc == nil {
|
|
return nil, errors.New("payload oneof not found")
|
|
}
|
|
|
|
fields := oneofDesc.Fields()
|
|
for i := 0; i < fields.Len(); i++ {
|
|
field := fields.Get(i)
|
|
if field.Message() != nil && field.Message().FullName() == payloadReflect.Descriptor().FullName() {
|
|
msgReflect.Set(field, protoreflect.ValueOfMessage(payloadReflect))
|
|
return msg, nil
|
|
}
|
|
}
|
|
|
|
return nil, errors.New("payload type not found in oneof")
|
|
}
|