diff --git a/hidrpc.go b/hidrpc.go new file mode 100644 index 0000000..0de564e --- /dev/null +++ b/hidrpc.go @@ -0,0 +1,152 @@ +package kvm + +import ( + "fmt" + "time" + + "github.com/jetkvm/kvm/internal/hidrpc" + "github.com/jetkvm/kvm/internal/usbgadget" +) + +func handleHidRPCMessage(message hidrpc.Message, session *Session) { + var rpcErr error + + switch message.Type() { + case hidrpc.TypeHandshake: + message, err := hidrpc.NewHandshakeMessage().Marshal() + if err != nil { + logger.Warn().Err(err).Msg("failed to marshal handshake message") + return + } + if err := session.HidChannel.Send(message); err != nil { + logger.Warn().Err(err).Msg("failed to send handshake message") + return + } + session.hidRPCAvailable = true + case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: + keysDownState, err := handleHidRPCKeyboardInput(message) + if keysDownState != nil { + session.reportHidRPCKeysDownState(*keysDownState) + } + rpcErr = err + case hidrpc.TypePointerReport: + pointerReport, err := message.PointerReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get pointer report") + return + } + rpcErr = rpcAbsMouseReport(pointerReport.X, pointerReport.Y, pointerReport.Button) + case hidrpc.TypeMouseReport: + mouseReport, err := message.MouseReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get mouse report") + return + } + rpcErr = rpcRelMouseReport(mouseReport.DX, mouseReport.DY, mouseReport.Button) + default: + logger.Warn().Uint8("type", uint8(message.Type())).Msg("unknown HID RPC message type") + } + + if rpcErr != nil { + logger.Warn().Err(rpcErr).Msg("failed to handle HID RPC message") + } +} + +func onHidMessage(data []byte, session *Session) { + scopedLogger := hidRPCLogger.With().Bytes("data", data).Logger() + scopedLogger.Debug().Msg("HID RPC message received") + + if len(data) < 1 { + scopedLogger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler") + return + } + + var message hidrpc.Message + + if err := hidrpc.Unmarshal(data, &message); err != nil { + scopedLogger.Warn().Err(err).Msg("failed to unmarshal HID RPC message") + return + } + + scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + + t := time.Now() + + r := make(chan interface{}) + go func() { + handleHidRPCMessage(message, session) + r <- nil + }() + select { + case <-time.After(1 * time.Second): + scopedLogger.Warn().Msg("HID RPC message timed out") + case <-r: + scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") + } +} + +func handleHidRPCKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { + switch message.Type() { + case hidrpc.TypeKeypressReport: + keypressReport, err := message.KeypressReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get keypress report") + return nil, err + } + keysDownState, rpcError := rpcKeypressReport(keypressReport.Key, keypressReport.Press) + return &keysDownState, rpcError + case hidrpc.TypeKeyboardReport: + keyboardReport, err := message.KeyboardReport() + if err != nil { + logger.Warn().Err(err).Msg("failed to get keyboard report") + return nil, err + } + keysDownState, rpcError := rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) + return &keysDownState, rpcError + } + + return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type()) +} + +func reportHidRPC(params any, session *Session) { + var ( + message []byte + err error + ) + switch params := params.(type) { + case usbgadget.KeyboardState: + message, err = hidrpc.NewKeyboardLedMessage(params).Marshal() + case usbgadget.KeysDownState: + message, err = hidrpc.NewKeydownStateMessage(params).Marshal() + default: + err = fmt.Errorf("unknown HID RPC message type: %T", params) + } + + if err != nil { + logger.Warn().Err(err).Msg("failed to marshal HID RPC message") + return + } + + if message == nil { + logger.Warn().Msg("failed to marshal HID RPC message") + return + } + + if err := session.HidChannel.Send(message); err != nil { + logger.Warn().Err(err).Msg("failed to send HID RPC message") + } +} + +func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) { + if !s.hidRPCAvailable { + writeJSONRPCEvent("keyboardLedState", state, s) + } + reportHidRPC(state, s) +} + +func (s *Session) reportHidRPCKeysDownState(state usbgadget.KeysDownState) { + if !s.hidRPCAvailable { + writeJSONRPCEvent("keysDownState", state, s) + } + reportHidRPC(state, s) +} diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go new file mode 100644 index 0000000..e9c8c24 --- /dev/null +++ b/internal/hidrpc/hidrpc.go @@ -0,0 +1,100 @@ +package hidrpc + +import ( + "fmt" + + "github.com/jetkvm/kvm/internal/usbgadget" +) + +// MessageType is the type of the HID RPC message +type MessageType byte + +const ( + TypeHandshake MessageType = 0x01 + TypeKeyboardReport MessageType = 0x02 + TypePointerReport MessageType = 0x03 + TypeWheelReport MessageType = 0x04 + TypeKeypressReport MessageType = 0x05 + TypeMouseReport MessageType = 0x06 + TypeKeyboardLedState MessageType = 0x32 + TypeKeydownState MessageType = 0x33 +) + +const ( + Version byte = 0x01 // Version of the HID RPC protocol +) + +// GetQueueIndex returns the index of the queue to which the message should be enqueued. +func GetQueueIndex(messageType MessageType) int { + switch messageType { + case TypeHandshake: + return 0 + case TypeKeyboardReport, TypeKeypressReport, TypeKeyboardLedState, TypeKeydownState: + return 1 + case TypePointerReport, TypeMouseReport, TypeWheelReport: + return 2 + default: + return 3 + } +} + +// Unmarshal unmarshals the HID RPC message from the data. +func Unmarshal(data []byte, message *Message) error { + l := len(data) + if l < 1 { + return fmt.Errorf("invalid data length: %d", l) + } + + message.t = MessageType(data[0]) + message.d = data[1:] + return nil +} + +// Marshal marshals the HID RPC message to the data. +func Marshal(message *Message) ([]byte, error) { + if message.t == 0 { + return nil, fmt.Errorf("invalid message type: %d", message.t) + } + + data := make([]byte, len(message.d)+1) + data[0] = byte(message.t) + copy(data[1:], message.d) + + return data, nil +} + +// NewHandshakeMessage creates a new handshake message. +func NewHandshakeMessage() *Message { + return &Message{ + t: TypeHandshake, + d: []byte{Version}, + } +} + +// NewKeyboardReportMessage creates a new keyboard report message. +func NewKeyboardReportMessage(keys []byte, modifier uint8) *Message { + return &Message{ + t: TypeKeyboardReport, + d: append([]byte{modifier}, keys...), + } +} + +// NewKeyboardLedMessage creates a new keyboard LED message. +func NewKeyboardLedMessage(state usbgadget.KeyboardState) *Message { + return &Message{ + t: TypeKeyboardLedState, + d: []byte{state.Byte()}, + } +} + +// NewKeydownStateMessage creates a new keydown state message. +func NewKeydownStateMessage(state usbgadget.KeysDownState) *Message { + data := make([]byte, len(state.Keys)+1) + data[0] = state.Modifier + copy(data[1:], state.Keys) + + return &Message{ + t: TypeKeydownState, + d: data, + } +} diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go new file mode 100644 index 0000000..84bbda7 --- /dev/null +++ b/internal/hidrpc/message.go @@ -0,0 +1,133 @@ +package hidrpc + +import ( + "fmt" +) + +// Message .. +type Message struct { + t MessageType + d []byte +} + +// Marshal marshals the message to a byte array. +func (m *Message) Marshal() ([]byte, error) { + return Marshal(m) +} + +func (m *Message) Type() MessageType { + return m.t +} + +func (m *Message) String() string { + switch m.t { + case TypeHandshake: + return "Handshake" + case TypeKeypressReport: + if len(m.d) < 2 { + return fmt.Sprintf("KeypressReport{Malformed: %v}", m.d) + } + return fmt.Sprintf("KeypressReport{Key: %d, Press: %v}", m.d[0], m.d[1] == uint8(1)) + case TypeKeyboardReport: + if len(m.d) < 2 { + return fmt.Sprintf("KeyboardReport{Malformed: %v}", m.d) + } + return fmt.Sprintf("KeyboardReport{Modifier: %d, Keys: %v}", m.d[0], m.d[1:]) + case TypePointerReport: + if len(m.d) < 9 { + return fmt.Sprintf("PointerReport{Malformed: %v}", m.d) + } + return fmt.Sprintf("PointerReport{X: %d, Y: %d, Button: %d}", m.d[0:4], m.d[4:8], m.d[8]) + case TypeMouseReport: + if len(m.d) < 3 { + return fmt.Sprintf("MouseReport{Malformed: %v}", m.d) + } + return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2]) + default: + return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d) + } +} + +// KeypressReport .. +type KeypressReport struct { + Key byte + Press bool +} + +// KeypressReport returns the keypress report from the message. +func (m *Message) KeypressReport() (KeypressReport, error) { + if m.t != TypeKeypressReport { + return KeypressReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + return KeypressReport{ + Key: m.d[0], + Press: m.d[1] == uint8(1), + }, nil +} + +// KeyboardReport .. +type KeyboardReport struct { + Modifier byte + Keys []byte +} + +// KeyboardReport returns the keyboard report from the message. +func (m *Message) KeyboardReport() (KeyboardReport, error) { + if m.t != TypeKeyboardReport { + return KeyboardReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + return KeyboardReport{ + Modifier: m.d[0], + Keys: m.d[1:], + }, nil +} + +// PointerReport .. +type PointerReport struct { + X int + Y int + Button uint8 +} + +func toInt(b []byte) int { + return int(b[0])<<24 + int(b[1])<<16 + int(b[2])<<8 + int(b[3])<<0 +} + +// PointerReport returns the point report from the message. +func (m *Message) PointerReport() (PointerReport, error) { + if m.t != TypePointerReport { + return PointerReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + if len(m.d) != 9 { + return PointerReport{}, fmt.Errorf("invalid message length: %d", len(m.d)) + } + + return PointerReport{ + X: toInt(m.d[0:4]), + Y: toInt(m.d[4:8]), + Button: uint8(m.d[8]), + }, nil +} + +// MouseReport .. +type MouseReport struct { + DX int8 + DY int8 + Button uint8 +} + +// MouseReport returns the mouse report from the message. +func (m *Message) MouseReport() (MouseReport, error) { + if m.t != TypeMouseReport { + return MouseReport{}, fmt.Errorf("invalid message type: %d", m.t) + } + + return MouseReport{ + DX: int8(m.d[0]), + DY: int8(m.d[1]), + Button: uint8(m.d[2]), + }, nil +} diff --git a/internal/usbgadget/consts.go b/internal/usbgadget/consts.go index 8204d0a..958aecc 100644 --- a/internal/usbgadget/consts.go +++ b/internal/usbgadget/consts.go @@ -1,3 +1,7 @@ package usbgadget +import "time" + const dwc3Path = "/sys/bus/platform/drivers/dwc3" + +const hidWriteTimeout = 10 * time.Millisecond diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index f4fbaa6..61b6115 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -86,6 +86,12 @@ type KeyboardState struct { Compose bool `json:"compose"` Kana bool `json:"kana"` Shift bool `json:"shift"` // This is not part of the main USB HID spec + raw byte +} + +// Byte returns the raw byte representation of the keyboard state. +func (k *KeyboardState) Byte() byte { + return k.raw } func getKeyboardState(b byte) KeyboardState { @@ -97,6 +103,7 @@ func getKeyboardState(b byte) KeyboardState { Compose: b&KeyboardLedMaskCompose != 0, Kana: b&KeyboardLedMaskKana != 0, Shift: b&KeyboardLedMaskShift != 0, + raw: b, } } @@ -151,7 +158,9 @@ func (u *UsbGadget) updateKeyDownState(state KeysDownState) { u.keysDownState = state if u.onKeysDownChange != nil { + u.log.Trace().Interface("state", state).Msg("calling onKeysDownChange") (*u.onKeysDownChange)(state) + u.log.Trace().Interface("state", state).Msg("onKeysDownChange called") } } @@ -233,7 +242,7 @@ func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { return err } - _, err := u.keyboardHidFile.Write(append([]byte{modifier, 0x00}, keys[:hidKeyBufferSize]...)) + _, err := writeWithTimeout(u.keyboardHidFile, append([]byte{modifier, 0x00}, keys[:hidKeyBufferSize]...)) if err != nil { u.logWithSuppression("keyboardWriteHidFile", 100, u.log, err, "failed to write to hidg0") u.keyboardHidFile.Close() diff --git a/internal/usbgadget/hid_mouse_absolute.go b/internal/usbgadget/hid_mouse_absolute.go index c083b60..4f6f8d7 100644 --- a/internal/usbgadget/hid_mouse_absolute.go +++ b/internal/usbgadget/hid_mouse_absolute.go @@ -74,7 +74,7 @@ func (u *UsbGadget) absMouseWriteHidFile(data []byte) error { } } - _, err := u.absMouseHidFile.Write(data) + _, err := writeWithTimeout(u.absMouseHidFile, data) if err != nil { u.logWithSuppression("absMouseWriteHidFile", 100, u.log, err, "failed to write to hidg1") u.absMouseHidFile.Close() diff --git a/internal/usbgadget/hid_mouse_relative.go b/internal/usbgadget/hid_mouse_relative.go index 70cb72c..25ec2c1 100644 --- a/internal/usbgadget/hid_mouse_relative.go +++ b/internal/usbgadget/hid_mouse_relative.go @@ -64,7 +64,7 @@ func (u *UsbGadget) relMouseWriteHidFile(data []byte) error { } } - _, err := u.relMouseHidFile.Write(data) + _, err := writeWithTimeout(u.relMouseHidFile, data) if err != nil { u.logWithSuppression("relMouseWriteHidFile", 100, u.log, err, "failed to write to hidg2") u.relMouseHidFile.Close() diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index 05fcd3a..6c295d6 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -3,10 +3,13 @@ package usbgadget import ( "bytes" "encoding/json" + "errors" "fmt" + "os" "path/filepath" "strconv" "strings" + "time" "github.com/rs/zerolog" ) @@ -107,6 +110,23 @@ func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) return false } +func writeWithTimeout(file *os.File, data []byte) (n int, err error) { + if err := file.SetWriteDeadline(time.Now().Add(hidWriteTimeout)); err != nil { + return -1, err + } + + n, err = file.Write(data) + if err == nil { + return + } + + if errors.Is(err, os.ErrDeadlineExceeded) { + err = nil + } + + return +} + func (u *UsbGadget) logWithSuppression(counterName string, every int, logger *zerolog.Logger, err error, msg string, args ...any) { u.logSuppressionLock.Lock() defer u.logSuppressionLock.Unlock() diff --git a/jsonrpc.go b/jsonrpc.go index 82b12d0..ff3a4b1 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -83,7 +83,7 @@ func writeJSONRPCEvent(event string, params any, session *Session) { Str("data", requestString). Logger() - scopedLogger.Info().Msg("sending JSONRPC event") + scopedLogger.Trace().Msg("sending JSONRPC event") err = session.RPCChannel.SendText(requestString) if err != nil { diff --git a/log.go b/log.go index 1a091b1..2047bbf 100644 --- a/log.go +++ b/log.go @@ -19,6 +19,7 @@ var ( nbdLogger = logging.GetSubsystemLogger("nbd") timesyncLogger = logging.GetSubsystemLogger("timesync") jsonRpcLogger = logging.GetSubsystemLogger("jsonrpc") + hidRPCLogger = logging.GetSubsystemLogger("hidrpc") watchdogLogger = logging.GetSubsystemLogger("watchdog") websecureLogger = logging.GetSubsystemLogger("websecure") otaLogger = logging.GetSubsystemLogger("ota") diff --git a/ui/src/components/InfoBar.tsx b/ui/src/components/InfoBar.tsx index 29f159d..8d0b282 100644 --- a/ui/src/components/InfoBar.tsx +++ b/ui/src/components/InfoBar.tsx @@ -10,10 +10,12 @@ import { VideoState } from "@/hooks/stores"; import { keys, modifiers } from "@/keyboardMappings"; +import { useHidRpc } from "@/hooks/useHidRpc"; export default function InfoBar() { const { keysDownState } = useHidStore(); const { mouseX, mouseY, mouseMove } = useMouseStore(); + const { rpcHidStatus } = useHidRpc(); const videoClientSize = useVideoStore( (state: VideoState) => `${Math.round(state.clientWidth)}x${Math.round(state.clientHeight)}`, @@ -46,7 +48,7 @@ export default function InfoBar() { const modifierNames = Object.entries(modifiers).filter(([_, mask]) => (activeModifierMask & mask) !== 0).map(([name, _]) => name); const keyNames = Object.entries(keys).filter(([_, value]) => keysDown.includes(value)).map(([name, _]) => name); - return [...modifierNames,...keyNames].join(", "); + return [...modifierNames, ...keyNames].join(", "); }, [keysDownState, showPressedKeys]); return ( @@ -100,6 +102,12 @@ export default function InfoBar() { {hdmiState} )} + {debugMode && ( +
+ HidRPC State: + {rpcHidStatus} +
+ )} {showPressedKeys && (
diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index 9e2f0f2..ba6ee5c 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -8,6 +8,7 @@ import InfoBar from "@components/InfoBar"; import notifications from "@/notifications"; import useKeyboard from "@/hooks/useKeyboard"; import { useJsonRpc } from "@/hooks/useJsonRpc"; +import { useHidRpc } from "@/hooks/useHidRpc"; import { cx } from "@/cva.config"; import { keys } from "@/keyboardMappings"; import { @@ -60,10 +61,11 @@ export default function WebRTCVideo() { // Misc states and hooks const { send } = useJsonRpc(); + const { reportAbsMouseEvent, reportRelMouseEvent, rpcHidReady } = useHidRpc(); // Video-related const handleResize = useCallback( - ( { width, height }: { width: number | undefined; height: number | undefined }) => { + ({ width, height }: { width: number | undefined; height: number | undefined }) => { if (!videoElm.current) return; // Do something with width and height, e.g.: setVideoClientSize(width || 0, height || 0); @@ -222,10 +224,22 @@ export default function WebRTCVideo() { if (settings.mouseMode !== "relative") return; // if we ignore the event, double-click will not work // if (x === 0 && y === 0 && buttons === 0) return; - send("relMouseReport", { dx: calcDelta(x), dy: calcDelta(y), buttons }); + const dx = calcDelta(x); + const dy = calcDelta(y); + if (rpcHidReady) { + reportRelMouseEvent(dx, dy, buttons); + } else { + send("relMouseReport", { dx, dy, buttons }); + } setMouseMove({ x, y, buttons }); }, - [send, setMouseMove, settings.mouseMode], + [ + send, + reportRelMouseEvent, + setMouseMove, + settings.mouseMode, + rpcHidReady, + ], ); const relMouseMoveHandler = useCallback( @@ -243,11 +257,21 @@ export default function WebRTCVideo() { const sendAbsMouseMovement = useCallback( (x: number, y: number, buttons: number) => { if (settings.mouseMode !== "absolute") return; - send("absMouseReport", { x, y, buttons }); + if (rpcHidReady) { + reportAbsMouseEvent(x, y, buttons); + } else { + send("absMouseReport", { x, y, buttons }); + } // We set that for the debug info bar setMousePosition(x, y); }, - [send, setMousePosition, settings.mouseMode], + [ + send, + reportAbsMouseEvent, + setMousePosition, + settings.mouseMode, + rpcHidReady, + ], ); const absMouseMoveHandler = useCallback( @@ -357,7 +381,7 @@ export default function WebRTCVideo() { } console.debug(`Key down: ${hidKey}`); handleKeyPress(hidKey, true); - + if (!isKeyboardLockActive && hidKey === keys.MetaLeft) { // If the left meta key was just pressed and we're not keyboard locked // we'll never see the keyup event because the browser is going to lose diff --git a/ui/src/hooks/hidRpc.ts b/ui/src/hooks/hidRpc.ts new file mode 100644 index 0000000..7910039 --- /dev/null +++ b/ui/src/hooks/hidRpc.ts @@ -0,0 +1,303 @@ +import { KeyboardLedState, KeysDownState } from "./stores"; + +export const HID_RPC_MESSAGE_TYPES = { + Handshake: 0x01, + KeyboardReport: 0x02, + PointerReport: 0x03, + WheelReport: 0x04, + KeypressReport: 0x05, + MouseReport: 0x06, + KeyboardLedState: 0x32, + KeysDownState: 0x33, +} + +export type HidRpcMessageType = typeof HID_RPC_MESSAGE_TYPES[keyof typeof HID_RPC_MESSAGE_TYPES]; + +export const HID_RPC_VERSION = 0x01; + +const withinUint8Range = (value: number) => { + return value >= 0 && value <= 255; +}; + +const fromInt32toUint8 = (n: number) => { + if (n !== n >> 0) { + throw new Error(`Number ${n} is not within the int32 range`); + } + + return new Uint8Array([ + (n >> 24) & 0xFF, + (n >> 16) & 0xFF, + (n >> 8) & 0xFF, + (n >> 0) & 0xFF, + ]); +}; + +const fromInt8ToUint8 = (n: number) => { + if (n < -128 || n > 127) { + throw new Error(`Number ${n} is not within the int8 range`); + } + + return (n >> 0) & 0xFF; +}; + +const keyboardLedStateMasks = { + num_lock: 1 << 0, + caps_lock: 1 << 1, + scroll_lock: 1 << 2, + compose: 1 << 3, + kana: 1 << 4, + shift: 1 << 6, +} + +export class RpcMessage { + messageType: HidRpcMessageType; + + constructor(messageType: HidRpcMessageType) { + this.messageType = messageType; + } + + marshal(): Uint8Array { + throw new Error("Not implemented"); + } + + // @ts-expect-error: this is a base class, so we don't need to implement it + public static unmarshal(data: Uint8Array): RpcMessage | undefined { + throw new Error("Not implemented"); + } +} + +export class HandshakeMessage extends RpcMessage { + version: number; + + constructor(version: number) { + super(HID_RPC_MESSAGE_TYPES.Handshake); + this.version = version; + } + + marshal(): Uint8Array { + return new Uint8Array([this.messageType, this.version]); + } + + public static unmarshal(data: Uint8Array): HandshakeMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid handshake message length: ${data.length}`); + } + + return new HandshakeMessage(data[0]); + } +} + +export class KeypressReportMessage extends RpcMessage { + private _key = 0; + private _press = false; + + get key(): number { + return this._key; + } + + set key(value: number) { + if (!withinUint8Range(value)) { + throw new Error(`Key ${value} is not within the uint8 range`); + } + + this._key = value; + } + + get press(): boolean { + return this._press; + } + + set press(value: boolean) { + this._press = value; + } + + constructor(key: number, press: boolean) { + super(HID_RPC_MESSAGE_TYPES.KeypressReport); + this.key = key; + this.press = press; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + this.key, + this.press ? 1 : 0, + ]); + } + + public static unmarshal(data: Uint8Array): KeypressReportMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keypress report message length: ${data.length}`); + } + + return new KeypressReportMessage(data[0], data[1] === 1); + } +} + +export class KeyboardReportMessage extends RpcMessage { + private _keys: number[] = []; + private _modifier = 0; + + get keys(): number[] { + return this._keys; + } + + set keys(value: number[]) { + value.forEach((k) => { + if (!withinUint8Range(k)) { + throw new Error(`Key ${k} is not within the uint8 range`); + } + }); + + this._keys = value; + } + + get modifier(): number { + return this._modifier; + } + + set modifier(value: number) { + if (!withinUint8Range(value)) { + throw new Error(`Modifier ${value} is not within the uint8 range`); + } + + this._modifier = value; + } + + constructor(keys: number[], modifier: number) { + super(HID_RPC_MESSAGE_TYPES.KeyboardReport); + this.keys = keys; + this.modifier = modifier; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + this.modifier, + ...this.keys, + ]); + } + + public static unmarshal(data: Uint8Array): KeyboardReportMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keyboard report message length: ${data.length}`); + } + + return new KeyboardReportMessage(Array.from(data.slice(1)), data[0]); + } +} + +export class KeyboardLedStateMessage extends RpcMessage { + keyboardLedState: KeyboardLedState; + + constructor(keyboardLedState: KeyboardLedState) { + super(HID_RPC_MESSAGE_TYPES.KeyboardLedState); + this.keyboardLedState = keyboardLedState; + } + + public static unmarshal(data: Uint8Array): KeyboardLedStateMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keyboard led state message length: ${data.length}`); + } + + const s = data[0]; + + const state = { + num_lock: (s & keyboardLedStateMasks.num_lock) !== 0, + caps_lock: (s & keyboardLedStateMasks.caps_lock) !== 0, + scroll_lock: (s & keyboardLedStateMasks.scroll_lock) !== 0, + compose: (s & keyboardLedStateMasks.compose) !== 0, + kana: (s & keyboardLedStateMasks.kana) !== 0, + shift: (s & keyboardLedStateMasks.shift) !== 0, + } as KeyboardLedState; + + return new KeyboardLedStateMessage(state); + } +} + +export class KeysDownStateMessage extends RpcMessage { + keysDownState: KeysDownState; + + constructor(keysDownState: KeysDownState) { + super(HID_RPC_MESSAGE_TYPES.KeysDownState); + this.keysDownState = keysDownState; + } + + public static unmarshal(data: Uint8Array): KeysDownStateMessage | undefined { + if (data.length < 1) { + throw new Error(`Invalid keys down state message length: ${data.length}`); + } + + return new KeysDownStateMessage({ + modifier: data[0], + keys: Array.from(data.slice(1)) + }); + } +} + +export class PointerReportMessage extends RpcMessage { + x: number; + y: number; + buttons: number; + + constructor(x: number, y: number, buttons: number) { + super(HID_RPC_MESSAGE_TYPES.PointerReport); + this.x = x; + this.y = y; + this.buttons = buttons; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + ...fromInt32toUint8(this.x), + ...fromInt32toUint8(this.y), + this.buttons, + ]); + } +} + +export class MouseReportMessage extends RpcMessage { + dx: number; + dy: number; + buttons: number; + + constructor(dx: number, dy: number, buttons: number) { + super(HID_RPC_MESSAGE_TYPES.MouseReport); + this.dx = dx; + this.dy = dy; + this.buttons = buttons; + } + + marshal(): Uint8Array { + return new Uint8Array([ + this.messageType, + fromInt8ToUint8(this.dx), + fromInt8ToUint8(this.dy), + this.buttons, + ]); + } +} + +export const messageRegistry = { + [HID_RPC_MESSAGE_TYPES.Handshake]: HandshakeMessage, + [HID_RPC_MESSAGE_TYPES.KeysDownState]: KeysDownStateMessage, + [HID_RPC_MESSAGE_TYPES.KeyboardLedState]: KeyboardLedStateMessage, + [HID_RPC_MESSAGE_TYPES.KeyboardReport]: KeyboardReportMessage, + [HID_RPC_MESSAGE_TYPES.KeypressReport]: KeypressReportMessage, +} + +export const unmarshalHidRpcMessage = (data: Uint8Array): RpcMessage | undefined => { + if (data.length < 1) { + throw new Error(`Invalid HID RPC message length: ${data.length}`); + } + + const payload = data.slice(1); + + const messageType = data[0]; + if (!(messageType in messageRegistry)) { + throw new Error(`Unknown HID RPC message type: ${messageType}`); + } + + return messageRegistry[messageType].unmarshal(payload); +}; \ No newline at end of file diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index a6dc95b..43f838b 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -105,6 +105,12 @@ export interface RTCState { setRpcDataChannel: (channel: RTCDataChannel) => void; rpcDataChannel: RTCDataChannel | null; + rpcHidProtocolVersion: number | null; + setRpcHidProtocolVersion: (version: number) => void; + + rpcHidChannel: RTCDataChannel | null; + setRpcHidChannel: (channel: RTCDataChannel) => void; + peerConnectionState: RTCPeerConnectionState | null; setPeerConnectionState: (state: RTCPeerConnectionState) => void; @@ -151,6 +157,12 @@ export const useRTCStore = create(set => ({ rpcDataChannel: null, setRpcDataChannel: (channel: RTCDataChannel) => set({ rpcDataChannel: channel }), + rpcHidProtocolVersion: null, + setRpcHidProtocolVersion: (version: number) => set({ rpcHidProtocolVersion: version }), + + rpcHidChannel: null, + setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }), + transceiver: null, setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts new file mode 100644 index 0000000..36ba038 --- /dev/null +++ b/ui/src/hooks/useHidRpc.ts @@ -0,0 +1,147 @@ +import { useCallback, useEffect, useMemo } from "react"; + +import { useRTCStore } from "@/hooks/stores"; + +import { + HID_RPC_VERSION, + HandshakeMessage, + KeyboardReportMessage, + KeypressReportMessage, + MouseReportMessage, + PointerReportMessage, + RpcMessage, + unmarshalHidRpcMessage, +} from "./hidRpc"; + +export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { + const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion } = useRTCStore(); + const rpcHidReady = useMemo(() => { + return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + }, [rpcHidChannel, rpcHidProtocolVersion]); + + const rpcHidStatus = useMemo(() => { + if (!rpcHidChannel) return "N/A"; + if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState; + if (!rpcHidProtocolVersion) return "handshaking"; + return `ready (v${rpcHidProtocolVersion})`; + }, [rpcHidChannel, rpcHidProtocolVersion]); + + const sendMessage = useCallback((message: RpcMessage, ignoreHandshakeState = false) => { + if (rpcHidChannel?.readyState !== "open") return; + if (!rpcHidReady && !ignoreHandshakeState) return; + + let data: Uint8Array | undefined; + try { + data = message.marshal(); + } catch (e) { + console.error("Failed to send HID RPC message", e); + } + if (!data) return; + + rpcHidChannel?.send(data as unknown as ArrayBuffer); + }, [rpcHidChannel, rpcHidReady]); + + const reportKeyboardEvent = useCallback( + (keys: number[], modifier: number) => { + sendMessage(new KeyboardReportMessage(keys, modifier)); + }, [sendMessage], + ); + + const reportKeypressEvent = useCallback( + (key: number, press: boolean) => { + sendMessage(new KeypressReportMessage(key, press)); + }, + [sendMessage], + ); + + const reportAbsMouseEvent = useCallback( + (x: number, y: number, buttons: number) => { + sendMessage(new PointerReportMessage(x, y, buttons)); + }, + [sendMessage], + ); + + const reportRelMouseEvent = useCallback( + (dx: number, dy: number, buttons: number) => { + sendMessage(new MouseReportMessage(dx, dy, buttons)); + }, + [sendMessage], + ); + + const sendHandshake = useCallback(() => { + if (rpcHidProtocolVersion) return; + if (!rpcHidChannel) return; + + sendMessage(new HandshakeMessage(HID_RPC_VERSION), true); + }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage]); + + const handleHandshake = useCallback((message: HandshakeMessage) => { + if (!message.version) { + console.error("Received handshake message without version", message); + return; + } + + if (message.version < HID_RPC_VERSION) { + console.error("Server is using an older HID RPC version than the client", message); + return; + } + + setRpcHidProtocolVersion(message.version); + }, [setRpcHidProtocolVersion]); + + useEffect(() => { + if (!rpcHidChannel) return; + + // send handshake message + sendHandshake(); + + const messageHandler = (e: MessageEvent) => { + if (typeof e.data === "string") { + console.warn("Received string data in HID RPC message handler", e.data); + return; + } + + const message = unmarshalHidRpcMessage(new Uint8Array(e.data)); + if (!message) { + console.warn("Received invalid HID RPC message", e.data); + return; + } + + console.debug("Received HID RPC message", message); + switch (message.constructor) { + case HandshakeMessage: + handleHandshake(message as HandshakeMessage); + break; + default: + // not all events are handled here, the rest are handled by the onHidRpcMessage callback + break; + } + + onHidRpcMessage?.(message); + }; + + rpcHidChannel.addEventListener("message", messageHandler); + + return () => { + rpcHidChannel.removeEventListener("message", messageHandler); + }; + }, + [ + rpcHidChannel, + onHidRpcMessage, + setRpcHidProtocolVersion, + sendHandshake, + handleHandshake, + ], + ); + + return { + reportKeyboardEvent, + reportKeypressEvent, + reportAbsMouseEvent, + reportRelMouseEvent, + rpcHidProtocolVersion, + rpcHidReady, + rpcHidStatus, + }; +} diff --git a/ui/src/hooks/useKeyboard.ts b/ui/src/hooks/useKeyboard.ts index 5f587b0..5c6b364 100644 --- a/ui/src/hooks/useKeyboard.ts +++ b/ui/src/hooks/useKeyboard.ts @@ -1,13 +1,15 @@ import { useCallback } from "react"; -import { KeysDownState, useHidStore, useRTCStore, hidKeyBufferSize, hidErrorRollOver } from "@/hooks/stores"; +import { hidErrorRollOver, hidKeyBufferSize, KeysDownState, useHidStore, useRTCStore } from "@/hooks/stores"; import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; +import { useHidRpc } from "@/hooks/useHidRpc"; +import { KeyboardLedStateMessage, KeysDownStateMessage } from "@/hooks/hidRpc"; import { hidKeyToModifierMask, keys, modifiers } from "@/keyboardMappings"; export default function useKeyboard() { const { send } = useJsonRpc(); const { rpcDataChannel } = useRTCStore(); - const { keysDownState, setKeysDownState } = useHidStore(); + const { keysDownState, setKeysDownState, setKeyboardLedState } = useHidStore(); // INTRODUCTION: The earlier version of the JetKVM device shipped with all keyboard state // being tracked on the browser/client-side. When adding the keyPressReport API to the @@ -19,7 +21,22 @@ export default function useKeyboard() { // dynamically set when the device responds to the first key press event or reports its // keysDownState when queried since the keyPressReport was introduced together with the // getKeysDownState API. - const { keyPressReportApiAvailable, setkeyPressReportApiAvailable} = useHidStore(); + const { keyPressReportApiAvailable, setkeyPressReportApiAvailable } = useHidStore(); + + // HidRPC is a binary format for exchanging keyboard and mouse events + const { reportKeyboardEvent, reportKeypressEvent, rpcHidReady } = useHidRpc((message) => { + switch (message.constructor) { + case KeysDownStateMessage: + setKeysDownState((message as KeysDownStateMessage).keysDownState); + setkeyPressReportApiAvailable(true); + break; + case KeyboardLedStateMessage: + setKeyboardLedState((message as KeyboardLedStateMessage).keyboardLedState); + break; + default: + break; + } + }); // sendKeyboardEvent is used to send the full keyboard state to the device for macro handling // and resetting keyboard state. It sends the keys currently pressed and the modifier state. @@ -27,9 +44,16 @@ export default function useKeyboard() { // or just accept the state if it does not support (returning no result) const sendKeyboardEvent = useCallback( async (state: KeysDownState) => { - if (rpcDataChannel?.readyState !== "open") return; + if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; console.debug(`Send keyboardReport keys: ${state.keys}, modifier: ${state.modifier}`); + + if (rpcHidReady) { + console.debug("Sending keyboard report via HidRPC"); + reportKeyboardEvent(state.keys, state.modifier); + return; + } + send("keyboardReport", { keys: state.keys, modifier: state.modifier }, (resp: JsonRpcResponse) => { if ("error" in resp) { console.error(`Failed to send keyboard report ${state}`, resp.error); @@ -44,13 +68,20 @@ export default function useKeyboard() { } else { // older devices versions do not return the keyDownState // so we just pretend they accepted what we sent - setKeysDownState(state); + setKeysDownState(state); setkeyPressReportApiAvailable(false); // we ALSO know they do not support keyPressReport } } }); }, - [rpcDataChannel?.readyState, send, setKeysDownState, setkeyPressReportApiAvailable], + [ + rpcDataChannel?.readyState, + rpcHidReady, + send, + reportKeyboardEvent, + setKeysDownState, + setkeyPressReportApiAvailable, + ], ); // sendKeypressEvent is used to send a single key press/release event to the device. @@ -61,9 +92,16 @@ export default function useKeyboard() { // in client/browser-side code using simulateDeviceSideKeyHandlingForLegacyDevices. const sendKeypressEvent = useCallback( async (key: number, press: boolean) => { - if (rpcDataChannel?.readyState !== "open") return; + if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; console.debug(`Send keypressEvent key: ${key}, press: ${press}`); + + if (rpcHidReady) { + console.debug("Sending keypress event via HidRPC"); + reportKeypressEvent(key, press); + return; + } + send("keypressReport", { key, press }, (resp: JsonRpcResponse) => { if ("error" in resp) { // -32601 means the method is not supported because the device is running an older version @@ -83,7 +121,14 @@ export default function useKeyboard() { } }); }, - [rpcDataChannel?.readyState, send, setkeyPressReportApiAvailable, setKeysDownState], + [ + rpcDataChannel?.readyState, + rpcHidReady, + send, + setkeyPressReportApiAvailable, + setKeysDownState, + reportKeypressEvent, + ], ); // resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers. @@ -135,9 +180,15 @@ export default function useKeyboard() { // It then sends the full keyboard state to the device. const handleKeyPress = useCallback( async (key: number, press: boolean) => { - if (rpcDataChannel?.readyState !== "open") return; + if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; if ((key || 0) === 0) return; // ignore zero key presses (they are bad mappings) + if (rpcHidReady) { + console.debug("Sending keypress event via HidRPC"); + reportKeypressEvent(key, press); + return; + } + if (keyPressReportApiAvailable) { // if the keyPress api is available, we can just send the key press event sendKeypressEvent(key, press); @@ -152,7 +203,16 @@ export default function useKeyboard() { } } }, - [keyPressReportApiAvailable, keysDownState, resetKeyboardState, rpcDataChannel?.readyState, sendKeyboardEvent, sendKeypressEvent], + [ + keyPressReportApiAvailable, + keysDownState, + resetKeyboardState, + rpcDataChannel?.readyState, + rpcHidReady, + sendKeyboardEvent, + sendKeypressEvent, + reportKeypressEvent, + ], ); // IMPORTANT: See the keyPressReportApiAvailable comment above for the reason this exists diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index 9be05f6..7737029 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -135,7 +135,8 @@ export default function KvmIdRoute() { setRpcDataChannel, isTurnServerInUse, setTurnServerInUse, rpcDataChannel, - setTransceiver + setTransceiver, + setRpcHidChannel, } = useRTCStore(); const location = useLocation(); @@ -482,6 +483,12 @@ export default function KvmIdRoute() { setRpcDataChannel(rpcDataChannel); }; + const rpcHidChannel = pc.createDataChannel("hidrpc"); + rpcHidChannel.binaryType = "arraybuffer"; + rpcHidChannel.onopen = () => { + setRpcHidChannel(rpcHidChannel); + }; + setPeerConnection(pc); }, [ cleanupAndStopReconnecting, @@ -492,6 +499,7 @@ export default function KvmIdRoute() { setPeerConnection, setPeerConnectionState, setRpcDataChannel, + setRpcHidChannel, setTransceiver, ]); diff --git a/ui/vite.config.ts b/ui/vite.config.ts index 5871c4b..44eec3a 100644 --- a/ui/vite.config.ts +++ b/ui/vite.config.ts @@ -28,6 +28,9 @@ export default defineConfig(({ mode, command }) => { return { plugins, + esbuild: { + pure: ["console.debug"], + }, build: { outDir: isCloud ? "dist" : "../static" }, server: { host: "0.0.0.0", diff --git a/usb.go b/usb.go index d29e01a..131cd51 100644 --- a/usb.go +++ b/usb.go @@ -27,13 +27,13 @@ func initUsbGadget() { gadget.SetOnKeyboardStateChange(func(state usbgadget.KeyboardState) { if currentSession != nil { - writeJSONRPCEvent("keyboardLedState", state, currentSession) + currentSession.reportHidRPCKeyboardLedState(state) } }) gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { if currentSession != nil { - writeJSONRPCEvent("keysDownState", state, currentSession) + currentSession.reportHidRPCKeysDownState(state) } }) diff --git a/webrtc.go b/webrtc.go index c0f159a..4b26b51 100644 --- a/webrtc.go +++ b/webrtc.go @@ -6,10 +6,12 @@ import ( "encoding/json" "net" "strings" + "sync" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" "github.com/gin-gonic/gin" + "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/logging" "github.com/pion/webrtc/v4" "github.com/rs/zerolog" @@ -22,7 +24,12 @@ type Session struct { RPCChannel *webrtc.DataChannel HidChannel *webrtc.DataChannel shouldUmountVirtualMedia bool - rpcQueue chan webrtc.DataChannelMessage + + rpcQueue chan webrtc.DataChannelMessage + + hidRPCAvailable bool + hidQueueLock sync.Mutex + hidQueue []chan webrtc.DataChannelMessage } type SessionConfig struct { @@ -67,6 +74,23 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) { return base64.StdEncoding.EncodeToString(localDescription), nil } +func (s *Session) initQueues() { + s.hidQueueLock.Lock() + defer s.hidQueueLock.Unlock() + + s.hidQueue = make([]chan webrtc.DataChannelMessage, 0) + for i := 0; i < 4; i++ { + q := make(chan webrtc.DataChannelMessage, 256) + s.hidQueue = append(s.hidQueue, q) + } +} + +func (s *Session) handleQueues(index int) { + for msg := range s.hidQueue[index] { + onHidMessage(msg.Data, s) + } +} + func newSession(config SessionConfig) (*Session, error) { webrtcSettingEngine := webrtc.SettingEngine{ LoggerFactory: logging.GetPionDefaultLoggerFactory(), @@ -105,17 +129,64 @@ func newSession(config SessionConfig) (*Session, error) { scopedLogger.Warn().Err(err).Msg("Failed to create PeerConnection") return nil, err } + session := &Session{peerConnection: peerConnection} session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) + session.initQueues() + go func() { for msg := range session.rpcQueue { - onRPCMessage(msg, session) + // TODO: only use goroutine if the task is asynchronous + go onRPCMessage(msg, session) } }() + for i := 0; i < len(session.hidQueue); i++ { + go session.handleQueues(i) + } + peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { + defer func() { + if r := recover(); r != nil { + scopedLogger.Error().Interface("error", r).Msg("Recovered from panic in DataChannel handler") + } + }() + scopedLogger.Info().Str("label", d.Label()).Uint16("id", *d.ID()).Msg("New DataChannel") + switch d.Label() { + case "hidrpc": + session.HidChannel = d + d.OnMessage(func(msg webrtc.DataChannelMessage) { + l := scopedLogger.With().Str("data", string(msg.Data)).Int("length", len(msg.Data)).Logger() + + if msg.IsString { + l.Warn().Msg("received string data in HID RPC message handler") + return + } + + if len(msg.Data) < 1 { + l.Warn().Msg("received empty data in HID RPC message handler") + return + } + + l.Trace().Msg("received data in HID RPC message handler") + + // Enqueue to ensure ordered processing + queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + if queueIndex >= len(session.hidQueue) || queueIndex < 0 { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") + queueIndex = 3 + } + + queue := session.hidQueue[queueIndex] + if queue != nil { + queue <- msg + } else { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") + return + } + }) case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) { @@ -198,6 +269,13 @@ func newSession(config SessionConfig) (*Session, error) { close(session.rpcQueue) session.rpcQueue = nil } + + // Stop HID RPC processor + for i := 0; i < len(session.hidQueue); i++ { + close(session.hidQueue[i]) + session.hidQueue[i] = nil + } + if session.shouldUmountVirtualMedia { if err := rpcUnmountImage(); err != nil { scopedLogger.Warn().Err(err).Msg("unmount image failed on connection close")