From b3b8614d38c433ff9824a7ce2860f406c8e379a5 Mon Sep 17 00:00:00 2001 From: Marc Brooks Date: Sat, 27 Sep 2025 01:28:28 -0500 Subject: [PATCH] Add a time limit for each message type/queue. --- hidrpc.go | 11 +++++--- internal/hidrpc/hidrpc.go | 45 +++++++++++++++++---------------- internal/hidrpc/message.go | 29 +++++++++++++-------- jsonrpc.go | 33 ++++++++++++------------ webrtc.go | 52 ++++++++++++++++++++++---------------- 5 files changed, 96 insertions(+), 74 deletions(-) diff --git a/hidrpc.go b/hidrpc.go index b2233673..9e993016 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -53,13 +53,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { rpcCancelKeyboardMacro() return - case hidrpc.TypeCancelKeyboardMacroByTokenReport: - token, err := message.KeyboardMacroToken() + case hidrpc.TypeKeyboardMacroTokenState: + tokenState, err := message.KeyboardMacroTokenState() if err != nil { logger.Warn().Err(err).Msg("failed to get keyboard macro token") return } - rpcCancelKeyboardMacroByToken(token) + rpcCancelKeyboardMacroByToken(tokenState.Token) return case hidrpc.TypeKeypressKeepAliveReport: @@ -96,6 +96,7 @@ func onHidMessage(msg hidQueueMessage, session *Session) { scopedLogger := hidRPCLogger.With(). Str("channel", msg.channel). + Dur("timelimit", msg.timelimit). Int("data_len", dataLen). Bytes("data", data[:min(dataLen, 32)]). Logger() @@ -125,7 +126,7 @@ func onHidMessage(msg hidQueueMessage, session *Session) { r <- nil }() select { - case <-time.After(1 * time.Second): + case <-time.After(msg.timelimit * time.Second): scopedLogger.Warn().Msg("HID RPC message timed out") case <-r: scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") @@ -241,6 +242,8 @@ func reportHidRPC(params any, session *Session) { message, err = hidrpc.NewKeydownStateMessage(params).Marshal() case hidrpc.KeyboardMacroState: message, err = hidrpc.NewKeyboardMacroStateMessage(params.State, params.IsPaste).Marshal() + case hidrpc.KeyboardMacroTokenState: + message, err = hidrpc.NewKeyboardMacroTokenMessage(params.Token).Marshal() default: err = fmt.Errorf("unknown HID RPC message type: %T", params) } diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go index f8844478..6851d078 100644 --- a/internal/hidrpc/hidrpc.go +++ b/internal/hidrpc/hidrpc.go @@ -2,6 +2,7 @@ package hidrpc import ( "fmt" + "time" "github.com/google/uuid" "github.com/jetkvm/kvm/internal/usbgadget" @@ -11,19 +12,19 @@ import ( type MessageType byte const ( - TypeHandshake MessageType = 0x01 - TypeKeyboardReport MessageType = 0x02 - TypePointerReport MessageType = 0x03 - TypeWheelReport MessageType = 0x04 - TypeKeypressReport MessageType = 0x05 - TypeKeypressKeepAliveReport MessageType = 0x09 - TypeMouseReport MessageType = 0x06 - TypeKeyboardMacroReport MessageType = 0x07 - TypeCancelKeyboardMacroReport MessageType = 0x08 - TypeKeyboardLedState MessageType = 0x32 - TypeKeydownState MessageType = 0x33 - TypeKeyboardMacroState MessageType = 0x34 - TypeCancelKeyboardMacroByTokenReport MessageType = 0x35 + TypeHandshake MessageType = 0x01 + TypeKeyboardReport MessageType = 0x02 + TypePointerReport MessageType = 0x03 + TypeWheelReport MessageType = 0x04 + TypeKeypressReport MessageType = 0x05 + TypeKeypressKeepAliveReport MessageType = 0x09 + TypeMouseReport MessageType = 0x06 + TypeKeyboardMacroReport MessageType = 0x07 + TypeCancelKeyboardMacroReport MessageType = 0x08 + TypeKeyboardLedState MessageType = 0x32 + TypeKeydownState MessageType = 0x33 + TypeKeyboardMacroState MessageType = 0x34 + TypeKeyboardMacroTokenState MessageType = 0x35 ) type QueueIndex int @@ -31,26 +32,26 @@ type QueueIndex int const ( Version byte = 0x01 // Version of the HID RPC protocol HandshakeQueue int = 0 // Queue index for handshake messages - KeyboardQueue int = 1 // Queue index for keyboard and macro messages + KeyboardQueue int = 1 // Queue index for keyboard messages MouseQueue int = 2 // Queue index for mouse messages - MacroQueue int = 3 // Queue index for macro cancel messages + MacroQueue int = 3 // Queue index for macro messages OtherQueue int = 4 // Queue index for other messages ) // GetQueueIndex returns the index of the queue to which the message should be enqueued. -func GetQueueIndex(messageType MessageType) int { +func GetQueueIndex(messageType MessageType) (int, time.Duration) { switch messageType { case TypeHandshake: - return HandshakeQueue + return HandshakeQueue, 1 case TypeKeyboardReport, TypeKeypressReport, TypeKeyboardLedState, TypeKeydownState, TypeKeyboardMacroState: - return KeyboardQueue + return KeyboardQueue, 1 case TypePointerReport, TypeMouseReport, TypeWheelReport: - return MouseQueue + return MouseQueue, 1 // we don't want to block the queue for these messages - case TypeKeyboardMacroReport, TypeCancelKeyboardMacroReport, TypeCancelKeyboardMacroByTokenReport: - return MacroQueue + case TypeKeyboardMacroReport, TypeCancelKeyboardMacroReport, TypeKeyboardMacroTokenState: + return MacroQueue, 60 // 1 minute timeout default: - return OtherQueue + return OtherQueue, 5 } } diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go index 88ff6602..381801f4 100644 --- a/internal/hidrpc/message.go +++ b/internal/hidrpc/message.go @@ -69,11 +69,11 @@ func (m *Message) String() string { return fmt.Sprintf("CancelKeyboardMacroReport{Malformed: %v}", m.d) } return "CancelKeyboardMacroReport" - case TypeCancelKeyboardMacroByTokenReport: + case TypeKeyboardMacroTokenState: if len(m.d) != 16 { - return fmt.Sprintf("CancelKeyboardMacroByTokenReport{Malformed: %v}", m.d) + return fmt.Sprintf("KeyboardMacroTokenState{Malformed: %v}", m.d) } - return fmt.Sprintf("CancelKeyboardMacroByTokenReport{Token: %s}", uuid.Must(uuid.FromBytes(m.d)).String()) + return fmt.Sprintf("KeyboardMacroTokenState{Token: %s}", uuid.Must(uuid.FromBytes(m.d)).String()) case TypeKeyboardLedState: if len(m.d) < 1 { return fmt.Sprintf("KeyboardLedState{Malformed: %v}", m.d) @@ -246,19 +246,28 @@ func (m *Message) KeyboardMacroState() (KeyboardMacroState, error) { }, nil } -// KeyboardMacroToken returns the keyboard macro token UUID from the message. -func (m *Message) KeyboardMacroToken() (uuid.UUID, error) { - if m.t != TypeCancelKeyboardMacroByTokenReport { - return uuid.Nil, fmt.Errorf("invalid message type: %d", m.t) +type KeyboardMacroTokenState struct { + Token uuid.UUID +} + +// KeyboardMacroTokenState returns the keyboard macro token UUID from the message. +func (m *Message) KeyboardMacroTokenState() (KeyboardMacroTokenState, error) { + if m.t != TypeKeyboardMacroTokenState { + return KeyboardMacroTokenState{}, fmt.Errorf("invalid message type: %d", m.t) } if len(m.d) == 0 { - return uuid.Nil, nil + return KeyboardMacroTokenState{Token: uuid.Nil}, nil } if len(m.d) != 16 { - return uuid.Nil, fmt.Errorf("invalid UUID length: %d", len(m.d)) + return KeyboardMacroTokenState{}, fmt.Errorf("invalid UUID length: %d", len(m.d)) } - return uuid.FromBytes(m.d) + token, err := uuid.FromBytes(m.d) + if err != nil { + return KeyboardMacroTokenState{}, fmt.Errorf("invalid UUID: %v", err) + } + + return KeyboardMacroTokenState{Token: token}, nil } diff --git a/jsonrpc.go b/jsonrpc.go index 4a343951..7e0b9d24 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -1091,10 +1091,10 @@ func getKeyboardMacroCancelMap() map[uuid.UUID]RunningMacro { func addKeyboardMacro(isPaste bool, cancel context.CancelFunc) uuid.UUID { keyboardMacroLock.Lock() defer keyboardMacroLock.Unlock() - keyboardMacroCancelMap := getKeyboardMacroCancelMap() + cancelMap := getKeyboardMacroCancelMap() token := uuid.New() // Generate a unique token - keyboardMacroCancelMap[token] = RunningMacro{ + cancelMap[token] = RunningMacro{ isPaste: isPaste, cancel: cancel, } @@ -1104,19 +1104,19 @@ func addKeyboardMacro(isPaste bool, cancel context.CancelFunc) uuid.UUID { func removeRunningKeyboardMacro(token uuid.UUID) { keyboardMacroLock.Lock() defer keyboardMacroLock.Unlock() - keyboardMacroCancelMap := getKeyboardMacroCancelMap() + cancelMap := getKeyboardMacroCancelMap() - delete(keyboardMacroCancelMap, token) + delete(cancelMap, token) } func cancelRunningKeyboardMacro(token uuid.UUID) { keyboardMacroLock.Lock() defer keyboardMacroLock.Unlock() - keyboardMacroCancelMap := getKeyboardMacroCancelMap() + cancelMap := getKeyboardMacroCancelMap() - if runningMacro, exists := keyboardMacroCancelMap[token]; exists { + if runningMacro, exists := cancelMap[token]; exists { runningMacro.cancel() - delete(keyboardMacroCancelMap, token) + delete(cancelMap, token) logger.Info().Interface("token", token).Msg("canceled keyboard macro by token") } else { logger.Debug().Interface("token", token).Msg("no running keyboard macro found for token") @@ -1126,11 +1126,11 @@ func cancelRunningKeyboardMacro(token uuid.UUID) { func cancelAllRunningKeyboardMacros() { keyboardMacroLock.Lock() defer keyboardMacroLock.Unlock() - keyboardMacroCancelMap := getKeyboardMacroCancelMap() + cancelMap := getKeyboardMacroCancelMap() - for token, runningMacro := range keyboardMacroCancelMap { + for token, runningMacro := range cancelMap { runningMacro.cancel() - delete(keyboardMacroCancelMap, token) + delete(cancelMap, token) logger.Info().Interface("token", token).Msg("cancelled keyboard macro") } } @@ -1139,12 +1139,10 @@ func reportRunningMacrosState() { if currentSession != nil { keyboardMacroLock.Lock() defer keyboardMacroLock.Unlock() - keyboardMacroCancelMap := getKeyboardMacroCancelMap() + cancelMap := getKeyboardMacroCancelMap() isPaste := false - anyRunning := false - for _, runningMacro := range keyboardMacroCancelMap { - anyRunning = true + for _, runningMacro := range cancelMap { if runningMacro.isPaste { isPaste = true break @@ -1152,7 +1150,7 @@ func reportRunningMacrosState() { } state := hidrpc.KeyboardMacroState{ - State: anyRunning, + State: len(cancelMap) > 0, IsPaste: isPaste, } @@ -1194,7 +1192,10 @@ func rpcCancelKeyboardMacroByToken(token uuid.UUID) { } func executeKeyboardMacro(ctx context.Context, isPaste bool, macro []hidrpc.KeyboardMacroStep) error { - logger.Debug().Int("macro_steps", len(macro)).Msg("Executing keyboard macro") + logger.Debug(). + Int("macro_steps", len(macro)). + Bool("isPaste", isPaste). + Msg("Executing keyboard macro") // don't report keyboard state changes while executing the macro gadget.SuspendKeyDownMessages() diff --git a/webrtc.go b/webrtc.go index 73879bdf..c911e258 100644 --- a/webrtc.go +++ b/webrtc.go @@ -34,7 +34,7 @@ type Session struct { lastTimerResetTime time.Time // Track when auto-release timer was last reset keepAliveJitterLock sync.Mutex // Protect jitter compensation timing state hidQueueLock sync.Mutex - hidQueue []chan hidQueueMessage + hidQueues []chan hidQueueMessage keysDownStateQueue chan usbgadget.KeysDownState } @@ -48,7 +48,8 @@ func (s *Session) resetKeepAliveTime() { type hidQueueMessage struct { webrtc.DataChannelMessage - channel string + channel string + timelimit time.Duration } type SessionConfig struct { @@ -93,19 +94,20 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) { return base64.StdEncoding.EncodeToString(localDescription), nil } -func (s *Session) initQueues() { +func (s *Session) initHidQueues() { s.hidQueueLock.Lock() defer s.hidQueueLock.Unlock() - s.hidQueue = make([]chan hidQueueMessage, 0) - for i := 0; i <= hidrpc.OtherQueue; i++ { - q := make(chan hidQueueMessage, 256) - s.hidQueue = append(s.hidQueue, q) - } + s.hidQueues = make([]chan hidQueueMessage, hidrpc.OtherQueue+1) + s.hidQueues[hidrpc.HandshakeQueue] = make(chan hidQueueMessage, 2) // we don't really want to queue many handshake messages + s.hidQueues[hidrpc.KeyboardQueue] = make(chan hidQueueMessage, 256) + s.hidQueues[hidrpc.MouseQueue] = make(chan hidQueueMessage, 256) + s.hidQueues[hidrpc.MacroQueue] = make(chan hidQueueMessage, 10) // macros can be long, but we don't want to queue too many + s.hidQueues[hidrpc.OtherQueue] = make(chan hidQueueMessage, 256) } -func (s *Session) handleQueues(index int) { - for msg := range s.hidQueue[index] { +func (s *Session) handleQueue(queue chan hidQueueMessage) { + for msg := range queue { onHidMessage(msg, s) } } @@ -160,17 +162,18 @@ func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, chan l.Trace().Msg("received data in HID RPC message handler") // Enqueue to ensure ordered processing - queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) - if queueIndex >= len(session.hidQueue) || queueIndex < 0 { + queueIndex, timelimit := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + if queueIndex >= len(session.hidQueues) || queueIndex < 0 { l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") queueIndex = hidrpc.OtherQueue } - queue := session.hidQueue[queueIndex] + queue := session.hidQueues[queueIndex] if queue != nil { queue <- hidQueueMessage{ DataChannelMessage: msg, channel: channel, + timelimit: timelimit, } } else { l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") @@ -220,7 +223,7 @@ func newSession(config SessionConfig) (*Session, error) { session := &Session{peerConnection: peerConnection} session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) - session.initQueues() + session.initHidQueues() session.initKeysDownStateQueue() go func() { @@ -230,8 +233,8 @@ func newSession(config SessionConfig) (*Session, error) { } }() - for i := 0; i < len(session.hidQueue); i++ { - go session.handleQueues(i) + for queue := range session.hidQueues { + go session.handleQueue(session.hidQueues[queue]) } peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { @@ -256,7 +259,11 @@ func newSession(config SessionConfig) (*Session, error) { session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) { // Enqueue to ensure ordered processing - session.rpcQueue <- msg + if session.rpcQueue != nil { + session.rpcQueue <- msg + } else { + scopedLogger.Warn().Msg("RPC message received but rpcQueue is nil") + } }) triggerOTAStateUpdate() triggerVideoStateUpdate() @@ -325,22 +332,23 @@ func newSession(config SessionConfig) (*Session, error) { _ = peerConnection.Close() } if connectionState == webrtc.ICEConnectionStateClosed { - scopedLogger.Debug().Msg("ICE Connection State is closed, unmounting virtual media") + scopedLogger.Debug().Msg("ICE Connection State is closed, tearing down session") if session == currentSession { // Cancel any ongoing keyboard report multi when session closes cancelAllRunningKeyboardMacros() currentSession = nil } + // Stop RPC processor if session.rpcQueue != nil { close(session.rpcQueue) session.rpcQueue = nil } - // Stop HID RPC processor - for i := 0; i < len(session.hidQueue); i++ { - close(session.hidQueue[i]) - session.hidQueue[i] = nil + // Stop HID RPC processors + for i := 0; i < len(session.hidQueues); i++ { + close(session.hidQueues[i]) + session.hidQueues[i] = nil } close(session.keysDownStateQueue)