diff --git a/Makefile b/Makefile index 6ca1dbb8..c533b96c 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,3 @@ -# --- JetKVM Audio/Toolchain Dev Environment Setup --- .PHONY: setup_toolchain build_audio_deps dev_env lint lint-go lint-ui lint-fix lint-go-fix lint-ui-fix ui-lint # Clone the rv1106-system toolchain to $HOME/.jetkvm/rv1106-system @@ -23,8 +22,8 @@ BRANCH ?= $(shell git rev-parse --abbrev-ref HEAD) BUILDDATE ?= $(shell date -u +%FT%T%z) BUILDTS ?= $(shell date -u +%s) REVISION ?= $(shell git rev-parse HEAD) -VERSION_DEV ?= 0.4.7-dev$(shell date +%Y%m%d%H%M) -VERSION ?= 0.4.6 +VERSION_DEV := 0.4.8-dev$(shell date +%Y%m%d%H%M) +VERSION := 0.4.7 # Audio library versions ALSA_VERSION ?= 1.2.14 @@ -127,7 +126,7 @@ frontend: -exec sh -c 'gzip -9 -kfv {}' \; dev_release: frontend build_dev - @echo "Uploading release..." + @echo "Uploading release... $(VERSION_DEV)" @shasum -a 256 bin/jetkvm_app | cut -d ' ' -f 1 > bin/jetkvm_app.sha256 rclone copyto bin/jetkvm_app r2://jetkvm-update/app/$(VERSION_DEV)/jetkvm_app rclone copyto bin/jetkvm_app.sha256 r2://jetkvm-update/app/$(VERSION_DEV)/jetkvm_app.sha256 diff --git a/cloud.go b/cloud.go index cec749e4..fb138508 100644 --- a/cloud.go +++ b/cloud.go @@ -475,6 +475,10 @@ func handleSessionRequest( cloudLogger.Info().Interface("session", session).Msg("new session accepted") cloudLogger.Trace().Interface("session", session).Msg("new session accepted") + + // Cancel any ongoing keyboard macro when session changes + cancelKeyboardMacro() + currentSession = session _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd}) return nil diff --git a/display.go b/display.go index 15d3ffcf..8cd632c7 100644 --- a/display.go +++ b/display.go @@ -64,11 +64,11 @@ func lvObjSetOpacity(objName string, opacity int) (*CtrlResponse, error) { // no } func lvObjFadeIn(objName string, duration uint32) (*CtrlResponse, error) { - return CallCtrlAction("lv_obj_fade_in", map[string]any{"obj": objName, "time": duration}) + return CallCtrlAction("lv_obj_fade_in", map[string]any{"obj": objName, "duration": duration}) } func lvObjFadeOut(objName string, duration uint32) (*CtrlResponse, error) { - return CallCtrlAction("lv_obj_fade_out", map[string]any{"obj": objName, "time": duration}) + return CallCtrlAction("lv_obj_fade_out", map[string]any{"obj": objName, "duration": duration}) } func lvLabelSetText(objName string, text string) (*CtrlResponse, error) { diff --git a/go.mod b/go.mod index 962c3a1b..d07ba239 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gin-contrib/logger v1.2.6 github.com/gin-gonic/gin v1.10.1 github.com/go-co-op/gocron/v2 v2.16.5 + github.com/google/flatbuffers v25.2.10+incompatible github.com/google/uuid v1.6.0 github.com/guregu/null/v6 v6.0.0 github.com/gwatts/rootcerts v0.0.0-20250901182336-dc5ae18bd79f @@ -23,6 +24,7 @@ require ( github.com/prometheus/common v0.66.0 github.com/prometheus/procfs v0.17.0 github.com/psanford/httpreadat v0.1.0 + github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index e19fa9e6..57576a3a 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,8 @@ github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAu github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= +github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -152,6 +154,7 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= diff --git a/hidrpc.go b/hidrpc.go index 74fe687f..ebe03daa 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -1,11 +1,14 @@ package kvm import ( + "errors" "fmt" + "io" "time" "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/usbgadget" + "github.com/rs/zerolog" ) func handleHidRPCMessage(message hidrpc.Message, session *Session) { @@ -24,11 +27,19 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } session.hidRPCAvailable = true case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: - keysDownState, err := handleHidRPCKeyboardInput(message) - if keysDownState != nil { - session.reportHidRPCKeysDownState(*keysDownState) + rpcErr = handleHidRPCKeyboardInput(message) + case hidrpc.TypeKeyboardMacroReport: + keyboardMacroReport, err := message.KeyboardMacroReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get keyboard macro report") + return } - rpcErr = err + rpcErr = rpcExecuteKeyboardMacro(keyboardMacroReport.Steps) + case hidrpc.TypeCancelKeyboardMacroReport: + rpcCancelKeyboardMacro() + return + case hidrpc.TypeKeypressKeepAliveReport: + rpcErr = handleHidRPCKeypressKeepAlive(session) case hidrpc.TypePointerReport: pointerReport, err := message.PointerReport() if err != nil { @@ -52,8 +63,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } } -func onHidMessage(data []byte, session *Session) { - scopedLogger := hidRPCLogger.With().Bytes("data", data).Logger() +func onHidMessage(msg hidQueueMessage, session *Session) { + data := msg.Data + + scopedLogger := hidRPCLogger.With(). + Str("channel", msg.channel). + Bytes("data", data). + Logger() scopedLogger.Debug().Msg("HID RPC message received") if len(data) < 1 { @@ -68,7 +84,9 @@ func onHidMessage(data []byte, session *Session) { return } - scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + if scopedLogger.GetLevel() <= zerolog.DebugLevel { + scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + } t := time.Now() @@ -85,27 +103,88 @@ func onHidMessage(data []byte, session *Session) { } } -func handleHidRPCKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { +// Tunables +// Keep in mind +// macOS default: 15 * 15 = 225ms https://discussions.apple.com/thread/1316947?sortBy=rank +// Linux default: 250ms https://man.archlinux.org/man/kbdrate.8.en +// Windows default: 1s `HKEY_CURRENT_USER\Control Panel\Accessibility\Keyboard Response\AutoRepeatDelay` + +const expectedRate = 50 * time.Millisecond // expected keepalive interval +const maxLateness = 50 * time.Millisecond // max jitter we'll tolerate OR jitter budget +const baseExtension = expectedRate + maxLateness // 100ms extension on perfect tick + +const maxStaleness = 225 * time.Millisecond // discard ancient packets outright + +func handleHidRPCKeypressKeepAlive(session *Session) error { + session.keepAliveJitterLock.Lock() + defer session.keepAliveJitterLock.Unlock() + + now := time.Now() + + // 1) Staleness guard: ensures packets that arrive far beyond the life of a valid key hold + // (e.g. after a network stall, retransmit burst, or machine sleep) are ignored outright. + // This prevents “zombie” keepalives from reviving a key that should already be released. + if !session.lastTimerResetTime.IsZero() && now.Sub(session.lastTimerResetTime) > maxStaleness { + return nil + } + + validTick := true + timerExtension := baseExtension + + if !session.lastKeepAliveArrivalTime.IsZero() { + timeSinceLastTick := now.Sub(session.lastKeepAliveArrivalTime) + lateness := timeSinceLastTick - expectedRate + + if lateness > 0 { + if lateness <= maxLateness { + // --- Small lateness (within jitterBudget) --- + // This is normal jitter (e.g., Wi-Fi contention). + // We still accept the tick, but *reduce the extension* + // so that the total hold time stays aligned with REAL client side intent. + timerExtension -= lateness + } else { + // --- Large lateness (beyond jitterBudget) --- + // This is likely a retransmit stall or ordering delay. + // We reject the tick entirely and DO NOT extend, + // so the auto-release still fires on time. + validTick = false + } + } + } + + if !validTick { + return nil + } + // Only valid ticks update our state and extend the timer. + session.lastKeepAliveArrivalTime = now + session.lastTimerResetTime = now + if gadget != nil { + gadget.DelayAutoReleaseWithDuration(timerExtension) + } + + // On a miss: do not advance any state — keeps baseline stable. + return nil +} + +func handleHidRPCKeyboardInput(message hidrpc.Message) error { switch message.Type() { case hidrpc.TypeKeypressReport: keypressReport, err := message.KeypressReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keypress report") - return nil, err + return err } - keysDownState, rpcError := rpcKeypressReport(keypressReport.Key, keypressReport.Press) - return &keysDownState, rpcError + return rpcKeypressReport(keypressReport.Key, keypressReport.Press) case hidrpc.TypeKeyboardReport: keyboardReport, err := message.KeyboardReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keyboard report") - return nil, err + return err } - keysDownState, rpcError := rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) - return &keysDownState, rpcError + return rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) } - return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type()) + return fmt.Errorf("unknown HID RPC message type: %d", message.Type()) } func reportHidRPC(params any, session *Session) { @@ -115,7 +194,10 @@ func reportHidRPC(params any, session *Session) { } if !session.hidRPCAvailable || session.HidChannel == nil { - logger.Warn().Msg("HID RPC is not available, skipping reportHidRPC") + logger.Warn(). + Bool("hidRPCAvailable", session.hidRPCAvailable). + Bool("HidChannel", session.HidChannel != nil). + Msg("HID RPC is not available, skipping reportHidRPC") return } @@ -128,6 +210,8 @@ func reportHidRPC(params any, session *Session) { message, err = hidrpc.NewKeyboardLedMessage(params).Marshal() case usbgadget.KeysDownState: message, err = hidrpc.NewKeydownStateMessage(params).Marshal() + case hidrpc.KeyboardMacroState: + message, err = hidrpc.NewKeyboardMacroStateMessage(params.State, params.IsPaste).Marshal() default: err = fmt.Errorf("unknown HID RPC message type: %T", params) } @@ -143,6 +227,10 @@ func reportHidRPC(params any, session *Session) { } if err := session.HidChannel.Send(message); err != nil { + if errors.Is(err, io.ErrClosedPipe) { + logger.Debug().Err(err).Msg("HID RPC channel closed, skipping reportHidRPC") + return + } logger.Warn().Err(err).Msg("failed to send HID RPC message") } } @@ -156,7 +244,16 @@ func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) { func (s *Session) reportHidRPCKeysDownState(state usbgadget.KeysDownState) { if !s.hidRPCAvailable { + usbLogger.Debug().Interface("state", state).Msg("reporting keys down state") writeJSONRPCEvent("keysDownState", state, s) } + usbLogger.Debug().Interface("state", state).Msg("reporting keys down state, calling reportHidRPC") + reportHidRPC(state, s) +} + +func (s *Session) reportHidRPCKeyboardMacroState(state hidrpc.KeyboardMacroState) { + if !s.hidRPCAvailable { + writeJSONRPCEvent("keyboardMacroState", state, s) + } reportHidRPC(state, s) } diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go index e9c8c24d..7313e3b5 100644 --- a/internal/hidrpc/hidrpc.go +++ b/internal/hidrpc/hidrpc.go @@ -10,14 +10,18 @@ import ( type MessageType byte const ( - TypeHandshake MessageType = 0x01 - TypeKeyboardReport MessageType = 0x02 - TypePointerReport MessageType = 0x03 - TypeWheelReport MessageType = 0x04 - TypeKeypressReport MessageType = 0x05 - TypeMouseReport MessageType = 0x06 - TypeKeyboardLedState MessageType = 0x32 - TypeKeydownState MessageType = 0x33 + TypeHandshake MessageType = 0x01 + TypeKeyboardReport MessageType = 0x02 + TypePointerReport MessageType = 0x03 + TypeWheelReport MessageType = 0x04 + TypeKeypressReport MessageType = 0x05 + TypeKeypressKeepAliveReport MessageType = 0x09 + TypeMouseReport MessageType = 0x06 + TypeKeyboardMacroReport MessageType = 0x07 + TypeCancelKeyboardMacroReport MessageType = 0x08 + TypeKeyboardLedState MessageType = 0x32 + TypeKeydownState MessageType = 0x33 + TypeKeyboardMacroState MessageType = 0x34 ) const ( @@ -29,10 +33,13 @@ func GetQueueIndex(messageType MessageType) int { switch messageType { case TypeHandshake: return 0 - case TypeKeyboardReport, TypeKeypressReport, TypeKeyboardLedState, TypeKeydownState: + case TypeKeyboardReport, TypeKeypressReport, TypeKeyboardMacroReport, TypeKeyboardLedState, TypeKeydownState, TypeKeyboardMacroState: return 1 case TypePointerReport, TypeMouseReport, TypeWheelReport: return 2 + // we don't want to block the queue for this message + case TypeCancelKeyboardMacroReport: + return 3 default: return 3 } @@ -98,3 +105,19 @@ func NewKeydownStateMessage(state usbgadget.KeysDownState) *Message { d: data, } } + +// NewKeyboardMacroStateMessage creates a new keyboard macro state message. +func NewKeyboardMacroStateMessage(state bool, isPaste bool) *Message { + data := make([]byte, 2) + if state { + data[0] = 1 + } + if isPaste { + data[1] = 1 + } + + return &Message{ + t: TypeKeyboardMacroState, + d: data, + } +} diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go index 84bbda7c..3f3506f7 100644 --- a/internal/hidrpc/message.go +++ b/internal/hidrpc/message.go @@ -1,6 +1,7 @@ package hidrpc import ( + "encoding/binary" "fmt" ) @@ -43,6 +44,13 @@ func (m *Message) String() string { return fmt.Sprintf("MouseReport{Malformed: %v}", m.d) } return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2]) + case TypeKeypressKeepAliveReport: + return "KeypressKeepAliveReport" + case TypeKeyboardMacroReport: + if len(m.d) < 5 { + return fmt.Sprintf("KeyboardMacroReport{Malformed: %v}", m.d) + } + return fmt.Sprintf("KeyboardMacroReport{IsPaste: %v, Length: %d}", m.d[0] == uint8(1), binary.BigEndian.Uint32(m.d[1:5])) default: return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d) } @@ -84,6 +92,55 @@ func (m *Message) KeyboardReport() (KeyboardReport, error) { }, nil } +// Macro .. +type KeyboardMacroStep struct { + Modifier byte // 1 byte + Keys []byte // 6 bytes: hidKeyBufferSize + Delay uint16 // 2 bytes +} +type KeyboardMacroReport struct { + IsPaste bool + StepCount uint32 + Steps []KeyboardMacroStep +} + +// HidKeyBufferSize is the size of the keys buffer in the keyboard report. +const HidKeyBufferSize = 6 + +// KeyboardMacroReport returns the keyboard macro report from the message. +func (m *Message) KeyboardMacroReport() (KeyboardMacroReport, error) { + if m.t != TypeKeyboardMacroReport { + return KeyboardMacroReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + isPaste := m.d[0] == uint8(1) + stepCount := binary.BigEndian.Uint32(m.d[1:5]) + + // check total length + expectedLength := int(stepCount)*9 + 5 + if len(m.d) != expectedLength { + return KeyboardMacroReport{}, fmt.Errorf("invalid length: %d, expected: %d", len(m.d), expectedLength) + } + + steps := make([]KeyboardMacroStep, 0, int(stepCount)) + offset := 5 + for i := 0; i < int(stepCount); i++ { + steps = append(steps, KeyboardMacroStep{ + Modifier: m.d[offset], + Keys: m.d[offset+1 : offset+7], + Delay: binary.BigEndian.Uint16(m.d[offset+7 : offset+9]), + }) + + offset += 1 + HidKeyBufferSize + 2 + } + + return KeyboardMacroReport{ + IsPaste: isPaste, + Steps: steps, + StepCount: stepCount, + }, nil +} + // PointerReport .. type PointerReport struct { X int @@ -131,3 +188,20 @@ func (m *Message) MouseReport() (MouseReport, error) { Button: uint8(m.d[2]), }, nil } + +type KeyboardMacroState struct { + State bool + IsPaste bool +} + +// KeyboardMacroState returns the keyboard macro state report from the message. +func (m *Message) KeyboardMacroState() (KeyboardMacroState, error) { + if m.t != TypeKeyboardMacroState { + return KeyboardMacroState{}, fmt.Errorf("invalid message type: %d", m.t) + } + + return KeyboardMacroState{ + State: m.d[0] == uint8(1), + IsPaste: m.d[1] == uint8(1), + }, nil +} diff --git a/internal/network/config.go b/internal/network/config.go index 8a28d515..da99496f 100644 --- a/internal/network/config.go +++ b/internal/network/config.go @@ -56,13 +56,12 @@ type NetworkConfig struct { } func (c *NetworkConfig) GetMDNSMode() *mdns.MDNSListenOptions { - mode := c.MDNSMode.String listenOptions := &mdns.MDNSListenOptions{ - IPv4: true, - IPv6: true, + IPv4: c.IPv4Mode.String != "disabled", + IPv6: c.IPv6Mode.String != "disabled", } - switch mode { + switch c.MDNSMode.String { case "ipv4_only": listenOptions.IPv6 = false case "ipv6_only": diff --git a/internal/network/netif.go b/internal/network/netif.go index 5a8dab6c..44bcaa4b 100644 --- a/internal/network/netif.go +++ b/internal/network/netif.go @@ -48,7 +48,7 @@ type NetworkInterfaceOptions struct { DefaultHostname string OnStateChange func(state *NetworkInterfaceState) OnInitialCheck func(state *NetworkInterfaceState) - OnDhcpLeaseChange func(lease *udhcpc.Lease) + OnDhcpLeaseChange func(lease *udhcpc.Lease, state *NetworkInterfaceState) OnConfigChange func(config *NetworkConfig) NetworkConfig *NetworkConfig } @@ -94,7 +94,7 @@ func NewNetworkInterfaceState(opts *NetworkInterfaceOptions) (*NetworkInterfaceS _ = s.updateNtpServersFromLease(lease) _ = s.setHostnameIfNotSame() - opts.OnDhcpLeaseChange(lease) + opts.OnDhcpLeaseChange(lease, s) }, }) @@ -239,6 +239,10 @@ func (s *NetworkInterfaceState) update() (DhcpTargetState, error) { ipv4Addresses = append(ipv4Addresses, addr.IP) ipv4AddressesString = append(ipv4AddressesString, addr.IPNet.String()) } else if addr.IP.To16() != nil { + if s.config.IPv6Mode.String == "disabled" { + continue + } + scopedLogger := s.l.With().Str("ipv6", addr.IP.String()).Logger() // check if it's a link local address if addr.IP.IsLinkLocalUnicast() { @@ -287,35 +291,37 @@ func (s *NetworkInterfaceState) update() (DhcpTargetState, error) { } s.ipv4Addresses = ipv4AddressesString - if ipv6LinkLocal != nil { - if s.ipv6LinkLocal == nil || s.ipv6LinkLocal.String() != ipv6LinkLocal.String() { - scopedLogger := s.l.With().Str("ipv6", ipv6LinkLocal.String()).Logger() - if s.ipv6LinkLocal != nil { - scopedLogger.Info(). - Str("old_ipv6", s.ipv6LinkLocal.String()). - Msg("IPv6 link local address changed") - } else { - scopedLogger.Info().Msg("IPv6 link local address found") + if s.config.IPv6Mode.String != "disabled" { + if ipv6LinkLocal != nil { + if s.ipv6LinkLocal == nil || s.ipv6LinkLocal.String() != ipv6LinkLocal.String() { + scopedLogger := s.l.With().Str("ipv6", ipv6LinkLocal.String()).Logger() + if s.ipv6LinkLocal != nil { + scopedLogger.Info(). + Str("old_ipv6", s.ipv6LinkLocal.String()). + Msg("IPv6 link local address changed") + } else { + scopedLogger.Info().Msg("IPv6 link local address found") + } + s.ipv6LinkLocal = ipv6LinkLocal + changed = true } - s.ipv6LinkLocal = ipv6LinkLocal - changed = true } - } - s.ipv6Addresses = ipv6Addresses + s.ipv6Addresses = ipv6Addresses - if len(ipv6Addresses) > 0 { - // compare the addresses to see if there's a change - if s.ipv6Addr == nil || s.ipv6Addr.String() != ipv6Addresses[0].Address.String() { - scopedLogger := s.l.With().Str("ipv6", ipv6Addresses[0].Address.String()).Logger() - if s.ipv6Addr != nil { - scopedLogger.Info(). - Str("old_ipv6", s.ipv6Addr.String()). - Msg("IPv6 address changed") - } else { - scopedLogger.Info().Msg("IPv6 address found") + if len(ipv6Addresses) > 0 { + // compare the addresses to see if there's a change + if s.ipv6Addr == nil || s.ipv6Addr.String() != ipv6Addresses[0].Address.String() { + scopedLogger := s.l.With().Str("ipv6", ipv6Addresses[0].Address.String()).Logger() + if s.ipv6Addr != nil { + scopedLogger.Info(). + Str("old_ipv6", s.ipv6Addr.String()). + Msg("IPv6 address changed") + } else { + scopedLogger.Info().Msg("IPv6 address found") + } + s.ipv6Addr = &ipv6Addresses[0].Address + changed = true } - s.ipv6Addr = &ipv6Addresses[0].Address - changed = true } } diff --git a/internal/network/rpc.go b/internal/network/rpc.go index 32f34f57..62f21be8 100644 --- a/internal/network/rpc.go +++ b/internal/network/rpc.go @@ -65,7 +65,7 @@ func (s *NetworkInterfaceState) IPv6LinkLocalAddress() string { func (s *NetworkInterfaceState) RpcGetNetworkState() RpcNetworkState { ipv6Addresses := make([]RpcIPv6Address, 0) - if s.ipv6Addresses != nil { + if s.ipv6Addresses != nil && s.config.IPv6Mode.String != "disabled" { for _, addr := range s.ipv6Addresses { ipv6Addresses = append(ipv6Addresses, RpcIPv6Address{ Address: addr.Prefix.String(), diff --git a/internal/timesync/ntp.go b/internal/timesync/ntp.go index c32de2a2..b9ffa249 100644 --- a/internal/timesync/ntp.go +++ b/internal/timesync/ntp.go @@ -9,17 +9,32 @@ import ( "github.com/beevik/ntp" ) -var defaultNTPServers = []string{ +var defaultNTPServerIPs = []string{ + // These servers are known by static IP and as such don't need DNS lookups + // These are from Google and Cloudflare since if they're down, the internet + // is broken anyway + "162.159.200.1", // time.cloudflare.com IPv4 + "162.159.200.123", // time.cloudflare.com IPv4 + "2606:4700:f1::1", // time.cloudflare.com IPv6 + "2606:4700:f1::123", // time.cloudflare.com IPv6 + "216.239.35.0", // time.google.com IPv4 + "216.239.35.4", // time.google.com IPv4 + "216.239.35.8", // time.google.com IPv4 + "216.239.35.12", // time.google.com IPv4 + "2001:4860:4806::", // time.google.com IPv6 + "2001:4860:4806:4::", // time.google.com IPv6 + "2001:4860:4806:8::", // time.google.com IPv6 + "2001:4860:4806:c::", // time.google.com IPv6 +} + +var defaultNTPServerHostnames = []string{ + // should use something from https://github.com/jauderho/public-ntp-servers "time.apple.com", "time.aws.com", "time.windows.com", "time.google.com", - "162.159.200.123", // time.cloudflare.com IPv4 - "2606:4700:f1::123", // time.cloudflare.com IPv6 - "0.pool.ntp.org", - "1.pool.ntp.org", - "2.pool.ntp.org", - "3.pool.ntp.org", + "time.cloudflare.com", + "pool.ntp.org", } func (t *TimeSync) queryNetworkTime(ntpServers []string) (now *time.Time, offset *time.Duration) { diff --git a/internal/timesync/timesync.go b/internal/timesync/timesync.go index db1c96ee..b29a61ab 100644 --- a/internal/timesync/timesync.go +++ b/internal/timesync/timesync.go @@ -158,6 +158,7 @@ func (t *TimeSync) Sync() error { var ( now *time.Time offset *time.Duration + log zerolog.Logger ) metricTimeSyncCount.Inc() @@ -166,54 +167,54 @@ func (t *TimeSync) Sync() error { Orders: for _, mode := range syncMode.Ordering { + log = t.l.With().Str("mode", mode).Logger() switch mode { case "ntp_user_provided": if syncMode.Ntp { - t.l.Info().Msg("using NTP custom servers") + log.Info().Msg("using NTP custom servers") now, offset = t.queryNetworkTime(t.networkConfig.TimeSyncNTPServers) if now != nil { - t.l.Info().Str("source", "NTP").Time("now", *now).Msg("time obtained") break Orders } } case "ntp_dhcp": if syncMode.Ntp { - t.l.Info().Msg("using NTP servers from DHCP") + log.Info().Msg("using NTP servers from DHCP") now, offset = t.queryNetworkTime(t.dhcpNtpAddresses) if now != nil { - t.l.Info().Str("source", "NTP DHCP").Time("now", *now).Msg("time obtained") break Orders } } case "ntp": if syncMode.Ntp && syncMode.NtpUseFallback { - t.l.Info().Msg("using NTP fallback") - now, offset = t.queryNetworkTime(defaultNTPServers) + log.Info().Msg("using NTP fallback IPs") + now, offset = t.queryNetworkTime(defaultNTPServerIPs) + if now == nil { + log.Info().Msg("using NTP fallback hostnames") + now, offset = t.queryNetworkTime(defaultNTPServerHostnames) + } if now != nil { - t.l.Info().Str("source", "NTP fallback").Time("now", *now).Msg("time obtained") break Orders } } case "http_user_provided": if syncMode.Http { - t.l.Info().Msg("using HTTP custom URLs") + log.Info().Msg("using HTTP custom URLs") now = t.queryAllHttpTime(t.networkConfig.TimeSyncHTTPUrls) if now != nil { - t.l.Info().Str("source", "HTTP").Time("now", *now).Msg("time obtained") break Orders } } case "http": if syncMode.Http && syncMode.HttpUseFallback { - t.l.Info().Msg("using HTTP fallback") + log.Info().Msg("using HTTP fallback") now = t.queryAllHttpTime(defaultHTTPUrls) if now != nil { - t.l.Info().Str("source", "HTTP fallback").Time("now", *now).Msg("time obtained") break Orders } } default: - t.l.Warn().Str("mode", mode).Msg("unknown time sync mode, skipping") + log.Warn().Msg("unknown time sync mode, skipping") } } @@ -226,6 +227,8 @@ Orders: now = &newNow } + log.Info().Time("now", *now).Msg("time obtained") + err := t.setSystemTime(*now) if err != nil { return fmt.Errorf("failed to set system time: %w", err) diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index 8208a541..99fa2887 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -5,7 +5,11 @@ import ( "context" "fmt" "os" + "sync" "time" + + "github.com/rs/xid" + "github.com/rs/zerolog" ) var keyboardConfig = gadgetConfigItem{ @@ -145,32 +149,95 @@ func (u *UsbGadget) GetKeysDownState() KeysDownState { return u.keysDownState } -func (u *UsbGadget) updateKeyDownState(state KeysDownState) { - u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("acquiring keyboardStateLock for updateKeyDownState") +func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { + u.onKeysDownChange = &f +} - // this is intentional to unlock keyboard state lock before onKeysDownChange callback - { - u.keyboardStateLock.Lock() - defer u.keyboardStateLock.Unlock() +func (u *UsbGadget) SetOnKeepAliveReset(f func()) { + u.onKeepAliveReset = &f +} - if u.keysDownState.Modifier == state.Modifier && - bytes.Equal(u.keysDownState.Keys, state.Keys) { - return // No change in key down state - } +// DefaultAutoReleaseDuration is the default duration for auto-release of a key. +const DefaultAutoReleaseDuration = 100 * time.Millisecond - u.log.Trace().Interface("old", u.keysDownState).Interface("new", state).Msg("keysDownState updated") - u.keysDownState = state +func (u *UsbGadget) scheduleAutoRelease(key byte) { + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled") + + if u.kbdAutoReleaseTimers[key] != nil { + u.kbdAutoReleaseTimers[key].Stop() } - if u.onKeysDownChange != nil { - u.log.Trace().Interface("state", state).Msg("calling onKeysDownChange") - (*u.onKeysDownChange)(state) - u.log.Trace().Interface("state", state).Msg("onKeysDownChange called") + // TODO: make this configurable + // We currently hardcode the duration to 100ms + // However, it should be the same as the duration of the keep-alive reset called baseExtension. + u.kbdAutoReleaseTimers[key] = time.AfterFunc(100*time.Millisecond, func() { + u.performAutoRelease(key) + }) +} + +func (u *UsbGadget) cancelAutoRelease(key byte) { + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled") + + if timer := u.kbdAutoReleaseTimers[key]; timer != nil { + timer.Stop() + u.kbdAutoReleaseTimers[key] = nil + delete(u.kbdAutoReleaseTimers, key) + + // Reset keep-alive timing when key is released + if u.onKeepAliveReset != nil { + (*u.onKeepAliveReset)() + } } } -func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { - u.onKeysDownChange = &f +func (u *UsbGadget) DelayAutoReleaseWithDuration(resetDuration time.Duration) { + u.kbdAutoReleaseLock.Lock() + defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed") + + u.log.Debug().Dur("reset_duration", resetDuration).Msg("delaying auto-release with dynamic duration") + + for _, timer := range u.kbdAutoReleaseTimers { + if timer != nil { + timer.Reset(resetDuration) + } + } +} + +func (u *UsbGadget) performAutoRelease(key byte) { + u.kbdAutoReleaseLock.Lock() + + if u.kbdAutoReleaseTimers[key] == nil { + u.log.Warn().Uint8("key", key).Msg("autoRelease timer not found") + u.kbdAutoReleaseLock.Unlock() + return + } + + u.kbdAutoReleaseTimers[key].Stop() + u.kbdAutoReleaseTimers[key] = nil + delete(u.kbdAutoReleaseTimers, key) + u.kbdAutoReleaseLock.Unlock() + + // Skip if already released + state := u.GetKeysDownState() + alreadyReleased := true + + for i := range state.Keys { + if state.Keys[i] == key { + alreadyReleased = false + break + } + } + + if alreadyReleased { + return + } + + _, err := u.keypressReport(key, false) + if err != nil { + u.log.Warn().Uint8("key", key).Msg("failed to release key") + } } func (u *UsbGadget) listenKeyboardEvents() { @@ -242,7 +309,11 @@ func (u *UsbGadget) OpenKeyboardHidFile() error { return u.openKeyboardHidFile() } +var keyboardWriteHidFileLock sync.Mutex + func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { + keyboardWriteHidFileLock.Lock() + defer keyboardWriteHidFileLock.Unlock() if err := u.openKeyboardHidFile(); err != nil { return err } @@ -265,17 +336,29 @@ func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { } } - downState := KeysDownState{ + state := KeysDownState{ Modifier: modifier, Keys: []byte(keys[:]), } - u.updateKeyDownState(downState) - return downState + + u.keyboardStateLock.Lock() + + if u.keysDownState.Modifier == state.Modifier && + bytes.Equal(u.keysDownState.Keys, state.Keys) { + u.keyboardStateLock.Unlock() + return state // No change in key down state + } + + u.keysDownState = state + u.keyboardStateLock.Unlock() + + if u.onKeysDownChange != nil { + (*u.onKeysDownChange)(state) // this enques to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...) + } + return state } -func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() +func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error { defer u.resetUserInputTime() if len(keys) > hidKeyBufferSize { @@ -290,7 +373,8 @@ func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, e u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") } - return u.UpdateKeysDown(modifier, keys), err + u.UpdateKeysDown(modifier, keys) + return err } const ( @@ -330,17 +414,23 @@ var KeyCodeToMaskMap = map[byte]byte{ RightSuper: ModifierMaskRightSuper, } -func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) { - u.keyboardLock.Lock() - defer u.keyboardLock.Unlock() +func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) { defer u.resetUserInputTime() + l := u.log.With().Uint8("key", key).Bool("press", press).Logger() + if l.GetLevel() <= zerolog.DebugLevel { + requestID := xid.New() + l = l.With().Str("requestID", requestID.String()).Logger() + } + // IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver // for handling key presses and releases. It ensures that the USB gadget // behaves similarly to a real USB HID keyboard. This logic is paralleled // in the client/browser-side code in useKeyboard.ts so make sure to keep // them in sync. - var state = u.keysDownState + var state = u.GetKeysDownState() + l.Trace().Interface("state", state).Msg("got keys down state") + modifier := state.Modifier keys := append([]byte(nil), state.Keys...) @@ -380,22 +470,36 @@ func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) // If we reach here it means we didn't find an empty slot or the key in the buffer if overrun { if press { - u.log.Error().Uint8("key", key).Msg("keyboard buffer overflow, key not added") + l.Error().Msg("keyboard buffer overflow, key not added") // Fill all key slots with ErrorRollOver (0x01) to indicate overflow for i := range keys { keys[i] = hidErrorRollOver } } else { // If we are releasing a key, and we didn't find it in a slot, who cares? - u.log.Warn().Uint8("key", key).Msg("key not found in buffer, nothing to release") + l.Warn().Msg("key not found in buffer, nothing to release") } } } err := u.keyboardWriteHidFile(modifier, keys) - if err != nil { - u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keypress report to hidg0") - } - return u.UpdateKeysDown(modifier, keys), err } + +func (u *UsbGadget) KeypressReport(key byte, press bool) error { + state, err := u.keypressReport(key, press) + if err != nil { + u.log.Warn().Uint8("key", key).Bool("press", press).Msg("failed to report key") + } + isRolledOver := state.Keys[0] == hidErrorRollOver + + if isRolledOver { + u.cancelAutoRelease(key) + } else if press { + u.scheduleAutoRelease(key) + } else { + u.cancelAutoRelease(key) + } + + return err +} diff --git a/internal/usbgadget/usbgadget.go b/internal/usbgadget/usbgadget.go index 5fc7a49b..04db4699 100644 --- a/internal/usbgadget/usbgadget.go +++ b/internal/usbgadget/usbgadget.go @@ -69,6 +69,9 @@ type UsbGadget struct { keyboardState byte // keyboard latched state (NumLock, CapsLock, ScrollLock, Compose, Kana) keysDownState KeysDownState // keyboard dynamic state (modifier keys and pressed keys) + kbdAutoReleaseLock sync.Mutex + kbdAutoReleaseTimers map[byte]*time.Timer + keyboardStateLock sync.Mutex keyboardStateCtx context.Context keyboardStateCancel context.CancelFunc @@ -86,6 +89,7 @@ type UsbGadget struct { onKeyboardStateChange *func(state KeyboardState) onKeysDownChange *func(state KeysDownState) + onKeepAliveReset *func() log *zerolog.Logger @@ -179,23 +183,24 @@ func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDev keyboardCtx, keyboardCancel := context.WithCancel(context.Background()) g := &UsbGadget{ - name: name, - kvmGadgetPath: path.Join(gadgetPath, name), - configC1Path: path.Join(gadgetPath, name, "configs/c.1"), - configMap: configMap, - customConfig: *config, - configLock: sync.Mutex{}, - keyboardLock: sync.Mutex{}, - absMouseLock: sync.Mutex{}, - relMouseLock: sync.Mutex{}, - txLock: sync.Mutex{}, - keyboardStateCtx: keyboardCtx, - keyboardStateCancel: keyboardCancel, - keyboardState: 0, - keysDownState: KeysDownState{Modifier: 0, Keys: []byte{0, 0, 0, 0, 0, 0}}, // must be initialized to hidKeyBufferSize (6) zero bytes - enabledDevices: *enabledDevices, - lastUserInput: time.Now(), - log: logger, + name: name, + kvmGadgetPath: path.Join(gadgetPath, name), + configC1Path: path.Join(gadgetPath, name, "configs/c.1"), + configMap: configMap, + customConfig: *config, + configLock: sync.Mutex{}, + keyboardLock: sync.Mutex{}, + absMouseLock: sync.Mutex{}, + relMouseLock: sync.Mutex{}, + txLock: sync.Mutex{}, + keyboardStateCtx: keyboardCtx, + keyboardStateCancel: keyboardCancel, + keyboardState: 0, + keysDownState: KeysDownState{Modifier: 0, Keys: []byte{0, 0, 0, 0, 0, 0}}, // must be initialized to hidKeyBufferSize (6) zero bytes + kbdAutoReleaseTimers: make(map[byte]*time.Timer), + enabledDevices: *enabledDevices, + lastUserInput: time.Now(), + log: logger, strictMode: config.strictMode, @@ -210,3 +215,37 @@ func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDev return g } + +// Close cleans up resources used by the USB gadget +func (u *UsbGadget) Close() error { + // Cancel keyboard state context + if u.keyboardStateCancel != nil { + u.keyboardStateCancel() + } + + // Stop auto-release timer + u.kbdAutoReleaseLock.Lock() + for _, timer := range u.kbdAutoReleaseTimers { + if timer != nil { + timer.Stop() + } + } + u.kbdAutoReleaseTimers = make(map[byte]*time.Timer) + u.kbdAutoReleaseLock.Unlock() + + // Close HID files + if u.keyboardHidFile != nil { + u.keyboardHidFile.Close() + u.keyboardHidFile = nil + } + if u.absMouseHidFile != nil { + u.absMouseHidFile.Close() + u.absMouseHidFile = nil + } + if u.relMouseHidFile != nil { + u.relMouseHidFile.Close() + u.relMouseHidFile = nil + } + + return nil +} diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index d51f9e40..85bf1579 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -9,6 +9,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "github.com/rs/zerolog" @@ -120,6 +121,12 @@ func (u *UsbGadget) writeWithTimeout(file *os.File, data []byte) (n int, err err return } + u.log.Trace(). + Str("file", file.Name()). + Bytes("data", data). + Err(err). + Msg("write failed") + if errors.Is(err, os.ErrDeadlineExceeded) { u.logWithSuppression( fmt.Sprintf("writeWithTimeout_%s", file.Name()), @@ -164,3 +171,8 @@ func (u *UsbGadget) resetLogSuppressionCounter(counterName string) { u.logSuppressionCounter[counterName] = 0 } } + +func unlockWithLog(lock *sync.Mutex, logger *zerolog.Logger, msg string, args ...any) { + logger.Trace().Msgf(msg, args...) + lock.Unlock() +} diff --git a/internal/utils/ssh.go b/internal/utils/ssh.go index e4602ffe..9b9e874a 100644 --- a/internal/utils/ssh.go +++ b/internal/utils/ssh.go @@ -20,6 +20,8 @@ var ValidSSHKeyTypes = []string{ ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521, + ssh.KeyAlgoSKED25519, + ssh.KeyAlgoSKECDSA256, } // ValidateSSHKey validates authorized_keys file content diff --git a/internal/utils/ssh_test.go b/internal/utils/ssh_test.go index f89cb90b..7502032b 100644 --- a/internal/utils/ssh_test.go +++ b/internal/utils/ssh_test.go @@ -27,6 +27,16 @@ func TestValidateSSHKey(t *testing.T) { sshKey: "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBAlTkxIo4mXBR+gEX0Q74BpYX4bFFHoX+8Uz7tsob8HvsnMvsEE+BW9h9XrbWX4/4ppL/o6sHbvsqNr9HcyKfdc= test@example.com", expectError: false, }, + { + name: "valid SK-backed ED25519 key", + sshKey: "sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIHHSRVC3qISk/mOorf24au6esimA9Uu1/BkEnVKJ+4bFAAAABHNzaDo= test@example.com", + expectError: false, + }, + { + name: "valid SK-backed ECDSA key", + sshKey: "sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBL/CFBZksvs+gJODMB9StxnkY6xRKH73npOzJBVb0UEGCPTAhDrvzW1PE5X5GDYXmZw1s7c/nS+GH0LF0OFCpwAAAAAEc3NoOg== test@example.com", + expectError: false, + }, { name: "multiple valid keys", sshKey: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp test@example.com\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBSbM8wuD5ab0nHsXaYOqaD3GLLUwmDzSk79Xi/N+H2j test@example.com", @@ -131,6 +141,8 @@ func TestValidSSHKeyTypes(t *testing.T) { "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521", + "sk-ecdsa-sha2-nistp256@openssh.com", + "sk-ssh-ed25519@openssh.com", } if len(ValidSSHKeyTypes) != len(expectedTypes) { diff --git a/jsonrpc.go b/jsonrpc.go index a2ed916a..4fe42cba 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -1,6 +1,7 @@ package kvm import ( + "bytes" "context" "encoding/json" "errors" @@ -10,6 +11,7 @@ import ( "path/filepath" "reflect" "strconv" + "sync" "time" "github.com/pion/webrtc/v4" @@ -17,6 +19,7 @@ import ( "go.bug.st/serial" "github.com/jetkvm/kvm/internal/audio" + "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/utils" ) @@ -280,6 +283,17 @@ func rpcGetUpdateStatus() (*UpdateStatus, error) { return updateStatus, nil } +func rpcGetLocalVersion() (*LocalMetadata, error) { + systemVersion, appVersion, err := GetLocalVersion() + if err != nil { + return nil, fmt.Errorf("error getting local version: %w", err) + } + return &LocalMetadata{ + AppVersion: appVersion.String(), + SystemVersion: systemVersion.String(), + }, nil +} + func rpcTryUpdate() error { includePreRelease := config.IncludePreRelease go func() { @@ -1228,6 +1242,103 @@ func rpcSetLocalLoopbackOnly(enabled bool) error { return nil } +var ( + keyboardMacroCancel context.CancelFunc + keyboardMacroLock sync.Mutex +) + +// cancelKeyboardMacro cancels any ongoing keyboard macro execution +func cancelKeyboardMacro() { + keyboardMacroLock.Lock() + defer keyboardMacroLock.Unlock() + + if keyboardMacroCancel != nil { + keyboardMacroCancel() + logger.Info().Msg("canceled keyboard macro") + keyboardMacroCancel = nil + } +} + +func setKeyboardMacroCancel(cancel context.CancelFunc) { + keyboardMacroLock.Lock() + defer keyboardMacroLock.Unlock() + + keyboardMacroCancel = cancel +} + +func rpcExecuteKeyboardMacro(macro []hidrpc.KeyboardMacroStep) error { + cancelKeyboardMacro() + + ctx, cancel := context.WithCancel(context.Background()) + setKeyboardMacroCancel(cancel) + + s := hidrpc.KeyboardMacroState{ + State: true, + IsPaste: true, + } + + if currentSession != nil { + currentSession.reportHidRPCKeyboardMacroState(s) + } + + err := rpcDoExecuteKeyboardMacro(ctx, macro) + + setKeyboardMacroCancel(nil) + + s.State = false + if currentSession != nil { + currentSession.reportHidRPCKeyboardMacroState(s) + } + + return err +} + +func rpcCancelKeyboardMacro() { + cancelKeyboardMacro() +} + +var keyboardClearStateKeys = make([]byte, hidrpc.HidKeyBufferSize) + +func isClearKeyStep(step hidrpc.KeyboardMacroStep) bool { + return step.Modifier == 0 && bytes.Equal(step.Keys, keyboardClearStateKeys) +} + +func rpcDoExecuteKeyboardMacro(ctx context.Context, macro []hidrpc.KeyboardMacroStep) error { + logger.Debug().Interface("macro", macro).Msg("Executing keyboard macro") + + for i, step := range macro { + delay := time.Duration(step.Delay) * time.Millisecond + + err := rpcKeyboardReport(step.Modifier, step.Keys) + if err != nil { + logger.Warn().Err(err).Msg("failed to execute keyboard macro") + return err + } + + // notify the device that the keyboard state is being cleared + if isClearKeyStep(step) { + gadget.UpdateKeysDown(0, keyboardClearStateKeys) + } + + // Use context-aware sleep that can be cancelled + select { + case <-time.After(delay): + // Sleep completed normally + case <-ctx.Done(): + // make sure keyboard state is reset + err := rpcKeyboardReport(0, keyboardClearStateKeys) + if err != nil { + logger.Warn().Err(err).Msg("failed to reset keyboard state") + } + + logger.Debug().Int("step", i).Msg("Keyboard macro cancelled during sleep") + return ctx.Err() + } + } + + return nil +} + var rpcHandlers = map[string]RPCHandler{ "ping": {Func: rpcPing}, "reboot": {Func: rpcReboot, Params: []string{"force"}}, @@ -1238,10 +1349,10 @@ var rpcHandlers = map[string]RPCHandler{ "getNetworkSettings": {Func: rpcGetNetworkSettings}, "setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}}, "renewDHCPLease": {Func: rpcRenewDHCPLease}, - "keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}}, "getKeyboardLedState": {Func: rpcGetKeyboardLedState}, - "keypressReport": {Func: rpcKeypressReport, Params: []string{"key", "press"}}, "getKeyDownState": {Func: rpcGetKeysDownState}, + "keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}}, + "keypressReport": {Func: rpcKeypressReport, Params: []string{"key", "press"}}, "absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}}, "relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}}, "wheelReport": {Func: rpcWheelReport, Params: []string{"wheelY"}}, @@ -1263,6 +1374,7 @@ var rpcHandlers = map[string]RPCHandler{ "setEDID": {Func: rpcSetEDID, Params: []string{"edid"}}, "getDevChannelState": {Func: rpcGetDevChannelState}, "setDevChannelState": {Func: rpcSetDevChannelState, Params: []string{"enabled"}}, + "getLocalVersion": {Func: rpcGetLocalVersion}, "getUpdateStatus": {Func: rpcGetUpdateStatus}, "tryUpdate": {Func: rpcTryUpdate}, "getDevModeState": {Func: rpcGetDevModeState}, diff --git a/main.go b/main.go index 3e380e5a..7f61dbb8 100644 --- a/main.go +++ b/main.go @@ -246,16 +246,25 @@ func Main(audioServer bool, audioInputServer bool) { if !config.AutoUpdateEnabled { return } + + if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() { + logger.Debug().Msg("system time is not synced, will retry in 30 seconds") + time.Sleep(30 * time.Second) + continue + } + if currentSession != nil { logger.Debug().Msg("skipping update since a session is active") time.Sleep(1 * time.Minute) continue } + includePreRelease := config.IncludePreRelease err = TryUpdate(context.Background(), GetDeviceID(), includePreRelease) if err != nil { logger.Warn().Err(err).Msg("failed to auto update") } + time.Sleep(1 * time.Hour) } }() diff --git a/mdns.go b/mdns.go index d7a3b553..4f9b49b1 100644 --- a/mdns.go +++ b/mdns.go @@ -13,10 +13,7 @@ func initMdns() error { networkState.GetHostname(), networkState.GetFQDN(), }, - ListenOptions: &mdns.MDNSListenOptions{ - IPv4: true, - IPv6: true, - }, + ListenOptions: config.NetworkConfig.GetMDNSMode(), }) if err != nil { return err diff --git a/network.go b/network.go index d4f46e7a..af8e50fb 100644 --- a/network.go +++ b/network.go @@ -15,7 +15,7 @@ var ( networkState *network.NetworkInterfaceState ) -func networkStateChanged() { +func networkStateChanged(isOnline bool) { // do not block the main thread go waitCtrlAndRequestDisplayUpdate(true) @@ -37,6 +37,13 @@ func networkStateChanged() { networkState.GetFQDN(), }, true) } + + // if the network is now online, trigger an NTP sync if still needed + if isOnline && timeSync != nil && (isTimeSyncNeeded() || !timeSync.IsSyncSuccess()) { + if err := timeSync.Sync(); err != nil { + logger.Warn().Str("error", err.Error()).Msg("unable to sync time on network state change") + } + } } func initNetwork() error { @@ -48,13 +55,13 @@ func initNetwork() error { NetworkConfig: config.NetworkConfig, Logger: networkLogger, OnStateChange: func(state *network.NetworkInterfaceState) { - networkStateChanged() + networkStateChanged(state.IsOnline()) }, OnInitialCheck: func(state *network.NetworkInterfaceState) { - networkStateChanged() + networkStateChanged(state.IsOnline()) }, - OnDhcpLeaseChange: func(lease *udhcpc.Lease) { - networkStateChanged() + OnDhcpLeaseChange: func(lease *udhcpc.Lease, state *network.NetworkInterfaceState) { + networkStateChanged(state.IsOnline()) if currentSession == nil { return @@ -64,7 +71,15 @@ func initNetwork() error { }, OnConfigChange: func(networkConfig *network.NetworkConfig) { config.NetworkConfig = networkConfig - networkStateChanged() + networkStateChanged(false) + + if mDNS != nil { + _ = mDNS.SetListenOptions(networkConfig.GetMDNSMode()) + _ = mDNS.SetLocalNames([]string{ + networkState.GetHostname(), + networkState.GetFQDN(), + }, true) + } }, }) diff --git a/resource/jetkvm_native b/resource/jetkvm_native old mode 100644 new mode 100755 index 68d0d4e0..f4ea2666 Binary files a/resource/jetkvm_native and b/resource/jetkvm_native differ diff --git a/resource/jetkvm_native.sha256 b/resource/jetkvm_native.sha256 index 0c0a4ff5..5bec8574 100644 --- a/resource/jetkvm_native.sha256 +++ b/resource/jetkvm_native.sha256 @@ -1 +1 @@ -01db2bbcd0bad46c3e21eb3cc5687d15df2153c3d8e2d4665b37acb55f0b5a57 +a4fca98710932aaa2765b57404e080105190cfa3af69171f4b4d95d4b78f9af0 diff --git a/ui/src/components/InfoBar.tsx b/ui/src/components/InfoBar.tsx index 8d0b2822..ce444d85 100644 --- a/ui/src/components/InfoBar.tsx +++ b/ui/src/components/InfoBar.tsx @@ -27,6 +27,7 @@ export default function InfoBar() { const { rpcDataChannel } = useRTCStore(); const { debugMode, mouseMode, showPressedKeys } = useSettingsStore(); + const { isPasteInProgress } = useHidStore(); useEffect(() => { if (!rpcDataChannel) return; @@ -108,7 +109,12 @@ export default function InfoBar() { {rpcHidStatus} )} - + {isPasteInProgress && ( +
+ Paste Mode: + Enabled +
+ )} {showPressedKeys && (
Keys: diff --git a/ui/src/components/Ipv6NetworkCard.tsx b/ui/src/components/Ipv6NetworkCard.tsx index a31b78e0..0cfacc6d 100644 --- a/ui/src/components/Ipv6NetworkCard.tsx +++ b/ui/src/components/Ipv6NetworkCard.tsx @@ -17,7 +17,7 @@ export default function Ipv6NetworkCard({
- {networkState?.dhcp_lease?.ip && ( + {networkState?.ipv6_link_local && (
Link-local diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index b26dde47..3d506914 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -215,7 +215,7 @@ export default function WebRTCVideo({ microphone }: WebRTCVideoProps) { if (!isFullscreenEnabled || !videoElm.current) return; // per https://wicg.github.io/keyboard-lock/#system-key-press-handler - // If keyboard lock is activated after fullscreen is already in effect, then the user my + // If keyboard lock is activated after fullscreen is already in effect, then the user my // see multiple messages about how to exit fullscreen. For this reason, we recommend that // developers call lock() before they enter fullscreen: await requestKeyboardLock(); @@ -262,6 +262,7 @@ export default function WebRTCVideo({ microphone }: WebRTCVideoProps) { const keyDownHandler = useCallback( (e: KeyboardEvent) => { e.preventDefault(); + if (e.repeat) return; const code = getAdjustedKeyCode(e); const hidKey = keys[code]; diff --git a/ui/src/components/popovers/PasteModal.tsx b/ui/src/components/popovers/PasteModal.tsx index 077759b7..6f224eb5 100644 --- a/ui/src/components/popovers/PasteModal.tsx +++ b/ui/src/components/popovers/PasteModal.tsx @@ -1,40 +1,44 @@ -import { useCallback, useEffect, useRef, useState } from "react"; -import { LuCornerDownLeft } from "react-icons/lu"; -import { ExclamationCircleIcon } from "@heroicons/react/16/solid"; import { useClose } from "@headlessui/react"; +import { ExclamationCircleIcon } from "@heroicons/react/16/solid"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { LuCornerDownLeft } from "react-icons/lu"; -import { Button } from "@components/Button"; -import { GridCard } from "@components/Card"; -import { TextAreaWithLabel } from "@components/TextArea"; -import { SettingsPageHeader } from "@components/SettingsPageheader"; +import { cx } from "@/cva.config"; +import { useHidStore, useSettingsStore, useUiStore } from "@/hooks/stores"; import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; -import { useHidStore, useRTCStore, useUiStore, useSettingsStore } from "@/hooks/stores"; -import { keys, modifiers } from "@/keyboardMappings"; -import { KeyStroke } from "@/keyboardLayouts"; +import useKeyboard, { type MacroStep } from "@/hooks/useKeyboard"; import useKeyboardLayout from "@/hooks/useKeyboardLayout"; import notifications from "@/notifications"; +import { Button } from "@components/Button"; +import { GridCard } from "@components/Card"; +import { InputFieldWithLabel } from "@components/InputField"; +import { SettingsPageHeader } from "@components/SettingsPageheader"; +import { TextAreaWithLabel } from "@components/TextArea"; -const hidKeyboardPayload = (modifier: number, keys: number[]) => { - return { modifier, keys }; -}; - -const modifierCode = (shift?: boolean, altRight?: boolean) => { - return (shift ? modifiers.ShiftLeft : 0) - | (altRight ? modifiers.AltRight : 0) -} -const noModifier = 0 +// uint32 max value / 4 +const pasteMaxLength = 1073741824; export default function PasteModal() { const TextAreaRef = useRef(null); - const { setPasteModeEnabled } = useHidStore(); + const { isPasteInProgress } = useHidStore(); const { setDisableVideoFocusTrap } = useUiStore(); const { send } = useJsonRpc(); - const { rpcDataChannel } = useRTCStore(); + const { executeMacro, cancelExecuteMacro } = useKeyboard(); const [invalidChars, setInvalidChars] = useState([]); + const [delayValue, setDelayValue] = useState(100); + const delay = useMemo(() => { + if (delayValue < 50 || delayValue > 65534) { + return 100; + } + return delayValue; + }, [delayValue]); const close = useClose(); + const debugMode = useSettingsStore(state => state.debugMode); + const delayClassName = useMemo(() => debugMode ? "" : "hidden", [debugMode]); + const { setKeyboardLayout } = useSettingsStore(); const { selectedKeyboard } = useKeyboardLayout(); @@ -46,21 +50,19 @@ export default function PasteModal() { }, [send, setKeyboardLayout]); const onCancelPasteMode = useCallback(() => { - setPasteModeEnabled(false); + cancelExecuteMacro(); setDisableVideoFocusTrap(false); setInvalidChars([]); - }, [setDisableVideoFocusTrap, setPasteModeEnabled]); + }, [setDisableVideoFocusTrap, cancelExecuteMacro]); const onConfirmPaste = useCallback(async () => { - setPasteModeEnabled(false); - setDisableVideoFocusTrap(false); - - if (rpcDataChannel?.readyState !== "open" || !TextAreaRef.current) return; - if (!selectedKeyboard) return; + if (!TextAreaRef.current || !selectedKeyboard) return; const text = TextAreaRef.current.value; try { + const macroSteps: MacroStep[] = []; + for (const char of text) { const keyprops = selectedKeyboard.chars[char]; if (!keyprops) continue; @@ -70,39 +72,41 @@ export default function PasteModal() { // if this is an accented character, we need to send that accent FIRST if (accentKey) { - await sendKeystroke({modifier: modifierCode(accentKey.shift, accentKey.altRight), keys: [ keys[accentKey.key] ] }) + const accentModifiers: string[] = []; + if (accentKey.shift) accentModifiers.push("ShiftLeft"); + if (accentKey.altRight) accentModifiers.push("AltRight"); + + macroSteps.push({ + keys: [String(accentKey.key)], + modifiers: accentModifiers.length > 0 ? accentModifiers : null, + delay, + }); } // now send the actual key - await sendKeystroke({ modifier: modifierCode(shift, altRight), keys: [ keys[key] ]}); + const modifiers: string[] = []; + if (shift) modifiers.push("ShiftLeft"); + if (altRight) modifiers.push("AltRight"); + + macroSteps.push({ + keys: [String(key)], + modifiers: modifiers.length > 0 ? modifiers : null, + delay + }); // if what was requested was a dead key, we need to send an unmodified space to emit // just the accent character - if (deadKey) { - await sendKeystroke({ modifier: noModifier, keys: [ keys["Space"] ] }); - } + if (deadKey) macroSteps.push({ keys: ["Space"], modifiers: null, delay }); + } - // now send a message with no keys down to "release" the keys - await sendKeystroke({ modifier: 0, keys: [] }); + if (macroSteps.length > 0) { + await executeMacro(macroSteps); } } catch (error) { console.error("Failed to paste text:", error); notifications.error("Failed to paste text"); } - - async function sendKeystroke(stroke: KeyStroke) { - await new Promise((resolve, reject) => { - send( - "keyboardReport", - hidKeyboardPayload(stroke.modifier, stroke.keys), - params => { - if ("error" in params) return reject(params.error); - resolve(); - } - ); - }); - } - }, [selectedKeyboard, rpcDataChannel?.readyState, send, setDisableVideoFocusTrap, setPasteModeEnabled]); + }, [selectedKeyboard, executeMacro, delay]); useEffect(() => { if (TextAreaRef.current) { @@ -122,19 +126,25 @@ export default function PasteModal() { />
-
e.stopPropagation()} onKeyDown={e => e.stopPropagation()}> +
e.stopPropagation()} + onKeyDown={e => e.stopPropagation()} onKeyDownCapture={e => e.stopPropagation()} + onKeyUpCapture={e => e.stopPropagation()} + > e.stopPropagation()} + maxLength={pasteMaxLength} onKeyDown={e => { e.stopPropagation(); if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) { @@ -171,9 +181,31 @@ export default function PasteModal() { )}
+
+ { + setDelayValue(parseInt(e.target.value, 10)); + }} + /> + {delayValue < 50 || delayValue > 65534 && ( +
+ + + Delay must be between 50 and 65534 + +
+ )} +

- Sending text using keyboard layout: {selectedKeyboard.isoCode}-{selectedKeyboard.name} + Sending text using keyboard layout: {selectedKeyboard.isoCode}- + {selectedKeyboard.name}

@@ -181,7 +213,7 @@ export default function PasteModal() {
diff --git a/ui/src/hooks/hidRpc.ts b/ui/src/hooks/hidRpc.ts index 20b8a108..823384ff 100644 --- a/ui/src/hooks/hidRpc.ts +++ b/ui/src/hooks/hidRpc.ts @@ -1,4 +1,4 @@ -import { KeyboardLedState, KeysDownState } from "./stores"; +import { hidKeyBufferSize, KeyboardLedState, KeysDownState } from "./stores"; export const HID_RPC_MESSAGE_TYPES = { Handshake: 0x01, @@ -6,9 +6,13 @@ export const HID_RPC_MESSAGE_TYPES = { PointerReport: 0x03, WheelReport: 0x04, KeypressReport: 0x05, + KeypressKeepAliveReport: 0x09, MouseReport: 0x06, + KeyboardMacroReport: 0x07, + CancelKeyboardMacroReport: 0x08, KeyboardLedState: 0x32, KeysDownState: 0x33, + KeyboardMacroState: 0x34, } export type HidRpcMessageType = typeof HID_RPC_MESSAGE_TYPES[keyof typeof HID_RPC_MESSAGE_TYPES]; @@ -28,7 +32,31 @@ const fromInt32toUint8 = (n: number) => { (n >> 24) & 0xFF, (n >> 16) & 0xFF, (n >> 8) & 0xFF, - (n >> 0) & 0xFF, + n & 0xFF, + ]); +}; + +const fromUint16toUint8 = (n: number) => { + if (n > 65535 || n < 0) { + throw new Error(`Number ${n} is not within the uint16 range`); + } + + return new Uint8Array([ + (n >> 8) & 0xFF, + n & 0xFF, + ]); +}; + +const fromUint32toUint8 = (n: number) => { + if (n > 4294967295 || n < 0) { + throw new Error(`Number ${n} is not within the uint32 range`); + } + + return new Uint8Array([ + (n >> 24) & 0xFF, + (n >> 16) & 0xFF, + (n >> 8) & 0xFF, + n & 0xFF, ]); }; @@ -37,7 +65,7 @@ const fromInt8ToUint8 = (n: number) => { throw new Error(`Number ${n} is not within the int8 range`); } - return (n >> 0) & 0xFF; + return n & 0xFF; }; const keyboardLedStateMasks = { @@ -186,6 +214,99 @@ export class KeyboardReportMessage extends RpcMessage { } } +export interface KeyboardMacroStep extends KeysDownState { + delay: number; +} + +export class KeyboardMacroReportMessage extends RpcMessage { + isPaste: boolean; + stepCount: number; + steps: KeyboardMacroStep[]; + + KEYS_LENGTH = hidKeyBufferSize; + + constructor(isPaste: boolean, stepCount: number, steps: KeyboardMacroStep[]) { + super(HID_RPC_MESSAGE_TYPES.KeyboardMacroReport); + this.isPaste = isPaste; + this.stepCount = stepCount; + this.steps = steps; + } + + marshal(): Uint8Array { + // validate if length is correct + if (this.stepCount !== this.steps.length) { + throw new Error(`Length ${this.stepCount} is not equal to the number of steps ${this.steps.length}`); + } + + const data = new Uint8Array(this.stepCount * 9 + 6); + data.set(new Uint8Array([ + this.messageType, + this.isPaste ? 1 : 0, + ...fromUint32toUint8(this.stepCount), + ]), 0); + + for (let i = 0; i < this.stepCount; i++) { + const step = this.steps[i]; + if (!withinUint8Range(step.modifier)) { + throw new Error(`Modifier ${step.modifier} is not within the uint8 range`); + } + + // Ensure the keys are within the KEYS_LENGTH range + const keys = step.keys; + if (keys.length > this.KEYS_LENGTH) { + throw new Error(`Keys ${keys} is not within the hidKeyBufferSize range`); + } else if (keys.length < this.KEYS_LENGTH) { + keys.push(...Array(this.KEYS_LENGTH - keys.length).fill(0)); + } + + for (const key of keys) { + if (!withinUint8Range(key)) { + throw new Error(`Key ${key} is not within the uint8 range`); + } + } + + const macroBinary = new Uint8Array([ + step.modifier, + ...keys, + ...fromUint16toUint8(step.delay), + ]); + const offset = 6 + i * 9; + + + data.set(macroBinary, offset); + } + + return data; + } +} + +export class KeyboardMacroStateMessage extends RpcMessage { + state: boolean; + isPaste: boolean; + + constructor(state: boolean, isPaste: boolean) { + super(HID_RPC_MESSAGE_TYPES.KeyboardMacroState); + this.state = state; + this.isPaste = isPaste; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + this.state ? 1 : 0, + this.isPaste ? 1 : 0, + ]); + } + + public static unmarshal(data: Uint8Array): KeyboardMacroStateMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keyboard macro state report message length: ${data.length}`); + } + + return new KeyboardMacroStateMessage(data[0] === 1, data[1] === 1); + } +} + export class KeyboardLedStateMessage extends RpcMessage { keyboardLedState: KeyboardLedState; @@ -256,6 +377,17 @@ export class PointerReportMessage extends RpcMessage { } } +export class CancelKeyboardMacroReportMessage extends RpcMessage { + + constructor() { + super(HID_RPC_MESSAGE_TYPES.CancelKeyboardMacroReport); + } + + marshal(): Uint8Array { + return new Uint8Array([this.messageType]); + } +} + export class MouseReportMessage extends RpcMessage { dx: number; dy: number; @@ -278,12 +410,26 @@ export class MouseReportMessage extends RpcMessage { } } +export class KeypressKeepAliveMessage extends RpcMessage { + constructor() { + super(HID_RPC_MESSAGE_TYPES.KeypressKeepAliveReport); + } + + marshal(): Uint8Array { + return new Uint8Array([this.messageType]); + } +} + export const messageRegistry = { [HID_RPC_MESSAGE_TYPES.Handshake]: HandshakeMessage, [HID_RPC_MESSAGE_TYPES.KeysDownState]: KeysDownStateMessage, [HID_RPC_MESSAGE_TYPES.KeyboardLedState]: KeyboardLedStateMessage, [HID_RPC_MESSAGE_TYPES.KeyboardReport]: KeyboardReportMessage, [HID_RPC_MESSAGE_TYPES.KeypressReport]: KeypressReportMessage, + [HID_RPC_MESSAGE_TYPES.KeyboardMacroReport]: KeyboardMacroReportMessage, + [HID_RPC_MESSAGE_TYPES.CancelKeyboardMacroReport]: CancelKeyboardMacroReportMessage, + [HID_RPC_MESSAGE_TYPES.KeyboardMacroState]: KeyboardMacroStateMessage, + [HID_RPC_MESSAGE_TYPES.KeypressKeepAliveReport]: KeypressKeepAliveMessage, } export const unmarshalHidRpcMessage = (data: Uint8Array): RpcMessage | undefined => { diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index c204a827..e43e5137 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -107,12 +107,21 @@ export interface RTCState { setRpcDataChannel: (channel: RTCDataChannel) => void; rpcDataChannel: RTCDataChannel | null; + hidRpcDisabled: boolean; + setHidRpcDisabled: (disabled: boolean) => void; + rpcHidProtocolVersion: number | null; - setRpcHidProtocolVersion: (version: number) => void; + setRpcHidProtocolVersion: (version: number | null) => void; rpcHidChannel: RTCDataChannel | null; setRpcHidChannel: (channel: RTCDataChannel) => void; + rpcHidUnreliableChannel: RTCDataChannel | null; + setRpcHidUnreliableChannel: (channel: RTCDataChannel) => void; + + rpcHidUnreliableNonOrderedChannel: RTCDataChannel | null; + setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => void; + peerConnectionState: RTCPeerConnectionState | null; setPeerConnectionState: (state: RTCPeerConnectionState) => void; @@ -169,12 +178,21 @@ export const useRTCStore = create(set => ({ rpcDataChannel: null, setRpcDataChannel: (channel: RTCDataChannel) => set({ rpcDataChannel: channel }), + hidRpcDisabled: false, + setHidRpcDisabled: (disabled: boolean) => set({ hidRpcDisabled: disabled }), + rpcHidProtocolVersion: null, - setRpcHidProtocolVersion: (version: number) => set({ rpcHidProtocolVersion: version }), + setRpcHidProtocolVersion: (version: number | null) => set({ rpcHidProtocolVersion: version }), rpcHidChannel: null, setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }), + rpcHidUnreliableChannel: null, + setRpcHidUnreliableChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableChannel: channel }), + + rpcHidUnreliableNonOrderedChannel: null, + setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableNonOrderedChannel: channel }), + transceiver: null, setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), @@ -494,7 +512,7 @@ export interface HidState { isVirtualKeyboardEnabled: boolean; setVirtualKeyboardEnabled: (enabled: boolean) => void; - isPasteModeEnabled: boolean; + isPasteInProgress: boolean; setPasteModeEnabled: (enabled: boolean) => void; usbState: USBStates; @@ -511,8 +529,8 @@ export const useHidStore = create(set => ({ isVirtualKeyboardEnabled: false, setVirtualKeyboardEnabled: (enabled: boolean): void => set({ isVirtualKeyboardEnabled: enabled }), - isPasteModeEnabled: false, - setPasteModeEnabled: (enabled: boolean): void => set({ isPasteModeEnabled: enabled }), + isPasteInProgress: false, + setPasteModeEnabled: (enabled: boolean): void => set({ isPasteInProgress: enabled }), // Add these new properties for USB state usbState: "not attached", diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts index ea0c7112..aeb1c4fa 100644 --- a/ui/src/hooks/useHidRpc.ts +++ b/ui/src/hooks/useHidRpc.ts @@ -3,9 +3,13 @@ import { useCallback, useEffect, useMemo } from "react"; import { useRTCStore } from "@/hooks/stores"; import { + CancelKeyboardMacroReportMessage, HID_RPC_VERSION, HandshakeMessage, + KeyboardMacroStep, + KeyboardMacroReportMessage, KeyboardReportMessage, + KeypressKeepAliveMessage, KeypressReportMessage, MouseReportMessage, PointerReportMessage, @@ -13,38 +17,97 @@ import { unmarshalHidRpcMessage, } from "./hidRpc"; +const KEEPALIVE_MESSAGE = new KeypressKeepAliveMessage(); + +interface sendMessageParams { + ignoreHandshakeState?: boolean; + useUnreliableChannel?: boolean; + requireOrdered?: boolean; +} + export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { - const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion } = useRTCStore(); + const { + rpcHidChannel, + rpcHidUnreliableChannel, + rpcHidUnreliableNonOrderedChannel, + setRpcHidProtocolVersion, + rpcHidProtocolVersion, hidRpcDisabled, + } = useRTCStore(); + const rpcHidReady = useMemo(() => { + if (hidRpcDisabled) return false; return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; - }, [rpcHidChannel, rpcHidProtocolVersion]); + }, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]); + + const rpcHidUnreliableReady = useMemo(() => { + return ( + rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null + ); + }, [rpcHidProtocolVersion, rpcHidUnreliableChannel?.readyState]); + + const rpcHidUnreliableNonOrderedReady = useMemo(() => { + return ( + rpcHidUnreliableNonOrderedChannel?.readyState === "open" && + rpcHidProtocolVersion !== null + ); + }, [rpcHidProtocolVersion, rpcHidUnreliableNonOrderedChannel?.readyState]); const rpcHidStatus = useMemo(() => { + if (hidRpcDisabled) return "disabled"; + if (!rpcHidChannel) return "N/A"; if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState; if (!rpcHidProtocolVersion) return "handshaking"; - return `ready (v${rpcHidProtocolVersion})`; - }, [rpcHidChannel, rpcHidProtocolVersion]); + return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`; + }, [rpcHidChannel, rpcHidProtocolVersion, rpcHidUnreliableReady, hidRpcDisabled]); - const sendMessage = useCallback((message: RpcMessage, ignoreHandshakeState = false) => { + const sendMessage = useCallback( + ( + message: RpcMessage, + { + ignoreHandshakeState, + useUnreliableChannel, + requireOrdered = true, + }: sendMessageParams = {}, + ) => { + if (hidRpcDisabled) return; if (rpcHidChannel?.readyState !== "open") return; - if (!rpcHidReady && !ignoreHandshakeState) return; + if (!rpcHidReady && !ignoreHandshakeState) return; - let data: Uint8Array | undefined; - try { - data = message.marshal(); - } catch (e) { - console.error("Failed to send HID RPC message", e); - } - if (!data) return; + let data: Uint8Array | undefined; + try { + data = message.marshal(); + } catch (e) { + console.error("Failed to send HID RPC message", e); + } + if (!data) return; - rpcHidChannel?.send(data as unknown as ArrayBuffer); - }, [rpcHidChannel, rpcHidReady]); + if (useUnreliableChannel) { + if (requireOrdered && rpcHidUnreliableReady) { + rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); + } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) { + rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer); + } + return; + } + + rpcHidChannel?.send(data as unknown as ArrayBuffer); + }, + [ + rpcHidChannel, + rpcHidUnreliableChannel, + hidRpcDisabled, rpcHidUnreliableNonOrderedChannel, + rpcHidReady, + rpcHidUnreliableReady, + rpcHidUnreliableNonOrderedReady, + ], + ); const reportKeyboardEvent = useCallback( (keys: number[], modifier: number) => { sendMessage(new KeyboardReportMessage(keys, modifier)); - }, [sendMessage], + }, + [sendMessage], ); const reportKeypressEvent = useCallback( @@ -56,7 +119,9 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { const reportAbsMouseEvent = useCallback( (x: number, y: number, buttons: number) => { - sendMessage(new PointerReportMessage(x, y, buttons)); + sendMessage(new PointerReportMessage(x, y, buttons), { + useUnreliableChannel: true, + }); }, [sendMessage], ); @@ -68,32 +133,57 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { [sendMessage], ); + const reportKeyboardMacroEvent = useCallback( + (steps: KeyboardMacroStep[]) => { + sendMessage(new KeyboardMacroReportMessage(false, steps.length, steps)); + }, + [sendMessage], + ); + + const cancelOngoingKeyboardMacro = useCallback( + () => { + sendMessage(new CancelKeyboardMacroReportMessage()); + }, + [sendMessage], + ); + + const reportKeypressKeepAlive = useCallback(() => { + sendMessage(KEEPALIVE_MESSAGE); + }, [sendMessage]); + const sendHandshake = useCallback(() => { + if (hidRpcDisabled) return; if (rpcHidProtocolVersion) return; if (!rpcHidChannel) return; - sendMessage(new HandshakeMessage(HID_RPC_VERSION), true); - }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage]); + sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true }); + }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage, hidRpcDisabled]); + + const handleHandshake = useCallback( + (message: HandshakeMessage) => { + if (hidRpcDisabled) return; - const handleHandshake = useCallback((message: HandshakeMessage) => { if (!message.version) { - console.error("Received handshake message without version", message); - return; - } + console.error("Received handshake message without version", message); + return; + } - if (message.version > HID_RPC_VERSION) { - // we assume that the UI is always using the latest version of the HID RPC protocol - // so we can't support this - // TODO: use capabilities to determine rather than version number - console.error("Server is using a newer HID RPC version than the client", message); - return; - } + if (message.version > HID_RPC_VERSION) { + // we assume that the UI is always using the latest version of the HID RPC protocol + // so we can't support this + // TODO: use capabilities to determine rather than version number + console.error("Server is using a newer HID RPC version than the client", message); + return; + } - setRpcHidProtocolVersion(message.version); - }, [setRpcHidProtocolVersion]); + setRpcHidProtocolVersion(message.version); + }, + [setRpcHidProtocolVersion, hidRpcDisabled], + ); useEffect(() => { if (!rpcHidChannel) return; + if (hidRpcDisabled) return; // send handshake message sendHandshake(); @@ -123,26 +213,42 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { onHidRpcMessage?.(message); }; + const openHandler = () => { + console.info("HID RPC channel opened"); + sendHandshake(); + }; + + const closeHandler = () => { + console.info("HID RPC channel closed"); + setRpcHidProtocolVersion(null); + }; + rpcHidChannel.addEventListener("message", messageHandler); + rpcHidChannel.addEventListener("close", closeHandler); + rpcHidChannel.addEventListener("open", openHandler); return () => { rpcHidChannel.removeEventListener("message", messageHandler); + rpcHidChannel.removeEventListener("close", closeHandler); + rpcHidChannel.removeEventListener("open", openHandler); }; - }, - [ - rpcHidChannel, - onHidRpcMessage, - setRpcHidProtocolVersion, - sendHandshake, - handleHandshake, - ], - ); + }, [ + rpcHidChannel, + onHidRpcMessage, + setRpcHidProtocolVersion, + sendHandshake, + handleHandshake, + hidRpcDisabled, + ]); return { reportKeyboardEvent, reportKeypressEvent, reportAbsMouseEvent, reportRelMouseEvent, + reportKeyboardMacroEvent, + cancelOngoingKeyboardMacro, + reportKeypressKeepAlive, rpcHidProtocolVersion, rpcHidReady, rpcHidStatus, diff --git a/ui/src/hooks/useJsonRpc.ts b/ui/src/hooks/useJsonRpc.ts index b4fcc8ef..5c52d59c 100644 --- a/ui/src/hooks/useJsonRpc.ts +++ b/ui/src/hooks/useJsonRpc.ts @@ -29,6 +29,8 @@ export interface JsonRpcErrorResponse { export type JsonRpcResponse = JsonRpcSuccessResponse | JsonRpcErrorResponse; +export const RpcMethodNotFound = -32601; + const callbackStore = new Map void>(); let requestCounter = 0; diff --git a/ui/src/hooks/useKeyboard.ts b/ui/src/hooks/useKeyboard.ts index 787df9a9..8d101b3b 100644 --- a/ui/src/hooks/useKeyboard.ts +++ b/ui/src/hooks/useKeyboard.ts @@ -1,15 +1,51 @@ -import { useCallback } from "react"; +import { useCallback, useRef } from "react"; -import { hidErrorRollOver, hidKeyBufferSize, KeysDownState, useHidStore, useRTCStore } from "@/hooks/stores"; -import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; +import { + KeyboardLedStateMessage, + KeyboardMacroStateMessage, + KeyboardMacroStep, + KeysDownStateMessage, +} from "@/hooks/hidRpc"; +import { + hidErrorRollOver, + hidKeyBufferSize, + KeysDownState, + useHidStore, + useRTCStore, +} from "@/hooks/stores"; import { useHidRpc } from "@/hooks/useHidRpc"; -import { KeyboardLedStateMessage, KeysDownStateMessage } from "@/hooks/hidRpc"; +import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; import { hidKeyToModifierMask, keys, modifiers } from "@/keyboardMappings"; +const MACRO_RESET_KEYBOARD_STATE = { + keys: new Array(hidKeyBufferSize).fill(0), + modifier: 0, + delay: 0, +}; + +export interface MacroStep { + keys: string[] | null; + modifiers: string[] | null; + delay: number; +} + +export type MacroSteps = MacroStep[]; + +const sleep = (ms: number): Promise => new Promise(resolve => setTimeout(resolve, ms)); + export default function useKeyboard() { const { send } = useJsonRpc(); const { rpcDataChannel } = useRTCStore(); - const { keysDownState, setKeysDownState, setKeyboardLedState } = useHidStore(); + const { keysDownState, setKeysDownState, setKeyboardLedState, setPasteModeEnabled } = + useHidStore(); + + const abortController = useRef(null); + const setAbortController = useCallback((ac: AbortController | null) => { + abortController.current = ac; + }, []); + + // Keepalive timer management + const keepAliveTimerRef = useRef(null); // INTRODUCTION: The earlier version of the JetKVM device shipped with all keyboard state // being tracked on the browser/client-side. When adding the keyPressReport API to the @@ -17,17 +53,19 @@ export default function useKeyboard() { // is running on the cloud against a device that has not been updated yet and thus does not // support the keyPressReport API. In that case, we need to handle the key presses locally // and send the full state to the device, so it can behave like a real USB HID keyboard. - // This flag indicates whether the keyPressReport API is available on the device which is - // dynamically set when the device responds to the first key press event or reports its - // keysDownState when queried since the keyPressReport was introduced together with the + // This flag indicates whether the keyPressReport API is available on the device which is + // dynamically set when the device responds to the first key press event or reports its // keysDownState when queried since the keyPressReport was introduced together with the // getKeysDownState API. // HidRPC is a binary format for exchanging keyboard and mouse events const { reportKeyboardEvent: sendKeyboardEventHidRpc, reportKeypressEvent: sendKeypressEventHidRpc, + reportKeyboardMacroEvent: sendKeyboardMacroEventHidRpc, + cancelOngoingKeyboardMacro: cancelOngoingKeyboardMacroHidRpc, + reportKeypressKeepAlive: sendKeypressKeepAliveHidRpc, rpcHidReady, - } = useHidRpc((message) => { + } = useHidRpc(message => { switch (message.constructor) { case KeysDownStateMessage: setKeysDownState((message as KeysDownStateMessage).keysDownState); @@ -35,81 +73,80 @@ export default function useKeyboard() { case KeyboardLedStateMessage: setKeyboardLedState((message as KeyboardLedStateMessage).keyboardLedState); break; + case KeyboardMacroStateMessage: + if (!(message as KeyboardMacroStateMessage).isPaste) break; + setPasteModeEnabled((message as KeyboardMacroStateMessage).state); + break; default: break; } }); - // sendKeyboardEvent is used to send the full keyboard state to the device for macro handling - // and resetting keyboard state. It sends the keys currently pressed and the modifier state. - // The device will respond with the keysDownState if it supports the keyPressReport API - // or just accept the state if it does not support (returning no result) - const sendKeyboardEvent = useCallback( - async (state: KeysDownState) => { - if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; - - console.debug(`Send keyboardReport keys: ${state.keys}, modifier: ${state.modifier}`); - - if (rpcHidReady) { - console.debug("Sending keyboard report via HidRPC"); - sendKeyboardEventHidRpc(state.keys, state.modifier); - return; - } - - send("keyboardReport", { keys: state.keys, modifier: state.modifier }, (resp: JsonRpcResponse) => { + const handleLegacyKeyboardReport = useCallback( + async (keys: number[], modifier: number) => { + send("keyboardReport", { keys, modifier }, (resp: JsonRpcResponse) => { if ("error" in resp) { - console.error(`Failed to send keyboard report ${state}`, resp.error); + console.error(`Failed to send keyboard report ${keys} ${modifier}`, resp.error); } + + // On older backends, we need to set the keysDownState manually since without the hidRpc API, the state doesn't trickle down from the backend + setKeysDownState({ modifier, keys }); }); }, - [ - rpcDataChannel?.readyState, - rpcHidReady, - send, - sendKeyboardEventHidRpc, - ], + [send, setKeysDownState], ); - // resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers. - // This is useful for macros and when the browser loses focus to ensure that the keyboard state - // is clean. - const resetKeyboardState = useCallback( - async () => { - // Reset the keys buffer to zeros and the modifier state to zero - keysDownState.keys.length = hidKeyBufferSize; - keysDownState.keys.fill(0); - keysDownState.modifier = 0; - sendKeyboardEvent(keysDownState); - }, [keysDownState, sendKeyboardEvent]); + const sendKeystrokeLegacy = useCallback(async (keys: number[], modifier: number, ac?: AbortController) => { + return await new Promise((resolve, reject) => { + const abortListener = () => { + reject(new Error("Keyboard report aborted")); + }; - // executeMacro is used to execute a macro consisting of multiple steps. - // Each step can have multiple keys, multiple modifiers and a delay. - // The keys and modifiers are pressed together and held for the delay duration. - // After the delay, the keys and modifiers are released and the next step is executed. - // If a step has no keys or modifiers, it is treated as a delay-only step. - // A small pause is added between steps to ensure that the device can process the events. - const executeMacro = async (steps: { keys: string[] | null; modifiers: string[] | null; delay: number }[]) => { - for (const [index, step] of steps.entries()) { - const keyValues = (step.keys || []).map(key => keys[key]).filter(Boolean); - const modifierMask: number = (step.modifiers || []).map(mod => modifiers[mod]).reduce((acc, val) => acc + val, 0); + ac?.signal?.addEventListener("abort", abortListener); - // If the step has keys and/or modifiers, press them and hold for the delay - if (keyValues.length > 0 || modifierMask > 0) { - sendKeyboardEvent({ keys: keyValues, modifier: modifierMask }); - await new Promise(resolve => setTimeout(resolve, step.delay || 50)); + send( + "keyboardReport", + { keys, modifier }, + params => { + if ("error" in params) return reject(params.error); + resolve(); + }, + ); + }); + }, [send]); - resetKeyboardState(); - } else { - // This is a delay-only step, just wait for the delay amount - await new Promise(resolve => setTimeout(resolve, step.delay || 50)); - } + const KEEPALIVE_INTERVAL = 50; - // Add a small pause between steps if not the last step - if (index < steps.length - 1) { - await new Promise(resolve => setTimeout(resolve, 10)); - } + const cancelKeepAlive = useCallback(() => { + if (keepAliveTimerRef.current) { + clearInterval(keepAliveTimerRef.current); + keepAliveTimerRef.current = null; } - }; + }, []); + + const scheduleKeepAlive = useCallback(() => { + // Clears existing keepalive timer + cancelKeepAlive(); + + keepAliveTimerRef.current = setInterval(() => { + sendKeypressKeepAliveHidRpc(); + }, KEEPALIVE_INTERVAL); + }, [cancelKeepAlive, sendKeypressKeepAliveHidRpc]); + + // resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers. + // This is useful for macros, in case of client-side rollover, and when the browser loses focus + const resetKeyboardState = useCallback(async () => { + // Cancel keepalive since we're resetting the keyboard state + cancelKeepAlive(); + // Reset the keys buffer to zeros and the modifier state to zero + const { keys, modifier } = MACRO_RESET_KEYBOARD_STATE; + if (rpcHidReady) { + sendKeyboardEventHidRpc(keys, modifier); + } else { + // Older backends don't support the hidRpc API, so we send the full reset state + handleLegacyKeyboardReport(keys, modifier); + } + }, [rpcHidReady, sendKeyboardEventHidRpc, handleLegacyKeyboardReport, cancelKeepAlive]); // handleKeyPress is used to handle a key press or release event. // This function handle both key press and key release events. @@ -117,6 +154,20 @@ export default function useKeyboard() { // If the keyPressReport API is not available, it simulates the device-side key // handling for legacy devices and updates the keysDownState accordingly. // It then sends the full keyboard state to the device. + + const sendKeypress = useCallback( + (key: number, press: boolean) => { + cancelKeepAlive(); + + sendKeypressEventHidRpc(key, press); + + if (press) { + scheduleKeepAlive(); + } + }, + [sendKeypressEventHidRpc, scheduleKeepAlive, cancelKeepAlive], + ); + const handleKeyPress = useCallback( async (key: number, press: boolean) => { if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; @@ -129,11 +180,18 @@ export default function useKeyboard() { // Older device version doesn't support this API, so we will switch to local key handling // In that case we will switch to local key handling and update the keysDownState // in client/browser-side code using simulateDeviceSideKeyHandlingForLegacyDevices. - sendKeypressEventHidRpc(key, press); + sendKeypress(key, press); } else { - // if the keyPress api is not available, we need to handle the key locally - const downState = simulateDeviceSideKeyHandlingForLegacyDevices(keysDownState, key, press); - sendKeyboardEvent(downState); // then we send the full state + // Older backends don't support the hidRpc API, so we need: + // 1. Calculate the state + // 2. Send the newly calculated state to the device + const downState = simulateDeviceSideKeyHandlingForLegacyDevices( + keysDownState, + key, + press, + ); + + handleLegacyKeyboardReport(downState.keys, downState.modifier); // if we just sent ErrorRollOver, reset to empty state if (downState.keys[0] === hidErrorRollOver) { @@ -142,17 +200,21 @@ export default function useKeyboard() { } }, [ + rpcDataChannel?.readyState, rpcHidReady, keysDownState, + handleLegacyKeyboardReport, resetKeyboardState, - rpcDataChannel?.readyState, - sendKeyboardEvent, - sendKeypressEventHidRpc, + sendKeypress, ], ); // IMPORTANT: See the keyPressReportApiAvailable comment above for the reason this exists - function simulateDeviceSideKeyHandlingForLegacyDevices(state: KeysDownState, key: number, press: boolean): KeysDownState { + function simulateDeviceSideKeyHandlingForLegacyDevices( + state: KeysDownState, + key: number, + press: boolean, + ): KeysDownState { // IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver // for handling key presses and releases. It ensures that the USB gadget // behaves similarly to a real USB HID keyboard. This logic is paralleled @@ -164,7 +226,7 @@ export default function useKeyboard() { if (modifierMask !== 0) { // If the key is a modifier key, we update the keyboardModifier state // by setting or clearing the corresponding bit in the modifier byte. - // This allows us to track the state of dynamic modifier keys like + // This allows us to track the state of dynamic modifier keys like // Shift, Control, Alt, and Super. if (press) { modifiers |= modifierMask; @@ -181,7 +243,7 @@ export default function useKeyboard() { // and if we find a zero byte, we can place the key there (if press is true) if (keys[i] === key || keys[i] === 0) { if (press) { - keys[i] = key // overwrites the zero byte or the same key if already pressed + keys[i] = key; // overwrites the zero byte or the same key if already pressed } else { // we are releasing the key, remove it from the buffer if (keys[i] !== 0) { @@ -203,12 +265,113 @@ export default function useKeyboard() { keys.fill(hidErrorRollOver); } else { // If we are releasing a key, and we didn't find it in a slot, who cares? - console.debug(`key ${key} not found in buffer, nothing to release`) + console.debug(`key ${key} not found in buffer, nothing to release`); } } } return { modifier: modifiers, keys }; } - return { handleKeyPress, resetKeyboardState, executeMacro }; + // Cleanup function to cancel keepalive timer + const cleanup = useCallback(() => { + cancelKeepAlive(); + }, [cancelKeepAlive]); + + + // executeMacro is used to execute a macro consisting of multiple steps. + // Each step can have multiple keys, multiple modifiers and a delay. + // The keys and modifiers are pressed together and held for the delay duration. + // After the delay, the keys and modifiers are released and the next step is executed. + // If a step has no keys or modifiers, it is treated as a delay-only step. + // A small pause is added between steps to ensure that the device can process the events. + const executeMacroRemote = useCallback(async ( + steps: MacroSteps, + ) => { + const macro: KeyboardMacroStep[] = []; + + for (const [_, step] of steps.entries()) { + const keyValues = (step.keys || []).map(key => keys[key]).filter(Boolean); + const modifierMask: number = (step.modifiers || []) + + .map(mod => modifiers[mod]) + + .reduce((acc, val) => acc + val, 0); + + // If the step has keys and/or modifiers, press them and hold for the delay + if (keyValues.length > 0 || modifierMask > 0) { + macro.push({ keys: keyValues, modifier: modifierMask, delay: 20 }); + macro.push({ ...MACRO_RESET_KEYBOARD_STATE, delay: step.delay || 100 }); + } + } + + sendKeyboardMacroEventHidRpc(macro); + }, [sendKeyboardMacroEventHidRpc]); + const executeMacroClientSide = useCallback(async (steps: MacroSteps) => { + const promises: (() => Promise)[] = []; + + const ac = new AbortController(); + setAbortController(ac); + + for (const [_, step] of steps.entries()) { + const keyValues = (step.keys || []).map(key => keys[key]).filter(Boolean); + const modifierMask: number = (step.modifiers || []) + .map(mod => modifiers[mod]) + .reduce((acc, val) => acc + val, 0); + + // If the step has keys and/or modifiers, press them and hold for the delay + if (keyValues.length > 0 || modifierMask > 0) { + promises.push(() => sendKeystrokeLegacy(keyValues, modifierMask, ac)); + promises.push(() => resetKeyboardState()); + promises.push(() => sleep(step.delay || 100)); + } + } + + const runAll = async () => { + for (const promise of promises) { + // Check if we've been aborted before executing each promise + if (ac.signal.aborted) { + throw new Error("Macro execution aborted"); + } + await promise(); + } + } + + return await new Promise((resolve, reject) => { + // Set up abort listener + const abortListener = () => { + reject(new Error("Macro execution aborted")); + }; + + ac.signal.addEventListener("abort", abortListener); + + runAll() + .then(() => { + ac.signal.removeEventListener("abort", abortListener); + resolve(); + }) + .catch((error) => { + ac.signal.removeEventListener("abort", abortListener); + reject(error); + }); + }); + }, [sendKeystrokeLegacy, resetKeyboardState, setAbortController]); + const executeMacro = useCallback(async (steps: MacroSteps) => { + if (rpcHidReady) { + return executeMacroRemote(steps); + } + return executeMacroClientSide(steps); + }, [rpcHidReady, executeMacroRemote, executeMacroClientSide]); + + const cancelExecuteMacro = useCallback(async () => { + if (abortController.current) { + abortController.current.abort(); + } + if (!rpcHidReady) return; + // older versions don't support this API, + // and all paste actions are pure-frontend, + // we don't need to cancel it actually + cancelOngoingKeyboardMacroHidRpc(); + }, [rpcHidReady, cancelOngoingKeyboardMacroHidRpc, abortController]); + + return { handleKeyPress, resetKeyboardState, executeMacro, cleanup, cancelExecuteMacro }; } diff --git a/ui/src/hooks/useVersion.tsx b/ui/src/hooks/useVersion.tsx new file mode 100644 index 00000000..7341dacb --- /dev/null +++ b/ui/src/hooks/useVersion.tsx @@ -0,0 +1,79 @@ +import { useCallback } from "react"; + +import { useDeviceStore } from "@/hooks/stores"; +import { type JsonRpcResponse, RpcMethodNotFound, useJsonRpc } from "@/hooks/useJsonRpc"; +import notifications from "@/notifications"; + +export interface VersionInfo { + appVersion: string; + systemVersion: string; +} + +export interface SystemVersionInfo { + local: VersionInfo; + remote?: VersionInfo; + systemUpdateAvailable: boolean; + appUpdateAvailable: boolean; + error?: string; +} + +export function useVersion() { + const { + appVersion, + systemVersion, + setAppVersion, + setSystemVersion, + } = useDeviceStore(); + const { send } = useJsonRpc(); + const getVersionInfo = useCallback(() => { + return new Promise((resolve, reject) => { + send("getUpdateStatus", {}, (resp: JsonRpcResponse) => { + if ("error" in resp) { + notifications.error(`Failed to check for updates: ${resp.error}`); + reject(new Error("Failed to check for updates")); + } else { + const result = resp.result as SystemVersionInfo; + setAppVersion(result.local.appVersion); + setSystemVersion(result.local.systemVersion); + + if (result.error) { + notifications.error(`Failed to check for updates: ${result.error}`); + reject(new Error("Failed to check for updates")); + } else { + resolve(result); + } + } + }); + }); + }, [send, setAppVersion, setSystemVersion]); + + const getLocalVersion = useCallback(() => { + return new Promise((resolve, reject) => { + send("getLocalVersion", {}, (resp: JsonRpcResponse) => { + if ("error" in resp) { + console.log(resp.error) + if (resp.error.code === RpcMethodNotFound) { + console.warn("Failed to get device version, using legacy version"); + return getVersionInfo().then(result => resolve(result.local)).catch(reject); + } + console.error("Failed to get device version N", resp.error); + notifications.error(`Failed to get device version: ${resp.error}`); + reject(new Error("Failed to get device version")); + } else { + const result = resp.result as VersionInfo; + + setAppVersion(result.appVersion); + setSystemVersion(result.systemVersion); + resolve(result); + } + }); + }); + }, [send, setAppVersion, setSystemVersion, getVersionInfo]); + + return { + getVersionInfo, + getLocalVersion, + appVersion, + systemVersion, + }; +} \ No newline at end of file diff --git a/ui/src/index.css b/ui/src/index.css index ae23db2b..6eaae1f7 100644 --- a/ui/src/index.css +++ b/ui/src/index.css @@ -239,7 +239,7 @@ video::-webkit-media-controls { } .simple-keyboard-arrows .hg-button { - @apply flex w-[50px] grow-0 items-center justify-center; + @apply flex w-[50px] items-center justify-center; } .controlArrows { @@ -264,7 +264,7 @@ video::-webkit-media-controls { } .simple-keyboard-control .hg-button { - @apply flex w-[50px] grow-0 items-center justify-center; + @apply flex w-[50px] items-center justify-center; } .numPad { @@ -332,7 +332,7 @@ video::-webkit-media-controls { .keyboard-detached .simple-keyboard.hg-theme-default div.hg-button { text-wrap: auto; - text-align: center; + text-align: center; min-width: 6ch; } .keyboard-detached .simple-keyboard.hg-theme-default .hg-button span { diff --git a/ui/src/routes/devices.$id.settings.general.update.tsx b/ui/src/routes/devices.$id.settings.general.update.tsx index 80ba0f78..38c15412 100644 --- a/ui/src/routes/devices.$id.settings.general.update.tsx +++ b/ui/src/routes/devices.$id.settings.general.update.tsx @@ -3,12 +3,12 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { CheckCircleIcon } from "@heroicons/react/20/solid"; import Card from "@/components/Card"; -import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; +import { useJsonRpc } from "@/hooks/useJsonRpc"; import { Button } from "@components/Button"; -import { UpdateState, useDeviceStore, useUpdateStore } from "@/hooks/stores"; -import notifications from "@/notifications"; +import { UpdateState, useUpdateStore } from "@/hooks/stores"; import LoadingSpinner from "@/components/LoadingSpinner"; import { useDeviceUiNavigation } from "@/hooks/useAppNavigation"; +import { SystemVersionInfo, useVersion } from "@/hooks/useVersion"; export default function SettingsGeneralUpdateRoute() { const navigate = useNavigate(); @@ -41,13 +41,7 @@ export default function SettingsGeneralUpdateRoute() { return navigate("..")} onConfirmUpdate={onConfirmUpdate} />; } -export interface SystemVersionInfo { - local: { appVersion: string; systemVersion: string }; - remote?: { appVersion: string; systemVersion: string }; - systemUpdateAvailable: boolean; - appUpdateAvailable: boolean; - error?: string; -} + export function Dialog({ onClose, @@ -134,30 +128,8 @@ function LoadingState({ }) { const [progressWidth, setProgressWidth] = useState("0%"); const abortControllerRef = useRef(null); - const { setAppVersion, setSystemVersion } = useDeviceStore(); - const { send } = useJsonRpc(); - const getVersionInfo = useCallback(() => { - return new Promise((resolve, reject) => { - send("getUpdateStatus", {}, (resp: JsonRpcResponse) => { - if ("error" in resp) { - notifications.error(`Failed to check for updates: ${resp.error}`); - reject(new Error("Failed to check for updates")); - } else { - const result = resp.result as SystemVersionInfo; - setAppVersion(result.local.appVersion); - setSystemVersion(result.local.systemVersion); - - if (result.error) { - notifications.error(`Failed to check for updates: ${result.error}`); - reject(new Error("Failed to check for updates")); - } else { - resolve(result); - } - } - }); - }); - }, [send, setAppVersion, setSystemVersion]); + const { getVersionInfo } = useVersion(); const progressBarRef = useRef(null); useEffect(() => { diff --git a/ui/src/routes/devices.$id.settings.network.tsx b/ui/src/routes/devices.$id.settings.network.tsx index d87eb2be..d1ac6966 100644 --- a/ui/src/routes/devices.$id.settings.network.tsx +++ b/ui/src/routes/devices.$id.settings.network.tsx @@ -166,11 +166,11 @@ export default function SettingsNetworkRoute() { }, [getNetworkState, getNetworkSettings]); const handleIpv4ModeChange = (value: IPv4Mode | string) => { - setNetworkSettings({ ...networkSettings, ipv4_mode: value as IPv4Mode }); + setNetworkSettingsRemote({ ...networkSettings, ipv4_mode: value as IPv4Mode }); }; const handleIpv6ModeChange = (value: IPv6Mode | string) => { - setNetworkSettings({ ...networkSettings, ipv6_mode: value as IPv6Mode }); + setNetworkSettingsRemote({ ...networkSettings, ipv6_mode: value as IPv6Mode }); }; const handleLldpModeChange = (value: LLDPMode | string) => { @@ -419,7 +419,7 @@ export default function SettingsNetworkRoute() { value={networkSettings.ipv6_mode} onChange={e => handleIpv6ModeChange(e.target.value)} options={filterUnknown([ - // { value: "disabled", label: "Disabled" }, + { value: "disabled", label: "Disabled" }, { value: "slaac", label: "SLAAC" }, // { value: "dhcpv6", label: "DHCPv6" }, // { value: "slaac_and_dhcpv6", label: "SLAAC and DHCPv6" }, diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index c3ca97f8..1841e8bd 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -19,14 +19,12 @@ import { CLOUD_API, DEVICE_API } from "@/ui.config"; import api from "@/api"; import { checkAuth, isInCloud, isOnDevice } from "@/main"; import { cx } from "@/cva.config"; -import notifications from "@/notifications"; import { KeyboardLedState, KeysDownState, NetworkState, OtaState, USBStates, - useDeviceStore, useHidStore, useNetworkStateStore, User, @@ -44,7 +42,7 @@ const ConnectionStatsSidebar = lazy(() => import('@/components/sidebar/connectio const Terminal = lazy(() => import('@components/Terminal')); const UpdateInProgressStatusCard = lazy(() => import("@/components/UpdateInProgressStatusCard")); import Modal from "@/components/Modal"; -import { JsonRpcRequest, JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; +import { JsonRpcRequest, JsonRpcResponse, RpcMethodNotFound, useJsonRpc } from "@/hooks/useJsonRpc"; import { ConnectionFailedOverlay, LoadingConnectionOverlay, @@ -53,8 +51,8 @@ import { import { useDeviceUiNavigation } from "@/hooks/useAppNavigation"; import { FeatureFlagProvider } from "@/providers/FeatureFlagProvider"; import { DeviceStatus } from "@routes/welcome-local"; -import { SystemVersionInfo } from "@routes/devices.$id.settings.general.update"; import audioQualityService from "@/services/audioQualityService"; +import { useVersion } from "@/hooks/useVersion"; interface LocalLoaderResp { authMode: "password" | "noPassword" | null; @@ -139,6 +137,8 @@ export default function KvmIdRoute() { rpcDataChannel, setTransceiver, setRpcHidChannel, + setRpcHidUnreliableNonOrderedChannel, + setRpcHidUnreliableChannel, } = useRTCStore(); const location = useLocation(); @@ -513,6 +513,24 @@ export default function KvmIdRoute() { setRpcHidChannel(rpcHidChannel); }; + const rpcHidUnreliableChannel = pc.createDataChannel("hidrpc-unreliable-ordered", { + ordered: true, + maxRetransmits: 0, + }); + rpcHidUnreliableChannel.binaryType = "arraybuffer"; + rpcHidUnreliableChannel.onopen = () => { + setRpcHidUnreliableChannel(rpcHidUnreliableChannel); + }; + + const rpcHidUnreliableNonOrderedChannel = pc.createDataChannel("hidrpc-unreliable-nonordered", { + ordered: false, + maxRetransmits: 0, + }); + rpcHidUnreliableNonOrderedChannel.binaryType = "arraybuffer"; + rpcHidUnreliableNonOrderedChannel.onopen = () => { + setRpcHidUnreliableNonOrderedChannel(rpcHidUnreliableNonOrderedChannel); + }; + setPeerConnection(pc); }, [ cleanupAndStopReconnecting, @@ -524,6 +542,8 @@ export default function KvmIdRoute() { setPeerConnectionState, setRpcDataChannel, setRpcHidChannel, + setRpcHidUnreliableNonOrderedChannel, + setRpcHidUnreliableChannel, setTransceiver, ]); @@ -613,6 +633,7 @@ export default function KvmIdRoute() { keyboardLedState, setKeyboardLedState, keysDownState, setKeysDownState, setUsbState, } = useHidStore(); + const setHidRpcDisabled = useRTCStore(state => state.setHidRpcDisabled); const [hasUpdated, setHasUpdated] = useState(false); const { navigateTo } = useDeviceUiNavigation(); @@ -741,9 +762,10 @@ export default function KvmIdRoute() { send("getKeyDownState", {}, (resp: JsonRpcResponse) => { if ("error" in resp) { // -32601 means the method is not supported - if (resp.error.code === -32601) { + if (resp.error.code === RpcMethodNotFound) { // if we don't support key down state, we know key press is also not available console.warn("Failed to get key down state, switching to old-school", resp.error); + setHidRpcDisabled(true); } else { console.error("Failed to get key down state", resp.error); } @@ -754,7 +776,7 @@ export default function KvmIdRoute() { } setNeedKeyDownState(false); }); - }, [keysDownState, needKeyDownState, rpcDataChannel?.readyState, send, setKeysDownState]); + }, [keysDownState, needKeyDownState, rpcDataChannel?.readyState, send, setKeysDownState, setHidRpcDisabled]); // When the update is successful, we need to refresh the client javascript and show a success modal useEffect(() => { @@ -783,26 +805,13 @@ export default function KvmIdRoute() { if (location.pathname !== "/other-session") navigateTo("/"); }, [navigateTo, location.pathname]); - const { appVersion, setAppVersion, setSystemVersion} = useDeviceStore(); + const { appVersion, getLocalVersion} = useVersion(); useEffect(() => { if (appVersion) return; - send("getUpdateStatus", {}, (resp: JsonRpcResponse) => { - if ("error" in resp) { - notifications.error(`Failed to get device version: ${resp.error}`); - return - } - - const result = resp.result as SystemVersionInfo; - if (result.error) { - notifications.error(`Failed to get device version: ${result.error}`); - } - - setAppVersion(result.local.appVersion); - setSystemVersion(result.local.systemVersion); - }); - }, [appVersion, send, setAppVersion, setSystemVersion]); + getLocalVersion(); + }, [appVersion, getLocalVersion]); const ConnectionStatusElement = useMemo(() => { const hasConnectionFailed = diff --git a/usb.go b/usb.go index 131cd517..99287a30 100644 --- a/usb.go +++ b/usb.go @@ -33,7 +33,13 @@ func initUsbGadget() { gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { if currentSession != nil { - currentSession.reportHidRPCKeysDownState(state) + currentSession.enqueueKeysDownState(state) + } + }) + + gadget.SetOnKeepAliveReset(func() { + if currentSession != nil { + currentSession.resetKeepAliveTime() } }) @@ -43,11 +49,11 @@ func initUsbGadget() { } } -func rpcKeyboardReport(modifier byte, keys []byte) (usbgadget.KeysDownState, error) { +func rpcKeyboardReport(modifier byte, keys []byte) error { return gadget.KeyboardReport(modifier, keys) } -func rpcKeypressReport(key byte, press bool) (usbgadget.KeysDownState, error) { +func rpcKeypressReport(key byte, press bool) error { return gadget.KeypressReport(key, press) } diff --git a/web.go b/web.go index 8c8707a0..7f8a8600 100644 --- a/web.go +++ b/web.go @@ -238,6 +238,10 @@ func handleWebRTCSession(c *gin.Context) { _ = peerConn.Close() }() } + + // Cancel any ongoing keyboard macro when session changes + cancelKeyboardMacro() + currentSession = session c.JSON(http.StatusOK, gin.H{"sd": sd}) } @@ -576,14 +580,31 @@ func RunWebServer() { r := setupRouter() // Determine the binding address based on the config - bindAddress := ":80" // Default to all interfaces + var bindAddress string + listenPort := 80 // default port + useIPv4 := config.NetworkConfig.IPv4Mode.String != "disabled" + useIPv6 := config.NetworkConfig.IPv6Mode.String != "disabled" + if config.LocalLoopbackOnly { - bindAddress = "localhost:80" // Loopback only (both IPv4 and IPv6) + if useIPv4 && useIPv6 { + bindAddress = fmt.Sprintf("localhost:%d", listenPort) + } else if useIPv4 { + bindAddress = fmt.Sprintf("127.0.0.1:%d", listenPort) + } else if useIPv6 { + bindAddress = fmt.Sprintf("[::1]:%d", listenPort) + } + } else { + if useIPv4 && useIPv6 { + bindAddress = fmt.Sprintf(":%d", listenPort) + } else if useIPv4 { + bindAddress = fmt.Sprintf("0.0.0.0:%d", listenPort) + } else if useIPv6 { + bindAddress = fmt.Sprintf("[::]:%d", listenPort) + } } logger.Info().Str("bindAddress", bindAddress).Bool("loopbackOnly", config.LocalLoopbackOnly).Msg("Starting web server") - err := r.Run(bindAddress) - if err != nil { + if err := r.Run(bindAddress); err != nil { panic(err) } } diff --git a/webrtc.go b/webrtc.go index bb796b85..afb5ff27 100644 --- a/webrtc.go +++ b/webrtc.go @@ -17,6 +17,7 @@ import ( "github.com/jetkvm/kvm/internal/audio" "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/usbgadget" "github.com/pion/webrtc/v4" "github.com/rs/zerolog" ) @@ -39,9 +40,26 @@ type Session struct { rpcQueue chan webrtc.DataChannelMessage - hidRPCAvailable bool - hidQueueLock sync.Mutex - hidQueue []chan webrtc.DataChannelMessage + hidRPCAvailable bool + lastKeepAliveArrivalTime time.Time // Track when last keep-alive packet arrived + lastTimerResetTime time.Time // Track when auto-release timer was last reset + keepAliveJitterLock sync.Mutex // Protect jitter compensation timing state + hidQueueLock sync.Mutex + hidQueue []chan hidQueueMessage + + keysDownStateQueue chan usbgadget.KeysDownState +} + +func (s *Session) resetKeepAliveTime() { + s.keepAliveJitterLock.Lock() + defer s.keepAliveJitterLock.Unlock() + s.lastKeepAliveArrivalTime = time.Time{} // Reset keep-alive timing tracking + s.lastTimerResetTime = time.Time{} // Reset auto-release timer tracking +} + +type hidQueueMessage struct { + webrtc.DataChannelMessage + channel string } type SessionConfig struct { @@ -90,16 +108,85 @@ func (s *Session) initQueues() { s.hidQueueLock.Lock() defer s.hidQueueLock.Unlock() - s.hidQueue = make([]chan webrtc.DataChannelMessage, 0) + s.hidQueue = make([]chan hidQueueMessage, 0) for i := 0; i < 4; i++ { - q := make(chan webrtc.DataChannelMessage, 256) + q := make(chan hidQueueMessage, 256) s.hidQueue = append(s.hidQueue, q) } } func (s *Session) handleQueues(index int) { for msg := range s.hidQueue[index] { - onHidMessage(msg.Data, s) + onHidMessage(msg, s) + } +} + +const keysDownStateQueueSize = 64 + +func (s *Session) initKeysDownStateQueue() { + // serialise outbound key state reports so unreliable links can't stall input handling + s.keysDownStateQueue = make(chan usbgadget.KeysDownState, keysDownStateQueueSize) + go s.handleKeysDownStateQueue() +} + +func (s *Session) handleKeysDownStateQueue() { + for state := range s.keysDownStateQueue { + s.reportHidRPCKeysDownState(state) + } +} + +func (s *Session) enqueueKeysDownState(state usbgadget.KeysDownState) { + if s == nil || s.keysDownStateQueue == nil { + return + } + + select { + case s.keysDownStateQueue <- state: + default: + hidRPCLogger.Warn().Msg("dropping keys down state update; queue full") + } +} + +func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { + return func(msg webrtc.DataChannelMessage) { + l := scopedLogger.With(). + Str("channel", channel). + Int("length", len(msg.Data)). + Logger() + // only log data if the log level is debug or lower + if scopedLogger.GetLevel() > zerolog.DebugLevel { + l = l.With().Str("data", string(msg.Data)).Logger() + } + + if msg.IsString { + l.Warn().Msg("received string data in HID RPC message handler") + return + } + + if len(msg.Data) < 1 { + l.Warn().Msg("received empty data in HID RPC message handler") + return + } + + l.Trace().Msg("received data in HID RPC message handler") + + // Enqueue to ensure ordered processing + queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + if queueIndex >= len(session.hidQueue) || queueIndex < 0 { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") + queueIndex = 3 + } + + queue := session.hidQueue[queueIndex] + if queue != nil { + queue <- hidQueueMessage{ + DataChannelMessage: msg, + channel: channel, + } + } else { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") + return + } } } @@ -155,6 +242,7 @@ func newSession(config SessionConfig) (*Session, error) { session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) session.initQueues() + session.initKeysDownStateQueue() go func() { for msg := range session.rpcQueue { @@ -179,40 +267,12 @@ func newSession(config SessionConfig) (*Session, error) { switch d.Label() { case "hidrpc": session.HidChannel = d - d.OnMessage(func(msg webrtc.DataChannelMessage) { - l := scopedLogger.With().Int("length", len(msg.Data)).Logger() - // only log data if the log level is debug or lower - if scopedLogger.GetLevel() > zerolog.DebugLevel { - l = l.With().Str("data", string(msg.Data)).Logger() - } - - if msg.IsString { - l.Warn().Msg("received string data in HID RPC message handler") - return - } - - if len(msg.Data) < 1 { - l.Warn().Msg("received empty data in HID RPC message handler") - return - } - - l.Trace().Msg("received data in HID RPC message handler") - - // Enqueue to ensure ordered processing - queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) - if queueIndex >= len(session.hidQueue) || queueIndex < 0 { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") - queueIndex = 3 - } - - queue := session.hidQueue[queueIndex] - if queue != nil { - queue <- msg - } else { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") - return - } - }) + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc")) + // we won't send anything over the unreliable channels + case "hidrpc-unreliable-ordered": + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-ordered")) + case "hidrpc-unreliable-nonordered": + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-nonordered")) case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) { @@ -338,6 +398,8 @@ func newSession(config SessionConfig) (*Session, error) { if connectionState == webrtc.ICEConnectionStateClosed { scopedLogger.Debug().Msg("ICE Connection State is closed, unmounting virtual media") if session == currentSession { + // Cancel any ongoing keyboard report multi when session closes + cancelKeyboardMacro() currentSession = nil } // Stop RPC processor @@ -352,6 +414,9 @@ func newSession(config SessionConfig) (*Session, error) { session.hidQueue[i] = nil } + close(session.keysDownStateQueue) + session.keysDownStateQueue = nil + if session.shouldUmountVirtualMedia { if err := rpcUnmountImage(); err != nil { scopedLogger.Warn().Err(err).Msg("unmount image failed on connection close")