fix issues

This commit is contained in:
Siyuan Miao 2025-08-30 14:01:25 +02:00
parent fefbc7611f
commit 3dd8645295
7 changed files with 54 additions and 32 deletions

View File

@ -8,7 +8,7 @@ import (
"github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/usbgadget"
) )
func handleHidRpcMessage(message hidrpc.Message, session *Session) { func handleHidRPCMessage(message hidrpc.Message, session *Session) {
var rpcErr error var rpcErr error
switch message.Type() { switch message.Type() {
@ -22,11 +22,11 @@ func handleHidRpcMessage(message hidrpc.Message, session *Session) {
logger.Warn().Err(err).Msg("failed to send handshake message") logger.Warn().Err(err).Msg("failed to send handshake message")
return return
} }
session.hidRpcAvailable = true session.hidRPCAvailable = true
case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport:
keysDownState, err := handleHidRpcKeyboardInput(message) keysDownState, err := handleHidRPCKeyboardInput(message)
if keysDownState != nil { if keysDownState != nil {
reportHidRpcKeysDownState(*keysDownState, session) session.reportHidRPCKeysDownState(*keysDownState)
} }
rpcErr = err rpcErr = err
case hidrpc.TypePointerReport: case hidrpc.TypePointerReport:
@ -53,7 +53,7 @@ func handleHidRpcMessage(message hidrpc.Message, session *Session) {
} }
func onHidMessage(data []byte, session *Session) { func onHidMessage(data []byte, session *Session) {
scopedLogger := hidRpcLogger.With().Bytes("data", data).Logger() scopedLogger := hidRPCLogger.With().Bytes("data", data).Logger()
scopedLogger.Debug().Msg("HID RPC message received") scopedLogger.Debug().Msg("HID RPC message received")
if len(data) < 1 { if len(data) < 1 {
@ -74,7 +74,7 @@ func onHidMessage(data []byte, session *Session) {
r := make(chan interface{}) r := make(chan interface{})
go func() { go func() {
handleHidRpcMessage(message, session) handleHidRPCMessage(message, session)
r <- nil r <- nil
}() }()
select { select {
@ -85,7 +85,7 @@ func onHidMessage(data []byte, session *Session) {
} }
} }
func handleHidRpcKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { func handleHidRPCKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) {
switch message.Type() { switch message.Type() {
case hidrpc.TypeKeypressReport: case hidrpc.TypeKeypressReport:
keypressReport, err := message.KeypressReport() keypressReport, err := message.KeypressReport()
@ -108,7 +108,7 @@ func handleHidRpcKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState
return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type()) return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type())
} }
func reportHidRpc(params any, session *Session) { func reportHidRPC(params any, session *Session) {
var ( var (
message []byte message []byte
err error err error
@ -118,6 +118,8 @@ func reportHidRpc(params any, session *Session) {
message, err = hidrpc.NewKeyboardLedMessage(params).Marshal() message, err = hidrpc.NewKeyboardLedMessage(params).Marshal()
case usbgadget.KeysDownState: case usbgadget.KeysDownState:
message, err = hidrpc.NewKeydownStateMessage(params).Marshal() message, err = hidrpc.NewKeydownStateMessage(params).Marshal()
default:
err = fmt.Errorf("unknown HID RPC message type: %T", params)
} }
if err != nil { if err != nil {
@ -135,16 +137,16 @@ func reportHidRpc(params any, session *Session) {
} }
} }
func reportHidRpcKeyboardLedState(state usbgadget.KeyboardState, session *Session) { func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) {
if !session.hidRpcAvailable { if !s.hidRPCAvailable {
writeJSONRPCEvent("keyboardLedState", state, currentSession) writeJSONRPCEvent("keyboardLedState", state, s)
} }
reportHidRpc(state, session) reportHidRPC(state, s)
} }
func reportHidRpcKeysDownState(state usbgadget.KeysDownState, session *Session) { func (s *Session) reportHidRPCKeysDownState(state usbgadget.KeysDownState) {
if !session.hidRpcAvailable { if !s.hidRPCAvailable {
writeJSONRPCEvent("keysDownState", state, currentSession) writeJSONRPCEvent("keysDownState", state, s)
} }
reportHidRpc(state, session) reportHidRPC(state, s)
} }

View File

@ -6,12 +6,8 @@ import (
"github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/usbgadget"
) )
// HID RPC is a variable-length packet format that is used to exchange keyboard and mouse events between the client and the server.
// The packet format is as follows:
// 1 byte: Event Type
// MessageType is the type of the HID RPC message // MessageType is the type of the HID RPC message
type MessageType uint8 type MessageType byte
const ( const (
TypeHandshake MessageType = 0x01 TypeHandshake MessageType = 0x01
@ -24,6 +20,10 @@ const (
TypeKeydownState MessageType = 0x33 TypeKeydownState MessageType = 0x33
) )
const (
Version byte = 0x01 // Version of the HID RPC protocol
)
// ShouldUseEnqueue returns true if the message type should be enqueued to the HID queue. // ShouldUseEnqueue returns true if the message type should be enqueued to the HID queue.
func ShouldUseEnqueue(messageType MessageType) bool { func ShouldUseEnqueue(messageType MessageType) bool {
return messageType == TypeMouseReport return messageType == TypeMouseReport
@ -58,7 +58,7 @@ func Marshal(message *Message) ([]byte, error) {
func NewHandshakeMessage() *Message { func NewHandshakeMessage() *Message {
return &Message{ return &Message{
t: TypeHandshake, t: TypeHandshake,
d: []byte{}, d: []byte{Version},
} }
} }

2
log.go
View File

@ -19,7 +19,7 @@ var (
nbdLogger = logging.GetSubsystemLogger("nbd") nbdLogger = logging.GetSubsystemLogger("nbd")
timesyncLogger = logging.GetSubsystemLogger("timesync") timesyncLogger = logging.GetSubsystemLogger("timesync")
jsonRpcLogger = logging.GetSubsystemLogger("jsonrpc") jsonRpcLogger = logging.GetSubsystemLogger("jsonrpc")
hidRpcLogger = logging.GetSubsystemLogger("hidrpc") hidRPCLogger = logging.GetSubsystemLogger("hidrpc")
watchdogLogger = logging.GetSubsystemLogger("watchdog") watchdogLogger = logging.GetSubsystemLogger("watchdog")
websecureLogger = logging.GetSubsystemLogger("websecure") websecureLogger = logging.GetSubsystemLogger("websecure")
otaLogger = logging.GetSubsystemLogger("ota") otaLogger = logging.GetSubsystemLogger("ota")

View File

@ -108,8 +108,8 @@ export interface RTCState {
rpcHidProtocolVersion: number | null; rpcHidProtocolVersion: number | null;
setRpcHidProtocolVersion: (version: number) => void; setRpcHidProtocolVersion: (version: number) => void;
setRpcHidChannel: (channel: RTCDataChannel) => void;
rpcHidChannel: RTCDataChannel | null; rpcHidChannel: RTCDataChannel | null;
setRpcHidChannel: (channel: RTCDataChannel) => void;
peerConnectionState: RTCPeerConnectionState | null; peerConnectionState: RTCPeerConnectionState | null;
setPeerConnectionState: (state: RTCPeerConnectionState) => void; setPeerConnectionState: (state: RTCPeerConnectionState) => void;

View File

@ -15,6 +15,8 @@ export const HID_RPC_MESSAGE_TYPES = {
export type HidRpcMessageType = typeof HID_RPC_MESSAGE_TYPES[keyof typeof HID_RPC_MESSAGE_TYPES]; export type HidRpcMessageType = typeof HID_RPC_MESSAGE_TYPES[keyof typeof HID_RPC_MESSAGE_TYPES];
export const HID_RPC_VERSION = 0x01;
const withinUint8Range = (value: number) => { const withinUint8Range = (value: number) => {
return value >= 0 && value <= 255; return value >= 0 && value <= 255;
}; };
@ -58,7 +60,7 @@ export const toKeyboardLedState = (s: number): KeyboardLedState => {
num_lock: (s & keyboardLedStateMasks.num_lock) !== 0, num_lock: (s & keyboardLedStateMasks.num_lock) !== 0,
caps_lock: (s & keyboardLedStateMasks.caps_lock) !== 0, caps_lock: (s & keyboardLedStateMasks.caps_lock) !== 0,
scroll_lock: (s & keyboardLedStateMasks.scroll_lock) !== 0, scroll_lock: (s & keyboardLedStateMasks.scroll_lock) !== 0,
compose: (s & keyboardLedStateMasks.compose) !== 0, // TODO: check if this is correct compose: (s & keyboardLedStateMasks.compose) !== 0,
kana: (s & keyboardLedStateMasks.kana) !== 0, kana: (s & keyboardLedStateMasks.kana) !== 0,
shift: (s & keyboardLedStateMasks.shift) !== 0, shift: (s & keyboardLedStateMasks.shift) !== 0,
} as KeyboardLedState; } as KeyboardLedState;
@ -120,11 +122,12 @@ const toKeypressReportEvent = (key: number, press: boolean) => {
}; };
const toHandshakeMessage = () => { const toHandshakeMessage = () => {
return new Uint8Array([HID_RPC_MESSAGE_TYPES.Handshake]); return new Uint8Array([HID_RPC_MESSAGE_TYPES.Handshake, HID_RPC_VERSION]);
}; };
export interface HidRpcMessage { export interface HidRpcMessage {
type: HidRpcMessageType; type: HidRpcMessageType;
version?: number;
keysDownState?: KeysDownState; keysDownState?: KeysDownState;
} }
@ -139,6 +142,7 @@ const unmarshalHidRpcMessage = (data: Uint8Array): HidRpcMessage | undefined =>
case HID_RPC_MESSAGE_TYPES.Handshake: case HID_RPC_MESSAGE_TYPES.Handshake:
return { return {
type: HID_RPC_MESSAGE_TYPES.Handshake, type: HID_RPC_MESSAGE_TYPES.Handshake,
version: payload[0],
}; };
case HID_RPC_MESSAGE_TYPES.KeysDownState: case HID_RPC_MESSAGE_TYPES.KeysDownState:
return { return {
@ -219,7 +223,18 @@ export function useHidRpc(onHidRpcMessage?: (payload: HidRpcMessage) => void) {
} }
if (message.type === HID_RPC_MESSAGE_TYPES.Handshake) { if (message.type === HID_RPC_MESSAGE_TYPES.Handshake) {
setRpcHidProtocolVersion(1); if (!message.version) {
console.error("Received handshake message without version", message);
return;
}
// TODO: use capabilities to determine the supported functions rather than the version
if (message.version < HID_RPC_VERSION) {
console.error("Server is using an older HID RPC version than the client", message);
return;
}
setRpcHidProtocolVersion(message.version);
} }
onHidRpcMessage?.(message); onHidRpcMessage?.(message);

4
usb.go
View File

@ -27,13 +27,13 @@ func initUsbGadget() {
gadget.SetOnKeyboardStateChange(func(state usbgadget.KeyboardState) { gadget.SetOnKeyboardStateChange(func(state usbgadget.KeyboardState) {
if currentSession != nil { if currentSession != nil {
reportHidRpcKeyboardLedState(state, currentSession) currentSession.reportHidRPCKeyboardLedState(state)
} }
}) })
gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) {
if currentSession != nil { if currentSession != nil {
reportHidRpcKeysDownState(state, currentSession) currentSession.reportHidRPCKeysDownState(state)
} }
}) })

View File

@ -23,7 +23,7 @@ type Session struct {
HidChannel *webrtc.DataChannel HidChannel *webrtc.DataChannel
shouldUmountVirtualMedia bool shouldUmountVirtualMedia bool
hidRpcAvailable bool hidRPCAvailable bool
hidQueue chan webrtc.DataChannelMessage hidQueue chan webrtc.DataChannelMessage
rpcQueue chan webrtc.DataChannelMessage rpcQueue chan webrtc.DataChannelMessage
} }
@ -111,7 +111,7 @@ func newSession(config SessionConfig) (*Session, error) {
session := &Session{peerConnection: peerConnection} session := &Session{peerConnection: peerConnection}
session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) session.rpcQueue = make(chan webrtc.DataChannelMessage, 256)
session.hidQueue = make(chan webrtc.DataChannelMessage, 1024) session.hidQueue = make(chan webrtc.DataChannelMessage, 256)
go func() { go func() {
for msg := range session.rpcQueue { for msg := range session.rpcQueue {
@ -129,7 +129,7 @@ func newSession(config SessionConfig) (*Session, error) {
peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
scopedLogger.Warn().Interface("error", r).Msg("Recovered from panic in DataChannel handler") scopedLogger.Error().Interface("error", r).Msg("Recovered from panic in DataChannel handler")
} }
}() }()
@ -235,6 +235,11 @@ func newSession(config SessionConfig) (*Session, error) {
close(session.rpcQueue) close(session.rpcQueue)
session.rpcQueue = nil session.rpcQueue = nil
} }
// Stop HID RPC processor
if session.hidQueue != nil {
close(session.hidQueue)
session.hidQueue = nil
}
if session.shouldUmountVirtualMedia { if session.shouldUmountVirtualMedia {
if err := rpcUnmountImage(); err != nil { if err := rpcUnmountImage(); err != nil {
scopedLogger.Warn().Err(err).Msg("unmount image failed on connection close") scopedLogger.Warn().Err(err).Msg("unmount image failed on connection close")