diff --git a/hidrpc.go b/hidrpc.go index 3d458f3d..8233e0f4 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -66,8 +66,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } } -func onHidMessage(data []byte, session *Session) { - scopedLogger := hidRPCLogger.With().Bytes("data", data).Logger() +func onHidMessage(msg hidQueueMessage, session *Session) { + data := msg.Data + + scopedLogger := hidRPCLogger.With(). + Str("channel", msg.channel). + Bytes("data", data). + Logger() scopedLogger.Debug().Msg("HID RPC message received") if len(data) < 1 { diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index 3bc6cf8f..b6eebfc0 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -117,6 +117,9 @@ export interface RTCState { rpcHidUnreliableChannel: RTCDataChannel | null; setRpcHidUnreliableChannel: (channel: RTCDataChannel) => void; + rpcHidUnreliableNonOrderedChannel: RTCDataChannel | null; + setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => void; + peerConnectionState: RTCPeerConnectionState | null; setPeerConnectionState: (state: RTCPeerConnectionState) => void; @@ -175,6 +178,9 @@ export const useRTCStore = create(set => ({ rpcHidUnreliableChannel: null, setRpcHidUnreliableChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableChannel: channel }), + rpcHidUnreliableNonOrderedChannel: null, + setRpcHidUnreliableNonOrderedChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableNonOrderedChannel: channel }), + transceiver: null, setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts index 74f7fe10..6f75c5ae 100644 --- a/ui/src/hooks/useHidRpc.ts +++ b/ui/src/hooks/useHidRpc.ts @@ -22,12 +22,14 @@ const KEEPALIVE_MESSAGE = new KeypressKeepAliveMessage(); interface sendMessageParams { ignoreHandshakeState?: boolean; useUnreliableChannel?: boolean; + requireOrdered?: boolean; } export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { const { rpcHidChannel, rpcHidUnreliableChannel, + rpcHidUnreliableNonOrderedChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion, hidRpcDisabled, } = useRTCStore(); @@ -41,6 +43,10 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { return rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null; }, [rpcHidUnreliableChannel, rpcHidProtocolVersion]); + const rpcHidUnreliableNonOrderedReady = useMemo(() => { + return rpcHidUnreliableNonOrderedChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + }, [rpcHidUnreliableNonOrderedChannel, rpcHidProtocolVersion]); + const rpcHidStatus = useMemo(() => { if (hidRpcDisabled) return "disabled"; @@ -50,7 +56,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`; }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion, hidRpcDisabled]); - const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel }: sendMessageParams = {}) => { + const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel, requireOrdered = true }: sendMessageParams = {}) => { if (hidRpcDisabled) return; if (rpcHidChannel?.readyState !== "open") return; if (!rpcHidReady && !ignoreHandshakeState) return; @@ -63,12 +69,24 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { } if (!data) return; - if (useUnreliableChannel && rpcHidUnreliableReady) { - rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); - } else { - rpcHidChannel?.send(data as unknown as ArrayBuffer); + if (useUnreliableChannel) { + if (requireOrdered && rpcHidUnreliableReady) { + rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); + } else if (!requireOrdered && rpcHidUnreliableNonOrderedReady) { + rpcHidUnreliableNonOrderedChannel?.send(data as unknown as ArrayBuffer); + } + return; } - }, [rpcHidChannel, rpcHidReady, hidRpcDisabled, rpcHidUnreliableChannel, rpcHidUnreliableReady]); + + rpcHidChannel?.send(data as unknown as ArrayBuffer); + }, [ + rpcHidChannel, + rpcHidUnreliableChannel, + hidRpcDisabled, rpcHidUnreliableNonOrderedChannel, + rpcHidReady, + rpcHidUnreliableReady, + rpcHidUnreliableNonOrderedReady, + ]); const reportKeyboardEvent = useCallback( (keys: number[], modifier: number) => { @@ -85,7 +103,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { const reportAbsMouseEvent = useCallback( (x: number, y: number, buttons: number) => { - sendMessage(new PointerReportMessage(x, y, buttons)); + sendMessage(new PointerReportMessage(x, y, buttons), { useUnreliableChannel: true }); }, [sendMessage], ); @@ -112,7 +130,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { ); const reportKeypressKeepAlive = useCallback(() => { - sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true }); + sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true, requireOrdered: false }); }, [sendMessage]); const sendHandshake = useCallback(() => { diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index 8e433d6f..fa9c4295 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -136,6 +136,7 @@ export default function KvmIdRoute() { rpcDataChannel, setTransceiver, setRpcHidChannel, + setRpcHidUnreliableNonOrderedChannel, setRpcHidUnreliableChannel, } = useRTCStore(); @@ -489,9 +490,8 @@ export default function KvmIdRoute() { setRpcHidChannel(rpcHidChannel); }; - const rpcHidUnreliableChannel = pc.createDataChannel("hidrpc-unreliable", { - // We don't need to be ordered, as we're using the unreliable channel for keepalive messages - ordered: false, + const rpcHidUnreliableChannel = pc.createDataChannel("hidrpc-unreliable-ordered", { + ordered: true, maxRetransmits: 0, }); rpcHidUnreliableChannel.binaryType = "arraybuffer"; @@ -499,6 +499,15 @@ export default function KvmIdRoute() { setRpcHidUnreliableChannel(rpcHidUnreliableChannel); }; + const rpcHidUnreliableNonOrderedChannel = pc.createDataChannel("hidrpc-unreliable-nonordered", { + ordered: false, + maxRetransmits: 0, + }); + rpcHidUnreliableNonOrderedChannel.binaryType = "arraybuffer"; + rpcHidUnreliableNonOrderedChannel.onopen = () => { + setRpcHidUnreliableNonOrderedChannel(rpcHidUnreliableNonOrderedChannel); + }; + setPeerConnection(pc); }, [ cleanupAndStopReconnecting, @@ -510,6 +519,7 @@ export default function KvmIdRoute() { setPeerConnectionState, setRpcDataChannel, setRpcHidChannel, + setRpcHidUnreliableNonOrderedChannel, setRpcHidUnreliableChannel, setTransceiver, ]); diff --git a/webrtc.go b/webrtc.go index b505e875..707a6a9d 100644 --- a/webrtc.go +++ b/webrtc.go @@ -29,7 +29,12 @@ type Session struct { hidRPCAvailable bool hidQueueLock sync.Mutex - hidQueue []chan webrtc.DataChannelMessage + hidQueue []chan hidQueueMessage +} + +type hidQueueMessage struct { + webrtc.DataChannelMessage + channel string } type SessionConfig struct { @@ -78,16 +83,59 @@ func (s *Session) initQueues() { s.hidQueueLock.Lock() defer s.hidQueueLock.Unlock() - s.hidQueue = make([]chan webrtc.DataChannelMessage, 0) + s.hidQueue = make([]chan hidQueueMessage, 0) for i := 0; i < 4; i++ { - q := make(chan webrtc.DataChannelMessage, 256) + q := make(chan hidQueueMessage, 256) s.hidQueue = append(s.hidQueue, q) } } func (s *Session) handleQueues(index int) { for msg := range s.hidQueue[index] { - onHidMessage(msg.Data, s) + onHidMessage(msg, s) + } +} + +func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { + return func(msg webrtc.DataChannelMessage) { + l := scopedLogger.With(). + Str("channel", channel). + Int("length", len(msg.Data)). + Logger() + // only log data if the log level is debug or lower + if scopedLogger.GetLevel() > zerolog.DebugLevel { + l = l.With().Str("data", string(msg.Data)).Logger() + } + + if msg.IsString { + l.Warn().Msg("received string data in HID RPC message handler") + return + } + + if len(msg.Data) < 1 { + l.Warn().Msg("received empty data in HID RPC message handler") + return + } + + l.Trace().Msg("received data in HID RPC message handler") + + // Enqueue to ensure ordered processing + queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + if queueIndex >= len(session.hidQueue) || queueIndex < 0 { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") + queueIndex = 3 + } + + queue := session.hidQueue[queueIndex] + if queue != nil { + queue <- hidQueueMessage{ + DataChannelMessage: msg, + channel: channel, + } + } else { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") + return + } } } @@ -145,41 +193,6 @@ func newSession(config SessionConfig) (*Session, error) { go session.handleQueues(i) } - onHidMessage := func(msg webrtc.DataChannelMessage) { - l := scopedLogger.With().Int("length", len(msg.Data)).Logger() - // only log data if the log level is debug or lower - if scopedLogger.GetLevel() > zerolog.DebugLevel { - l = l.With().Str("data", string(msg.Data)).Logger() - } - - if msg.IsString { - l.Warn().Msg("received string data in HID RPC message handler") - return - } - - if len(msg.Data) < 1 { - l.Warn().Msg("received empty data in HID RPC message handler") - return - } - - l.Trace().Msg("received data in HID RPC message handler") - - // Enqueue to ensure ordered processing - queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) - if queueIndex >= len(session.hidQueue) || queueIndex < 0 { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") - queueIndex = 3 - } - - queue := session.hidQueue[queueIndex] - if queue != nil { - queue <- msg - } else { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") - return - } - } - peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { defer func() { if r := recover(); r != nil { @@ -192,10 +205,12 @@ func newSession(config SessionConfig) (*Session, error) { switch d.Label() { case "hidrpc": session.HidChannel = d - d.OnMessage(onHidMessage) - // we won't send anything over the unreliable channel - case "hidrpc-unreliable": - d.OnMessage(onHidMessage) + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc")) + // we won't send anything over the unreliable channels + case "hidrpc-unreliable-ordered": + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-ordered")) + case "hidrpc-unreliable-nonordered": + d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-nonordered")) case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) {