diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go index e56eb69..e9c8c24 100644 --- a/internal/hidrpc/hidrpc.go +++ b/internal/hidrpc/hidrpc.go @@ -24,9 +24,18 @@ const ( Version byte = 0x01 // Version of the HID RPC protocol ) -// ShouldUseEnqueue returns true if the message type should be enqueued to the HID queue. -func ShouldUseEnqueue(messageType MessageType) bool { - return messageType == TypeMouseReport +// GetQueueIndex returns the index of the queue to which the message should be enqueued. +func GetQueueIndex(messageType MessageType) int { + switch messageType { + case TypeHandshake: + return 0 + case TypeKeyboardReport, TypeKeypressReport, TypeKeyboardLedState, TypeKeydownState: + return 1 + case TypePointerReport, TypeMouseReport, TypeWheelReport: + return 2 + default: + return 3 + } } // Unmarshal unmarshals the HID RPC message from the data. diff --git a/webrtc.go b/webrtc.go index bbfce0a..4b26b51 100644 --- a/webrtc.go +++ b/webrtc.go @@ -6,10 +6,12 @@ import ( "encoding/json" "net" "strings" + "sync" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" "github.com/gin-gonic/gin" + "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/logging" "github.com/pion/webrtc/v4" "github.com/rs/zerolog" @@ -23,9 +25,11 @@ type Session struct { HidChannel *webrtc.DataChannel shouldUmountVirtualMedia bool + rpcQueue chan webrtc.DataChannelMessage + hidRPCAvailable bool - hidQueue chan webrtc.DataChannelMessage - rpcQueue chan webrtc.DataChannelMessage + hidQueueLock sync.Mutex + hidQueue []chan webrtc.DataChannelMessage } type SessionConfig struct { @@ -70,6 +74,23 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) { return base64.StdEncoding.EncodeToString(localDescription), nil } +func (s *Session) initQueues() { + s.hidQueueLock.Lock() + defer s.hidQueueLock.Unlock() + + s.hidQueue = make([]chan webrtc.DataChannelMessage, 0) + for i := 0; i < 4; i++ { + q := make(chan webrtc.DataChannelMessage, 256) + s.hidQueue = append(s.hidQueue, q) + } +} + +func (s *Session) handleQueues(index int) { + for msg := range s.hidQueue[index] { + onHidMessage(msg.Data, s) + } +} + func newSession(config SessionConfig) (*Session, error) { webrtcSettingEngine := webrtc.SettingEngine{ LoggerFactory: logging.GetPionDefaultLoggerFactory(), @@ -111,7 +132,7 @@ func newSession(config SessionConfig) (*Session, error) { session := &Session{peerConnection: peerConnection} session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) - session.hidQueue = make(chan webrtc.DataChannelMessage, 256) + session.initQueues() go func() { for msg := range session.rpcQueue { @@ -120,11 +141,9 @@ func newSession(config SessionConfig) (*Session, error) { } }() - go func() { - for msg := range session.hidQueue { - onHidMessage(msg.Data, session) - } - }() + for i := 0; i < len(session.hidQueue); i++ { + go session.handleQueues(i) + } peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { defer func() { @@ -154,7 +173,19 @@ func newSession(config SessionConfig) (*Session, error) { l.Trace().Msg("received data in HID RPC message handler") // Enqueue to ensure ordered processing - session.hidQueue <- msg + queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + if queueIndex >= len(session.hidQueue) || queueIndex < 0 { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") + queueIndex = 3 + } + + queue := session.hidQueue[queueIndex] + if queue != nil { + queue <- msg + } else { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") + return + } }) case "rpc": session.RPCChannel = d @@ -238,11 +269,13 @@ func newSession(config SessionConfig) (*Session, error) { close(session.rpcQueue) session.rpcQueue = nil } + // Stop HID RPC processor - if session.hidQueue != nil { - close(session.hidQueue) - session.hidQueue = nil + for i := 0; i < len(session.hidQueue); i++ { + close(session.hidQueue[i]) + session.hidQueue[i] = nil } + if session.shouldUmountVirtualMedia { if err := rpcUnmountImage(); err != nil { scopedLogger.Warn().Err(err).Msg("unmount image failed on connection close")