diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index 2c4a456..3b9f6c2 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -261,7 +261,23 @@ func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { return nil } -func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error { +func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { + // if we just reported an error roll over, we should clear the keys + if keys[0] == hidErrorRollOver { + for i := range keys { + keys[i] = 0 + } + } + + downState := KeysDownState{ + Modifier: modifier, + Keys: []byte(keys[:]), + } + u.updateKeyDownState(downState) + return downState +} + +func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) (KeysDownState, error) { u.keyboardLock.Lock() defer u.keyboardLock.Unlock() defer u.resetUserInputTime() @@ -273,7 +289,12 @@ func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error { keys = append(keys, make([]byte, hidKeyBufferSize-len(keys))...) } - return u.keyboardWriteHidFile(modifier, keys) + err := u.keyboardWriteHidFile(modifier, keys) + if err != nil { + u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") + } + + return u.UpdateKeysDown(modifier, keys), err } const ( @@ -357,16 +378,10 @@ func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) } } - if err := u.keyboardWriteHidFile(modifier, keys); err != nil { + err := u.keyboardWriteHidFile(modifier, keys) + if err != nil { u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keypress report to hidg0") } - var result = KeysDownState{ - Modifier: modifier, - Keys: []byte(keys[:]), - } - - u.updateKeyDownState(result) - - return result, nil + return u.UpdateKeysDown(modifier, keys), err } diff --git a/jsonrpc.go b/jsonrpc.go index 7d05933..321c1d3 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -134,7 +134,6 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - scopedLogger.Trace().Msg("Calling RPC handler") result, err := callRPCHandler(scopedLogger, handler, request.Params) if err != nil { scopedLogger.Error().Err(err).Msg("Error calling RPC handler") diff --git a/usb.go b/usb.go index 5c8036c..d29e01a 100644 --- a/usb.go +++ b/usb.go @@ -43,7 +43,7 @@ func initUsbGadget() { } } -func rpcKeyboardReport(modifier byte, keys []byte) error { +func rpcKeyboardReport(modifier byte, keys []byte) (usbgadget.KeysDownState, error) { return gadget.KeyboardReport(modifier, keys) }