diff --git a/go.mod b/go.mod index 962c3a1b..d07ba239 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gin-contrib/logger v1.2.6 github.com/gin-gonic/gin v1.10.1 github.com/go-co-op/gocron/v2 v2.16.5 + github.com/google/flatbuffers v25.2.10+incompatible github.com/google/uuid v1.6.0 github.com/guregu/null/v6 v6.0.0 github.com/gwatts/rootcerts v0.0.0-20250901182336-dc5ae18bd79f @@ -23,6 +24,7 @@ require ( github.com/prometheus/common v0.66.0 github.com/prometheus/procfs v0.17.0 github.com/psanford/httpreadat v0.1.0 + github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index e19fa9e6..57576a3a 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,8 @@ github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAu github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= +github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -152,6 +154,7 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= diff --git a/hidrpc.go b/hidrpc.go index 74fe687f..e0d9ad54 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -6,6 +6,7 @@ import ( "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/usbgadget" + "github.com/rs/zerolog" ) func handleHidRPCMessage(message hidrpc.Message, session *Session) { @@ -24,11 +25,9 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } session.hidRPCAvailable = true case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: - keysDownState, err := handleHidRPCKeyboardInput(message) - if keysDownState != nil { - session.reportHidRPCKeysDownState(*keysDownState) - } - rpcErr = err + rpcErr = handleHidRPCKeyboardInput(message) + case hidrpc.TypeKeypressKeepAliveReport: + rpcErr = handleHidRPCKeypressKeepAlive(session) case hidrpc.TypePointerReport: pointerReport, err := message.PointerReport() if err != nil { @@ -52,8 +51,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } } -func onHidMessage(data []byte, session *Session) { - scopedLogger := hidRPCLogger.With().Bytes("data", data).Logger() +func onHidMessage(msg hidQueueMessage, session *Session) { + data := msg.Data + + scopedLogger := hidRPCLogger.With(). + Str("channel", msg.channel). + Bytes("data", data). + Logger() scopedLogger.Debug().Msg("HID RPC message received") if len(data) < 1 { @@ -68,7 +72,9 @@ func onHidMessage(data []byte, session *Session) { return } - scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + if scopedLogger.GetLevel() <= zerolog.DebugLevel { + scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + } t := time.Now() @@ -85,27 +91,88 @@ func onHidMessage(data []byte, session *Session) { } } -func handleHidRPCKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { +// Tunables +// Keep in mind +// macOS default: 15 * 15 = 225ms https://discussions.apple.com/thread/1316947?sortBy=rank +// Linux default: 250ms https://man.archlinux.org/man/kbdrate.8.en +// Windows default: 1s `HKEY_CURRENT_USER\Control Panel\Accessibility\Keyboard Response\AutoRepeatDelay` + +const expectedRate = 50 * time.Millisecond // expected keepalive interval +const maxLateness = 50 * time.Millisecond // max jitter we'll tolerate OR jitter budget +const baseExtension = expectedRate + maxLateness // 100ms extension on perfect tick + +const maxStaleness = 225 * time.Millisecond // discard ancient packets outright + +func handleHidRPCKeypressKeepAlive(session *Session) error { + session.keepAliveJitterLock.Lock() + defer session.keepAliveJitterLock.Unlock() + + now := time.Now() + + // 1) Staleness guard: ensures packets that arrive far beyond the life of a valid key hold + // (e.g. after a network stall, retransmit burst, or machine sleep) are ignored outright. + // This prevents “zombie” keepalives from reviving a key that should already be released. + if !session.lastTimerResetTime.IsZero() && now.Sub(session.lastTimerResetTime) > maxStaleness { + return nil + } + + validTick := true + timerExtension := baseExtension + + if !session.lastKeepAliveArrivalTime.IsZero() { + timeSinceLastTick := now.Sub(session.lastKeepAliveArrivalTime) + lateness := timeSinceLastTick - expectedRate + + if lateness > 0 { + if lateness <= maxLateness { + // --- Small lateness (within jitterBudget) --- + // This is normal jitter (e.g., Wi-Fi contention). + // We still accept the tick, but *reduce the extension* + // so that the total hold time stays aligned with REAL client side intent. + timerExtension -= lateness + } else { + // --- Large lateness (beyond jitterBudget) --- + // This is likely a retransmit stall or ordering delay. + // We reject the tick entirely and DO NOT extend, + // so the auto-release still fires on time. + validTick = false + } + } + } + + if !validTick { + return nil + } + // Only valid ticks update our state and extend the timer. + session.lastKeepAliveArrivalTime = now + session.lastTimerResetTime = now + if gadget != nil { + gadget.DelayAutoReleaseWithDuration(timerExtension) + } + + // On a miss: do not advance any state — keeps baseline stable. + return nil +} + +func handleHidRPCKeyboardInput(message hidrpc.Message) error { switch message.Type() { case hidrpc.TypeKeypressReport: keypressReport, err := message.KeypressReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keypress report") - return nil, err + return err } - keysDownState, rpcError := rpcKeypressReport(keypressReport.Key, keypressReport.Press) - return &keysDownState, rpcError + return rpcKeypressReport(keypressReport.Key, keypressReport.Press) case hidrpc.TypeKeyboardReport: keyboardReport, err := message.KeyboardReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keyboard report") - return nil, err + return err } - keysDownState, rpcError := rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) - return &keysDownState, rpcError + return rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) } - return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type()) + return fmt.Errorf("unknown HID RPC message type: %d", message.Type()) } func reportHidRPC(params any, session *Session) { @@ -115,7 +182,10 @@ func reportHidRPC(params any, session *Session) { } if !session.hidRPCAvailable || session.HidChannel == nil { - logger.Warn().Msg("HID RPC is not available, skipping reportHidRPC") + logger.Warn(). + Bool("hidRPCAvailable", session.hidRPCAvailable). + Bool("HidChannel", session.HidChannel != nil). + Msg("HID RPC is not available, skipping reportHidRPC") return } @@ -156,7 +226,9 @@ func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) { func (s *Session) reportHidRPCKeysDownState(state usbgadget.KeysDownState) { if !s.hidRPCAvailable { + usbLogger.Debug().Interface("state", state).Msg("reporting keys down state") writeJSONRPCEvent("keysDownState", state, s) } + usbLogger.Debug().Interface("state", state).Msg("reporting keys down state, calling reportHidRPC") reportHidRPC(state, s) } diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go index e9c8c24d..1bc9807b 100644 --- a/internal/hidrpc/hidrpc.go +++ b/internal/hidrpc/hidrpc.go @@ -10,14 +10,15 @@ import ( type MessageType byte const ( - TypeHandshake MessageType = 0x01 - TypeKeyboardReport MessageType = 0x02 - TypePointerReport MessageType = 0x03 - TypeWheelReport MessageType = 0x04 - TypeKeypressReport MessageType = 0x05 - TypeMouseReport MessageType = 0x06 - TypeKeyboardLedState MessageType = 0x32 - TypeKeydownState MessageType = 0x33 + TypeHandshake MessageType = 0x01 + TypeKeyboardReport MessageType = 0x02 + TypePointerReport MessageType = 0x03 + TypeWheelReport MessageType = 0x04 + TypeKeypressReport MessageType = 0x05 + TypeKeypressKeepAliveReport MessageType = 0x09 + TypeMouseReport MessageType = 0x06 + TypeKeyboardLedState MessageType = 0x32 + TypeKeydownState MessageType = 0x33 ) const ( diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go index 84bbda7c..e0f4493e 100644 --- a/internal/hidrpc/message.go +++ b/internal/hidrpc/message.go @@ -43,6 +43,8 @@ func (m *Message) String() string { return fmt.Sprintf("MouseReport{Malformed: %v}", m.d) } return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2]) + case TypeKeypressKeepAliveReport: + return "KeypressKeepAliveReport" default: return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d) } diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index fb710c20..8335966d 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -5,7 +5,11 @@ import ( "context" "fmt" "os" + "sync" "time" + + "github.com/rs/xid" + "github.com/rs/zerolog" ) var keyboardConfig = gadgetConfigItem{ @@ -145,32 +149,105 @@ func (u *UsbGadget) GetKeysDownState() KeysDownState { return u.keysDownState } -func (u *UsbGadget) updateKeyDownState(state KeysDownState) { - u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("acquiring keyboardStateLock for updateKeyDownState") +func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { + u.onKeysDownChange = &f +} - // this is intentional to unlock keyboard state lock before onKeysDownChange callback - { - u.keyboardStateLock.Lock() - defer u.keyboardStateLock.Unlock() +func (u *UsbGadget) SetOnKeepAliveReset(f func()) { + u.onKeepAliveReset = &f +} - if u.keysDownState.Modifier == state.Modifier && - bytes.Equal(u.keysDownState.Keys, state.Keys) { - return // No change in key down state - } +// DefaultAutoReleaseDuration is the default duration for auto-release of a key. +const DefaultAutoReleaseDuration = 100 * time.Millisecond - u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("keysDownState updated") - u.keysDownState = state +func (u *UsbGadget) scheduleAutoRelease(key byte) { + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled") + + if u.kbdAutoReleaseTimers[key] != nil { + u.kbdAutoReleaseTimers[key].Stop() } - if u.onKeysDownChange != nil { - u.log.Trace().Interface("state", state).Msg("calling onKeysDownChange") - (*u.onKeysDownChange)(state) - u.log.Trace().Interface("state", state).Msg("onKeysDownChange called") + duration := u.kbdAutoReleaseTimerExtension + if duration == 0 { + duration = DefaultAutoReleaseDuration + } + + u.log.Debug().Dur("duration", duration).Msg("autoRelease scheduled with duration") + + u.kbdAutoReleaseTimers[key] = time.AfterFunc(duration, func() { + u.performAutoRelease(key) + }) +} + +func (u *UsbGadget) cancelAutoRelease(key byte) { + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled") + + if timer := u.kbdAutoReleaseTimers[key]; timer != nil { + timer.Stop() + u.kbdAutoReleaseTimers[key] = nil + delete(u.kbdAutoReleaseTimers, key) + + // Reset keep-alive timing when key is released + if u.onKeepAliveReset != nil { + (*u.onKeepAliveReset)() + } } } -func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { - u.onKeysDownChange = &f +func (u *UsbGadget) DelayAutoReleaseWithDuration(resetDuration time.Duration) { + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed") + + if u.kbdAutoReleaseTimers == nil { + return + } + + u.kbdAutoReleaseTimerExtension = resetDuration + + u.log.Debug().Dur("reset_duration", resetDuration).Msg("delaying auto-release with dynamic duration") + + for _, timer := range u.kbdAutoReleaseTimers { + if timer != nil { + timer.Reset(resetDuration) + } + } +} + +func (u *UsbGadget) performAutoRelease(key byte) { + u.kbdAutoReleaseLock.Lock() + + if u.kbdAutoReleaseTimers[key] == nil { + u.log.Warn().Uint8("key", key).Msg("autoRelease timer not found") + u.kbdAutoReleaseLock.Unlock() + return + } + + u.kbdAutoReleaseTimers[key].Stop() + u.kbdAutoReleaseTimers[key] = nil + delete(u.kbdAutoReleaseTimers, key) + u.kbdAutoReleaseLock.Unlock() + + // Skip if already released + state := u.GetKeysDownState() + alreadyReleased := true + + for i := range state.Keys { + if state.Keys[i] == key { + alreadyReleased = false + break + } + } + + if alreadyReleased { + return + } + + _, err := u.keypressReport(key, false) + if err != nil { + u.log.Warn().Uint8("key", key).Msg("failed to release key") + } } func (u *UsbGadget) listenKeyboardEvents() { @@ -242,7 +319,11 @@ func (u *UsbGadget) OpenKeyboardHidFile() error { return u.openKeyboardHidFile() } +var keyboardWriteHidFileLock sync.Mutex + func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { + keyboardWriteHidFileLock.Lock() + defer keyboardWriteHidFileLock.Unlock() if err := u.openKeyboardHidFile(); err != nil { return err } @@ -266,17 +347,29 @@ func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { } } - downState := KeysDownState{ + state := KeysDownState{ Modifier: modifier, Keys: []byte(keys[:]), } - u.updateKeyDownState(downState) - return downState + + u.keyboardStateLock.Lock() + + if u.keysDownState.Modifier == state.Modifier && + bytes.Equal(u.keysDownState.Keys, state.Keys) { + u.keyboardStateLock.Unlock() + return state // No change in key down state + } + + u.keysDownState = state + u.keyboardStateLock.Unlock() + + if u.onKeysDownChange != nil { + (*u.onKeysDownChange)(state) // this enques to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...) + } + return state } -func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() +func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error { defer u.resetUserInputTime() if len(keys) > hidKeyBufferSize { @@ -291,7 +384,8 @@ func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, e u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") } - return u.UpdateKeysDown(modifier, keys), err + u.UpdateKeysDown(modifier, keys) + return err } const ( @@ -331,17 +425,23 @@ var KeyCodeToMaskMap = map[byte]byte{ RightSuper: ModifierMaskRightSuper, } -func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() +func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) { defer u.resetUserInputTime() + l := u.log.With().Uint8("key", key).Bool("press", press).Logger() + if l.GetLevel() <= zerolog.DebugLevel { + requestID := xid.New() + l = l.With().Str("requestID", requestID.String()).Logger() + } + // IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver // for handling key presses and releases. It ensures that the USB gadget // behaves similarly to a real USB HID keyboard. This logic is paralleled // in the client/browser-side code in useKeyboard.ts so make sure to keep // them in sync. - var state = u.keysDownState + var state = u.GetKeysDownState() + l.Trace().Interface("state", state).Msg("got keys down state") + modifier := state.Modifier keys := append([]byte(nil), state.Keys...) @@ -381,22 +481,36 @@ func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) // If we reach here it means we didn't find an empty slot or the key in the buffer if overrun { if press { - u.log.Error().Uint8("key", key).Msg("keyboard buffer overflow, key not added") + l.Error().Msg("keyboard buffer overflow, key not added") // Fill all key slots with ErrorRollOver (0x01) to indicate overflow for i := range keys { keys[i] = hidErrorRollOver } } else { // If we are releasing a key, and we didn't find it in a slot, who cares? - u.log.Warn().Uint8("key", key).Msg("key not found in buffer, nothing to release") + l.Warn().Msg("key not found in buffer, nothing to release") } } } err := u.keyboardWriteHidFile(modifier, keys) - if err != nil { - u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keypress report to hidg0") - } - return u.UpdateKeysDown(modifier, keys), err } + +func (u *UsbGadget) KeypressReport(key byte, press bool) error { + state, err := u.keypressReport(key, press) + if err != nil { + u.log.Warn().Uint8("key", key).Bool("press", press).Msg("failed to report key") + } + isRolledOver := state.Keys[0] == hidErrorRollOver + + if isRolledOver { + u.cancelAutoRelease(key) + } else if press { + u.scheduleAutoRelease(key) + } else { + u.cancelAutoRelease(key) + } + + return err +} diff --git a/internal/usbgadget/usbgadget.go b/internal/usbgadget/usbgadget.go index 3a01a447..36d54a2c 100644 --- a/internal/usbgadget/usbgadget.go +++ b/internal/usbgadget/usbgadget.go @@ -68,6 +68,10 @@ type UsbGadget struct { keyboardState byte // keyboard latched state (NumLock, CapsLock, ScrollLock, Compose, Kana) keysDownState KeysDownState // keyboard dynamic state (modifier keys and pressed keys) + kbdAutoReleaseLock sync.Mutex + kbdAutoReleaseTimers map[byte]*time.Timer + kbdAutoReleaseTimerExtension time.Duration + keyboardStateLock sync.Mutex keyboardStateCtx context.Context keyboardStateCancel context.CancelFunc @@ -85,6 +89,7 @@ type UsbGadget struct { onKeyboardStateChange *func(state KeyboardState) onKeysDownChange *func(state KeysDownState) + onKeepAliveReset *func() log *zerolog.Logger @@ -118,23 +123,25 @@ func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDev keyboardCtx, keyboardCancel := context.WithCancel(context.Background()) g := &UsbGadget{ - name: name, - kvmGadgetPath: path.Join(gadgetPath, name), - configC1Path: path.Join(gadgetPath, name, "configs/c.1"), - configMap: configMap, - customConfig: *config, - configLock: sync.Mutex{}, - keyboardLock: sync.Mutex{}, - absMouseLock: sync.Mutex{}, - relMouseLock: sync.Mutex{}, - txLock: sync.Mutex{}, - keyboardStateCtx: keyboardCtx, - keyboardStateCancel: keyboardCancel, - keyboardState: 0, - keysDownState: KeysDownState{Modifier: 0, Keys: []byte{0, 0, 0, 0, 0, 0}}, // must be initialized to hidKeyBufferSize (6) zero bytes - enabledDevices: *enabledDevices, - lastUserInput: time.Now(), - log: logger, + name: name, + kvmGadgetPath: path.Join(gadgetPath, name), + configC1Path: path.Join(gadgetPath, name, "configs/c.1"), + configMap: configMap, + customConfig: *config, + configLock: sync.Mutex{}, + keyboardLock: sync.Mutex{}, + absMouseLock: sync.Mutex{}, + relMouseLock: sync.Mutex{}, + txLock: sync.Mutex{}, + keyboardStateCtx: keyboardCtx, + keyboardStateCancel: keyboardCancel, + keyboardState: 0, + keysDownState: KeysDownState{Modifier: 0, Keys: []byte{0, 0, 0, 0, 0, 0}}, // must be initialized to hidKeyBufferSize (6) zero bytes + kbdAutoReleaseTimers: make(map[byte]*time.Timer), + kbdAutoReleaseTimerExtension: 0, + enabledDevices: *enabledDevices, + lastUserInput: time.Now(), + log: logger, strictMode: config.strictMode, @@ -149,3 +156,35 @@ func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDev return g } + +// Close cleans up resources used by the USB gadget +func (u *UsbGadget) Close() error { + // Cancel keyboard state context + if u.keyboardStateCancel != nil { + u.keyboardStateCancel() + } + + // Stop auto-release timer + u.kbdAutoReleaseLock.Lock() + for _, timer := range u.kbdAutoReleaseTimers { + timer.Stop() + } + u.kbdAutoReleaseTimers = make(map[byte]*time.Timer) + u.kbdAutoReleaseLock.Unlock() + + // Close HID files + if u.keyboardHidFile != nil { + u.keyboardHidFile.Close() + u.keyboardHidFile = nil + } + if u.absMouseHidFile != nil { + u.absMouseHidFile.Close() + u.absMouseHidFile = nil + } + if u.relMouseHidFile != nil { + u.relMouseHidFile.Close() + u.relMouseHidFile = nil + } + + return nil +} diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index d51f9e40..85bf1579 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -9,6 +9,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "github.com/rs/zerolog" @@ -120,6 +121,12 @@ func (u *UsbGadget) writeWithTimeout(file *os.File, data []byte) (n int, err err return } + u.log.Trace(). + Str("file", file.Name()). + Bytes("data", data). + Err(err). + Msg("write failed") + if errors.Is(err, os.ErrDeadlineExceeded) { u.logWithSuppression( fmt.Sprintf("writeWithTimeout_%s", file.Name()), @@ -164,3 +171,8 @@ func (u *UsbGadget) resetLogSuppressionCounter(counterName string) { u.logSuppressionCounter[counterName] = 0 } } + +func unlockWithLog(lock *sync.Mutex, logger *zerolog.Logger, msg string, args ...any) { + logger.Trace().Msgf(msg, args...) + lock.Unlock() +} diff --git a/jsonrpc.go b/jsonrpc.go index 61f28df5..8136a704 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -1066,9 +1066,7 @@ var rpcHandlers = map[string]RPCHandler{ "getNetworkSettings": {Func: rpcGetNetworkSettings}, "setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}}, "renewDHCPLease": {Func: rpcRenewDHCPLease}, - "keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}}, "getKeyboardLedState": {Func: rpcGetKeyboardLedState}, - "keypressReport": {Func: rpcKeypressReport, Params: []string{"key", "press"}}, "getKeyDownState": {Func: rpcGetKeysDownState}, "absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}}, "relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}}, diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index bc6897ea..64452bf8 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -190,7 +190,7 @@ export default function WebRTCVideo() { if (!isFullscreenEnabled || !videoElm.current) return; // per https://wicg.github.io/keyboard-lock/#system-key-press-handler - // If keyboard lock is activated after fullscreen is already in effect, then the user my + // If keyboard lock is activated after fullscreen is already in effect, then the user my // see multiple messages about how to exit fullscreen. For this reason, we recommend that // developers call lock() before they enter fullscreen: await requestKeyboardLock(); @@ -237,6 +237,7 @@ export default function WebRTCVideo() { const keyDownHandler = useCallback( (e: KeyboardEvent) => { e.preventDefault(); + if (e.repeat) return; const code = getAdjustedKeyCode(e); const hidKey = keys[code]; diff --git a/ui/src/hooks/hidRpc.ts b/ui/src/hooks/hidRpc.ts index 20b8a108..d6b4ad96 100644 --- a/ui/src/hooks/hidRpc.ts +++ b/ui/src/hooks/hidRpc.ts @@ -6,6 +6,7 @@ export const HID_RPC_MESSAGE_TYPES = { PointerReport: 0x03, WheelReport: 0x04, KeypressReport: 0x05, + KeypressKeepAliveReport: 0x09, MouseReport: 0x06, KeyboardLedState: 0x32, KeysDownState: 0x33, @@ -278,12 +279,23 @@ export class MouseReportMessage extends RpcMessage { } } +export class KeypressKeepAliveMessage extends RpcMessage { + constructor() { + super(HID_RPC_MESSAGE_TYPES.KeypressKeepAliveReport); + } + + marshal(): Uint8Array { + return new Uint8Array([this.messageType]); + } +} + export const messageRegistry = { [HID_RPC_MESSAGE_TYPES.Handshake]: HandshakeMessage, [HID_RPC_MESSAGE_TYPES.KeysDownState]: KeysDownStateMessage, [HID_RPC_MESSAGE_TYPES.KeyboardLedState]: KeyboardLedStateMessage, [HID_RPC_MESSAGE_TYPES.KeyboardReport]: KeyboardReportMessage, [HID_RPC_MESSAGE_TYPES.KeypressReport]: KeypressReportMessage, + [HID_RPC_MESSAGE_TYPES.KeypressKeepAliveReport]: KeypressKeepAliveMessage, } export const unmarshalHidRpcMessage = (data: Uint8Array): RpcMessage | undefined => { diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index f99fd07d..9d5117c8 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -106,11 +106,17 @@ export interface RTCState { rpcDataChannel: RTCDataChannel | null; rpcHidProtocolVersion: number | null; - setRpcHidProtocolVersion: (version: number) => void; + setRpcHidProtocolVersion: (version: number | null) => void; rpcHidChannel: RTCDataChannel | null; setRpcHidChannel: (channel: RTCDataChannel) => void; + rpcHidUnreliableChannel: RTCDataChannel | null; + setRpcHidUnreliableChannel: (channel: RTCDataChannel) => void; + + rpcHidUnreliableNonOrderedChannel: RTCDataChannel | null; + setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => void; + peerConnectionState: RTCPeerConnectionState | null; setPeerConnectionState: (state: RTCPeerConnectionState) => void; @@ -158,11 +164,17 @@ export const useRTCStore = create(set => ({ setRpcDataChannel: (channel: RTCDataChannel) => set({ rpcDataChannel: channel }), rpcHidProtocolVersion: null, - setRpcHidProtocolVersion: (version: number) => set({ rpcHidProtocolVersion: version }), + setRpcHidProtocolVersion: (version: number | null) => set({ rpcHidProtocolVersion: version }), rpcHidChannel: null, setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }), + rpcHidUnreliableChannel: null, + setRpcHidUnreliableChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableChannel: channel }), + + rpcHidUnreliableNonOrderedChannel: null, + setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableNonOrderedChannel: channel }), + transceiver: null, setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts index ea0c7112..4546b625 100644 --- a/ui/src/hooks/useHidRpc.ts +++ b/ui/src/hooks/useHidRpc.ts @@ -6,6 +6,7 @@ import { HID_RPC_VERSION, HandshakeMessage, KeyboardReportMessage, + KeypressKeepAliveMessage, KeypressReportMessage, MouseReportMessage, PointerReportMessage, @@ -13,38 +14,93 @@ import { unmarshalHidRpcMessage, } from "./hidRpc"; +const KEEPALIVE_MESSAGE = new KeypressKeepAliveMessage(); + +interface sendMessageParams { + ignoreHandshakeState?: boolean; + useUnreliableChannel?: boolean; + requireOrdered?: boolean; +} + export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { - const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion } = useRTCStore(); + const { + rpcHidChannel, + rpcHidUnreliableChannel, + rpcHidUnreliableNonOrderedChannel, + setRpcHidProtocolVersion, + rpcHidProtocolVersion, + } = useRTCStore(); + const rpcHidReady = useMemo(() => { return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; }, [rpcHidChannel, rpcHidProtocolVersion]); + const rpcHidUnreliableReady = useMemo(() => { + return ( + rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null + ); + }, [rpcHidUnreliableChannel, rpcHidProtocolVersion]); + + const rpcHidUnreliableNonOrderedReady = useMemo(() => { + return ( + rpcHidUnreliableNonOrderedChannel?.readyState === "open" && + rpcHidProtocolVersion !== null + ); + }, [rpcHidUnreliableNonOrderedChannel, rpcHidProtocolVersion]); + const rpcHidStatus = useMemo(() => { if (!rpcHidChannel) return "N/A"; if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState; if (!rpcHidProtocolVersion) return "handshaking"; - return `ready (v${rpcHidProtocolVersion})`; - }, [rpcHidChannel, rpcHidProtocolVersion]); + return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`; + }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion]); - const sendMessage = useCallback((message: RpcMessage, ignoreHandshakeState = false) => { - if (rpcHidChannel?.readyState !== "open") return; - if (!rpcHidReady && !ignoreHandshakeState) return; + const sendMessage = useCallback( + ( + message: RpcMessage, + { + ignoreHandshakeState, + useUnreliableChannel, + requireOrdered = true, + }: sendMessageParams = {}, + ) => { + if (rpcHidChannel?.readyState !== "open") return; + if (!rpcHidReady && !ignoreHandshakeState) return; - let data: Uint8Array | undefined; - try { - data = message.marshal(); - } catch (e) { - console.error("Failed to send HID RPC message", e); - } - if (!data) return; + let data: Uint8Array | undefined; + try { + data = message.marshal(); + } catch (e) { + console.error("Failed to send HID RPC message", e); + } + if (!data) return; - rpcHidChannel?.send(data as unknown as ArrayBuffer); - }, [rpcHidChannel, rpcHidReady]); + if (useUnreliableChannel) { + if (requireOrdered && rpcHidUnreliableReady) { + rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); + } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) { + rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer); + } + return; + } + + rpcHidChannel?.send(data as unknown as ArrayBuffer); + }, + [ + rpcHidChannel, + rpcHidUnreliableChannel, + rpcHidUnreliableNonOrderedChannel, + rpcHidReady, + rpcHidUnreliableReady, + rpcHidUnreliableNonOrderedReady, + ], + ); const reportKeyboardEvent = useCallback( (keys: number[], modifier: number) => { sendMessage(new KeyboardReportMessage(keys, modifier)); - }, [sendMessage], + }, + [sendMessage], ); const reportKeypressEvent = useCallback( @@ -56,7 +112,9 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { const reportAbsMouseEvent = useCallback( (x: number, y: number, buttons: number) => { - sendMessage(new PointerReportMessage(x, y, buttons)); + sendMessage(new PointerReportMessage(x, y, buttons), { + useUnreliableChannel: true, + }); }, [sendMessage], ); @@ -68,29 +126,36 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { [sendMessage], ); + const reportKeypressKeepAlive = useCallback(() => { + sendMessage(KEEPALIVE_MESSAGE); + }, [sendMessage]); + const sendHandshake = useCallback(() => { if (rpcHidProtocolVersion) return; if (!rpcHidChannel) return; - sendMessage(new HandshakeMessage(HID_RPC_VERSION), true); + sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true }); }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage]); - const handleHandshake = useCallback((message: HandshakeMessage) => { - if (!message.version) { - console.error("Received handshake message without version", message); - return; - } + const handleHandshake = useCallback( + (message: HandshakeMessage) => { + if (!message.version) { + console.error("Received handshake message without version", message); + return; + } - if (message.version > HID_RPC_VERSION) { - // we assume that the UI is always using the latest version of the HID RPC protocol - // so we can't support this - // TODO: use capabilities to determine rather than version number - console.error("Server is using a newer HID RPC version than the client", message); - return; - } + if (message.version > HID_RPC_VERSION) { + // we assume that the UI is always using the latest version of the HID RPC protocol + // so we can't support this + // TODO: use capabilities to determine rather than version number + console.error("Server is using a newer HID RPC version than the client", message); + return; + } - setRpcHidProtocolVersion(message.version); - }, [setRpcHidProtocolVersion]); + setRpcHidProtocolVersion(message.version); + }, + [setRpcHidProtocolVersion], + ); useEffect(() => { if (!rpcHidChannel) return; @@ -123,26 +188,39 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { onHidRpcMessage?.(message); }; + const openHandler = () => { + console.info("HID RPC channel opened"); + sendHandshake(); + }; + + const closeHandler = () => { + console.info("HID RPC channel closed"); + setRpcHidProtocolVersion(null); + }; + rpcHidChannel.addEventListener("message", messageHandler); + rpcHidChannel.addEventListener("close", closeHandler); + rpcHidChannel.addEventListener("open", openHandler); return () => { rpcHidChannel.removeEventListener("message", messageHandler); + rpcHidChannel.removeEventListener("close", closeHandler); + rpcHidChannel.removeEventListener("open", openHandler); }; - }, - [ - rpcHidChannel, - onHidRpcMessage, - setRpcHidProtocolVersion, - sendHandshake, - handleHandshake, - ], - ); + }, [ + rpcHidChannel, + onHidRpcMessage, + setRpcHidProtocolVersion, + sendHandshake, + handleHandshake, + ]); return { reportKeyboardEvent, reportKeypressEvent, reportAbsMouseEvent, reportRelMouseEvent, + reportKeypressKeepAlive, rpcHidProtocolVersion, rpcHidReady, rpcHidStatus, diff --git a/ui/src/hooks/useKeyboard.ts b/ui/src/hooks/useKeyboard.ts index 787df9a9..969166a7 100644 --- a/ui/src/hooks/useKeyboard.ts +++ b/ui/src/hooks/useKeyboard.ts @@ -1,6 +1,12 @@ -import { useCallback } from "react"; +import { useCallback, useRef } from "react"; -import { hidErrorRollOver, hidKeyBufferSize, KeysDownState, useHidStore, useRTCStore } from "@/hooks/stores"; +import { + hidErrorRollOver, + hidKeyBufferSize, + KeysDownState, + useHidStore, + useRTCStore, +} from "@/hooks/stores"; import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; import { useHidRpc } from "@/hooks/useHidRpc"; import { KeyboardLedStateMessage, KeysDownStateMessage } from "@/hooks/hidRpc"; @@ -11,23 +17,27 @@ export default function useKeyboard() { const { rpcDataChannel } = useRTCStore(); const { keysDownState, setKeysDownState, setKeyboardLedState } = useHidStore(); + // Keepalive timer management + const keepAliveTimerRef = useRef(null); + // INTRODUCTION: The earlier version of the JetKVM device shipped with all keyboard state // being tracked on the browser/client-side. When adding the keyPressReport API to the // device-side code, we have to still support the situation where the browser/client-side code // is running on the cloud against a device that has not been updated yet and thus does not // support the keyPressReport API. In that case, we need to handle the key presses locally // and send the full state to the device, so it can behave like a real USB HID keyboard. - // This flag indicates whether the keyPressReport API is available on the device which is + // This flag indicates whether the keyPressReport API is available on the device which is // dynamically set when the device responds to the first key press event or reports its - // keysDownState when queried since the keyPressReport was introduced together with the + // keysDownState when queried since the keyPressReport was introduced together with the // getKeysDownState API. // HidRPC is a binary format for exchanging keyboard and mouse events const { reportKeyboardEvent: sendKeyboardEventHidRpc, reportKeypressEvent: sendKeypressEventHidRpc, + reportKeypressKeepAlive: sendKeypressKeepAliveHidRpc, rpcHidReady, - } = useHidRpc((message) => { + } = useHidRpc(message => { switch (message.constructor) { case KeysDownStateMessage: setKeysDownState((message as KeysDownStateMessage).keysDownState); @@ -48,7 +58,9 @@ export default function useKeyboard() { async (state: KeysDownState) => { if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; - console.debug(`Send keyboardReport keys: ${state.keys}, modifier: ${state.modifier}`); + console.debug( + `Send keyboardReport keys: ${state.keys}, modifier: ${state.modifier}`, + ); if (rpcHidReady) { console.debug("Sending keyboard report via HidRPC"); @@ -56,42 +68,33 @@ export default function useKeyboard() { return; } - send("keyboardReport", { keys: state.keys, modifier: state.modifier }, (resp: JsonRpcResponse) => { - if ("error" in resp) { - console.error(`Failed to send keyboard report ${state}`, resp.error); - } - }); + send( + "keyboardReport", + { keys: state.keys, modifier: state.modifier }, + (resp: JsonRpcResponse) => { + if ("error" in resp) { + console.error(`Failed to send keyboard report ${state}`, resp.error); + } + }, + ); }, - [ - rpcDataChannel?.readyState, - rpcHidReady, - send, - sendKeyboardEventHidRpc, - ], + [rpcDataChannel?.readyState, rpcHidReady, send, sendKeyboardEventHidRpc], ); - // resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers. - // This is useful for macros and when the browser loses focus to ensure that the keyboard state - // is clean. - const resetKeyboardState = useCallback( - async () => { - // Reset the keys buffer to zeros and the modifier state to zero - keysDownState.keys.length = hidKeyBufferSize; - keysDownState.keys.fill(0); - keysDownState.modifier = 0; - sendKeyboardEvent(keysDownState); - }, [keysDownState, sendKeyboardEvent]); - // executeMacro is used to execute a macro consisting of multiple steps. // Each step can have multiple keys, multiple modifiers and a delay. // The keys and modifiers are pressed together and held for the delay duration. // After the delay, the keys and modifiers are released and the next step is executed. // If a step has no keys or modifiers, it is treated as a delay-only step. // A small pause is added between steps to ensure that the device can process the events. - const executeMacro = async (steps: { keys: string[] | null; modifiers: string[] | null; delay: number }[]) => { + const executeMacro = async ( + steps: { keys: string[] | null; modifiers: string[] | null; delay: number }[], + ) => { for (const [index, step] of steps.entries()) { const keyValues = (step.keys || []).map(key => keys[key]).filter(Boolean); - const modifierMask: number = (step.modifiers || []).map(mod => modifiers[mod]).reduce((acc, val) => acc + val, 0); + const modifierMask: number = (step.modifiers || []) + .map(mod => modifiers[mod]) + .reduce((acc, val) => acc + val, 0); // If the step has keys and/or modifiers, press them and hold for the delay if (keyValues.length > 0 || modifierMask > 0) { @@ -111,12 +114,60 @@ export default function useKeyboard() { } }; + const KEEPALIVE_INTERVAL = 50; + + const cancelKeepAlive = useCallback(() => { + if (keepAliveTimerRef.current) { + clearInterval(keepAliveTimerRef.current); + keepAliveTimerRef.current = null; + } + }, []); + + const scheduleKeepAlive = useCallback(() => { + // Clear existing timer if it exists + if (keepAliveTimerRef.current) { + clearInterval(keepAliveTimerRef.current); + } + + keepAliveTimerRef.current = setInterval(() => { + sendKeypressKeepAliveHidRpc(); + }, KEEPALIVE_INTERVAL); + }, [sendKeypressKeepAliveHidRpc]); + + // resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers. + // This is useful for macros and when the browser loses focus to ensure that the keyboard state + // is clean. + const resetKeyboardState = useCallback(async () => { + // Cancel keepalive since we're resetting the keyboard state + cancelKeepAlive(); + + // Reset the keys buffer to zeros and the modifier state to zero + keysDownState.keys.length = hidKeyBufferSize; + keysDownState.keys.fill(0); + keysDownState.modifier = 0; + sendKeyboardEvent(keysDownState); + }, [keysDownState, sendKeyboardEvent, cancelKeepAlive]); + // handleKeyPress is used to handle a key press or release event. // This function handle both key press and key release events. // It checks if the keyPressReport API is available and sends the key press event. // If the keyPressReport API is not available, it simulates the device-side key // handling for legacy devices and updates the keysDownState accordingly. // It then sends the full keyboard state to the device. + + const sendKeypress = useCallback( + (key: number, press: boolean) => { + cancelKeepAlive(); + + sendKeypressEventHidRpc(key, press); + + if (press) { + scheduleKeepAlive(); + } + }, + [sendKeypressEventHidRpc, scheduleKeepAlive, cancelKeepAlive], + ); + const handleKeyPress = useCallback( async (key: number, press: boolean) => { if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; @@ -129,10 +180,14 @@ export default function useKeyboard() { // Older device version doesn't support this API, so we will switch to local key handling // In that case we will switch to local key handling and update the keysDownState // in client/browser-side code using simulateDeviceSideKeyHandlingForLegacyDevices. - sendKeypressEventHidRpc(key, press); + sendKeypress(key, press); } else { // if the keyPress api is not available, we need to handle the key locally - const downState = simulateDeviceSideKeyHandlingForLegacyDevices(keysDownState, key, press); + const downState = simulateDeviceSideKeyHandlingForLegacyDevices( + keysDownState, + key, + press, + ); sendKeyboardEvent(downState); // then we send the full state // if we just sent ErrorRollOver, reset to empty state @@ -147,12 +202,16 @@ export default function useKeyboard() { resetKeyboardState, rpcDataChannel?.readyState, sendKeyboardEvent, - sendKeypressEventHidRpc, + sendKeypress, ], ); // IMPORTANT: See the keyPressReportApiAvailable comment above for the reason this exists - function simulateDeviceSideKeyHandlingForLegacyDevices(state: KeysDownState, key: number, press: boolean): KeysDownState { + function simulateDeviceSideKeyHandlingForLegacyDevices( + state: KeysDownState, + key: number, + press: boolean, + ): KeysDownState { // IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver // for handling key presses and releases. It ensures that the USB gadget // behaves similarly to a real USB HID keyboard. This logic is paralleled @@ -164,7 +223,7 @@ export default function useKeyboard() { if (modifierMask !== 0) { // If the key is a modifier key, we update the keyboardModifier state // by setting or clearing the corresponding bit in the modifier byte. - // This allows us to track the state of dynamic modifier keys like + // This allows us to track the state of dynamic modifier keys like // Shift, Control, Alt, and Super. if (press) { modifiers |= modifierMask; @@ -181,7 +240,7 @@ export default function useKeyboard() { // and if we find a zero byte, we can place the key there (if press is true) if (keys[i] === key || keys[i] === 0) { if (press) { - keys[i] = key // overwrites the zero byte or the same key if already pressed + keys[i] = key; // overwrites the zero byte or the same key if already pressed } else { // we are releasing the key, remove it from the buffer if (keys[i] !== 0) { @@ -197,18 +256,25 @@ export default function useKeyboard() { // If we reach here it means we didn't find an empty slot or the key in the buffer if (overrun) { if (press) { - console.warn(`keyboard buffer overflow current keys ${keys}, key: ${key} not added`); + console.warn( + `keyboard buffer overflow current keys ${keys}, key: ${key} not added`, + ); // Fill all key slots with ErrorRollOver (0x01) to indicate overflow keys.length = hidKeyBufferSize; keys.fill(hidErrorRollOver); } else { // If we are releasing a key, and we didn't find it in a slot, who cares? - console.debug(`key ${key} not found in buffer, nothing to release`) + console.debug(`key ${key} not found in buffer, nothing to release`); } } } return { modifier: modifiers, keys }; } - return { handleKeyPress, resetKeyboardState, executeMacro }; + // Cleanup function to cancel keepalive timer + const cleanup = useCallback(() => { + cancelKeepAlive(); + }, [cancelKeepAlive]); + + return { handleKeyPress, resetKeyboardState, executeMacro, cleanup }; } diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index 4318447e..8a363199 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -136,6 +136,8 @@ export default function KvmIdRoute() { rpcDataChannel, setTransceiver, setRpcHidChannel, + setRpcHidUnreliableNonOrderedChannel, + setRpcHidUnreliableChannel, } = useRTCStore(); const location = useLocation(); @@ -488,6 +490,24 @@ export default function KvmIdRoute() { setRpcHidChannel(rpcHidChannel); }; + const rpcHidUnreliableChannel = pc.createDataChannel("hidrpc-unreliable-ordered", { + ordered: true, + maxRetransmits: 0, + }); + rpcHidUnreliableChannel.binaryType = "arraybuffer"; + rpcHidUnreliableChannel.onopen = () => { + setRpcHidUnreliableChannel(rpcHidUnreliableChannel); + }; + + const rpcHidUnreliableNonOrderedChannel = pc.createDataChannel("hidrpc-unreliable-nonordered", { + ordered: false, + maxRetransmits: 0, + }); + rpcHidUnreliableNonOrderedChannel.binaryType = "arraybuffer"; + rpcHidUnreliableNonOrderedChannel.onopen = () => { + setRpcHidUnreliableNonOrderedChannel(rpcHidUnreliableNonOrderedChannel); + }; + setPeerConnection(pc); }, [ cleanupAndStopReconnecting, @@ -499,6 +519,8 @@ export default function KvmIdRoute() { setPeerConnectionState, setRpcDataChannel, setRpcHidChannel, + setRpcHidUnreliableNonOrderedChannel, + setRpcHidUnreliableChannel, setTransceiver, ]); diff --git a/usb.go b/usb.go index 131cd517..99287a30 100644 --- a/usb.go +++ b/usb.go @@ -33,7 +33,13 @@ func initUsbGadget() { gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { if currentSession != nil { - currentSession.reportHidRPCKeysDownState(state) + currentSession.enqueueKeysDownState(state) + } + }) + + gadget.SetOnKeepAliveReset(func() { + if currentSession != nil { + currentSession.resetKeepAliveTime() } }) @@ -43,11 +49,11 @@ func initUsbGadget() { } } -func rpcKeyboardReport(modifier byte, keys []byte) (usbgadget.KeysDownState, error) { +func rpcKeyboardReport(modifier byte, keys []byte) error { return gadget.KeyboardReport(modifier, keys) } -func rpcKeypressReport(key byte, press bool) (usbgadget.KeysDownState, error) { +func rpcKeypressReport(key byte, press bool) error { return gadget.KeypressReport(key, press) } diff --git a/webrtc.go b/webrtc.go index c3d0dc1b..d0d5f7ac 100644 --- a/webrtc.go +++ b/webrtc.go @@ -7,12 +7,14 @@ import ( "net" "strings" "sync" + "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" "github.com/gin-gonic/gin" "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/usbgadget" "github.com/pion/webrtc/v4" "github.com/rs/zerolog" ) @@ -27,9 +29,26 @@ type Session struct { rpcQueue chan webrtc.DataChannelMessage - hidRPCAvailable bool - hidQueueLock sync.Mutex - hidQueue []chan webrtc.DataChannelMessage + hidRPCAvailable bool + lastKeepAliveArrivalTime time.Time // Track when last keep-alive packet arrived + lastTimerResetTime time.Time // Track when auto-release timer was last reset + keepAliveJitterLock sync.Mutex // Protect jitter compensation timing state + hidQueueLock sync.Mutex + hidQueue []chan hidQueueMessage + + keysDownStateQueue chan usbgadget.KeysDownState +} + +func (s *Session) resetKeepAliveTime() { + s.keepAliveJitterLock.Lock() + defer s.keepAliveJitterLock.Unlock() + s.lastKeepAliveArrivalTime = time.Time{} // Reset keep-alive timing tracking + s.lastTimerResetTime = time.Time{} // Reset auto-release timer tracking +} + +type hidQueueMessage struct { + webrtc.DataChannelMessage + channel string } type SessionConfig struct { @@ -78,16 +97,85 @@ func (s *Session) initQueues() { s.hidQueueLock.Lock() defer s.hidQueueLock.Unlock() - s.hidQueue = make([]chan webrtc.DataChannelMessage, 0) + s.hidQueue = make([]chan hidQueueMessage, 0) for i := 0; i < 4; i++ { - q := make(chan webrtc.DataChannelMessage, 256) + q := make(chan hidQueueMessage, 256) s.hidQueue = append(s.hidQueue, q) } } func (s *Session) handleQueues(index int) { for msg := range s.hidQueue[index] { - onHidMessage(msg.Data, s) + onHidMessage(msg, s) + } +} + +const keysDownStateQueueSize = 256 + +func (s *Session) initKeysDownStateQueue() { + // serialise outbound key state reports so unreliable links can't stall input handling + s.keysDownStateQueue = make(chan usbgadget.KeysDownState, keysDownStateQueueSize) + go s.handleKeysDownStateQueue() +} + +func (s *Session) handleKeysDownStateQueue() { + for state := range s.keysDownStateQueue { + s.reportHidRPCKeysDownState(state) + } +} + +func (s *Session) enqueueKeysDownState(state usbgadget.KeysDownState) { + if s == nil || s.keysDownStateQueue == nil { + return + } + + select { + case s.keysDownStateQueue <- state: + default: + hidRPCLogger.Warn().Msg("dropping keys down state update; queue full") + } +} + +func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { + return func(msg webrtc.DataChannelMessage) { + l := scopedLogger.With(). + Str("channel", channel). + Int("length", len(msg.Data)). + Logger() + // only log data if the log level is debug or lower + if scopedLogger.GetLevel() > zerolog.DebugLevel { + l = l.With().Str("data", string(msg.Data)).Logger() + } + + if msg.IsString { + l.Warn().Msg("received string data in HID RPC message handler") + return + } + + if len(msg.Data) < 1 { + l.Warn().Msg("received empty data in HID RPC message handler") + return + } + + l.Trace().Msg("received data in HID RPC message handler") + + // Enqueue to ensure ordered processing + queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + if queueIndex >= len(session.hidQueue) || queueIndex < 0 { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") + queueIndex = 3 + } + + queue := session.hidQueue[queueIndex] + if queue != nil { + queue <- hidQueueMessage{ + DataChannelMessage: msg, + channel: channel, + } + } else { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") + return + } } } @@ -133,6 +221,7 @@ func newSession(config SessionConfig) (*Session, error) { session := &Session{peerConnection: peerConnection} session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) session.initQueues() + session.initKeysDownStateQueue() go func() { for msg := range session.rpcQueue { @@ -157,40 +246,12 @@ func newSession(config SessionConfig) (*Session, error) { switch d.Label() { case "hidrpc": session.HidChannel = d - d.OnMessage(func(msg webrtc.DataChannelMessage) { - l := scopedLogger.With().Int("length", len(msg.Data)).Logger() - // only log data if the log level is debug or lower - if scopedLogger.GetLevel() > zerolog.DebugLevel { - l = l.With().Str("data", string(msg.Data)).Logger() - } - - if msg.IsString { - l.Warn().Msg("received string data in HID RPC message handler") - return - } - - if len(msg.Data) < 1 { - l.Warn().Msg("received empty data in HID RPC message handler") - return - } - - l.Trace().Msg("received data in HID RPC message handler") - - // Enqueue to ensure ordered processing - queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) - if queueIndex >= len(session.hidQueue) || queueIndex < 0 { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") - queueIndex = 3 - } - - queue := session.hidQueue[queueIndex] - if queue != nil { - queue <- msg - } else { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") - return - } - }) + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc")) + // we won't send anything over the unreliable channels + case "hidrpc-unreliable-ordered": + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-ordered")) + case "hidrpc-unreliable-nonordered": + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-nonordered")) case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) {