feat(cloud): Use Websocket signaling in cloud mode

This commit is contained in:
Adam Shiervani 2025-04-04 15:22:06 +02:00 committed by Siyuan Miao
parent fa1b11b228
commit 68f53dcc5f
5 changed files with 152 additions and 118 deletions

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
"github.com/pion/webrtc/v4"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
@ -325,13 +326,27 @@ func runWebsocketClient() error {
// ignore non-text messages // ignore non-text messages
continue continue
} }
var req WebRTCSessionRequest
err = json.Unmarshal(msg, &req) var message struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
err = json.Unmarshal(msg, &message)
if err != nil { if err != nil {
cloudLogger.Warnf("unable to parse ws message: %v", string(msg)) cloudLogger.Warnf("unable to parse ws message: %v", string(msg))
continue continue
} }
if message.Type == "offer" {
cloudLogger.Infof("new session request received")
var req WebRTCSessionRequest
err = json.Unmarshal(message.Data, &req)
if err != nil {
cloudLogger.Warnf("unable to parse session request data: %v", string(message.Data))
continue
}
cloudLogger.Infof("new session request: %v", req.OidcGoogle) cloudLogger.Infof("new session request: %v", req.OidcGoogle)
cloudLogger.Tracef("session request info: %v", req) cloudLogger.Tracef("session request info: %v", req)
@ -342,6 +357,33 @@ func runWebsocketClient() error {
cloudLogger.Infof("error starting new session: %v", err) cloudLogger.Infof("error starting new session: %v", err)
continue continue
} }
} else if message.Type == "new-ice-candidate" {
cloudLogger.Infof("client has sent us a new ICE candidate: %v", string(message.Data))
var candidate webrtc.ICECandidateInit
// Attempt to unmarshal as a ICECandidateInit
if err := json.Unmarshal(message.Data, &candidate); err != nil {
cloudLogger.Warnf("unable to parse ICE candidate data: %v", string(message.Data))
continue
}
if candidate.Candidate == "" {
cloudLogger.Warnf("empty ICE candidate, skipping")
continue
}
cloudLogger.Infof("unmarshalled ICE candidate: %v", candidate)
if currentSession == nil {
cloudLogger.Infof("no current session, skipping ICE candidate")
continue
}
cloudLogger.Infof("adding ICE candidate to current session: %v", candidate)
if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
cloudLogger.Warnf("failed to add ICE candidate: %v", err)
}
}
} }
} }
@ -383,6 +425,7 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
ICEServers: req.ICEServers, ICEServers: req.ICEServers,
LocalIP: req.IP, LocalIP: req.IP,
IsCloud: true, IsCloud: true,
ws: c,
}) })
if err != nil { if err != nil {
_ = wsjson.Write(context.Background(), c, gin.H{"error": err}) _ = wsjson.Write(context.Background(), c, gin.H{"error": err})
@ -406,7 +449,7 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
cloudLogger.Info("new session accepted") cloudLogger.Info("new session accepted")
cloudLogger.Tracef("new session accepted: %v", session) cloudLogger.Tracef("new session accepted: %v", session)
currentSession = session currentSession = session
_ = wsjson.Write(context.Background(), c, gin.H{"sd": sd}) _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd})
return nil return nil
} }

6
ui/package-lock.json generated
View File

@ -30,6 +30,7 @@
"react-icons": "^5.4.0", "react-icons": "^5.4.0",
"react-router-dom": "^6.22.3", "react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112", "react-simple-keyboard": "^3.7.112",
"react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9", "react-xtermjs": "^1.0.9",
"recharts": "^2.15.0", "recharts": "^2.15.0",
"tailwind-merge": "^2.5.5", "tailwind-merge": "^2.5.5",
@ -5180,6 +5181,11 @@
"react-dom": ">=16.6.0" "react-dom": ">=16.6.0"
} }
}, },
"node_modules/react-use-websocket": {
"version": "4.13.0",
"resolved": "https://registry.npmjs.org/react-use-websocket/-/react-use-websocket-4.13.0.tgz",
"integrity": "sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw=="
},
"node_modules/react-xtermjs": { "node_modules/react-xtermjs": {
"version": "1.0.9", "version": "1.0.9",
"resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz", "resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz",

View File

@ -40,6 +40,7 @@
"react-icons": "^5.4.0", "react-icons": "^5.4.0",
"react-router-dom": "^6.22.3", "react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112", "react-simple-keyboard": "^3.7.112",
"react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9", "react-xtermjs": "^1.0.9",
"recharts": "^2.15.0", "recharts": "^2.15.0",
"tailwind-merge": "^2.5.5", "tailwind-merge": "^2.5.5",

View File

@ -14,6 +14,7 @@ import {
import { useInterval } from "usehooks-ts"; import { useInterval } from "usehooks-ts";
import FocusTrap from "focus-trap-react"; import FocusTrap from "focus-trap-react";
import { motion, AnimatePresence } from "framer-motion"; import { motion, AnimatePresence } from "framer-motion";
import useWebSocket, { ReadyState } from "react-use-websocket";
import { cx } from "@/cva.config"; import { cx } from "@/cva.config";
import { import {
@ -117,7 +118,6 @@ const loader = async ({ params }: LoaderFunctionArgs) => {
export default function KvmIdRoute() { export default function KvmIdRoute() {
const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp; const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp;
// Depending on the mode, we set the appropriate variables // Depending on the mode, we set the appropriate variables
const user = "user" in loaderResp ? loaderResp.user : null; const user = "user" in loaderResp ? loaderResp.user : null;
const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null; const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null;
@ -169,87 +169,70 @@ export default function KvmIdRoute() {
useEffect(() => { useEffect(() => {
connectionFailedRef.current = connectionFailed; connectionFailedRef.current = connectionFailed;
}, [connectionFailed]); }, [connectionFailed]);
const signalingAttempts = useRef(0); const signalingAttempts = useRef(0);
const syncRemoteSessionDescription = useCallback( const setRemoteSessionDescription = useCallback(
async function syncRemoteSessionDescription(pc: RTCPeerConnection) { async function setRemoteSessionDescription(
try { pc: RTCPeerConnection,
if (!pc) return; remoteDescription: RTCSessionDescriptionInit,
) {
const sd = btoa(JSON.stringify(pc.localDescription)); setLoadingMessage("Setting remote description");
pc.setRemoteDescription(new RTCSessionDescription(remoteDescription));
const sessionUrl = isOnDevice
? `${DEVICE_API}/webrtc/session`
: `${CLOUD_API}/webrtc/session`;
console.log("Trying to get remote session description");
setLoadingMessage(
`Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
);
const res = await api.POST(sessionUrl, {
sd,
// When on device, we don't need to specify the device id, as it's already known
...(isOnDevice ? {} : { id: params.id }),
});
const json = await res.json();
if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
if (!res.ok) {
console.error("Error getting SDP", { status: res.status, json });
throw new Error("Error getting SDP");
}
console.log("Successfully got Remote Session Description. Setting.");
setLoadingMessage("Setting remote session description...");
const decodedSd = atob(json.sd);
const parsedSd = JSON.parse(decodedSd);
pc.setRemoteDescription(new RTCSessionDescription(parsedSd));
await new Promise((resolve, reject) => {
console.log("Waiting for remote description to be set"); console.log("Waiting for remote description to be set");
const maxAttempts = 10;
const interval = 1000;
let attempts = 0; let attempts = 0;
setLoadingMessage("Establishing secure connection...");
const checkInterval = setInterval(() => { const checkInterval = setInterval(() => {
attempts++; attempts++;
// When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
if (pc.sctp?.state === "connected") { if (pc.sctp?.state === "connected") {
console.log("Remote description set"); console.log("Remote description set");
clearInterval(checkInterval); clearInterval(checkInterval);
resolve(true); } else if (attempts >= 10) {
} else if (attempts >= maxAttempts) { console.log("Failed to get remote description after 10 attempts");
console.log(
`Failed to get remote description after ${maxAttempts} attempts`,
);
closePeerConnection(); closePeerConnection();
clearInterval(checkInterval); clearInterval(checkInterval);
reject(
new Error(
`Failed to get remote description after ${maxAttempts} attempts`,
),
);
} else { } else {
console.log("Waiting for remote description to be set"); console.log("Waiting for remote description to be set");
} }
}, interval); }, 1000);
}); },
} catch (error) { [closePeerConnection],
console.error("Error getting SDP", { error }); );
console.log("Connection failed", connectionFailedRef.current);
if (connectionFailedRef.current) return; // TODO: Handle auth!!! The old signaling http request could get a 401 on local and on cloud
if (signalingAttempts.current < 5) { const { sendMessage, readyState } = useWebSocket(
signalingAttempts.current++; isOnDevice ? `${DEVICE_API}/client` : `${CLOUD_API}/client?id=${params.id}`,
await new Promise(resolve => setTimeout(resolve, 500)); {
console.log("Attempting to get SDP again", signalingAttempts.current); heartbeat: true,
syncRemoteSessionDescription(pc); onMessage: message => {
} else { if (message.data === "pong") return;
closePeerConnection(); if (!peerConnection) return;
}
const parsedMessage = JSON.parse(message.data);
if (parsedMessage.type === "answer") {
console.log("Setting remote description", parsedMessage.data);
const sd = atob(parsedMessage.data);
const remoteSessionDescription = JSON.parse(sd);
setRemoteSessionDescription(
peerConnection,
new RTCSessionDescription(remoteSessionDescription),
);
} else if (parsedMessage.type === "new-ice-candidate") {
console.log("Received new ICE candidate", parsedMessage.data);
const candidate = parsedMessage.data;
peerConnection.addIceCandidate(candidate);
} }
}, },
[closePeerConnection, navigate, params.id], },
);
const sendWebRTCSignal = useCallback(
(type: string, data: string | RTCIceCandidate) => {
sendMessage(JSON.stringify({ type, data }));
},
[sendMessage],
); );
const setupPeerConnection = useCallback(async () => { const setupPeerConnection = useCallback(async () => {
@ -267,6 +250,7 @@ export default function KvmIdRoute() {
? { iceServers: [iceConfig?.iceServers] } ? { iceServers: [iceConfig?.iceServers] }
: {}), : {}),
}); });
console.log("Peer connection created", pc); console.log("Peer connection created", pc);
setLoadingMessage("Peer connection created"); setLoadingMessage("Peer connection created");
} catch (e) { } catch (e) {
@ -282,21 +266,24 @@ export default function KvmIdRoute() {
console.log("Connection state changed", pc.connectionState); console.log("Connection state changed", pc.connectionState);
}; };
pc.onicegatheringstatechange = event => { pc.onnegotiationneeded = async () => {
const pc = event.currentTarget as RTCPeerConnection; try {
console.log("ICE Gathering State Changed", pc.iceGatheringState); const offer = await pc.createOffer();
if (pc.iceGatheringState === "complete") { await pc.setLocalDescription(offer);
console.log("ICE Gathering completed"); const sd = btoa(JSON.stringify(pc.localDescription));
setLoadingMessage("ICE Gathering completed"); sendWebRTCSignal("offer", sd);
} catch (e) {
// We can now start the https/ws connection to get the remote session description from the KVM device console.error(`Error creating offer: ${e}`, new Date().toISOString());
syncRemoteSessionDescription(pc); closePeerConnection();
} else if (pc.iceGatheringState === "gathering") {
console.log("ICE Gathering Started");
setLoadingMessage("Gathering ICE candidates...");
} }
}; };
pc.onicecandidate = async ({ candidate }) => {
if (!candidate) return;
if (candidate.candidate === "") return;
sendWebRTCSignal("new-ice-candidate", candidate);
};
pc.ontrack = function (event) { pc.ontrack = function (event) {
setMediaMediaStream(event.streams[0]); setMediaMediaStream(event.streams[0]);
}; };
@ -314,31 +301,24 @@ export default function KvmIdRoute() {
}; };
setPeerConnection(pc); setPeerConnection(pc);
try {
const offer = await pc.createOffer();
await pc.setLocalDescription(offer);
} catch (e) {
console.error(`Error creating offer: ${e}`, new Date().toISOString());
closePeerConnection();
}
}, [ }, [
closePeerConnection, closePeerConnection,
iceConfig?.iceServers, iceConfig?.iceServers,
sendWebRTCSignal,
setDiskChannel, setDiskChannel,
setMediaMediaStream, setMediaMediaStream,
setPeerConnection, setPeerConnection,
setRpcDataChannel, setRpcDataChannel,
setTransceiver, setTransceiver,
syncRemoteSessionDescription,
]); ]);
// On boot, if the connection state is undefined, we connect to the WebRTC // On boot, if the connection state is undefined, we connect to the WebRTC
useEffect(() => { useEffect(() => {
if (readyState !== ReadyState.OPEN) return;
if (peerConnection?.connectionState === undefined) { if (peerConnection?.connectionState === undefined) {
setupPeerConnection(); setupPeerConnection();
} }
}, [setupPeerConnection, peerConnection?.connectionState]); }, [readyState, setupPeerConnection, peerConnection?.connectionState]);
// Cleanup effect // Cleanup effect
const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats); const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats);
@ -593,7 +573,7 @@ export default function KvmIdRoute() {
/> />
<div className="flex h-full w-full overflow-hidden"> <div className="flex h-full w-full overflow-hidden">
<div className="pointer-events-none fixed inset-0 isolate z-50 flex h-full w-full items-center justify-center"> <div className="pointer-events-none fixed inset-0 isolate z-20 flex h-full w-full items-center justify-center">
<div className="my-2 h-full max-h-[720px] w-full max-w-[1280px] rounded-md"> <div className="my-2 h-full max-h-[720px] w-full max-w-[1280px] rounded-md">
<LoadingConnectionOverlay <LoadingConnectionOverlay
show={ show={

View File

@ -1,11 +1,15 @@
package kvm package kvm
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"net" "net"
"strings" "strings"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
"github.com/pion/webrtc/v4" "github.com/pion/webrtc/v4"
) )
@ -23,6 +27,7 @@ type SessionConfig struct {
ICEServers []string ICEServers []string
LocalIP string LocalIP string
IsCloud bool IsCloud bool
ws *websocket.Conn
} }
func (s *Session) ExchangeOffer(offerStr string) (string, error) { func (s *Session) ExchangeOffer(offerStr string) (string, error) {
@ -46,19 +51,11 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
return "", err return "", err
} }
// Create channel that is blocked until ICE Gathering is complete
gatherComplete := webrtc.GatheringCompletePromise(s.peerConnection)
// Sets the LocalDescription, and starts our UDP listeners // Sets the LocalDescription, and starts our UDP listeners
if err = s.peerConnection.SetLocalDescription(answer); err != nil { if err = s.peerConnection.SetLocalDescription(answer); err != nil {
return "", err return "", err
} }
// Block until ICE Gathering is complete, disabling trickle ICE
// we do this because we only can exchange one signaling message
// in a production application you should exchange ICE Candidates via OnICECandidate
<-gatherComplete
localDescription, err := json.Marshal(s.peerConnection.LocalDescription()) localDescription, err := json.Marshal(s.peerConnection.LocalDescription())
if err != nil { if err != nil {
return "", err return "", err
@ -144,6 +141,13 @@ func newSession(config SessionConfig) (*Session, error) {
}() }()
var isConnected bool var isConnected bool
peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
cloudLogger.Infof("we got a new ICE candidate: %v", candidate)
if candidate != nil {
wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
}
})
peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
logger.Infof("Connection State has changed %s", connectionState) logger.Infof("Connection State has changed %s", connectionState)
if connectionState == webrtc.ICEConnectionStateConnected { if connectionState == webrtc.ICEConnectionStateConnected {