feat(cloud): Use Websocket signaling in cloud mode

This commit is contained in:
Adam Shiervani 2025-04-04 15:22:06 +02:00 committed by Siyuan Miao
parent fa1b11b228
commit 68f53dcc5f
5 changed files with 152 additions and 118 deletions

View File

@ -11,6 +11,7 @@ import (
"time"
"github.com/coder/websocket/wsjson"
"github.com/pion/webrtc/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@ -325,22 +326,63 @@ func runWebsocketClient() error {
// ignore non-text messages
continue
}
var req WebRTCSessionRequest
err = json.Unmarshal(msg, &req)
var message struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
err = json.Unmarshal(msg, &message)
if err != nil {
cloudLogger.Warnf("unable to parse ws message: %v", string(msg))
continue
}
cloudLogger.Infof("new session request: %v", req.OidcGoogle)
cloudLogger.Tracef("session request info: %v", req)
if message.Type == "offer" {
cloudLogger.Infof("new session request received")
var req WebRTCSessionRequest
err = json.Unmarshal(message.Data, &req)
if err != nil {
cloudLogger.Warnf("unable to parse session request data: %v", string(message.Data))
continue
}
metricCloudConnectionSessionRequestCount.Inc()
metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime()
err = handleSessionRequest(runCtx, c, req)
if err != nil {
cloudLogger.Infof("error starting new session: %v", err)
continue
cloudLogger.Infof("new session request: %v", req.OidcGoogle)
cloudLogger.Tracef("session request info: %v", req)
metricCloudConnectionSessionRequestCount.Inc()
metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime()
err = handleSessionRequest(runCtx, c, req)
if err != nil {
cloudLogger.Infof("error starting new session: %v", err)
continue
}
} else if message.Type == "new-ice-candidate" {
cloudLogger.Infof("client has sent us a new ICE candidate: %v", string(message.Data))
var candidate webrtc.ICECandidateInit
// Attempt to unmarshal as a ICECandidateInit
if err := json.Unmarshal(message.Data, &candidate); err != nil {
cloudLogger.Warnf("unable to parse ICE candidate data: %v", string(message.Data))
continue
}
if candidate.Candidate == "" {
cloudLogger.Warnf("empty ICE candidate, skipping")
continue
}
cloudLogger.Infof("unmarshalled ICE candidate: %v", candidate)
if currentSession == nil {
cloudLogger.Infof("no current session, skipping ICE candidate")
continue
}
cloudLogger.Infof("adding ICE candidate to current session: %v", candidate)
if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
cloudLogger.Warnf("failed to add ICE candidate: %v", err)
}
}
}
}
@ -383,6 +425,7 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
ICEServers: req.ICEServers,
LocalIP: req.IP,
IsCloud: true,
ws: c,
})
if err != nil {
_ = wsjson.Write(context.Background(), c, gin.H{"error": err})
@ -406,7 +449,7 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
cloudLogger.Info("new session accepted")
cloudLogger.Tracef("new session accepted: %v", session)
currentSession = session
_ = wsjson.Write(context.Background(), c, gin.H{"sd": sd})
_ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd})
return nil
}

6
ui/package-lock.json generated
View File

@ -30,6 +30,7 @@
"react-icons": "^5.4.0",
"react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112",
"react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9",
"recharts": "^2.15.0",
"tailwind-merge": "^2.5.5",
@ -5180,6 +5181,11 @@
"react-dom": ">=16.6.0"
}
},
"node_modules/react-use-websocket": {
"version": "4.13.0",
"resolved": "https://registry.npmjs.org/react-use-websocket/-/react-use-websocket-4.13.0.tgz",
"integrity": "sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw=="
},
"node_modules/react-xtermjs": {
"version": "1.0.9",
"resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz",

View File

@ -40,6 +40,7 @@
"react-icons": "^5.4.0",
"react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112",
"react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9",
"recharts": "^2.15.0",
"tailwind-merge": "^2.5.5",

View File

@ -14,6 +14,7 @@ import {
import { useInterval } from "usehooks-ts";
import FocusTrap from "focus-trap-react";
import { motion, AnimatePresence } from "framer-motion";
import useWebSocket, { ReadyState } from "react-use-websocket";
import { cx } from "@/cva.config";
import {
@ -117,7 +118,6 @@ const loader = async ({ params }: LoaderFunctionArgs) => {
export default function KvmIdRoute() {
const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp;
// Depending on the mode, we set the appropriate variables
const user = "user" in loaderResp ? loaderResp.user : null;
const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null;
@ -169,87 +169,70 @@ export default function KvmIdRoute() {
useEffect(() => {
connectionFailedRef.current = connectionFailed;
}, [connectionFailed]);
const signalingAttempts = useRef(0);
const syncRemoteSessionDescription = useCallback(
async function syncRemoteSessionDescription(pc: RTCPeerConnection) {
try {
if (!pc) return;
const setRemoteSessionDescription = useCallback(
async function setRemoteSessionDescription(
pc: RTCPeerConnection,
remoteDescription: RTCSessionDescriptionInit,
) {
setLoadingMessage("Setting remote description");
pc.setRemoteDescription(new RTCSessionDescription(remoteDescription));
console.log("Waiting for remote description to be set");
let attempts = 0;
const sd = btoa(JSON.stringify(pc.localDescription));
setLoadingMessage("Establishing secure connection...");
const checkInterval = setInterval(() => {
attempts++;
const sessionUrl = isOnDevice
? `${DEVICE_API}/webrtc/session`
: `${CLOUD_API}/webrtc/session`;
console.log("Trying to get remote session description");
setLoadingMessage(
`Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
);
const res = await api.POST(sessionUrl, {
sd,
// When on device, we don't need to specify the device id, as it's already known
...(isOnDevice ? {} : { id: params.id }),
});
const json = await res.json();
if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
if (!res.ok) {
console.error("Error getting SDP", { status: res.status, json });
throw new Error("Error getting SDP");
}
console.log("Successfully got Remote Session Description. Setting.");
setLoadingMessage("Setting remote session description...");
const decodedSd = atob(json.sd);
const parsedSd = JSON.parse(decodedSd);
pc.setRemoteDescription(new RTCSessionDescription(parsedSd));
await new Promise((resolve, reject) => {
console.log("Waiting for remote description to be set");
const maxAttempts = 10;
const interval = 1000;
let attempts = 0;
const checkInterval = setInterval(() => {
attempts++;
// When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
if (pc.sctp?.state === "connected") {
console.log("Remote description set");
clearInterval(checkInterval);
resolve(true);
} else if (attempts >= maxAttempts) {
console.log(
`Failed to get remote description after ${maxAttempts} attempts`,
);
closePeerConnection();
clearInterval(checkInterval);
reject(
new Error(
`Failed to get remote description after ${maxAttempts} attempts`,
),
);
} else {
console.log("Waiting for remote description to be set");
}
}, interval);
});
} catch (error) {
console.error("Error getting SDP", { error });
console.log("Connection failed", connectionFailedRef.current);
if (connectionFailedRef.current) return;
if (signalingAttempts.current < 5) {
signalingAttempts.current++;
await new Promise(resolve => setTimeout(resolve, 500));
console.log("Attempting to get SDP again", signalingAttempts.current);
syncRemoteSessionDescription(pc);
} else {
// When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
if (pc.sctp?.state === "connected") {
console.log("Remote description set");
clearInterval(checkInterval);
} else if (attempts >= 10) {
console.log("Failed to get remote description after 10 attempts");
closePeerConnection();
clearInterval(checkInterval);
} else {
console.log("Waiting for remote description to be set");
}
}
}, 1000);
},
[closePeerConnection, navigate, params.id],
[closePeerConnection],
);
// TODO: Handle auth!!! The old signaling http request could get a 401 on local and on cloud
const { sendMessage, readyState } = useWebSocket(
isOnDevice ? `${DEVICE_API}/client` : `${CLOUD_API}/client?id=${params.id}`,
{
heartbeat: true,
onMessage: message => {
if (message.data === "pong") return;
if (!peerConnection) return;
const parsedMessage = JSON.parse(message.data);
if (parsedMessage.type === "answer") {
console.log("Setting remote description", parsedMessage.data);
const sd = atob(parsedMessage.data);
const remoteSessionDescription = JSON.parse(sd);
setRemoteSessionDescription(
peerConnection,
new RTCSessionDescription(remoteSessionDescription),
);
} else if (parsedMessage.type === "new-ice-candidate") {
console.log("Received new ICE candidate", parsedMessage.data);
const candidate = parsedMessage.data;
peerConnection.addIceCandidate(candidate);
}
},
},
);
const sendWebRTCSignal = useCallback(
(type: string, data: string | RTCIceCandidate) => {
sendMessage(JSON.stringify({ type, data }));
},
[sendMessage],
);
const setupPeerConnection = useCallback(async () => {
@ -267,6 +250,7 @@ export default function KvmIdRoute() {
? { iceServers: [iceConfig?.iceServers] }
: {}),
});
console.log("Peer connection created", pc);
setLoadingMessage("Peer connection created");
} catch (e) {
@ -282,21 +266,24 @@ export default function KvmIdRoute() {
console.log("Connection state changed", pc.connectionState);
};
pc.onicegatheringstatechange = event => {
const pc = event.currentTarget as RTCPeerConnection;
console.log("ICE Gathering State Changed", pc.iceGatheringState);
if (pc.iceGatheringState === "complete") {
console.log("ICE Gathering completed");
setLoadingMessage("ICE Gathering completed");
// We can now start the https/ws connection to get the remote session description from the KVM device
syncRemoteSessionDescription(pc);
} else if (pc.iceGatheringState === "gathering") {
console.log("ICE Gathering Started");
setLoadingMessage("Gathering ICE candidates...");
pc.onnegotiationneeded = async () => {
try {
const offer = await pc.createOffer();
await pc.setLocalDescription(offer);
const sd = btoa(JSON.stringify(pc.localDescription));
sendWebRTCSignal("offer", sd);
} catch (e) {
console.error(`Error creating offer: ${e}`, new Date().toISOString());
closePeerConnection();
}
};
pc.onicecandidate = async ({ candidate }) => {
if (!candidate) return;
if (candidate.candidate === "") return;
sendWebRTCSignal("new-ice-candidate", candidate);
};
pc.ontrack = function (event) {
setMediaMediaStream(event.streams[0]);
};
@ -314,31 +301,24 @@ export default function KvmIdRoute() {
};
setPeerConnection(pc);
try {
const offer = await pc.createOffer();
await pc.setLocalDescription(offer);
} catch (e) {
console.error(`Error creating offer: ${e}`, new Date().toISOString());
closePeerConnection();
}
}, [
closePeerConnection,
iceConfig?.iceServers,
sendWebRTCSignal,
setDiskChannel,
setMediaMediaStream,
setPeerConnection,
setRpcDataChannel,
setTransceiver,
syncRemoteSessionDescription,
]);
// On boot, if the connection state is undefined, we connect to the WebRTC
useEffect(() => {
if (readyState !== ReadyState.OPEN) return;
if (peerConnection?.connectionState === undefined) {
setupPeerConnection();
}
}, [setupPeerConnection, peerConnection?.connectionState]);
}, [readyState, setupPeerConnection, peerConnection?.connectionState]);
// Cleanup effect
const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats);
@ -593,7 +573,7 @@ export default function KvmIdRoute() {
/>
<div className="flex h-full w-full overflow-hidden">
<div className="pointer-events-none fixed inset-0 isolate z-50 flex h-full w-full items-center justify-center">
<div className="pointer-events-none fixed inset-0 isolate z-20 flex h-full w-full items-center justify-center">
<div className="my-2 h-full max-h-[720px] w-full max-w-[1280px] rounded-md">
<LoadingConnectionOverlay
show={

View File

@ -1,11 +1,15 @@
package kvm
import (
"context"
"encoding/base64"
"encoding/json"
"net"
"strings"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
"github.com/pion/webrtc/v4"
)
@ -23,6 +27,7 @@ type SessionConfig struct {
ICEServers []string
LocalIP string
IsCloud bool
ws *websocket.Conn
}
func (s *Session) ExchangeOffer(offerStr string) (string, error) {
@ -46,19 +51,11 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
return "", err
}
// Create channel that is blocked until ICE Gathering is complete
gatherComplete := webrtc.GatheringCompletePromise(s.peerConnection)
// Sets the LocalDescription, and starts our UDP listeners
if err = s.peerConnection.SetLocalDescription(answer); err != nil {
return "", err
}
// Block until ICE Gathering is complete, disabling trickle ICE
// we do this because we only can exchange one signaling message
// in a production application you should exchange ICE Candidates via OnICECandidate
<-gatherComplete
localDescription, err := json.Marshal(s.peerConnection.LocalDescription())
if err != nil {
return "", err
@ -144,6 +141,13 @@ func newSession(config SessionConfig) (*Session, error) {
}()
var isConnected bool
peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
cloudLogger.Infof("we got a new ICE candidate: %v", candidate)
if candidate != nil {
wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
}
})
peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
logger.Infof("Connection State has changed %s", connectionState)
if connectionState == webrtc.ICEConnectionStateConnected {