Compare commits

...

10 Commits

Author SHA1 Message Date
Aveline cc6effa70f
Merge cbf1c7fba6 into ea068414dc 2025-09-11 23:58:18 +02:00
Aveline ea068414dc
feat: validate ssh public key before saving (#794)
* feat: validate ssh public key before saving

* fix: TestValidSSHKeyTypes
2025-09-11 23:32:40 +02:00
Adam Shiervani 8d1a66806c
refactor(ui): Don't fetch KeybardAndMouse Icon on every re-render (#795) 2025-09-11 19:57:35 +02:00
Siyuan Miao cbf1c7fba6 chore: adjust auto release key interval 2025-09-11 17:46:50 +02:00
Siyuan Miao 3901ac91e0 chore: use ordered unreliable channel for pointer events 2025-09-11 17:32:43 +02:00
Siyuan Miao 8e9554f4d6 chore: use unreliable channel to send keepalive events 2025-09-11 16:16:47 +02:00
Siyuan Miao 46aee61565 clean up logging 2025-09-11 15:04:35 +02:00
Siyuan Miao 8b07903388 remove logging 2025-09-11 15:00:03 +02:00
Siyuan Miao 4b42c7e7e3 send keepalive when pressing the key 2025-09-11 14:10:18 +02:00
Siyuan Miao 2f0aa18d1d feat: release keyPress automatically 2025-09-11 13:05:20 +02:00
15 changed files with 698 additions and 115 deletions

View File

@ -29,6 +29,8 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) {
session.reportHidRPCKeysDownState(*keysDownState) session.reportHidRPCKeysDownState(*keysDownState)
} }
rpcErr = err rpcErr = err
case hidrpc.TypeKeypressKeepAliveReport:
gadget.DelayAutoRelease()
case hidrpc.TypePointerReport: case hidrpc.TypePointerReport:
pointerReport, err := message.PointerReport() pointerReport, err := message.PointerReport()
if err != nil { if err != nil {
@ -52,8 +54,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) {
} }
} }
func onHidMessage(data []byte, session *Session) { func onHidMessage(msg hidQueueMessage, session *Session) {
scopedLogger := hidRPCLogger.With().Bytes("data", data).Logger() data := msg.Data
scopedLogger := hidRPCLogger.With().
Str("channel", msg.channel).
Bytes("data", data).
Logger()
scopedLogger.Debug().Msg("HID RPC message received") scopedLogger.Debug().Msg("HID RPC message received")
if len(data) < 1 { if len(data) < 1 {

View File

@ -15,6 +15,7 @@ const (
TypePointerReport MessageType = 0x03 TypePointerReport MessageType = 0x03
TypeWheelReport MessageType = 0x04 TypeWheelReport MessageType = 0x04
TypeKeypressReport MessageType = 0x05 TypeKeypressReport MessageType = 0x05
TypeKeypressKeepAliveReport MessageType = 0x09
TypeMouseReport MessageType = 0x06 TypeMouseReport MessageType = 0x06
TypeKeyboardLedState MessageType = 0x32 TypeKeyboardLedState MessageType = 0x32
TypeKeydownState MessageType = 0x33 TypeKeydownState MessageType = 0x33
@ -98,3 +99,11 @@ func NewKeydownStateMessage(state usbgadget.KeysDownState) *Message {
d: data, d: data,
} }
} }
// NewKeypressKeepAliveMessage creates a new keypress keep alive message.
func NewKeypressKeepAliveMessage() *Message {
return &Message{
t: TypeKeypressKeepAliveReport,
d: []byte{},
}
}

View File

@ -43,6 +43,8 @@ func (m *Message) String() string {
return fmt.Sprintf("MouseReport{Malformed: %v}", m.d) return fmt.Sprintf("MouseReport{Malformed: %v}", m.d)
} }
return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2]) return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2])
case TypeKeypressKeepAliveReport:
return "KeypressKeepAliveReport"
default: default:
return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d) return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d)
} }

View File

@ -22,6 +22,11 @@ var keyboardConfig = gadgetConfigItem{
reportDesc: keyboardReportDesc, reportDesc: keyboardReportDesc,
} }
// macOS default: 15 * 15 = 225ms https://discussions.apple.com/thread/1316947?sortBy=rank
// Linux default: 250ms https://man.archlinux.org/man/kbdrate.8.en
// Windows default: 1ms `HKEY_CURRENT_USER\Control Panel\Accessibility\Keyboard Response\AutoRepeatDelay`
const autoReleaseKeyboardInterval = time.Millisecond * 225
// Source: https://www.kernel.org/doc/Documentation/usb/gadget_hid.txt // Source: https://www.kernel.org/doc/Documentation/usb/gadget_hid.txt
var keyboardReportDesc = []byte{ var keyboardReportDesc = []byte{
0x05, 0x01, /* USAGE_PAGE (Generic Desktop) */ 0x05, 0x01, /* USAGE_PAGE (Generic Desktop) */
@ -173,6 +178,64 @@ func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) {
u.onKeysDownChange = &f u.onKeysDownChange = &f
} }
func (u *UsbGadget) scheduleAutoRelease(key byte) {
u.kbdAutoReleaseLock.Lock()
defer u.kbdAutoReleaseLock.Unlock()
if u.kbdAutoReleaseTimer != nil {
u.kbdAutoReleaseTimer.Stop()
}
u.kbdAutoReleaseTimer = time.AfterFunc(autoReleaseKeyboardInterval, func() {
u.performAutoRelease(key)
})
}
func (u *UsbGadget) cancelAutoRelease() {
u.kbdAutoReleaseLock.Lock()
defer u.kbdAutoReleaseLock.Unlock()
if u.kbdAutoReleaseTimer != nil {
u.kbdAutoReleaseTimer.Stop()
}
}
func (u *UsbGadget) DelayAutoRelease() {
u.kbdAutoReleaseLock.Lock()
defer u.kbdAutoReleaseLock.Unlock()
u.log.Trace().Msg("delaying auto-release")
if u.kbdAutoReleaseTimer == nil {
return
}
u.kbdAutoReleaseTimer.Reset(autoReleaseKeyboardInterval)
u.log.Trace().Msg("auto-release timer reset")
}
func (u *UsbGadget) performAutoRelease(key byte) {
u.kbdAutoReleaseLock.Lock()
defer u.kbdAutoReleaseLock.Unlock()
select {
case <-u.keyboardStateCtx.Done():
return
default:
}
// we just reset the keyboard state to 0 no matter what
_, err := u.keypressReport(0, false, false)
if err != nil {
u.log.Warn().Uint8("key", key).Msg("failed to auto-release keyboard key")
}
u.kbdAutoReleaseTimer = nil
u.log.Trace().Uint8("key", key).Msg("auto release performed")
}
func (u *UsbGadget) listenKeyboardEvents() { func (u *UsbGadget) listenKeyboardEvents() {
var path string var path string
if u.keyboardHidFile != nil { if u.keyboardHidFile != nil {
@ -331,9 +394,10 @@ var KeyCodeToMaskMap = map[byte]byte{
RightSuper: ModifierMaskRightSuper, RightSuper: ModifierMaskRightSuper,
} }
func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) { func (u *UsbGadget) keypressReport(key byte, press bool, autoRelease bool) (KeysDownState, error) {
u.keyboardLock.Lock() u.keyboardLock.Lock()
defer u.keyboardLock.Unlock() defer u.keyboardLock.Unlock()
defer u.resetUserInputTime() defer u.resetUserInputTime()
// IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver // IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver
@ -398,5 +462,25 @@ func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error)
u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keypress report to hidg0") u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keypress report to hidg0")
} }
if press {
{
u.kbdAutoReleaseLock.Lock()
u.kbdAutoReleaseLastKey = key
u.kbdAutoReleaseLock.Unlock()
}
if autoRelease {
u.scheduleAutoRelease(key)
}
} else {
if autoRelease {
u.cancelAutoRelease()
}
}
return u.UpdateKeysDown(modifier, keys), err return u.UpdateKeysDown(modifier, keys), err
} }
func (u *UsbGadget) KeypressReport(key byte, press bool) (KeysDownState, error) {
return u.keypressReport(key, press, true)
}

View File

@ -68,6 +68,10 @@ type UsbGadget struct {
keyboardState byte // keyboard latched state (NumLock, CapsLock, ScrollLock, Compose, Kana) keyboardState byte // keyboard latched state (NumLock, CapsLock, ScrollLock, Compose, Kana)
keysDownState KeysDownState // keyboard dynamic state (modifier keys and pressed keys) keysDownState KeysDownState // keyboard dynamic state (modifier keys and pressed keys)
kbdAutoReleaseLock sync.Mutex
kbdAutoReleaseTimer *time.Timer
kbdAutoReleaseLastKey byte
keyboardStateLock sync.Mutex keyboardStateLock sync.Mutex
keyboardStateCtx context.Context keyboardStateCtx context.Context
keyboardStateCancel context.CancelFunc keyboardStateCancel context.CancelFunc
@ -149,3 +153,35 @@ func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDev
return g return g
} }
// Close cleans up resources used by the USB gadget
func (u *UsbGadget) Close() error {
// Cancel keyboard state context
if u.keyboardStateCancel != nil {
u.keyboardStateCancel()
}
// Stop auto-release timer
u.kbdAutoReleaseLock.Lock()
if u.kbdAutoReleaseTimer != nil {
u.kbdAutoReleaseTimer.Stop()
u.kbdAutoReleaseTimer = nil
}
u.kbdAutoReleaseLock.Unlock()
// Close HID files
if u.keyboardHidFile != nil {
u.keyboardHidFile.Close()
u.keyboardHidFile = nil
}
if u.absMouseHidFile != nil {
u.absMouseHidFile.Close()
u.absMouseHidFile = nil
}
if u.relMouseHidFile != nil {
u.relMouseHidFile.Close()
u.relMouseHidFile = nil
}
return nil
}

71
internal/utils/ssh.go Normal file
View File

@ -0,0 +1,71 @@
package utils
import (
"fmt"
"slices"
"strings"
"golang.org/x/crypto/ssh"
)
// ValidSSHKeyTypes is a list of valid SSH key types
//
// Please make sure that all the types in this list are supported by dropbear
// https://github.com/mkj/dropbear/blob/003c5fcaabc114430d5d14142e95ffdbbd2d19b6/src/signkey.c#L37
//
// ssh-dss is not allowed here as it's insecure
var ValidSSHKeyTypes = []string{
ssh.KeyAlgoRSA,
ssh.KeyAlgoED25519,
ssh.KeyAlgoECDSA256,
ssh.KeyAlgoECDSA384,
ssh.KeyAlgoECDSA521,
}
// ValidateSSHKey validates authorized_keys file content
func ValidateSSHKey(sshKey string) error {
// validate SSH key
var (
hasValidPublicKey = false
lastError = fmt.Errorf("no valid SSH key found")
)
for _, key := range strings.Split(sshKey, "\n") {
key = strings.TrimSpace(key)
// skip empty lines and comments
if key == "" || strings.HasPrefix(key, "#") {
continue
}
parsedPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
if err != nil {
lastError = err
continue
}
if parsedPublicKey == nil {
continue
}
parsedType := parsedPublicKey.Type()
textType := strings.Fields(key)[0]
if parsedType != textType {
lastError = fmt.Errorf("parsed SSH key type %s does not match type in text %s", parsedType, textType)
continue
}
if !slices.Contains(ValidSSHKeyTypes, parsedType) {
lastError = fmt.Errorf("invalid SSH key type: %s", parsedType)
continue
}
hasValidPublicKey = true
}
if !hasValidPublicKey {
return lastError
}
return nil
}

208
internal/utils/ssh_test.go Normal file
View File

@ -0,0 +1,208 @@
package utils
import (
"strings"
"testing"
)
func TestValidateSSHKey(t *testing.T) {
tests := []struct {
name string
sshKey string
expectError bool
errorMsg string
}{
{
name: "valid RSA key",
sshKey: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp test@example.com",
expectError: false,
},
{
name: "valid ED25519 key",
sshKey: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBSbM8wuD5ab0nHsXaYOqaD3GLLUwmDzSk79Xi/N+H2j test@example.com",
expectError: false,
},
{
name: "valid ECDSA key",
sshKey: "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBAlTkxIo4mXBR+gEX0Q74BpYX4bFFHoX+8Uz7tsob8HvsnMvsEE+BW9h9XrbWX4/4ppL/o6sHbvsqNr9HcyKfdc= test@example.com",
expectError: false,
},
{
name: "multiple valid keys",
sshKey: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp test@example.com\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBSbM8wuD5ab0nHsXaYOqaD3GLLUwmDzSk79Xi/N+H2j test@example.com",
expectError: false,
},
{
name: "valid key with comment",
sshKey: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp user@example.com",
expectError: false,
},
{
name: "valid key with options and comment (we don't support options yet)",
sshKey: "command=\"echo hello\" ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp user@example.com",
expectError: true,
},
{
name: "empty string",
sshKey: "",
expectError: true,
errorMsg: "no valid SSH key found",
},
{
name: "whitespace only",
sshKey: " \n\t \n ",
expectError: true,
errorMsg: "no valid SSH key found",
},
{
name: "comment only",
sshKey: "# This is a comment\n# Another comment",
expectError: true,
errorMsg: "no valid SSH key found",
},
{
name: "invalid key format",
sshKey: "not-a-valid-ssh-key",
expectError: true,
},
{
name: "invalid key type",
sshKey: "ssh-dss AAAAB3NzaC1kc3MAAACBAOeB...",
expectError: true,
errorMsg: "invalid SSH key type: ssh-dss",
},
{
name: "unsupported key type",
sshKey: "ssh-rsa-cert-v01@openssh.com AAAAB3NzaC1yc2EAAAADAQABAAABgQC7vbqajDhA...",
expectError: true,
errorMsg: "invalid SSH key type: ssh-rsa-cert-v01@openssh.com",
},
{
name: "malformed key data",
sshKey: "ssh-rsa invalid-base64-data",
expectError: true,
},
{
name: "type mismatch",
sshKey: "ssh-rsa AAAAC3NzaC1lZDI1NTE5AAAAIGomKoH...",
expectError: true,
errorMsg: "parsed SSH key type ssh-ed25519 does not match type in text ssh-rsa",
},
{
name: "mixed valid and invalid keys",
sshKey: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp test@example.com\ninvalid-key\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBSbM8wuD5ab0nHsXaYOqaD3GLLUwmDzSk79Xi/N+H2j test@example.com",
expectError: false,
},
{
name: "valid key with empty lines and comments",
sshKey: "# Comment line\n\nssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp test@example.com\n# Another comment\n\t\n",
expectError: false,
},
{
name: "all invalid keys",
sshKey: "invalid-key-1\ninvalid-key-2\nssh-dss AAAAB3NzaC1kc3MAAACBAOeB...",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateSSHKey(tt.sshKey)
if tt.expectError {
if err == nil {
t.Errorf("ValidateSSHKey() expected error but got none")
} else if tt.errorMsg != "" && !strings.ContainsAny(err.Error(), tt.errorMsg) {
t.Errorf("ValidateSSHKey() error = %v, expected to contain %v", err, tt.errorMsg)
}
} else {
if err != nil {
t.Errorf("ValidateSSHKey() unexpected error = %v", err)
}
}
})
}
}
func TestValidSSHKeyTypes(t *testing.T) {
expectedTypes := []string{
"ssh-rsa",
"ssh-ed25519",
"ecdsa-sha2-nistp256",
"ecdsa-sha2-nistp384",
"ecdsa-sha2-nistp521",
}
if len(ValidSSHKeyTypes) != len(expectedTypes) {
t.Errorf("ValidSSHKeyTypes length = %d, expected %d", len(ValidSSHKeyTypes), len(expectedTypes))
}
for _, expectedType := range expectedTypes {
found := false
for _, actualType := range ValidSSHKeyTypes {
if actualType == expectedType {
found = true
break
}
}
if !found {
t.Errorf("ValidSSHKeyTypes missing expected type: %s", expectedType)
}
}
}
// TestValidateSSHKeyEdgeCases tests edge cases and boundary conditions
func TestValidateSSHKeyEdgeCases(t *testing.T) {
tests := []struct {
name string
sshKey string
expectError bool
}{
{
name: "key with only type",
sshKey: "ssh-rsa",
expectError: true,
},
{
name: "key with type and empty data",
sshKey: "ssh-rsa ",
expectError: true,
},
{
name: "key with type and whitespace data",
sshKey: "ssh-rsa \t ",
expectError: true,
},
{
name: "key with multiple spaces between type and data",
sshKey: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp test@example.com",
expectError: false,
},
{
name: "key with tabs",
sshKey: "\tssh-rsa\tAAAAB3NzaC1yc2EAAAADAQABAAABAQDiYUb9Fy2vlPfO+HwubnshimpVrWPoePyvyN+jPC5gWqZSycjMy6Is2vFVn7oQc72bkY0wZalspT5wUOwKtltSoLpL7vcqGL9zHVw4yjYXtPGIRd3zLpU9wdngevnepPQWTX3LvZTZfmOsrGoMDKIG+Lbmiq/STMuWYecIqMp7tUKRGS8vfAmpu6MsrN9/4UTcdWWXYWJQQn+2nCyMz28jYlWRsKtqFK6owrdZWt8WQnPN+9Upcf2ByQje+0NLnpNrnh+yd2ocuVW9wQYKAZXy7IaTfEJwd5m34sLwkqlZTaBBcmWJU+3RfpYXE763cf3rUoPIGQ8eUEBJ8IdM4vhp test@example.com",
expectError: false,
},
{
name: "very long line",
sshKey: "ssh-rsa " + string(make([]byte, 10000)),
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateSSHKey(tt.sshKey)
if tt.expectError {
if err == nil {
t.Errorf("ValidateSSHKey() expected error but got none")
}
} else {
if err != nil {
t.Errorf("ValidateSSHKey() unexpected error = %v", err)
}
}
})
}
}

View File

@ -17,6 +17,7 @@ import (
"go.bug.st/serial" "go.bug.st/serial"
"github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/usbgadget"
"github.com/jetkvm/kvm/internal/utils"
) )
type JSONRPCRequest struct { type JSONRPCRequest struct {
@ -429,7 +430,19 @@ func rpcGetSSHKeyState() (string, error) {
} }
func rpcSetSSHKeyState(sshKey string) error { func rpcSetSSHKeyState(sshKey string) error {
if sshKey != "" { if sshKey == "" {
// Remove SSH key file if empty string is provided
if err := os.Remove(sshKeyFile); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove SSH key file: %w", err)
}
return nil
}
// Validate SSH key
if err := utils.ValidateSSHKey(sshKey); err != nil {
return err
}
// Create directory if it doesn't exist // Create directory if it doesn't exist
if err := os.MkdirAll(sshKeyDir, 0700); err != nil { if err := os.MkdirAll(sshKeyDir, 0700); err != nil {
return fmt.Errorf("failed to create SSH key directory: %w", err) return fmt.Errorf("failed to create SSH key directory: %w", err)
@ -439,12 +452,6 @@ func rpcSetSSHKeyState(sshKey string) error {
if err := os.WriteFile(sshKeyFile, []byte(sshKey), 0600); err != nil { if err := os.WriteFile(sshKeyFile, []byte(sshKey), 0600); err != nil {
return fmt.Errorf("failed to write SSH key: %w", err) return fmt.Errorf("failed to write SSH key: %w", err)
} }
} else {
// Remove SSH key file if empty string is provided
if err := os.Remove(sshKeyFile); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove SSH key file: %w", err)
}
}
return nil return nil
} }

View File

@ -22,14 +22,6 @@ const USBStateMap: Record<USBStates, string> = {
"not attached": "Disconnected", "not attached": "Disconnected",
suspended: "Low power mode", suspended: "Low power mode",
}; };
export default function USBStateStatus({
state,
peerConnectionState,
}: {
state: USBStates;
peerConnectionState?: RTCPeerConnectionState | null;
}) {
const StatusCardProps: StatusProps = { const StatusCardProps: StatusProps = {
configured: { configured: {
icon: ({ className }) => ( icon: ({ className }) => (
@ -63,6 +55,15 @@ export default function USBStateStatus({
statusIndicatorClassName: "bg-green-500 border-green-600", statusIndicatorClassName: "bg-green-500 border-green-600",
}, },
}; };
export default function USBStateStatus({
state,
peerConnectionState,
}: {
state: USBStates;
peerConnectionState?: RTCPeerConnectionState | null;
}) {
const props = StatusCardProps[state]; const props = StatusCardProps[state];
if (!props) { if (!props) {
console.warn("Unsupported USB state: ", state); console.warn("Unsupported USB state: ", state);

View File

@ -6,6 +6,7 @@ export const HID_RPC_MESSAGE_TYPES = {
PointerReport: 0x03, PointerReport: 0x03,
WheelReport: 0x04, WheelReport: 0x04,
KeypressReport: 0x05, KeypressReport: 0x05,
KeypressKeepAliveReport: 0x09,
MouseReport: 0x06, MouseReport: 0x06,
KeyboardLedState: 0x32, KeyboardLedState: 0x32,
KeysDownState: 0x33, KeysDownState: 0x33,
@ -278,12 +279,23 @@ export class MouseReportMessage extends RpcMessage {
} }
} }
export class KeypressKeepAliveMessage extends RpcMessage {
constructor() {
super(HID_RPC_MESSAGE_TYPES.KeypressKeepAliveReport);
}
marshal(): Uint8Array {
return new Uint8Array([this.messageType]);
}
}
export const messageRegistry = { export const messageRegistry = {
[HID_RPC_MESSAGE_TYPES.Handshake]: HandshakeMessage, [HID_RPC_MESSAGE_TYPES.Handshake]: HandshakeMessage,
[HID_RPC_MESSAGE_TYPES.KeysDownState]: KeysDownStateMessage, [HID_RPC_MESSAGE_TYPES.KeysDownState]: KeysDownStateMessage,
[HID_RPC_MESSAGE_TYPES.KeyboardLedState]: KeyboardLedStateMessage, [HID_RPC_MESSAGE_TYPES.KeyboardLedState]: KeyboardLedStateMessage,
[HID_RPC_MESSAGE_TYPES.KeyboardReport]: KeyboardReportMessage, [HID_RPC_MESSAGE_TYPES.KeyboardReport]: KeyboardReportMessage,
[HID_RPC_MESSAGE_TYPES.KeypressReport]: KeypressReportMessage, [HID_RPC_MESSAGE_TYPES.KeypressReport]: KeypressReportMessage,
[HID_RPC_MESSAGE_TYPES.KeypressKeepAliveReport]: KeypressKeepAliveMessage,
} }
export const unmarshalHidRpcMessage = (data: Uint8Array): RpcMessage | undefined => { export const unmarshalHidRpcMessage = (data: Uint8Array): RpcMessage | undefined => {

View File

@ -111,6 +111,12 @@ export interface RTCState {
rpcHidChannel: RTCDataChannel | null; rpcHidChannel: RTCDataChannel | null;
setRpcHidChannel: (channel: RTCDataChannel) => void; setRpcHidChannel: (channel: RTCDataChannel) => void;
rpcHidUnreliableChannel: RTCDataChannel | null;
setRpcHidUnreliableChannel: (channel: RTCDataChannel) => void;
rpcHidUnreliableNonOrderedChannel: RTCDataChannel | null;
setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => void;
peerConnectionState: RTCPeerConnectionState | null; peerConnectionState: RTCPeerConnectionState | null;
setPeerConnectionState: (state: RTCPeerConnectionState) => void; setPeerConnectionState: (state: RTCPeerConnectionState) => void;
@ -163,6 +169,12 @@ export const useRTCStore = create<RTCState>(set => ({
rpcHidChannel: null, rpcHidChannel: null,
setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }), setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }),
rpcHidUnreliableChannel: null,
setRpcHidUnreliableChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableChannel: channel }),
rpcHidUnreliableNonOrderedChannel: null,
setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableNonOrderedChannel: channel }),
transceiver: null, transceiver: null,
setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }),

View File

@ -6,6 +6,7 @@ import {
HID_RPC_VERSION, HID_RPC_VERSION,
HandshakeMessage, HandshakeMessage,
KeyboardReportMessage, KeyboardReportMessage,
KeypressKeepAliveMessage,
KeypressReportMessage, KeypressReportMessage,
MouseReportMessage, MouseReportMessage,
PointerReportMessage, PointerReportMessage,
@ -13,20 +14,43 @@ import {
unmarshalHidRpcMessage, unmarshalHidRpcMessage,
} from "./hidRpc"; } from "./hidRpc";
const KEEPALIVE_MESSAGE = new KeypressKeepAliveMessage();
interface sendMessageParams {
ignoreHandshakeState?: boolean;
useUnreliableChannel?: boolean;
requireOrdered?: boolean;
}
export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion } = useRTCStore(); const {
rpcHidChannel,
rpcHidUnreliableChannel,
rpcHidUnreliableNonOrderedChannel,
setRpcHidProtocolVersion,
rpcHidProtocolVersion,
} = useRTCStore();
const rpcHidReady = useMemo(() => { const rpcHidReady = useMemo(() => {
return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null;
}, [rpcHidChannel, rpcHidProtocolVersion]); }, [rpcHidChannel, rpcHidProtocolVersion]);
const rpcHidUnreliableReady = useMemo(() => {
return rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null;
}, [rpcHidUnreliableChannel, rpcHidProtocolVersion]);
const rpcHidUnreliableNonOrderedReady = useMemo(() => {
return rpcHidUnreliableNonOrderedChannel?.readyState === "open" && rpcHidProtocolVersion !== null;
}, [rpcHidUnreliableNonOrderedChannel, rpcHidProtocolVersion]);
const rpcHidStatus = useMemo(() => { const rpcHidStatus = useMemo(() => {
if (!rpcHidChannel) return "N/A"; if (!rpcHidChannel) return "N/A";
if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState; if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState;
if (!rpcHidProtocolVersion) return "handshaking"; if (!rpcHidProtocolVersion) return "handshaking";
return `ready (v${rpcHidProtocolVersion})`; return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`;
}, [rpcHidChannel, rpcHidProtocolVersion]); }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion]);
const sendMessage = useCallback((message: RpcMessage, ignoreHandshakeState = false) => { const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel, requireOrdered = true }: sendMessageParams = {}) => {
if (rpcHidChannel?.readyState !== "open") return; if (rpcHidChannel?.readyState !== "open") return;
if (!rpcHidReady && !ignoreHandshakeState) return; if (!rpcHidReady && !ignoreHandshakeState) return;
@ -38,8 +62,24 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
} }
if (!data) return; if (!data) return;
if (useUnreliableChannel) {
if (requireOrdered && rpcHidUnreliableReady) {
rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer);
} else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) {
rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer);
}
return;
}
rpcHidChannel?.send(data as unknown as ArrayBuffer); rpcHidChannel?.send(data as unknown as ArrayBuffer);
}, [rpcHidChannel, rpcHidReady]); }, [
rpcHidChannel,
rpcHidUnreliableChannel,
rpcHidUnreliableNonOrderedChannel,
rpcHidReady,
rpcHidUnreliableReady,
rpcHidUnreliableNonOrderedReady,
]);
const reportKeyboardEvent = useCallback( const reportKeyboardEvent = useCallback(
(keys: number[], modifier: number) => { (keys: number[], modifier: number) => {
@ -56,7 +96,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
const reportAbsMouseEvent = useCallback( const reportAbsMouseEvent = useCallback(
(x: number, y: number, buttons: number) => { (x: number, y: number, buttons: number) => {
sendMessage(new PointerReportMessage(x, y, buttons)); sendMessage(new PointerReportMessage(x, y, buttons), { useUnreliableChannel: true });
}, },
[sendMessage], [sendMessage],
); );
@ -68,11 +108,15 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
[sendMessage], [sendMessage],
); );
const reportKeypressKeepAlive = useCallback(() => {
sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true, requireOrdered: false });
}, [sendMessage]);
const sendHandshake = useCallback(() => { const sendHandshake = useCallback(() => {
if (rpcHidProtocolVersion) return; if (rpcHidProtocolVersion) return;
if (!rpcHidChannel) return; if (!rpcHidChannel) return;
sendMessage(new HandshakeMessage(HID_RPC_VERSION), true); sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true });
}, [rpcHidChannel, rpcHidProtocolVersion, sendMessage]); }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage]);
const handleHandshake = useCallback((message: HandshakeMessage) => { const handleHandshake = useCallback((message: HandshakeMessage) => {
@ -143,6 +187,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
reportKeypressEvent, reportKeypressEvent,
reportAbsMouseEvent, reportAbsMouseEvent,
reportRelMouseEvent, reportRelMouseEvent,
reportKeypressKeepAlive,
rpcHidProtocolVersion, rpcHidProtocolVersion,
rpcHidReady, rpcHidReady,
rpcHidStatus, rpcHidStatus,

View File

@ -1,4 +1,4 @@
import { useCallback } from "react"; import { useCallback, useRef } from "react";
import { hidErrorRollOver, hidKeyBufferSize, KeysDownState, useHidStore, useRTCStore } from "@/hooks/stores"; import { hidErrorRollOver, hidKeyBufferSize, KeysDownState, useHidStore, useRTCStore } from "@/hooks/stores";
import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc"; import { JsonRpcResponse, useJsonRpc } from "@/hooks/useJsonRpc";
@ -11,6 +11,9 @@ export default function useKeyboard() {
const { rpcDataChannel } = useRTCStore(); const { rpcDataChannel } = useRTCStore();
const { keysDownState, setKeysDownState, setKeyboardLedState } = useHidStore(); const { keysDownState, setKeysDownState, setKeyboardLedState } = useHidStore();
// Keepalive timer management
const keepAliveTimerRef = useRef<number | null>(null);
// INTRODUCTION: The earlier version of the JetKVM device shipped with all keyboard state // INTRODUCTION: The earlier version of the JetKVM device shipped with all keyboard state
// being tracked on the browser/client-side. When adding the keyPressReport API to the // being tracked on the browser/client-side. When adding the keyPressReport API to the
// device-side code, we have to still support the situation where the browser/client-side code // device-side code, we have to still support the situation where the browser/client-side code
@ -26,6 +29,7 @@ export default function useKeyboard() {
const { const {
reportKeyboardEvent: sendKeyboardEventHidRpc, reportKeyboardEvent: sendKeyboardEventHidRpc,
reportKeypressEvent: sendKeypressEventHidRpc, reportKeypressEvent: sendKeypressEventHidRpc,
reportKeypressKeepAlive: sendKeypressKeepAliveHidRpc,
rpcHidReady, rpcHidReady,
} = useHidRpc((message) => { } = useHidRpc((message) => {
switch (message.constructor) { switch (message.constructor) {
@ -70,17 +74,6 @@ export default function useKeyboard() {
], ],
); );
// resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers.
// This is useful for macros and when the browser loses focus to ensure that the keyboard state
// is clean.
const resetKeyboardState = useCallback(
async () => {
// Reset the keys buffer to zeros and the modifier state to zero
keysDownState.keys.length = hidKeyBufferSize;
keysDownState.keys.fill(0);
keysDownState.modifier = 0;
sendKeyboardEvent(keysDownState);
}, [keysDownState, sendKeyboardEvent]);
// executeMacro is used to execute a macro consisting of multiple steps. // executeMacro is used to execute a macro consisting of multiple steps.
// Each step can have multiple keys, multiple modifiers and a delay. // Each step can have multiple keys, multiple modifiers and a delay.
@ -111,12 +104,61 @@ export default function useKeyboard() {
} }
}; };
const KEEPALIVE_INTERVAL = 75; // 200ms interval
const cancelKeepAlive = useCallback(() => {
if (keepAliveTimerRef.current) {
clearInterval(keepAliveTimerRef.current);
keepAliveTimerRef.current = null;
}
}, []);
const scheduleKeepAlive = useCallback(() => {
// Clear existing timer if it exists
if (keepAliveTimerRef.current) {
clearInterval(keepAliveTimerRef.current);
}
// Create new interval timer
keepAliveTimerRef.current = setInterval(() => {
sendKeypressKeepAliveHidRpc();
}, KEEPALIVE_INTERVAL);
}, [sendKeypressKeepAliveHidRpc]);
// resetKeyboardState is used to reset the keyboard state to no keys pressed and no modifiers.
// This is useful for macros and when the browser loses focus to ensure that the keyboard state
// is clean.
const resetKeyboardState = useCallback(async () => {
// Cancel keepalive since we're resetting the keyboard state
cancelKeepAlive();
// Reset the keys buffer to zeros and the modifier state to zero
keysDownState.keys.length = hidKeyBufferSize;
keysDownState.keys.fill(0);
keysDownState.modifier = 0;
sendKeyboardEvent(keysDownState);
}, [keysDownState, sendKeyboardEvent, cancelKeepAlive]);
// handleKeyPress is used to handle a key press or release event. // handleKeyPress is used to handle a key press or release event.
// This function handle both key press and key release events. // This function handle both key press and key release events.
// It checks if the keyPressReport API is available and sends the key press event. // It checks if the keyPressReport API is available and sends the key press event.
// If the keyPressReport API is not available, it simulates the device-side key // If the keyPressReport API is not available, it simulates the device-side key
// handling for legacy devices and updates the keysDownState accordingly. // handling for legacy devices and updates the keysDownState accordingly.
// It then sends the full keyboard state to the device. // It then sends the full keyboard state to the device.
const sendKeypress = useCallback(
(key: number, press: boolean) => {
cancelKeepAlive();
sendKeypressEventHidRpc(key, press);
if (press) {
scheduleKeepAlive();
}
},
[sendKeypressEventHidRpc, scheduleKeepAlive, cancelKeepAlive],
);
const handleKeyPress = useCallback( const handleKeyPress = useCallback(
async (key: number, press: boolean) => { async (key: number, press: boolean) => {
if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return; if (rpcDataChannel?.readyState !== "open" && !rpcHidReady) return;
@ -129,7 +171,7 @@ export default function useKeyboard() {
// Older device version doesn't support this API, so we will switch to local key handling // Older device version doesn't support this API, so we will switch to local key handling
// In that case we will switch to local key handling and update the keysDownState // In that case we will switch to local key handling and update the keysDownState
// in client/browser-side code using simulateDeviceSideKeyHandlingForLegacyDevices. // in client/browser-side code using simulateDeviceSideKeyHandlingForLegacyDevices.
sendKeypressEventHidRpc(key, press); sendKeypress(key, press);
} else { } else {
// if the keyPress api is not available, we need to handle the key locally // if the keyPress api is not available, we need to handle the key locally
const downState = simulateDeviceSideKeyHandlingForLegacyDevices(keysDownState, key, press); const downState = simulateDeviceSideKeyHandlingForLegacyDevices(keysDownState, key, press);
@ -147,7 +189,7 @@ export default function useKeyboard() {
resetKeyboardState, resetKeyboardState,
rpcDataChannel?.readyState, rpcDataChannel?.readyState,
sendKeyboardEvent, sendKeyboardEvent,
sendKeypressEventHidRpc, sendKeypress,
], ],
); );
@ -210,5 +252,10 @@ export default function useKeyboard() {
return { modifier: modifiers, keys }; return { modifier: modifiers, keys };
} }
return { handleKeyPress, resetKeyboardState, executeMacro }; // Cleanup function to cancel keepalive timer
const cleanup = useCallback(() => {
cancelKeepAlive();
}, [cancelKeepAlive]);
return { handleKeyPress, resetKeyboardState, executeMacro, cleanup };
} }

View File

@ -136,6 +136,8 @@ export default function KvmIdRoute() {
rpcDataChannel, rpcDataChannel,
setTransceiver, setTransceiver,
setRpcHidChannel, setRpcHidChannel,
setRpcHidUnreliableNonOrderedChannel,
setRpcHidUnreliableChannel,
} = useRTCStore(); } = useRTCStore();
const location = useLocation(); const location = useLocation();
@ -488,6 +490,24 @@ export default function KvmIdRoute() {
setRpcHidChannel(rpcHidChannel); setRpcHidChannel(rpcHidChannel);
}; };
const rpcHidUnreliableChannel = pc.createDataChannel("hidrpc-unreliable-ordered", {
ordered: true,
maxRetransmits: 0,
});
rpcHidUnreliableChannel.binaryType = "arraybuffer";
rpcHidUnreliableChannel.onopen = () => {
setRpcHidUnreliableChannel(rpcHidUnreliableChannel);
};
const rpcHidUnreliableNonOrderedChannel = pc.createDataChannel("hidrpc-unreliable-nonordered", {
ordered: false,
maxRetransmits: 0,
});
rpcHidUnreliableNonOrderedChannel.binaryType = "arraybuffer";
rpcHidUnreliableNonOrderedChannel.onopen = () => {
setRpcHidUnreliableNonOrderedChannel(rpcHidUnreliableNonOrderedChannel);
};
setPeerConnection(pc); setPeerConnection(pc);
}, [ }, [
cleanupAndStopReconnecting, cleanupAndStopReconnecting,
@ -499,6 +519,8 @@ export default function KvmIdRoute() {
setPeerConnectionState, setPeerConnectionState,
setRpcDataChannel, setRpcDataChannel,
setRpcHidChannel, setRpcHidChannel,
setRpcHidUnreliableNonOrderedChannel,
setRpcHidUnreliableChannel,
setTransceiver, setTransceiver,
]); ]);

View File

@ -29,7 +29,12 @@ type Session struct {
hidRPCAvailable bool hidRPCAvailable bool
hidQueueLock sync.Mutex hidQueueLock sync.Mutex
hidQueue []chan webrtc.DataChannelMessage hidQueue []chan hidQueueMessage
}
type hidQueueMessage struct {
webrtc.DataChannelMessage
channel string
} }
type SessionConfig struct { type SessionConfig struct {
@ -78,16 +83,59 @@ func (s *Session) initQueues() {
s.hidQueueLock.Lock() s.hidQueueLock.Lock()
defer s.hidQueueLock.Unlock() defer s.hidQueueLock.Unlock()
s.hidQueue = make([]chan webrtc.DataChannelMessage, 0) s.hidQueue = make([]chan hidQueueMessage, 0)
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
q := make(chan webrtc.DataChannelMessage, 256) q := make(chan hidQueueMessage, 256)
s.hidQueue = append(s.hidQueue, q) s.hidQueue = append(s.hidQueue, q)
} }
} }
func (s *Session) handleQueues(index int) { func (s *Session) handleQueues(index int) {
for msg := range s.hidQueue[index] { for msg := range s.hidQueue[index] {
onHidMessage(msg.Data, s) onHidMessage(msg, s)
}
}
func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) {
return func(msg webrtc.DataChannelMessage) {
l := scopedLogger.With().
Str("channel", channel).
Int("length", len(msg.Data)).
Logger()
// only log data if the log level is debug or lower
if scopedLogger.GetLevel() > zerolog.DebugLevel {
l = l.With().Str("data", string(msg.Data)).Logger()
}
if msg.IsString {
l.Warn().Msg("received string data in HID RPC message handler")
return
}
if len(msg.Data) < 1 {
l.Warn().Msg("received empty data in HID RPC message handler")
return
}
l.Trace().Msg("received data in HID RPC message handler")
// Enqueue to ensure ordered processing
queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0]))
if queueIndex >= len(session.hidQueue) || queueIndex < 0 {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found")
queueIndex = 3
}
queue := session.hidQueue[queueIndex]
if queue != nil {
queue <- hidQueueMessage{
DataChannelMessage: msg,
channel: channel,
}
} else {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil")
return
}
} }
} }
@ -157,40 +205,12 @@ func newSession(config SessionConfig) (*Session, error) {
switch d.Label() { switch d.Label() {
case "hidrpc": case "hidrpc":
session.HidChannel = d session.HidChannel = d
d.OnMessage(func(msg webrtc.DataChannelMessage) { d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc"))
l := scopedLogger.With().Int("length", len(msg.Data)).Logger() // we won't send anything over the unreliable channels
// only log data if the log level is debug or lower case "hidrpc-unreliable-ordered":
if scopedLogger.GetLevel() > zerolog.DebugLevel { d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-ordered"))
l = l.With().Str("data", string(msg.Data)).Logger() case "hidrpc-unreliable-nonordered":
} d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-nonordered"))
if msg.IsString {
l.Warn().Msg("received string data in HID RPC message handler")
return
}
if len(msg.Data) < 1 {
l.Warn().Msg("received empty data in HID RPC message handler")
return
}
l.Trace().Msg("received data in HID RPC message handler")
// Enqueue to ensure ordered processing
queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0]))
if queueIndex >= len(session.hidQueue) || queueIndex < 0 {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found")
queueIndex = 3
}
queue := session.hidQueue[queueIndex]
if queue != nil {
queue <- msg
} else {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil")
return
}
})
case "rpc": case "rpc":
session.RPCChannel = d session.RPCChannel = d
d.OnMessage(func(msg webrtc.DataChannelMessage) { d.OnMessage(func(msg webrtc.DataChannelMessage) {