From e3eb8330fec64ab7911970787476f4b5886959af Mon Sep 17 00:00:00 2001 From: Adam Shiervani Date: Tue, 16 Sep 2025 01:48:13 +0200 Subject: [PATCH] refactor: simplify HID RPC keyboard input handling and improve key state management - Updated `handleHidRPCKeyboardInput` to return errors directly instead of keys down state. - Refactored `rpcKeyboardReport` and `rpcKeypressReport` to return errors instead of states. - Introduced a queue for managing key down state updates in the `Session` struct to prevent input handling stalls. - Adjusted the `UpdateKeysDown` method to handle state changes more efficiently. - Removed unnecessary logging and commented-out code for clarity. --- hidrpc.go | 20 ++-- internal/usbgadget/hid_keyboard.go | 163 ++++++++++------------------- jsonrpc.go | 2 - ui/src/components/WebRTCVideo.tsx | 3 +- ui/src/hooks/useHidRpc.ts | 129 +++++++++++++---------- usb.go | 6 +- webrtc.go | 30 ++++++ 7 files changed, 174 insertions(+), 179 deletions(-) diff --git a/hidrpc.go b/hidrpc.go index 46a8ced7..ffcfd241 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -27,11 +27,7 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } session.hidRPCAvailable = true case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: - keysDownState, err := handleHidRPCKeyboardInput(message) - if keysDownState != nil { - session.reportHidRPCKeysDownState(*keysDownState) - } - rpcErr = err + rpcErr = handleHidRPCKeyboardInput(message) case hidrpc.TypeKeyboardMacroReport: keyboardMacroReport, err := message.KeyboardMacroReport() if err != nil { @@ -107,27 +103,25 @@ func onHidMessage(msg hidQueueMessage, session *Session) { } } -func handleHidRPCKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { +func handleHidRPCKeyboardInput(message hidrpc.Message) error { switch message.Type() { case hidrpc.TypeKeypressReport: keypressReport, err := message.KeypressReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keypress report") - return nil, err + return err } - keysDownState, rpcError := rpcKeypressReport(keypressReport.Key, keypressReport.Press) - return &keysDownState, rpcError + return rpcKeypressReport(keypressReport.Key, keypressReport.Press) case hidrpc.TypeKeyboardReport: keyboardReport, err := message.KeyboardReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keyboard report") - return nil, err + return err } - keysDownState, rpcError := rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) - return &keysDownState, rpcError + return rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) } - return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type()) + return fmt.Errorf("unknown HID RPC message type: %d", message.Type()) } func reportHidRPC(params any, session *Session) { diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index 53e232c0..fa3b2a0c 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "os" + "sync" "time" "github.com/rs/xid" @@ -153,36 +154,11 @@ func (u *UsbGadget) GetKeysDownState() KeysDownState { return u.keysDownState } -func (u *UsbGadget) updateKeyDownState(state KeysDownState) { - u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("acquiring keyboardStateLock for updateKeyDownState") - - // this is intentional to unlock keyboard state lock before onKeysDownChange callback - { - u.keyboardStateLock.Lock() - defer u.keyboardStateLock.Unlock() - - if u.keysDownState.Modifier == state.Modifier && - bytes.Equal(u.keysDownState.Keys, state.Keys) { - return // No change in key down state - } - - u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("keysDownState updated") - u.keysDownState = state - } - - if u.onKeysDownChange != nil { - u.log.Trace().Interface("state", state).Msg("calling onKeysDownChange") - (*u.onKeysDownChange)(state) - u.log.Trace().Interface("state", state).Msg("onKeysDownChange called") - } -} - func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { u.onKeysDownChange = &f } -func (u *UsbGadget) scheduleAutoRelease(key byte) { - u.log.Trace().Msg("scheduling autoRelease") +func (u *UsbGadget) scheduleAutoRelease() { u.kbdAutoReleaseLock.Lock() defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled") @@ -196,7 +172,6 @@ func (u *UsbGadget) scheduleAutoRelease(key byte) { } func (u *UsbGadget) cancelAutoRelease() { - u.log.Trace().Msg("cancelling autoRelease") u.kbdAutoReleaseLock.Lock() defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled") @@ -206,7 +181,6 @@ func (u *UsbGadget) cancelAutoRelease() { } func (u *UsbGadget) DelayAutoRelease() { - u.log.Trace().Msg("delaying autoRelease") u.kbdAutoReleaseLock.Lock() defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed") @@ -215,33 +189,37 @@ func (u *UsbGadget) DelayAutoRelease() { } u.kbdAutoReleaseTimer.Reset(autoReleaseKeyboardInterval) - - u.log.Trace().Msg("auto-release timer reset") } func (u *UsbGadget) performAutoRelease() { - u.log.Trace().Msg("performing autoRelease") - u.kbdAutoReleaseLock.Lock() - defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease unlocked") - - key := u.kbdAutoReleaseLastKey - select { case <-u.keyboardStateCtx.Done(): return default: } - // we just reset the keyboard state to 0 no matter what - u.log.Trace().Uint8("key", key).Msg("auto-releasing keyboard key") - _, err := u.keypressReport(key, false, false) - if err != nil { - u.log.Warn().Uint8("key", key).Msg("failed to auto-release keyboard key") + u.kbdAutoReleaseLock.Lock() + + key := u.kbdAutoReleaseLastKey + + // Skip if already released + state := u.GetKeysDownState() + alreadyReleased := true + for i := range state.Keys { + if state.Keys[i] == key { + alreadyReleased = false + break + } + } + + if alreadyReleased { + return } u.kbdAutoReleaseTimer = nil + u.kbdAutoReleaseLock.Unlock() - u.log.Trace().Uint8("key", key).Msg("auto release performed") + u.keypressReport(key, false) // autoRelease the ket } func (u *UsbGadget) listenKeyboardEvents() { @@ -313,7 +291,11 @@ func (u *UsbGadget) OpenKeyboardHidFile() error { return u.openKeyboardHidFile() } +var keyboardWriteHidFileLock sync.Mutex + func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { + keyboardWriteHidFileLock.Lock() + defer keyboardWriteHidFileLock.Unlock() if err := u.openKeyboardHidFile(); err != nil { return err } @@ -329,7 +311,7 @@ func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { return nil } -func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { +func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) { // if we just reported an error roll over, we should clear the keys if keys[0] == hidErrorRollOver { for i := range keys { @@ -337,17 +319,29 @@ func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { } } - downState := KeysDownState{ + state := KeysDownState{ Modifier: modifier, Keys: []byte(keys[:]), } - u.updateKeyDownState(downState) - return downState + + u.keyboardStateLock.Lock() + + if u.keysDownState.Modifier == state.Modifier && + bytes.Equal(u.keysDownState.Keys, state.Keys) { + u.keyboardStateLock.Unlock() + return // No change in key down state + } + + u.keysDownState = state + u.keyboardStateLock.Unlock() + + if u.onKeysDownChange != nil { + (*u.onKeysDownChange)(state) // this enques to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...) + } + return } -func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() +func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error { defer u.resetUserInputTime() if len(keys) > hidKeyBufferSize { @@ -362,7 +356,8 @@ func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, e u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") } - return u.UpdateKeysDown(modifier, keys), err + u.UpdateKeysDown(modifier, keys) + return err } const ( @@ -402,10 +397,10 @@ var KeyCodeToMaskMap = map[byte]byte{ RightSuper: ModifierMaskRightSuper, } -func (u *UsbGadget) keypressReportNonThreadSafe(key byte, press bool, autoRelease bool) (KeysDownState, error) { +func (u *UsbGadget) keypressReport(key byte, press bool) error { defer u.resetUserInputTime() - l := u.log.With().Uint8("key", key).Bool("press", press).Bool("autoRelease", autoRelease).Logger() + l := u.log.With().Uint8("key", key).Bool("press", press).Logger() if l.GetLevel() <= zerolog.DebugLevel { requestID := xid.New() l = l.With().Str("requestID", requestID.String()).Logger() @@ -470,64 +465,22 @@ func (u *UsbGadget) keypressReportNonThreadSafe(key byte, press bool, autoReleas } } - if l.GetLevel() <= zerolog.DebugLevel { - l = l.With().Uint8("modifier", modifier).Uints8("keys", keys).Logger() - } - - l.Trace().Msg("writing keypress report to hidg0") - err := u.keyboardWriteHidFile(modifier, keys) - if err != nil { - l.Warn().Msg("Could not write keypress report to hidg0") - } + u.UpdateKeysDown(modifier, keys) + return err +} - l.Trace().Msg("keypress report written to hidg0") +func (u *UsbGadget) KeypressReport(key byte, press bool) error { + u.kbdAutoReleaseLock.Lock() + u.kbdAutoReleaseLastKey = key + u.kbdAutoReleaseLock.Unlock() if press { - { - l.Trace().Msg("acquiring kbdAutoReleaseLock to update last key") - u.kbdAutoReleaseLock.Lock() - u.kbdAutoReleaseLastKey = key - unlockWithLog(&u.kbdAutoReleaseLock, u.log, "last key updated") - } - - if autoRelease { - u.scheduleAutoRelease(key) - } + u.scheduleAutoRelease() } else { - if autoRelease { - u.cancelAutoRelease() - } + u.cancelAutoRelease() } - return u.UpdateKeysDown(modifier, keys), err -} - -type keypressReportResult struct { - KeysDownState KeysDownState - Error error -} - -func (u *UsbGadget) keypressReport(key byte, press bool, autoRelease bool) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() - - r := make(chan keypressReportResult) - go func() { - state, err := u.keypressReportNonThreadSafe(key, press, autoRelease) - r <- keypressReportResult{KeysDownState: state, Error: err} - }() - - select { - case <-time.After(1 * time.Second): - u.log.Warn().Msg("keypressReport timed out, possibly stuck") - return u.keysDownState, fmt.Errorf("keypressReport timed out, possibly stuck") - case ret := <-r: - u.log.Debug().Msg("keypressReport handled") - return ret.KeysDownState, ret.Error - } -} - -func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) { - return u.keypressReport(key, press, true) + err := u.keypressReport(key, press) + return err } diff --git a/jsonrpc.go b/jsonrpc.go index 6be8633b..759325b4 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -1169,9 +1169,7 @@ var rpcHandlers = map[string]RPCHandler{ "getNetworkSettings": {Func: rpcGetNetworkSettings}, "setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}}, "renewDHCPLease": {Func: rpcRenewDHCPLease}, - "keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}}, "getKeyboardLedState": {Func: rpcGetKeyboardLedState}, - "keypressReport": {Func: rpcKeypressReport, Params: []string{"key", "press"}}, "getKeyDownState": {Func: rpcGetKeysDownState}, "absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}}, "relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}}, diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index bc6897ea..64452bf8 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -190,7 +190,7 @@ export default function WebRTCVideo() { if (!isFullscreenEnabled || !videoElm.current) return; // per https://wicg.github.io/keyboard-lock/#system-key-press-handler - // If keyboard lock is activated after fullscreen is already in effect, then the user my + // If keyboard lock is activated after fullscreen is already in effect, then the user my // see multiple messages about how to exit fullscreen. For this reason, we recommend that // developers call lock() before they enter fullscreen: await requestKeyboardLock(); @@ -237,6 +237,7 @@ export default function WebRTCVideo() { const keyDownHandler = useCallback( (e: KeyboardEvent) => { e.preventDefault(); + if (e.repeat) return; const code = getAdjustedKeyCode(e); const hidKey = keys[code]; diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts index 57a54e96..e0664e52 100644 --- a/ui/src/hooks/useHidRpc.ts +++ b/ui/src/hooks/useHidRpc.ts @@ -40,11 +40,16 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { }, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]); const rpcHidUnreliableReady = useMemo(() => { - return rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + return ( + rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null + ); }, [rpcHidUnreliableChannel, rpcHidProtocolVersion]); const rpcHidUnreliableNonOrderedReady = useMemo(() => { - return rpcHidUnreliableNonOrderedChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + return ( + rpcHidUnreliableNonOrderedChannel?.readyState === "open" && + rpcHidProtocolVersion !== null + ); }, [rpcHidUnreliableNonOrderedChannel, rpcHidProtocolVersion]); const rpcHidStatus = useMemo(() => { @@ -56,42 +61,53 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`; }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion, hidRpcDisabled]); - const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel, requireOrdered = true }: sendMessageParams = {}) => { - if (hidRpcDisabled) return; + const sendMessage = useCallback( + ( + message: RpcMessage, + { + ignoreHandshakeState, + useUnreliableChannel, + requireOrdered = true, + }: sendMessageParams = {}, + ) => { + if (hidRpcDisabled) return; if (rpcHidChannel?.readyState !== "open") return; - if (!rpcHidReady && !ignoreHandshakeState) return; + if (!rpcHidReady && !ignoreHandshakeState) return; - let data: Uint8Array | undefined; - try { - data = message.marshal(); - } catch (e) { - console.error("Failed to send HID RPC message", e); - } - if (!data) return; - - if (useUnreliableChannel) { - if (requireOrdered && rpcHidUnreliableReady) { - rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); - } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) { - rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer); + let data: Uint8Array | undefined; + try { + data = message.marshal(); + } catch (e) { + console.error("Failed to send HID RPC message", e); } - return; - } + if (!data) return; - rpcHidChannel?.send(data as unknown as ArrayBuffer); - }, [ - rpcHidChannel, - rpcHidUnreliableChannel, - hidRpcDisabled, rpcHidUnreliableNonOrderedChannel, - rpcHidReady, - rpcHidUnreliableReady, - rpcHidUnreliableNonOrderedReady, - ]); + if (useUnreliableChannel) { + if (requireOrdered && rpcHidUnreliableReady) { + rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); + } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) { + rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer); + } + return; + } + + rpcHidChannel?.send(data as unknown as ArrayBuffer); + }, + [ + rpcHidChannel, + rpcHidUnreliableChannel, + hidRpcDisabled, rpcHidUnreliableNonOrderedChannel, + rpcHidReady, + rpcHidUnreliableReady, + rpcHidUnreliableNonOrderedReady, + ], + ); const reportKeyboardEvent = useCallback( (keys: number[], modifier: number) => { sendMessage(new KeyboardReportMessage(keys, modifier)); - }, [sendMessage], + }, + [sendMessage], ); const reportKeypressEvent = useCallback( @@ -103,7 +119,9 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { const reportAbsMouseEvent = useCallback( (x: number, y: number, buttons: number) => { - sendMessage(new PointerReportMessage(x, y, buttons), { useUnreliableChannel: true }); + sendMessage(new PointerReportMessage(x, y, buttons), { + useUnreliableChannel: true, + }); }, [sendMessage], ); @@ -130,7 +148,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { ); const reportKeypressKeepAlive = useCallback(() => { - sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true, requireOrdered: false }); + sendMessage(KEEPALIVE_MESSAGE); }, [sendMessage]); const sendHandshake = useCallback(() => { @@ -141,24 +159,27 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true }); }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage, hidRpcDisabled]); - const handleHandshake = useCallback((message: HandshakeMessage) => { - if (hidRpcDisabled) return; + const handleHandshake = useCallback( + (message: HandshakeMessage) => { + if (hidRpcDisabled) return; if (!message.version) { - console.error("Received handshake message without version", message); - return; - } + console.error("Received handshake message without version", message); + return; + } - if (message.version > HID_RPC_VERSION) { - // we assume that the UI is always using the latest version of the HID RPC protocol - // so we can't support this - // TODO: use capabilities to determine rather than version number - console.error("Server is using a newer HID RPC version than the client", message); - return; - } + if (message.version > HID_RPC_VERSION) { + // we assume that the UI is always using the latest version of the HID RPC protocol + // so we can't support this + // TODO: use capabilities to determine rather than version number + console.error("Server is using a newer HID RPC version than the client", message); + return; + } - setRpcHidProtocolVersion(message.version); - }, [setRpcHidProtocolVersion, hidRpcDisabled]); + setRpcHidProtocolVersion(message.version); + }, + [setRpcHidProtocolVersion, hidRpcDisabled], + ); useEffect(() => { if (!rpcHidChannel) return; @@ -211,16 +232,14 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { rpcHidChannel.removeEventListener("close", closeHandler); rpcHidChannel.removeEventListener("open", openHandler); }; - }, - [ - rpcHidChannel, - onHidRpcMessage, - setRpcHidProtocolVersion, - sendHandshake, - handleHandshake, + }, [ + rpcHidChannel, + onHidRpcMessage, + setRpcHidProtocolVersion, + sendHandshake, + handleHandshake, hidRpcDisabled, - ], - ); + ]); return { reportKeyboardEvent, diff --git a/usb.go b/usb.go index 131cd517..f63a6aab 100644 --- a/usb.go +++ b/usb.go @@ -33,7 +33,7 @@ func initUsbGadget() { gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { if currentSession != nil { - currentSession.reportHidRPCKeysDownState(state) + currentSession.enqueueKeysDownState(state) } }) @@ -43,11 +43,11 @@ func initUsbGadget() { } } -func rpcKeyboardReport(modifier byte, keys []byte) (usbgadget.KeysDownState, error) { +func rpcKeyboardReport(modifier byte, keys []byte) error { return gadget.KeyboardReport(modifier, keys) } -func rpcKeypressReport(key byte, press bool) (usbgadget.KeysDownState, error) { +func rpcKeypressReport(key byte, press bool) error { return gadget.KeypressReport(key, press) } diff --git a/webrtc.go b/webrtc.go index 707a6a9d..9d5c49d4 100644 --- a/webrtc.go +++ b/webrtc.go @@ -13,6 +13,7 @@ import ( "github.com/gin-gonic/gin" "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/usbgadget" "github.com/pion/webrtc/v4" "github.com/rs/zerolog" ) @@ -30,6 +31,8 @@ type Session struct { hidRPCAvailable bool hidQueueLock sync.Mutex hidQueue []chan hidQueueMessage + + keysDownStateQueue chan usbgadget.KeysDownState } type hidQueueMessage struct { @@ -96,6 +99,32 @@ func (s *Session) handleQueues(index int) { } } +const keysDownStateQueueSize = 256 + +func (s *Session) initKeysDownStateQueue() { + // serialise outbound key state reports so unreliable links can't stall input handling + s.keysDownStateQueue = make(chan usbgadget.KeysDownState, keysDownStateQueueSize) + go s.handleKeysDownStateQueue() +} + +func (s *Session) handleKeysDownStateQueue() { + for state := range s.keysDownStateQueue { + s.reportHidRPCKeysDownState(state) + } +} + +func (s *Session) enqueueKeysDownState(state usbgadget.KeysDownState) { + if s == nil || s.keysDownStateQueue == nil { + return + } + + select { + case s.keysDownStateQueue <- state: + default: + hidRPCLogger.Warn().Msg("dropping keys down state update; queue full") + } +} + func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { return func(msg webrtc.DataChannelMessage) { l := scopedLogger.With(). @@ -181,6 +210,7 @@ func newSession(config SessionConfig) (*Session, error) { session := &Session{peerConnection: peerConnection} session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) session.initQueues() + session.initKeysDownStateQueue() go func() { for msg := range session.rpcQueue {