diff --git a/hidrpc.go b/hidrpc.go new file mode 100644 index 0000000..5925488 --- /dev/null +++ b/hidrpc.go @@ -0,0 +1,150 @@ +package kvm + +import ( + "fmt" + "time" + + "github.com/jetkvm/kvm/internal/hidrpc" + "github.com/jetkvm/kvm/internal/usbgadget" +) + +func handleHidRpcMessage(message hidrpc.Message, session *Session) { + var rpcErr error + + switch message.Type() { + case hidrpc.TypeHandshake: + message, err := hidrpc.NewHandshakeMessage().Marshal() + if err != nil { + logger.Warn().Err(err).Msg("failed to marshal handshake message") + return + } + if err := session.HidChannel.Send(message); err != nil { + logger.Warn().Err(err).Msg("failed to send handshake message") + return + } + session.hidRpcAvailable = true + case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: + keysDownState, err := handleHidRpcKeyboardInput(message) + if keysDownState != nil { + reportHidRpcKeysDownState(*keysDownState, session) + } + rpcErr = err + case hidrpc.TypePointerReport: + pointerReport, err := message.PointerReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get pointer report") + return + } + rpcErr = rpcAbsMouseReport(pointerReport.X, pointerReport.Y, pointerReport.Button) + case hidrpc.TypeMouseReport: + mouseReport, err := message.MouseReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get mouse report") + return + } + rpcErr = rpcRelMouseReport(mouseReport.DX, mouseReport.DY, mouseReport.Button) + default: + logger.Warn().Uint8("type", uint8(message.Type())).Msg("unknown HID RPC message type") + } + + if rpcErr != nil { + logger.Warn().Err(rpcErr).Msg("failed to handle HID RPC message") + } +} + +func onHidMessage(data []byte, session *Session) { + scopedLogger := hidRpcLogger.With().Bytes("data", data).Logger() + scopedLogger.Debug().Msg("HID RPC message received") + + if len(data) < 1 { + scopedLogger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler") + return + } + + var message hidrpc.Message + + if err := hidrpc.Unmarshal(data, &message); err != nil { + scopedLogger.Warn().Err(err).Msg("failed to unmarshal HID RPC message") + return + } + + scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + + t := time.Now() + + r := make(chan interface{}) + go func() { + handleHidRpcMessage(message, session) + r <- nil + }() + select { + case <-time.After(1 * time.Second): + scopedLogger.Warn().Msg("HID RPC message timed out") + case <-r: + scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") + } +} + +func handleHidRpcKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { + switch message.Type() { + case hidrpc.TypeKeypressReport: + keypressReport, err := message.KeypressReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get keypress report") + return nil, err + } + keysDownState, rpcError := rpcKeypressReport(keypressReport.Key, keypressReport.Press) + return &keysDownState, rpcError + case hidrpc.TypeKeyboardReport: + keyboardReport, err := message.KeyboardReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get keyboard report") + return nil, err + } + keysDownState, rpcError := rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) + return &keysDownState, rpcError + } + + return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type()) +} + +func reportHidRpc(params any, session *Session) { + var ( + message []byte + err error + ) + switch params := params.(type) { + case usbgadget.KeyboardState: + message, err = hidrpc.NewKeyboardLedMessage(params).Marshal() + case usbgadget.KeysDownState: + message, err = hidrpc.NewKeydownStateMessage(params).Marshal() + } + + if err != nil { + logger.Warn().Err(err).Msg("failed to marshal HID RPC message") + return + } + + if message == nil { + logger.Warn().Msg("failed to marshal HID RPC message") + return + } + + if err := session.HidChannel.Send(message); err != nil { + logger.Warn().Err(err).Msg("failed to send HID RPC message") + } +} + +func reportHidRpcKeyboardLedState(state usbgadget.KeyboardState, session *Session) { + if !session.hidRpcAvailable { + writeJSONRPCEvent("keyboardLedState", state, currentSession) + } + reportHidRpc(state, session) +} + +func reportHidRpcKeysDownState(state usbgadget.KeysDownState, session *Session) { + if !session.hidRpcAvailable { + writeJSONRPCEvent("keysDownState", state, currentSession) + } + reportHidRpc(state, session) +} diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go new file mode 100644 index 0000000..dfe916c --- /dev/null +++ b/internal/hidrpc/hidrpc.go @@ -0,0 +1,91 @@ +package hidrpc + +import ( + "fmt" + + "github.com/jetkvm/kvm/internal/usbgadget" +) + +// HID RPC is a variable-length packet format that is used to exchange keyboard and mouse events between the client and the server. +// The packet format is as follows: +// 1 byte: Event Type + +// MessageType is the type of the HID RPC message +type MessageType uint8 + +const ( + TypeHandshake MessageType = 0x01 + TypeKeyboardReport MessageType = 0x02 + TypePointerReport MessageType = 0x03 + TypeWheelReport MessageType = 0x04 + TypeKeypressReport MessageType = 0x05 + TypeMouseReport MessageType = 0x06 + TypeKeyboardLedState MessageType = 0x32 + TypeKeydownState MessageType = 0x33 +) + +// ShouldUseEnqueue returns true if the message type should be enqueued to the HID queue. +func ShouldUseEnqueue(messageType MessageType) bool { + return messageType == TypeMouseReport +} + +// Unmarshal unmarshals the HID RPC message from the data. +func Unmarshal(data []byte, message *Message) error { + l := len(data) + if l < 1 { + return fmt.Errorf("invalid data length: %d", l) + } + + message.t = MessageType(data[0]) + message.d = data[1:] + return nil +} + +// Marshal marshals the HID RPC message to the data. +func Marshal(message *Message) ([]byte, error) { + if message.t == 0 { + return nil, fmt.Errorf("invalid message type: %d", message.t) + } + + data := make([]byte, len(message.d)+1) + data[0] = byte(message.t) + copy(data[1:], message.d) + + return data, nil +} + +// NewHandshakeMessage creates a new handshake message. +func NewHandshakeMessage() *Message { + return &Message{ + t: TypeHandshake, + d: []byte{}, + } +} + +// NewKeyboardReportMessage creates a new keyboard report message. +func NewKeyboardReportMessage(keys []byte, modifier uint8) *Message { + return &Message{ + t: TypeKeyboardReport, + d: append([]byte{modifier}, keys...), + } +} + +// NewKeyboardLedMessage creates a new keyboard LED message. +func NewKeyboardLedMessage(state usbgadget.KeyboardState) *Message { + return &Message{ + t: TypeKeyboardLedState, + d: []byte{state.Byte()}, + } +} + +// NewKeydownStateMessage creates a new keydown state message. +func NewKeydownStateMessage(state usbgadget.KeysDownState) *Message { + data := make([]byte, len(state.Keys)+1) + data[0] = state.Modifier + copy(data[1:], state.Keys) + + return &Message{ + t: TypeKeydownState, + d: data, + } +} diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go new file mode 100644 index 0000000..1d00062 --- /dev/null +++ b/internal/hidrpc/message.go @@ -0,0 +1,121 @@ +package hidrpc + +import ( + "fmt" +) + +// Message .. +type Message struct { + t MessageType + d []byte +} + +// Marshal marshals the message to a byte array. +func (m *Message) Marshal() ([]byte, error) { + return Marshal(m) +} + +func (m *Message) Type() MessageType { + return m.t +} + +func (m *Message) String() string { + switch m.t { + case TypeHandshake: + return "Handshake" + case TypeKeypressReport: + return fmt.Sprintf("KeypressReport{Key: %d, Press: %v}", m.d[0], m.d[1] == uint8(1)) + case TypeKeyboardReport: + return fmt.Sprintf("KeyboardReport{Modifier: %d, Keys: %v}", m.d[0], m.d[1:]) + case TypePointerReport: + return fmt.Sprintf("PointerReport{X: %d, Y: %d, Button: %d}", m.d[0:4], m.d[4:8], m.d[8]) + case TypeMouseReport: + return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0:2], m.d[2:4], m.d[4]) + default: + return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d) + } +} + +// KeypressReport .. +type KeypressReport struct { + Key byte + Press bool +} + +// KeypressReport returns the keypress report from the message. +func (m *Message) KeypressReport() (KeypressReport, error) { + if m.t != TypeKeypressReport { + return KeypressReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + return KeypressReport{ + Key: m.d[0], + Press: m.d[1] == uint8(1), + }, nil +} + +// KeyboardReport .. +type KeyboardReport struct { + Modifier byte + Keys []byte +} + +// KeyboardReport returns the keyboard report from the message. +func (m *Message) KeyboardReport() (KeyboardReport, error) { + if m.t != TypeKeyboardReport { + return KeyboardReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + return KeyboardReport{ + Modifier: m.d[0], + Keys: m.d[1:], + }, nil +} + +// PointerReport .. +type PointerReport struct { + X int + Y int + Button uint8 +} + +func toInt(b []byte) int { + return int(b[0])<<24 + int(b[1])<<16 + int(b[2])<<8 + int(b[3])<<0 +} + +// PointerReport returns the point report from the message. +func (m *Message) PointerReport() (PointerReport, error) { + if m.t != TypePointerReport { + return PointerReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + if len(m.d) != 9 { + return PointerReport{}, fmt.Errorf("invalid message length: %d", len(m.d)) + } + + return PointerReport{ + X: toInt(m.d[0:4]), + Y: toInt(m.d[4:8]), + Button: uint8(m.d[8]), + }, nil +} + +// MouseReport .. +type MouseReport struct { + DX int8 + DY int8 + Button uint8 +} + +// MouseReport returns the mouse report from the message. +func (m *Message) MouseReport() (MouseReport, error) { + if m.t != TypeMouseReport { + return MouseReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + return MouseReport{ + DX: int8(m.d[0]), + DY: int8(m.d[1]), + Button: uint8(m.d[2]), + }, nil +} diff --git a/internal/usbgadget/consts.go b/internal/usbgadget/consts.go index 8204d0a..958aecc 100644 --- a/internal/usbgadget/consts.go +++ b/internal/usbgadget/consts.go @@ -1,3 +1,7 @@ package usbgadget +import "time" + const dwc3Path = "/sys/bus/platform/drivers/dwc3" + +const hidWriteTimeout = 10 * time.Millisecond diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index f4fbaa6..8b433cd 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -86,6 +86,12 @@ type KeyboardState struct { Compose bool `json:"compose"` Kana bool `json:"kana"` Shift bool `json:"shift"` // This is not part of the main USB HID spec + raw byte +} + +// Byte returns the raw byte representation of the keyboard state. +func (k *KeyboardState) Byte() byte { + return k.raw } func getKeyboardState(b byte) KeyboardState { @@ -151,7 +157,9 @@ func (u *UsbGadget) updateKeyDownState(state KeysDownState) { u.keysDownState = state if u.onKeysDownChange != nil { + u.log.Trace().Interface("state", state).Msg("calling onKeysDownChange") (*u.onKeysDownChange)(state) + u.log.Trace().Interface("state", state).Msg("onKeysDownChange called") } } @@ -233,7 +241,7 @@ func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { return err } - _, err := u.keyboardHidFile.Write(append([]byte{modifier, 0x00}, keys[:hidKeyBufferSize]...)) + _, err := writeWithTimeout(u.keyboardHidFile, append([]byte{modifier, 0x00}, keys[:hidKeyBufferSize]...)) if err != nil { u.logWithSuppression("keyboardWriteHidFile", 100, u.log, err, "failed to write to hidg0") u.keyboardHidFile.Close() diff --git a/internal/usbgadget/hid_mouse_absolute.go b/internal/usbgadget/hid_mouse_absolute.go index c083b60..4f6f8d7 100644 --- a/internal/usbgadget/hid_mouse_absolute.go +++ b/internal/usbgadget/hid_mouse_absolute.go @@ -74,7 +74,7 @@ func (u *UsbGadget) absMouseWriteHidFile(data []byte) error { } } - _, err := u.absMouseHidFile.Write(data) + _, err := writeWithTimeout(u.absMouseHidFile, data) if err != nil { u.logWithSuppression("absMouseWriteHidFile", 100, u.log, err, "failed to write to hidg1") u.absMouseHidFile.Close() diff --git a/internal/usbgadget/hid_mouse_relative.go b/internal/usbgadget/hid_mouse_relative.go index 70cb72c..25ec2c1 100644 --- a/internal/usbgadget/hid_mouse_relative.go +++ b/internal/usbgadget/hid_mouse_relative.go @@ -64,7 +64,7 @@ func (u *UsbGadget) relMouseWriteHidFile(data []byte) error { } } - _, err := u.relMouseHidFile.Write(data) + _, err := writeWithTimeout(u.relMouseHidFile, data) if err != nil { u.logWithSuppression("relMouseWriteHidFile", 100, u.log, err, "failed to write to hidg2") u.relMouseHidFile.Close() diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index 05fcd3a..6c295d6 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -3,10 +3,13 @@ package usbgadget import ( "bytes" "encoding/json" + "errors" "fmt" + "os" "path/filepath" "strconv" "strings" + "time" "github.com/rs/zerolog" ) @@ -107,6 +110,23 @@ func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) return false } +func writeWithTimeout(file *os.File, data []byte) (n int, err error) { + if err := file.SetWriteDeadline(time.Now().Add(hidWriteTimeout)); err != nil { + return -1, err + } + + n, err = file.Write(data) + if err == nil { + return + } + + if errors.Is(err, os.ErrDeadlineExceeded) { + err = nil + } + + return +} + func (u *UsbGadget) logWithSuppression(counterName string, every int, logger *zerolog.Logger, err error, msg string, args ...any) { u.logSuppressionLock.Lock() defer u.logSuppressionLock.Unlock() diff --git a/log.go b/log.go index 1a091b1..8b8194e 100644 --- a/log.go +++ b/log.go @@ -19,6 +19,7 @@ var ( nbdLogger = logging.GetSubsystemLogger("nbd") timesyncLogger = logging.GetSubsystemLogger("timesync") jsonRpcLogger = logging.GetSubsystemLogger("jsonrpc") + hidRpcLogger = logging.GetSubsystemLogger("hidrpc") watchdogLogger = logging.GetSubsystemLogger("watchdog") websecureLogger = logging.GetSubsystemLogger("websecure") otaLogger = logging.GetSubsystemLogger("ota") diff --git a/ui/src/components/InfoBar.tsx b/ui/src/components/InfoBar.tsx index 29f159d..36f6e95 100644 --- a/ui/src/components/InfoBar.tsx +++ b/ui/src/components/InfoBar.tsx @@ -10,11 +10,13 @@ import { VideoState } from "@/hooks/stores"; import { keys, modifiers } from "@/keyboardMappings"; +import { useHidRpc } from "@/hooks/useHidRpc"; export default function InfoBar() { const { keysDownState } = useHidStore(); const { mouseX, mouseY, mouseMove } = useMouseStore(); - + const { rpcHidReady } = useHidRpc(); + const videoClientSize = useVideoStore( (state: VideoState) => `${Math.round(state.clientWidth)}x${Math.round(state.clientHeight)}`, ); @@ -100,6 +102,12 @@ export default function InfoBar() { {hdmiState} )} + {debugMode && ( +
+ HidRPC State: + {rpcHidReady ? "Ready" : "Not Ready"} +
+ )} {showPressedKeys && (
diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index 9e2f0f2..ba6ee5c 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -8,6 +8,7 @@ import InfoBar from "@components/InfoBar"; import notifications from "@/notifications"; import useKeyboard from "@/hooks/useKeyboard"; import { useJsonRpc } from "@/hooks/useJsonRpc"; +import { useHidRpc } from "@/hooks/useHidRpc"; import { cx } from "@/cva.config"; import { keys } from "@/keyboardMappings"; import { @@ -60,10 +61,11 @@ export default function WebRTCVideo() { // Misc states and hooks const { send } = useJsonRpc(); + const { reportAbsMouseEvent, reportRelMouseEvent, rpcHidReady } = useHidRpc(); // Video-related const handleResize = useCallback( - ( { width, height }: { width: number | undefined; height: number | undefined }) => { + ({ width, height }: { width: number | undefined; height: number | undefined }) => { if (!videoElm.current) return; // Do something with width and height, e.g.: setVideoClientSize(width || 0, height || 0); @@ -222,10 +224,22 @@ export default function WebRTCVideo() { if (settings.mouseMode !== "relative") return; // if we ignore the event, double-click will not work // if (x === 0 && y === 0 && buttons === 0) return; - send("relMouseReport", { dx: calcDelta(x), dy: calcDelta(y), buttons }); + const dx = calcDelta(x); + const dy = calcDelta(y); + if (rpcHidReady) { + reportRelMouseEvent(dx, dy, buttons); + } else { + send("relMouseReport", { dx, dy, buttons }); + } setMouseMove({ x, y, buttons }); }, - [send, setMouseMove, settings.mouseMode], + [ + send, + reportRelMouseEvent, + setMouseMove, + settings.mouseMode, + rpcHidReady, + ], ); const relMouseMoveHandler = useCallback( @@ -243,11 +257,21 @@ export default function WebRTCVideo() { const sendAbsMouseMovement = useCallback( (x: number, y: number, buttons: number) => { if (settings.mouseMode !== "absolute") return; - send("absMouseReport", { x, y, buttons }); + if (rpcHidReady) { + reportAbsMouseEvent(x, y, buttons); + } else { + send("absMouseReport", { x, y, buttons }); + } // We set that for the debug info bar setMousePosition(x, y); }, - [send, setMousePosition, settings.mouseMode], + [ + send, + reportAbsMouseEvent, + setMousePosition, + settings.mouseMode, + rpcHidReady, + ], ); const absMouseMoveHandler = useCallback( @@ -357,7 +381,7 @@ export default function WebRTCVideo() { } console.debug(`Key down: ${hidKey}`); handleKeyPress(hidKey, true); - + if (!isKeyboardLockActive && hidKey === keys.MetaLeft) { // If the left meta key was just pressed and we're not keyboard locked // we'll never see the keyup event because the browser is going to lose diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index a6dc95b..21cd7ed 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -105,6 +105,12 @@ export interface RTCState { setRpcDataChannel: (channel: RTCDataChannel) => void; rpcDataChannel: RTCDataChannel | null; + rpcHidProtocolVersion: number | null; + setRpcHidProtocolVersion: (version: number) => void; + + setRpcHidChannel: (channel: RTCDataChannel) => void; + rpcHidChannel: RTCDataChannel | null; + peerConnectionState: RTCPeerConnectionState | null; setPeerConnectionState: (state: RTCPeerConnectionState) => void; @@ -151,6 +157,12 @@ export const useRTCStore = create(set => ({ rpcDataChannel: null, setRpcDataChannel: (channel: RTCDataChannel) => set({ rpcDataChannel: channel }), + rpcHidProtocolVersion: null, + setRpcHidProtocolVersion: (version: number) => set({ rpcHidProtocolVersion: version }), + + rpcHidChannel: null, + setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }), + transceiver: null, setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts new file mode 100644 index 0000000..8dce581 --- /dev/null +++ b/ui/src/hooks/useHidRpc.ts @@ -0,0 +1,251 @@ +import { useCallback, useEffect, useMemo } from "react"; + +import { KeyboardLedState, KeysDownState, useRTCStore } from "@/hooks/stores"; + +export const HID_RPC_MESSAGE_TYPES = { + Handshake: 0x01, + KeyboardReport: 0x02, + PointerReport: 0x03, + WheelReport: 0x04, + KeypressReport: 0x05, + MouseReport: 0x06, + KeyboardLedState: 0x32, + KeysDownState: 0x33, +} + +export type HidRpcMessageType = typeof HID_RPC_MESSAGE_TYPES[keyof typeof HID_RPC_MESSAGE_TYPES]; + +const withinUint8Range = (value: number) => { + return value >= 0 && value <= 255; +}; + +const fromInt32toUint8 = (n: number) => { + if (n !== n >> 0) { + throw new Error(`Number ${n} is not within the int32 range`); + } + + return new Uint8Array([ + (n >> 24) & 0xFF, + (n >> 16) & 0xFF, + (n >> 8) & 0xFF, + (n >> 0) & 0xFF, + ]); +}; + +const fromInt8ToUint8 = (n: number) => { + if (n < -128 || n > 127) { + throw new Error(`Number ${n} is not within the int8 range`); + } + + return (n >> 0) & 0xFF; +}; + +const keyboardLedStateMasks = { + num_lock: 1 << 0, + caps_lock: 1 << 1, + scroll_lock: 1 << 2, + compose: 1 << 3, + kana: 1 << 4, + shift: 1 << 6, +} + +export const toKeyboardLedState = (s: number): KeyboardLedState => { + if (!withinUint8Range(s)) { + throw new Error(`State ${s} is not within the uint8 range`); + } + + return { + num_lock: (s & keyboardLedStateMasks.num_lock) !== 0, + caps_lock: (s & keyboardLedStateMasks.caps_lock) !== 0, + scroll_lock: (s & keyboardLedStateMasks.scroll_lock) !== 0, + compose: (s & keyboardLedStateMasks.compose) !== 0, // TODO: check if this is correct + kana: (s & keyboardLedStateMasks.kana) !== 0, + shift: (s & keyboardLedStateMasks.shift) !== 0, + } as KeyboardLedState; +}; + +const toPointerReportEvent = (x: number, y: number, buttons: number) => { + if (!withinUint8Range(buttons)) { + throw new Error(`Buttons ${buttons} is not within the uint8 range`); + } + + return new Uint8Array([ + HID_RPC_MESSAGE_TYPES.PointerReport, + ...fromInt32toUint8(x), + ...fromInt32toUint8(y), + buttons, + ]); +}; + +const toMouseReportEvent = (dx: number, dy: number, buttons: number) => { + if (!withinUint8Range(buttons)) { + throw new Error(`Buttons ${buttons} is not within the uint8 range`); + } + return new Uint8Array([ + HID_RPC_MESSAGE_TYPES.MouseReport, + fromInt8ToUint8(dx), + fromInt8ToUint8(dy), + buttons, + ]); +}; + +const toKeyboardReportEvent = (keys: number[], modifier: number) => { + if (!withinUint8Range(modifier)) { + throw new Error(`Modifier ${modifier} is not within the uint8 range`); + } + + keys.forEach((k) => { + if (!withinUint8Range(k)) { + throw new Error(`Key ${k} is not within the uint8 range`); + } + }); + + return new Uint8Array([ + HID_RPC_MESSAGE_TYPES.KeyboardReport, + modifier, + ...keys, + ]); +}; + +const toKeypressReportEvent = (key: number, press: boolean) => { + if (!withinUint8Range(key)) { + throw new Error(`Key ${key} is not within the uint8 range`); + } + + return new Uint8Array([ + HID_RPC_MESSAGE_TYPES.KeypressReport, + key, + press ? 1 : 0, + ]); +}; + +const toHandshakeMessage = () => { + return new Uint8Array([HID_RPC_MESSAGE_TYPES.Handshake]); +}; + +export interface HidRpcMessage { + type: HidRpcMessageType; + keysDownState?: KeysDownState; +} + +const unmarshalHidRpcMessage = (data: Uint8Array): HidRpcMessage | undefined => { + if (data.length < 1) { + throw new Error(`Invalid HID RPC message length: ${data.length}`); + } + + const payload = data.slice(1); + + switch (data[0]) { + case HID_RPC_MESSAGE_TYPES.Handshake: + return { + type: HID_RPC_MESSAGE_TYPES.Handshake, + }; + case HID_RPC_MESSAGE_TYPES.KeysDownState: + return { + type: HID_RPC_MESSAGE_TYPES.KeysDownState, + keysDownState: { + modifier: payload[0], + keys: Array.from(payload.slice(1)) + }, + }; + default: + throw new Error(`Unknown HID RPC message type: ${data[0]}`); + } +}; + +export function useHidRpc(onHidRpcMessage?: (payload: HidRpcMessage) => void) { + const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion } = useRTCStore(); + const rpcHidReady = useMemo(() => { + return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + }, [rpcHidChannel, rpcHidProtocolVersion]); + + const reportKeyboardEvent = useCallback( + (keys: number[], modifier: number) => { + if (!rpcHidReady) return; + rpcHidChannel?.send(toKeyboardReportEvent(keys, modifier)); + }, + [rpcHidChannel, rpcHidReady], + ); + + const reportKeypressEvent = useCallback( + (key: number, press: boolean) => { + if (!rpcHidReady) return; + rpcHidChannel?.send(toKeypressReportEvent(key, press)); + }, + [rpcHidChannel, rpcHidReady], + ); + + const reportAbsMouseEvent = useCallback( + (x: number, y: number, buttons: number) => { + if (!rpcHidReady) return; + rpcHidChannel?.send(toPointerReportEvent(x, y, buttons)); + }, + [rpcHidChannel, rpcHidReady], + ); + + const reportRelMouseEvent = useCallback( + (dx: number, dy: number, buttons: number) => { + if (!rpcHidReady) return; + rpcHidChannel?.send(toMouseReportEvent(dx, dy, buttons)); + }, + [rpcHidChannel, rpcHidReady], + ); + + const doHandshake = useCallback(() => { + if (rpcHidProtocolVersion) return; + if (!rpcHidChannel) return; + + rpcHidChannel?.send(toHandshakeMessage()); + }, [rpcHidChannel, rpcHidProtocolVersion]); + + useEffect(() => { + if (!rpcHidChannel) return; + + // send handshake message + doHandshake(); + + const messageHandler = (e: MessageEvent) => { + if (typeof e.data === "string") { + console.warn("Received string data in HID RPC message handler", e.data); + return; + } + + console.debug("Received HID RPC message", e.data); + + const message = unmarshalHidRpcMessage(new Uint8Array(e.data)); + if (!message) { + console.warn("Received invalid HID RPC message", e.data); + return; + } + + if (message.type === HID_RPC_MESSAGE_TYPES.Handshake) { + setRpcHidProtocolVersion(1); + } + + onHidRpcMessage?.(message); + }; + + rpcHidChannel.addEventListener("message", messageHandler); + + return () => { + rpcHidChannel.removeEventListener("message", messageHandler); + }; + }, + [ + rpcHidChannel, + onHidRpcMessage, + setRpcHidProtocolVersion, + doHandshake, + rpcHidReady, + ], + ); + + return { + reportKeyboardEvent, + reportKeypressEvent, + reportAbsMouseEvent, + reportRelMouseEvent, + rpcHidProtocolVersion, + rpcHidReady, + }; +} diff --git a/ui/src/hooks/useKeyboard.ts b/ui/src/hooks/useKeyboard.ts index 5f587b0..ae790bb 100644 --- a/ui/src/hooks/useKeyboard.ts +++ b/ui/src/hooks/useKeyboard.ts @@ -1,7 +1,8 @@ import { useCallback } from "react"; -import { KeysDownState, useHidStore, useRTCStore, hidKeyBufferSize, hidErrorRollOver } from "@/hooks/stores"; +import { hidErrorRollOver, hidKeyBufferSize, KeysDownState, useHidStore, useRTCStore } from "@/hooks/stores"; import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; +import { HID_RPC_MESSAGE_TYPES, useHidRpc } from "@/hooks/useHidRpc"; import { hidKeyToModifierMask, keys, modifiers } from "@/keyboardMappings"; export default function useKeyboard() { @@ -19,7 +20,19 @@ export default function useKeyboard() { // dynamically set when the device responds to the first key press event or reports its // keysDownState when queried since the keyPressReport was introduced together with the // getKeysDownState API. - const { keyPressReportApiAvailable, setkeyPressReportApiAvailable} = useHidStore(); + const { keyPressReportApiAvailable, setkeyPressReportApiAvailable } = useHidStore(); + + // HidRPC is a binary format for exchanging keyboard and mouse events + const { reportKeyboardEvent, reportKeypressEvent, rpcHidReady } = useHidRpc((message) => { + if (message.type === HID_RPC_MESSAGE_TYPES.KeysDownState) { + if (!message.keysDownState) { + return; + } + + setKeysDownState(message.keysDownState); + setkeyPressReportApiAvailable(true); + } + }); // sendKeyboardEvent is used to send the full keyboard state to the device for macro handling // and resetting keyboard state. It sends the keys currently pressed and the modifier state. @@ -27,9 +40,16 @@ export default function useKeyboard() { // or just accept the state if it does not support (returning no result) const sendKeyboardEvent = useCallback( async (state: KeysDownState) => { - if (rpcDataChannel?.readyState !== "open") return; + if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; console.debug(`Send keyboardReport keys: ${state.keys}, modifier: ${state.modifier}`); + + if (rpcHidReady) { + console.debug("Sending keyboard report via HidRPC"); + reportKeyboardEvent(state.keys, state.modifier); + return; + } + send("keyboardReport", { keys: state.keys, modifier: state.modifier }, (resp: JsonRpcResponse) => { if ("error" in resp) { console.error(`Failed to send keyboard report ${state}`, resp.error); @@ -44,13 +64,20 @@ export default function useKeyboard() { } else { // older devices versions do not return the keyDownState // so we just pretend they accepted what we sent - setKeysDownState(state); + setKeysDownState(state); setkeyPressReportApiAvailable(false); // we ALSO know they do not support keyPressReport } } }); }, - [rpcDataChannel?.readyState, send, setKeysDownState, setkeyPressReportApiAvailable], + [ + rpcDataChannel?.readyState, + rpcHidReady, + send, + reportKeyboardEvent, + setKeysDownState, + setkeyPressReportApiAvailable, + ], ); // sendKeypressEvent is used to send a single key press/release event to the device. @@ -61,9 +88,16 @@ export default function useKeyboard() { // in client/browser-side code using simulateDeviceSideKeyHandlingForLegacyDevices. const sendKeypressEvent = useCallback( async (key: number, press: boolean) => { - if (rpcDataChannel?.readyState !== "open") return; + if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; console.debug(`Send keypressEvent key: ${key}, press: ${press}`); + + if (rpcHidReady) { + console.debug("Sending keypress event via HidRPC"); + reportKeypressEvent(key, press); + return; + } + send("keypressReport", { key, press }, (resp: JsonRpcResponse) => { if ("error" in resp) { // -32601 means the method is not supported because the device is running an older version @@ -83,7 +117,14 @@ export default function useKeyboard() { } }); }, - [rpcDataChannel?.readyState, send, setkeyPressReportApiAvailable, setKeysDownState], + [ + rpcDataChannel?.readyState, + rpcHidReady, + send, + setkeyPressReportApiAvailable, + setKeysDownState, + reportKeypressEvent, + ], ); // resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers. @@ -135,9 +176,15 @@ export default function useKeyboard() { // It then sends the full keyboard state to the device. const handleKeyPress = useCallback( async (key: number, press: boolean) => { - if (rpcDataChannel?.readyState !== "open") return; + if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; if ((key || 0) === 0) return; // ignore zero key presses (they are bad mappings) + if (rpcHidReady) { + console.debug("Sending keypress event via HidRPC"); + reportKeypressEvent(key, press); + return; + } + if (keyPressReportApiAvailable) { // if the keyPress api is available, we can just send the key press event sendKeypressEvent(key, press); @@ -152,7 +199,16 @@ export default function useKeyboard() { } } }, - [keyPressReportApiAvailable, keysDownState, resetKeyboardState, rpcDataChannel?.readyState, sendKeyboardEvent, sendKeypressEvent], + [ + keyPressReportApiAvailable, + keysDownState, + resetKeyboardState, + rpcDataChannel?.readyState, + rpcHidReady, + sendKeyboardEvent, + sendKeypressEvent, + reportKeypressEvent, + ], ); // IMPORTANT: See the keyPressReportApiAvailable comment above for the reason this exists diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index 9be05f6..7737029 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -135,7 +135,8 @@ export default function KvmIdRoute() { setRpcDataChannel, isTurnServerInUse, setTurnServerInUse, rpcDataChannel, - setTransceiver + setTransceiver, + setRpcHidChannel, } = useRTCStore(); const location = useLocation(); @@ -482,6 +483,12 @@ export default function KvmIdRoute() { setRpcDataChannel(rpcDataChannel); }; + const rpcHidChannel = pc.createDataChannel("hidrpc"); + rpcHidChannel.binaryType = "arraybuffer"; + rpcHidChannel.onopen = () => { + setRpcHidChannel(rpcHidChannel); + }; + setPeerConnection(pc); }, [ cleanupAndStopReconnecting, @@ -492,6 +499,7 @@ export default function KvmIdRoute() { setPeerConnection, setPeerConnectionState, setRpcDataChannel, + setRpcHidChannel, setTransceiver, ]); diff --git a/ui/vite.config.ts b/ui/vite.config.ts index 5871c4b..44eec3a 100644 --- a/ui/vite.config.ts +++ b/ui/vite.config.ts @@ -28,6 +28,9 @@ export default defineConfig(({ mode, command }) => { return { plugins, + esbuild: { + pure: ["console.debug"], + }, build: { outDir: isCloud ? "dist" : "../static" }, server: { host: "0.0.0.0", diff --git a/usb.go b/usb.go index d29e01a..8c06719 100644 --- a/usb.go +++ b/usb.go @@ -27,13 +27,13 @@ func initUsbGadget() { gadget.SetOnKeyboardStateChange(func(state usbgadget.KeyboardState) { if currentSession != nil { - writeJSONRPCEvent("keyboardLedState", state, currentSession) + reportHidRpcKeyboardLedState(state, currentSession) } }) gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { if currentSession != nil { - writeJSONRPCEvent("keysDownState", state, currentSession) + reportHidRpcKeysDownState(state, currentSession) } }) diff --git a/webrtc.go b/webrtc.go index c0f159a..d801ef8 100644 --- a/webrtc.go +++ b/webrtc.go @@ -22,7 +22,10 @@ type Session struct { RPCChannel *webrtc.DataChannel HidChannel *webrtc.DataChannel shouldUmountVirtualMedia bool - rpcQueue chan webrtc.DataChannelMessage + + hidRpcAvailable bool + hidQueue chan webrtc.DataChannelMessage + rpcQueue chan webrtc.DataChannelMessage } type SessionConfig struct { @@ -105,17 +108,51 @@ func newSession(config SessionConfig) (*Session, error) { scopedLogger.Warn().Err(err).Msg("Failed to create PeerConnection") return nil, err } + session := &Session{peerConnection: peerConnection} session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) + session.hidQueue = make(chan webrtc.DataChannelMessage, 1024) + go func() { for msg := range session.rpcQueue { - onRPCMessage(msg, session) + // TODO: only use goroutine if the task is asynchronous + go onRPCMessage(msg, session) + } + }() + + go func() { + for msg := range session.hidQueue { + onHidMessage(msg.Data, session) } }() peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { + defer func() { + if r := recover(); r != nil { + scopedLogger.Warn().Interface("error", r).Msg("Recovered from panic in DataChannel handler") + } + }() + scopedLogger.Info().Str("label", d.Label()).Uint16("id", *d.ID()).Msg("New DataChannel") switch d.Label() { + case "hidrpc": + session.HidChannel = d + d.OnMessage(func(msg webrtc.DataChannelMessage) { + if msg.IsString { + scopedLogger.Warn().Str("data", string(msg.Data)).Msg("received string data in HID RPC message handler") + return + } + + if len(msg.Data) < 1 { + scopedLogger.Warn().Int("length", len(msg.Data)).Msg("received empty data in HID RPC message handler") + return + } + + scopedLogger.Debug().Str("data", string(msg.Data)).Msg("received data in HID RPC message handler") + + // Enqueue to ensure ordered processing + session.hidQueue <- msg + }) case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) {