From d61ea2195bc4b73246683dd3d6f17583efeacf53 Mon Sep 17 00:00:00 2001 From: Siyuan Miao Date: Sat, 30 Aug 2025 14:56:43 +0200 Subject: [PATCH] refactor hidrpc marshal / unmarshal --- internal/hidrpc/message.go | 14 +- internal/usbgadget/hid_keyboard.go | 1 + ui/src/hooks/hidRpc.ts | 303 +++++++++++++++++++++++++++++ ui/src/hooks/useHidRpc.ts | 253 ++++++------------------ ui/src/hooks/useKeyboard.ts | 24 ++- 5 files changed, 394 insertions(+), 201 deletions(-) create mode 100644 ui/src/hooks/hidRpc.ts diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go index 1d00062..84bbda7 100644 --- a/internal/hidrpc/message.go +++ b/internal/hidrpc/message.go @@ -24,13 +24,25 @@ func (m *Message) String() string { case TypeHandshake: return "Handshake" case TypeKeypressReport: + if len(m.d) < 2 { + return fmt.Sprintf("KeypressReport{Malformed: %v}", m.d) + } return fmt.Sprintf("KeypressReport{Key: %d, Press: %v}", m.d[0], m.d[1] == uint8(1)) case TypeKeyboardReport: + if len(m.d) < 2 { + return fmt.Sprintf("KeyboardReport{Malformed: %v}", m.d) + } return fmt.Sprintf("KeyboardReport{Modifier: %d, Keys: %v}", m.d[0], m.d[1:]) case TypePointerReport: + if len(m.d) < 9 { + return fmt.Sprintf("PointerReport{Malformed: %v}", m.d) + } return fmt.Sprintf("PointerReport{X: %d, Y: %d, Button: %d}", m.d[0:4], m.d[4:8], m.d[8]) case TypeMouseReport: - return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0:2], m.d[2:4], m.d[4]) + if len(m.d) < 3 { + return fmt.Sprintf("MouseReport{Malformed: %v}", m.d) + } + return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2]) default: return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d) } diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index 8b433cd..61b6115 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -103,6 +103,7 @@ func getKeyboardState(b byte) KeyboardState { Compose: b&KeyboardLedMaskCompose != 0, Kana: b&KeyboardLedMaskKana != 0, Shift: b&KeyboardLedMaskShift != 0, + raw: b, } } diff --git a/ui/src/hooks/hidRpc.ts b/ui/src/hooks/hidRpc.ts new file mode 100644 index 0000000..7910039 --- /dev/null +++ b/ui/src/hooks/hidRpc.ts @@ -0,0 +1,303 @@ +import { KeyboardLedState, KeysDownState } from "./stores"; + +export const HID_RPC_MESSAGE_TYPES = { + Handshake: 0x01, + KeyboardReport: 0x02, + PointerReport: 0x03, + WheelReport: 0x04, + KeypressReport: 0x05, + MouseReport: 0x06, + KeyboardLedState: 0x32, + KeysDownState: 0x33, +} + +export type HidRpcMessageType = typeof HID_RPC_MESSAGE_TYPES[keyof typeof HID_RPC_MESSAGE_TYPES]; + +export const HID_RPC_VERSION = 0x01; + +const withinUint8Range = (value: number) => { + return value >= 0 && value <= 255; +}; + +const fromInt32toUint8 = (n: number) => { + if (n !== n >> 0) { + throw new Error(`Number ${n} is not within the int32 range`); + } + + return new Uint8Array([ + (n >> 24) & 0xFF, + (n >> 16) & 0xFF, + (n >> 8) & 0xFF, + (n >> 0) & 0xFF, + ]); +}; + +const fromInt8ToUint8 = (n: number) => { + if (n < -128 || n > 127) { + throw new Error(`Number ${n} is not within the int8 range`); + } + + return (n >> 0) & 0xFF; +}; + +const keyboardLedStateMasks = { + num_lock: 1 << 0, + caps_lock: 1 << 1, + scroll_lock: 1 << 2, + compose: 1 << 3, + kana: 1 << 4, + shift: 1 << 6, +} + +export class RpcMessage { + messageType: HidRpcMessageType; + + constructor(messageType: HidRpcMessageType) { + this.messageType = messageType; + } + + marshal(): Uint8Array { + throw new Error("Not implemented"); + } + + // @ts-expect-error: this is a base class, so we don't need to implement it + public static unmarshal(data: Uint8Array): RpcMessage | undefined { + throw new Error("Not implemented"); + } +} + +export class HandshakeMessage extends RpcMessage { + version: number; + + constructor(version: number) { + super(HID_RPC_MESSAGE_TYPES.Handshake); + this.version = version; + } + + marshal(): Uint8Array { + return new Uint8Array([this.messageType, this.version]); + } + + public static unmarshal(data: Uint8Array): HandshakeMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid handshake message length: ${data.length}`); + } + + return new HandshakeMessage(data[0]); + } +} + +export class KeypressReportMessage extends RpcMessage { + private _key = 0; + private _press = false; + + get key(): number { + return this._key; + } + + set key(value: number) { + if (!withinUint8Range(value)) { + throw new Error(`Key ${value} is not within the uint8 range`); + } + + this._key = value; + } + + get press(): boolean { + return this._press; + } + + set press(value: boolean) { + this._press = value; + } + + constructor(key: number, press: boolean) { + super(HID_RPC_MESSAGE_TYPES.KeypressReport); + this.key = key; + this.press = press; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + this.key, + this.press ? 1 : 0, + ]); + } + + public static unmarshal(data: Uint8Array): KeypressReportMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keypress report message length: ${data.length}`); + } + + return new KeypressReportMessage(data[0], data[1] === 1); + } +} + +export class KeyboardReportMessage extends RpcMessage { + private _keys: number[] = []; + private _modifier = 0; + + get keys(): number[] { + return this._keys; + } + + set keys(value: number[]) { + value.forEach((k) => { + if (!withinUint8Range(k)) { + throw new Error(`Key ${k} is not within the uint8 range`); + } + }); + + this._keys = value; + } + + get modifier(): number { + return this._modifier; + } + + set modifier(value: number) { + if (!withinUint8Range(value)) { + throw new Error(`Modifier ${value} is not within the uint8 range`); + } + + this._modifier = value; + } + + constructor(keys: number[], modifier: number) { + super(HID_RPC_MESSAGE_TYPES.KeyboardReport); + this.keys = keys; + this.modifier = modifier; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + this.modifier, + ...this.keys, + ]); + } + + public static unmarshal(data: Uint8Array): KeyboardReportMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keyboard report message length: ${data.length}`); + } + + return new KeyboardReportMessage(Array.from(data.slice(1)), data[0]); + } +} + +export class KeyboardLedStateMessage extends RpcMessage { + keyboardLedState: KeyboardLedState; + + constructor(keyboardLedState: KeyboardLedState) { + super(HID_RPC_MESSAGE_TYPES.KeyboardLedState); + this.keyboardLedState = keyboardLedState; + } + + public static unmarshal(data: Uint8Array): KeyboardLedStateMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keyboard led state message length: ${data.length}`); + } + + const s = data[0]; + + const state = { + num_lock: (s & keyboardLedStateMasks.num_lock) !== 0, + caps_lock: (s & keyboardLedStateMasks.caps_lock) !== 0, + scroll_lock: (s & keyboardLedStateMasks.scroll_lock) !== 0, + compose: (s & keyboardLedStateMasks.compose) !== 0, + kana: (s & keyboardLedStateMasks.kana) !== 0, + shift: (s & keyboardLedStateMasks.shift) !== 0, + } as KeyboardLedState; + + return new KeyboardLedStateMessage(state); + } +} + +export class KeysDownStateMessage extends RpcMessage { + keysDownState: KeysDownState; + + constructor(keysDownState: KeysDownState) { + super(HID_RPC_MESSAGE_TYPES.KeysDownState); + this.keysDownState = keysDownState; + } + + public static unmarshal(data: Uint8Array): KeysDownStateMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keys down state message length: ${data.length}`); + } + + return new KeysDownStateMessage({ + modifier: data[0], + keys: Array.from(data.slice(1)) + }); + } +} + +export class PointerReportMessage extends RpcMessage { + x: number; + y: number; + buttons: number; + + constructor(x: number, y: number, buttons: number) { + super(HID_RPC_MESSAGE_TYPES.PointerReport); + this.x = x; + this.y = y; + this.buttons = buttons; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + ...fromInt32toUint8(this.x), + ...fromInt32toUint8(this.y), + this.buttons, + ]); + } +} + +export class MouseReportMessage extends RpcMessage { + dx: number; + dy: number; + buttons: number; + + constructor(dx: number, dy: number, buttons: number) { + super(HID_RPC_MESSAGE_TYPES.MouseReport); + this.dx = dx; + this.dy = dy; + this.buttons = buttons; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + fromInt8ToUint8(this.dx), + fromInt8ToUint8(this.dy), + this.buttons, + ]); + } +} + +export const messageRegistry = { + [HID_RPC_MESSAGE_TYPES.Handshake]: HandshakeMessage, + [HID_RPC_MESSAGE_TYPES.KeysDownState]: KeysDownStateMessage, + [HID_RPC_MESSAGE_TYPES.KeyboardLedState]: KeyboardLedStateMessage, + [HID_RPC_MESSAGE_TYPES.KeyboardReport]: KeyboardReportMessage, + [HID_RPC_MESSAGE_TYPES.KeypressReport]: KeypressReportMessage, +} + +export const unmarshalHidRpcMessage = (data: Uint8Array): RpcMessage | undefined => { + if (data.length < 1) { + throw new Error(`Invalid HID RPC message length: ${data.length}`); + } + + const payload = data.slice(1); + + const messageType = data[0]; + if (!(messageType in messageRegistry)) { + throw new Error(`Unknown HID RPC message type: ${messageType}`); + } + + return messageRegistry[messageType].unmarshal(payload); +}; \ No newline at end of file diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts index ebf5225..36ba038 100644 --- a/ui/src/hooks/useHidRpc.ts +++ b/ui/src/hooks/useHidRpc.ts @@ -1,163 +1,19 @@ import { useCallback, useEffect, useMemo } from "react"; -import { KeyboardLedState, KeysDownState, useRTCStore } from "@/hooks/stores"; +import { useRTCStore } from "@/hooks/stores"; -export const HID_RPC_MESSAGE_TYPES = { - Handshake: 0x01, - KeyboardReport: 0x02, - PointerReport: 0x03, - WheelReport: 0x04, - KeypressReport: 0x05, - MouseReport: 0x06, - KeyboardLedState: 0x32, - KeysDownState: 0x33, -} +import { + HID_RPC_VERSION, + HandshakeMessage, + KeyboardReportMessage, + KeypressReportMessage, + MouseReportMessage, + PointerReportMessage, + RpcMessage, + unmarshalHidRpcMessage, +} from "./hidRpc"; -export type HidRpcMessageType = typeof HID_RPC_MESSAGE_TYPES[keyof typeof HID_RPC_MESSAGE_TYPES]; - -export const HID_RPC_VERSION = 0x01; - -const withinUint8Range = (value: number) => { - return value >= 0 && value <= 255; -}; - -const fromInt32toUint8 = (n: number) => { - if (n !== n >> 0) { - throw new Error(`Number ${n} is not within the int32 range`); - } - - return new Uint8Array([ - (n >> 24) & 0xFF, - (n >> 16) & 0xFF, - (n >> 8) & 0xFF, - (n >> 0) & 0xFF, - ]); -}; - -const fromInt8ToUint8 = (n: number) => { - if (n < -128 || n > 127) { - throw new Error(`Number ${n} is not within the int8 range`); - } - - return (n >> 0) & 0xFF; -}; - -const keyboardLedStateMasks = { - num_lock: 1 << 0, - caps_lock: 1 << 1, - scroll_lock: 1 << 2, - compose: 1 << 3, - kana: 1 << 4, - shift: 1 << 6, -} - -export const toKeyboardLedState = (s: number): KeyboardLedState => { - if (!withinUint8Range(s)) { - throw new Error(`State ${s} is not within the uint8 range`); - } - - return { - num_lock: (s & keyboardLedStateMasks.num_lock) !== 0, - caps_lock: (s & keyboardLedStateMasks.caps_lock) !== 0, - scroll_lock: (s & keyboardLedStateMasks.scroll_lock) !== 0, - compose: (s & keyboardLedStateMasks.compose) !== 0, - kana: (s & keyboardLedStateMasks.kana) !== 0, - shift: (s & keyboardLedStateMasks.shift) !== 0, - } as KeyboardLedState; -}; - -const toPointerReportEvent = (x: number, y: number, buttons: number) => { - if (!withinUint8Range(buttons)) { - throw new Error(`Buttons ${buttons} is not within the uint8 range`); - } - - return new Uint8Array([ - HID_RPC_MESSAGE_TYPES.PointerReport, - ...fromInt32toUint8(x), - ...fromInt32toUint8(y), - buttons, - ]); -}; - -const toMouseReportEvent = (dx: number, dy: number, buttons: number) => { - if (!withinUint8Range(buttons)) { - throw new Error(`Buttons ${buttons} is not within the uint8 range`); - } - return new Uint8Array([ - HID_RPC_MESSAGE_TYPES.MouseReport, - fromInt8ToUint8(dx), - fromInt8ToUint8(dy), - buttons, - ]); -}; - -const toKeyboardReportEvent = (keys: number[], modifier: number) => { - if (!withinUint8Range(modifier)) { - throw new Error(`Modifier ${modifier} is not within the uint8 range`); - } - - keys.forEach((k) => { - if (!withinUint8Range(k)) { - throw new Error(`Key ${k} is not within the uint8 range`); - } - }); - - return new Uint8Array([ - HID_RPC_MESSAGE_TYPES.KeyboardReport, - modifier, - ...keys, - ]); -}; - -const toKeypressReportEvent = (key: number, press: boolean) => { - if (!withinUint8Range(key)) { - throw new Error(`Key ${key} is not within the uint8 range`); - } - - return new Uint8Array([ - HID_RPC_MESSAGE_TYPES.KeypressReport, - key, - press ? 1 : 0, - ]); -}; - -const toHandshakeMessage = () => { - return new Uint8Array([HID_RPC_MESSAGE_TYPES.Handshake, HID_RPC_VERSION]); -}; - -export interface HidRpcMessage { - type: HidRpcMessageType; - version?: number; - keysDownState?: KeysDownState; -} - -const unmarshalHidRpcMessage = (data: Uint8Array): HidRpcMessage | undefined => { - if (data.length < 1) { - throw new Error(`Invalid HID RPC message length: ${data.length}`); - } - - const payload = data.slice(1); - - switch (data[0]) { - case HID_RPC_MESSAGE_TYPES.Handshake: - return { - type: HID_RPC_MESSAGE_TYPES.Handshake, - version: payload[0], - }; - case HID_RPC_MESSAGE_TYPES.KeysDownState: - return { - type: HID_RPC_MESSAGE_TYPES.KeysDownState, - keysDownState: { - modifier: payload[0], - keys: Array.from(payload.slice(1)) - }, - }; - default: - throw new Error(`Unknown HID RPC message type: ${data[0]}`); - } -}; - -export function useHidRpc(onHidRpcMessage?: (payload: HidRpcMessage) => void) { +export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion } = useRTCStore(); const rpcHidReady = useMemo(() => { return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; @@ -170,50 +26,74 @@ export function useHidRpc(onHidRpcMessage?: (payload: HidRpcMessage) => void) { return `ready (v${rpcHidProtocolVersion})`; }, [rpcHidChannel, rpcHidProtocolVersion]); + const sendMessage = useCallback((message: RpcMessage, ignoreHandshakeState = false) => { + if (rpcHidChannel?.readyState !== "open") return; + if (!rpcHidReady && !ignoreHandshakeState) return; + + let data: Uint8Array | undefined; + try { + data = message.marshal(); + } catch (e) { + console.error("Failed to send HID RPC message", e); + } + if (!data) return; + + rpcHidChannel?.send(data as unknown as ArrayBuffer); + }, [rpcHidChannel, rpcHidReady]); + const reportKeyboardEvent = useCallback( (keys: number[], modifier: number) => { - if (!rpcHidReady) return; - rpcHidChannel?.send(toKeyboardReportEvent(keys, modifier)); - }, - [rpcHidChannel, rpcHidReady], + sendMessage(new KeyboardReportMessage(keys, modifier)); + }, [sendMessage], ); const reportKeypressEvent = useCallback( (key: number, press: boolean) => { - if (!rpcHidReady) return; - rpcHidChannel?.send(toKeypressReportEvent(key, press)); + sendMessage(new KeypressReportMessage(key, press)); }, - [rpcHidChannel, rpcHidReady], + [sendMessage], ); const reportAbsMouseEvent = useCallback( (x: number, y: number, buttons: number) => { - if (!rpcHidReady) return; - rpcHidChannel?.send(toPointerReportEvent(x, y, buttons)); + sendMessage(new PointerReportMessage(x, y, buttons)); }, - [rpcHidChannel, rpcHidReady], + [sendMessage], ); const reportRelMouseEvent = useCallback( (dx: number, dy: number, buttons: number) => { - if (!rpcHidReady) return; - rpcHidChannel?.send(toMouseReportEvent(dx, dy, buttons)); + sendMessage(new MouseReportMessage(dx, dy, buttons)); }, - [rpcHidChannel, rpcHidReady], + [sendMessage], ); - const doHandshake = useCallback(() => { + const sendHandshake = useCallback(() => { if (rpcHidProtocolVersion) return; if (!rpcHidChannel) return; - rpcHidChannel?.send(toHandshakeMessage()); - }, [rpcHidChannel, rpcHidProtocolVersion]); + sendMessage(new HandshakeMessage(HID_RPC_VERSION), true); + }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage]); + + const handleHandshake = useCallback((message: HandshakeMessage) => { + if (!message.version) { + console.error("Received handshake message without version", message); + return; + } + + if (message.version < HID_RPC_VERSION) { + console.error("Server is using an older HID RPC version than the client", message); + return; + } + + setRpcHidProtocolVersion(message.version); + }, [setRpcHidProtocolVersion]); useEffect(() => { if (!rpcHidChannel) return; // send handshake message - doHandshake(); + sendHandshake(); const messageHandler = (e: MessageEvent) => { if (typeof e.data === "string") { @@ -221,27 +101,20 @@ export function useHidRpc(onHidRpcMessage?: (payload: HidRpcMessage) => void) { return; } - console.debug("Received HID RPC message", e.data); - const message = unmarshalHidRpcMessage(new Uint8Array(e.data)); if (!message) { console.warn("Received invalid HID RPC message", e.data); return; } - if (message.type === HID_RPC_MESSAGE_TYPES.Handshake) { - if (!message.version) { - console.error("Received handshake message without version", message); - return; - } - - // TODO: use capabilities to determine the supported functions rather than the version - if (message.version < HID_RPC_VERSION) { - console.error("Server is using an older HID RPC version than the client", message); - return; - } - - setRpcHidProtocolVersion(message.version); + console.debug("Received HID RPC message", message); + switch (message.constructor) { + case HandshakeMessage: + handleHandshake(message as HandshakeMessage); + break; + default: + // not all events are handled here, the rest are handled by the onHidRpcMessage callback + break; } onHidRpcMessage?.(message); @@ -257,8 +130,8 @@ export function useHidRpc(onHidRpcMessage?: (payload: HidRpcMessage) => void) { rpcHidChannel, onHidRpcMessage, setRpcHidProtocolVersion, - doHandshake, - rpcHidReady, + sendHandshake, + handleHandshake, ], ); diff --git a/ui/src/hooks/useKeyboard.ts b/ui/src/hooks/useKeyboard.ts index ae790bb..5c6b364 100644 --- a/ui/src/hooks/useKeyboard.ts +++ b/ui/src/hooks/useKeyboard.ts @@ -2,13 +2,14 @@ import { useCallback } from "react"; import { hidErrorRollOver, hidKeyBufferSize, KeysDownState, useHidStore, useRTCStore } from "@/hooks/stores"; import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; -import { HID_RPC_MESSAGE_TYPES, useHidRpc } from "@/hooks/useHidRpc"; +import { useHidRpc } from "@/hooks/useHidRpc"; +import { KeyboardLedStateMessage, KeysDownStateMessage } from "@/hooks/hidRpc"; import { hidKeyToModifierMask, keys, modifiers } from "@/keyboardMappings"; export default function useKeyboard() { const { send } = useJsonRpc(); const { rpcDataChannel } = useRTCStore(); - const { keysDownState, setKeysDownState } = useHidStore(); + const { keysDownState, setKeysDownState, setKeyboardLedState } = useHidStore(); // INTRODUCTION: The earlier version of the JetKVM device shipped with all keyboard state // being tracked on the browser/client-side. When adding the keyPressReport API to the @@ -24,13 +25,16 @@ export default function useKeyboard() { // HidRPC is a binary format for exchanging keyboard and mouse events const { reportKeyboardEvent, reportKeypressEvent, rpcHidReady } = useHidRpc((message) => { - if (message.type === HID_RPC_MESSAGE_TYPES.KeysDownState) { - if (!message.keysDownState) { - return; - } - - setKeysDownState(message.keysDownState); - setkeyPressReportApiAvailable(true); + switch (message.constructor) { + case KeysDownStateMessage: + setKeysDownState((message as KeysDownStateMessage).keysDownState); + setkeyPressReportApiAvailable(true); + break; + case KeyboardLedStateMessage: + setKeyboardLedState((message as KeyboardLedStateMessage).keyboardLedState); + break; + default: + break; } }); @@ -95,7 +99,7 @@ export default function useKeyboard() { if (rpcHidReady) { console.debug("Sending keypress event via HidRPC"); reportKeypressEvent(key, press); - return; + return; } send("keypressReport", { key, press }, (resp: JsonRpcResponse) => {