chore: use unreliable channel to send keepalive events

This commit is contained in:
Siyuan Miao 2025-09-11 16:16:47 +02:00
parent 19be5ea885
commit 3f83efa830
4 changed files with 84 additions and 42 deletions

View File

@ -114,6 +114,9 @@ export interface RTCState {
rpcHidChannel: RTCDataChannel | null; rpcHidChannel: RTCDataChannel | null;
setRpcHidChannel: (channel: RTCDataChannel) => void; setRpcHidChannel: (channel: RTCDataChannel) => void;
rpcHidUnreliableChannel: RTCDataChannel | null;
setRpcHidUnreliableChannel: (channel: RTCDataChannel) => void;
peerConnectionState: RTCPeerConnectionState | null; peerConnectionState: RTCPeerConnectionState | null;
setPeerConnectionState: (state: RTCPeerConnectionState) => void; setPeerConnectionState: (state: RTCPeerConnectionState) => void;
@ -169,6 +172,9 @@ export const useRTCStore = create<RTCState>(set => ({
rpcHidChannel: null, rpcHidChannel: null,
setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }), setRpcHidChannel: (channel: RTCDataChannel) => set({ rpcHidChannel: channel }),
rpcHidUnreliableChannel: null,
setRpcHidUnreliableChannel: (channel: RTCDataChannel) => set({ rpcHidUnreliableChannel: channel }),
transceiver: null, transceiver: null,
setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }), setTransceiver: (transceiver: RTCRtpTransceiver) => set({ transceiver }),

View File

@ -19,23 +19,38 @@ import {
const KEEPALIVE_MESSAGE = new KeypressKeepAliveMessage(); const KEEPALIVE_MESSAGE = new KeypressKeepAliveMessage();
interface sendMessageParams {
ignoreHandshakeState?: boolean;
useUnreliableChannel?: boolean;
}
export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) { export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
const { rpcHidChannel, setRpcHidProtocolVersion, rpcHidProtocolVersion, hidRpcDisabled } = useRTCStore(); const {
rpcHidChannel,
rpcHidUnreliableChannel,
setRpcHidProtocolVersion,
rpcHidProtocolVersion, hidRpcDisabled,
} = useRTCStore();
const rpcHidReady = useMemo(() => { const rpcHidReady = useMemo(() => {
if (hidRpcDisabled) return false; if (hidRpcDisabled) return false;
return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null; return rpcHidChannel?.readyState === "open" && rpcHidProtocolVersion !== null;
}, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]); }, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]);
const rpcHidUnreliableReady = useMemo(() => {
return rpcHidUnreliableChannel?.readyState === "open" && rpcHidProtocolVersion !== null;
}, [rpcHidUnreliableChannel, rpcHidProtocolVersion]);
const rpcHidStatus = useMemo(() => { const rpcHidStatus = useMemo(() => {
if (hidRpcDisabled) return "disabled"; if (hidRpcDisabled) return "disabled";
if (!rpcHidChannel) return "N/A"; if (!rpcHidChannel) return "N/A";
if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState; if (rpcHidChannel.readyState !== "open") return rpcHidChannel.readyState;
if (!rpcHidProtocolVersion) return "handshaking"; if (!rpcHidProtocolVersion) return "handshaking";
return `ready (v${rpcHidProtocolVersion})`; return `ready (v${rpcHidProtocolVersion}${rpcHidUnreliableReady ? "+u" : ""})`;
}, [rpcHidChannel, rpcHidProtocolVersion, hidRpcDisabled]); }, [rpcHidChannel, rpcHidUnreliableReady, rpcHidProtocolVersion, hidRpcDisabled]);
const sendMessage = useCallback((message: RpcMessage, ignoreHandshakeState = false) => { const sendMessage = useCallback((message: RpcMessage, { ignoreHandshakeState, useUnreliableChannel }: sendMessageParams = {}) => {
if (hidRpcDisabled) return; if (hidRpcDisabled) return;
if (rpcHidChannel?.readyState !== "open") return; if (rpcHidChannel?.readyState !== "open") return;
if (!rpcHidReady && !ignoreHandshakeState) return; if (!rpcHidReady && !ignoreHandshakeState) return;
@ -48,8 +63,12 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
} }
if (!data) return; if (!data) return;
rpcHidChannel?.send(data as unknown as ArrayBuffer); if (useUnreliableChannel && rpcHidUnreliableReady) {
}, [rpcHidChannel, rpcHidReady, hidRpcDisabled]); rpcHidUnreliableChannel?.send(data as unknown as ArrayBuffer);
} else {
rpcHidChannel?.send(data as unknown as ArrayBuffer);
}
}, [rpcHidChannel, rpcHidReady, hidRpcDisabled, rpcHidUnreliableChannel, rpcHidUnreliableReady]);
const reportKeyboardEvent = useCallback( const reportKeyboardEvent = useCallback(
(keys: number[], modifier: number) => { (keys: number[], modifier: number) => {
@ -93,7 +112,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
); );
const reportKeypressKeepAlive = useCallback(() => { const reportKeypressKeepAlive = useCallback(() => {
sendMessage(KEEPALIVE_MESSAGE); sendMessage(KEEPALIVE_MESSAGE, { useUnreliableChannel: true });
}, [sendMessage]); }, [sendMessage]);
const sendHandshake = useCallback(() => { const sendHandshake = useCallback(() => {
@ -101,7 +120,7 @@ export function useHidRpc(onHidRpcMessage?: (payload: RpcMessage) => void) {
if (rpcHidProtocolVersion) return; if (rpcHidProtocolVersion) return;
if (!rpcHidChannel) return; if (!rpcHidChannel) return;
sendMessage(new HandshakeMessage(HID_RPC_VERSION), true); sendMessage(new HandshakeMessage(HID_RPC_VERSION), { ignoreHandshakeState: true });
}, [rpcHidChannel, rpcHidProtocolVersion, sendMessage, hidRpcDisabled]); }, [rpcHidChannel, rpcHidProtocolVersion, sendMessage, hidRpcDisabled]);
const handleHandshake = useCallback((message: HandshakeMessage) => { const handleHandshake = useCallback((message: HandshakeMessage) => {

View File

@ -136,6 +136,7 @@ export default function KvmIdRoute() {
rpcDataChannel, rpcDataChannel,
setTransceiver, setTransceiver,
setRpcHidChannel, setRpcHidChannel,
setRpcHidUnreliableChannel,
} = useRTCStore(); } = useRTCStore();
const location = useLocation(); const location = useLocation();
@ -488,6 +489,16 @@ export default function KvmIdRoute() {
setRpcHidChannel(rpcHidChannel); setRpcHidChannel(rpcHidChannel);
}; };
const rpcHidUnreliableChannel = pc.createDataChannel("hidrpc-unreliable", {
// We don't need to be ordered, as we're using the unreliable channel for keepalive messages
ordered: false,
maxRetransmits: 0,
});
rpcHidUnreliableChannel.binaryType = "arraybuffer";
rpcHidUnreliableChannel.onopen = () => {
setRpcHidUnreliableChannel(rpcHidUnreliableChannel);
};
setPeerConnection(pc); setPeerConnection(pc);
}, [ }, [
cleanupAndStopReconnecting, cleanupAndStopReconnecting,
@ -499,6 +510,7 @@ export default function KvmIdRoute() {
setPeerConnectionState, setPeerConnectionState,
setRpcDataChannel, setRpcDataChannel,
setRpcHidChannel, setRpcHidChannel,
setRpcHidUnreliableChannel,
setTransceiver, setTransceiver,
]); ]);

View File

@ -145,6 +145,41 @@ func newSession(config SessionConfig) (*Session, error) {
go session.handleQueues(i) go session.handleQueues(i)
} }
onHidMessage := func(msg webrtc.DataChannelMessage) {
l := scopedLogger.With().Int("length", len(msg.Data)).Logger()
// only log data if the log level is debug or lower
if scopedLogger.GetLevel() > zerolog.DebugLevel {
l = l.With().Str("data", string(msg.Data)).Logger()
}
if msg.IsString {
l.Warn().Msg("received string data in HID RPC message handler")
return
}
if len(msg.Data) < 1 {
l.Warn().Msg("received empty data in HID RPC message handler")
return
}
l.Trace().Msg("received data in HID RPC message handler")
// Enqueue to ensure ordered processing
queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0]))
if queueIndex >= len(session.hidQueue) || queueIndex < 0 {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found")
queueIndex = 3
}
queue := session.hidQueue[queueIndex]
if queue != nil {
queue <- msg
} else {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil")
return
}
}
peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -157,40 +192,10 @@ func newSession(config SessionConfig) (*Session, error) {
switch d.Label() { switch d.Label() {
case "hidrpc": case "hidrpc":
session.HidChannel = d session.HidChannel = d
d.OnMessage(func(msg webrtc.DataChannelMessage) { d.OnMessage(onHidMessage)
l := scopedLogger.With().Int("length", len(msg.Data)).Logger() // we won't send anything over the unreliable channel
// only log data if the log level is debug or lower case "hidrpc-unreliable":
if scopedLogger.GetLevel() > zerolog.DebugLevel { d.OnMessage(onHidMessage)
l = l.With().Str("data", string(msg.Data)).Logger()
}
if msg.IsString {
l.Warn().Msg("received string data in HID RPC message handler")
return
}
if len(msg.Data) < 1 {
l.Warn().Msg("received empty data in HID RPC message handler")
return
}
l.Trace().Msg("received data in HID RPC message handler")
// Enqueue to ensure ordered processing
queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0]))
if queueIndex >= len(session.hidQueue) || queueIndex < 0 {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found")
queueIndex = 3
}
queue := session.hidQueue[queueIndex]
if queue != nil {
queue <- msg
} else {
l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil")
return
}
})
case "rpc": case "rpc":
session.RPCChannel = d session.RPCChannel = d
d.OnMessage(func(msg webrtc.DataChannelMessage) { d.OnMessage(func(msg webrtc.DataChannelMessage) {