diff --git a/hidrpc.go b/hidrpc.go index 7b9469f..5925488 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -8,26 +8,8 @@ import ( "github.com/jetkvm/kvm/internal/usbgadget" ) -func onHidMessage(data []byte, session *Session) { - if len(data) < 1 { - logger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler") - return - } - - var ( - message hidrpc.Message - rpcErr error - ) - - if err := hidrpc.Unmarshal(data, &message); err != nil { - logger.Warn().Err(err).Bytes("data", data).Msg("failed to unmarshal HID RPC message") - return - } - - scopedLogger := hidRpcLogger.With().Str("payload", message.String()).Logger() - - scopedLogger.Debug().Msg("received HID RPC message from the queue") - startTime := time.Now() +func handleHidRpcMessage(message hidrpc.Message, session *Session) { + var rpcErr error switch message.Type() { case hidrpc.TypeHandshake: @@ -68,9 +50,39 @@ func onHidMessage(data []byte, session *Session) { if rpcErr != nil { logger.Warn().Err(rpcErr).Msg("failed to handle HID RPC message") } +} - duration := time.Since(startTime) - scopedLogger.Debug().Dur("duration", duration).Msg("handled HID RPC message") +func onHidMessage(data []byte, session *Session) { + scopedLogger := hidRpcLogger.With().Bytes("data", data).Logger() + scopedLogger.Debug().Msg("HID RPC message received") + + if len(data) < 1 { + scopedLogger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler") + return + } + + var message hidrpc.Message + + if err := hidrpc.Unmarshal(data, &message); err != nil { + scopedLogger.Warn().Err(err).Msg("failed to unmarshal HID RPC message") + return + } + + scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + + t := time.Now() + + r := make(chan interface{}) + go func() { + handleHidRpcMessage(message, session) + r <- nil + }() + select { + case <-time.After(1 * time.Second): + scopedLogger.Warn().Msg("HID RPC message timed out") + case <-r: + scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") + } } func handleHidRpcKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) { diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go index d44662c..1d00062 100644 --- a/internal/hidrpc/message.go +++ b/internal/hidrpc/message.go @@ -28,9 +28,9 @@ func (m *Message) String() string { case TypeKeyboardReport: return fmt.Sprintf("KeyboardReport{Modifier: %d, Keys: %v}", m.d[0], m.d[1:]) case TypePointerReport: - return fmt.Sprintf("PointerReport{X: %d, Y: %d, Button: %d}", m.d[0], m.d[1], m.d[2]) + return fmt.Sprintf("PointerReport{X: %d, Y: %d, Button: %d}", m.d[0:4], m.d[4:8], m.d[8]) case TypeMouseReport: - return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2]) + return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0:2], m.d[2:4], m.d[4]) default: return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d) } diff --git a/internal/usbgadget/consts.go b/internal/usbgadget/consts.go index 8204d0a..958aecc 100644 --- a/internal/usbgadget/consts.go +++ b/internal/usbgadget/consts.go @@ -1,3 +1,7 @@ package usbgadget +import "time" + const dwc3Path = "/sys/bus/platform/drivers/dwc3" + +const hidWriteTimeout = 10 * time.Millisecond diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index 2c734d2..8b433cd 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -157,7 +157,9 @@ func (u *UsbGadget) updateKeyDownState(state KeysDownState) { u.keysDownState = state if u.onKeysDownChange != nil { + u.log.Trace().Interface("state", state).Msg("calling onKeysDownChange") (*u.onKeysDownChange)(state) + u.log.Trace().Interface("state", state).Msg("onKeysDownChange called") } } @@ -239,7 +241,7 @@ func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { return err } - _, err := u.keyboardHidFile.Write(append([]byte{modifier, 0x00}, keys[:hidKeyBufferSize]...)) + _, err := writeWithTimeout(u.keyboardHidFile, append([]byte{modifier, 0x00}, keys[:hidKeyBufferSize]...)) if err != nil { u.logWithSuppression("keyboardWriteHidFile", 100, u.log, err, "failed to write to hidg0") u.keyboardHidFile.Close() diff --git a/internal/usbgadget/hid_mouse_absolute.go b/internal/usbgadget/hid_mouse_absolute.go index c083b60..4f6f8d7 100644 --- a/internal/usbgadget/hid_mouse_absolute.go +++ b/internal/usbgadget/hid_mouse_absolute.go @@ -74,7 +74,7 @@ func (u *UsbGadget) absMouseWriteHidFile(data []byte) error { } } - _, err := u.absMouseHidFile.Write(data) + _, err := writeWithTimeout(u.absMouseHidFile, data) if err != nil { u.logWithSuppression("absMouseWriteHidFile", 100, u.log, err, "failed to write to hidg1") u.absMouseHidFile.Close() diff --git a/internal/usbgadget/hid_mouse_relative.go b/internal/usbgadget/hid_mouse_relative.go index 70cb72c..25ec2c1 100644 --- a/internal/usbgadget/hid_mouse_relative.go +++ b/internal/usbgadget/hid_mouse_relative.go @@ -64,7 +64,7 @@ func (u *UsbGadget) relMouseWriteHidFile(data []byte) error { } } - _, err := u.relMouseHidFile.Write(data) + _, err := writeWithTimeout(u.relMouseHidFile, data) if err != nil { u.logWithSuppression("relMouseWriteHidFile", 100, u.log, err, "failed to write to hidg2") u.relMouseHidFile.Close() diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index 05fcd3a..6c295d6 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -3,10 +3,13 @@ package usbgadget import ( "bytes" "encoding/json" + "errors" "fmt" + "os" "path/filepath" "strconv" "strings" + "time" "github.com/rs/zerolog" ) @@ -107,6 +110,23 @@ func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) return false } +func writeWithTimeout(file *os.File, data []byte) (n int, err error) { + if err := file.SetWriteDeadline(time.Now().Add(hidWriteTimeout)); err != nil { + return -1, err + } + + n, err = file.Write(data) + if err == nil { + return + } + + if errors.Is(err, os.ErrDeadlineExceeded) { + err = nil + } + + return +} + func (u *UsbGadget) logWithSuppression(counterName string, every int, logger *zerolog.Logger, err error, msg string, args ...any) { u.logSuppressionLock.Lock() defer u.logSuppressionLock.Unlock()