diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index bb4b8dd1..3bc6cf8f 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -114,6 +114,9 @@ export interface RTCState { rpcHidChannel: RTCDataChannel | null; setRpcHidChannel: (channel: RTCDataChannel) => void; + rpcHidUnreliableChannel: RTCDataChannel | null; + setRpcHidUnreliableChannel: (channel: RTCDataChannel) => void; + peerConnectionState: RTCPeerConnectionState | null; setPeerConnectionState: (state: RTCPeerConnectionState) => void; @@ -169,6 +172,9 @@ export const useRTCStore = create(set => ({ rpcHidChannel: null, setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }), + rpcHidUnreliableChannel: null, + setRpcHidUnreliableChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableChannel: channel }), + transceiver: null, setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), diff --git a/ui/src/hooks/useHidRpc.ts b/ui/src/hooks/useHidRpc.ts index 6288b541..74f7fe10 100644 --- a/ui/src/hooks/useHidRpc.ts +++ b/ui/src/hooks/useHidRpc.ts @@ -19,23 +19,38 @@ import { const KEEPALIVE_MESSAGE = new KeypressKeepAliveMessage(); +interface sendMessageParams { + ignoreHandshakeState?: boolean; + useUnreliableChannel?: boolean; +} + export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { - const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion, hidRpcDisabled } = useRTCStore(); + const { + rpcHidChannel, + rpcHidUnreliableChannel, + setRpcHidProtocolVersion, + rpcHidProtocolVersion, hidRpcDisabled, + } = useRTCStore(); + const rpcHidReady = useMemo(() => { if (hidRpcDisabled) return false; return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; }, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]); + const rpcHidUnreliableReady = useMemo(() => { + return rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null; + }, [rpcHidUnreliableChannel, rpcHidProtocolVersion]); + const rpcHidStatus = useMemo(() => { if (hidRpcDisabled) return "disabled"; if (!rpcHidChannel) return "N/A"; if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState; if (!rpcHidProtocolVersion) return "handshaking"; - return `ready (v${rpcHidProtocolVersion})`; - }, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]); + return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`; + }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion, hidRpcDisabled]); - const sendMessage = useCallback((message: RpcMessage, ignoreHandshakeState = false) => { + const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel }: sendMessageParams = {}) => { if (hidRpcDisabled) return; if (rpcHidChannel?.readyState !== "open") return; if (!rpcHidReady && !ignoreHandshakeState) return; @@ -48,8 +63,12 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { } if (!data) return; - rpcHidChannel?.send(data as unknown as ArrayBuffer); - }, [rpcHidChannel, rpcHidReady, hidRpcDisabled]); + if (useUnreliableChannel && rpcHidUnreliableReady) { + rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer); + } else { + rpcHidChannel?.send(data as unknown as ArrayBuffer); + } + }, [rpcHidChannel, rpcHidReady, hidRpcDisabled, rpcHidUnreliableChannel, rpcHidUnreliableReady]); const reportKeyboardEvent = useCallback( (keys: number[], modifier: number) => { @@ -93,7 +112,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { ); const reportKeypressKeepAlive = useCallback(() => { - sendMessage(KEEPALIVE_MESSAGE); + sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true }); }, [sendMessage]); const sendHandshake = useCallback(() => { @@ -101,7 +120,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { if (rpcHidProtocolVersion) return; if (!rpcHidChannel) return; - sendMessage(new HandshakeMessage(HID_RPC_VERSION), true); + sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true }); }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage, hidRpcDisabled]); const handleHandshake = useCallback((message: HandshakeMessage) => { diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index bdf6de9a..8e433d6f 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -136,6 +136,7 @@ export default function KvmIdRoute() { rpcDataChannel, setTransceiver, setRpcHidChannel, + setRpcHidUnreliableChannel, } = useRTCStore(); const location = useLocation(); @@ -488,6 +489,16 @@ export default function KvmIdRoute() { setRpcHidChannel(rpcHidChannel); }; + const rpcHidUnreliableChannel = pc.createDataChannel("hidrpc-unreliable", { + // We don't need to be ordered, as we're using the unreliable channel for keepalive messages + ordered: false, + maxRetransmits: 0, + }); + rpcHidUnreliableChannel.binaryType = "arraybuffer"; + rpcHidUnreliableChannel.onopen = () => { + setRpcHidUnreliableChannel(rpcHidUnreliableChannel); + }; + setPeerConnection(pc); }, [ cleanupAndStopReconnecting, @@ -499,6 +510,7 @@ export default function KvmIdRoute() { setPeerConnectionState, setRpcDataChannel, setRpcHidChannel, + setRpcHidUnreliableChannel, setTransceiver, ]); diff --git a/webrtc.go b/webrtc.go index db9a7c2c..b505e875 100644 --- a/webrtc.go +++ b/webrtc.go @@ -145,6 +145,41 @@ func newSession(config SessionConfig) (*Session, error) { go session.handleQueues(i) } + onHidMessage := func(msg webrtc.DataChannelMessage) { + l := scopedLogger.With().Int("length", len(msg.Data)).Logger() + // only log data if the log level is debug or lower + if scopedLogger.GetLevel() > zerolog.DebugLevel { + l = l.With().Str("data", string(msg.Data)).Logger() + } + + if msg.IsString { + l.Warn().Msg("received string data in HID RPC message handler") + return + } + + if len(msg.Data) < 1 { + l.Warn().Msg("received empty data in HID RPC message handler") + return + } + + l.Trace().Msg("received data in HID RPC message handler") + + // Enqueue to ensure ordered processing + queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + if queueIndex >= len(session.hidQueue) || queueIndex < 0 { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") + queueIndex = 3 + } + + queue := session.hidQueue[queueIndex] + if queue != nil { + queue <- msg + } else { + l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") + return + } + } + peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { defer func() { if r := recover(); r != nil { @@ -157,40 +192,10 @@ func newSession(config SessionConfig) (*Session, error) { switch d.Label() { case "hidrpc": session.HidChannel = d - d.OnMessage(func(msg webrtc.DataChannelMessage) { - l := scopedLogger.With().Int("length", len(msg.Data)).Logger() - // only log data if the log level is debug or lower - if scopedLogger.GetLevel() > zerolog.DebugLevel { - l = l.With().Str("data", string(msg.Data)).Logger() - } - - if msg.IsString { - l.Warn().Msg("received string data in HID RPC message handler") - return - } - - if len(msg.Data) < 1 { - l.Warn().Msg("received empty data in HID RPC message handler") - return - } - - l.Trace().Msg("received data in HID RPC message handler") - - // Enqueue to ensure ordered processing - queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) - if queueIndex >= len(session.hidQueue) || queueIndex < 0 { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") - queueIndex = 3 - } - - queue := session.hidQueue[queueIndex] - if queue != nil { - queue <- msg - } else { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") - return - } - }) + d.OnMessage(onHidMessage) + // we won't send anything over the unreliable channel + case "hidrpc-unreliable": + d.OnMessage(onHidMessage) case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) {