refactor: simplify HID RPC keyboard input handling and improve key state management

- Updated `handleHidRPCKeyboardInput` to return errors directly instead of keys down state.
- Refactored `rpcKeyboardReport` and `rpcKeypressReport` to return errors instead of states.
- Introduced a queue for managing key down state updates in the `Session` struct to prevent input handling stalls.
- Adjusted the `UpdateKeysDown` method to handle state changes more efficiently.
- Removed unnecessary logging and commented-out code for clarity.
This commit is contained in:
Adam Shiervani 2025-09-16 01:48:13 +02:00 committed by Siyuan Miao
parent e9b252430f
commit e3eb8330fe
7 changed files with 174 additions and 179 deletions

View File

@ -27,11 +27,7 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) {
} }
session.hidRPCAvailable = true session.hidRPCAvailable = true
case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport:
keysDownState, err := handleHidRPCKeyboardInput(message) rpcErr = handleHidRPCKeyboardInput(message)
if keysDownState != nil {
session.reportHidRPCKeysDownState(*keysDownState)
}
rpcErr = err
case hidrpc.TypeKeyboardMacroReport: case hidrpc.TypeKeyboardMacroReport:
keyboardMacroReport, err := message.KeyboardMacroReport() keyboardMacroReport, err := message.KeyboardMacroReport()
if err != nil { if err != nil {
@ -107,27 +103,25 @@ func onHidMessage(msg hidQueueMessage, session *Session) {
} }
} }
func handleHidRPCKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { func handleHidRPCKeyboardInput(message hidrpc.Message) error {
switch message.Type() { switch message.Type() {
case hidrpc.TypeKeypressReport: case hidrpc.TypeKeypressReport:
keypressReport, err := message.KeypressReport() keypressReport, err := message.KeypressReport()
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("failed to get keypress report") logger.Warn().Err(err).Msg("failed to get keypress report")
return nil, err return err
} }
keysDownState, rpcError := rpcKeypressReport(keypressReport.Key, keypressReport.Press) return rpcKeypressReport(keypressReport.Key, keypressReport.Press)
return &keysDownState, rpcError
case hidrpc.TypeKeyboardReport: case hidrpc.TypeKeyboardReport:
keyboardReport, err := message.KeyboardReport() keyboardReport, err := message.KeyboardReport()
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("failed to get keyboard report") logger.Warn().Err(err).Msg("failed to get keyboard report")
return nil, err return err
} }
keysDownState, rpcError := rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) return rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys)
return &keysDownState, rpcError
} }
return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type()) return fmt.Errorf("unknown HID RPC message type: %d", message.Type())
} }
func reportHidRPC(params any, session *Session) { func reportHidRPC(params any, session *Session) {

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"sync"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
@ -153,36 +154,11 @@ func (u *UsbGadget) GetKeysDownState() KeysDownState {
return u.keysDownState return u.keysDownState
} }
func (u *UsbGadget) updateKeyDownState(state KeysDownState) {
u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("acquiring keyboardStateLock for updateKeyDownState")
// this is intentional to unlock keyboard state lock before onKeysDownChange callback
{
u.keyboardStateLock.Lock()
defer u.keyboardStateLock.Unlock()
if u.keysDownState.Modifier == state.Modifier &&
bytes.Equal(u.keysDownState.Keys, state.Keys) {
return // No change in key down state
}
u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("keysDownState updated")
u.keysDownState = state
}
if u.onKeysDownChange != nil {
u.log.Trace().Interface("state", state).Msg("calling onKeysDownChange")
(*u.onKeysDownChange)(state)
u.log.Trace().Interface("state", state).Msg("onKeysDownChange called")
}
}
func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) {
u.onKeysDownChange = &f u.onKeysDownChange = &f
} }
func (u *UsbGadget) scheduleAutoRelease(key byte) { func (u *UsbGadget) scheduleAutoRelease() {
u.log.Trace().Msg("scheduling autoRelease")
u.kbdAutoReleaseLock.Lock() u.kbdAutoReleaseLock.Lock()
defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled") defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled")
@ -196,7 +172,6 @@ func (u *UsbGadget) scheduleAutoRelease(key byte) {
} }
func (u *UsbGadget) cancelAutoRelease() { func (u *UsbGadget) cancelAutoRelease() {
u.log.Trace().Msg("cancelling autoRelease")
u.kbdAutoReleaseLock.Lock() u.kbdAutoReleaseLock.Lock()
defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled") defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled")
@ -206,7 +181,6 @@ func (u *UsbGadget) cancelAutoRelease() {
} }
func (u *UsbGadget) DelayAutoRelease() { func (u *UsbGadget) DelayAutoRelease() {
u.log.Trace().Msg("delaying autoRelease")
u.kbdAutoReleaseLock.Lock() u.kbdAutoReleaseLock.Lock()
defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed") defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed")
@ -215,33 +189,37 @@ func (u *UsbGadget) DelayAutoRelease() {
} }
u.kbdAutoReleaseTimer.Reset(autoReleaseKeyboardInterval) u.kbdAutoReleaseTimer.Reset(autoReleaseKeyboardInterval)
u.log.Trace().Msg("auto-release timer reset")
} }
func (u *UsbGadget) performAutoRelease() { func (u *UsbGadget) performAutoRelease() {
u.log.Trace().Msg("performing autoRelease")
u.kbdAutoReleaseLock.Lock()
defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease unlocked")
key := u.kbdAutoReleaseLastKey
select { select {
case <-u.keyboardStateCtx.Done(): case <-u.keyboardStateCtx.Done():
return return
default: default:
} }
// we just reset the keyboard state to 0 no matter what u.kbdAutoReleaseLock.Lock()
u.log.Trace().Uint8("key", key).Msg("auto-releasing keyboard key")
_, err := u.keypressReport(key, false, false) key := u.kbdAutoReleaseLastKey
if err != nil {
u.log.Warn().Uint8("key", key).Msg("failed to auto-release keyboard key") // Skip if already released
state := u.GetKeysDownState()
alreadyReleased := true
for i := range state.Keys {
if state.Keys[i] == key {
alreadyReleased = false
break
}
}
if alreadyReleased {
return
} }
u.kbdAutoReleaseTimer = nil u.kbdAutoReleaseTimer = nil
u.kbdAutoReleaseLock.Unlock()
u.log.Trace().Uint8("key", key).Msg("auto release performed") u.keypressReport(key, false) // autoRelease the ket
} }
func (u *UsbGadget) listenKeyboardEvents() { func (u *UsbGadget) listenKeyboardEvents() {
@ -313,7 +291,11 @@ func (u *UsbGadget) OpenKeyboardHidFile() error {
return u.openKeyboardHidFile() return u.openKeyboardHidFile()
} }
var keyboardWriteHidFileLock sync.Mutex
func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error {
keyboardWriteHidFileLock.Lock()
defer keyboardWriteHidFileLock.Unlock()
if err := u.openKeyboardHidFile(); err != nil { if err := u.openKeyboardHidFile(); err != nil {
return err return err
} }
@ -329,7 +311,7 @@ func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error {
return nil return nil
} }
func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) {
// if we just reported an error roll over, we should clear the keys // if we just reported an error roll over, we should clear the keys
if keys[0] == hidErrorRollOver { if keys[0] == hidErrorRollOver {
for i := range keys { for i := range keys {
@ -337,17 +319,29 @@ func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState {
} }
} }
downState := KeysDownState{ state := KeysDownState{
Modifier: modifier, Modifier: modifier,
Keys: []byte(keys[:]), Keys: []byte(keys[:]),
} }
u.updateKeyDownState(downState)
return downState u.keyboardStateLock.Lock()
if u.keysDownState.Modifier == state.Modifier &&
bytes.Equal(u.keysDownState.Keys, state.Keys) {
u.keyboardStateLock.Unlock()
return // No change in key down state
}
u.keysDownState = state
u.keyboardStateLock.Unlock()
if u.onKeysDownChange != nil {
(*u.onKeysDownChange)(state) // this enques to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...)
}
return
} }
func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, error) { func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error {
u.keyboardLock.Lock()
defer u.keyboardLock.Unlock()
defer u.resetUserInputTime() defer u.resetUserInputTime()
if len(keys) > hidKeyBufferSize { if len(keys) > hidKeyBufferSize {
@ -362,7 +356,8 @@ func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, e
u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0")
} }
return u.UpdateKeysDown(modifier, keys), err u.UpdateKeysDown(modifier, keys)
return err
} }
const ( const (
@ -402,10 +397,10 @@ var KeyCodeToMaskMap = map[byte]byte{
RightSuper: ModifierMaskRightSuper, RightSuper: ModifierMaskRightSuper,
} }
func (u *UsbGadget) keypressReportNonThreadSafe(key byte, press bool, autoRelease bool) (KeysDownState, error) { func (u *UsbGadget) keypressReport(key byte, press bool) error {
defer u.resetUserInputTime() defer u.resetUserInputTime()
l := u.log.With().Uint8("key", key).Bool("press", press).Bool("autoRelease", autoRelease).Logger() l := u.log.With().Uint8("key", key).Bool("press", press).Logger()
if l.GetLevel() <= zerolog.DebugLevel { if l.GetLevel() <= zerolog.DebugLevel {
requestID := xid.New() requestID := xid.New()
l = l.With().Str("requestID", requestID.String()).Logger() l = l.With().Str("requestID", requestID.String()).Logger()
@ -470,64 +465,22 @@ func (u *UsbGadget) keypressReportNonThreadSafe(key byte, press bool, autoReleas
} }
} }
if l.GetLevel() <= zerolog.DebugLevel {
l = l.With().Uint8("modifier", modifier).Uints8("keys", keys).Logger()
}
l.Trace().Msg("writing keypress report to hidg0")
err := u.keyboardWriteHidFile(modifier, keys) err := u.keyboardWriteHidFile(modifier, keys)
if err != nil { u.UpdateKeysDown(modifier, keys)
l.Warn().Msg("Could not write keypress report to hidg0") return err
} }
l.Trace().Msg("keypress report written to hidg0") func (u *UsbGadget) KeypressReport(key byte, press bool) error {
u.kbdAutoReleaseLock.Lock()
u.kbdAutoReleaseLastKey = key
u.kbdAutoReleaseLock.Unlock()
if press { if press {
{ u.scheduleAutoRelease()
l.Trace().Msg("acquiring kbdAutoReleaseLock to update last key")
u.kbdAutoReleaseLock.Lock()
u.kbdAutoReleaseLastKey = key
unlockWithLog(&u.kbdAutoReleaseLock, u.log, "last key updated")
}
if autoRelease {
u.scheduleAutoRelease(key)
}
} else { } else {
if autoRelease { u.cancelAutoRelease()
u.cancelAutoRelease()
}
} }
return u.UpdateKeysDown(modifier, keys), err err := u.keypressReport(key, press)
} return err
type keypressReportResult struct {
KeysDownState KeysDownState
Error error
}
func (u *UsbGadget) keypressReport(key byte, press bool, autoRelease bool) (KeysDownState, error) {
u.keyboardLock.Lock()
defer u.keyboardLock.Unlock()
r := make(chan keypressReportResult)
go func() {
state, err := u.keypressReportNonThreadSafe(key, press, autoRelease)
r <- keypressReportResult{KeysDownState: state, Error: err}
}()
select {
case <-time.After(1 * time.Second):
u.log.Warn().Msg("keypressReport timed out, possibly stuck")
return u.keysDownState, fmt.Errorf("keypressReport timed out, possibly stuck")
case ret := <-r:
u.log.Debug().Msg("keypressReport handled")
return ret.KeysDownState, ret.Error
}
}
func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) {
return u.keypressReport(key, press, true)
} }

View File

@ -1169,9 +1169,7 @@ var rpcHandlers = map[string]RPCHandler{
"getNetworkSettings": {Func: rpcGetNetworkSettings}, "getNetworkSettings": {Func: rpcGetNetworkSettings},
"setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}}, "setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}},
"renewDHCPLease": {Func: rpcRenewDHCPLease}, "renewDHCPLease": {Func: rpcRenewDHCPLease},
"keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}},
"getKeyboardLedState": {Func: rpcGetKeyboardLedState}, "getKeyboardLedState": {Func: rpcGetKeyboardLedState},
"keypressReport": {Func: rpcKeypressReport, Params: []string{"key", "press"}},
"getKeyDownState": {Func: rpcGetKeysDownState}, "getKeyDownState": {Func: rpcGetKeysDownState},
"absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}}, "absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}},
"relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}}, "relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}},

View File

@ -190,7 +190,7 @@ export default function WebRTCVideo() {
if (!isFullscreenEnabled || !videoElm.current) return; if (!isFullscreenEnabled || !videoElm.current) return;
// per https://wicg.github.io/keyboard-lock/#system-key-press-handler // per https://wicg.github.io/keyboard-lock/#system-key-press-handler
// If keyboard lock is activated after fullscreen is already in effect, then the user my // If keyboard lock is activated after fullscreen is already in effect, then the user my
// see multiple messages about how to exit fullscreen. For this reason, we recommend that // see multiple messages about how to exit fullscreen. For this reason, we recommend that
// developers call lock() before they enter fullscreen: // developers call lock() before they enter fullscreen:
await requestKeyboardLock(); await requestKeyboardLock();
@ -237,6 +237,7 @@ export default function WebRTCVideo() {
const keyDownHandler = useCallback( const keyDownHandler = useCallback(
(e: KeyboardEvent) => { (e: KeyboardEvent) => {
e.preventDefault(); e.preventDefault();
if (e.repeat) return;
const code = getAdjustedKeyCode(e); const code = getAdjustedKeyCode(e);
const hidKey = keys[code]; const hidKey = keys[code];

View File

@ -40,11 +40,16 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
}, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]); }, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]);
const rpcHidUnreliableReady = useMemo(() => { const rpcHidUnreliableReady = useMemo(() => {
return rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null; return (
rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null
);
}, [rpcHidUnreliableChannel, rpcHidProtocolVersion]); }, [rpcHidUnreliableChannel, rpcHidProtocolVersion]);
const rpcHidUnreliableNonOrderedReady = useMemo(() => { const rpcHidUnreliableNonOrderedReady = useMemo(() => {
return rpcHidUnreliableNonOrderedChannel?.readyState === "open" && rpcHidProtocolVersion !== null; return (
rpcHidUnreliableNonOrderedChannel?.readyState === "open" &&
rpcHidProtocolVersion !== null
);
}, [rpcHidUnreliableNonOrderedChannel, rpcHidProtocolVersion]); }, [rpcHidUnreliableNonOrderedChannel, rpcHidProtocolVersion]);
const rpcHidStatus = useMemo(() => { const rpcHidStatus = useMemo(() => {
@ -56,42 +61,53 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`; return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`;
}, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion, hidRpcDisabled]); }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion, hidRpcDisabled]);
const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel, requireOrdered = true }: sendMessageParams = {}) => { const sendMessage = useCallback(
if (hidRpcDisabled) return; (
message: RpcMessage,
{
ignoreHandshakeState,
useUnreliableChannel,
requireOrdered = true,
}: sendMessageParams = {},
) => {
if (hidRpcDisabled) return;
if (rpcHidChannel?.readyState !== "open") return; if (rpcHidChannel?.readyState !== "open") return;
if (!rpcHidReady && !ignoreHandshakeState) return; if (!rpcHidReady && !ignoreHandshakeState) return;
let data: Uint8Array | undefined; let data: Uint8Array | undefined;
try { try {
data = message.marshal(); data = message.marshal();
} catch (e) { } catch (e) {
console.error("Failed to send HID RPC message", e); console.error("Failed to send HID RPC message", e);
}
if (!data) return;
if (useUnreliableChannel) {
if (requireOrdered && rpcHidUnreliableReady) {
rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer);
} else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) {
rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer);
} }
return; if (!data) return;
}
rpcHidChannel?.send(data as unknown as ArrayBuffer); if (useUnreliableChannel) {
}, [ if (requireOrdered && rpcHidUnreliableReady) {
rpcHidChannel, rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer);
rpcHidUnreliableChannel, } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) {
hidRpcDisabled, rpcHidUnreliableNonOrderedChannel, rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer);
rpcHidReady, }
rpcHidUnreliableReady, return;
rpcHidUnreliableNonOrderedReady, }
]);
rpcHidChannel?.send(data as unknown as ArrayBuffer);
},
[
rpcHidChannel,
rpcHidUnreliableChannel,
hidRpcDisabled, rpcHidUnreliableNonOrderedChannel,
rpcHidReady,
rpcHidUnreliableReady,
rpcHidUnreliableNonOrderedReady,
],
);
const reportKeyboardEvent = useCallback( const reportKeyboardEvent = useCallback(
(keys: number[], modifier: number) => { (keys: number[], modifier: number) => {
sendMessage(new KeyboardReportMessage(keys, modifier)); sendMessage(new KeyboardReportMessage(keys, modifier));
}, [sendMessage], },
[sendMessage],
); );
const reportKeypressEvent = useCallback( const reportKeypressEvent = useCallback(
@ -103,7 +119,9 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
const reportAbsMouseEvent = useCallback( const reportAbsMouseEvent = useCallback(
(x: number, y: number, buttons: number) => { (x: number, y: number, buttons: number) => {
sendMessage(new PointerReportMessage(x, y, buttons), { useUnreliableChannel: true }); sendMessage(new PointerReportMessage(x, y, buttons), {
useUnreliableChannel: true,
});
}, },
[sendMessage], [sendMessage],
); );
@ -130,7 +148,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
); );
const reportKeypressKeepAlive = useCallback(() => { const reportKeypressKeepAlive = useCallback(() => {
sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true, requireOrdered: false }); sendMessage(KEEPALIVE_MESSAGE);
}, [sendMessage]); }, [sendMessage]);
const sendHandshake = useCallback(() => { const sendHandshake = useCallback(() => {
@ -141,24 +159,27 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true }); sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true });
}, [rpcHidChannel, rpcHidProtocolVersion, sendMessage, hidRpcDisabled]); }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage, hidRpcDisabled]);
const handleHandshake = useCallback((message: HandshakeMessage) => { const handleHandshake = useCallback(
if (hidRpcDisabled) return; (message: HandshakeMessage) => {
if (hidRpcDisabled) return;
if (!message.version) { if (!message.version) {
console.error("Received handshake message without version", message); console.error("Received handshake message without version", message);
return; return;
} }
if (message.version > HID_RPC_VERSION) { if (message.version > HID_RPC_VERSION) {
// we assume that the UI is always using the latest version of the HID RPC protocol // we assume that the UI is always using the latest version of the HID RPC protocol
// so we can't support this // so we can't support this
// TODO: use capabilities to determine rather than version number // TODO: use capabilities to determine rather than version number
console.error("Server is using a newer HID RPC version than the client", message); console.error("Server is using a newer HID RPC version than the client", message);
return; return;
} }
setRpcHidProtocolVersion(message.version); setRpcHidProtocolVersion(message.version);
}, [setRpcHidProtocolVersion, hidRpcDisabled]); },
[setRpcHidProtocolVersion, hidRpcDisabled],
);
useEffect(() => { useEffect(() => {
if (!rpcHidChannel) return; if (!rpcHidChannel) return;
@ -211,16 +232,14 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
rpcHidChannel.removeEventListener("close", closeHandler); rpcHidChannel.removeEventListener("close", closeHandler);
rpcHidChannel.removeEventListener("open", openHandler); rpcHidChannel.removeEventListener("open", openHandler);
}; };
}, }, [
[ rpcHidChannel,
rpcHidChannel, onHidRpcMessage,
onHidRpcMessage, setRpcHidProtocolVersion,
setRpcHidProtocolVersion, sendHandshake,
sendHandshake, handleHandshake,
handleHandshake,
hidRpcDisabled, hidRpcDisabled,
], ]);
);
return { return {
reportKeyboardEvent, reportKeyboardEvent,

6
usb.go
View File

@ -33,7 +33,7 @@ func initUsbGadget() {
gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) {
if currentSession != nil { if currentSession != nil {
currentSession.reportHidRPCKeysDownState(state) currentSession.enqueueKeysDownState(state)
} }
}) })
@ -43,11 +43,11 @@ func initUsbGadget() {
} }
} }
func rpcKeyboardReport(modifier byte, keys []byte) (usbgadget.KeysDownState, error) { func rpcKeyboardReport(modifier byte, keys []byte) error {
return gadget.KeyboardReport(modifier, keys) return gadget.KeyboardReport(modifier, keys)
} }
func rpcKeypressReport(key byte, press bool) (usbgadget.KeysDownState, error) { func rpcKeypressReport(key byte, press bool) error {
return gadget.KeypressReport(key, press) return gadget.KeypressReport(key, press)
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/hidrpc"
"github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/logging"
"github.com/jetkvm/kvm/internal/usbgadget"
"github.com/pion/webrtc/v4" "github.com/pion/webrtc/v4"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
@ -30,6 +31,8 @@ type Session struct {
hidRPCAvailable bool hidRPCAvailable bool
hidQueueLock sync.Mutex hidQueueLock sync.Mutex
hidQueue []chan hidQueueMessage hidQueue []chan hidQueueMessage
keysDownStateQueue chan usbgadget.KeysDownState
} }
type hidQueueMessage struct { type hidQueueMessage struct {
@ -96,6 +99,32 @@ func (s *Session) handleQueues(index int) {
} }
} }
const keysDownStateQueueSize = 256
func (s *Session) initKeysDownStateQueue() {
// serialise outbound key state reports so unreliable links can't stall input handling
s.keysDownStateQueue = make(chan usbgadget.KeysDownState, keysDownStateQueueSize)
go s.handleKeysDownStateQueue()
}
func (s *Session) handleKeysDownStateQueue() {
for state := range s.keysDownStateQueue {
s.reportHidRPCKeysDownState(state)
}
}
func (s *Session) enqueueKeysDownState(state usbgadget.KeysDownState) {
if s == nil || s.keysDownStateQueue == nil {
return
}
select {
case s.keysDownStateQueue <- state:
default:
hidRPCLogger.Warn().Msg("dropping keys down state update; queue full")
}
}
func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) {
return func(msg webrtc.DataChannelMessage) { return func(msg webrtc.DataChannelMessage) {
l := scopedLogger.With(). l := scopedLogger.With().
@ -181,6 +210,7 @@ func newSession(config SessionConfig) (*Session, error) {
session := &Session{peerConnection: peerConnection} session := &Session{peerConnection: peerConnection}
session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) session.rpcQueue = make(chan webrtc.DataChannelMessage, 256)
session.initQueues() session.initQueues()
session.initKeysDownStateQueue()
go func() { go func() {
for msg := range session.rpcQueue { for msg := range session.rpcQueue {