Feat/Trickle ice (#336)

* feat(cloud): Use Websocket signaling in cloud mode

* refactor: Enhance WebRTC signaling and connection handling

* refactor: Improve WebRTC connection management and logging in KvmIdRoute

* refactor: Update PeerConnectionDisconnectedOverlay to use Card component for better UI structure

* refactor: Standardize metric naming and improve websocket logging

* refactor: Rename WebRTC signaling functions and update deployment script for debug version

* fix: Handle error when writing new ICE candidate to WebRTC signaling channel

* refactor: Rename signaling handler function for clarity

* refactor: Remove old http local http endpoint

* refactor: Improve metric help text and standardize comparison operator in KvmIdRoute

* chore(websocket): use MetricVec instead of Metric to store metrics

* fix conflicts

* fix: use wss when the page is served over https

* feat: Add app version header and update WebRTC signaling endpoint

* fix: Handle error when writing device metadata to WebRTC signaling channel

---------

Co-authored-by: Siyuan Miao <i@xswan.net>
This commit is contained in:
Adam Shiervani 2025-04-09 00:10:38 +02:00 committed by GitHub
parent fa1b11b228
commit 1a30977085
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 654 additions and 275 deletions

179
cloud.go
View File

@ -35,8 +35,8 @@ const (
// CloudOidcRequestTimeout is the timeout for OIDC token verification requests
// should be lower than the websocket response timeout set in cloud-api
CloudOidcRequestTimeout = 10 * time.Second
// CloudWebSocketPingInterval is the interval at which the websocket client sends ping messages to the cloud
CloudWebSocketPingInterval = 15 * time.Second
// WebsocketPingInterval is the interval at which the websocket client sends ping messages to the cloud
WebsocketPingInterval = 15 * time.Second
)
var (
@ -52,59 +52,67 @@ var (
Help: "The timestamp when the cloud connection was established",
},
)
metricCloudConnectionLastPingTimestamp = promauto.NewGauge(
metricConnectionLastPingTimestamp = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "jetkvm_cloud_connection_last_ping_timestamp",
Name: "jetkvm_connection_last_ping_timestamp",
Help: "The timestamp when the last ping response was received",
},
[]string{"type", "source"},
)
metricCloudConnectionLastPingDuration = promauto.NewGauge(
metricConnectionLastPingDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "jetkvm_cloud_connection_last_ping_duration",
Name: "jetkvm_connection_last_ping_duration",
Help: "The duration of the last ping response",
},
[]string{"type", "source"},
)
metricCloudConnectionPingDuration = promauto.NewHistogram(
metricConnectionPingDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "jetkvm_cloud_connection_ping_duration",
Name: "jetkvm_connection_ping_duration",
Help: "The duration of the ping response",
Buckets: []float64{
0.1, 0.5, 1, 10,
},
},
[]string{"type", "source"},
)
metricCloudConnectionTotalPingCount = promauto.NewCounter(
metricConnectionTotalPingCount = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "jetkvm_cloud_connection_total_ping_count",
Help: "The total number of pings sent to the cloud",
Name: "jetkvm_connection_total_ping_count",
Help: "The total number of pings sent to the connection",
},
[]string{"type", "source"},
)
metricCloudConnectionSessionRequestCount = promauto.NewCounter(
metricConnectionSessionRequestCount = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "jetkvm_cloud_connection_session_total_request_count",
Help: "The total number of session requests received from the cloud",
Name: "jetkvm_connection_session_total_request_count",
Help: "The total number of session requests received",
},
[]string{"type", "source"},
)
metricCloudConnectionSessionRequestDuration = promauto.NewHistogram(
metricConnectionSessionRequestDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "jetkvm_cloud_connection_session_request_duration",
Name: "jetkvm_connection_session_request_duration",
Help: "The duration of session requests",
Buckets: []float64{
0.1, 0.5, 1, 10,
},
},
[]string{"type", "source"},
)
metricCloudConnectionLastSessionRequestTimestamp = promauto.NewGauge(
metricConnectionLastSessionRequestTimestamp = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "jetkvm_cloud_connection_last_session_request_timestamp",
Name: "jetkvm_connection_last_session_request_timestamp",
Help: "The timestamp of the last session request",
},
[]string{"type", "source"},
)
metricCloudConnectionLastSessionRequestDuration = promauto.NewGauge(
metricConnectionLastSessionRequestDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "jetkvm_cloud_connection_last_session_request_duration",
Name: "jetkvm_connection_last_session_request_duration",
Help: "The duration of the last session request",
},
[]string{"type", "source"},
)
metricCloudConnectionFailureCount = promauto.NewCounter(
prometheus.CounterOpts{
@ -119,12 +127,16 @@ var (
cloudDisconnectLock = &sync.Mutex{}
)
func cloudResetMetrics(established bool) {
metricCloudConnectionLastPingTimestamp.Set(-1)
metricCloudConnectionLastPingDuration.Set(-1)
func wsResetMetrics(established bool, sourceType string, source string) {
metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1)
metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1)
metricCloudConnectionLastSessionRequestTimestamp.Set(-1)
metricCloudConnectionLastSessionRequestDuration.Set(-1)
metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1)
metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1)
if sourceType != "cloud" {
return
}
if established {
metricCloudConnectionEstablishedTimestamp.SetToCurrentTime()
@ -256,6 +268,7 @@ func runWebsocketClient() error {
header := http.Header{}
header.Set("X-Device-ID", GetDeviceID())
header.Set("X-App-Version", builtAppVersion)
header.Set("Authorization", "Bearer "+config.CloudToken)
dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout)
@ -270,88 +283,13 @@ func runWebsocketClient() error {
cloudLogger.Infof("websocket connected to %s", wsURL)
// set the metrics when we successfully connect to the cloud.
cloudResetMetrics(true)
wsResetMetrics(true, "cloud", "")
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
go func() {
for {
time.Sleep(CloudWebSocketPingInterval)
// set the timer for the ping duration
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricCloudConnectionLastPingDuration.Set(v)
metricCloudConnectionPingDuration.Observe(v)
}))
err := c.Ping(runCtx)
if err != nil {
cloudLogger.Warnf("websocket ping error: %v", err)
cancelRun()
return
}
// dont use `defer` here because we want to observe the duration of the ping
timer.ObserveDuration()
metricCloudConnectionTotalPingCount.Inc()
metricCloudConnectionLastPingTimestamp.SetToCurrentTime()
}
}()
// create a channel to receive the disconnect event, once received, we cancelRun
cloudDisconnectChan = make(chan error)
defer func() {
close(cloudDisconnectChan)
cloudDisconnectChan = nil
}()
go func() {
for err := range cloudDisconnectChan {
if err == nil {
continue
}
cloudLogger.Infof("disconnecting from cloud due to: %v", err)
cancelRun()
}
}()
for {
typ, msg, err := c.Read(runCtx)
if err != nil {
return err
}
if typ != websocket.MessageText {
// ignore non-text messages
continue
}
var req WebRTCSessionRequest
err = json.Unmarshal(msg, &req)
if err != nil {
cloudLogger.Warnf("unable to parse ws message: %v", string(msg))
continue
}
cloudLogger.Infof("new session request: %v", req.OidcGoogle)
cloudLogger.Tracef("session request info: %v", req)
metricCloudConnectionSessionRequestCount.Inc()
metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime()
err = handleSessionRequest(runCtx, c, req)
if err != nil {
cloudLogger.Infof("error starting new session: %v", err)
continue
}
}
// we don't have a source for the cloud connection
return handleWebRTCSignalWsMessages(c, true, "")
}
func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricCloudConnectionLastSessionRequestDuration.Set(v)
metricCloudConnectionSessionRequestDuration.Observe(v)
}))
defer timer.ObserveDuration()
func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
oidcCtx, cancelOIDC := context.WithTimeout(ctx, CloudOidcRequestTimeout)
defer cancelOIDC()
provider, err := oidc.NewProvider(oidcCtx, "https://accounts.google.com")
@ -379,10 +317,35 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
return fmt.Errorf("google identity mismatch")
}
return nil
}
func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool, source string) error {
var sourceType string
if isCloudConnection {
sourceType = "cloud"
} else {
sourceType = "local"
}
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(v)
metricConnectionSessionRequestDuration.WithLabelValues(sourceType, source).Observe(v)
}))
defer timer.ObserveDuration()
// If the message is from the cloud, we need to authenticate the session.
if isCloudConnection {
if err := authenticateSession(ctx, c, req); err != nil {
return err
}
}
session, err := newSession(SessionConfig{
ICEServers: req.ICEServers,
ws: c,
IsCloud: isCloudConnection,
LocalIP: req.IP,
IsCloud: true,
ICEServers: req.ICEServers,
})
if err != nil {
_ = wsjson.Write(context.Background(), c, gin.H{"error": err})
@ -406,14 +369,14 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
cloudLogger.Info("new session accepted")
cloudLogger.Tracef("new session accepted: %v", session)
currentSession = session
_ = wsjson.Write(context.Background(), c, gin.H{"sd": sd})
_ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd})
return nil
}
func RunWebsocketClient() {
for {
// reset the metrics when we start the websocket client.
cloudResetMetrics(false)
wsResetMetrics(false, "cloud", "")
// If the cloud token is not set, we don't need to run the websocket client.
if config.CloudToken == "" {

1
log.go
View File

@ -6,3 +6,4 @@ import "github.com/pion/logging"
// ref: https://github.com/pion/webrtc/wiki/Debugging-WebRTC
var logger = logging.NewDefaultLoggerFactory().NewLogger("jetkvm")
var cloudLogger = logging.NewDefaultLoggerFactory().NewLogger("cloud")
var websocketLogger = logging.NewDefaultLoggerFactory().NewLogger("websocket")

6
ui/package-lock.json generated
View File

@ -30,6 +30,7 @@
"react-icons": "^5.4.0",
"react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112",
"react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9",
"recharts": "^2.15.0",
"tailwind-merge": "^2.5.5",
@ -5180,6 +5181,11 @@
"react-dom": ">=16.6.0"
}
},
"node_modules/react-use-websocket": {
"version": "4.13.0",
"resolved": "https://registry.npmjs.org/react-use-websocket/-/react-use-websocket-4.13.0.tgz",
"integrity": "sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw=="
},
"node_modules/react-xtermjs": {
"version": "1.0.9",
"resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz",

View File

@ -40,6 +40,7 @@
"react-icons": "^5.4.0",
"react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112",
"react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9",
"recharts": "^2.15.0",
"tailwind-merge": "^2.5.5",

View File

@ -36,7 +36,7 @@ export default function DashboardNavbar({
picture,
kvmName,
}: NavbarProps) {
const peerConnection = useRTCStore(state => state.peerConnection);
const peerConnectionState = useRTCStore(state => state.peerConnectionState);
const setUser = useUserStore(state => state.setUser);
const navigate = useNavigate();
const onLogout = useCallback(async () => {
@ -82,14 +82,14 @@ export default function DashboardNavbar({
<div className="hidden items-center gap-x-2 md:flex">
<div className="w-[159px]">
<PeerConnectionStatusCard
state={peerConnection?.connectionState}
state={peerConnectionState}
title={kvmName}
/>
</div>
<div className="hidden w-[159px] md:block">
<USBStateStatus
state={usbState}
peerConnectionState={peerConnection?.connectionState}
peerConnectionState={peerConnectionState}
/>
</div>
</div>

View File

@ -6,7 +6,7 @@ import { LuPlay } from "react-icons/lu";
import { Button, LinkButton } from "@components/Button";
import LoadingSpinner from "@components/LoadingSpinner";
import { GridCard } from "@components/Card";
import Card, { GridCard } from "@components/Card";
interface OverlayContentProps {
children: React.ReactNode;
@ -94,7 +94,7 @@ interface ConnectionErrorOverlayProps {
setupPeerConnection: () => Promise<void>;
}
export function ConnectionErrorOverlay({
export function ConnectionFailedOverlay({
show,
setupPeerConnection,
}: ConnectionErrorOverlayProps) {
@ -151,6 +151,60 @@ export function ConnectionErrorOverlay({
);
}
interface PeerConnectionDisconnectedOverlay {
show: boolean;
}
export function PeerConnectionDisconnectedOverlay({
show,
}: PeerConnectionDisconnectedOverlay) {
return (
<AnimatePresence>
{show && (
<motion.div
className="aspect-video h-full w-full"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0, transition: { duration: 0 } }}
transition={{
duration: 0.4,
ease: "easeInOut",
}}
>
<OverlayContent>
<div className="flex flex-col items-start gap-y-1">
<ExclamationTriangleIcon className="h-12 w-12 text-yellow-500" />
<div className="text-left text-sm text-slate-700 dark:text-slate-300">
<div className="space-y-4">
<div className="space-y-2 text-black dark:text-white">
<h2 className="text-xl font-bold">Connection Issue Detected</h2>
<ul className="list-disc space-y-2 pl-4 text-left">
<li>Verify that the device is powered on and properly connected</li>
<li>Check all cable connections for any loose or damaged wires</li>
<li>Ensure your network connection is stable and active</li>
<li>Try restarting both the device and your computer</li>
</ul>
</div>
<div className="flex items-center gap-x-2">
<Card>
<div className="flex items-center gap-x-2 p-4">
<LoadingSpinner className="h-4 w-4 text-blue-800 dark:text-blue-200" />
<p className="text-sm text-slate-700 dark:text-slate-300">
Retrying connection...
</p>
</div>
</Card>
</div>
</div>
</div>
</div>
</OverlayContent>
</motion.div>
)}
</AnimatePresence>
);
}
interface HDMIErrorOverlayProps {
show: boolean;
hdmiState: string;

View File

@ -380,7 +380,7 @@ export default function WebRTCVideo() {
(mediaStream: MediaStream) => {
if (!videoElm.current) return;
const videoElmRefValue = videoElm.current;
console.log("Adding stream to video element", videoElmRefValue);
// console.log("Adding stream to video element", videoElmRefValue);
videoElmRefValue.srcObject = mediaStream;
updateVideoSizeStore(videoElmRefValue);
},
@ -396,7 +396,7 @@ export default function WebRTCVideo() {
peerConnection.addEventListener(
"track",
(e: RTCTrackEvent) => {
console.log("Adding stream to video element");
// console.log("Adding stream to video element");
addStreamToVideoElm(e.streams[0]);
},
{ signal },

View File

@ -1,4 +1,4 @@
import { useCallback, useEffect, useRef, useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import {
LoaderFunctionArgs,
Outlet,
@ -14,6 +14,7 @@ import {
import { useInterval } from "usehooks-ts";
import FocusTrap from "focus-trap-react";
import { motion, AnimatePresence } from "framer-motion";
import useWebSocket from "react-use-websocket";
import { cx } from "@/cva.config";
import {
@ -43,15 +44,16 @@ import UpdateInProgressStatusCard from "../components/UpdateInProgressStatusCard
import api from "../api";
import Modal from "../components/Modal";
import { useDeviceUiNavigation } from "../hooks/useAppNavigation";
import {
ConnectionFailedOverlay,
LoadingConnectionOverlay,
PeerConnectionDisconnectedOverlay,
} from "../components/VideoOverlay";
import { FeatureFlagProvider } from "../providers/FeatureFlagProvider";
import notifications from "../notifications";
import {
ConnectionErrorOverlay,
LoadingConnectionOverlay,
} from "../components/VideoOverlay";
import { SystemVersionInfo } from "./devices.$id.settings.general.update";
import { DeviceStatus } from "./welcome-local";
import { SystemVersionInfo } from "./devices.$id.settings.general.update";
interface LocalLoaderResp {
authMode: "password" | "noPassword" | null;
@ -117,7 +119,6 @@ const loader = async ({ params }: LoaderFunctionArgs) => {
export default function KvmIdRoute() {
const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp;
// Depending on the mode, we set the appropriate variables
const user = "user" in loaderResp ? loaderResp.user : null;
const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null;
@ -130,6 +131,8 @@ export default function KvmIdRoute() {
const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse);
const peerConnection = useRTCStore(state => state.peerConnection);
const setPeerConnectionState = useRTCStore(state => state.setPeerConnectionState);
const peerConnectionState = useRTCStore(state => state.peerConnectionState);
const setMediaMediaStream = useRTCStore(state => state.setMediaStream);
const setPeerConnection = useRTCStore(state => state.setPeerConnection);
const setDiskChannel = useRTCStore(state => state.setDiskChannel);
@ -137,23 +140,28 @@ export default function KvmIdRoute() {
const setTransceiver = useRTCStore(state => state.setTransceiver);
const location = useLocation();
const isLegacySignalingEnabled = useRef(false);
const [connectionFailed, setConnectionFailed] = useState(false);
const navigate = useNavigate();
const { otaState, setOtaState, setModalView } = useUpdateStore();
const [loadingMessage, setLoadingMessage] = useState("Connecting to device...");
const closePeerConnection = useCallback(
function closePeerConnection() {
const cleanupAndStopReconnecting = useCallback(
function cleanupAndStopReconnecting() {
console.log("Closing peer connection");
setConnectionFailed(true);
if (peerConnection) {
setPeerConnectionState(peerConnection.connectionState);
}
connectionFailedRef.current = true;
peerConnection?.close();
signalingAttempts.current = 0;
},
[peerConnection],
[peerConnection, setPeerConnectionState],
);
// We need to track connectionFailed in a ref to avoid stale closure issues
@ -171,95 +179,233 @@ export default function KvmIdRoute() {
}, [connectionFailed]);
const signalingAttempts = useRef(0);
const syncRemoteSessionDescription = useCallback(
async function syncRemoteSessionDescription(pc: RTCPeerConnection) {
const setRemoteSessionDescription = useCallback(
async function setRemoteSessionDescription(
pc: RTCPeerConnection,
remoteDescription: RTCSessionDescriptionInit,
) {
setLoadingMessage("Setting remote description");
try {
if (!pc) return;
const sd = btoa(JSON.stringify(pc.localDescription));
const sessionUrl = isOnDevice
? `${DEVICE_API}/webrtc/session`
: `${CLOUD_API}/webrtc/session`;
console.log("Trying to get remote session description");
setLoadingMessage(
`Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
);
const res = await api.POST(sessionUrl, {
sd,
// When on device, we don't need to specify the device id, as it's already known
...(isOnDevice ? {} : { id: params.id }),
});
const json = await res.json();
if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
if (!res.ok) {
console.error("Error getting SDP", { status: res.status, json });
throw new Error("Error getting SDP");
}
console.log("Successfully got Remote Session Description. Setting.");
setLoadingMessage("Setting remote session description...");
const decodedSd = atob(json.sd);
const parsedSd = JSON.parse(decodedSd);
pc.setRemoteDescription(new RTCSessionDescription(parsedSd));
await new Promise((resolve, reject) => {
console.log("Waiting for remote description to be set");
const maxAttempts = 10;
const interval = 1000;
let attempts = 0;
const checkInterval = setInterval(() => {
attempts++;
// When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
if (pc.sctp?.state === "connected") {
console.log("Remote description set");
clearInterval(checkInterval);
resolve(true);
} else if (attempts >= maxAttempts) {
console.log(
`Failed to get remote description after ${maxAttempts} attempts`,
);
closePeerConnection();
clearInterval(checkInterval);
reject(
new Error(
`Failed to get remote description after ${maxAttempts} attempts`,
),
);
} else {
console.log("Waiting for remote description to be set");
}
}, interval);
});
await pc.setRemoteDescription(new RTCSessionDescription(remoteDescription));
console.log("[setRemoteSessionDescription] Remote description set successfully");
setLoadingMessage("Establishing secure connection...");
} catch (error) {
console.error("Error getting SDP", { error });
console.log("Connection failed", connectionFailedRef.current);
if (connectionFailedRef.current) return;
if (signalingAttempts.current < 5) {
signalingAttempts.current++;
await new Promise(resolve => setTimeout(resolve, 500));
console.log("Attempting to get SDP again", signalingAttempts.current);
syncRemoteSessionDescription(pc);
} else {
closePeerConnection();
}
console.error(
"[setRemoteSessionDescription] Failed to set remote description:",
error,
);
cleanupAndStopReconnecting();
return;
}
// Replace the interval-based check with a more reliable approach
let attempts = 0;
const checkInterval = setInterval(() => {
attempts++;
// When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
if (pc.sctp?.state === "connected") {
console.log("[setRemoteSessionDescription] Remote description set");
clearInterval(checkInterval);
setLoadingMessage("Connection established");
} else if (attempts >= 10) {
console.log(
"[setRemoteSessionDescription] Failed to establish connection after 10 attempts",
{
connectionState: pc.connectionState,
iceConnectionState: pc.iceConnectionState,
},
);
cleanupAndStopReconnecting();
clearInterval(checkInterval);
} else {
console.log("[setRemoteSessionDescription] Waiting for connection, state:", {
connectionState: pc.connectionState,
iceConnectionState: pc.iceConnectionState,
});
}
}, 1000);
},
[closePeerConnection, navigate, params.id],
[cleanupAndStopReconnecting],
);
const ignoreOffer = useRef(false);
const isSettingRemoteAnswerPending = useRef(false);
const makingOffer = useRef(false);
const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const { sendMessage, getWebSocket } = useWebSocket(
isOnDevice
? `${wsProtocol}//${window.location.host}/webrtc/signaling/client`
: `${CLOUD_API.replace("http", "ws")}/webrtc/signaling/client?id=${params.id}`,
{
heartbeat: true,
retryOnError: true,
reconnectAttempts: 5,
reconnectInterval: 1000,
onReconnectStop: () => {
console.log("Reconnect stopped");
cleanupAndStopReconnecting();
},
shouldReconnect(event) {
console.log("[Websocket] shouldReconnect", event);
// TODO: Why true?
return true;
},
onClose(event) {
console.log("[Websocket] onClose", event);
// We don't want to close everything down, we wait for the reconnect to stop instead
},
onError(event) {
console.log("[Websocket] onError", event);
// We don't want to close everything down, we wait for the reconnect to stop instead
},
onOpen() {
console.log("[Websocket] onOpen");
},
onMessage: message => {
if (message.data === "pong") return;
/*
Currently the signaling process is as follows:
After open, the other side will send a `device-metadata` message with the device version
If the device version is not set, we can assume the device is using the legacy signaling
Otherwise, we can assume the device is using the new signaling
If the device is using the legacy signaling, we close the websocket connection
and use the legacy HTTPSignaling function to get the remote session description
If the device is using the new signaling, we don't need to do anything special, but continue to use the websocket connection
to chat with the other peer about the connection
*/
const parsedMessage = JSON.parse(message.data);
if (parsedMessage.type === "device-metadata") {
const { deviceVersion } = parsedMessage.data;
console.log("[Websocket] Received device-metadata message");
console.log("[Websocket] Device version", deviceVersion);
// If the device version is not set, we can assume the device is using the legacy signaling
if (!deviceVersion) {
console.log("[Websocket] Device is using legacy signaling");
// Now we don't need the websocket connection anymore, as we've established that we need to use the legacy signaling
// which does everything over HTTP(at least from the perspective of the client)
isLegacySignalingEnabled.current = true;
getWebSocket()?.close();
} else {
console.log("[Websocket] Device is using new signaling");
isLegacySignalingEnabled.current = false;
}
setupPeerConnection();
}
if (!peerConnection) return;
if (parsedMessage.type === "answer") {
console.log("[Websocket] Received answer");
const readyForOffer =
// If we're making an offer, we don't want to accept an answer
!makingOffer &&
// If the peer connection is stable or we're setting the remote answer pending, we're ready for an offer
(peerConnection?.signalingState === "stable" ||
isSettingRemoteAnswerPending.current);
// If we're not ready for an offer, we don't want to accept an offer
ignoreOffer.current = parsedMessage.type === "offer" && !readyForOffer;
if (ignoreOffer.current) return;
// Set so we don't accept an answer while we're setting the remote description
isSettingRemoteAnswerPending.current = parsedMessage.type === "answer";
console.log(
"[Websocket] Setting remote answer pending",
isSettingRemoteAnswerPending.current,
);
const sd = atob(parsedMessage.data);
const remoteSessionDescription = JSON.parse(sd);
setRemoteSessionDescription(
peerConnection,
new RTCSessionDescription(remoteSessionDescription),
);
// Reset the remote answer pending flag
isSettingRemoteAnswerPending.current = false;
} else if (parsedMessage.type === "new-ice-candidate") {
console.log("[Websocket] Received new-ice-candidate");
const candidate = parsedMessage.data;
peerConnection.addIceCandidate(candidate);
}
},
},
// Don't even retry once we declare failure
!connectionFailed && isLegacySignalingEnabled.current === false,
);
const sendWebRTCSignal = useCallback(
(type: string, data: unknown) => {
// Second argument tells the library not to queue the message, and send it once the connection is established again.
// We have event handlers that handle the connection set up, so we don't need to queue the message.
sendMessage(JSON.stringify({ type, data }), false);
},
[sendMessage],
);
const legacyHTTPSignaling = useCallback(
async (pc: RTCPeerConnection) => {
const sd = btoa(JSON.stringify(pc.localDescription));
// Legacy mode == UI in cloud with updated code connecting to older device version.
// In device mode, old devices wont server this JS, and on newer devices legacy mode wont be enabled
const sessionUrl = `${CLOUD_API}/webrtc/session`;
console.log("Trying to get remote session description");
setLoadingMessage(
`Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
);
const res = await api.POST(sessionUrl, {
sd,
// When on device, we don't need to specify the device id, as it's already known
...(isOnDevice ? {} : { id: params.id }),
});
const json = await res.json();
if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
if (!res.ok) {
console.error("Error getting SDP", { status: res.status, json });
cleanupAndStopReconnecting();
return;
}
console.log("Successfully got Remote Session Description. Setting.");
setLoadingMessage("Setting remote session description...");
const decodedSd = atob(json.sd);
const parsedSd = JSON.parse(decodedSd);
setRemoteSessionDescription(pc, new RTCSessionDescription(parsedSd));
},
[cleanupAndStopReconnecting, navigate, params.id, setRemoteSessionDescription],
);
const setupPeerConnection = useCallback(async () => {
console.log("Setting up peer connection");
console.log("[setupPeerConnection] Setting up peer connection");
setConnectionFailed(false);
setLoadingMessage("Connecting to device...");
if (peerConnection?.signalingState === "stable") {
console.log("[setupPeerConnection] Peer connection already established");
return;
}
let pc: RTCPeerConnection;
try {
console.log("Creating peer connection");
console.log("[setupPeerConnection] Creating peer connection");
setLoadingMessage("Creating peer connection...");
pc = new RTCPeerConnection({
// We only use STUN or TURN servers if we're in the cloud
@ -267,30 +413,65 @@ export default function KvmIdRoute() {
? { iceServers: [iceConfig?.iceServers] }
: {}),
});
console.log("Peer connection created", pc);
setLoadingMessage("Peer connection created");
setPeerConnectionState(pc.connectionState);
console.log("[setupPeerConnection] Peer connection created", pc);
setLoadingMessage("Setting up connection to device...");
} catch (e) {
console.error(`Error creating peer connection: ${e}`);
console.error(`[setupPeerConnection] Error creating peer connection: ${e}`);
setTimeout(() => {
closePeerConnection();
cleanupAndStopReconnecting();
}, 1000);
return;
}
// Set up event listeners and data channels
pc.onconnectionstatechange = () => {
console.log("Connection state changed", pc.connectionState);
console.log("[setupPeerConnection] Connection state changed", pc.connectionState);
setPeerConnectionState(pc.connectionState);
};
pc.onnegotiationneeded = async () => {
try {
console.log("[setupPeerConnection] Creating offer");
makingOffer.current = true;
const offer = await pc.createOffer();
await pc.setLocalDescription(offer);
const sd = btoa(JSON.stringify(pc.localDescription));
const isNewSignalingEnabled = isLegacySignalingEnabled.current === false;
if (isNewSignalingEnabled) {
sendWebRTCSignal("offer", { sd: sd });
} else {
console.log("Legacy signanling. Waiting for ICE Gathering to complete...");
}
} catch (e) {
console.error(
`[setupPeerConnection] Error creating offer: ${e}`,
new Date().toISOString(),
);
cleanupAndStopReconnecting();
} finally {
makingOffer.current = false;
}
};
pc.onicecandidate = async ({ candidate }) => {
if (!candidate) return;
if (candidate.candidate === "") return;
sendWebRTCSignal("new-ice-candidate", candidate);
};
pc.onicegatheringstatechange = event => {
const pc = event.currentTarget as RTCPeerConnection;
console.log("ICE Gathering State Changed", pc.iceGatheringState);
if (pc.iceGatheringState === "complete") {
console.log("ICE Gathering completed");
setLoadingMessage("ICE Gathering completed");
// We can now start the https/ws connection to get the remote session description from the KVM device
syncRemoteSessionDescription(pc);
if (isLegacySignalingEnabled.current) {
// We can now start the https/ws connection to get the remote session description from the KVM device
legacyHTTPSignaling(pc);
}
} else if (pc.iceGatheringState === "gathering") {
console.log("ICE Gathering Started");
setLoadingMessage("Gathering ICE candidates...");
@ -314,31 +495,26 @@ export default function KvmIdRoute() {
};
setPeerConnection(pc);
try {
const offer = await pc.createOffer();
await pc.setLocalDescription(offer);
} catch (e) {
console.error(`Error creating offer: ${e}`, new Date().toISOString());
closePeerConnection();
}
}, [
closePeerConnection,
cleanupAndStopReconnecting,
iceConfig?.iceServers,
legacyHTTPSignaling,
peerConnection?.signalingState,
sendWebRTCSignal,
setDiskChannel,
setMediaMediaStream,
setPeerConnection,
setPeerConnectionState,
setRpcDataChannel,
setTransceiver,
syncRemoteSessionDescription,
]);
// On boot, if the connection state is undefined, we connect to the WebRTC
useEffect(() => {
if (peerConnection?.connectionState === undefined) {
setupPeerConnection();
if (peerConnectionState === "failed") {
console.log("Connection failed, closing peer connection");
cleanupAndStopReconnecting();
}
}, [setupPeerConnection, peerConnection?.connectionState]);
}, [peerConnectionState, cleanupAndStopReconnecting]);
// Cleanup effect
const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats);
@ -363,7 +539,7 @@ export default function KvmIdRoute() {
// TURN server usage detection
useEffect(() => {
if (peerConnection?.connectionState !== "connected") return;
if (peerConnectionState !== "connected") return;
const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState();
const lastLocalStat = Array.from(localCandidateStats).pop();
@ -375,7 +551,7 @@ export default function KvmIdRoute() {
const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here
setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn);
}, [peerConnection?.connectionState, setIsTurnServerInUse]);
}, [peerConnectionState, setIsTurnServerInUse]);
// TURN server usage reporting
const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse);
@ -466,10 +642,6 @@ export default function KvmIdRoute() {
});
}, [rpcDataChannel?.readyState, send, setHdmiState]);
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
window.send = send;
// When the update is successful, we need to refresh the client javascript and show a success modal
useEffect(() => {
if (queryParams.get("updateSuccess")) {
@ -506,12 +678,12 @@ export default function KvmIdRoute() {
useEffect(() => {
if (!peerConnection) return;
if (!kvmTerminal) {
console.log('Creating data channel "terminal"');
// console.log('Creating data channel "terminal"');
setKvmTerminal(peerConnection.createDataChannel("terminal"));
}
if (!serialConsole) {
console.log('Creating data channel "serial"');
// console.log('Creating data channel "serial"');
setSerialConsole(peerConnection.createDataChannel("serial"));
}
}, [kvmTerminal, peerConnection, serialConsole]);
@ -554,6 +726,43 @@ export default function KvmIdRoute() {
[send, setScrollSensitivity],
);
const ConnectionStatusElement = useMemo(() => {
const hasConnectionFailed =
connectionFailed || ["failed", "closed"].includes(peerConnectionState || "");
const isPeerConnectionLoading =
["connecting", "new"].includes(peerConnectionState || "") ||
peerConnection === null;
const isDisconnected = peerConnectionState === "disconnected";
const isOtherSession = location.pathname.includes("other-session");
if (isOtherSession) return null;
if (peerConnectionState === "connected") return null;
if (isDisconnected) {
return <PeerConnectionDisconnectedOverlay show={true} />;
}
if (hasConnectionFailed)
return (
<ConnectionFailedOverlay show={true} setupPeerConnection={setupPeerConnection} />
);
if (isPeerConnectionLoading) {
return <LoadingConnectionOverlay show={true} text={loadingMessage} />;
}
return null;
}, [
connectionFailed,
loadingMessage,
location.pathname,
peerConnection,
peerConnectionState,
setupPeerConnection,
]);
return (
<FeatureFlagProvider appVersion={appVersion}>
{!outlet && otaState.updating && (
@ -593,27 +802,13 @@ export default function KvmIdRoute() {
/>
<div className="flex h-full w-full overflow-hidden">
<div className="pointer-events-none fixed inset-0 isolate z-50 flex h-full w-full items-center justify-center">
<div className="pointer-events-none fixed inset-0 isolate z-20 flex h-full w-full items-center justify-center">
<div className="my-2 h-full max-h-[720px] w-full max-w-[1280px] rounded-md">
<LoadingConnectionOverlay
show={
!connectionFailed &&
(["connecting", "new"].includes(
peerConnection?.connectionState || "",
) ||
peerConnection === null) &&
!location.pathname.includes("other-session")
}
text={loadingMessage}
/>
<ConnectionErrorOverlay
show={connectionFailed && !location.pathname.includes("other-session")}
setupPeerConnection={setupPeerConnection}
/>
{!!ConnectionStatusElement && ConnectionStatusElement}
</div>
</div>
<WebRTCVideo />
{peerConnectionState === "connected" && <WebRTCVideo />}
<SidebarContainer sidebarView={sidebarView} />
</div>
</div>

188
web.go
View File

@ -1,6 +1,7 @@
package kvm
import (
"context"
"embed"
"encoding/json"
"fmt"
@ -10,8 +11,12 @@ import (
"strings"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/pion/webrtc/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/crypto/bcrypt"
)
@ -94,7 +99,7 @@ func setupRouter() *gin.Engine {
protected := r.Group("/")
protected.Use(protectedMiddleware())
{
protected.POST("/webrtc/session", handleWebRTCSession)
protected.GET("/webrtc/signaling/client", handleLocalWebRTCSignal)
protected.POST("/cloud/register", handleCloudRegister)
protected.GET("/cloud/state", handleCloudState)
protected.GET("/device", handleDevice)
@ -121,35 +126,182 @@ func setupRouter() *gin.Engine {
// TODO: support multiple sessions?
var currentSession *Session
func handleWebRTCSession(c *gin.Context) {
var req WebRTCSessionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
func handleLocalWebRTCSignal(c *gin.Context) {
cloudLogger.Infof("new websocket connection established")
// Create WebSocket options with InsecureSkipVerify to bypass origin check
wsOptions := &websocket.AcceptOptions{
InsecureSkipVerify: true, // Allow connections from any origin
}
session, err := newSession(SessionConfig{})
wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sd, err := session.ExchangeOffer(req.Sd)
// get the source from the request
source := c.ClientIP()
// Now use conn for websocket operations
defer wsCon.Close(websocket.StatusNormalClosure, "")
err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": gin.H{"deviceVersion": builtAppVersion}})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if currentSession != nil {
writeJSONRPCEvent("otherSessionConnected", nil, currentSession)
peerConn := currentSession.peerConnection
err = handleWebRTCSignalWsMessages(wsCon, false, source)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string) error {
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
// Add connection tracking to detect reconnections
connectionID := uuid.New().String()
cloudLogger.Infof("new websocket connection established with ID: %s", connectionID)
// connection type
var sourceType string
if isCloudConnection {
sourceType = "cloud"
} else {
sourceType = "local"
}
// probably we can use a better logging framework here
logInfof := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Infof(format+", source: %s, sourceType: %s", args...)
}
logWarnf := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...)
}
logTracef := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...)
}
go func() {
for {
time.Sleep(WebsocketPingInterval)
// set the timer for the ping duration
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v)
metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v)
}))
logInfof("pinging websocket")
err := wsCon.Ping(runCtx)
if err != nil {
logWarnf("websocket ping error: %v", err)
cancelRun()
return
}
// dont use `defer` here because we want to observe the duration of the ping
timer.ObserveDuration()
metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc()
metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
}
}()
if isCloudConnection {
// create a channel to receive the disconnect event, once received, we cancelRun
cloudDisconnectChan = make(chan error)
defer func() {
close(cloudDisconnectChan)
cloudDisconnectChan = nil
}()
go func() {
time.Sleep(1 * time.Second)
_ = peerConn.Close()
for err := range cloudDisconnectChan {
if err == nil {
continue
}
cloudLogger.Infof("disconnecting from cloud due to: %v", err)
cancelRun()
}
}()
}
currentSession = session
c.JSON(http.StatusOK, gin.H{"sd": sd})
for {
typ, msg, err := wsCon.Read(runCtx)
if err != nil {
logWarnf("websocket read error: %v", err)
return err
}
if typ != websocket.MessageText {
// ignore non-text messages
continue
}
var message struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
err = json.Unmarshal(msg, &message)
if err != nil {
logWarnf("unable to parse ws message: %v", err)
continue
}
if message.Type == "offer" {
logInfof("new session request received")
var req WebRTCSessionRequest
err = json.Unmarshal(message.Data, &req)
if err != nil {
logWarnf("unable to parse session request data: %v", err)
continue
}
logInfof("new session request: %v", req.OidcGoogle)
logTracef("session request info: %v", req)
metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc()
metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source)
if err != nil {
logWarnf("error starting new session: %v", err)
continue
}
} else if message.Type == "new-ice-candidate" {
logInfof("The client sent us a new ICE candidate: %v", string(message.Data))
var candidate webrtc.ICECandidateInit
// Attempt to unmarshal as a ICECandidateInit
if err := json.Unmarshal(message.Data, &candidate); err != nil {
logWarnf("unable to parse incoming ICE candidate data: %v", string(message.Data))
continue
}
if candidate.Candidate == "" {
logWarnf("empty incoming ICE candidate, skipping")
continue
}
logInfof("unmarshalled incoming ICE candidate: %v", candidate)
if currentSession == nil {
logInfof("no current session, skipping incoming ICE candidate")
continue
}
logInfof("adding incoming ICE candidate to current session: %v", candidate)
if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
logWarnf("failed to add incoming ICE candidate to our peer connection: %v", err)
}
}
}
}
func handleLogin(c *gin.Context) {

View File

@ -1,11 +1,15 @@
package kvm
import (
"context"
"encoding/base64"
"encoding/json"
"net"
"strings"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
"github.com/pion/webrtc/v4"
)
@ -23,6 +27,7 @@ type SessionConfig struct {
ICEServers []string
LocalIP string
IsCloud bool
ws *websocket.Conn
}
func (s *Session) ExchangeOffer(offerStr string) (string, error) {
@ -46,19 +51,11 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
return "", err
}
// Create channel that is blocked until ICE Gathering is complete
gatherComplete := webrtc.GatheringCompletePromise(s.peerConnection)
// Sets the LocalDescription, and starts our UDP listeners
if err = s.peerConnection.SetLocalDescription(answer); err != nil {
return "", err
}
// Block until ICE Gathering is complete, disabling trickle ICE
// we do this because we only can exchange one signaling message
// in a production application you should exchange ICE Candidates via OnICECandidate
<-gatherComplete
localDescription, err := json.Marshal(s.peerConnection.LocalDescription())
if err != nil {
return "", err
@ -144,6 +141,16 @@ func newSession(config SessionConfig) (*Session, error) {
}()
var isConnected bool
peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
logger.Infof("Our WebRTC peerConnection has a new ICE candidate: %v", candidate)
if candidate != nil {
err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
if err != nil {
logger.Errorf("failed to write new-ice-candidate to WebRTC signaling channel: %v", err)
}
}
})
peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
logger.Infof("Connection State has changed %s", connectionState)
if connectionState == webrtc.ICEConnectionStateConnected {