diff --git a/go.mod b/go.mod index 962c3a1b..d07ba239 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gin-contrib/logger v1.2.6 github.com/gin-gonic/gin v1.10.1 github.com/go-co-op/gocron/v2 v2.16.5 + github.com/google/flatbuffers v25.2.10+incompatible github.com/google/uuid v1.6.0 github.com/guregu/null/v6 v6.0.0 github.com/gwatts/rootcerts v0.0.0-20250901182336-dc5ae18bd79f @@ -23,6 +24,7 @@ require ( github.com/prometheus/common v0.66.0 github.com/prometheus/procfs v0.17.0 github.com/psanford/httpreadat v0.1.0 + github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index e19fa9e6..57576a3a 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,8 @@ github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAu github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= +github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -152,6 +154,7 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= diff --git a/hidrpc.go b/hidrpc.go index 74fe687f..12824ac6 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -6,6 +6,7 @@ import ( "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/usbgadget" + "github.com/rs/zerolog" ) func handleHidRPCMessage(message hidrpc.Message, session *Session) { @@ -29,6 +30,8 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { session.reportHidRPCKeysDownState(*keysDownState) } rpcErr = err + case hidrpc.TypeKeypressKeepAliveReport: + gadget.DelayAutoRelease() case hidrpc.TypePointerReport: pointerReport, err := message.PointerReport() if err != nil { @@ -52,8 +55,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } } -func onHidMessage(data []byte, session *Session) { - scopedLogger := hidRPCLogger.With().Bytes("data", data).Logger() +func onHidMessage(msg hidQueueMessage, session *Session) { + data := msg.Data + + scopedLogger := hidRPCLogger.With(). + Str("channel", msg.channel). + Bytes("data", data). + Logger() scopedLogger.Debug().Msg("HID RPC message received") if len(data) < 1 { @@ -68,7 +76,9 @@ func onHidMessage(data []byte, session *Session) { return } - scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + if scopedLogger.GetLevel() <= zerolog.DebugLevel { + scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + } t := time.Now() @@ -115,7 +125,10 @@ func reportHidRPC(params any, session *Session) { } if !session.hidRPCAvailable || session.HidChannel == nil { - logger.Warn().Msg("HID RPC is not available, skipping reportHidRPC") + logger.Warn(). + Bool("hidRPCAvailable", session.hidRPCAvailable). + Bool("HidChannel", session.HidChannel != nil). + Msg("HID RPC is not available, skipping reportHidRPC") return } @@ -156,7 +169,9 @@ func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) { func (s *Session) reportHidRPCKeysDownState(state usbgadget.KeysDownState) { if !s.hidRPCAvailable { + usbLogger.Debug().Interface("state", state).Msg("reporting keys down state") writeJSONRPCEvent("keysDownState", state, s) } + usbLogger.Debug().Interface("state", state).Msg("reporting keys down state, calling reportHidRPC") reportHidRPC(state, s) } diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go index e9c8c24d..7161dc8f 100644 --- a/internal/hidrpc/hidrpc.go +++ b/internal/hidrpc/hidrpc.go @@ -10,14 +10,15 @@ import ( type MessageType byte const ( - TypeHandshake MessageType = 0x01 - TypeKeyboardReport MessageType = 0x02 - TypePointerReport MessageType = 0x03 - TypeWheelReport MessageType = 0x04 - TypeKeypressReport MessageType = 0x05 - TypeMouseReport MessageType = 0x06 - TypeKeyboardLedState MessageType = 0x32 - TypeKeydownState MessageType = 0x33 + TypeHandshake MessageType = 0x01 + TypeKeyboardReport MessageType = 0x02 + TypePointerReport MessageType = 0x03 + TypeWheelReport MessageType = 0x04 + TypeKeypressReport MessageType = 0x05 + TypeKeypressKeepAliveReport MessageType = 0x09 + TypeMouseReport MessageType = 0x06 + TypeKeyboardLedState MessageType = 0x32 + TypeKeydownState MessageType = 0x33 ) const ( @@ -98,3 +99,11 @@ func NewKeydownStateMessage(state usbgadget.KeysDownState) *Message { d: data, } } + +// NewKeypressKeepAliveMessage creates a new keypress keep alive message. +func NewKeypressKeepAliveMessage() *Message { + return &Message{ + t: TypeKeypressKeepAliveReport, + d: []byte{}, + } +} diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go index 84bbda7c..e0f4493e 100644 --- a/internal/hidrpc/message.go +++ b/internal/hidrpc/message.go @@ -43,6 +43,8 @@ func (m *Message) String() string { return fmt.Sprintf("MouseReport{Malformed: %v}", m.d) } return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2]) + case TypeKeypressKeepAliveReport: + return "KeypressKeepAliveReport" default: return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d) } diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index fb710c20..53e232c0 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -6,6 +6,9 @@ import ( "fmt" "os" "time" + + "github.com/rs/xid" + "github.com/rs/zerolog" ) var keyboardConfig = gadgetConfigItem{ @@ -22,6 +25,11 @@ var keyboardConfig = gadgetConfigItem{ reportDesc: keyboardReportDesc, } +// macOS default: 15 * 15 = 225ms https://discussions.apple.com/thread/1316947?sortBy=rank +// Linux default: 250ms https://man.archlinux.org/man/kbdrate.8.en +// Windows default: 1ms `HKEY_CURRENT_USER\Control Panel\Accessibility\Keyboard Response\AutoRepeatDelay` +const autoReleaseKeyboardInterval = time.Millisecond * 100 + // Source: https://www.kernel.org/doc/Documentation/usb/gadget_hid.txt var keyboardReportDesc = []byte{ 0x05, 0x01, /* USAGE_PAGE (Generic Desktop) */ @@ -173,6 +181,69 @@ func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { u.onKeysDownChange = &f } +func (u *UsbGadget) scheduleAutoRelease(key byte) { + u.log.Trace().Msg("scheduling autoRelease") + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled") + + if u.kbdAutoReleaseTimer != nil { + u.kbdAutoReleaseTimer.Stop() + } + + u.kbdAutoReleaseTimer = time.AfterFunc(autoReleaseKeyboardInterval, func() { + u.performAutoRelease() + }) +} + +func (u *UsbGadget) cancelAutoRelease() { + u.log.Trace().Msg("cancelling autoRelease") + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled") + + if u.kbdAutoReleaseTimer != nil { + u.kbdAutoReleaseTimer.Stop() + } +} + +func (u *UsbGadget) DelayAutoRelease() { + u.log.Trace().Msg("delaying autoRelease") + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed") + + if u.kbdAutoReleaseTimer == nil { + return + } + + u.kbdAutoReleaseTimer.Reset(autoReleaseKeyboardInterval) + + u.log.Trace().Msg("auto-release timer reset") +} + +func (u *UsbGadget) performAutoRelease() { + u.log.Trace().Msg("performing autoRelease") + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease unlocked") + + key := u.kbdAutoReleaseLastKey + + select { + case <-u.keyboardStateCtx.Done(): + return + default: + } + + // we just reset the keyboard state to 0 no matter what + u.log.Trace().Uint8("key", key).Msg("auto-releasing keyboard key") + _, err := u.keypressReport(key, false, false) + if err != nil { + u.log.Warn().Uint8("key", key).Msg("failed to auto-release keyboard key") + } + + u.kbdAutoReleaseTimer = nil + + u.log.Trace().Uint8("key", key).Msg("auto release performed") +} + func (u *UsbGadget) listenKeyboardEvents() { var path string if u.keyboardHidFile != nil { @@ -331,17 +402,23 @@ var KeyCodeToMaskMap = map[byte]byte{ RightSuper: ModifierMaskRightSuper, } -func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() +func (u *UsbGadget) keypressReportNonThreadSafe(key byte, press bool, autoRelease bool) (KeysDownState, error) { defer u.resetUserInputTime() + l := u.log.With().Uint8("key", key).Bool("press", press).Bool("autoRelease", autoRelease).Logger() + if l.GetLevel() <= zerolog.DebugLevel { + requestID := xid.New() + l = l.With().Str("requestID", requestID.String()).Logger() + } + // IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver // for handling key presses and releases. It ensures that the USB gadget // behaves similarly to a real USB HID keyboard. This logic is paralleled // in the client/browser-side code in useKeyboard.ts so make sure to keep // them in sync. - var state = u.keysDownState + var state = u.GetKeysDownState() + l.Trace().Interface("state", state).Msg("got keys down state") + modifier := state.Modifier keys := append([]byte(nil), state.Keys...) @@ -381,22 +458,76 @@ func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) // If we reach here it means we didn't find an empty slot or the key in the buffer if overrun { if press { - u.log.Error().Uint8("key", key).Msg("keyboard buffer overflow, key not added") + l.Error().Msg("keyboard buffer overflow, key not added") // Fill all key slots with ErrorRollOver (0x01) to indicate overflow for i := range keys { keys[i] = hidErrorRollOver } } else { // If we are releasing a key, and we didn't find it in a slot, who cares? - u.log.Warn().Uint8("key", key).Msg("key not found in buffer, nothing to release") + l.Warn().Msg("key not found in buffer, nothing to release") } } } + if l.GetLevel() <= zerolog.DebugLevel { + l = l.With().Uint8("modifier", modifier).Uints8("keys", keys).Logger() + } + + l.Trace().Msg("writing keypress report to hidg0") + err := u.keyboardWriteHidFile(modifier, keys) if err != nil { - u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keypress report to hidg0") + l.Warn().Msg("Could not write keypress report to hidg0") + } + + l.Trace().Msg("keypress report written to hidg0") + + if press { + { + l.Trace().Msg("acquiring kbdAutoReleaseLock to update last key") + u.kbdAutoReleaseLock.Lock() + u.kbdAutoReleaseLastKey = key + unlockWithLog(&u.kbdAutoReleaseLock, u.log, "last key updated") + } + + if autoRelease { + u.scheduleAutoRelease(key) + } + } else { + if autoRelease { + u.cancelAutoRelease() + } } return u.UpdateKeysDown(modifier, keys), err } + +type keypressReportResult struct { + KeysDownState KeysDownState + Error error +} + +func (u *UsbGadget) keypressReport(key byte, press bool, autoRelease bool) (KeysDownState, error) { + u.keyboardLock.Lock() + defer u.keyboardLock.Unlock() + + r := make(chan keypressReportResult) + go func() { + state, err := u.keypressReportNonThreadSafe(key, press, autoRelease) + r <- keypressReportResult{KeysDownState: state, Error: err} + }() + + select { + case <-time.After(1 * time.Second): + u.log.Warn().Msg("keypressReport timed out, possibly stuck") + return u.keysDownState, fmt.Errorf("keypressReport timed out, possibly stuck") + case ret := <-r: + u.log.Debug().Msg("keypressReport handled") + return ret.KeysDownState, ret.Error + } +} + +func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) { + return u.keypressReport(key, press, true) +} diff --git a/internal/usbgadget/usbgadget.go b/internal/usbgadget/usbgadget.go index 3a01a447..e491a78c 100644 --- a/internal/usbgadget/usbgadget.go +++ b/internal/usbgadget/usbgadget.go @@ -68,6 +68,10 @@ type UsbGadget struct { keyboardState byte // keyboard latched state (NumLock, CapsLock, ScrollLock, Compose, Kana) keysDownState KeysDownState // keyboard dynamic state (modifier keys and pressed keys) + kbdAutoReleaseLock sync.Mutex + kbdAutoReleaseTimer *time.Timer + kbdAutoReleaseLastKey byte + keyboardStateLock sync.Mutex keyboardStateCtx context.Context keyboardStateCancel context.CancelFunc @@ -149,3 +153,35 @@ func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDev return g } + +// Close cleans up resources used by the USB gadget +func (u *UsbGadget) Close() error { + // Cancel keyboard state context + if u.keyboardStateCancel != nil { + u.keyboardStateCancel() + } + + // Stop auto-release timer + u.kbdAutoReleaseLock.Lock() + if u.kbdAutoReleaseTimer != nil { + u.kbdAutoReleaseTimer.Stop() + u.kbdAutoReleaseTimer = nil + } + u.kbdAutoReleaseLock.Unlock() + + // Close HID files + if u.keyboardHidFile != nil { + u.keyboardHidFile.Close() + u.keyboardHidFile = nil + } + if u.absMouseHidFile != nil { + u.absMouseHidFile.Close() + u.absMouseHidFile = nil + } + if u.relMouseHidFile != nil { + u.relMouseHidFile.Close() + u.relMouseHidFile = nil + } + + return nil +} diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index d51f9e40..85bf1579 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -9,6 +9,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "github.com/rs/zerolog" @@ -120,6 +121,12 @@ func (u *UsbGadget) writeWithTimeout(file *os.File, data []byte) (n int, err err return } + u.log.Trace(). + Str("file", file.Name()). + Bytes("data", data). + Err(err). + Msg("write failed") + if errors.Is(err, os.ErrDeadlineExceeded) { u.logWithSuppression( fmt.Sprintf("writeWithTimeout_%s", file.Name()), @@ -164,3 +171,8 @@ func (u *UsbGadget) resetLogSuppressionCounter(counterName string) { u.logSuppressionCounter[counterName] = 0 } } + +func unlockWithLog(lock *sync.Mutex, logger *zerolog.Logger, msg string, args ...any) { + logger.Trace().Msgf(msg, args...) + lock.Unlock() +} diff --git a/ui/src/hooks/hidRpc.ts b/ui/src/hooks/hidRpc.ts index 20b8a108..d6b4ad96 100644 --- a/ui/src/hooks/hidRpc.ts +++ b/ui/src/hooks/hidRpc.ts @@ -6,6 +6,7 @@ export const HID_RPC_MESSAGE_TYPES = { PointerReport: 0x03, WheelReport: 0x04, KeypressReport: 0x05, + KeypressKeepAliveReport: 0x09, MouseReport: 0x06, KeyboardLedState: 0x32, KeysDownState: 0x33, @@ -278,12 +279,23 @@ export class MouseReportMessage extends RpcMessage { } } +export class KeypressKeepAliveMessage extends RpcMessage { + constructor() { + super(HID_RPC_MESSAGE_TYPES.KeypressKeepAliveReport); + } + + marshal(): Uint8Array { + return new Uint8Array([this.messageType]); + } +} + export const messageRegistry = { [HID_RPC_MESSAGE_TYPES.Handshake]: HandshakeMessage, [HID_RPC_MESSAGE_TYPES.KeysDownState]: KeysDownStateMessage, [HID_RPC_MESSAGE_TYPES.KeyboardLedState]: KeyboardLedStateMessage, [HID_RPC_MESSAGE_TYPES.KeyboardReport]: KeyboardReportMessage, [HID_RPC_MESSAGE_TYPES.KeypressReport]: KeypressReportMessage, + [HID_RPC_MESSAGE_TYPES.KeypressKeepAliveReport]: KeypressKeepAliveMessage, } export const unmarshalHidRpcMessage = (data: Uint8Array): RpcMessage | undefined => { diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index f99fd07d..9d5117c8 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -106,11 +106,17 @@ export interface RTCState { rpcDataChannel: RTCDataChannel | null; rpcHidProtocolVersion: number | null; - setRpcHidProtocolVersion: (version: number) => void; + setRpcHidProtocolVersion: (version: number | null) => void; rpcHidChannel: RTCDataChannel | null; setRpcHidChannel: (channel: RTCDataChannel) => void; + rpcHidUnreliableChannel: RTCDataChannel | null; + setRpcHidUnreliableChannel: (channel: RTCDataChannel) => void; + + rpcHidUnreliableNonOrderedChannel: RTCDataChannel | null; + setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => void; + peerConnectionState: RTCPeerConnectionState | null; setPeerConnectionState: (state: RTCPeerConnectionState) => void; @@ -158,11 +164,17 @@ export const useRTCStore = create(set => ({ setRpcDataChannel: (channel: RTCDataChannel) => set({ rpcDataChannel: channel }), rpcHidProtocolVersion: null, - setRpcHidProtocolVersion: (version: number) => set({ rpcHidProtocolVersion: version }), + setRpcHidProtocolVersion: (version: number | null) => set({ rpcHidProtocolVersion: version }), rpcHidChannel: null, setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }), + rpcHidUnreliableChannel: null, + setRpcHidUnreliableChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableChannel: channel }), + + rpcHidUnreliableNonOrderedChannel: null, + setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableNonOrderedChannel: channel }), + transceiver: null, setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts index ea0c7112..0670acb8 100644 --- a/ui/src/hooks/useHidRpc.ts +++ b/ui/src/hooks/useHidRpc.ts @@ -6,6 +6,7 @@ import { HID_RPC_VERSION, HandshakeMessage, KeyboardReportMessage, + KeypressKeepAliveMessage, KeypressReportMessage, MouseReportMessage, PointerReportMessage, @@ -13,20 +14,43 @@ import { unmarshalHidRpcMessage, } from "./hidRpc"; +const KEEPALIVE_MESSAGE = new KeypressKeepAliveMessage(); + +interface sendMessageParams { + ignoreHandshakeState?: boolean; + useUnreliableChannel?: boolean; + requireOrdered?: boolean; +} + export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { - const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion } = useRTCStore(); + const { + rpcHidChannel, + rpcHidUnreliableChannel, + rpcHidUnreliableNonOrderedChannel, + setRpcHidProtocolVersion, + rpcHidProtocolVersion, + } = useRTCStore(); + const rpcHidReady = useMemo(() => { return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; }, [rpcHidChannel, rpcHidProtocolVersion]); + const rpcHidUnreliableReady = useMemo(() => { + return rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + }, [rpcHidUnreliableChannel, rpcHidProtocolVersion]); + + const rpcHidUnreliableNonOrderedReady = useMemo(() => { + return rpcHidUnreliableNonOrderedChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + }, [rpcHidUnreliableNonOrderedChannel, rpcHidProtocolVersion]); + const rpcHidStatus = useMemo(() => { if (!rpcHidChannel) return "N/A"; if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState; if (!rpcHidProtocolVersion) return "handshaking"; - return `ready (v${rpcHidProtocolVersion})`; - }, [rpcHidChannel, rpcHidProtocolVersion]); + return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`; + }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion]); - const sendMessage = useCallback((message: RpcMessage, ignoreHandshakeState = false) => { + const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel, requireOrdered = true }: sendMessageParams = {}) => { if (rpcHidChannel?.readyState !== "open") return; if (!rpcHidReady && !ignoreHandshakeState) return; @@ -38,8 +62,24 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { } if (!data) return; + if (useUnreliableChannel) { + if (requireOrdered && rpcHidUnreliableReady) { + rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); + } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) { + rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer); + } + return; + } + rpcHidChannel?.send(data as unknown as ArrayBuffer); - }, [rpcHidChannel, rpcHidReady]); + }, [ + rpcHidChannel, + rpcHidUnreliableChannel, + rpcHidUnreliableNonOrderedChannel, + rpcHidReady, + rpcHidUnreliableReady, + rpcHidUnreliableNonOrderedReady, + ]); const reportKeyboardEvent = useCallback( (keys: number[], modifier: number) => { @@ -56,7 +96,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { const reportAbsMouseEvent = useCallback( (x: number, y: number, buttons: number) => { - sendMessage(new PointerReportMessage(x, y, buttons)); + sendMessage(new PointerReportMessage(x, y, buttons), { useUnreliableChannel: true }); }, [sendMessage], ); @@ -68,11 +108,15 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { [sendMessage], ); + const reportKeypressKeepAlive = useCallback(() => { + sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true, requireOrdered: false }); + }, [sendMessage]); + const sendHandshake = useCallback(() => { if (rpcHidProtocolVersion) return; if (!rpcHidChannel) return; - sendMessage(new HandshakeMessage(HID_RPC_VERSION), true); + sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true }); }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage]); const handleHandshake = useCallback((message: HandshakeMessage) => { @@ -123,10 +167,24 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { onHidRpcMessage?.(message); }; + const openHandler = () => { + console.warn("HID RPC channel opened"); + sendHandshake(); + }; + + const closeHandler = () => { + console.warn("HID RPC channel closed"); + setRpcHidProtocolVersion(null); + }; + rpcHidChannel.addEventListener("message", messageHandler); + rpcHidChannel.addEventListener("close", closeHandler); + rpcHidChannel.addEventListener("open", openHandler); return () => { rpcHidChannel.removeEventListener("message", messageHandler); + rpcHidChannel.removeEventListener("close", closeHandler); + rpcHidChannel.removeEventListener("open", openHandler); }; }, [ @@ -143,6 +201,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { reportKeypressEvent, reportAbsMouseEvent, reportRelMouseEvent, + reportKeypressKeepAlive, rpcHidProtocolVersion, rpcHidReady, rpcHidStatus, diff --git a/ui/src/hooks/useKeyboard.ts b/ui/src/hooks/useKeyboard.ts index 787df9a9..561485fb 100644 --- a/ui/src/hooks/useKeyboard.ts +++ b/ui/src/hooks/useKeyboard.ts @@ -1,4 +1,4 @@ -import { useCallback } from "react"; +import { useCallback, useRef } from "react"; import { hidErrorRollOver, hidKeyBufferSize, KeysDownState, useHidStore, useRTCStore } from "@/hooks/stores"; import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; @@ -11,6 +11,9 @@ export default function useKeyboard() { const { rpcDataChannel } = useRTCStore(); const { keysDownState, setKeysDownState, setKeyboardLedState } = useHidStore(); + // Keepalive timer management + const keepAliveTimerRef = useRef(null); + // INTRODUCTION: The earlier version of the JetKVM device shipped with all keyboard state // being tracked on the browser/client-side. When adding the keyPressReport API to the // device-side code, we have to still support the situation where the browser/client-side code @@ -26,6 +29,7 @@ export default function useKeyboard() { const { reportKeyboardEvent: sendKeyboardEventHidRpc, reportKeypressEvent: sendKeypressEventHidRpc, + reportKeypressKeepAlive: sendKeypressKeepAliveHidRpc, rpcHidReady, } = useHidRpc((message) => { switch (message.constructor) { @@ -70,17 +74,6 @@ export default function useKeyboard() { ], ); - // resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers. - // This is useful for macros and when the browser loses focus to ensure that the keyboard state - // is clean. - const resetKeyboardState = useCallback( - async () => { - // Reset the keys buffer to zeros and the modifier state to zero - keysDownState.keys.length = hidKeyBufferSize; - keysDownState.keys.fill(0); - keysDownState.modifier = 0; - sendKeyboardEvent(keysDownState); - }, [keysDownState, sendKeyboardEvent]); // executeMacro is used to execute a macro consisting of multiple steps. // Each step can have multiple keys, multiple modifiers and a delay. @@ -111,12 +104,61 @@ export default function useKeyboard() { } }; + const KEEPALIVE_INTERVAL = 75; // TODO: use an adaptive interval based on RTT later + + const cancelKeepAlive = useCallback(() => { + if (keepAliveTimerRef.current) { + clearInterval(keepAliveTimerRef.current); + keepAliveTimerRef.current = null; + } + }, []); + + const scheduleKeepAlive = useCallback(() => { + // Clear existing timer if it exists + if (keepAliveTimerRef.current) { + clearInterval(keepAliveTimerRef.current); + } + + // Create new interval timer + keepAliveTimerRef.current = setInterval(() => { + sendKeypressKeepAliveHidRpc(); + }, KEEPALIVE_INTERVAL); + }, [sendKeypressKeepAliveHidRpc]); + + // resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers. + // This is useful for macros and when the browser loses focus to ensure that the keyboard state + // is clean. + const resetKeyboardState = useCallback(async () => { + // Cancel keepalive since we're resetting the keyboard state + cancelKeepAlive(); + + // Reset the keys buffer to zeros and the modifier state to zero + keysDownState.keys.length = hidKeyBufferSize; + keysDownState.keys.fill(0); + keysDownState.modifier = 0; + sendKeyboardEvent(keysDownState); + }, [keysDownState, sendKeyboardEvent, cancelKeepAlive]); + // handleKeyPress is used to handle a key press or release event. // This function handle both key press and key release events. // It checks if the keyPressReport API is available and sends the key press event. // If the keyPressReport API is not available, it simulates the device-side key // handling for legacy devices and updates the keysDownState accordingly. // It then sends the full keyboard state to the device. + + const sendKeypress = useCallback( + (key: number, press: boolean) => { + cancelKeepAlive(); + + sendKeypressEventHidRpc(key, press); + + if (press) { + scheduleKeepAlive(); + } + }, + [sendKeypressEventHidRpc, scheduleKeepAlive, cancelKeepAlive], + ); + const handleKeyPress = useCallback( async (key: number, press: boolean) => { if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; @@ -129,7 +171,7 @@ export default function useKeyboard() { // Older device version doesn't support this API, so we will switch to local key handling // In that case we will switch to local key handling and update the keysDownState // in client/browser-side code using simulateDeviceSideKeyHandlingForLegacyDevices. - sendKeypressEventHidRpc(key, press); + sendKeypress(key, press); } else { // if the keyPress api is not available, we need to handle the key locally const downState = simulateDeviceSideKeyHandlingForLegacyDevices(keysDownState, key, press); @@ -147,7 +189,7 @@ export default function useKeyboard() { resetKeyboardState, rpcDataChannel?.readyState, sendKeyboardEvent, - sendKeypressEventHidRpc, + sendKeypress, ], ); @@ -210,5 +252,10 @@ export default function useKeyboard() { return { modifier: modifiers, keys }; } - return { handleKeyPress, resetKeyboardState, executeMacro }; + // Cleanup function to cancel keepalive timer + const cleanup = useCallback(() => { + cancelKeepAlive(); + }, [cancelKeepAlive]); + + return { handleKeyPress, resetKeyboardState, executeMacro, cleanup }; } diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index 4318447e..8a363199 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -136,6 +136,8 @@ export default function KvmIdRoute() { rpcDataChannel, setTransceiver, setRpcHidChannel, + setRpcHidUnreliableNonOrderedChannel, + setRpcHidUnreliableChannel, } = useRTCStore(); const location = useLocation(); @@ -488,6 +490,24 @@ export default function KvmIdRoute() { setRpcHidChannel(rpcHidChannel); }; + const rpcHidUnreliableChannel = pc.createDataChannel("hidrpc-unreliable-ordered", { + ordered: true, + maxRetransmits: 0, + }); + rpcHidUnreliableChannel.binaryType = "arraybuffer"; + rpcHidUnreliableChannel.onopen = () => { + setRpcHidUnreliableChannel(rpcHidUnreliableChannel); + }; + + const rpcHidUnreliableNonOrderedChannel = pc.createDataChannel("hidrpc-unreliable-nonordered", { + ordered: false, + maxRetransmits: 0, + }); + rpcHidUnreliableNonOrderedChannel.binaryType = "arraybuffer"; + rpcHidUnreliableNonOrderedChannel.onopen = () => { + setRpcHidUnreliableNonOrderedChannel(rpcHidUnreliableNonOrderedChannel); + }; + setPeerConnection(pc); }, [ cleanupAndStopReconnecting, @@ -499,6 +519,8 @@ export default function KvmIdRoute() { setPeerConnectionState, setRpcDataChannel, setRpcHidChannel, + setRpcHidUnreliableNonOrderedChannel, + setRpcHidUnreliableChannel, setTransceiver, ]); diff --git a/webrtc.go b/webrtc.go index c3d0dc1b..333a58b8 100644 --- a/webrtc.go +++ b/webrtc.go @@ -29,7 +29,12 @@ type Session struct { hidRPCAvailable bool hidQueueLock sync.Mutex - hidQueue []chan webrtc.DataChannelMessage + hidQueue []chan hidQueueMessage +} + +type hidQueueMessage struct { + webrtc.DataChannelMessage + channel string } type SessionConfig struct { @@ -78,16 +83,59 @@ func (s *Session) initQueues() { s.hidQueueLock.Lock() defer s.hidQueueLock.Unlock() - s.hidQueue = make([]chan webrtc.DataChannelMessage, 0) + s.hidQueue = make([]chan hidQueueMessage, 0) for i := 0; i < 4; i++ { - q := make(chan webrtc.DataChannelMessage, 256) + q := make(chan hidQueueMessage, 256) s.hidQueue = append(s.hidQueue, q) } } func (s *Session) handleQueues(index int) { for msg := range s.hidQueue[index] { - onHidMessage(msg.Data, s) + onHidMessage(msg, s) + } +} + +func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { + return func(msg webrtc.DataChannelMessage) { + l := scopedLogger.With(). + Str("channel", channel). + Int("length", len(msg.Data)). + Logger() + // only log data if the log level is debug or lower + if scopedLogger.GetLevel() > zerolog.DebugLevel { + l = l.With().Str("data", string(msg.Data)).Logger() + } + + if msg.IsString { + l.Warn().Msg("received string data in HID RPC message handler") + return + } + + if len(msg.Data) < 1 { + l.Warn().Msg("received empty data in HID RPC message handler") + return + } + + l.Trace().Msg("received data in HID RPC message handler") + + // Enqueue to ensure ordered processing + queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + if queueIndex >= len(session.hidQueue) || queueIndex < 0 { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") + queueIndex = 3 + } + + queue := session.hidQueue[queueIndex] + if queue != nil { + queue <- hidQueueMessage{ + DataChannelMessage: msg, + channel: channel, + } + } else { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") + return + } } } @@ -157,40 +205,12 @@ func newSession(config SessionConfig) (*Session, error) { switch d.Label() { case "hidrpc": session.HidChannel = d - d.OnMessage(func(msg webrtc.DataChannelMessage) { - l := scopedLogger.With().Int("length", len(msg.Data)).Logger() - // only log data if the log level is debug or lower - if scopedLogger.GetLevel() > zerolog.DebugLevel { - l = l.With().Str("data", string(msg.Data)).Logger() - } - - if msg.IsString { - l.Warn().Msg("received string data in HID RPC message handler") - return - } - - if len(msg.Data) < 1 { - l.Warn().Msg("received empty data in HID RPC message handler") - return - } - - l.Trace().Msg("received data in HID RPC message handler") - - // Enqueue to ensure ordered processing - queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) - if queueIndex >= len(session.hidQueue) || queueIndex < 0 { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") - queueIndex = 3 - } - - queue := session.hidQueue[queueIndex] - if queue != nil { - queue <- msg - } else { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") - return - } - }) + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc")) + // we won't send anything over the unreliable channels + case "hidrpc-unreliable-ordered": + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-ordered")) + case "hidrpc-unreliable-nonordered": + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-nonordered")) case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) {