From b430384ca86c7e71efb416e767aacf97d6393166 Mon Sep 17 00:00:00 2001 From: Adam Shiervani Date: Tue, 16 Sep 2025 01:48:13 +0200 Subject: [PATCH] refactor: simplify HID RPC keyboard input handling and improve key state management - Updated `handleHidRPCKeyboardInput` to return errors directly instead of keys down state. - Refactored `rpcKeyboardReport` and `rpcKeypressReport` to return errors instead of states. - Introduced a queue for managing key down state updates in the `Session` struct to prevent input handling stalls. - Adjusted the `UpdateKeysDown` method to handle state changes more efficiently. - Removed unnecessary logging and commented-out code for clarity. --- hidrpc.go | 20 +-- internal/usbgadget/hid_keyboard.go | 163 ++++++++-------------- jsonrpc.go | 2 - ui/src/components/WebRTCVideo.tsx | 3 +- ui/src/components/popovers/PasteModal.tsx | 78 ++++------- ui/src/hooks/useHidRpc.ts | 129 +++++++++-------- usb.go | 6 +- webrtc.go | 30 ++++ 8 files changed, 198 insertions(+), 233 deletions(-) diff --git a/hidrpc.go b/hidrpc.go index 12824ac6..53d758d8 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -25,11 +25,7 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } session.hidRPCAvailable = true case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: - keysDownState, err := handleHidRPCKeyboardInput(message) - if keysDownState != nil { - session.reportHidRPCKeysDownState(*keysDownState) - } - rpcErr = err + rpcErr = handleHidRPCKeyboardInput(message) case hidrpc.TypeKeypressKeepAliveReport: gadget.DelayAutoRelease() case hidrpc.TypePointerReport: @@ -95,27 +91,25 @@ func onHidMessage(msg hidQueueMessage, session *Session) { } } -func handleHidRPCKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { +func handleHidRPCKeyboardInput(message hidrpc.Message) error { switch message.Type() { case hidrpc.TypeKeypressReport: keypressReport, err := message.KeypressReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keypress report") - return nil, err + return err } - keysDownState, rpcError := rpcKeypressReport(keypressReport.Key, keypressReport.Press) - return &keysDownState, rpcError + return rpcKeypressReport(keypressReport.Key, keypressReport.Press) case hidrpc.TypeKeyboardReport: keyboardReport, err := message.KeyboardReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keyboard report") - return nil, err + return err } - keysDownState, rpcError := rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) - return &keysDownState, rpcError + return rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) } - return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type()) + return fmt.Errorf("unknown HID RPC message type: %d", message.Type()) } func reportHidRPC(params any, session *Session) { diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index 53e232c0..fa3b2a0c 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "os" + "sync" "time" "github.com/rs/xid" @@ -153,36 +154,11 @@ func (u *UsbGadget) GetKeysDownState() KeysDownState { return u.keysDownState } -func (u *UsbGadget) updateKeyDownState(state KeysDownState) { - u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("acquiring keyboardStateLock for updateKeyDownState") - - // this is intentional to unlock keyboard state lock before onKeysDownChange callback - { - u.keyboardStateLock.Lock() - defer u.keyboardStateLock.Unlock() - - if u.keysDownState.Modifier == state.Modifier && - bytes.Equal(u.keysDownState.Keys, state.Keys) { - return // No change in key down state - } - - u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("keysDownState updated") - u.keysDownState = state - } - - if u.onKeysDownChange != nil { - u.log.Trace().Interface("state", state).Msg("calling onKeysDownChange") - (*u.onKeysDownChange)(state) - u.log.Trace().Interface("state", state).Msg("onKeysDownChange called") - } -} - func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { u.onKeysDownChange = &f } -func (u *UsbGadget) scheduleAutoRelease(key byte) { - u.log.Trace().Msg("scheduling autoRelease") +func (u *UsbGadget) scheduleAutoRelease() { u.kbdAutoReleaseLock.Lock() defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled") @@ -196,7 +172,6 @@ func (u *UsbGadget) scheduleAutoRelease(key byte) { } func (u *UsbGadget) cancelAutoRelease() { - u.log.Trace().Msg("cancelling autoRelease") u.kbdAutoReleaseLock.Lock() defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled") @@ -206,7 +181,6 @@ func (u *UsbGadget) cancelAutoRelease() { } func (u *UsbGadget) DelayAutoRelease() { - u.log.Trace().Msg("delaying autoRelease") u.kbdAutoReleaseLock.Lock() defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed") @@ -215,33 +189,37 @@ func (u *UsbGadget) DelayAutoRelease() { } u.kbdAutoReleaseTimer.Reset(autoReleaseKeyboardInterval) - - u.log.Trace().Msg("auto-release timer reset") } func (u *UsbGadget) performAutoRelease() { - u.log.Trace().Msg("performing autoRelease") - u.kbdAutoReleaseLock.Lock() - defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease unlocked") - - key := u.kbdAutoReleaseLastKey - select { case <-u.keyboardStateCtx.Done(): return default: } - // we just reset the keyboard state to 0 no matter what - u.log.Trace().Uint8("key", key).Msg("auto-releasing keyboard key") - _, err := u.keypressReport(key, false, false) - if err != nil { - u.log.Warn().Uint8("key", key).Msg("failed to auto-release keyboard key") + u.kbdAutoReleaseLock.Lock() + + key := u.kbdAutoReleaseLastKey + + // Skip if already released + state := u.GetKeysDownState() + alreadyReleased := true + for i := range state.Keys { + if state.Keys[i] == key { + alreadyReleased = false + break + } + } + + if alreadyReleased { + return } u.kbdAutoReleaseTimer = nil + u.kbdAutoReleaseLock.Unlock() - u.log.Trace().Uint8("key", key).Msg("auto release performed") + u.keypressReport(key, false) // autoRelease the ket } func (u *UsbGadget) listenKeyboardEvents() { @@ -313,7 +291,11 @@ func (u *UsbGadget) OpenKeyboardHidFile() error { return u.openKeyboardHidFile() } +var keyboardWriteHidFileLock sync.Mutex + func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { + keyboardWriteHidFileLock.Lock() + defer keyboardWriteHidFileLock.Unlock() if err := u.openKeyboardHidFile(); err != nil { return err } @@ -329,7 +311,7 @@ func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { return nil } -func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { +func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) { // if we just reported an error roll over, we should clear the keys if keys[0] == hidErrorRollOver { for i := range keys { @@ -337,17 +319,29 @@ func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { } } - downState := KeysDownState{ + state := KeysDownState{ Modifier: modifier, Keys: []byte(keys[:]), } - u.updateKeyDownState(downState) - return downState + + u.keyboardStateLock.Lock() + + if u.keysDownState.Modifier == state.Modifier && + bytes.Equal(u.keysDownState.Keys, state.Keys) { + u.keyboardStateLock.Unlock() + return // No change in key down state + } + + u.keysDownState = state + u.keyboardStateLock.Unlock() + + if u.onKeysDownChange != nil { + (*u.onKeysDownChange)(state) // this enques to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...) + } + return } -func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() +func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error { defer u.resetUserInputTime() if len(keys) > hidKeyBufferSize { @@ -362,7 +356,8 @@ func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, e u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") } - return u.UpdateKeysDown(modifier, keys), err + u.UpdateKeysDown(modifier, keys) + return err } const ( @@ -402,10 +397,10 @@ var KeyCodeToMaskMap = map[byte]byte{ RightSuper: ModifierMaskRightSuper, } -func (u *UsbGadget) keypressReportNonThreadSafe(key byte, press bool, autoRelease bool) (KeysDownState, error) { +func (u *UsbGadget) keypressReport(key byte, press bool) error { defer u.resetUserInputTime() - l := u.log.With().Uint8("key", key).Bool("press", press).Bool("autoRelease", autoRelease).Logger() + l := u.log.With().Uint8("key", key).Bool("press", press).Logger() if l.GetLevel() <= zerolog.DebugLevel { requestID := xid.New() l = l.With().Str("requestID", requestID.String()).Logger() @@ -470,64 +465,22 @@ func (u *UsbGadget) keypressReportNonThreadSafe(key byte, press bool, autoReleas } } - if l.GetLevel() <= zerolog.DebugLevel { - l = l.With().Uint8("modifier", modifier).Uints8("keys", keys).Logger() - } - - l.Trace().Msg("writing keypress report to hidg0") - err := u.keyboardWriteHidFile(modifier, keys) - if err != nil { - l.Warn().Msg("Could not write keypress report to hidg0") - } + u.UpdateKeysDown(modifier, keys) + return err +} - l.Trace().Msg("keypress report written to hidg0") +func (u *UsbGadget) KeypressReport(key byte, press bool) error { + u.kbdAutoReleaseLock.Lock() + u.kbdAutoReleaseLastKey = key + u.kbdAutoReleaseLock.Unlock() if press { - { - l.Trace().Msg("acquiring kbdAutoReleaseLock to update last key") - u.kbdAutoReleaseLock.Lock() - u.kbdAutoReleaseLastKey = key - unlockWithLog(&u.kbdAutoReleaseLock, u.log, "last key updated") - } - - if autoRelease { - u.scheduleAutoRelease(key) - } + u.scheduleAutoRelease() } else { - if autoRelease { - u.cancelAutoRelease() - } + u.cancelAutoRelease() } - return u.UpdateKeysDown(modifier, keys), err -} - -type keypressReportResult struct { - KeysDownState KeysDownState - Error error -} - -func (u *UsbGadget) keypressReport(key byte, press bool, autoRelease bool) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() - - r := make(chan keypressReportResult) - go func() { - state, err := u.keypressReportNonThreadSafe(key, press, autoRelease) - r <- keypressReportResult{KeysDownState: state, Error: err} - }() - - select { - case <-time.After(1 * time.Second): - u.log.Warn().Msg("keypressReport timed out, possibly stuck") - return u.keysDownState, fmt.Errorf("keypressReport timed out, possibly stuck") - case ret := <-r: - u.log.Debug().Msg("keypressReport handled") - return ret.KeysDownState, ret.Error - } -} - -func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) { - return u.keypressReport(key, press, true) + err := u.keypressReport(key, press) + return err } diff --git a/jsonrpc.go b/jsonrpc.go index 61f28df5..8136a704 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -1066,9 +1066,7 @@ var rpcHandlers = map[string]RPCHandler{ "getNetworkSettings": {Func: rpcGetNetworkSettings}, "setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}}, "renewDHCPLease": {Func: rpcRenewDHCPLease}, - "keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}}, "getKeyboardLedState": {Func: rpcGetKeyboardLedState}, - "keypressReport": {Func: rpcKeypressReport, Params: []string{"key", "press"}}, "getKeyDownState": {Func: rpcGetKeysDownState}, "absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}}, "relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}}, diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index bc6897ea..64452bf8 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -190,7 +190,7 @@ export default function WebRTCVideo() { if (!isFullscreenEnabled || !videoElm.current) return; // per https://wicg.github.io/keyboard-lock/#system-key-press-handler - // If keyboard lock is activated after fullscreen is already in effect, then the user my + // If keyboard lock is activated after fullscreen is already in effect, then the user my // see multiple messages about how to exit fullscreen. For this reason, we recommend that // developers call lock() before they enter fullscreen: await requestKeyboardLock(); @@ -237,6 +237,7 @@ export default function WebRTCVideo() { const keyDownHandler = useCallback( (e: KeyboardEvent) => { e.preventDefault(); + if (e.repeat) return; const code = getAdjustedKeyCode(e); const hidKey = keys[code]; diff --git a/ui/src/components/popovers/PasteModal.tsx b/ui/src/components/popovers/PasteModal.tsx index 077759b7..fc11476f 100644 --- a/ui/src/components/popovers/PasteModal.tsx +++ b/ui/src/components/popovers/PasteModal.tsx @@ -9,20 +9,12 @@ import { TextAreaWithLabel } from "@components/TextArea"; import { SettingsPageHeader } from "@components/SettingsPageheader"; import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; import { useHidStore, useRTCStore, useUiStore, useSettingsStore } from "@/hooks/stores"; -import { keys, modifiers } from "@/keyboardMappings"; -import { KeyStroke } from "@/keyboardLayouts"; +import { keys } from "@/keyboardMappings"; import useKeyboardLayout from "@/hooks/useKeyboardLayout"; -import notifications from "@/notifications"; -const hidKeyboardPayload = (modifier: number, keys: number[]) => { - return { modifier, keys }; -}; +import useKeyboard from "../../hooks/useKeyboard"; + -const modifierCode = (shift?: boolean, altRight?: boolean) => { - return (shift ? modifiers.ShiftLeft : 0) - | (altRight ? modifiers.AltRight : 0) -} -const noModifier = 0 export default function PasteModal() { const TextAreaRef = useRef(null); @@ -34,9 +26,9 @@ export default function PasteModal() { const [invalidChars, setInvalidChars] = useState([]); const close = useClose(); - + const { handleKeyPress } = useKeyboard(); const { setKeyboardLayout } = useSettingsStore(); - const { selectedKeyboard } = useKeyboardLayout(); + const { selectedKeyboard } = useKeyboardLayout(); useEffect(() => { send("getKeyboardLayout", {}, (resp: JsonRpcResponse) => { @@ -58,51 +50,29 @@ export default function PasteModal() { if (rpcDataChannel?.readyState !== "open" || !TextAreaRef.current) return; if (!selectedKeyboard) return; - const text = TextAreaRef.current.value; - try { - for (const char of text) { - const keyprops = selectedKeyboard.chars[char]; - if (!keyprops) continue; - - const { key, shift, altRight, deadKey, accentKey } = keyprops; - if (!key) continue; - - // if this is an accented character, we need to send that accent FIRST - if (accentKey) { - await sendKeystroke({modifier: modifierCode(accentKey.shift, accentKey.altRight), keys: [ keys[accentKey.key] ] }) - } - - // now send the actual key - await sendKeystroke({ modifier: modifierCode(shift, altRight), keys: [ keys[key] ]}); - - // if what was requested was a dead key, we need to send an unmodified space to emit - // just the accent character - if (deadKey) { - await sendKeystroke({ modifier: noModifier, keys: [ keys["Space"] ] }); - } - - // now send a message with no keys down to "release" the keys - await sendKeystroke({ modifier: 0, keys: [] }); + for (let i = 0; i < 5; i++) { + for (let i = 0; i < 26; i++) { + handleKeyPress(keys[`Key${String.fromCharCode(65 + i)}`], true); + await new Promise(resolve => setTimeout(resolve, 50)); + handleKeyPress(keys[`Key${String.fromCharCode(65 + i)}`], false); + await new Promise(resolve => setTimeout(resolve, 50)); } - } catch (error) { - console.error("Failed to paste text:", error); - notifications.error("Failed to paste text"); + await new Promise(resolve => setTimeout(resolve, 50)); + handleKeyPress(keys.Enter, true); + await new Promise(resolve => setTimeout(resolve, 50)); + handleKeyPress(keys.Enter, false); + await new Promise(resolve => setTimeout(resolve, 50)); } - async function sendKeystroke(stroke: KeyStroke) { - await new Promise((resolve, reject) => { - send( - "keyboardReport", - hidKeyboardPayload(stroke.modifier, stroke.keys), - params => { - if ("error" in params) return reject(params.error); - resolve(); - } - ); - }); - } - }, [selectedKeyboard, rpcDataChannel?.readyState, send, setDisableVideoFocusTrap, setPasteModeEnabled]); + + // for (let index = 0; index < 2; index++) { + // handleKeyPress(keys.KeyA, true); + // await new Promise(resolve => setTimeout(resolve, 3000)); + // handleKeyPress(keys.KeyA, false); + // } + + }, [setPasteModeEnabled, setDisableVideoFocusTrap, rpcDataChannel?.readyState, selectedKeyboard, handleKeyPress]); useEffect(() => { if (TextAreaRef.current) { diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts index 0670acb8..5480148f 100644 --- a/ui/src/hooks/useHidRpc.ts +++ b/ui/src/hooks/useHidRpc.ts @@ -36,11 +36,16 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { }, [rpcHidChannel, rpcHidProtocolVersion]); const rpcHidUnreliableReady = useMemo(() => { - return rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + return ( + rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null + ); }, [rpcHidUnreliableChannel, rpcHidProtocolVersion]); const rpcHidUnreliableNonOrderedReady = useMemo(() => { - return rpcHidUnreliableNonOrderedChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + return ( + rpcHidUnreliableNonOrderedChannel?.readyState === "open" && + rpcHidProtocolVersion !== null + ); }, [rpcHidUnreliableNonOrderedChannel, rpcHidProtocolVersion]); const rpcHidStatus = useMemo(() => { @@ -50,41 +55,52 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`; }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion]); - const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel, requireOrdered = true }: sendMessageParams = {}) => { - if (rpcHidChannel?.readyState !== "open") return; - if (!rpcHidReady && !ignoreHandshakeState) return; + const sendMessage = useCallback( + ( + message: RpcMessage, + { + ignoreHandshakeState, + useUnreliableChannel, + requireOrdered = true, + }: sendMessageParams = {}, + ) => { + if (rpcHidChannel?.readyState !== "open") return; + if (!rpcHidReady && !ignoreHandshakeState) return; - let data: Uint8Array | undefined; - try { - data = message.marshal(); - } catch (e) { - console.error("Failed to send HID RPC message", e); - } - if (!data) return; - - if (useUnreliableChannel) { - if (requireOrdered && rpcHidUnreliableReady) { - rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); - } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) { - rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer); + let data: Uint8Array | undefined; + try { + data = message.marshal(); + } catch (e) { + console.error("Failed to send HID RPC message", e); } - return; - } + if (!data) return; - rpcHidChannel?.send(data as unknown as ArrayBuffer); - }, [ - rpcHidChannel, - rpcHidUnreliableChannel, - rpcHidUnreliableNonOrderedChannel, - rpcHidReady, - rpcHidUnreliableReady, - rpcHidUnreliableNonOrderedReady, - ]); + if (useUnreliableChannel) { + if (requireOrdered && rpcHidUnreliableReady) { + rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); + } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) { + rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer); + } + return; + } + + rpcHidChannel?.send(data as unknown as ArrayBuffer); + }, + [ + rpcHidChannel, + rpcHidUnreliableChannel, + rpcHidUnreliableNonOrderedChannel, + rpcHidReady, + rpcHidUnreliableReady, + rpcHidUnreliableNonOrderedReady, + ], + ); const reportKeyboardEvent = useCallback( (keys: number[], modifier: number) => { sendMessage(new KeyboardReportMessage(keys, modifier)); - }, [sendMessage], + }, + [sendMessage], ); const reportKeypressEvent = useCallback( @@ -96,7 +112,9 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { const reportAbsMouseEvent = useCallback( (x: number, y: number, buttons: number) => { - sendMessage(new PointerReportMessage(x, y, buttons), { useUnreliableChannel: true }); + sendMessage(new PointerReportMessage(x, y, buttons), { + useUnreliableChannel: true, + }); }, [sendMessage], ); @@ -109,7 +127,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { ); const reportKeypressKeepAlive = useCallback(() => { - sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true, requireOrdered: false }); + sendMessage(KEEPALIVE_MESSAGE); }, [sendMessage]); const sendHandshake = useCallback(() => { @@ -119,22 +137,25 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true }); }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage]); - const handleHandshake = useCallback((message: HandshakeMessage) => { - if (!message.version) { - console.error("Received handshake message without version", message); - return; - } + const handleHandshake = useCallback( + (message: HandshakeMessage) => { + if (!message.version) { + console.error("Received handshake message without version", message); + return; + } - if (message.version > HID_RPC_VERSION) { - // we assume that the UI is always using the latest version of the HID RPC protocol - // so we can't support this - // TODO: use capabilities to determine rather than version number - console.error("Server is using a newer HID RPC version than the client", message); - return; - } + if (message.version > HID_RPC_VERSION) { + // we assume that the UI is always using the latest version of the HID RPC protocol + // so we can't support this + // TODO: use capabilities to determine rather than version number + console.error("Server is using a newer HID RPC version than the client", message); + return; + } - setRpcHidProtocolVersion(message.version); - }, [setRpcHidProtocolVersion]); + setRpcHidProtocolVersion(message.version); + }, + [setRpcHidProtocolVersion], + ); useEffect(() => { if (!rpcHidChannel) return; @@ -186,15 +207,13 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { rpcHidChannel.removeEventListener("close", closeHandler); rpcHidChannel.removeEventListener("open", openHandler); }; - }, - [ - rpcHidChannel, - onHidRpcMessage, - setRpcHidProtocolVersion, - sendHandshake, - handleHandshake, - ], - ); + }, [ + rpcHidChannel, + onHidRpcMessage, + setRpcHidProtocolVersion, + sendHandshake, + handleHandshake, + ]); return { reportKeyboardEvent, diff --git a/usb.go b/usb.go index 131cd517..f63a6aab 100644 --- a/usb.go +++ b/usb.go @@ -33,7 +33,7 @@ func initUsbGadget() { gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { if currentSession != nil { - currentSession.reportHidRPCKeysDownState(state) + currentSession.enqueueKeysDownState(state) } }) @@ -43,11 +43,11 @@ func initUsbGadget() { } } -func rpcKeyboardReport(modifier byte, keys []byte) (usbgadget.KeysDownState, error) { +func rpcKeyboardReport(modifier byte, keys []byte) error { return gadget.KeyboardReport(modifier, keys) } -func rpcKeypressReport(key byte, press bool) (usbgadget.KeysDownState, error) { +func rpcKeypressReport(key byte, press bool) error { return gadget.KeypressReport(key, press) } diff --git a/webrtc.go b/webrtc.go index 333a58b8..9593e627 100644 --- a/webrtc.go +++ b/webrtc.go @@ -13,6 +13,7 @@ import ( "github.com/gin-gonic/gin" "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/usbgadget" "github.com/pion/webrtc/v4" "github.com/rs/zerolog" ) @@ -30,6 +31,8 @@ type Session struct { hidRPCAvailable bool hidQueueLock sync.Mutex hidQueue []chan hidQueueMessage + + keysDownStateQueue chan usbgadget.KeysDownState } type hidQueueMessage struct { @@ -96,6 +99,32 @@ func (s *Session) handleQueues(index int) { } } +const keysDownStateQueueSize = 256 + +func (s *Session) initKeysDownStateQueue() { + // serialise outbound key state reports so unreliable links can't stall input handling + s.keysDownStateQueue = make(chan usbgadget.KeysDownState, keysDownStateQueueSize) + go s.handleKeysDownStateQueue() +} + +func (s *Session) handleKeysDownStateQueue() { + for state := range s.keysDownStateQueue { + s.reportHidRPCKeysDownState(state) + } +} + +func (s *Session) enqueueKeysDownState(state usbgadget.KeysDownState) { + if s == nil || s.keysDownStateQueue == nil { + return + } + + select { + case s.keysDownStateQueue <- state: + default: + hidRPCLogger.Warn().Msg("dropping keys down state update; queue full") + } +} + func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { return func(msg webrtc.DataChannelMessage) { l := scopedLogger.With(). @@ -181,6 +210,7 @@ func newSession(config SessionConfig) (*Session, error) { session := &Session{peerConnection: peerConnection} session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) session.initQueues() + session.initKeysDownStateQueue() go func() { for msg := range session.rpcQueue {