diff --git a/cloud.go b/cloud.go index f91085a..7ad8b75 100644 --- a/cloud.go +++ b/cloud.go @@ -35,8 +35,8 @@ const ( // CloudOidcRequestTimeout is the timeout for OIDC token verification requests // should be lower than the websocket response timeout set in cloud-api CloudOidcRequestTimeout = 10 * time.Second - // CloudWebSocketPingInterval is the interval at which the websocket client sends ping messages to the cloud - CloudWebSocketPingInterval = 15 * time.Second + // WebsocketPingInterval is the interval at which the websocket client sends ping messages to the cloud + WebsocketPingInterval = 15 * time.Second ) var ( @@ -52,59 +52,67 @@ var ( Help: "The timestamp when the cloud connection was established", }, ) - metricCloudConnectionLastPingTimestamp = promauto.NewGauge( + metricConnectionLastPingTimestamp = promauto.NewGaugeVec( prometheus.GaugeOpts{ - Name: "jetkvm_cloud_connection_last_ping_timestamp", + Name: "jetkvm_connection_last_ping_timestamp", Help: "The timestamp when the last ping response was received", }, + []string{"type", "source"}, ) - metricCloudConnectionLastPingDuration = promauto.NewGauge( + metricConnectionLastPingDuration = promauto.NewGaugeVec( prometheus.GaugeOpts{ - Name: "jetkvm_cloud_connection_last_ping_duration", + Name: "jetkvm_connection_last_ping_duration", Help: "The duration of the last ping response", }, + []string{"type", "source"}, ) - metricCloudConnectionPingDuration = promauto.NewHistogram( + metricConnectionPingDuration = promauto.NewHistogramVec( prometheus.HistogramOpts{ - Name: "jetkvm_cloud_connection_ping_duration", + Name: "jetkvm_connection_ping_duration", Help: "The duration of the ping response", Buckets: []float64{ 0.1, 0.5, 1, 10, }, }, + []string{"type", "source"}, ) - metricCloudConnectionTotalPingCount = promauto.NewCounter( + metricConnectionTotalPingCount = promauto.NewCounterVec( prometheus.CounterOpts{ - Name: "jetkvm_cloud_connection_total_ping_count", - Help: "The total number of pings sent to the cloud", + Name: "jetkvm_connection_total_ping_count", + Help: "The total number of pings sent to the connection", }, + []string{"type", "source"}, ) - metricCloudConnectionSessionRequestCount = promauto.NewCounter( + metricConnectionSessionRequestCount = promauto.NewCounterVec( prometheus.CounterOpts{ - Name: "jetkvm_cloud_connection_session_total_request_count", - Help: "The total number of session requests received from the cloud", + Name: "jetkvm_connection_session_total_request_count", + Help: "The total number of session requests received", }, + []string{"type", "source"}, ) - metricCloudConnectionSessionRequestDuration = promauto.NewHistogram( + metricConnectionSessionRequestDuration = promauto.NewHistogramVec( prometheus.HistogramOpts{ - Name: "jetkvm_cloud_connection_session_request_duration", + Name: "jetkvm_connection_session_request_duration", Help: "The duration of session requests", Buckets: []float64{ 0.1, 0.5, 1, 10, }, }, + []string{"type", "source"}, ) - metricCloudConnectionLastSessionRequestTimestamp = promauto.NewGauge( + metricConnectionLastSessionRequestTimestamp = promauto.NewGaugeVec( prometheus.GaugeOpts{ - Name: "jetkvm_cloud_connection_last_session_request_timestamp", + Name: "jetkvm_connection_last_session_request_timestamp", Help: "The timestamp of the last session request", }, + []string{"type", "source"}, ) - metricCloudConnectionLastSessionRequestDuration = promauto.NewGauge( + metricConnectionLastSessionRequestDuration = promauto.NewGaugeVec( prometheus.GaugeOpts{ - Name: "jetkvm_cloud_connection_last_session_request_duration", + Name: "jetkvm_connection_last_session_request_duration", Help: "The duration of the last session request", }, + []string{"type", "source"}, ) metricCloudConnectionFailureCount = promauto.NewCounter( prometheus.CounterOpts{ @@ -119,12 +127,16 @@ var ( cloudDisconnectLock = &sync.Mutex{} ) -func cloudResetMetrics(established bool) { - metricCloudConnectionLastPingTimestamp.Set(-1) - metricCloudConnectionLastPingDuration.Set(-1) +func wsResetMetrics(established bool, sourceType string, source string) { + metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1) + metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1) - metricCloudConnectionLastSessionRequestTimestamp.Set(-1) - metricCloudConnectionLastSessionRequestDuration.Set(-1) + metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1) + metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1) + + if sourceType != "cloud" { + return + } if established { metricCloudConnectionEstablishedTimestamp.SetToCurrentTime() @@ -256,6 +268,7 @@ func runWebsocketClient() error { header := http.Header{} header.Set("X-Device-ID", GetDeviceID()) + header.Set("X-App-Version", builtAppVersion) header.Set("Authorization", "Bearer "+config.CloudToken) dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout) @@ -270,88 +283,13 @@ func runWebsocketClient() error { cloudLogger.Infof("websocket connected to %s", wsURL) // set the metrics when we successfully connect to the cloud. - cloudResetMetrics(true) + wsResetMetrics(true, "cloud", "") - runCtx, cancelRun := context.WithCancel(context.Background()) - defer cancelRun() - go func() { - for { - time.Sleep(CloudWebSocketPingInterval) - - // set the timer for the ping duration - timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - metricCloudConnectionLastPingDuration.Set(v) - metricCloudConnectionPingDuration.Observe(v) - })) - - err := c.Ping(runCtx) - - if err != nil { - cloudLogger.Warnf("websocket ping error: %v", err) - cancelRun() - return - } - - // dont use `defer` here because we want to observe the duration of the ping - timer.ObserveDuration() - - metricCloudConnectionTotalPingCount.Inc() - metricCloudConnectionLastPingTimestamp.SetToCurrentTime() - } - }() - - // create a channel to receive the disconnect event, once received, we cancelRun - cloudDisconnectChan = make(chan error) - defer func() { - close(cloudDisconnectChan) - cloudDisconnectChan = nil - }() - go func() { - for err := range cloudDisconnectChan { - if err == nil { - continue - } - cloudLogger.Infof("disconnecting from cloud due to: %v", err) - cancelRun() - } - }() - - for { - typ, msg, err := c.Read(runCtx) - if err != nil { - return err - } - if typ != websocket.MessageText { - // ignore non-text messages - continue - } - var req WebRTCSessionRequest - err = json.Unmarshal(msg, &req) - if err != nil { - cloudLogger.Warnf("unable to parse ws message: %v", string(msg)) - continue - } - - cloudLogger.Infof("new session request: %v", req.OidcGoogle) - cloudLogger.Tracef("session request info: %v", req) - - metricCloudConnectionSessionRequestCount.Inc() - metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime() - err = handleSessionRequest(runCtx, c, req) - if err != nil { - cloudLogger.Infof("error starting new session: %v", err) - continue - } - } + // we don't have a source for the cloud connection + return handleWebRTCSignalWsMessages(c, true, "") } -func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { - timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - metricCloudConnectionLastSessionRequestDuration.Set(v) - metricCloudConnectionSessionRequestDuration.Observe(v) - })) - defer timer.ObserveDuration() - +func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { oidcCtx, cancelOIDC := context.WithTimeout(ctx, CloudOidcRequestTimeout) defer cancelOIDC() provider, err := oidc.NewProvider(oidcCtx, "https://accounts.google.com") @@ -379,10 +317,35 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess return fmt.Errorf("google identity mismatch") } + return nil +} + +func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool, source string) error { + var sourceType string + if isCloudConnection { + sourceType = "cloud" + } else { + sourceType = "local" + } + + timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { + metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(v) + metricConnectionSessionRequestDuration.WithLabelValues(sourceType, source).Observe(v) + })) + defer timer.ObserveDuration() + + // If the message is from the cloud, we need to authenticate the session. + if isCloudConnection { + if err := authenticateSession(ctx, c, req); err != nil { + return err + } + } + session, err := newSession(SessionConfig{ - ICEServers: req.ICEServers, + ws: c, + IsCloud: isCloudConnection, LocalIP: req.IP, - IsCloud: true, + ICEServers: req.ICEServers, }) if err != nil { _ = wsjson.Write(context.Background(), c, gin.H{"error": err}) @@ -406,14 +369,14 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess cloudLogger.Info("new session accepted") cloudLogger.Tracef("new session accepted: %v", session) currentSession = session - _ = wsjson.Write(context.Background(), c, gin.H{"sd": sd}) + _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd}) return nil } func RunWebsocketClient() { for { // reset the metrics when we start the websocket client. - cloudResetMetrics(false) + wsResetMetrics(false, "cloud", "") // If the cloud token is not set, we don't need to run the websocket client. if config.CloudToken == "" { diff --git a/log.go b/log.go index 7718a28..0d36c0d 100644 --- a/log.go +++ b/log.go @@ -6,3 +6,4 @@ import "github.com/pion/logging" // ref: https://github.com/pion/webrtc/wiki/Debugging-WebRTC var logger = logging.NewDefaultLoggerFactory().NewLogger("jetkvm") var cloudLogger = logging.NewDefaultLoggerFactory().NewLogger("cloud") +var websocketLogger = logging.NewDefaultLoggerFactory().NewLogger("websocket") diff --git a/ui/package-lock.json b/ui/package-lock.json index e9caa20..ebce148 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -30,6 +30,7 @@ "react-icons": "^5.4.0", "react-router-dom": "^6.22.3", "react-simple-keyboard": "^3.7.112", + "react-use-websocket": "^4.13.0", "react-xtermjs": "^1.0.9", "recharts": "^2.15.0", "tailwind-merge": "^2.5.5", @@ -5180,6 +5181,11 @@ "react-dom": ">=16.6.0" } }, + "node_modules/react-use-websocket": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/react-use-websocket/-/react-use-websocket-4.13.0.tgz", + "integrity": "sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw==" + }, "node_modules/react-xtermjs": { "version": "1.0.9", "resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz", diff --git a/ui/package.json b/ui/package.json index f8f1c7a..a248616 100644 --- a/ui/package.json +++ b/ui/package.json @@ -40,6 +40,7 @@ "react-icons": "^5.4.0", "react-router-dom": "^6.22.3", "react-simple-keyboard": "^3.7.112", + "react-use-websocket": "^4.13.0", "react-xtermjs": "^1.0.9", "recharts": "^2.15.0", "tailwind-merge": "^2.5.5", diff --git a/ui/src/components/Header.tsx b/ui/src/components/Header.tsx index 03a907e..19e9652 100644 --- a/ui/src/components/Header.tsx +++ b/ui/src/components/Header.tsx @@ -36,7 +36,7 @@ export default function DashboardNavbar({ picture, kvmName, }: NavbarProps) { - const peerConnection = useRTCStore(state => state.peerConnection); + const peerConnectionState = useRTCStore(state => state.peerConnectionState); const setUser = useUserStore(state => state.setUser); const navigate = useNavigate(); const onLogout = useCallback(async () => { @@ -82,14 +82,14 @@ export default function DashboardNavbar({
diff --git a/ui/src/components/VideoOverlay.tsx b/ui/src/components/VideoOverlay.tsx index 0620af4..e34cf10 100644 --- a/ui/src/components/VideoOverlay.tsx +++ b/ui/src/components/VideoOverlay.tsx @@ -6,7 +6,7 @@ import { LuPlay } from "react-icons/lu"; import { Button, LinkButton } from "@components/Button"; import LoadingSpinner from "@components/LoadingSpinner"; -import { GridCard } from "@components/Card"; +import Card, { GridCard } from "@components/Card"; interface OverlayContentProps { children: React.ReactNode; @@ -94,7 +94,7 @@ interface ConnectionErrorOverlayProps { setupPeerConnection: () => Promise; } -export function ConnectionErrorOverlay({ +export function ConnectionFailedOverlay({ show, setupPeerConnection, }: ConnectionErrorOverlayProps) { @@ -151,6 +151,60 @@ export function ConnectionErrorOverlay({ ); } +interface PeerConnectionDisconnectedOverlay { + show: boolean; +} + +export function PeerConnectionDisconnectedOverlay({ + show, +}: PeerConnectionDisconnectedOverlay) { + return ( + + {show && ( + + +
+ +
+
+
+

Connection Issue Detected

+
    +
  • Verify that the device is powered on and properly connected
  • +
  • Check all cable connections for any loose or damaged wires
  • +
  • Ensure your network connection is stable and active
  • +
  • Try restarting both the device and your computer
  • +
+
+
+ +
+ +

+ Retrying connection... +

+
+
+
+
+
+
+
+
+ )} +
+ ); +} + interface HDMIErrorOverlayProps { show: boolean; hdmiState: string; diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index 5d8fb55..99c0191 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -380,7 +380,7 @@ export default function WebRTCVideo() { (mediaStream: MediaStream) => { if (!videoElm.current) return; const videoElmRefValue = videoElm.current; - console.log("Adding stream to video element", videoElmRefValue); + // console.log("Adding stream to video element", videoElmRefValue); videoElmRefValue.srcObject = mediaStream; updateVideoSizeStore(videoElmRefValue); }, @@ -396,7 +396,7 @@ export default function WebRTCVideo() { peerConnection.addEventListener( "track", (e: RTCTrackEvent) => { - console.log("Adding stream to video element"); + // console.log("Adding stream to video element"); addStreamToVideoElm(e.streams[0]); }, { signal }, diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index d2662fc..fef1764 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useRef, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { LoaderFunctionArgs, Outlet, @@ -14,6 +14,7 @@ import { import { useInterval } from "usehooks-ts"; import FocusTrap from "focus-trap-react"; import { motion, AnimatePresence } from "framer-motion"; +import useWebSocket from "react-use-websocket"; import { cx } from "@/cva.config"; import { @@ -43,15 +44,16 @@ import UpdateInProgressStatusCard from "../components/UpdateInProgressStatusCard import api from "../api"; import Modal from "../components/Modal"; import { useDeviceUiNavigation } from "../hooks/useAppNavigation"; +import { + ConnectionFailedOverlay, + LoadingConnectionOverlay, + PeerConnectionDisconnectedOverlay, +} from "../components/VideoOverlay"; import { FeatureFlagProvider } from "../providers/FeatureFlagProvider"; import notifications from "../notifications"; -import { - ConnectionErrorOverlay, - LoadingConnectionOverlay, -} from "../components/VideoOverlay"; -import { SystemVersionInfo } from "./devices.$id.settings.general.update"; import { DeviceStatus } from "./welcome-local"; +import { SystemVersionInfo } from "./devices.$id.settings.general.update"; interface LocalLoaderResp { authMode: "password" | "noPassword" | null; @@ -117,7 +119,6 @@ const loader = async ({ params }: LoaderFunctionArgs) => { export default function KvmIdRoute() { const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp; - // Depending on the mode, we set the appropriate variables const user = "user" in loaderResp ? loaderResp.user : null; const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null; @@ -130,6 +131,8 @@ export default function KvmIdRoute() { const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse); const peerConnection = useRTCStore(state => state.peerConnection); + const setPeerConnectionState = useRTCStore(state => state.setPeerConnectionState); + const peerConnectionState = useRTCStore(state => state.peerConnectionState); const setMediaMediaStream = useRTCStore(state => state.setMediaStream); const setPeerConnection = useRTCStore(state => state.setPeerConnection); const setDiskChannel = useRTCStore(state => state.setDiskChannel); @@ -137,23 +140,28 @@ export default function KvmIdRoute() { const setTransceiver = useRTCStore(state => state.setTransceiver); const location = useLocation(); + const isLegacySignalingEnabled = useRef(false); + const [connectionFailed, setConnectionFailed] = useState(false); const navigate = useNavigate(); const { otaState, setOtaState, setModalView } = useUpdateStore(); const [loadingMessage, setLoadingMessage] = useState("Connecting to device..."); - const closePeerConnection = useCallback( - function closePeerConnection() { + const cleanupAndStopReconnecting = useCallback( + function cleanupAndStopReconnecting() { console.log("Closing peer connection"); setConnectionFailed(true); + if (peerConnection) { + setPeerConnectionState(peerConnection.connectionState); + } connectionFailedRef.current = true; peerConnection?.close(); signalingAttempts.current = 0; }, - [peerConnection], + [peerConnection, setPeerConnectionState], ); // We need to track connectionFailed in a ref to avoid stale closure issues @@ -171,95 +179,233 @@ export default function KvmIdRoute() { }, [connectionFailed]); const signalingAttempts = useRef(0); - const syncRemoteSessionDescription = useCallback( - async function syncRemoteSessionDescription(pc: RTCPeerConnection) { + const setRemoteSessionDescription = useCallback( + async function setRemoteSessionDescription( + pc: RTCPeerConnection, + remoteDescription: RTCSessionDescriptionInit, + ) { + setLoadingMessage("Setting remote description"); + try { - if (!pc) return; - - const sd = btoa(JSON.stringify(pc.localDescription)); - - const sessionUrl = isOnDevice - ? `${DEVICE_API}/webrtc/session` - : `${CLOUD_API}/webrtc/session`; - - console.log("Trying to get remote session description"); - setLoadingMessage( - `Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`, - ); - const res = await api.POST(sessionUrl, { - sd, - // When on device, we don't need to specify the device id, as it's already known - ...(isOnDevice ? {} : { id: params.id }), - }); - - const json = await res.json(); - if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login"); - if (!res.ok) { - console.error("Error getting SDP", { status: res.status, json }); - throw new Error("Error getting SDP"); - } - - console.log("Successfully got Remote Session Description. Setting."); - setLoadingMessage("Setting remote session description..."); - - const decodedSd = atob(json.sd); - const parsedSd = JSON.parse(decodedSd); - pc.setRemoteDescription(new RTCSessionDescription(parsedSd)); - - await new Promise((resolve, reject) => { - console.log("Waiting for remote description to be set"); - const maxAttempts = 10; - const interval = 1000; - let attempts = 0; - - const checkInterval = setInterval(() => { - attempts++; - // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects - if (pc.sctp?.state === "connected") { - console.log("Remote description set"); - clearInterval(checkInterval); - resolve(true); - } else if (attempts >= maxAttempts) { - console.log( - `Failed to get remote description after ${maxAttempts} attempts`, - ); - closePeerConnection(); - clearInterval(checkInterval); - reject( - new Error( - `Failed to get remote description after ${maxAttempts} attempts`, - ), - ); - } else { - console.log("Waiting for remote description to be set"); - } - }, interval); - }); + await pc.setRemoteDescription(new RTCSessionDescription(remoteDescription)); + console.log("[setRemoteSessionDescription] Remote description set successfully"); + setLoadingMessage("Establishing secure connection..."); } catch (error) { - console.error("Error getting SDP", { error }); - console.log("Connection failed", connectionFailedRef.current); - if (connectionFailedRef.current) return; - if (signalingAttempts.current < 5) { - signalingAttempts.current++; - await new Promise(resolve => setTimeout(resolve, 500)); - console.log("Attempting to get SDP again", signalingAttempts.current); - syncRemoteSessionDescription(pc); - } else { - closePeerConnection(); - } + console.error( + "[setRemoteSessionDescription] Failed to set remote description:", + error, + ); + cleanupAndStopReconnecting(); + return; } + + // Replace the interval-based check with a more reliable approach + let attempts = 0; + const checkInterval = setInterval(() => { + attempts++; + + // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects + if (pc.sctp?.state === "connected") { + console.log("[setRemoteSessionDescription] Remote description set"); + clearInterval(checkInterval); + setLoadingMessage("Connection established"); + } else if (attempts >= 10) { + console.log( + "[setRemoteSessionDescription] Failed to establish connection after 10 attempts", + { + connectionState: pc.connectionState, + iceConnectionState: pc.iceConnectionState, + }, + ); + cleanupAndStopReconnecting(); + clearInterval(checkInterval); + } else { + console.log("[setRemoteSessionDescription] Waiting for connection, state:", { + connectionState: pc.connectionState, + iceConnectionState: pc.iceConnectionState, + }); + } + }, 1000); }, - [closePeerConnection, navigate, params.id], + [cleanupAndStopReconnecting], + ); + + const ignoreOffer = useRef(false); + const isSettingRemoteAnswerPending = useRef(false); + const makingOffer = useRef(false); + + const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:"; + + const { sendMessage, getWebSocket } = useWebSocket( + isOnDevice + ? `${wsProtocol}//${window.location.host}/webrtc/signaling/client` + : `${CLOUD_API.replace("http", "ws")}/webrtc/signaling/client?id=${params.id}`, + { + heartbeat: true, + retryOnError: true, + reconnectAttempts: 5, + reconnectInterval: 1000, + onReconnectStop: () => { + console.log("Reconnect stopped"); + cleanupAndStopReconnecting(); + }, + + shouldReconnect(event) { + console.log("[Websocket] shouldReconnect", event); + // TODO: Why true? + return true; + }, + + onClose(event) { + console.log("[Websocket] onClose", event); + // We don't want to close everything down, we wait for the reconnect to stop instead + }, + + onError(event) { + console.log("[Websocket] onError", event); + // We don't want to close everything down, we wait for the reconnect to stop instead + }, + onOpen() { + console.log("[Websocket] onOpen"); + }, + + onMessage: message => { + if (message.data === "pong") return; + + /* + Currently the signaling process is as follows: + After open, the other side will send a `device-metadata` message with the device version + If the device version is not set, we can assume the device is using the legacy signaling + Otherwise, we can assume the device is using the new signaling + + If the device is using the legacy signaling, we close the websocket connection + and use the legacy HTTPSignaling function to get the remote session description + + If the device is using the new signaling, we don't need to do anything special, but continue to use the websocket connection + to chat with the other peer about the connection + */ + + const parsedMessage = JSON.parse(message.data); + if (parsedMessage.type === "device-metadata") { + const { deviceVersion } = parsedMessage.data; + console.log("[Websocket] Received device-metadata message"); + console.log("[Websocket] Device version", deviceVersion); + // If the device version is not set, we can assume the device is using the legacy signaling + if (!deviceVersion) { + console.log("[Websocket] Device is using legacy signaling"); + + // Now we don't need the websocket connection anymore, as we've established that we need to use the legacy signaling + // which does everything over HTTP(at least from the perspective of the client) + isLegacySignalingEnabled.current = true; + getWebSocket()?.close(); + } else { + console.log("[Websocket] Device is using new signaling"); + isLegacySignalingEnabled.current = false; + } + setupPeerConnection(); + } + + if (!peerConnection) return; + if (parsedMessage.type === "answer") { + console.log("[Websocket] Received answer"); + const readyForOffer = + // If we're making an offer, we don't want to accept an answer + !makingOffer && + // If the peer connection is stable or we're setting the remote answer pending, we're ready for an offer + (peerConnection?.signalingState === "stable" || + isSettingRemoteAnswerPending.current); + + // If we're not ready for an offer, we don't want to accept an offer + ignoreOffer.current = parsedMessage.type === "offer" && !readyForOffer; + if (ignoreOffer.current) return; + + // Set so we don't accept an answer while we're setting the remote description + isSettingRemoteAnswerPending.current = parsedMessage.type === "answer"; + console.log( + "[Websocket] Setting remote answer pending", + isSettingRemoteAnswerPending.current, + ); + + const sd = atob(parsedMessage.data); + const remoteSessionDescription = JSON.parse(sd); + + setRemoteSessionDescription( + peerConnection, + new RTCSessionDescription(remoteSessionDescription), + ); + + // Reset the remote answer pending flag + isSettingRemoteAnswerPending.current = false; + } else if (parsedMessage.type === "new-ice-candidate") { + console.log("[Websocket] Received new-ice-candidate"); + const candidate = parsedMessage.data; + peerConnection.addIceCandidate(candidate); + } + }, + }, + + // Don't even retry once we declare failure + !connectionFailed && isLegacySignalingEnabled.current === false, + ); + + const sendWebRTCSignal = useCallback( + (type: string, data: unknown) => { + // Second argument tells the library not to queue the message, and send it once the connection is established again. + // We have event handlers that handle the connection set up, so we don't need to queue the message. + sendMessage(JSON.stringify({ type, data }), false); + }, + [sendMessage], + ); + + const legacyHTTPSignaling = useCallback( + async (pc: RTCPeerConnection) => { + const sd = btoa(JSON.stringify(pc.localDescription)); + + // Legacy mode == UI in cloud with updated code connecting to older device version. + // In device mode, old devices wont server this JS, and on newer devices legacy mode wont be enabled + const sessionUrl = `${CLOUD_API}/webrtc/session`; + + console.log("Trying to get remote session description"); + setLoadingMessage( + `Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`, + ); + const res = await api.POST(sessionUrl, { + sd, + // When on device, we don't need to specify the device id, as it's already known + ...(isOnDevice ? {} : { id: params.id }), + }); + + const json = await res.json(); + if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login"); + if (!res.ok) { + console.error("Error getting SDP", { status: res.status, json }); + cleanupAndStopReconnecting(); + return; + } + + console.log("Successfully got Remote Session Description. Setting."); + setLoadingMessage("Setting remote session description..."); + + const decodedSd = atob(json.sd); + const parsedSd = JSON.parse(decodedSd); + setRemoteSessionDescription(pc, new RTCSessionDescription(parsedSd)); + }, + [cleanupAndStopReconnecting, navigate, params.id, setRemoteSessionDescription], ); const setupPeerConnection = useCallback(async () => { - console.log("Setting up peer connection"); + console.log("[setupPeerConnection] Setting up peer connection"); setConnectionFailed(false); setLoadingMessage("Connecting to device..."); + if (peerConnection?.signalingState === "stable") { + console.log("[setupPeerConnection] Peer connection already established"); + return; + } + let pc: RTCPeerConnection; try { - console.log("Creating peer connection"); + console.log("[setupPeerConnection] Creating peer connection"); setLoadingMessage("Creating peer connection..."); pc = new RTCPeerConnection({ // We only use STUN or TURN servers if we're in the cloud @@ -267,30 +413,65 @@ export default function KvmIdRoute() { ? { iceServers: [iceConfig?.iceServers] } : {}), }); - console.log("Peer connection created", pc); - setLoadingMessage("Peer connection created"); + + setPeerConnectionState(pc.connectionState); + console.log("[setupPeerConnection] Peer connection created", pc); + setLoadingMessage("Setting up connection to device..."); } catch (e) { - console.error(`Error creating peer connection: ${e}`); + console.error(`[setupPeerConnection] Error creating peer connection: ${e}`); setTimeout(() => { - closePeerConnection(); + cleanupAndStopReconnecting(); }, 1000); return; } // Set up event listeners and data channels pc.onconnectionstatechange = () => { - console.log("Connection state changed", pc.connectionState); + console.log("[setupPeerConnection] Connection state changed", pc.connectionState); + setPeerConnectionState(pc.connectionState); + }; + + pc.onnegotiationneeded = async () => { + try { + console.log("[setupPeerConnection] Creating offer"); + makingOffer.current = true; + + const offer = await pc.createOffer(); + await pc.setLocalDescription(offer); + const sd = btoa(JSON.stringify(pc.localDescription)); + const isNewSignalingEnabled = isLegacySignalingEnabled.current === false; + if (isNewSignalingEnabled) { + sendWebRTCSignal("offer", { sd: sd }); + } else { + console.log("Legacy signanling. Waiting for ICE Gathering to complete..."); + } + } catch (e) { + console.error( + `[setupPeerConnection] Error creating offer: ${e}`, + new Date().toISOString(), + ); + cleanupAndStopReconnecting(); + } finally { + makingOffer.current = false; + } + }; + + pc.onicecandidate = async ({ candidate }) => { + if (!candidate) return; + if (candidate.candidate === "") return; + sendWebRTCSignal("new-ice-candidate", candidate); }; pc.onicegatheringstatechange = event => { const pc = event.currentTarget as RTCPeerConnection; - console.log("ICE Gathering State Changed", pc.iceGatheringState); if (pc.iceGatheringState === "complete") { console.log("ICE Gathering completed"); setLoadingMessage("ICE Gathering completed"); - // We can now start the https/ws connection to get the remote session description from the KVM device - syncRemoteSessionDescription(pc); + if (isLegacySignalingEnabled.current) { + // We can now start the https/ws connection to get the remote session description from the KVM device + legacyHTTPSignaling(pc); + } } else if (pc.iceGatheringState === "gathering") { console.log("ICE Gathering Started"); setLoadingMessage("Gathering ICE candidates..."); @@ -314,31 +495,26 @@ export default function KvmIdRoute() { }; setPeerConnection(pc); - - try { - const offer = await pc.createOffer(); - await pc.setLocalDescription(offer); - } catch (e) { - console.error(`Error creating offer: ${e}`, new Date().toISOString()); - closePeerConnection(); - } }, [ - closePeerConnection, + cleanupAndStopReconnecting, iceConfig?.iceServers, + legacyHTTPSignaling, + peerConnection?.signalingState, + sendWebRTCSignal, setDiskChannel, setMediaMediaStream, setPeerConnection, + setPeerConnectionState, setRpcDataChannel, setTransceiver, - syncRemoteSessionDescription, ]); - // On boot, if the connection state is undefined, we connect to the WebRTC useEffect(() => { - if (peerConnection?.connectionState === undefined) { - setupPeerConnection(); + if (peerConnectionState === "failed") { + console.log("Connection failed, closing peer connection"); + cleanupAndStopReconnecting(); } - }, [setupPeerConnection, peerConnection?.connectionState]); + }, [peerConnectionState, cleanupAndStopReconnecting]); // Cleanup effect const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats); @@ -363,7 +539,7 @@ export default function KvmIdRoute() { // TURN server usage detection useEffect(() => { - if (peerConnection?.connectionState !== "connected") return; + if (peerConnectionState !== "connected") return; const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState(); const lastLocalStat = Array.from(localCandidateStats).pop(); @@ -375,7 +551,7 @@ export default function KvmIdRoute() { const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn); - }, [peerConnection?.connectionState, setIsTurnServerInUse]); + }, [peerConnectionState, setIsTurnServerInUse]); // TURN server usage reporting const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse); @@ -466,10 +642,6 @@ export default function KvmIdRoute() { }); }, [rpcDataChannel?.readyState, send, setHdmiState]); - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-expect-error - window.send = send; - // When the update is successful, we need to refresh the client javascript and show a success modal useEffect(() => { if (queryParams.get("updateSuccess")) { @@ -506,12 +678,12 @@ export default function KvmIdRoute() { useEffect(() => { if (!peerConnection) return; if (!kvmTerminal) { - console.log('Creating data channel "terminal"'); + // console.log('Creating data channel "terminal"'); setKvmTerminal(peerConnection.createDataChannel("terminal")); } if (!serialConsole) { - console.log('Creating data channel "serial"'); + // console.log('Creating data channel "serial"'); setSerialConsole(peerConnection.createDataChannel("serial")); } }, [kvmTerminal, peerConnection, serialConsole]); @@ -554,6 +726,43 @@ export default function KvmIdRoute() { [send, setScrollSensitivity], ); + const ConnectionStatusElement = useMemo(() => { + const hasConnectionFailed = + connectionFailed || ["failed", "closed"].includes(peerConnectionState || ""); + + const isPeerConnectionLoading = + ["connecting", "new"].includes(peerConnectionState || "") || + peerConnection === null; + + const isDisconnected = peerConnectionState === "disconnected"; + + const isOtherSession = location.pathname.includes("other-session"); + + if (isOtherSession) return null; + if (peerConnectionState === "connected") return null; + if (isDisconnected) { + return ; + } + + if (hasConnectionFailed) + return ( + + ); + + if (isPeerConnectionLoading) { + return ; + } + + return null; + }, [ + connectionFailed, + loadingMessage, + location.pathname, + peerConnection, + peerConnectionState, + setupPeerConnection, + ]); + return ( {!outlet && otaState.updating && ( @@ -593,27 +802,13 @@ export default function KvmIdRoute() { />
-
+
- - + {!!ConnectionStatusElement && ConnectionStatusElement}
- + {peerConnectionState === "connected" && }
diff --git a/web.go b/web.go index 9201e7b..c3f6d8d 100644 --- a/web.go +++ b/web.go @@ -1,6 +1,7 @@ package kvm import ( + "context" "embed" "encoding/json" "fmt" @@ -10,8 +11,12 @@ import ( "strings" "time" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/pion/webrtc/v4" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/crypto/bcrypt" ) @@ -94,7 +99,7 @@ func setupRouter() *gin.Engine { protected := r.Group("/") protected.Use(protectedMiddleware()) { - protected.POST("/webrtc/session", handleWebRTCSession) + protected.GET("/webrtc/signaling/client", handleLocalWebRTCSignal) protected.POST("/cloud/register", handleCloudRegister) protected.GET("/cloud/state", handleCloudState) protected.GET("/device", handleDevice) @@ -121,35 +126,182 @@ func setupRouter() *gin.Engine { // TODO: support multiple sessions? var currentSession *Session -func handleWebRTCSession(c *gin.Context) { - var req WebRTCSessionRequest - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return +func handleLocalWebRTCSignal(c *gin.Context) { + cloudLogger.Infof("new websocket connection established") + // Create WebSocket options with InsecureSkipVerify to bypass origin check + wsOptions := &websocket.AcceptOptions{ + InsecureSkipVerify: true, // Allow connections from any origin } - session, err := newSession(SessionConfig{}) + wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err}) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - sd, err := session.ExchangeOffer(req.Sd) + // get the source from the request + source := c.ClientIP() + + // Now use conn for websocket operations + defer wsCon.Close(websocket.StatusNormalClosure, "") + + err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": gin.H{"deviceVersion": builtAppVersion}}) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err}) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - if currentSession != nil { - writeJSONRPCEvent("otherSessionConnected", nil, currentSession) - peerConn := currentSession.peerConnection + + err = handleWebRTCSignalWsMessages(wsCon, false, source) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } +} + +func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string) error { + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + + // Add connection tracking to detect reconnections + connectionID := uuid.New().String() + cloudLogger.Infof("new websocket connection established with ID: %s", connectionID) + + // connection type + var sourceType string + if isCloudConnection { + sourceType = "cloud" + } else { + sourceType = "local" + } + + // probably we can use a better logging framework here + logInfof := func(format string, args ...interface{}) { + args = append(args, source, sourceType) + websocketLogger.Infof(format+", source: %s, sourceType: %s", args...) + } + logWarnf := func(format string, args ...interface{}) { + args = append(args, source, sourceType) + websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...) + } + logTracef := func(format string, args ...interface{}) { + args = append(args, source, sourceType) + websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...) + } + + go func() { + for { + time.Sleep(WebsocketPingInterval) + + // set the timer for the ping duration + timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { + metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v) + metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v) + })) + + logInfof("pinging websocket") + err := wsCon.Ping(runCtx) + + if err != nil { + logWarnf("websocket ping error: %v", err) + cancelRun() + return + } + + // dont use `defer` here because we want to observe the duration of the ping + timer.ObserveDuration() + + metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc() + metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() + } + }() + + if isCloudConnection { + // create a channel to receive the disconnect event, once received, we cancelRun + cloudDisconnectChan = make(chan error) + defer func() { + close(cloudDisconnectChan) + cloudDisconnectChan = nil + }() go func() { - time.Sleep(1 * time.Second) - _ = peerConn.Close() + for err := range cloudDisconnectChan { + if err == nil { + continue + } + cloudLogger.Infof("disconnecting from cloud due to: %v", err) + cancelRun() + } }() } - currentSession = session - c.JSON(http.StatusOK, gin.H{"sd": sd}) + + for { + typ, msg, err := wsCon.Read(runCtx) + if err != nil { + logWarnf("websocket read error: %v", err) + return err + } + if typ != websocket.MessageText { + // ignore non-text messages + continue + } + + var message struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + } + + err = json.Unmarshal(msg, &message) + if err != nil { + logWarnf("unable to parse ws message: %v", err) + continue + } + + if message.Type == "offer" { + logInfof("new session request received") + var req WebRTCSessionRequest + err = json.Unmarshal(message.Data, &req) + if err != nil { + logWarnf("unable to parse session request data: %v", err) + continue + } + + logInfof("new session request: %v", req.OidcGoogle) + logTracef("session request info: %v", req) + + metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc() + metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() + err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source) + if err != nil { + logWarnf("error starting new session: %v", err) + continue + } + } else if message.Type == "new-ice-candidate" { + logInfof("The client sent us a new ICE candidate: %v", string(message.Data)) + var candidate webrtc.ICECandidateInit + + // Attempt to unmarshal as a ICECandidateInit + if err := json.Unmarshal(message.Data, &candidate); err != nil { + logWarnf("unable to parse incoming ICE candidate data: %v", string(message.Data)) + continue + } + + if candidate.Candidate == "" { + logWarnf("empty incoming ICE candidate, skipping") + continue + } + + logInfof("unmarshalled incoming ICE candidate: %v", candidate) + + if currentSession == nil { + logInfof("no current session, skipping incoming ICE candidate") + continue + } + + logInfof("adding incoming ICE candidate to current session: %v", candidate) + if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil { + logWarnf("failed to add incoming ICE candidate to our peer connection: %v", err) + } + } + } } func handleLogin(c *gin.Context) { diff --git a/webrtc.go b/webrtc.go index 12d4f95..a047ecc 100644 --- a/webrtc.go +++ b/webrtc.go @@ -1,11 +1,15 @@ package kvm import ( + "context" "encoding/base64" "encoding/json" "net" "strings" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/gin-gonic/gin" "github.com/pion/webrtc/v4" ) @@ -23,6 +27,7 @@ type SessionConfig struct { ICEServers []string LocalIP string IsCloud bool + ws *websocket.Conn } func (s *Session) ExchangeOffer(offerStr string) (string, error) { @@ -46,19 +51,11 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) { return "", err } - // Create channel that is blocked until ICE Gathering is complete - gatherComplete := webrtc.GatheringCompletePromise(s.peerConnection) - // Sets the LocalDescription, and starts our UDP listeners if err = s.peerConnection.SetLocalDescription(answer); err != nil { return "", err } - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - <-gatherComplete - localDescription, err := json.Marshal(s.peerConnection.LocalDescription()) if err != nil { return "", err @@ -144,6 +141,16 @@ func newSession(config SessionConfig) (*Session, error) { }() var isConnected bool + peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { + logger.Infof("Our WebRTC peerConnection has a new ICE candidate: %v", candidate) + if candidate != nil { + err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()}) + if err != nil { + logger.Errorf("failed to write new-ice-candidate to WebRTC signaling channel: %v", err) + } + } + }) + peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { logger.Infof("Connection State has changed %s", connectionState) if connectionState == webrtc.ICEConnectionStateConnected {