Files
netris-nestri/packages/relay/internal/common/safebufio.go
2025-10-21 18:41:45 +03:00

128 lines
2.9 KiB
Go

package common
import (
"bufio"
"encoding/binary"
"errors"
"io"
gen "relay/internal/proto"
"sync"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/timestamppb"
)
// readUvarint reads an unsigned varint from the reader
func readUvarint(r io.ByteReader) (uint64, error) {
return binary.ReadUvarint(r)
}
// writeUvarint writes an unsigned varint to the writer
func writeUvarint(w io.Writer, x uint64) error {
buf := make([]byte, binary.MaxVarintLen64)
n := binary.PutUvarint(buf, x)
_, err := w.Write(buf[:n])
return err
}
// SafeBufioRW wraps a bufio.ReadWriter for sending and receiving JSON and protobufs safely
type SafeBufioRW struct {
brw *bufio.ReadWriter
mutex sync.RWMutex
}
func NewSafeBufioRW(brw *bufio.ReadWriter) *SafeBufioRW {
return &SafeBufioRW{brw: brw}
}
func (bu *SafeBufioRW) SendProto(msg proto.Message) error {
bu.mutex.Lock()
defer bu.mutex.Unlock()
protoData, err := proto.Marshal(msg)
if err != nil {
return err
}
// Write varint length prefix
if err := writeUvarint(bu.brw, uint64(len(protoData))); err != nil {
return err
}
// Write the Protobuf data
if _, err := bu.brw.Write(protoData); err != nil {
return err
}
return bu.brw.Flush()
}
func (bu *SafeBufioRW) ReceiveProto(msg proto.Message) error {
bu.mutex.RLock()
defer bu.mutex.RUnlock()
// Read varint length prefix
length, err := readUvarint(bu.brw)
if err != nil {
return err
}
// Read the Protobuf data
data := make([]byte, length)
if _, err := io.ReadFull(bu.brw, data); err != nil {
return err
}
return proto.Unmarshal(data, msg)
}
type CreateMessageOptions struct {
SequenceID string
Latency *gen.ProtoLatencyTracker
}
func CreateMessage(payload proto.Message, payloadType string, opts *CreateMessageOptions) (*gen.ProtoMessage, error) {
msg := &gen.ProtoMessage{
MessageBase: &gen.ProtoMessageBase{
PayloadType: payloadType,
},
}
if opts != nil {
if opts.Latency != nil {
msg.MessageBase.Latency = opts.Latency
} else if opts.SequenceID != "" {
msg.MessageBase.Latency = &gen.ProtoLatencyTracker{
SequenceId: opts.SequenceID,
Timestamps: []*gen.ProtoTimestampEntry{
{
Stage: "created",
Time: timestamppb.Now(),
},
},
}
}
}
// Use reflection to set the oneof field automatically
msgReflect := msg.ProtoReflect()
payloadReflect := payload.ProtoReflect()
oneofDesc := msgReflect.Descriptor().Oneofs().ByName("payload")
if oneofDesc == nil {
return nil, errors.New("payload oneof not found")
}
fields := oneofDesc.Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
if field.Message() != nil && field.Message().FullName() == payloadReflect.Descriptor().FullName() {
msgReflect.Set(field, protoreflect.ValueOfMessage(payloadReflect))
return msg, nil
}
}
return nil, errors.New("payload type not found in oneof")
}