Compare commits

...

2 Commits

Author SHA1 Message Date
Marc Brooks 2dc32c2706
Merge b3b8614d38 into 8f4081a5b1 2025-09-27 06:52:32 +00:00
Marc Brooks b3b8614d38
Add a time limit for each message type/queue. 2025-09-27 01:50:09 -05:00
5 changed files with 96 additions and 74 deletions

View File

@ -53,13 +53,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) {
rpcCancelKeyboardMacro() rpcCancelKeyboardMacro()
return return
case hidrpc.TypeCancelKeyboardMacroByTokenReport: case hidrpc.TypeKeyboardMacroTokenState:
token, err := message.KeyboardMacroToken() tokenState, err := message.KeyboardMacroTokenState()
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("failed to get keyboard macro token") logger.Warn().Err(err).Msg("failed to get keyboard macro token")
return return
} }
rpcCancelKeyboardMacroByToken(token) rpcCancelKeyboardMacroByToken(tokenState.Token)
return return
case hidrpc.TypeKeypressKeepAliveReport: case hidrpc.TypeKeypressKeepAliveReport:
@ -96,6 +96,7 @@ func onHidMessage(msg hidQueueMessage, session *Session) {
scopedLogger := hidRPCLogger.With(). scopedLogger := hidRPCLogger.With().
Str("channel", msg.channel). Str("channel", msg.channel).
Dur("timelimit", msg.timelimit).
Int("data_len", dataLen). Int("data_len", dataLen).
Bytes("data", data[:min(dataLen, 32)]). Bytes("data", data[:min(dataLen, 32)]).
Logger() Logger()
@ -125,7 +126,7 @@ func onHidMessage(msg hidQueueMessage, session *Session) {
r <- nil r <- nil
}() }()
select { select {
case <-time.After(1 * time.Second): case <-time.After(msg.timelimit * time.Second):
scopedLogger.Warn().Msg("HID RPC message timed out") scopedLogger.Warn().Msg("HID RPC message timed out")
case <-r: case <-r:
scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled")
@ -241,6 +242,8 @@ func reportHidRPC(params any, session *Session) {
message, err = hidrpc.NewKeydownStateMessage(params).Marshal() message, err = hidrpc.NewKeydownStateMessage(params).Marshal()
case hidrpc.KeyboardMacroState: case hidrpc.KeyboardMacroState:
message, err = hidrpc.NewKeyboardMacroStateMessage(params.State, params.IsPaste).Marshal() message, err = hidrpc.NewKeyboardMacroStateMessage(params.State, params.IsPaste).Marshal()
case hidrpc.KeyboardMacroTokenState:
message, err = hidrpc.NewKeyboardMacroTokenMessage(params.Token).Marshal()
default: default:
err = fmt.Errorf("unknown HID RPC message type: %T", params) err = fmt.Errorf("unknown HID RPC message type: %T", params)
} }

View File

@ -2,6 +2,7 @@ package hidrpc
import ( import (
"fmt" "fmt"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/usbgadget"
@ -11,19 +12,19 @@ import (
type MessageType byte type MessageType byte
const ( const (
TypeHandshake MessageType = 0x01 TypeHandshake MessageType = 0x01
TypeKeyboardReport MessageType = 0x02 TypeKeyboardReport MessageType = 0x02
TypePointerReport MessageType = 0x03 TypePointerReport MessageType = 0x03
TypeWheelReport MessageType = 0x04 TypeWheelReport MessageType = 0x04
TypeKeypressReport MessageType = 0x05 TypeKeypressReport MessageType = 0x05
TypeKeypressKeepAliveReport MessageType = 0x09 TypeKeypressKeepAliveReport MessageType = 0x09
TypeMouseReport MessageType = 0x06 TypeMouseReport MessageType = 0x06
TypeKeyboardMacroReport MessageType = 0x07 TypeKeyboardMacroReport MessageType = 0x07
TypeCancelKeyboardMacroReport MessageType = 0x08 TypeCancelKeyboardMacroReport MessageType = 0x08
TypeKeyboardLedState MessageType = 0x32 TypeKeyboardLedState MessageType = 0x32
TypeKeydownState MessageType = 0x33 TypeKeydownState MessageType = 0x33
TypeKeyboardMacroState MessageType = 0x34 TypeKeyboardMacroState MessageType = 0x34
TypeCancelKeyboardMacroByTokenReport MessageType = 0x35 TypeKeyboardMacroTokenState MessageType = 0x35
) )
type QueueIndex int type QueueIndex int
@ -31,26 +32,26 @@ type QueueIndex int
const ( const (
Version byte = 0x01 // Version of the HID RPC protocol Version byte = 0x01 // Version of the HID RPC protocol
HandshakeQueue int = 0 // Queue index for handshake messages HandshakeQueue int = 0 // Queue index for handshake messages
KeyboardQueue int = 1 // Queue index for keyboard and macro messages KeyboardQueue int = 1 // Queue index for keyboard messages
MouseQueue int = 2 // Queue index for mouse messages MouseQueue int = 2 // Queue index for mouse messages
MacroQueue int = 3 // Queue index for macro cancel messages MacroQueue int = 3 // Queue index for macro messages
OtherQueue int = 4 // Queue index for other messages OtherQueue int = 4 // Queue index for other messages
) )
// GetQueueIndex returns the index of the queue to which the message should be enqueued. // GetQueueIndex returns the index of the queue to which the message should be enqueued.
func GetQueueIndex(messageType MessageType) int { func GetQueueIndex(messageType MessageType) (int, time.Duration) {
switch messageType { switch messageType {
case TypeHandshake: case TypeHandshake:
return HandshakeQueue return HandshakeQueue, 1
case TypeKeyboardReport, TypeKeypressReport, TypeKeyboardLedState, TypeKeydownState, TypeKeyboardMacroState: case TypeKeyboardReport, TypeKeypressReport, TypeKeyboardLedState, TypeKeydownState, TypeKeyboardMacroState:
return KeyboardQueue return KeyboardQueue, 1
case TypePointerReport, TypeMouseReport, TypeWheelReport: case TypePointerReport, TypeMouseReport, TypeWheelReport:
return MouseQueue return MouseQueue, 1
// we don't want to block the queue for these messages // we don't want to block the queue for these messages
case TypeKeyboardMacroReport, TypeCancelKeyboardMacroReport, TypeCancelKeyboardMacroByTokenReport: case TypeKeyboardMacroReport, TypeCancelKeyboardMacroReport, TypeKeyboardMacroTokenState:
return MacroQueue return MacroQueue, 60 // 1 minute timeout
default: default:
return OtherQueue return OtherQueue, 5
} }
} }

View File

@ -69,11 +69,11 @@ func (m *Message) String() string {
return fmt.Sprintf("CancelKeyboardMacroReport{Malformed: %v}", m.d) return fmt.Sprintf("CancelKeyboardMacroReport{Malformed: %v}", m.d)
} }
return "CancelKeyboardMacroReport" return "CancelKeyboardMacroReport"
case TypeCancelKeyboardMacroByTokenReport: case TypeKeyboardMacroTokenState:
if len(m.d) != 16 { if len(m.d) != 16 {
return fmt.Sprintf("CancelKeyboardMacroByTokenReport{Malformed: %v}", m.d) return fmt.Sprintf("KeyboardMacroTokenState{Malformed: %v}", m.d)
} }
return fmt.Sprintf("CancelKeyboardMacroByTokenReport{Token: %s}", uuid.Must(uuid.FromBytes(m.d)).String()) return fmt.Sprintf("KeyboardMacroTokenState{Token: %s}", uuid.Must(uuid.FromBytes(m.d)).String())
case TypeKeyboardLedState: case TypeKeyboardLedState:
if len(m.d) < 1 { if len(m.d) < 1 {
return fmt.Sprintf("KeyboardLedState{Malformed: %v}", m.d) return fmt.Sprintf("KeyboardLedState{Malformed: %v}", m.d)
@ -246,19 +246,28 @@ func (m *Message) KeyboardMacroState() (KeyboardMacroState, error) {
}, nil }, nil
} }
// KeyboardMacroToken returns the keyboard macro token UUID from the message. type KeyboardMacroTokenState struct {
func (m *Message) KeyboardMacroToken() (uuid.UUID, error) { Token uuid.UUID
if m.t != TypeCancelKeyboardMacroByTokenReport { }
return uuid.Nil, fmt.Errorf("invalid message type: %d", m.t)
// KeyboardMacroTokenState returns the keyboard macro token UUID from the message.
func (m *Message) KeyboardMacroTokenState() (KeyboardMacroTokenState, error) {
if m.t != TypeKeyboardMacroTokenState {
return KeyboardMacroTokenState{}, fmt.Errorf("invalid message type: %d", m.t)
} }
if len(m.d) == 0 { if len(m.d) == 0 {
return uuid.Nil, nil return KeyboardMacroTokenState{Token: uuid.Nil}, nil
} }
if len(m.d) != 16 { if len(m.d) != 16 {
return uuid.Nil, fmt.Errorf("invalid UUID length: %d", len(m.d)) return KeyboardMacroTokenState{}, fmt.Errorf("invalid UUID length: %d", len(m.d))
} }
return uuid.FromBytes(m.d) token, err := uuid.FromBytes(m.d)
if err != nil {
return KeyboardMacroTokenState{}, fmt.Errorf("invalid UUID: %v", err)
}
return KeyboardMacroTokenState{Token: token}, nil
} }

View File

@ -1091,10 +1091,10 @@ func getKeyboardMacroCancelMap() map[uuid.UUID]RunningMacro {
func addKeyboardMacro(isPaste bool, cancel context.CancelFunc) uuid.UUID { func addKeyboardMacro(isPaste bool, cancel context.CancelFunc) uuid.UUID {
keyboardMacroLock.Lock() keyboardMacroLock.Lock()
defer keyboardMacroLock.Unlock() defer keyboardMacroLock.Unlock()
keyboardMacroCancelMap := getKeyboardMacroCancelMap() cancelMap := getKeyboardMacroCancelMap()
token := uuid.New() // Generate a unique token token := uuid.New() // Generate a unique token
keyboardMacroCancelMap[token] = RunningMacro{ cancelMap[token] = RunningMacro{
isPaste: isPaste, isPaste: isPaste,
cancel: cancel, cancel: cancel,
} }
@ -1104,19 +1104,19 @@ func addKeyboardMacro(isPaste bool, cancel context.CancelFunc) uuid.UUID {
func removeRunningKeyboardMacro(token uuid.UUID) { func removeRunningKeyboardMacro(token uuid.UUID) {
keyboardMacroLock.Lock() keyboardMacroLock.Lock()
defer keyboardMacroLock.Unlock() defer keyboardMacroLock.Unlock()
keyboardMacroCancelMap := getKeyboardMacroCancelMap() cancelMap := getKeyboardMacroCancelMap()
delete(keyboardMacroCancelMap, token) delete(cancelMap, token)
} }
func cancelRunningKeyboardMacro(token uuid.UUID) { func cancelRunningKeyboardMacro(token uuid.UUID) {
keyboardMacroLock.Lock() keyboardMacroLock.Lock()
defer keyboardMacroLock.Unlock() defer keyboardMacroLock.Unlock()
keyboardMacroCancelMap := getKeyboardMacroCancelMap() cancelMap := getKeyboardMacroCancelMap()
if runningMacro, exists := keyboardMacroCancelMap[token]; exists { if runningMacro, exists := cancelMap[token]; exists {
runningMacro.cancel() runningMacro.cancel()
delete(keyboardMacroCancelMap, token) delete(cancelMap, token)
logger.Info().Interface("token", token).Msg("canceled keyboard macro by token") logger.Info().Interface("token", token).Msg("canceled keyboard macro by token")
} else { } else {
logger.Debug().Interface("token", token).Msg("no running keyboard macro found for token") logger.Debug().Interface("token", token).Msg("no running keyboard macro found for token")
@ -1126,11 +1126,11 @@ func cancelRunningKeyboardMacro(token uuid.UUID) {
func cancelAllRunningKeyboardMacros() { func cancelAllRunningKeyboardMacros() {
keyboardMacroLock.Lock() keyboardMacroLock.Lock()
defer keyboardMacroLock.Unlock() defer keyboardMacroLock.Unlock()
keyboardMacroCancelMap := getKeyboardMacroCancelMap() cancelMap := getKeyboardMacroCancelMap()
for token, runningMacro := range keyboardMacroCancelMap { for token, runningMacro := range cancelMap {
runningMacro.cancel() runningMacro.cancel()
delete(keyboardMacroCancelMap, token) delete(cancelMap, token)
logger.Info().Interface("token", token).Msg("cancelled keyboard macro") logger.Info().Interface("token", token).Msg("cancelled keyboard macro")
} }
} }
@ -1139,12 +1139,10 @@ func reportRunningMacrosState() {
if currentSession != nil { if currentSession != nil {
keyboardMacroLock.Lock() keyboardMacroLock.Lock()
defer keyboardMacroLock.Unlock() defer keyboardMacroLock.Unlock()
keyboardMacroCancelMap := getKeyboardMacroCancelMap() cancelMap := getKeyboardMacroCancelMap()
isPaste := false isPaste := false
anyRunning := false for _, runningMacro := range cancelMap {
for _, runningMacro := range keyboardMacroCancelMap {
anyRunning = true
if runningMacro.isPaste { if runningMacro.isPaste {
isPaste = true isPaste = true
break break
@ -1152,7 +1150,7 @@ func reportRunningMacrosState() {
} }
state := hidrpc.KeyboardMacroState{ state := hidrpc.KeyboardMacroState{
State: anyRunning, State: len(cancelMap) > 0,
IsPaste: isPaste, IsPaste: isPaste,
} }
@ -1194,7 +1192,10 @@ func rpcCancelKeyboardMacroByToken(token uuid.UUID) {
} }
func executeKeyboardMacro(ctx context.Context, isPaste bool, macro []hidrpc.KeyboardMacroStep) error { func executeKeyboardMacro(ctx context.Context, isPaste bool, macro []hidrpc.KeyboardMacroStep) error {
logger.Debug().Int("macro_steps", len(macro)).Msg("Executing keyboard macro") logger.Debug().
Int("macro_steps", len(macro)).
Bool("isPaste", isPaste).
Msg("Executing keyboard macro")
// don't report keyboard state changes while executing the macro // don't report keyboard state changes while executing the macro
gadget.SuspendKeyDownMessages() gadget.SuspendKeyDownMessages()

View File

@ -34,7 +34,7 @@ type Session struct {
lastTimerResetTime time.Time // Track when auto-release timer was last reset lastTimerResetTime time.Time // Track when auto-release timer was last reset
keepAliveJitterLock sync.Mutex // Protect jitter compensation timing state keepAliveJitterLock sync.Mutex // Protect jitter compensation timing state
hidQueueLock sync.Mutex hidQueueLock sync.Mutex
hidQueue []chan hidQueueMessage hidQueues []chan hidQueueMessage
keysDownStateQueue chan usbgadget.KeysDownState keysDownStateQueue chan usbgadget.KeysDownState
} }
@ -48,7 +48,8 @@ func (s *Session) resetKeepAliveTime() {
type hidQueueMessage struct { type hidQueueMessage struct {
webrtc.DataChannelMessage webrtc.DataChannelMessage
channel string channel string
timelimit time.Duration
} }
type SessionConfig struct { type SessionConfig struct {
@ -93,19 +94,20 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
return base64.StdEncoding.EncodeToString(localDescription), nil return base64.StdEncoding.EncodeToString(localDescription), nil
} }
func (s *Session) initQueues() { func (s *Session) initHidQueues() {
s.hidQueueLock.Lock() s.hidQueueLock.Lock()
defer s.hidQueueLock.Unlock() defer s.hidQueueLock.Unlock()
s.hidQueue = make([]chan hidQueueMessage, 0) s.hidQueues = make([]chan hidQueueMessage, hidrpc.OtherQueue+1)
for i := 0; i <= hidrpc.OtherQueue; i++ { s.hidQueues[hidrpc.HandshakeQueue] = make(chan hidQueueMessage, 2) // we don't really want to queue many handshake messages
q := make(chan hidQueueMessage, 256) s.hidQueues[hidrpc.KeyboardQueue] = make(chan hidQueueMessage, 256)
s.hidQueue = append(s.hidQueue, q) s.hidQueues[hidrpc.MouseQueue] = make(chan hidQueueMessage, 256)
} s.hidQueues[hidrpc.MacroQueue] = make(chan hidQueueMessage, 10) // macros can be long, but we don't want to queue too many
s.hidQueues[hidrpc.OtherQueue] = make(chan hidQueueMessage, 256)
} }
func (s *Session) handleQueues(index int) { func (s *Session) handleQueue(queue chan hidQueueMessage) {
for msg := range s.hidQueue[index] { for msg := range queue {
onHidMessage(msg, s) onHidMessage(msg, s)
} }
} }
@ -160,17 +162,18 @@ func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, chan
l.Trace().Msg("received data in HID RPC message handler") l.Trace().Msg("received data in HID RPC message handler")
// Enqueue to ensure ordered processing // Enqueue to ensure ordered processing
queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) queueIndex, timelimit := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0]))
if queueIndex >= len(session.hidQueue) || queueIndex < 0 { if queueIndex >= len(session.hidQueues) || queueIndex < 0 {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found")
queueIndex = hidrpc.OtherQueue queueIndex = hidrpc.OtherQueue
} }
queue := session.hidQueue[queueIndex] queue := session.hidQueues[queueIndex]
if queue != nil { if queue != nil {
queue <- hidQueueMessage{ queue <- hidQueueMessage{
DataChannelMessage: msg, DataChannelMessage: msg,
channel: channel, channel: channel,
timelimit: timelimit,
} }
} else { } else {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil")
@ -220,7 +223,7 @@ func newSession(config SessionConfig) (*Session, error) {
session := &Session{peerConnection: peerConnection} session := &Session{peerConnection: peerConnection}
session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) session.rpcQueue = make(chan webrtc.DataChannelMessage, 256)
session.initQueues() session.initHidQueues()
session.initKeysDownStateQueue() session.initKeysDownStateQueue()
go func() { go func() {
@ -230,8 +233,8 @@ func newSession(config SessionConfig) (*Session, error) {
} }
}() }()
for i := 0; i < len(session.hidQueue); i++ { for queue := range session.hidQueues {
go session.handleQueues(i) go session.handleQueue(session.hidQueues[queue])
} }
peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
@ -256,7 +259,11 @@ func newSession(config SessionConfig) (*Session, error) {
session.RPCChannel = d session.RPCChannel = d
d.OnMessage(func(msg webrtc.DataChannelMessage) { d.OnMessage(func(msg webrtc.DataChannelMessage) {
// Enqueue to ensure ordered processing // Enqueue to ensure ordered processing
session.rpcQueue <- msg if session.rpcQueue != nil {
session.rpcQueue <- msg
} else {
scopedLogger.Warn().Msg("RPC message received but rpcQueue is nil")
}
}) })
triggerOTAStateUpdate() triggerOTAStateUpdate()
triggerVideoStateUpdate() triggerVideoStateUpdate()
@ -325,22 +332,23 @@ func newSession(config SessionConfig) (*Session, error) {
_ = peerConnection.Close() _ = peerConnection.Close()
} }
if connectionState == webrtc.ICEConnectionStateClosed { if connectionState == webrtc.ICEConnectionStateClosed {
scopedLogger.Debug().Msg("ICE Connection State is closed, unmounting virtual media") scopedLogger.Debug().Msg("ICE Connection State is closed, tearing down session")
if session == currentSession { if session == currentSession {
// Cancel any ongoing keyboard report multi when session closes // Cancel any ongoing keyboard report multi when session closes
cancelAllRunningKeyboardMacros() cancelAllRunningKeyboardMacros()
currentSession = nil currentSession = nil
} }
// Stop RPC processor // Stop RPC processor
if session.rpcQueue != nil { if session.rpcQueue != nil {
close(session.rpcQueue) close(session.rpcQueue)
session.rpcQueue = nil session.rpcQueue = nil
} }
// Stop HID RPC processor // Stop HID RPC processors
for i := 0; i < len(session.hidQueue); i++ { for i := 0; i < len(session.hidQueues); i++ {
close(session.hidQueue[i]) close(session.hidQueues[i])
session.hidQueue[i] = nil session.hidQueues[i] = nil
} }
close(session.keysDownStateQueue) close(session.keysDownStateQueue)