refactor: Enhance WebRTC signaling and connection handling

This commit is contained in:
Adam Shiervani 2025-04-05 16:04:49 +02:00 committed by Siyuan Miao
parent 68f53dcc5f
commit 44ac37d11f
8 changed files with 342 additions and 164 deletions

134
cloud.go
View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
"github.com/pion/webrtc/v4"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
@ -273,121 +272,10 @@ func runWebsocketClient() error {
// set the metrics when we successfully connect to the cloud. // set the metrics when we successfully connect to the cloud.
cloudResetMetrics(true) cloudResetMetrics(true)
runCtx, cancelRun := context.WithCancel(context.Background()) return handleWebRTCSignalWsConnection(c, true)
defer cancelRun()
go func() {
for {
time.Sleep(CloudWebSocketPingInterval)
// set the timer for the ping duration
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricCloudConnectionLastPingDuration.Set(v)
metricCloudConnectionPingDuration.Observe(v)
}))
err := c.Ping(runCtx)
if err != nil {
cloudLogger.Warnf("websocket ping error: %v", err)
cancelRun()
return
}
// dont use `defer` here because we want to observe the duration of the ping
timer.ObserveDuration()
metricCloudConnectionTotalPingCount.Inc()
metricCloudConnectionLastPingTimestamp.SetToCurrentTime()
}
}()
// create a channel to receive the disconnect event, once received, we cancelRun
cloudDisconnectChan = make(chan error)
defer func() {
close(cloudDisconnectChan)
cloudDisconnectChan = nil
}()
go func() {
for err := range cloudDisconnectChan {
if err == nil {
continue
}
cloudLogger.Infof("disconnecting from cloud due to: %v", err)
cancelRun()
}
}()
for {
typ, msg, err := c.Read(runCtx)
if err != nil {
return err
}
if typ != websocket.MessageText {
// ignore non-text messages
continue
}
var message struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
err = json.Unmarshal(msg, &message)
if err != nil {
cloudLogger.Warnf("unable to parse ws message: %v", string(msg))
continue
}
if message.Type == "offer" {
cloudLogger.Infof("new session request received")
var req WebRTCSessionRequest
err = json.Unmarshal(message.Data, &req)
if err != nil {
cloudLogger.Warnf("unable to parse session request data: %v", string(message.Data))
continue
}
cloudLogger.Infof("new session request: %v", req.OidcGoogle)
cloudLogger.Tracef("session request info: %v", req)
metricCloudConnectionSessionRequestCount.Inc()
metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime()
err = handleSessionRequest(runCtx, c, req)
if err != nil {
cloudLogger.Infof("error starting new session: %v", err)
continue
}
} else if message.Type == "new-ice-candidate" {
cloudLogger.Infof("client has sent us a new ICE candidate: %v", string(message.Data))
var candidate webrtc.ICECandidateInit
// Attempt to unmarshal as a ICECandidateInit
if err := json.Unmarshal(message.Data, &candidate); err != nil {
cloudLogger.Warnf("unable to parse ICE candidate data: %v", string(message.Data))
continue
}
if candidate.Candidate == "" {
cloudLogger.Warnf("empty ICE candidate, skipping")
continue
}
cloudLogger.Infof("unmarshalled ICE candidate: %v", candidate)
if currentSession == nil {
cloudLogger.Infof("no current session, skipping ICE candidate")
continue
}
cloudLogger.Infof("adding ICE candidate to current session: %v", candidate)
if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
cloudLogger.Warnf("failed to add ICE candidate: %v", err)
}
}
}
} }
func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricCloudConnectionLastSessionRequestDuration.Set(v) metricCloudConnectionLastSessionRequestDuration.Set(v)
metricCloudConnectionSessionRequestDuration.Observe(v) metricCloudConnectionSessionRequestDuration.Observe(v)
@ -421,12 +309,18 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
return fmt.Errorf("google identity mismatch") return fmt.Errorf("google identity mismatch")
} }
session, err := newSession(SessionConfig{ return nil
ICEServers: req.ICEServers, }
LocalIP: req.IP,
IsCloud: true, func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool) error {
ws: c, // If the message is from the cloud, we need to authenticate the session.
}) if isCloudConnection {
if err := authenticateSession(ctx, c, req); err != nil {
return err
}
}
session, err := newSession(SessionConfig{ws: c})
if err != nil { if err != nil {
_ = wsjson.Write(context.Background(), c, gin.H{"error": err}) _ = wsjson.Write(context.Background(), c, gin.H{"error": err})
return err return err

View File

@ -67,10 +67,10 @@ make build_dev
cd bin cd bin
# Kill any existing instances of the application # Kill any existing instances of the application
ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app_debug || true" ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app || true"
# Copy the binary to the remote host # Copy the binary to the remote host
cat jetkvm_app | ssh "${REMOTE_USER}@${REMOTE_HOST}" "cat > $REMOTE_PATH/jetkvm_app_debug" cat jetkvm_app | ssh "${REMOTE_USER}@${REMOTE_HOST}" "cat > $REMOTE_PATH/jetkvm_app"
# Deploy and run the application on the remote host # Deploy and run the application on the remote host
ssh "${REMOTE_USER}@${REMOTE_HOST}" ash <<EOF ssh "${REMOTE_USER}@${REMOTE_HOST}" ash <<EOF
@ -81,16 +81,16 @@ export LD_LIBRARY_PATH=/oem/usr/lib:\$LD_LIBRARY_PATH
# Kill any existing instances of the application # Kill any existing instances of the application
killall jetkvm_app || true killall jetkvm_app || true
killall jetkvm_app_debug || true killall jetkvm_app || true
# Navigate to the directory where the binary will be stored # Navigate to the directory where the binary will be stored
cd "$REMOTE_PATH" cd "$REMOTE_PATH"
# Make the new binary executable # Make the new binary executable
chmod +x jetkvm_app_debug chmod +x jetkvm_app
# Run the application in the background # Run the application in the background
PION_LOG_TRACE=jetkvm,cloud ./jetkvm_app_debug PION_LOG_TRACE=jetkvm,cloud ./jetkvm_app
EOF EOF
echo "Deployment complete." echo "Deployment complete."

View File

@ -36,7 +36,7 @@ export default function DashboardNavbar({
picture, picture,
kvmName, kvmName,
}: NavbarProps) { }: NavbarProps) {
const peerConnection = useRTCStore(state => state.peerConnection); const peerConnectionState = useRTCStore(state => state.peerConnectionState);
const setUser = useUserStore(state => state.setUser); const setUser = useUserStore(state => state.setUser);
const navigate = useNavigate(); const navigate = useNavigate();
const onLogout = useCallback(async () => { const onLogout = useCallback(async () => {
@ -82,14 +82,14 @@ export default function DashboardNavbar({
<div className="hidden items-center gap-x-2 md:flex"> <div className="hidden items-center gap-x-2 md:flex">
<div className="w-[159px]"> <div className="w-[159px]">
<PeerConnectionStatusCard <PeerConnectionStatusCard
state={peerConnection?.connectionState} state={peerConnectionState}
title={kvmName} title={kvmName}
/> />
</div> </div>
<div className="hidden w-[159px] md:block"> <div className="hidden w-[159px] md:block">
<USBStateStatus <USBStateStatus
state={usbState} state={usbState}
peerConnectionState={peerConnection?.connectionState} peerConnectionState={peerConnectionState}
/> />
</div> </div>
</div> </div>

View File

@ -94,7 +94,7 @@ interface ConnectionErrorOverlayProps {
setupPeerConnection: () => Promise<void>; setupPeerConnection: () => Promise<void>;
} }
export function ConnectionErrorOverlay({ export function ConnectionFailedOverlay({
show, show,
setupPeerConnection, setupPeerConnection,
}: ConnectionErrorOverlayProps) { }: ConnectionErrorOverlayProps) {
@ -151,6 +151,57 @@ export function ConnectionErrorOverlay({
); );
} }
interface PeerConnectionDisconnectedOverlay {
show: boolean;
setupPeerConnection: () => Promise<void>;
}
export function PeerConnectionDisconnectedOverlay({ show }: ConnectionErrorOverlayProps) {
return (
<AnimatePresence>
{show && (
<motion.div
className="aspect-video h-full w-full"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0, transition: { duration: 0 } }}
transition={{
duration: 0.4,
ease: "easeInOut",
}}
>
<OverlayContent>
<div className="flex flex-col items-start gap-y-1">
<ExclamationTriangleIcon className="h-12 w-12 text-yellow-500" />
<div className="text-left text-sm text-slate-700 dark:text-slate-300">
<div className="space-y-4">
<div className="space-y-2 text-black dark:text-white">
<h2 className="text-xl font-bold">Connection Issue Detected</h2>
<ul className="list-disc space-y-2 pl-4 text-left">
<li>Verify that the device is powered on and properly connected</li>
<li>Check all cable connections for any loose or damaged wires</li>
<li>Ensure your network connection is stable and active</li>
<li>Try restarting both the device and your computer</li>
</ul>
</div>
<div className="flex items-center gap-x-2">
<div className="flex flex-col items-center gap-y-2">
<LoadingSpinner className="h-4 w-4 text-blue-800 dark:text-blue-200" />
<p className="text-sm text-slate-700 dark:text-slate-300">
Retrying connection...
</p>
</div>
</div>
</div>
</div>
</div>
</OverlayContent>
</motion.div>
)}
</AnimatePresence>
);
}
interface HDMIErrorOverlayProps { interface HDMIErrorOverlayProps {
show: boolean; show: boolean;
hdmiState: string; hdmiState: string;

View File

@ -380,7 +380,7 @@ export default function WebRTCVideo() {
(mediaStream: MediaStream) => { (mediaStream: MediaStream) => {
if (!videoElm.current) return; if (!videoElm.current) return;
const videoElmRefValue = videoElm.current; const videoElmRefValue = videoElm.current;
console.log("Adding stream to video element", videoElmRefValue); // console.log("Adding stream to video element", videoElmRefValue);
videoElmRefValue.srcObject = mediaStream; videoElmRefValue.srcObject = mediaStream;
updateVideoSizeStore(videoElmRefValue); updateVideoSizeStore(videoElmRefValue);
}, },
@ -396,7 +396,7 @@ export default function WebRTCVideo() {
peerConnection.addEventListener( peerConnection.addEventListener(
"track", "track",
(e: RTCTrackEvent) => { (e: RTCTrackEvent) => {
console.log("Adding stream to video element"); // console.log("Adding stream to video element");
addStreamToVideoElm(e.streams[0]); addStreamToVideoElm(e.streams[0]);
}, },
{ signal }, { signal },

View File

@ -14,7 +14,7 @@ import {
import { useInterval } from "usehooks-ts"; import { useInterval } from "usehooks-ts";
import FocusTrap from "focus-trap-react"; import FocusTrap from "focus-trap-react";
import { motion, AnimatePresence } from "framer-motion"; import { motion, AnimatePresence } from "framer-motion";
import useWebSocket, { ReadyState } from "react-use-websocket"; import useWebSocket from "react-use-websocket";
import { cx } from "@/cva.config"; import { cx } from "@/cva.config";
import { import {
@ -47,8 +47,9 @@ import { useDeviceUiNavigation } from "../hooks/useAppNavigation";
import { FeatureFlagProvider } from "../providers/FeatureFlagProvider"; import { FeatureFlagProvider } from "../providers/FeatureFlagProvider";
import notifications from "../notifications"; import notifications from "../notifications";
import { import {
ConnectionErrorOverlay, ConnectionFailedOverlay,
LoadingConnectionOverlay, LoadingConnectionOverlay,
PeerConnectionDisconnectedOverlay,
} from "../components/VideoOverlay"; } from "../components/VideoOverlay";
import { SystemVersionInfo } from "./devices.$id.settings.general.update"; import { SystemVersionInfo } from "./devices.$id.settings.general.update";
@ -130,6 +131,8 @@ export default function KvmIdRoute() {
const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse); const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse);
const peerConnection = useRTCStore(state => state.peerConnection); const peerConnection = useRTCStore(state => state.peerConnection);
const setPeerConnectionState = useRTCStore(state => state.setPeerConnectionState);
const peerConnectionState = useRTCStore(state => state.peerConnectionState);
const setMediaMediaStream = useRTCStore(state => state.setMediaStream); const setMediaMediaStream = useRTCStore(state => state.setMediaStream);
const setPeerConnection = useRTCStore(state => state.setPeerConnection); const setPeerConnection = useRTCStore(state => state.setPeerConnection);
const setDiskChannel = useRTCStore(state => state.setDiskChannel); const setDiskChannel = useRTCStore(state => state.setDiskChannel);
@ -176,11 +179,19 @@ export default function KvmIdRoute() {
remoteDescription: RTCSessionDescriptionInit, remoteDescription: RTCSessionDescriptionInit,
) { ) {
setLoadingMessage("Setting remote description"); setLoadingMessage("Setting remote description");
pc.setRemoteDescription(new RTCSessionDescription(remoteDescription));
console.log("Waiting for remote description to be set");
let attempts = 0;
try {
await pc.setRemoteDescription(new RTCSessionDescription(remoteDescription));
console.log("Remote description set successfully");
setLoadingMessage("Establishing secure connection..."); setLoadingMessage("Establishing secure connection...");
} catch (error) {
console.error("Failed to set remote description:", error);
closePeerConnection();
return;
}
// Replace the interval-based check with a more reliable approach
let attempts = 0;
const checkInterval = setInterval(() => { const checkInterval = setInterval(() => {
attempts++; attempts++;
@ -189,29 +200,71 @@ export default function KvmIdRoute() {
console.log("Remote description set"); console.log("Remote description set");
clearInterval(checkInterval); clearInterval(checkInterval);
} else if (attempts >= 10) { } else if (attempts >= 10) {
console.log("Failed to get remote description after 10 attempts"); console.log("Failed to establish connection after 10 attempts");
closePeerConnection(); closePeerConnection();
clearInterval(checkInterval); clearInterval(checkInterval);
} else { } else {
console.log("Waiting for remote description to be set"); console.log(
"Waiting for connection, state:",
pc.connectionState,
pc.iceConnectionState,
);
} }
}, 1000); }, 1000);
}, },
[closePeerConnection], [closePeerConnection],
); );
// TODO: Handle auth!!! The old signaling http request could get a 401 on local and on cloud const ignoreOffer = useRef(false);
const { sendMessage, readyState } = useWebSocket( const isSettingRemoteAnswerPending = useRef(false);
isOnDevice ? `${DEVICE_API}/client` : `${CLOUD_API}/client?id=${params.id}`,
const { sendMessage } = useWebSocket(
isOnDevice
? `ws://192.168.1.77/webrtc/signaling`
: `${CLOUD_API.replace("http", "ws")}/webrtc/signaling?id=${params.id}`,
{ {
heartbeat: true, heartbeat: true,
retryOnError: true,
reconnectAttempts: 5,
reconnectInterval: 1000,
onReconnectStop: () => {
console.log("Reconnect stopped");
closePeerConnection();
},
shouldReconnect(event) {
console.log("shouldReconnect", event);
return true;
},
onClose(event) {
console.log("onClose", event);
},
onError(event) {
console.log("onError", event);
},
onOpen(event) {
console.log("onOpen", event);
console.log("signalingState", peerConnection?.signalingState);
setupPeerConnection();
},
onMessage: message => { onMessage: message => {
if (message.data === "pong") return; if (message.data === "pong") return;
if (!peerConnection) return; if (!peerConnection) return;
console.log("Received WebSocket message:", message.data);
const parsedMessage = JSON.parse(message.data); const parsedMessage = JSON.parse(message.data);
if (parsedMessage.type === "answer") { if (parsedMessage.type === "answer") {
console.log("Setting remote description", parsedMessage.data); const polite = false;
const readyForOffer =
!makingOffer &&
(peerConnection?.signalingState === "stable" ||
isSettingRemoteAnswerPending.current);
const offerCollision = parsedMessage.type === "offer" && !readyForOffer;
ignoreOffer.current = !polite && offerCollision;
if (ignoreOffer.current) return;
isSettingRemoteAnswerPending.current = parsedMessage.type == "answer";
const sd = atob(parsedMessage.data); const sd = atob(parsedMessage.data);
const remoteSessionDescription = JSON.parse(sd); const remoteSessionDescription = JSON.parse(sd);
@ -219,27 +272,35 @@ export default function KvmIdRoute() {
peerConnection, peerConnection,
new RTCSessionDescription(remoteSessionDescription), new RTCSessionDescription(remoteSessionDescription),
); );
isSettingRemoteAnswerPending.current = false;
} else if (parsedMessage.type === "new-ice-candidate") { } else if (parsedMessage.type === "new-ice-candidate") {
console.log("Received new ICE candidate", parsedMessage.data);
const candidate = parsedMessage.data; const candidate = parsedMessage.data;
peerConnection.addIceCandidate(candidate); peerConnection.addIceCandidate(candidate);
} }
}, },
}, },
connectionFailed ? false : true,
); );
const sendWebRTCSignal = useCallback( const sendWebRTCSignal = useCallback(
(type: string, data: string | RTCIceCandidate) => { (type: string, data: any) => {
sendMessage(JSON.stringify({ type, data })); sendMessage(JSON.stringify({ type, data }));
}, },
[sendMessage], [sendMessage],
); );
const makingOffer = useRef(false);
const setupPeerConnection = useCallback(async () => { const setupPeerConnection = useCallback(async () => {
console.log("Setting up peer connection"); console.log("Setting up peer connection");
setConnectionFailed(false); setConnectionFailed(false);
setLoadingMessage("Connecting to device..."); setLoadingMessage("Connecting to device...");
if (peerConnection?.signalingState === "stable") {
console.log("Peer connection already established");
return;
}
let pc: RTCPeerConnection; let pc: RTCPeerConnection;
try { try {
console.log("Creating peer connection"); console.log("Creating peer connection");
@ -264,17 +325,23 @@ export default function KvmIdRoute() {
// Set up event listeners and data channels // Set up event listeners and data channels
pc.onconnectionstatechange = () => { pc.onconnectionstatechange = () => {
console.log("Connection state changed", pc.connectionState); console.log("Connection state changed", pc.connectionState);
setPeerConnectionState(pc.connectionState);
}; };
pc.onnegotiationneeded = async () => { pc.onnegotiationneeded = async () => {
try { try {
console.log("Creating offer");
makingOffer.current = true;
const offer = await pc.createOffer(); const offer = await pc.createOffer();
await pc.setLocalDescription(offer); await pc.setLocalDescription(offer);
const sd = btoa(JSON.stringify(pc.localDescription)); const sd = btoa(JSON.stringify(pc.localDescription));
sendWebRTCSignal("offer", sd); sendWebRTCSignal("offer", { sd: sd });
} catch (e) { } catch (e) {
console.error(`Error creating offer: ${e}`, new Date().toISOString()); console.error(`Error creating offer: ${e}`, new Date().toISOString());
closePeerConnection(); closePeerConnection();
} finally {
makingOffer.current = false;
} }
}; };
@ -308,17 +375,17 @@ export default function KvmIdRoute() {
setDiskChannel, setDiskChannel,
setMediaMediaStream, setMediaMediaStream,
setPeerConnection, setPeerConnection,
setPeerConnectionState,
setRpcDataChannel, setRpcDataChannel,
setTransceiver, setTransceiver,
]); ]);
// On boot, if the connection state is undefined, we connect to the WebRTC
useEffect(() => { useEffect(() => {
if (readyState !== ReadyState.OPEN) return; if (peerConnectionState === "failed") {
if (peerConnection?.connectionState === undefined) { console.log("Connection failed, closing peer connection");
setupPeerConnection(); closePeerConnection();
} }
}, [readyState, setupPeerConnection, peerConnection?.connectionState]); }, [peerConnectionState, closePeerConnection]);
// Cleanup effect // Cleanup effect
const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats); const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats);
@ -343,7 +410,7 @@ export default function KvmIdRoute() {
// TURN server usage detection // TURN server usage detection
useEffect(() => { useEffect(() => {
if (peerConnection?.connectionState !== "connected") return; if (peerConnectionState !== "connected") return;
const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState(); const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState();
const lastLocalStat = Array.from(localCandidateStats).pop(); const lastLocalStat = Array.from(localCandidateStats).pop();
@ -355,7 +422,7 @@ export default function KvmIdRoute() {
const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here
setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn); setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn);
}, [peerConnection?.connectionState, setIsTurnServerInUse]); }, [peerConnectionState, setIsTurnServerInUse]);
// TURN server usage reporting // TURN server usage reporting
const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse); const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse);
@ -486,12 +553,12 @@ export default function KvmIdRoute() {
useEffect(() => { useEffect(() => {
if (!peerConnection) return; if (!peerConnection) return;
if (!kvmTerminal) { if (!kvmTerminal) {
console.log('Creating data channel "terminal"'); // console.log('Creating data channel "terminal"');
setKvmTerminal(peerConnection.createDataChannel("terminal")); setKvmTerminal(peerConnection.createDataChannel("terminal"));
} }
if (!serialConsole) { if (!serialConsole) {
console.log('Creating data channel "serial"'); // console.log('Creating data channel "serial"');
setSerialConsole(peerConnection.createDataChannel("serial")); setSerialConsole(peerConnection.createDataChannel("serial"));
} }
}, [kvmTerminal, peerConnection, serialConsole]); }, [kvmTerminal, peerConnection, serialConsole]);
@ -578,22 +645,32 @@ export default function KvmIdRoute() {
<LoadingConnectionOverlay <LoadingConnectionOverlay
show={ show={
!connectionFailed && !connectionFailed &&
(["connecting", "new"].includes( peerConnectionState !== "disconnected" &&
peerConnection?.connectionState || "", (["connecting", "new"].includes(peerConnectionState || "") ||
) ||
peerConnection === null) && peerConnection === null) &&
!location.pathname.includes("other-session") !location.pathname.includes("other-session")
} }
text={loadingMessage} text={loadingMessage}
/> />
<ConnectionErrorOverlay <ConnectionFailedOverlay
show={connectionFailed && !location.pathname.includes("other-session")} show={
(connectionFailed || peerConnectionState === "failed") &&
!location.pathname.includes("other-session")
}
setupPeerConnection={setupPeerConnection}
/>
<PeerConnectionDisconnectedOverlay
show={
peerConnectionState === "disconnected" &&
!location.pathname.includes("other-session")
}
setupPeerConnection={setupPeerConnection} setupPeerConnection={setupPeerConnection}
/> />
</div> </div>
</div> </div>
<WebRTCVideo /> {peerConnectionState === "connected" && <WebRTCVideo />}
<SidebarContainer sidebarView={sidebarView} /> <SidebarContainer sidebarView={sidebarView} />
</div> </div>
</div> </div>

156
web.go
View File

@ -1,6 +1,8 @@
package kvm package kvm
import ( import (
"context"
"crypto/sha256"
"embed" "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -8,10 +10,14 @@ import (
"net/http" "net/http"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
"github.com/coder/websocket"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pion/webrtc/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -89,6 +95,7 @@ func setupRouter() *gin.Engine {
// A Prometheus metrics endpoint. // A Prometheus metrics endpoint.
r.GET("/metrics", gin.WrapH(promhttp.Handler())) r.GET("/metrics", gin.WrapH(promhttp.Handler()))
r.GET("/webrtc/signaling", handleWebRTCSignal)
// Protected routes (allows both password and noPassword modes) // Protected routes (allows both password and noPassword modes)
protected := r.Group("/") protected := r.Group("/")
@ -121,6 +128,155 @@ func setupRouter() *gin.Engine {
// TODO: support multiple sessions? // TODO: support multiple sessions?
var currentSession *Session var currentSession *Session
func handleWebRTCSignal(c *gin.Context) {
cloudLogger.Infof("new websocket connection established")
// Create WebSocket options with InsecureSkipVerify to bypass origin check
wsOptions := &websocket.AcceptOptions{
InsecureSkipVerify: true, // Allow connections from any origin
}
wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Now use conn for websocket operations
defer wsCon.Close(websocket.StatusNormalClosure, "")
err = handleWebRTCSignalWsConnection(wsCon, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func handleWebRTCSignalWsConnection(wsCon *websocket.Conn, isCloudConnection bool) error {
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
// Add connection tracking to detect reconnections
connectionID := uuid.New().String()
cloudLogger.Infof("new websocket connection established with ID: %s", connectionID)
// Add a mutex to protect against concurrent access to session state
sessionMutex := &sync.Mutex{}
// Track processed offers to avoid duplicates
processedOffers := make(map[string]bool)
go func() {
for {
time.Sleep(CloudWebSocketPingInterval)
// set the timer for the ping duration
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricCloudConnectionLastPingDuration.Set(v)
metricCloudConnectionPingDuration.Observe(v)
}))
cloudLogger.Infof("pinging websocket")
err := wsCon.Ping(runCtx)
if err != nil {
cloudLogger.Warnf("websocket ping error: %v", err)
cancelRun()
return
}
// dont use `defer` here because we want to observe the duration of the ping
timer.ObserveDuration()
metricCloudConnectionTotalPingCount.Inc()
metricCloudConnectionLastPingTimestamp.SetToCurrentTime()
}
}()
for {
typ, msg, err := wsCon.Read(runCtx)
if err != nil {
cloudLogger.Warnf("websocket read error: %v", err)
return err
}
if typ != websocket.MessageText {
// ignore non-text messages
continue
}
var message struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
err = json.Unmarshal(msg, &message)
if err != nil {
cloudLogger.Warnf("unable to parse ws message: %v", string(msg))
continue
}
if message.Type == "offer" {
cloudLogger.Infof("new session request received")
var req WebRTCSessionRequest
err = json.Unmarshal(message.Data, &req)
if err != nil {
cloudLogger.Warnf("unable to parse session request data: %v", string(message.Data))
continue
}
// Create a hash of the offer to deduplicate
offerHash := fmt.Sprintf("%x", sha256.Sum256(message.Data))
sessionMutex.Lock()
isDuplicate := processedOffers[offerHash]
if !isDuplicate {
processedOffers[offerHash] = true
}
sessionMutex.Unlock()
if isDuplicate {
cloudLogger.Infof("duplicate offer detected, ignoring: %s", offerHash[:8])
continue
}
cloudLogger.Infof("new session request: %v", req.OidcGoogle)
cloudLogger.Tracef("session request info: %v", req)
metricCloudConnectionSessionRequestCount.Inc()
metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime()
err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection)
if err != nil {
cloudLogger.Infof("error starting new session: %v", err)
continue
}
} else if message.Type == "new-ice-candidate" {
cloudLogger.Infof("client has sent us a new ICE candidate: %v", string(message.Data))
var candidate webrtc.ICECandidateInit
// Attempt to unmarshal as a ICECandidateInit
if err := json.Unmarshal(message.Data, &candidate); err != nil {
cloudLogger.Warnf("unable to parse ICE candidate data: %v", string(message.Data))
continue
}
if candidate.Candidate == "" {
cloudLogger.Warnf("empty ICE candidate, skipping")
continue
}
cloudLogger.Infof("unmarshalled ICE candidate: %v", candidate)
if currentSession == nil {
cloudLogger.Infof("no current session, skipping ICE candidate")
continue
}
cloudLogger.Infof("adding ICE candidate to current session: %v", candidate)
if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
cloudLogger.Warnf("failed to add ICE candidate: %v", err)
}
}
}
}
func handleWebRTCSession(c *gin.Context) { func handleWebRTCSession(c *gin.Context) {
var req WebRTCSessionRequest var req WebRTCSessionRequest

View File

@ -142,7 +142,7 @@ func newSession(config SessionConfig) (*Session, error) {
var isConnected bool var isConnected bool
peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
cloudLogger.Infof("we got a new ICE candidate: %v", candidate) cloudLogger.Infof("AAAAAAA got a new ICE candidate: %v", candidate)
if candidate != nil { if candidate != nil {
wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()}) wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
} }