diff --git a/cloud.go b/cloud.go
index f91085a..7ad8b75 100644
--- a/cloud.go
+++ b/cloud.go
@@ -35,8 +35,8 @@ const (
// CloudOidcRequestTimeout is the timeout for OIDC token verification requests
// should be lower than the websocket response timeout set in cloud-api
CloudOidcRequestTimeout = 10 * time.Second
- // CloudWebSocketPingInterval is the interval at which the websocket client sends ping messages to the cloud
- CloudWebSocketPingInterval = 15 * time.Second
+ // WebsocketPingInterval is the interval at which the websocket client sends ping messages to the cloud
+ WebsocketPingInterval = 15 * time.Second
)
var (
@@ -52,59 +52,67 @@ var (
Help: "The timestamp when the cloud connection was established",
},
)
- metricCloudConnectionLastPingTimestamp = promauto.NewGauge(
+ metricConnectionLastPingTimestamp = promauto.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "jetkvm_cloud_connection_last_ping_timestamp",
+ Name: "jetkvm_connection_last_ping_timestamp",
Help: "The timestamp when the last ping response was received",
},
+ []string{"type", "source"},
)
- metricCloudConnectionLastPingDuration = promauto.NewGauge(
+ metricConnectionLastPingDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "jetkvm_cloud_connection_last_ping_duration",
+ Name: "jetkvm_connection_last_ping_duration",
Help: "The duration of the last ping response",
},
+ []string{"type", "source"},
)
- metricCloudConnectionPingDuration = promauto.NewHistogram(
+ metricConnectionPingDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
- Name: "jetkvm_cloud_connection_ping_duration",
+ Name: "jetkvm_connection_ping_duration",
Help: "The duration of the ping response",
Buckets: []float64{
0.1, 0.5, 1, 10,
},
},
+ []string{"type", "source"},
)
- metricCloudConnectionTotalPingCount = promauto.NewCounter(
+ metricConnectionTotalPingCount = promauto.NewCounterVec(
prometheus.CounterOpts{
- Name: "jetkvm_cloud_connection_total_ping_count",
- Help: "The total number of pings sent to the cloud",
+ Name: "jetkvm_connection_total_ping_count",
+ Help: "The total number of pings sent to the connection",
},
+ []string{"type", "source"},
)
- metricCloudConnectionSessionRequestCount = promauto.NewCounter(
+ metricConnectionSessionRequestCount = promauto.NewCounterVec(
prometheus.CounterOpts{
- Name: "jetkvm_cloud_connection_session_total_request_count",
- Help: "The total number of session requests received from the cloud",
+ Name: "jetkvm_connection_session_total_request_count",
+ Help: "The total number of session requests received",
},
+ []string{"type", "source"},
)
- metricCloudConnectionSessionRequestDuration = promauto.NewHistogram(
+ metricConnectionSessionRequestDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
- Name: "jetkvm_cloud_connection_session_request_duration",
+ Name: "jetkvm_connection_session_request_duration",
Help: "The duration of session requests",
Buckets: []float64{
0.1, 0.5, 1, 10,
},
},
+ []string{"type", "source"},
)
- metricCloudConnectionLastSessionRequestTimestamp = promauto.NewGauge(
+ metricConnectionLastSessionRequestTimestamp = promauto.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "jetkvm_cloud_connection_last_session_request_timestamp",
+ Name: "jetkvm_connection_last_session_request_timestamp",
Help: "The timestamp of the last session request",
},
+ []string{"type", "source"},
)
- metricCloudConnectionLastSessionRequestDuration = promauto.NewGauge(
+ metricConnectionLastSessionRequestDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "jetkvm_cloud_connection_last_session_request_duration",
+ Name: "jetkvm_connection_last_session_request_duration",
Help: "The duration of the last session request",
},
+ []string{"type", "source"},
)
metricCloudConnectionFailureCount = promauto.NewCounter(
prometheus.CounterOpts{
@@ -119,12 +127,16 @@ var (
cloudDisconnectLock = &sync.Mutex{}
)
-func cloudResetMetrics(established bool) {
- metricCloudConnectionLastPingTimestamp.Set(-1)
- metricCloudConnectionLastPingDuration.Set(-1)
+func wsResetMetrics(established bool, sourceType string, source string) {
+ metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1)
+ metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1)
- metricCloudConnectionLastSessionRequestTimestamp.Set(-1)
- metricCloudConnectionLastSessionRequestDuration.Set(-1)
+ metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1)
+ metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1)
+
+ if sourceType != "cloud" {
+ return
+ }
if established {
metricCloudConnectionEstablishedTimestamp.SetToCurrentTime()
@@ -256,6 +268,7 @@ func runWebsocketClient() error {
header := http.Header{}
header.Set("X-Device-ID", GetDeviceID())
+ header.Set("X-App-Version", builtAppVersion)
header.Set("Authorization", "Bearer "+config.CloudToken)
dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout)
@@ -270,88 +283,13 @@ func runWebsocketClient() error {
cloudLogger.Infof("websocket connected to %s", wsURL)
// set the metrics when we successfully connect to the cloud.
- cloudResetMetrics(true)
+ wsResetMetrics(true, "cloud", "")
- runCtx, cancelRun := context.WithCancel(context.Background())
- defer cancelRun()
- go func() {
- for {
- time.Sleep(CloudWebSocketPingInterval)
-
- // set the timer for the ping duration
- timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
- metricCloudConnectionLastPingDuration.Set(v)
- metricCloudConnectionPingDuration.Observe(v)
- }))
-
- err := c.Ping(runCtx)
-
- if err != nil {
- cloudLogger.Warnf("websocket ping error: %v", err)
- cancelRun()
- return
- }
-
- // dont use `defer` here because we want to observe the duration of the ping
- timer.ObserveDuration()
-
- metricCloudConnectionTotalPingCount.Inc()
- metricCloudConnectionLastPingTimestamp.SetToCurrentTime()
- }
- }()
-
- // create a channel to receive the disconnect event, once received, we cancelRun
- cloudDisconnectChan = make(chan error)
- defer func() {
- close(cloudDisconnectChan)
- cloudDisconnectChan = nil
- }()
- go func() {
- for err := range cloudDisconnectChan {
- if err == nil {
- continue
- }
- cloudLogger.Infof("disconnecting from cloud due to: %v", err)
- cancelRun()
- }
- }()
-
- for {
- typ, msg, err := c.Read(runCtx)
- if err != nil {
- return err
- }
- if typ != websocket.MessageText {
- // ignore non-text messages
- continue
- }
- var req WebRTCSessionRequest
- err = json.Unmarshal(msg, &req)
- if err != nil {
- cloudLogger.Warnf("unable to parse ws message: %v", string(msg))
- continue
- }
-
- cloudLogger.Infof("new session request: %v", req.OidcGoogle)
- cloudLogger.Tracef("session request info: %v", req)
-
- metricCloudConnectionSessionRequestCount.Inc()
- metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime()
- err = handleSessionRequest(runCtx, c, req)
- if err != nil {
- cloudLogger.Infof("error starting new session: %v", err)
- continue
- }
- }
+ // we don't have a source for the cloud connection
+ return handleWebRTCSignalWsMessages(c, true, "")
}
-func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
- timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
- metricCloudConnectionLastSessionRequestDuration.Set(v)
- metricCloudConnectionSessionRequestDuration.Observe(v)
- }))
- defer timer.ObserveDuration()
-
+func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
oidcCtx, cancelOIDC := context.WithTimeout(ctx, CloudOidcRequestTimeout)
defer cancelOIDC()
provider, err := oidc.NewProvider(oidcCtx, "https://accounts.google.com")
@@ -379,10 +317,35 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
return fmt.Errorf("google identity mismatch")
}
+ return nil
+}
+
+func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool, source string) error {
+ var sourceType string
+ if isCloudConnection {
+ sourceType = "cloud"
+ } else {
+ sourceType = "local"
+ }
+
+ timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
+ metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(v)
+ metricConnectionSessionRequestDuration.WithLabelValues(sourceType, source).Observe(v)
+ }))
+ defer timer.ObserveDuration()
+
+ // If the message is from the cloud, we need to authenticate the session.
+ if isCloudConnection {
+ if err := authenticateSession(ctx, c, req); err != nil {
+ return err
+ }
+ }
+
session, err := newSession(SessionConfig{
- ICEServers: req.ICEServers,
+ ws: c,
+ IsCloud: isCloudConnection,
LocalIP: req.IP,
- IsCloud: true,
+ ICEServers: req.ICEServers,
})
if err != nil {
_ = wsjson.Write(context.Background(), c, gin.H{"error": err})
@@ -406,14 +369,14 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
cloudLogger.Info("new session accepted")
cloudLogger.Tracef("new session accepted: %v", session)
currentSession = session
- _ = wsjson.Write(context.Background(), c, gin.H{"sd": sd})
+ _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd})
return nil
}
func RunWebsocketClient() {
for {
// reset the metrics when we start the websocket client.
- cloudResetMetrics(false)
+ wsResetMetrics(false, "cloud", "")
// If the cloud token is not set, we don't need to run the websocket client.
if config.CloudToken == "" {
diff --git a/log.go b/log.go
index 7718a28..0d36c0d 100644
--- a/log.go
+++ b/log.go
@@ -6,3 +6,4 @@ import "github.com/pion/logging"
// ref: https://github.com/pion/webrtc/wiki/Debugging-WebRTC
var logger = logging.NewDefaultLoggerFactory().NewLogger("jetkvm")
var cloudLogger = logging.NewDefaultLoggerFactory().NewLogger("cloud")
+var websocketLogger = logging.NewDefaultLoggerFactory().NewLogger("websocket")
diff --git a/ui/package-lock.json b/ui/package-lock.json
index e9caa20..ebce148 100644
--- a/ui/package-lock.json
+++ b/ui/package-lock.json
@@ -30,6 +30,7 @@
"react-icons": "^5.4.0",
"react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112",
+ "react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9",
"recharts": "^2.15.0",
"tailwind-merge": "^2.5.5",
@@ -5180,6 +5181,11 @@
"react-dom": ">=16.6.0"
}
},
+ "node_modules/react-use-websocket": {
+ "version": "4.13.0",
+ "resolved": "https://registry.npmjs.org/react-use-websocket/-/react-use-websocket-4.13.0.tgz",
+ "integrity": "sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw=="
+ },
"node_modules/react-xtermjs": {
"version": "1.0.9",
"resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz",
diff --git a/ui/package.json b/ui/package.json
index f8f1c7a..a248616 100644
--- a/ui/package.json
+++ b/ui/package.json
@@ -40,6 +40,7 @@
"react-icons": "^5.4.0",
"react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112",
+ "react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9",
"recharts": "^2.15.0",
"tailwind-merge": "^2.5.5",
diff --git a/ui/src/components/Header.tsx b/ui/src/components/Header.tsx
index 03a907e..19e9652 100644
--- a/ui/src/components/Header.tsx
+++ b/ui/src/components/Header.tsx
@@ -36,7 +36,7 @@ export default function DashboardNavbar({
picture,
kvmName,
}: NavbarProps) {
- const peerConnection = useRTCStore(state => state.peerConnection);
+ const peerConnectionState = useRTCStore(state => state.peerConnectionState);
const setUser = useUserStore(state => state.setUser);
const navigate = useNavigate();
const onLogout = useCallback(async () => {
@@ -82,14 +82,14 @@ export default function DashboardNavbar({
diff --git a/ui/src/components/VideoOverlay.tsx b/ui/src/components/VideoOverlay.tsx
index 0620af4..e34cf10 100644
--- a/ui/src/components/VideoOverlay.tsx
+++ b/ui/src/components/VideoOverlay.tsx
@@ -6,7 +6,7 @@ import { LuPlay } from "react-icons/lu";
import { Button, LinkButton } from "@components/Button";
import LoadingSpinner from "@components/LoadingSpinner";
-import { GridCard } from "@components/Card";
+import Card, { GridCard } from "@components/Card";
interface OverlayContentProps {
children: React.ReactNode;
@@ -94,7 +94,7 @@ interface ConnectionErrorOverlayProps {
setupPeerConnection: () => Promise;
}
-export function ConnectionErrorOverlay({
+export function ConnectionFailedOverlay({
show,
setupPeerConnection,
}: ConnectionErrorOverlayProps) {
@@ -151,6 +151,60 @@ export function ConnectionErrorOverlay({
);
}
+interface PeerConnectionDisconnectedOverlay {
+ show: boolean;
+}
+
+export function PeerConnectionDisconnectedOverlay({
+ show,
+}: PeerConnectionDisconnectedOverlay) {
+ return (
+
+ {show && (
+
+
+
+
+
+
+
+
Connection Issue Detected
+
+ - Verify that the device is powered on and properly connected
+ - Check all cable connections for any loose or damaged wires
+ - Ensure your network connection is stable and active
+ - Try restarting both the device and your computer
+
+
+
+
+
+
+
+ Retrying connection...
+
+
+
+
+
+
+
+
+
+ )}
+
+ );
+}
+
interface HDMIErrorOverlayProps {
show: boolean;
hdmiState: string;
diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx
index 5d8fb55..99c0191 100644
--- a/ui/src/components/WebRTCVideo.tsx
+++ b/ui/src/components/WebRTCVideo.tsx
@@ -380,7 +380,7 @@ export default function WebRTCVideo() {
(mediaStream: MediaStream) => {
if (!videoElm.current) return;
const videoElmRefValue = videoElm.current;
- console.log("Adding stream to video element", videoElmRefValue);
+ // console.log("Adding stream to video element", videoElmRefValue);
videoElmRefValue.srcObject = mediaStream;
updateVideoSizeStore(videoElmRefValue);
},
@@ -396,7 +396,7 @@ export default function WebRTCVideo() {
peerConnection.addEventListener(
"track",
(e: RTCTrackEvent) => {
- console.log("Adding stream to video element");
+ // console.log("Adding stream to video element");
addStreamToVideoElm(e.streams[0]);
},
{ signal },
diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx
index d2662fc..fef1764 100644
--- a/ui/src/routes/devices.$id.tsx
+++ b/ui/src/routes/devices.$id.tsx
@@ -1,4 +1,4 @@
-import { useCallback, useEffect, useRef, useState } from "react";
+import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import {
LoaderFunctionArgs,
Outlet,
@@ -14,6 +14,7 @@ import {
import { useInterval } from "usehooks-ts";
import FocusTrap from "focus-trap-react";
import { motion, AnimatePresence } from "framer-motion";
+import useWebSocket from "react-use-websocket";
import { cx } from "@/cva.config";
import {
@@ -43,15 +44,16 @@ import UpdateInProgressStatusCard from "../components/UpdateInProgressStatusCard
import api from "../api";
import Modal from "../components/Modal";
import { useDeviceUiNavigation } from "../hooks/useAppNavigation";
+import {
+ ConnectionFailedOverlay,
+ LoadingConnectionOverlay,
+ PeerConnectionDisconnectedOverlay,
+} from "../components/VideoOverlay";
import { FeatureFlagProvider } from "../providers/FeatureFlagProvider";
import notifications from "../notifications";
-import {
- ConnectionErrorOverlay,
- LoadingConnectionOverlay,
-} from "../components/VideoOverlay";
-import { SystemVersionInfo } from "./devices.$id.settings.general.update";
import { DeviceStatus } from "./welcome-local";
+import { SystemVersionInfo } from "./devices.$id.settings.general.update";
interface LocalLoaderResp {
authMode: "password" | "noPassword" | null;
@@ -117,7 +119,6 @@ const loader = async ({ params }: LoaderFunctionArgs) => {
export default function KvmIdRoute() {
const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp;
-
// Depending on the mode, we set the appropriate variables
const user = "user" in loaderResp ? loaderResp.user : null;
const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null;
@@ -130,6 +131,8 @@ export default function KvmIdRoute() {
const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse);
const peerConnection = useRTCStore(state => state.peerConnection);
+ const setPeerConnectionState = useRTCStore(state => state.setPeerConnectionState);
+ const peerConnectionState = useRTCStore(state => state.peerConnectionState);
const setMediaMediaStream = useRTCStore(state => state.setMediaStream);
const setPeerConnection = useRTCStore(state => state.setPeerConnection);
const setDiskChannel = useRTCStore(state => state.setDiskChannel);
@@ -137,23 +140,28 @@ export default function KvmIdRoute() {
const setTransceiver = useRTCStore(state => state.setTransceiver);
const location = useLocation();
+ const isLegacySignalingEnabled = useRef(false);
+
const [connectionFailed, setConnectionFailed] = useState(false);
const navigate = useNavigate();
const { otaState, setOtaState, setModalView } = useUpdateStore();
const [loadingMessage, setLoadingMessage] = useState("Connecting to device...");
- const closePeerConnection = useCallback(
- function closePeerConnection() {
+ const cleanupAndStopReconnecting = useCallback(
+ function cleanupAndStopReconnecting() {
console.log("Closing peer connection");
setConnectionFailed(true);
+ if (peerConnection) {
+ setPeerConnectionState(peerConnection.connectionState);
+ }
connectionFailedRef.current = true;
peerConnection?.close();
signalingAttempts.current = 0;
},
- [peerConnection],
+ [peerConnection, setPeerConnectionState],
);
// We need to track connectionFailed in a ref to avoid stale closure issues
@@ -171,95 +179,233 @@ export default function KvmIdRoute() {
}, [connectionFailed]);
const signalingAttempts = useRef(0);
- const syncRemoteSessionDescription = useCallback(
- async function syncRemoteSessionDescription(pc: RTCPeerConnection) {
+ const setRemoteSessionDescription = useCallback(
+ async function setRemoteSessionDescription(
+ pc: RTCPeerConnection,
+ remoteDescription: RTCSessionDescriptionInit,
+ ) {
+ setLoadingMessage("Setting remote description");
+
try {
- if (!pc) return;
-
- const sd = btoa(JSON.stringify(pc.localDescription));
-
- const sessionUrl = isOnDevice
- ? `${DEVICE_API}/webrtc/session`
- : `${CLOUD_API}/webrtc/session`;
-
- console.log("Trying to get remote session description");
- setLoadingMessage(
- `Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
- );
- const res = await api.POST(sessionUrl, {
- sd,
- // When on device, we don't need to specify the device id, as it's already known
- ...(isOnDevice ? {} : { id: params.id }),
- });
-
- const json = await res.json();
- if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
- if (!res.ok) {
- console.error("Error getting SDP", { status: res.status, json });
- throw new Error("Error getting SDP");
- }
-
- console.log("Successfully got Remote Session Description. Setting.");
- setLoadingMessage("Setting remote session description...");
-
- const decodedSd = atob(json.sd);
- const parsedSd = JSON.parse(decodedSd);
- pc.setRemoteDescription(new RTCSessionDescription(parsedSd));
-
- await new Promise((resolve, reject) => {
- console.log("Waiting for remote description to be set");
- const maxAttempts = 10;
- const interval = 1000;
- let attempts = 0;
-
- const checkInterval = setInterval(() => {
- attempts++;
- // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
- if (pc.sctp?.state === "connected") {
- console.log("Remote description set");
- clearInterval(checkInterval);
- resolve(true);
- } else if (attempts >= maxAttempts) {
- console.log(
- `Failed to get remote description after ${maxAttempts} attempts`,
- );
- closePeerConnection();
- clearInterval(checkInterval);
- reject(
- new Error(
- `Failed to get remote description after ${maxAttempts} attempts`,
- ),
- );
- } else {
- console.log("Waiting for remote description to be set");
- }
- }, interval);
- });
+ await pc.setRemoteDescription(new RTCSessionDescription(remoteDescription));
+ console.log("[setRemoteSessionDescription] Remote description set successfully");
+ setLoadingMessage("Establishing secure connection...");
} catch (error) {
- console.error("Error getting SDP", { error });
- console.log("Connection failed", connectionFailedRef.current);
- if (connectionFailedRef.current) return;
- if (signalingAttempts.current < 5) {
- signalingAttempts.current++;
- await new Promise(resolve => setTimeout(resolve, 500));
- console.log("Attempting to get SDP again", signalingAttempts.current);
- syncRemoteSessionDescription(pc);
- } else {
- closePeerConnection();
- }
+ console.error(
+ "[setRemoteSessionDescription] Failed to set remote description:",
+ error,
+ );
+ cleanupAndStopReconnecting();
+ return;
}
+
+ // Replace the interval-based check with a more reliable approach
+ let attempts = 0;
+ const checkInterval = setInterval(() => {
+ attempts++;
+
+ // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
+ if (pc.sctp?.state === "connected") {
+ console.log("[setRemoteSessionDescription] Remote description set");
+ clearInterval(checkInterval);
+ setLoadingMessage("Connection established");
+ } else if (attempts >= 10) {
+ console.log(
+ "[setRemoteSessionDescription] Failed to establish connection after 10 attempts",
+ {
+ connectionState: pc.connectionState,
+ iceConnectionState: pc.iceConnectionState,
+ },
+ );
+ cleanupAndStopReconnecting();
+ clearInterval(checkInterval);
+ } else {
+ console.log("[setRemoteSessionDescription] Waiting for connection, state:", {
+ connectionState: pc.connectionState,
+ iceConnectionState: pc.iceConnectionState,
+ });
+ }
+ }, 1000);
},
- [closePeerConnection, navigate, params.id],
+ [cleanupAndStopReconnecting],
+ );
+
+ const ignoreOffer = useRef(false);
+ const isSettingRemoteAnswerPending = useRef(false);
+ const makingOffer = useRef(false);
+
+ const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:";
+
+ const { sendMessage, getWebSocket } = useWebSocket(
+ isOnDevice
+ ? `${wsProtocol}//${window.location.host}/webrtc/signaling/client`
+ : `${CLOUD_API.replace("http", "ws")}/webrtc/signaling/client?id=${params.id}`,
+ {
+ heartbeat: true,
+ retryOnError: true,
+ reconnectAttempts: 5,
+ reconnectInterval: 1000,
+ onReconnectStop: () => {
+ console.log("Reconnect stopped");
+ cleanupAndStopReconnecting();
+ },
+
+ shouldReconnect(event) {
+ console.log("[Websocket] shouldReconnect", event);
+ // TODO: Why true?
+ return true;
+ },
+
+ onClose(event) {
+ console.log("[Websocket] onClose", event);
+ // We don't want to close everything down, we wait for the reconnect to stop instead
+ },
+
+ onError(event) {
+ console.log("[Websocket] onError", event);
+ // We don't want to close everything down, we wait for the reconnect to stop instead
+ },
+ onOpen() {
+ console.log("[Websocket] onOpen");
+ },
+
+ onMessage: message => {
+ if (message.data === "pong") return;
+
+ /*
+ Currently the signaling process is as follows:
+ After open, the other side will send a `device-metadata` message with the device version
+ If the device version is not set, we can assume the device is using the legacy signaling
+ Otherwise, we can assume the device is using the new signaling
+
+ If the device is using the legacy signaling, we close the websocket connection
+ and use the legacy HTTPSignaling function to get the remote session description
+
+ If the device is using the new signaling, we don't need to do anything special, but continue to use the websocket connection
+ to chat with the other peer about the connection
+ */
+
+ const parsedMessage = JSON.parse(message.data);
+ if (parsedMessage.type === "device-metadata") {
+ const { deviceVersion } = parsedMessage.data;
+ console.log("[Websocket] Received device-metadata message");
+ console.log("[Websocket] Device version", deviceVersion);
+ // If the device version is not set, we can assume the device is using the legacy signaling
+ if (!deviceVersion) {
+ console.log("[Websocket] Device is using legacy signaling");
+
+ // Now we don't need the websocket connection anymore, as we've established that we need to use the legacy signaling
+ // which does everything over HTTP(at least from the perspective of the client)
+ isLegacySignalingEnabled.current = true;
+ getWebSocket()?.close();
+ } else {
+ console.log("[Websocket] Device is using new signaling");
+ isLegacySignalingEnabled.current = false;
+ }
+ setupPeerConnection();
+ }
+
+ if (!peerConnection) return;
+ if (parsedMessage.type === "answer") {
+ console.log("[Websocket] Received answer");
+ const readyForOffer =
+ // If we're making an offer, we don't want to accept an answer
+ !makingOffer &&
+ // If the peer connection is stable or we're setting the remote answer pending, we're ready for an offer
+ (peerConnection?.signalingState === "stable" ||
+ isSettingRemoteAnswerPending.current);
+
+ // If we're not ready for an offer, we don't want to accept an offer
+ ignoreOffer.current = parsedMessage.type === "offer" && !readyForOffer;
+ if (ignoreOffer.current) return;
+
+ // Set so we don't accept an answer while we're setting the remote description
+ isSettingRemoteAnswerPending.current = parsedMessage.type === "answer";
+ console.log(
+ "[Websocket] Setting remote answer pending",
+ isSettingRemoteAnswerPending.current,
+ );
+
+ const sd = atob(parsedMessage.data);
+ const remoteSessionDescription = JSON.parse(sd);
+
+ setRemoteSessionDescription(
+ peerConnection,
+ new RTCSessionDescription(remoteSessionDescription),
+ );
+
+ // Reset the remote answer pending flag
+ isSettingRemoteAnswerPending.current = false;
+ } else if (parsedMessage.type === "new-ice-candidate") {
+ console.log("[Websocket] Received new-ice-candidate");
+ const candidate = parsedMessage.data;
+ peerConnection.addIceCandidate(candidate);
+ }
+ },
+ },
+
+ // Don't even retry once we declare failure
+ !connectionFailed && isLegacySignalingEnabled.current === false,
+ );
+
+ const sendWebRTCSignal = useCallback(
+ (type: string, data: unknown) => {
+ // Second argument tells the library not to queue the message, and send it once the connection is established again.
+ // We have event handlers that handle the connection set up, so we don't need to queue the message.
+ sendMessage(JSON.stringify({ type, data }), false);
+ },
+ [sendMessage],
+ );
+
+ const legacyHTTPSignaling = useCallback(
+ async (pc: RTCPeerConnection) => {
+ const sd = btoa(JSON.stringify(pc.localDescription));
+
+ // Legacy mode == UI in cloud with updated code connecting to older device version.
+ // In device mode, old devices wont server this JS, and on newer devices legacy mode wont be enabled
+ const sessionUrl = `${CLOUD_API}/webrtc/session`;
+
+ console.log("Trying to get remote session description");
+ setLoadingMessage(
+ `Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
+ );
+ const res = await api.POST(sessionUrl, {
+ sd,
+ // When on device, we don't need to specify the device id, as it's already known
+ ...(isOnDevice ? {} : { id: params.id }),
+ });
+
+ const json = await res.json();
+ if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
+ if (!res.ok) {
+ console.error("Error getting SDP", { status: res.status, json });
+ cleanupAndStopReconnecting();
+ return;
+ }
+
+ console.log("Successfully got Remote Session Description. Setting.");
+ setLoadingMessage("Setting remote session description...");
+
+ const decodedSd = atob(json.sd);
+ const parsedSd = JSON.parse(decodedSd);
+ setRemoteSessionDescription(pc, new RTCSessionDescription(parsedSd));
+ },
+ [cleanupAndStopReconnecting, navigate, params.id, setRemoteSessionDescription],
);
const setupPeerConnection = useCallback(async () => {
- console.log("Setting up peer connection");
+ console.log("[setupPeerConnection] Setting up peer connection");
setConnectionFailed(false);
setLoadingMessage("Connecting to device...");
+ if (peerConnection?.signalingState === "stable") {
+ console.log("[setupPeerConnection] Peer connection already established");
+ return;
+ }
+
let pc: RTCPeerConnection;
try {
- console.log("Creating peer connection");
+ console.log("[setupPeerConnection] Creating peer connection");
setLoadingMessage("Creating peer connection...");
pc = new RTCPeerConnection({
// We only use STUN or TURN servers if we're in the cloud
@@ -267,30 +413,65 @@ export default function KvmIdRoute() {
? { iceServers: [iceConfig?.iceServers] }
: {}),
});
- console.log("Peer connection created", pc);
- setLoadingMessage("Peer connection created");
+
+ setPeerConnectionState(pc.connectionState);
+ console.log("[setupPeerConnection] Peer connection created", pc);
+ setLoadingMessage("Setting up connection to device...");
} catch (e) {
- console.error(`Error creating peer connection: ${e}`);
+ console.error(`[setupPeerConnection] Error creating peer connection: ${e}`);
setTimeout(() => {
- closePeerConnection();
+ cleanupAndStopReconnecting();
}, 1000);
return;
}
// Set up event listeners and data channels
pc.onconnectionstatechange = () => {
- console.log("Connection state changed", pc.connectionState);
+ console.log("[setupPeerConnection] Connection state changed", pc.connectionState);
+ setPeerConnectionState(pc.connectionState);
+ };
+
+ pc.onnegotiationneeded = async () => {
+ try {
+ console.log("[setupPeerConnection] Creating offer");
+ makingOffer.current = true;
+
+ const offer = await pc.createOffer();
+ await pc.setLocalDescription(offer);
+ const sd = btoa(JSON.stringify(pc.localDescription));
+ const isNewSignalingEnabled = isLegacySignalingEnabled.current === false;
+ if (isNewSignalingEnabled) {
+ sendWebRTCSignal("offer", { sd: sd });
+ } else {
+ console.log("Legacy signanling. Waiting for ICE Gathering to complete...");
+ }
+ } catch (e) {
+ console.error(
+ `[setupPeerConnection] Error creating offer: ${e}`,
+ new Date().toISOString(),
+ );
+ cleanupAndStopReconnecting();
+ } finally {
+ makingOffer.current = false;
+ }
+ };
+
+ pc.onicecandidate = async ({ candidate }) => {
+ if (!candidate) return;
+ if (candidate.candidate === "") return;
+ sendWebRTCSignal("new-ice-candidate", candidate);
};
pc.onicegatheringstatechange = event => {
const pc = event.currentTarget as RTCPeerConnection;
- console.log("ICE Gathering State Changed", pc.iceGatheringState);
if (pc.iceGatheringState === "complete") {
console.log("ICE Gathering completed");
setLoadingMessage("ICE Gathering completed");
- // We can now start the https/ws connection to get the remote session description from the KVM device
- syncRemoteSessionDescription(pc);
+ if (isLegacySignalingEnabled.current) {
+ // We can now start the https/ws connection to get the remote session description from the KVM device
+ legacyHTTPSignaling(pc);
+ }
} else if (pc.iceGatheringState === "gathering") {
console.log("ICE Gathering Started");
setLoadingMessage("Gathering ICE candidates...");
@@ -314,31 +495,26 @@ export default function KvmIdRoute() {
};
setPeerConnection(pc);
-
- try {
- const offer = await pc.createOffer();
- await pc.setLocalDescription(offer);
- } catch (e) {
- console.error(`Error creating offer: ${e}`, new Date().toISOString());
- closePeerConnection();
- }
}, [
- closePeerConnection,
+ cleanupAndStopReconnecting,
iceConfig?.iceServers,
+ legacyHTTPSignaling,
+ peerConnection?.signalingState,
+ sendWebRTCSignal,
setDiskChannel,
setMediaMediaStream,
setPeerConnection,
+ setPeerConnectionState,
setRpcDataChannel,
setTransceiver,
- syncRemoteSessionDescription,
]);
- // On boot, if the connection state is undefined, we connect to the WebRTC
useEffect(() => {
- if (peerConnection?.connectionState === undefined) {
- setupPeerConnection();
+ if (peerConnectionState === "failed") {
+ console.log("Connection failed, closing peer connection");
+ cleanupAndStopReconnecting();
}
- }, [setupPeerConnection, peerConnection?.connectionState]);
+ }, [peerConnectionState, cleanupAndStopReconnecting]);
// Cleanup effect
const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats);
@@ -363,7 +539,7 @@ export default function KvmIdRoute() {
// TURN server usage detection
useEffect(() => {
- if (peerConnection?.connectionState !== "connected") return;
+ if (peerConnectionState !== "connected") return;
const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState();
const lastLocalStat = Array.from(localCandidateStats).pop();
@@ -375,7 +551,7 @@ export default function KvmIdRoute() {
const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here
setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn);
- }, [peerConnection?.connectionState, setIsTurnServerInUse]);
+ }, [peerConnectionState, setIsTurnServerInUse]);
// TURN server usage reporting
const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse);
@@ -466,10 +642,6 @@ export default function KvmIdRoute() {
});
}, [rpcDataChannel?.readyState, send, setHdmiState]);
- // eslint-disable-next-line @typescript-eslint/ban-ts-comment
- // @ts-expect-error
- window.send = send;
-
// When the update is successful, we need to refresh the client javascript and show a success modal
useEffect(() => {
if (queryParams.get("updateSuccess")) {
@@ -506,12 +678,12 @@ export default function KvmIdRoute() {
useEffect(() => {
if (!peerConnection) return;
if (!kvmTerminal) {
- console.log('Creating data channel "terminal"');
+ // console.log('Creating data channel "terminal"');
setKvmTerminal(peerConnection.createDataChannel("terminal"));
}
if (!serialConsole) {
- console.log('Creating data channel "serial"');
+ // console.log('Creating data channel "serial"');
setSerialConsole(peerConnection.createDataChannel("serial"));
}
}, [kvmTerminal, peerConnection, serialConsole]);
@@ -554,6 +726,43 @@ export default function KvmIdRoute() {
[send, setScrollSensitivity],
);
+ const ConnectionStatusElement = useMemo(() => {
+ const hasConnectionFailed =
+ connectionFailed || ["failed", "closed"].includes(peerConnectionState || "");
+
+ const isPeerConnectionLoading =
+ ["connecting", "new"].includes(peerConnectionState || "") ||
+ peerConnection === null;
+
+ const isDisconnected = peerConnectionState === "disconnected";
+
+ const isOtherSession = location.pathname.includes("other-session");
+
+ if (isOtherSession) return null;
+ if (peerConnectionState === "connected") return null;
+ if (isDisconnected) {
+ return ;
+ }
+
+ if (hasConnectionFailed)
+ return (
+
+ );
+
+ if (isPeerConnectionLoading) {
+ return ;
+ }
+
+ return null;
+ }, [
+ connectionFailed,
+ loadingMessage,
+ location.pathname,
+ peerConnection,
+ peerConnectionState,
+ setupPeerConnection,
+ ]);
+
return (
{!outlet && otaState.updating && (
@@ -593,27 +802,13 @@ export default function KvmIdRoute() {
/>
-
+
-
-
+ {!!ConnectionStatusElement && ConnectionStatusElement}
-
+ {peerConnectionState === "connected" &&
}
diff --git a/web.go b/web.go
index 9201e7b..c3f6d8d 100644
--- a/web.go
+++ b/web.go
@@ -1,6 +1,7 @@
package kvm
import (
+ "context"
"embed"
"encoding/json"
"fmt"
@@ -10,8 +11,12 @@ import (
"strings"
"time"
+ "github.com/coder/websocket"
+ "github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
+ "github.com/pion/webrtc/v4"
+ "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/crypto/bcrypt"
)
@@ -94,7 +99,7 @@ func setupRouter() *gin.Engine {
protected := r.Group("/")
protected.Use(protectedMiddleware())
{
- protected.POST("/webrtc/session", handleWebRTCSession)
+ protected.GET("/webrtc/signaling/client", handleLocalWebRTCSignal)
protected.POST("/cloud/register", handleCloudRegister)
protected.GET("/cloud/state", handleCloudState)
protected.GET("/device", handleDevice)
@@ -121,35 +126,182 @@ func setupRouter() *gin.Engine {
// TODO: support multiple sessions?
var currentSession *Session
-func handleWebRTCSession(c *gin.Context) {
- var req WebRTCSessionRequest
-
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
+func handleLocalWebRTCSignal(c *gin.Context) {
+ cloudLogger.Infof("new websocket connection established")
+ // Create WebSocket options with InsecureSkipVerify to bypass origin check
+ wsOptions := &websocket.AcceptOptions{
+ InsecureSkipVerify: true, // Allow connections from any origin
}
- session, err := newSession(SessionConfig{})
+ wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions)
if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err})
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- sd, err := session.ExchangeOffer(req.Sd)
+ // get the source from the request
+ source := c.ClientIP()
+
+ // Now use conn for websocket operations
+ defer wsCon.Close(websocket.StatusNormalClosure, "")
+
+ err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": gin.H{"deviceVersion": builtAppVersion}})
if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err})
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- if currentSession != nil {
- writeJSONRPCEvent("otherSessionConnected", nil, currentSession)
- peerConn := currentSession.peerConnection
+
+ err = handleWebRTCSignalWsMessages(wsCon, false, source)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+}
+
+func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string) error {
+ runCtx, cancelRun := context.WithCancel(context.Background())
+ defer cancelRun()
+
+ // Add connection tracking to detect reconnections
+ connectionID := uuid.New().String()
+ cloudLogger.Infof("new websocket connection established with ID: %s", connectionID)
+
+ // connection type
+ var sourceType string
+ if isCloudConnection {
+ sourceType = "cloud"
+ } else {
+ sourceType = "local"
+ }
+
+ // probably we can use a better logging framework here
+ logInfof := func(format string, args ...interface{}) {
+ args = append(args, source, sourceType)
+ websocketLogger.Infof(format+", source: %s, sourceType: %s", args...)
+ }
+ logWarnf := func(format string, args ...interface{}) {
+ args = append(args, source, sourceType)
+ websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...)
+ }
+ logTracef := func(format string, args ...interface{}) {
+ args = append(args, source, sourceType)
+ websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...)
+ }
+
+ go func() {
+ for {
+ time.Sleep(WebsocketPingInterval)
+
+ // set the timer for the ping duration
+ timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
+ metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v)
+ metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v)
+ }))
+
+ logInfof("pinging websocket")
+ err := wsCon.Ping(runCtx)
+
+ if err != nil {
+ logWarnf("websocket ping error: %v", err)
+ cancelRun()
+ return
+ }
+
+ // dont use `defer` here because we want to observe the duration of the ping
+ timer.ObserveDuration()
+
+ metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc()
+ metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
+ }
+ }()
+
+ if isCloudConnection {
+ // create a channel to receive the disconnect event, once received, we cancelRun
+ cloudDisconnectChan = make(chan error)
+ defer func() {
+ close(cloudDisconnectChan)
+ cloudDisconnectChan = nil
+ }()
go func() {
- time.Sleep(1 * time.Second)
- _ = peerConn.Close()
+ for err := range cloudDisconnectChan {
+ if err == nil {
+ continue
+ }
+ cloudLogger.Infof("disconnecting from cloud due to: %v", err)
+ cancelRun()
+ }
}()
}
- currentSession = session
- c.JSON(http.StatusOK, gin.H{"sd": sd})
+
+ for {
+ typ, msg, err := wsCon.Read(runCtx)
+ if err != nil {
+ logWarnf("websocket read error: %v", err)
+ return err
+ }
+ if typ != websocket.MessageText {
+ // ignore non-text messages
+ continue
+ }
+
+ var message struct {
+ Type string `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }
+
+ err = json.Unmarshal(msg, &message)
+ if err != nil {
+ logWarnf("unable to parse ws message: %v", err)
+ continue
+ }
+
+ if message.Type == "offer" {
+ logInfof("new session request received")
+ var req WebRTCSessionRequest
+ err = json.Unmarshal(message.Data, &req)
+ if err != nil {
+ logWarnf("unable to parse session request data: %v", err)
+ continue
+ }
+
+ logInfof("new session request: %v", req.OidcGoogle)
+ logTracef("session request info: %v", req)
+
+ metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc()
+ metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
+ err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source)
+ if err != nil {
+ logWarnf("error starting new session: %v", err)
+ continue
+ }
+ } else if message.Type == "new-ice-candidate" {
+ logInfof("The client sent us a new ICE candidate: %v", string(message.Data))
+ var candidate webrtc.ICECandidateInit
+
+ // Attempt to unmarshal as a ICECandidateInit
+ if err := json.Unmarshal(message.Data, &candidate); err != nil {
+ logWarnf("unable to parse incoming ICE candidate data: %v", string(message.Data))
+ continue
+ }
+
+ if candidate.Candidate == "" {
+ logWarnf("empty incoming ICE candidate, skipping")
+ continue
+ }
+
+ logInfof("unmarshalled incoming ICE candidate: %v", candidate)
+
+ if currentSession == nil {
+ logInfof("no current session, skipping incoming ICE candidate")
+ continue
+ }
+
+ logInfof("adding incoming ICE candidate to current session: %v", candidate)
+ if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
+ logWarnf("failed to add incoming ICE candidate to our peer connection: %v", err)
+ }
+ }
+ }
}
func handleLogin(c *gin.Context) {
diff --git a/webrtc.go b/webrtc.go
index 12d4f95..a047ecc 100644
--- a/webrtc.go
+++ b/webrtc.go
@@ -1,11 +1,15 @@
package kvm
import (
+ "context"
"encoding/base64"
"encoding/json"
"net"
"strings"
+ "github.com/coder/websocket"
+ "github.com/coder/websocket/wsjson"
+ "github.com/gin-gonic/gin"
"github.com/pion/webrtc/v4"
)
@@ -23,6 +27,7 @@ type SessionConfig struct {
ICEServers []string
LocalIP string
IsCloud bool
+ ws *websocket.Conn
}
func (s *Session) ExchangeOffer(offerStr string) (string, error) {
@@ -46,19 +51,11 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
return "", err
}
- // Create channel that is blocked until ICE Gathering is complete
- gatherComplete := webrtc.GatheringCompletePromise(s.peerConnection)
-
// Sets the LocalDescription, and starts our UDP listeners
if err = s.peerConnection.SetLocalDescription(answer); err != nil {
return "", err
}
- // Block until ICE Gathering is complete, disabling trickle ICE
- // we do this because we only can exchange one signaling message
- // in a production application you should exchange ICE Candidates via OnICECandidate
- <-gatherComplete
-
localDescription, err := json.Marshal(s.peerConnection.LocalDescription())
if err != nil {
return "", err
@@ -144,6 +141,16 @@ func newSession(config SessionConfig) (*Session, error) {
}()
var isConnected bool
+ peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
+ logger.Infof("Our WebRTC peerConnection has a new ICE candidate: %v", candidate)
+ if candidate != nil {
+ err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
+ if err != nil {
+ logger.Errorf("failed to write new-ice-candidate to WebRTC signaling channel: %v", err)
+ }
+ }
+ })
+
peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
logger.Infof("Connection State has changed %s", connectionState)
if connectionState == webrtc.ICEConnectionStateConnected {