From 44ac37d11f36afa588df6230272f69558257ad7e Mon Sep 17 00:00:00 2001 From: Adam Shiervani Date: Sat, 5 Apr 2025 16:04:49 +0200 Subject: [PATCH] refactor: Enhance WebRTC signaling and connection handling --- cloud.go | 134 +++---------------------- dev_deploy.sh | 10 +- ui/src/components/Header.tsx | 6 +- ui/src/components/VideoOverlay.tsx | 53 +++++++++- ui/src/components/WebRTCVideo.tsx | 4 +- ui/src/routes/devices.$id.tsx | 141 ++++++++++++++++++++------ web.go | 156 +++++++++++++++++++++++++++++ webrtc.go | 2 +- 8 files changed, 342 insertions(+), 164 deletions(-) diff --git a/cloud.go b/cloud.go index fe77482..f69f416 100644 --- a/cloud.go +++ b/cloud.go @@ -11,7 +11,6 @@ import ( "time" "github.com/coder/websocket/wsjson" - "github.com/pion/webrtc/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -273,121 +272,10 @@ func runWebsocketClient() error { // set the metrics when we successfully connect to the cloud. cloudResetMetrics(true) - runCtx, cancelRun := context.WithCancel(context.Background()) - defer cancelRun() - go func() { - for { - time.Sleep(CloudWebSocketPingInterval) - - // set the timer for the ping duration - timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - metricCloudConnectionLastPingDuration.Set(v) - metricCloudConnectionPingDuration.Observe(v) - })) - - err := c.Ping(runCtx) - - if err != nil { - cloudLogger.Warnf("websocket ping error: %v", err) - cancelRun() - return - } - - // dont use `defer` here because we want to observe the duration of the ping - timer.ObserveDuration() - - metricCloudConnectionTotalPingCount.Inc() - metricCloudConnectionLastPingTimestamp.SetToCurrentTime() - } - }() - - // create a channel to receive the disconnect event, once received, we cancelRun - cloudDisconnectChan = make(chan error) - defer func() { - close(cloudDisconnectChan) - cloudDisconnectChan = nil - }() - go func() { - for err := range cloudDisconnectChan { - if err == nil { - continue - } - cloudLogger.Infof("disconnecting from cloud due to: %v", err) - cancelRun() - } - }() - - for { - typ, msg, err := c.Read(runCtx) - if err != nil { - return err - } - if typ != websocket.MessageText { - // ignore non-text messages - continue - } - - var message struct { - Type string `json:"type"` - Data json.RawMessage `json:"data"` - } - - err = json.Unmarshal(msg, &message) - if err != nil { - cloudLogger.Warnf("unable to parse ws message: %v", string(msg)) - continue - } - - if message.Type == "offer" { - cloudLogger.Infof("new session request received") - var req WebRTCSessionRequest - err = json.Unmarshal(message.Data, &req) - if err != nil { - cloudLogger.Warnf("unable to parse session request data: %v", string(message.Data)) - continue - } - - cloudLogger.Infof("new session request: %v", req.OidcGoogle) - cloudLogger.Tracef("session request info: %v", req) - - metricCloudConnectionSessionRequestCount.Inc() - metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime() - err = handleSessionRequest(runCtx, c, req) - if err != nil { - cloudLogger.Infof("error starting new session: %v", err) - continue - } - } else if message.Type == "new-ice-candidate" { - cloudLogger.Infof("client has sent us a new ICE candidate: %v", string(message.Data)) - var candidate webrtc.ICECandidateInit - - // Attempt to unmarshal as a ICECandidateInit - if err := json.Unmarshal(message.Data, &candidate); err != nil { - cloudLogger.Warnf("unable to parse ICE candidate data: %v", string(message.Data)) - continue - } - - if candidate.Candidate == "" { - cloudLogger.Warnf("empty ICE candidate, skipping") - continue - } - - cloudLogger.Infof("unmarshalled ICE candidate: %v", candidate) - - if currentSession == nil { - cloudLogger.Infof("no current session, skipping ICE candidate") - continue - } - - cloudLogger.Infof("adding ICE candidate to current session: %v", candidate) - if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil { - cloudLogger.Warnf("failed to add ICE candidate: %v", err) - } - } - } + return handleWebRTCSignalWsConnection(c, true) } -func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { +func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { metricCloudConnectionLastSessionRequestDuration.Set(v) metricCloudConnectionSessionRequestDuration.Observe(v) @@ -421,12 +309,18 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess return fmt.Errorf("google identity mismatch") } - session, err := newSession(SessionConfig{ - ICEServers: req.ICEServers, - LocalIP: req.IP, - IsCloud: true, - ws: c, - }) + return nil +} + +func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool) error { + // If the message is from the cloud, we need to authenticate the session. + if isCloudConnection { + if err := authenticateSession(ctx, c, req); err != nil { + return err + } + } + + session, err := newSession(SessionConfig{ws: c}) if err != nil { _ = wsjson.Write(context.Background(), c, gin.H{"error": err}) return err diff --git a/dev_deploy.sh b/dev_deploy.sh index 7fbf29e..c3e3716 100755 --- a/dev_deploy.sh +++ b/dev_deploy.sh @@ -67,10 +67,10 @@ make build_dev cd bin # Kill any existing instances of the application -ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app_debug || true" +ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app || true" # Copy the binary to the remote host -cat jetkvm_app | ssh "${REMOTE_USER}@${REMOTE_HOST}" "cat > $REMOTE_PATH/jetkvm_app_debug" +cat jetkvm_app | ssh "${REMOTE_USER}@${REMOTE_HOST}" "cat > $REMOTE_PATH/jetkvm_app" # Deploy and run the application on the remote host ssh "${REMOTE_USER}@${REMOTE_HOST}" ash < state.peerConnection); + const peerConnectionState = useRTCStore(state => state.peerConnectionState); const setUser = useUserStore(state => state.setUser); const navigate = useNavigate(); const onLogout = useCallback(async () => { @@ -82,14 +82,14 @@ export default function DashboardNavbar({
diff --git a/ui/src/components/VideoOverlay.tsx b/ui/src/components/VideoOverlay.tsx index 0620af4..fe1f120 100644 --- a/ui/src/components/VideoOverlay.tsx +++ b/ui/src/components/VideoOverlay.tsx @@ -94,7 +94,7 @@ interface ConnectionErrorOverlayProps { setupPeerConnection: () => Promise; } -export function ConnectionErrorOverlay({ +export function ConnectionFailedOverlay({ show, setupPeerConnection, }: ConnectionErrorOverlayProps) { @@ -151,6 +151,57 @@ export function ConnectionErrorOverlay({ ); } +interface PeerConnectionDisconnectedOverlay { + show: boolean; + setupPeerConnection: () => Promise; +} + +export function PeerConnectionDisconnectedOverlay({ show }: ConnectionErrorOverlayProps) { + return ( + + {show && ( + + +
+ +
+
+
+

Connection Issue Detected

+
    +
  • Verify that the device is powered on and properly connected
  • +
  • Check all cable connections for any loose or damaged wires
  • +
  • Ensure your network connection is stable and active
  • +
  • Try restarting both the device and your computer
  • +
+
+
+
+ +

+ Retrying connection... +

+
+
+
+
+
+
+
+ )} +
+ ); +} + interface HDMIErrorOverlayProps { show: boolean; hdmiState: string; diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index 5d8fb55..99c0191 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -380,7 +380,7 @@ export default function WebRTCVideo() { (mediaStream: MediaStream) => { if (!videoElm.current) return; const videoElmRefValue = videoElm.current; - console.log("Adding stream to video element", videoElmRefValue); + // console.log("Adding stream to video element", videoElmRefValue); videoElmRefValue.srcObject = mediaStream; updateVideoSizeStore(videoElmRefValue); }, @@ -396,7 +396,7 @@ export default function WebRTCVideo() { peerConnection.addEventListener( "track", (e: RTCTrackEvent) => { - console.log("Adding stream to video element"); + // console.log("Adding stream to video element"); addStreamToVideoElm(e.streams[0]); }, { signal }, diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index 08ed023..e148fcc 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -14,7 +14,7 @@ import { import { useInterval } from "usehooks-ts"; import FocusTrap from "focus-trap-react"; import { motion, AnimatePresence } from "framer-motion"; -import useWebSocket, { ReadyState } from "react-use-websocket"; +import useWebSocket from "react-use-websocket"; import { cx } from "@/cva.config"; import { @@ -47,8 +47,9 @@ import { useDeviceUiNavigation } from "../hooks/useAppNavigation"; import { FeatureFlagProvider } from "../providers/FeatureFlagProvider"; import notifications from "../notifications"; import { - ConnectionErrorOverlay, + ConnectionFailedOverlay, LoadingConnectionOverlay, + PeerConnectionDisconnectedOverlay, } from "../components/VideoOverlay"; import { SystemVersionInfo } from "./devices.$id.settings.general.update"; @@ -130,6 +131,8 @@ export default function KvmIdRoute() { const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse); const peerConnection = useRTCStore(state => state.peerConnection); + const setPeerConnectionState = useRTCStore(state => state.setPeerConnectionState); + const peerConnectionState = useRTCStore(state => state.peerConnectionState); const setMediaMediaStream = useRTCStore(state => state.setMediaStream); const setPeerConnection = useRTCStore(state => state.setPeerConnection); const setDiskChannel = useRTCStore(state => state.setDiskChannel); @@ -176,11 +179,19 @@ export default function KvmIdRoute() { remoteDescription: RTCSessionDescriptionInit, ) { setLoadingMessage("Setting remote description"); - pc.setRemoteDescription(new RTCSessionDescription(remoteDescription)); - console.log("Waiting for remote description to be set"); - let attempts = 0; - setLoadingMessage("Establishing secure connection..."); + try { + await pc.setRemoteDescription(new RTCSessionDescription(remoteDescription)); + console.log("Remote description set successfully"); + setLoadingMessage("Establishing secure connection..."); + } catch (error) { + console.error("Failed to set remote description:", error); + closePeerConnection(); + return; + } + + // Replace the interval-based check with a more reliable approach + let attempts = 0; const checkInterval = setInterval(() => { attempts++; @@ -189,29 +200,71 @@ export default function KvmIdRoute() { console.log("Remote description set"); clearInterval(checkInterval); } else if (attempts >= 10) { - console.log("Failed to get remote description after 10 attempts"); + console.log("Failed to establish connection after 10 attempts"); closePeerConnection(); clearInterval(checkInterval); } else { - console.log("Waiting for remote description to be set"); + console.log( + "Waiting for connection, state:", + pc.connectionState, + pc.iceConnectionState, + ); } }, 1000); }, [closePeerConnection], ); - // TODO: Handle auth!!! The old signaling http request could get a 401 on local and on cloud - const { sendMessage, readyState } = useWebSocket( - isOnDevice ? `${DEVICE_API}/client` : `${CLOUD_API}/client?id=${params.id}`, + const ignoreOffer = useRef(false); + const isSettingRemoteAnswerPending = useRef(false); + + const { sendMessage } = useWebSocket( + isOnDevice + ? `ws://192.168.1.77/webrtc/signaling` + : `${CLOUD_API.replace("http", "ws")}/webrtc/signaling?id=${params.id}`, { heartbeat: true, + retryOnError: true, + reconnectAttempts: 5, + reconnectInterval: 1000, + onReconnectStop: () => { + console.log("Reconnect stopped"); + closePeerConnection(); + }, + shouldReconnect(event) { + console.log("shouldReconnect", event); + return true; + }, + onClose(event) { + console.log("onClose", event); + }, + onError(event) { + console.log("onError", event); + }, + onOpen(event) { + console.log("onOpen", event); + console.log("signalingState", peerConnection?.signalingState); + setupPeerConnection(); + }, + onMessage: message => { if (message.data === "pong") return; if (!peerConnection) return; - + console.log("Received WebSocket message:", message.data); const parsedMessage = JSON.parse(message.data); if (parsedMessage.type === "answer") { - console.log("Setting remote description", parsedMessage.data); + const polite = false; + const readyForOffer = + !makingOffer && + (peerConnection?.signalingState === "stable" || + isSettingRemoteAnswerPending.current); + const offerCollision = parsedMessage.type === "offer" && !readyForOffer; + + ignoreOffer.current = !polite && offerCollision; + if (ignoreOffer.current) return; + + isSettingRemoteAnswerPending.current = parsedMessage.type == "answer"; + const sd = atob(parsedMessage.data); const remoteSessionDescription = JSON.parse(sd); @@ -219,27 +272,35 @@ export default function KvmIdRoute() { peerConnection, new RTCSessionDescription(remoteSessionDescription), ); + + isSettingRemoteAnswerPending.current = false; } else if (parsedMessage.type === "new-ice-candidate") { - console.log("Received new ICE candidate", parsedMessage.data); const candidate = parsedMessage.data; peerConnection.addIceCandidate(candidate); } }, }, + + connectionFailed ? false : true, ); const sendWebRTCSignal = useCallback( - (type: string, data: string | RTCIceCandidate) => { + (type: string, data: any) => { sendMessage(JSON.stringify({ type, data })); }, [sendMessage], ); - + const makingOffer = useRef(false); const setupPeerConnection = useCallback(async () => { console.log("Setting up peer connection"); setConnectionFailed(false); setLoadingMessage("Connecting to device..."); + if (peerConnection?.signalingState === "stable") { + console.log("Peer connection already established"); + return; + } + let pc: RTCPeerConnection; try { console.log("Creating peer connection"); @@ -264,17 +325,23 @@ export default function KvmIdRoute() { // Set up event listeners and data channels pc.onconnectionstatechange = () => { console.log("Connection state changed", pc.connectionState); + setPeerConnectionState(pc.connectionState); }; pc.onnegotiationneeded = async () => { try { + console.log("Creating offer"); + makingOffer.current = true; + const offer = await pc.createOffer(); await pc.setLocalDescription(offer); const sd = btoa(JSON.stringify(pc.localDescription)); - sendWebRTCSignal("offer", sd); + sendWebRTCSignal("offer", { sd: sd }); } catch (e) { console.error(`Error creating offer: ${e}`, new Date().toISOString()); closePeerConnection(); + } finally { + makingOffer.current = false; } }; @@ -308,17 +375,17 @@ export default function KvmIdRoute() { setDiskChannel, setMediaMediaStream, setPeerConnection, + setPeerConnectionState, setRpcDataChannel, setTransceiver, ]); - // On boot, if the connection state is undefined, we connect to the WebRTC useEffect(() => { - if (readyState !== ReadyState.OPEN) return; - if (peerConnection?.connectionState === undefined) { - setupPeerConnection(); + if (peerConnectionState === "failed") { + console.log("Connection failed, closing peer connection"); + closePeerConnection(); } - }, [readyState, setupPeerConnection, peerConnection?.connectionState]); + }, [peerConnectionState, closePeerConnection]); // Cleanup effect const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats); @@ -343,7 +410,7 @@ export default function KvmIdRoute() { // TURN server usage detection useEffect(() => { - if (peerConnection?.connectionState !== "connected") return; + if (peerConnectionState !== "connected") return; const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState(); const lastLocalStat = Array.from(localCandidateStats).pop(); @@ -355,7 +422,7 @@ export default function KvmIdRoute() { const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn); - }, [peerConnection?.connectionState, setIsTurnServerInUse]); + }, [peerConnectionState, setIsTurnServerInUse]); // TURN server usage reporting const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse); @@ -486,12 +553,12 @@ export default function KvmIdRoute() { useEffect(() => { if (!peerConnection) return; if (!kvmTerminal) { - console.log('Creating data channel "terminal"'); + // console.log('Creating data channel "terminal"'); setKvmTerminal(peerConnection.createDataChannel("terminal")); } if (!serialConsole) { - console.log('Creating data channel "serial"'); + // console.log('Creating data channel "serial"'); setSerialConsole(peerConnection.createDataChannel("serial")); } }, [kvmTerminal, peerConnection, serialConsole]); @@ -578,22 +645,32 @@ export default function KvmIdRoute() { - + + - + {peerConnectionState === "connected" && } diff --git a/web.go b/web.go index 9201e7b..b32bec6 100644 --- a/web.go +++ b/web.go @@ -1,6 +1,8 @@ package kvm import ( + "context" + "crypto/sha256" "embed" "encoding/json" "fmt" @@ -8,10 +10,14 @@ import ( "net/http" "path/filepath" "strings" + "sync" "time" + "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/pion/webrtc/v4" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/crypto/bcrypt" ) @@ -89,6 +95,7 @@ func setupRouter() *gin.Engine { // A Prometheus metrics endpoint. r.GET("/metrics", gin.WrapH(promhttp.Handler())) + r.GET("/webrtc/signaling", handleWebRTCSignal) // Protected routes (allows both password and noPassword modes) protected := r.Group("/") @@ -121,6 +128,155 @@ func setupRouter() *gin.Engine { // TODO: support multiple sessions? var currentSession *Session +func handleWebRTCSignal(c *gin.Context) { + cloudLogger.Infof("new websocket connection established") + // Create WebSocket options with InsecureSkipVerify to bypass origin check + wsOptions := &websocket.AcceptOptions{ + InsecureSkipVerify: true, // Allow connections from any origin + } + + wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Now use conn for websocket operations + defer wsCon.Close(websocket.StatusNormalClosure, "") + err = handleWebRTCSignalWsConnection(wsCon, false) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } +} + +func handleWebRTCSignalWsConnection(wsCon *websocket.Conn, isCloudConnection bool) error { + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + + // Add connection tracking to detect reconnections + connectionID := uuid.New().String() + cloudLogger.Infof("new websocket connection established with ID: %s", connectionID) + + // Add a mutex to protect against concurrent access to session state + sessionMutex := &sync.Mutex{} + + // Track processed offers to avoid duplicates + processedOffers := make(map[string]bool) + + go func() { + for { + time.Sleep(CloudWebSocketPingInterval) + + // set the timer for the ping duration + timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { + metricCloudConnectionLastPingDuration.Set(v) + metricCloudConnectionPingDuration.Observe(v) + })) + + cloudLogger.Infof("pinging websocket") + err := wsCon.Ping(runCtx) + + if err != nil { + cloudLogger.Warnf("websocket ping error: %v", err) + cancelRun() + return + } + + // dont use `defer` here because we want to observe the duration of the ping + timer.ObserveDuration() + + metricCloudConnectionTotalPingCount.Inc() + metricCloudConnectionLastPingTimestamp.SetToCurrentTime() + } + }() + + for { + typ, msg, err := wsCon.Read(runCtx) + if err != nil { + cloudLogger.Warnf("websocket read error: %v", err) + return err + } + if typ != websocket.MessageText { + // ignore non-text messages + continue + } + + var message struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + } + + err = json.Unmarshal(msg, &message) + if err != nil { + cloudLogger.Warnf("unable to parse ws message: %v", string(msg)) + continue + } + + if message.Type == "offer" { + cloudLogger.Infof("new session request received") + var req WebRTCSessionRequest + err = json.Unmarshal(message.Data, &req) + if err != nil { + cloudLogger.Warnf("unable to parse session request data: %v", string(message.Data)) + continue + } + + // Create a hash of the offer to deduplicate + offerHash := fmt.Sprintf("%x", sha256.Sum256(message.Data)) + + sessionMutex.Lock() + isDuplicate := processedOffers[offerHash] + if !isDuplicate { + processedOffers[offerHash] = true + } + sessionMutex.Unlock() + + if isDuplicate { + cloudLogger.Infof("duplicate offer detected, ignoring: %s", offerHash[:8]) + continue + } + + cloudLogger.Infof("new session request: %v", req.OidcGoogle) + cloudLogger.Tracef("session request info: %v", req) + + metricCloudConnectionSessionRequestCount.Inc() + metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime() + err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection) + if err != nil { + cloudLogger.Infof("error starting new session: %v", err) + continue + } + } else if message.Type == "new-ice-candidate" { + cloudLogger.Infof("client has sent us a new ICE candidate: %v", string(message.Data)) + var candidate webrtc.ICECandidateInit + + // Attempt to unmarshal as a ICECandidateInit + if err := json.Unmarshal(message.Data, &candidate); err != nil { + cloudLogger.Warnf("unable to parse ICE candidate data: %v", string(message.Data)) + continue + } + + if candidate.Candidate == "" { + cloudLogger.Warnf("empty ICE candidate, skipping") + continue + } + + cloudLogger.Infof("unmarshalled ICE candidate: %v", candidate) + + if currentSession == nil { + cloudLogger.Infof("no current session, skipping ICE candidate") + continue + } + + cloudLogger.Infof("adding ICE candidate to current session: %v", candidate) + if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil { + cloudLogger.Warnf("failed to add ICE candidate: %v", err) + } + } + } +} + func handleWebRTCSession(c *gin.Context) { var req WebRTCSessionRequest diff --git a/webrtc.go b/webrtc.go index a485b1f..642516d 100644 --- a/webrtc.go +++ b/webrtc.go @@ -142,7 +142,7 @@ func newSession(config SessionConfig) (*Session, error) { var isConnected bool peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { - cloudLogger.Infof("we got a new ICE candidate: %v", candidate) + cloudLogger.Infof("AAAAAAA got a new ICE candidate: %v", candidate) if candidate != nil { wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()}) }