Feat/Trickle ice (#336)

* feat(cloud): Use Websocket signaling in cloud mode

* refactor: Enhance WebRTC signaling and connection handling

* refactor: Improve WebRTC connection management and logging in KvmIdRoute

* refactor: Update PeerConnectionDisconnectedOverlay to use Card component for better UI structure

* refactor: Standardize metric naming and improve websocket logging

* refactor: Rename WebRTC signaling functions and update deployment script for debug version

* fix: Handle error when writing new ICE candidate to WebRTC signaling channel

* refactor: Rename signaling handler function for clarity

* refactor: Remove old http local http endpoint

* refactor: Improve metric help text and standardize comparison operator in KvmIdRoute

* chore(websocket): use MetricVec instead of Metric to store metrics

* fix conflicts

* fix: use wss when the page is served over https

* feat: Add app version header and update WebRTC signaling endpoint

* fix: Handle error when writing device metadata to WebRTC signaling channel

---------

Co-authored-by: Siyuan Miao <i@xswan.net>
This commit is contained in:
Adam Shiervani 2025-04-09 00:10:38 +02:00 committed by GitHub
parent fa1b11b228
commit 1a30977085
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 654 additions and 275 deletions

179
cloud.go
View File

@ -35,8 +35,8 @@ const (
// CloudOidcRequestTimeout is the timeout for OIDC token verification requests // CloudOidcRequestTimeout is the timeout for OIDC token verification requests
// should be lower than the websocket response timeout set in cloud-api // should be lower than the websocket response timeout set in cloud-api
CloudOidcRequestTimeout = 10 * time.Second CloudOidcRequestTimeout = 10 * time.Second
// CloudWebSocketPingInterval is the interval at which the websocket client sends ping messages to the cloud // WebsocketPingInterval is the interval at which the websocket client sends ping messages to the cloud
CloudWebSocketPingInterval = 15 * time.Second WebsocketPingInterval = 15 * time.Second
) )
var ( var (
@ -52,59 +52,67 @@ var (
Help: "The timestamp when the cloud connection was established", Help: "The timestamp when the cloud connection was established",
}, },
) )
metricCloudConnectionLastPingTimestamp = promauto.NewGauge( metricConnectionLastPingTimestamp = promauto.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "jetkvm_cloud_connection_last_ping_timestamp", Name: "jetkvm_connection_last_ping_timestamp",
Help: "The timestamp when the last ping response was received", Help: "The timestamp when the last ping response was received",
}, },
[]string{"type", "source"},
) )
metricCloudConnectionLastPingDuration = promauto.NewGauge( metricConnectionLastPingDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "jetkvm_cloud_connection_last_ping_duration", Name: "jetkvm_connection_last_ping_duration",
Help: "The duration of the last ping response", Help: "The duration of the last ping response",
}, },
[]string{"type", "source"},
) )
metricCloudConnectionPingDuration = promauto.NewHistogram( metricConnectionPingDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{ prometheus.HistogramOpts{
Name: "jetkvm_cloud_connection_ping_duration", Name: "jetkvm_connection_ping_duration",
Help: "The duration of the ping response", Help: "The duration of the ping response",
Buckets: []float64{ Buckets: []float64{
0.1, 0.5, 1, 10, 0.1, 0.5, 1, 10,
}, },
}, },
[]string{"type", "source"},
) )
metricCloudConnectionTotalPingCount = promauto.NewCounter( metricConnectionTotalPingCount = promauto.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "jetkvm_cloud_connection_total_ping_count", Name: "jetkvm_connection_total_ping_count",
Help: "The total number of pings sent to the cloud", Help: "The total number of pings sent to the connection",
}, },
[]string{"type", "source"},
) )
metricCloudConnectionSessionRequestCount = promauto.NewCounter( metricConnectionSessionRequestCount = promauto.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "jetkvm_cloud_connection_session_total_request_count", Name: "jetkvm_connection_session_total_request_count",
Help: "The total number of session requests received from the cloud", Help: "The total number of session requests received",
}, },
[]string{"type", "source"},
) )
metricCloudConnectionSessionRequestDuration = promauto.NewHistogram( metricConnectionSessionRequestDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{ prometheus.HistogramOpts{
Name: "jetkvm_cloud_connection_session_request_duration", Name: "jetkvm_connection_session_request_duration",
Help: "The duration of session requests", Help: "The duration of session requests",
Buckets: []float64{ Buckets: []float64{
0.1, 0.5, 1, 10, 0.1, 0.5, 1, 10,
}, },
}, },
[]string{"type", "source"},
) )
metricCloudConnectionLastSessionRequestTimestamp = promauto.NewGauge( metricConnectionLastSessionRequestTimestamp = promauto.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "jetkvm_cloud_connection_last_session_request_timestamp", Name: "jetkvm_connection_last_session_request_timestamp",
Help: "The timestamp of the last session request", Help: "The timestamp of the last session request",
}, },
[]string{"type", "source"},
) )
metricCloudConnectionLastSessionRequestDuration = promauto.NewGauge( metricConnectionLastSessionRequestDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "jetkvm_cloud_connection_last_session_request_duration", Name: "jetkvm_connection_last_session_request_duration",
Help: "The duration of the last session request", Help: "The duration of the last session request",
}, },
[]string{"type", "source"},
) )
metricCloudConnectionFailureCount = promauto.NewCounter( metricCloudConnectionFailureCount = promauto.NewCounter(
prometheus.CounterOpts{ prometheus.CounterOpts{
@ -119,12 +127,16 @@ var (
cloudDisconnectLock = &sync.Mutex{} cloudDisconnectLock = &sync.Mutex{}
) )
func cloudResetMetrics(established bool) { func wsResetMetrics(established bool, sourceType string, source string) {
metricCloudConnectionLastPingTimestamp.Set(-1) metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1)
metricCloudConnectionLastPingDuration.Set(-1) metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1)
metricCloudConnectionLastSessionRequestTimestamp.Set(-1) metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1)
metricCloudConnectionLastSessionRequestDuration.Set(-1) metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1)
if sourceType != "cloud" {
return
}
if established { if established {
metricCloudConnectionEstablishedTimestamp.SetToCurrentTime() metricCloudConnectionEstablishedTimestamp.SetToCurrentTime()
@ -256,6 +268,7 @@ func runWebsocketClient() error {
header := http.Header{} header := http.Header{}
header.Set("X-Device-ID", GetDeviceID()) header.Set("X-Device-ID", GetDeviceID())
header.Set("X-App-Version", builtAppVersion)
header.Set("Authorization", "Bearer "+config.CloudToken) header.Set("Authorization", "Bearer "+config.CloudToken)
dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout) dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout)
@ -270,88 +283,13 @@ func runWebsocketClient() error {
cloudLogger.Infof("websocket connected to %s", wsURL) cloudLogger.Infof("websocket connected to %s", wsURL)
// set the metrics when we successfully connect to the cloud. // set the metrics when we successfully connect to the cloud.
cloudResetMetrics(true) wsResetMetrics(true, "cloud", "")
runCtx, cancelRun := context.WithCancel(context.Background()) // we don't have a source for the cloud connection
defer cancelRun() return handleWebRTCSignalWsMessages(c, true, "")
go func() {
for {
time.Sleep(CloudWebSocketPingInterval)
// set the timer for the ping duration
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricCloudConnectionLastPingDuration.Set(v)
metricCloudConnectionPingDuration.Observe(v)
}))
err := c.Ping(runCtx)
if err != nil {
cloudLogger.Warnf("websocket ping error: %v", err)
cancelRun()
return
}
// dont use `defer` here because we want to observe the duration of the ping
timer.ObserveDuration()
metricCloudConnectionTotalPingCount.Inc()
metricCloudConnectionLastPingTimestamp.SetToCurrentTime()
}
}()
// create a channel to receive the disconnect event, once received, we cancelRun
cloudDisconnectChan = make(chan error)
defer func() {
close(cloudDisconnectChan)
cloudDisconnectChan = nil
}()
go func() {
for err := range cloudDisconnectChan {
if err == nil {
continue
}
cloudLogger.Infof("disconnecting from cloud due to: %v", err)
cancelRun()
}
}()
for {
typ, msg, err := c.Read(runCtx)
if err != nil {
return err
}
if typ != websocket.MessageText {
// ignore non-text messages
continue
}
var req WebRTCSessionRequest
err = json.Unmarshal(msg, &req)
if err != nil {
cloudLogger.Warnf("unable to parse ws message: %v", string(msg))
continue
}
cloudLogger.Infof("new session request: %v", req.OidcGoogle)
cloudLogger.Tracef("session request info: %v", req)
metricCloudConnectionSessionRequestCount.Inc()
metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime()
err = handleSessionRequest(runCtx, c, req)
if err != nil {
cloudLogger.Infof("error starting new session: %v", err)
continue
}
}
} }
func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricCloudConnectionLastSessionRequestDuration.Set(v)
metricCloudConnectionSessionRequestDuration.Observe(v)
}))
defer timer.ObserveDuration()
oidcCtx, cancelOIDC := context.WithTimeout(ctx, CloudOidcRequestTimeout) oidcCtx, cancelOIDC := context.WithTimeout(ctx, CloudOidcRequestTimeout)
defer cancelOIDC() defer cancelOIDC()
provider, err := oidc.NewProvider(oidcCtx, "https://accounts.google.com") provider, err := oidc.NewProvider(oidcCtx, "https://accounts.google.com")
@ -379,10 +317,35 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
return fmt.Errorf("google identity mismatch") return fmt.Errorf("google identity mismatch")
} }
return nil
}
func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool, source string) error {
var sourceType string
if isCloudConnection {
sourceType = "cloud"
} else {
sourceType = "local"
}
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(v)
metricConnectionSessionRequestDuration.WithLabelValues(sourceType, source).Observe(v)
}))
defer timer.ObserveDuration()
// If the message is from the cloud, we need to authenticate the session.
if isCloudConnection {
if err := authenticateSession(ctx, c, req); err != nil {
return err
}
}
session, err := newSession(SessionConfig{ session, err := newSession(SessionConfig{
ICEServers: req.ICEServers, ws: c,
IsCloud: isCloudConnection,
LocalIP: req.IP, LocalIP: req.IP,
IsCloud: true, ICEServers: req.ICEServers,
}) })
if err != nil { if err != nil {
_ = wsjson.Write(context.Background(), c, gin.H{"error": err}) _ = wsjson.Write(context.Background(), c, gin.H{"error": err})
@ -406,14 +369,14 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
cloudLogger.Info("new session accepted") cloudLogger.Info("new session accepted")
cloudLogger.Tracef("new session accepted: %v", session) cloudLogger.Tracef("new session accepted: %v", session)
currentSession = session currentSession = session
_ = wsjson.Write(context.Background(), c, gin.H{"sd": sd}) _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd})
return nil return nil
} }
func RunWebsocketClient() { func RunWebsocketClient() {
for { for {
// reset the metrics when we start the websocket client. // reset the metrics when we start the websocket client.
cloudResetMetrics(false) wsResetMetrics(false, "cloud", "")
// If the cloud token is not set, we don't need to run the websocket client. // If the cloud token is not set, we don't need to run the websocket client.
if config.CloudToken == "" { if config.CloudToken == "" {

1
log.go
View File

@ -6,3 +6,4 @@ import "github.com/pion/logging"
// ref: https://github.com/pion/webrtc/wiki/Debugging-WebRTC // ref: https://github.com/pion/webrtc/wiki/Debugging-WebRTC
var logger = logging.NewDefaultLoggerFactory().NewLogger("jetkvm") var logger = logging.NewDefaultLoggerFactory().NewLogger("jetkvm")
var cloudLogger = logging.NewDefaultLoggerFactory().NewLogger("cloud") var cloudLogger = logging.NewDefaultLoggerFactory().NewLogger("cloud")
var websocketLogger = logging.NewDefaultLoggerFactory().NewLogger("websocket")

6
ui/package-lock.json generated
View File

@ -30,6 +30,7 @@
"react-icons": "^5.4.0", "react-icons": "^5.4.0",
"react-router-dom": "^6.22.3", "react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112", "react-simple-keyboard": "^3.7.112",
"react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9", "react-xtermjs": "^1.0.9",
"recharts": "^2.15.0", "recharts": "^2.15.0",
"tailwind-merge": "^2.5.5", "tailwind-merge": "^2.5.5",
@ -5180,6 +5181,11 @@
"react-dom": ">=16.6.0" "react-dom": ">=16.6.0"
} }
}, },
"node_modules/react-use-websocket": {
"version": "4.13.0",
"resolved": "https://registry.npmjs.org/react-use-websocket/-/react-use-websocket-4.13.0.tgz",
"integrity": "sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw=="
},
"node_modules/react-xtermjs": { "node_modules/react-xtermjs": {
"version": "1.0.9", "version": "1.0.9",
"resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz", "resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz",

View File

@ -40,6 +40,7 @@
"react-icons": "^5.4.0", "react-icons": "^5.4.0",
"react-router-dom": "^6.22.3", "react-router-dom": "^6.22.3",
"react-simple-keyboard": "^3.7.112", "react-simple-keyboard": "^3.7.112",
"react-use-websocket": "^4.13.0",
"react-xtermjs": "^1.0.9", "react-xtermjs": "^1.0.9",
"recharts": "^2.15.0", "recharts": "^2.15.0",
"tailwind-merge": "^2.5.5", "tailwind-merge": "^2.5.5",

View File

@ -36,7 +36,7 @@ export default function DashboardNavbar({
picture, picture,
kvmName, kvmName,
}: NavbarProps) { }: NavbarProps) {
const peerConnection = useRTCStore(state => state.peerConnection); const peerConnectionState = useRTCStore(state => state.peerConnectionState);
const setUser = useUserStore(state => state.setUser); const setUser = useUserStore(state => state.setUser);
const navigate = useNavigate(); const navigate = useNavigate();
const onLogout = useCallback(async () => { const onLogout = useCallback(async () => {
@ -82,14 +82,14 @@ export default function DashboardNavbar({
<div className="hidden items-center gap-x-2 md:flex"> <div className="hidden items-center gap-x-2 md:flex">
<div className="w-[159px]"> <div className="w-[159px]">
<PeerConnectionStatusCard <PeerConnectionStatusCard
state={peerConnection?.connectionState} state={peerConnectionState}
title={kvmName} title={kvmName}
/> />
</div> </div>
<div className="hidden w-[159px] md:block"> <div className="hidden w-[159px] md:block">
<USBStateStatus <USBStateStatus
state={usbState} state={usbState}
peerConnectionState={peerConnection?.connectionState} peerConnectionState={peerConnectionState}
/> />
</div> </div>
</div> </div>

View File

@ -6,7 +6,7 @@ import { LuPlay } from "react-icons/lu";
import { Button, LinkButton } from "@components/Button"; import { Button, LinkButton } from "@components/Button";
import LoadingSpinner from "@components/LoadingSpinner"; import LoadingSpinner from "@components/LoadingSpinner";
import { GridCard } from "@components/Card"; import Card, { GridCard } from "@components/Card";
interface OverlayContentProps { interface OverlayContentProps {
children: React.ReactNode; children: React.ReactNode;
@ -94,7 +94,7 @@ interface ConnectionErrorOverlayProps {
setupPeerConnection: () => Promise<void>; setupPeerConnection: () => Promise<void>;
} }
export function ConnectionErrorOverlay({ export function ConnectionFailedOverlay({
show, show,
setupPeerConnection, setupPeerConnection,
}: ConnectionErrorOverlayProps) { }: ConnectionErrorOverlayProps) {
@ -151,6 +151,60 @@ export function ConnectionErrorOverlay({
); );
} }
interface PeerConnectionDisconnectedOverlay {
show: boolean;
}
export function PeerConnectionDisconnectedOverlay({
show,
}: PeerConnectionDisconnectedOverlay) {
return (
<AnimatePresence>
{show && (
<motion.div
className="aspect-video h-full w-full"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0, transition: { duration: 0 } }}
transition={{
duration: 0.4,
ease: "easeInOut",
}}
>
<OverlayContent>
<div className="flex flex-col items-start gap-y-1">
<ExclamationTriangleIcon className="h-12 w-12 text-yellow-500" />
<div className="text-left text-sm text-slate-700 dark:text-slate-300">
<div className="space-y-4">
<div className="space-y-2 text-black dark:text-white">
<h2 className="text-xl font-bold">Connection Issue Detected</h2>
<ul className="list-disc space-y-2 pl-4 text-left">
<li>Verify that the device is powered on and properly connected</li>
<li>Check all cable connections for any loose or damaged wires</li>
<li>Ensure your network connection is stable and active</li>
<li>Try restarting both the device and your computer</li>
</ul>
</div>
<div className="flex items-center gap-x-2">
<Card>
<div className="flex items-center gap-x-2 p-4">
<LoadingSpinner className="h-4 w-4 text-blue-800 dark:text-blue-200" />
<p className="text-sm text-slate-700 dark:text-slate-300">
Retrying connection...
</p>
</div>
</Card>
</div>
</div>
</div>
</div>
</OverlayContent>
</motion.div>
)}
</AnimatePresence>
);
}
interface HDMIErrorOverlayProps { interface HDMIErrorOverlayProps {
show: boolean; show: boolean;
hdmiState: string; hdmiState: string;

View File

@ -380,7 +380,7 @@ export default function WebRTCVideo() {
(mediaStream: MediaStream) => { (mediaStream: MediaStream) => {
if (!videoElm.current) return; if (!videoElm.current) return;
const videoElmRefValue = videoElm.current; const videoElmRefValue = videoElm.current;
console.log("Adding stream to video element", videoElmRefValue); // console.log("Adding stream to video element", videoElmRefValue);
videoElmRefValue.srcObject = mediaStream; videoElmRefValue.srcObject = mediaStream;
updateVideoSizeStore(videoElmRefValue); updateVideoSizeStore(videoElmRefValue);
}, },
@ -396,7 +396,7 @@ export default function WebRTCVideo() {
peerConnection.addEventListener( peerConnection.addEventListener(
"track", "track",
(e: RTCTrackEvent) => { (e: RTCTrackEvent) => {
console.log("Adding stream to video element"); // console.log("Adding stream to video element");
addStreamToVideoElm(e.streams[0]); addStreamToVideoElm(e.streams[0]);
}, },
{ signal }, { signal },

View File

@ -1,4 +1,4 @@
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { import {
LoaderFunctionArgs, LoaderFunctionArgs,
Outlet, Outlet,
@ -14,6 +14,7 @@ import {
import { useInterval } from "usehooks-ts"; import { useInterval } from "usehooks-ts";
import FocusTrap from "focus-trap-react"; import FocusTrap from "focus-trap-react";
import { motion, AnimatePresence } from "framer-motion"; import { motion, AnimatePresence } from "framer-motion";
import useWebSocket from "react-use-websocket";
import { cx } from "@/cva.config"; import { cx } from "@/cva.config";
import { import {
@ -43,15 +44,16 @@ import UpdateInProgressStatusCard from "../components/UpdateInProgressStatusCard
import api from "../api"; import api from "../api";
import Modal from "../components/Modal"; import Modal from "../components/Modal";
import { useDeviceUiNavigation } from "../hooks/useAppNavigation"; import { useDeviceUiNavigation } from "../hooks/useAppNavigation";
import {
ConnectionFailedOverlay,
LoadingConnectionOverlay,
PeerConnectionDisconnectedOverlay,
} from "../components/VideoOverlay";
import { FeatureFlagProvider } from "../providers/FeatureFlagProvider"; import { FeatureFlagProvider } from "../providers/FeatureFlagProvider";
import notifications from "../notifications"; import notifications from "../notifications";
import {
ConnectionErrorOverlay,
LoadingConnectionOverlay,
} from "../components/VideoOverlay";
import { SystemVersionInfo } from "./devices.$id.settings.general.update";
import { DeviceStatus } from "./welcome-local"; import { DeviceStatus } from "./welcome-local";
import { SystemVersionInfo } from "./devices.$id.settings.general.update";
interface LocalLoaderResp { interface LocalLoaderResp {
authMode: "password" | "noPassword" | null; authMode: "password" | "noPassword" | null;
@ -117,7 +119,6 @@ const loader = async ({ params }: LoaderFunctionArgs) => {
export default function KvmIdRoute() { export default function KvmIdRoute() {
const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp; const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp;
// Depending on the mode, we set the appropriate variables // Depending on the mode, we set the appropriate variables
const user = "user" in loaderResp ? loaderResp.user : null; const user = "user" in loaderResp ? loaderResp.user : null;
const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null; const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null;
@ -130,6 +131,8 @@ export default function KvmIdRoute() {
const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse); const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse);
const peerConnection = useRTCStore(state => state.peerConnection); const peerConnection = useRTCStore(state => state.peerConnection);
const setPeerConnectionState = useRTCStore(state => state.setPeerConnectionState);
const peerConnectionState = useRTCStore(state => state.peerConnectionState);
const setMediaMediaStream = useRTCStore(state => state.setMediaStream); const setMediaMediaStream = useRTCStore(state => state.setMediaStream);
const setPeerConnection = useRTCStore(state => state.setPeerConnection); const setPeerConnection = useRTCStore(state => state.setPeerConnection);
const setDiskChannel = useRTCStore(state => state.setDiskChannel); const setDiskChannel = useRTCStore(state => state.setDiskChannel);
@ -137,23 +140,28 @@ export default function KvmIdRoute() {
const setTransceiver = useRTCStore(state => state.setTransceiver); const setTransceiver = useRTCStore(state => state.setTransceiver);
const location = useLocation(); const location = useLocation();
const isLegacySignalingEnabled = useRef(false);
const [connectionFailed, setConnectionFailed] = useState(false); const [connectionFailed, setConnectionFailed] = useState(false);
const navigate = useNavigate(); const navigate = useNavigate();
const { otaState, setOtaState, setModalView } = useUpdateStore(); const { otaState, setOtaState, setModalView } = useUpdateStore();
const [loadingMessage, setLoadingMessage] = useState("Connecting to device..."); const [loadingMessage, setLoadingMessage] = useState("Connecting to device...");
const closePeerConnection = useCallback( const cleanupAndStopReconnecting = useCallback(
function closePeerConnection() { function cleanupAndStopReconnecting() {
console.log("Closing peer connection"); console.log("Closing peer connection");
setConnectionFailed(true); setConnectionFailed(true);
if (peerConnection) {
setPeerConnectionState(peerConnection.connectionState);
}
connectionFailedRef.current = true; connectionFailedRef.current = true;
peerConnection?.close(); peerConnection?.close();
signalingAttempts.current = 0; signalingAttempts.current = 0;
}, },
[peerConnection], [peerConnection, setPeerConnectionState],
); );
// We need to track connectionFailed in a ref to avoid stale closure issues // We need to track connectionFailed in a ref to avoid stale closure issues
@ -171,95 +179,233 @@ export default function KvmIdRoute() {
}, [connectionFailed]); }, [connectionFailed]);
const signalingAttempts = useRef(0); const signalingAttempts = useRef(0);
const syncRemoteSessionDescription = useCallback( const setRemoteSessionDescription = useCallback(
async function syncRemoteSessionDescription(pc: RTCPeerConnection) { async function setRemoteSessionDescription(
pc: RTCPeerConnection,
remoteDescription: RTCSessionDescriptionInit,
) {
setLoadingMessage("Setting remote description");
try { try {
if (!pc) return; await pc.setRemoteDescription(new RTCSessionDescription(remoteDescription));
console.log("[setRemoteSessionDescription] Remote description set successfully");
const sd = btoa(JSON.stringify(pc.localDescription)); setLoadingMessage("Establishing secure connection...");
const sessionUrl = isOnDevice
? `${DEVICE_API}/webrtc/session`
: `${CLOUD_API}/webrtc/session`;
console.log("Trying to get remote session description");
setLoadingMessage(
`Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
);
const res = await api.POST(sessionUrl, {
sd,
// When on device, we don't need to specify the device id, as it's already known
...(isOnDevice ? {} : { id: params.id }),
});
const json = await res.json();
if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
if (!res.ok) {
console.error("Error getting SDP", { status: res.status, json });
throw new Error("Error getting SDP");
}
console.log("Successfully got Remote Session Description. Setting.");
setLoadingMessage("Setting remote session description...");
const decodedSd = atob(json.sd);
const parsedSd = JSON.parse(decodedSd);
pc.setRemoteDescription(new RTCSessionDescription(parsedSd));
await new Promise((resolve, reject) => {
console.log("Waiting for remote description to be set");
const maxAttempts = 10;
const interval = 1000;
let attempts = 0;
const checkInterval = setInterval(() => {
attempts++;
// When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
if (pc.sctp?.state === "connected") {
console.log("Remote description set");
clearInterval(checkInterval);
resolve(true);
} else if (attempts >= maxAttempts) {
console.log(
`Failed to get remote description after ${maxAttempts} attempts`,
);
closePeerConnection();
clearInterval(checkInterval);
reject(
new Error(
`Failed to get remote description after ${maxAttempts} attempts`,
),
);
} else {
console.log("Waiting for remote description to be set");
}
}, interval);
});
} catch (error) { } catch (error) {
console.error("Error getting SDP", { error }); console.error(
console.log("Connection failed", connectionFailedRef.current); "[setRemoteSessionDescription] Failed to set remote description:",
if (connectionFailedRef.current) return; error,
if (signalingAttempts.current < 5) { );
signalingAttempts.current++; cleanupAndStopReconnecting();
await new Promise(resolve => setTimeout(resolve, 500)); return;
console.log("Attempting to get SDP again", signalingAttempts.current);
syncRemoteSessionDescription(pc);
} else {
closePeerConnection();
}
} }
// Replace the interval-based check with a more reliable approach
let attempts = 0;
const checkInterval = setInterval(() => {
attempts++;
// When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
if (pc.sctp?.state === "connected") {
console.log("[setRemoteSessionDescription] Remote description set");
clearInterval(checkInterval);
setLoadingMessage("Connection established");
} else if (attempts >= 10) {
console.log(
"[setRemoteSessionDescription] Failed to establish connection after 10 attempts",
{
connectionState: pc.connectionState,
iceConnectionState: pc.iceConnectionState,
},
);
cleanupAndStopReconnecting();
clearInterval(checkInterval);
} else {
console.log("[setRemoteSessionDescription] Waiting for connection, state:", {
connectionState: pc.connectionState,
iceConnectionState: pc.iceConnectionState,
});
}
}, 1000);
}, },
[closePeerConnection, navigate, params.id], [cleanupAndStopReconnecting],
);
const ignoreOffer = useRef(false);
const isSettingRemoteAnswerPending = useRef(false);
const makingOffer = useRef(false);
const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const { sendMessage, getWebSocket } = useWebSocket(
isOnDevice
? `${wsProtocol}//${window.location.host}/webrtc/signaling/client`
: `${CLOUD_API.replace("http", "ws")}/webrtc/signaling/client?id=${params.id}`,
{
heartbeat: true,
retryOnError: true,
reconnectAttempts: 5,
reconnectInterval: 1000,
onReconnectStop: () => {
console.log("Reconnect stopped");
cleanupAndStopReconnecting();
},
shouldReconnect(event) {
console.log("[Websocket] shouldReconnect", event);
// TODO: Why true?
return true;
},
onClose(event) {
console.log("[Websocket] onClose", event);
// We don't want to close everything down, we wait for the reconnect to stop instead
},
onError(event) {
console.log("[Websocket] onError", event);
// We don't want to close everything down, we wait for the reconnect to stop instead
},
onOpen() {
console.log("[Websocket] onOpen");
},
onMessage: message => {
if (message.data === "pong") return;
/*
Currently the signaling process is as follows:
After open, the other side will send a `device-metadata` message with the device version
If the device version is not set, we can assume the device is using the legacy signaling
Otherwise, we can assume the device is using the new signaling
If the device is using the legacy signaling, we close the websocket connection
and use the legacy HTTPSignaling function to get the remote session description
If the device is using the new signaling, we don't need to do anything special, but continue to use the websocket connection
to chat with the other peer about the connection
*/
const parsedMessage = JSON.parse(message.data);
if (parsedMessage.type === "device-metadata") {
const { deviceVersion } = parsedMessage.data;
console.log("[Websocket] Received device-metadata message");
console.log("[Websocket] Device version", deviceVersion);
// If the device version is not set, we can assume the device is using the legacy signaling
if (!deviceVersion) {
console.log("[Websocket] Device is using legacy signaling");
// Now we don't need the websocket connection anymore, as we've established that we need to use the legacy signaling
// which does everything over HTTP(at least from the perspective of the client)
isLegacySignalingEnabled.current = true;
getWebSocket()?.close();
} else {
console.log("[Websocket] Device is using new signaling");
isLegacySignalingEnabled.current = false;
}
setupPeerConnection();
}
if (!peerConnection) return;
if (parsedMessage.type === "answer") {
console.log("[Websocket] Received answer");
const readyForOffer =
// If we're making an offer, we don't want to accept an answer
!makingOffer &&
// If the peer connection is stable or we're setting the remote answer pending, we're ready for an offer
(peerConnection?.signalingState === "stable" ||
isSettingRemoteAnswerPending.current);
// If we're not ready for an offer, we don't want to accept an offer
ignoreOffer.current = parsedMessage.type === "offer" && !readyForOffer;
if (ignoreOffer.current) return;
// Set so we don't accept an answer while we're setting the remote description
isSettingRemoteAnswerPending.current = parsedMessage.type === "answer";
console.log(
"[Websocket] Setting remote answer pending",
isSettingRemoteAnswerPending.current,
);
const sd = atob(parsedMessage.data);
const remoteSessionDescription = JSON.parse(sd);
setRemoteSessionDescription(
peerConnection,
new RTCSessionDescription(remoteSessionDescription),
);
// Reset the remote answer pending flag
isSettingRemoteAnswerPending.current = false;
} else if (parsedMessage.type === "new-ice-candidate") {
console.log("[Websocket] Received new-ice-candidate");
const candidate = parsedMessage.data;
peerConnection.addIceCandidate(candidate);
}
},
},
// Don't even retry once we declare failure
!connectionFailed && isLegacySignalingEnabled.current === false,
);
const sendWebRTCSignal = useCallback(
(type: string, data: unknown) => {
// Second argument tells the library not to queue the message, and send it once the connection is established again.
// We have event handlers that handle the connection set up, so we don't need to queue the message.
sendMessage(JSON.stringify({ type, data }), false);
},
[sendMessage],
);
const legacyHTTPSignaling = useCallback(
async (pc: RTCPeerConnection) => {
const sd = btoa(JSON.stringify(pc.localDescription));
// Legacy mode == UI in cloud with updated code connecting to older device version.
// In device mode, old devices wont server this JS, and on newer devices legacy mode wont be enabled
const sessionUrl = `${CLOUD_API}/webrtc/session`;
console.log("Trying to get remote session description");
setLoadingMessage(
`Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
);
const res = await api.POST(sessionUrl, {
sd,
// When on device, we don't need to specify the device id, as it's already known
...(isOnDevice ? {} : { id: params.id }),
});
const json = await res.json();
if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
if (!res.ok) {
console.error("Error getting SDP", { status: res.status, json });
cleanupAndStopReconnecting();
return;
}
console.log("Successfully got Remote Session Description. Setting.");
setLoadingMessage("Setting remote session description...");
const decodedSd = atob(json.sd);
const parsedSd = JSON.parse(decodedSd);
setRemoteSessionDescription(pc, new RTCSessionDescription(parsedSd));
},
[cleanupAndStopReconnecting, navigate, params.id, setRemoteSessionDescription],
); );
const setupPeerConnection = useCallback(async () => { const setupPeerConnection = useCallback(async () => {
console.log("Setting up peer connection"); console.log("[setupPeerConnection] Setting up peer connection");
setConnectionFailed(false); setConnectionFailed(false);
setLoadingMessage("Connecting to device..."); setLoadingMessage("Connecting to device...");
if (peerConnection?.signalingState === "stable") {
console.log("[setupPeerConnection] Peer connection already established");
return;
}
let pc: RTCPeerConnection; let pc: RTCPeerConnection;
try { try {
console.log("Creating peer connection"); console.log("[setupPeerConnection] Creating peer connection");
setLoadingMessage("Creating peer connection..."); setLoadingMessage("Creating peer connection...");
pc = new RTCPeerConnection({ pc = new RTCPeerConnection({
// We only use STUN or TURN servers if we're in the cloud // We only use STUN or TURN servers if we're in the cloud
@ -267,30 +413,65 @@ export default function KvmIdRoute() {
? { iceServers: [iceConfig?.iceServers] } ? { iceServers: [iceConfig?.iceServers] }
: {}), : {}),
}); });
console.log("Peer connection created", pc);
setLoadingMessage("Peer connection created"); setPeerConnectionState(pc.connectionState);
console.log("[setupPeerConnection] Peer connection created", pc);
setLoadingMessage("Setting up connection to device...");
} catch (e) { } catch (e) {
console.error(`Error creating peer connection: ${e}`); console.error(`[setupPeerConnection] Error creating peer connection: ${e}`);
setTimeout(() => { setTimeout(() => {
closePeerConnection(); cleanupAndStopReconnecting();
}, 1000); }, 1000);
return; return;
} }
// Set up event listeners and data channels // Set up event listeners and data channels
pc.onconnectionstatechange = () => { pc.onconnectionstatechange = () => {
console.log("Connection state changed", pc.connectionState); console.log("[setupPeerConnection] Connection state changed", pc.connectionState);
setPeerConnectionState(pc.connectionState);
};
pc.onnegotiationneeded = async () => {
try {
console.log("[setupPeerConnection] Creating offer");
makingOffer.current = true;
const offer = await pc.createOffer();
await pc.setLocalDescription(offer);
const sd = btoa(JSON.stringify(pc.localDescription));
const isNewSignalingEnabled = isLegacySignalingEnabled.current === false;
if (isNewSignalingEnabled) {
sendWebRTCSignal("offer", { sd: sd });
} else {
console.log("Legacy signanling. Waiting for ICE Gathering to complete...");
}
} catch (e) {
console.error(
`[setupPeerConnection] Error creating offer: ${e}`,
new Date().toISOString(),
);
cleanupAndStopReconnecting();
} finally {
makingOffer.current = false;
}
};
pc.onicecandidate = async ({ candidate }) => {
if (!candidate) return;
if (candidate.candidate === "") return;
sendWebRTCSignal("new-ice-candidate", candidate);
}; };
pc.onicegatheringstatechange = event => { pc.onicegatheringstatechange = event => {
const pc = event.currentTarget as RTCPeerConnection; const pc = event.currentTarget as RTCPeerConnection;
console.log("ICE Gathering State Changed", pc.iceGatheringState);
if (pc.iceGatheringState === "complete") { if (pc.iceGatheringState === "complete") {
console.log("ICE Gathering completed"); console.log("ICE Gathering completed");
setLoadingMessage("ICE Gathering completed"); setLoadingMessage("ICE Gathering completed");
// We can now start the https/ws connection to get the remote session description from the KVM device if (isLegacySignalingEnabled.current) {
syncRemoteSessionDescription(pc); // We can now start the https/ws connection to get the remote session description from the KVM device
legacyHTTPSignaling(pc);
}
} else if (pc.iceGatheringState === "gathering") { } else if (pc.iceGatheringState === "gathering") {
console.log("ICE Gathering Started"); console.log("ICE Gathering Started");
setLoadingMessage("Gathering ICE candidates..."); setLoadingMessage("Gathering ICE candidates...");
@ -314,31 +495,26 @@ export default function KvmIdRoute() {
}; };
setPeerConnection(pc); setPeerConnection(pc);
try {
const offer = await pc.createOffer();
await pc.setLocalDescription(offer);
} catch (e) {
console.error(`Error creating offer: ${e}`, new Date().toISOString());
closePeerConnection();
}
}, [ }, [
closePeerConnection, cleanupAndStopReconnecting,
iceConfig?.iceServers, iceConfig?.iceServers,
legacyHTTPSignaling,
peerConnection?.signalingState,
sendWebRTCSignal,
setDiskChannel, setDiskChannel,
setMediaMediaStream, setMediaMediaStream,
setPeerConnection, setPeerConnection,
setPeerConnectionState,
setRpcDataChannel, setRpcDataChannel,
setTransceiver, setTransceiver,
syncRemoteSessionDescription,
]); ]);
// On boot, if the connection state is undefined, we connect to the WebRTC
useEffect(() => { useEffect(() => {
if (peerConnection?.connectionState === undefined) { if (peerConnectionState === "failed") {
setupPeerConnection(); console.log("Connection failed, closing peer connection");
cleanupAndStopReconnecting();
} }
}, [setupPeerConnection, peerConnection?.connectionState]); }, [peerConnectionState, cleanupAndStopReconnecting]);
// Cleanup effect // Cleanup effect
const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats); const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats);
@ -363,7 +539,7 @@ export default function KvmIdRoute() {
// TURN server usage detection // TURN server usage detection
useEffect(() => { useEffect(() => {
if (peerConnection?.connectionState !== "connected") return; if (peerConnectionState !== "connected") return;
const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState(); const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState();
const lastLocalStat = Array.from(localCandidateStats).pop(); const lastLocalStat = Array.from(localCandidateStats).pop();
@ -375,7 +551,7 @@ export default function KvmIdRoute() {
const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here
setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn); setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn);
}, [peerConnection?.connectionState, setIsTurnServerInUse]); }, [peerConnectionState, setIsTurnServerInUse]);
// TURN server usage reporting // TURN server usage reporting
const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse); const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse);
@ -466,10 +642,6 @@ export default function KvmIdRoute() {
}); });
}, [rpcDataChannel?.readyState, send, setHdmiState]); }, [rpcDataChannel?.readyState, send, setHdmiState]);
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
window.send = send;
// When the update is successful, we need to refresh the client javascript and show a success modal // When the update is successful, we need to refresh the client javascript and show a success modal
useEffect(() => { useEffect(() => {
if (queryParams.get("updateSuccess")) { if (queryParams.get("updateSuccess")) {
@ -506,12 +678,12 @@ export default function KvmIdRoute() {
useEffect(() => { useEffect(() => {
if (!peerConnection) return; if (!peerConnection) return;
if (!kvmTerminal) { if (!kvmTerminal) {
console.log('Creating data channel "terminal"'); // console.log('Creating data channel "terminal"');
setKvmTerminal(peerConnection.createDataChannel("terminal")); setKvmTerminal(peerConnection.createDataChannel("terminal"));
} }
if (!serialConsole) { if (!serialConsole) {
console.log('Creating data channel "serial"'); // console.log('Creating data channel "serial"');
setSerialConsole(peerConnection.createDataChannel("serial")); setSerialConsole(peerConnection.createDataChannel("serial"));
} }
}, [kvmTerminal, peerConnection, serialConsole]); }, [kvmTerminal, peerConnection, serialConsole]);
@ -554,6 +726,43 @@ export default function KvmIdRoute() {
[send, setScrollSensitivity], [send, setScrollSensitivity],
); );
const ConnectionStatusElement = useMemo(() => {
const hasConnectionFailed =
connectionFailed || ["failed", "closed"].includes(peerConnectionState || "");
const isPeerConnectionLoading =
["connecting", "new"].includes(peerConnectionState || "") ||
peerConnection === null;
const isDisconnected = peerConnectionState === "disconnected";
const isOtherSession = location.pathname.includes("other-session");
if (isOtherSession) return null;
if (peerConnectionState === "connected") return null;
if (isDisconnected) {
return <PeerConnectionDisconnectedOverlay show={true} />;
}
if (hasConnectionFailed)
return (
<ConnectionFailedOverlay show={true} setupPeerConnection={setupPeerConnection} />
);
if (isPeerConnectionLoading) {
return <LoadingConnectionOverlay show={true} text={loadingMessage} />;
}
return null;
}, [
connectionFailed,
loadingMessage,
location.pathname,
peerConnection,
peerConnectionState,
setupPeerConnection,
]);
return ( return (
<FeatureFlagProvider appVersion={appVersion}> <FeatureFlagProvider appVersion={appVersion}>
{!outlet && otaState.updating && ( {!outlet && otaState.updating && (
@ -593,27 +802,13 @@ export default function KvmIdRoute() {
/> />
<div className="flex h-full w-full overflow-hidden"> <div className="flex h-full w-full overflow-hidden">
<div className="pointer-events-none fixed inset-0 isolate z-50 flex h-full w-full items-center justify-center"> <div className="pointer-events-none fixed inset-0 isolate z-20 flex h-full w-full items-center justify-center">
<div className="my-2 h-full max-h-[720px] w-full max-w-[1280px] rounded-md"> <div className="my-2 h-full max-h-[720px] w-full max-w-[1280px] rounded-md">
<LoadingConnectionOverlay {!!ConnectionStatusElement && ConnectionStatusElement}
show={
!connectionFailed &&
(["connecting", "new"].includes(
peerConnection?.connectionState || "",
) ||
peerConnection === null) &&
!location.pathname.includes("other-session")
}
text={loadingMessage}
/>
<ConnectionErrorOverlay
show={connectionFailed && !location.pathname.includes("other-session")}
setupPeerConnection={setupPeerConnection}
/>
</div> </div>
</div> </div>
<WebRTCVideo /> {peerConnectionState === "connected" && <WebRTCVideo />}
<SidebarContainer sidebarView={sidebarView} /> <SidebarContainer sidebarView={sidebarView} />
</div> </div>
</div> </div>

188
web.go
View File

@ -1,6 +1,7 @@
package kvm package kvm
import ( import (
"context"
"embed" "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -10,8 +11,12 @@ import (
"strings" "strings"
"time" "time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pion/webrtc/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -94,7 +99,7 @@ func setupRouter() *gin.Engine {
protected := r.Group("/") protected := r.Group("/")
protected.Use(protectedMiddleware()) protected.Use(protectedMiddleware())
{ {
protected.POST("/webrtc/session", handleWebRTCSession) protected.GET("/webrtc/signaling/client", handleLocalWebRTCSignal)
protected.POST("/cloud/register", handleCloudRegister) protected.POST("/cloud/register", handleCloudRegister)
protected.GET("/cloud/state", handleCloudState) protected.GET("/cloud/state", handleCloudState)
protected.GET("/device", handleDevice) protected.GET("/device", handleDevice)
@ -121,35 +126,182 @@ func setupRouter() *gin.Engine {
// TODO: support multiple sessions? // TODO: support multiple sessions?
var currentSession *Session var currentSession *Session
func handleWebRTCSession(c *gin.Context) { func handleLocalWebRTCSignal(c *gin.Context) {
var req WebRTCSessionRequest cloudLogger.Infof("new websocket connection established")
// Create WebSocket options with InsecureSkipVerify to bypass origin check
if err := c.ShouldBindJSON(&req); err != nil { wsOptions := &websocket.AcceptOptions{
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) InsecureSkipVerify: true, // Allow connections from any origin
return
} }
session, err := newSession(SessionConfig{}) wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sd, err := session.ExchangeOffer(req.Sd) // get the source from the request
source := c.ClientIP()
// Now use conn for websocket operations
defer wsCon.Close(websocket.StatusNormalClosure, "")
err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": gin.H{"deviceVersion": builtAppVersion}})
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if currentSession != nil {
writeJSONRPCEvent("otherSessionConnected", nil, currentSession) err = handleWebRTCSignalWsMessages(wsCon, false, source)
peerConn := currentSession.peerConnection if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string) error {
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
// Add connection tracking to detect reconnections
connectionID := uuid.New().String()
cloudLogger.Infof("new websocket connection established with ID: %s", connectionID)
// connection type
var sourceType string
if isCloudConnection {
sourceType = "cloud"
} else {
sourceType = "local"
}
// probably we can use a better logging framework here
logInfof := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Infof(format+", source: %s, sourceType: %s", args...)
}
logWarnf := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...)
}
logTracef := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...)
}
go func() {
for {
time.Sleep(WebsocketPingInterval)
// set the timer for the ping duration
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v)
metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v)
}))
logInfof("pinging websocket")
err := wsCon.Ping(runCtx)
if err != nil {
logWarnf("websocket ping error: %v", err)
cancelRun()
return
}
// dont use `defer` here because we want to observe the duration of the ping
timer.ObserveDuration()
metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc()
metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
}
}()
if isCloudConnection {
// create a channel to receive the disconnect event, once received, we cancelRun
cloudDisconnectChan = make(chan error)
defer func() {
close(cloudDisconnectChan)
cloudDisconnectChan = nil
}()
go func() { go func() {
time.Sleep(1 * time.Second) for err := range cloudDisconnectChan {
_ = peerConn.Close() if err == nil {
continue
}
cloudLogger.Infof("disconnecting from cloud due to: %v", err)
cancelRun()
}
}() }()
} }
currentSession = session
c.JSON(http.StatusOK, gin.H{"sd": sd}) for {
typ, msg, err := wsCon.Read(runCtx)
if err != nil {
logWarnf("websocket read error: %v", err)
return err
}
if typ != websocket.MessageText {
// ignore non-text messages
continue
}
var message struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
err = json.Unmarshal(msg, &message)
if err != nil {
logWarnf("unable to parse ws message: %v", err)
continue
}
if message.Type == "offer" {
logInfof("new session request received")
var req WebRTCSessionRequest
err = json.Unmarshal(message.Data, &req)
if err != nil {
logWarnf("unable to parse session request data: %v", err)
continue
}
logInfof("new session request: %v", req.OidcGoogle)
logTracef("session request info: %v", req)
metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc()
metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source)
if err != nil {
logWarnf("error starting new session: %v", err)
continue
}
} else if message.Type == "new-ice-candidate" {
logInfof("The client sent us a new ICE candidate: %v", string(message.Data))
var candidate webrtc.ICECandidateInit
// Attempt to unmarshal as a ICECandidateInit
if err := json.Unmarshal(message.Data, &candidate); err != nil {
logWarnf("unable to parse incoming ICE candidate data: %v", string(message.Data))
continue
}
if candidate.Candidate == "" {
logWarnf("empty incoming ICE candidate, skipping")
continue
}
logInfof("unmarshalled incoming ICE candidate: %v", candidate)
if currentSession == nil {
logInfof("no current session, skipping incoming ICE candidate")
continue
}
logInfof("adding incoming ICE candidate to current session: %v", candidate)
if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
logWarnf("failed to add incoming ICE candidate to our peer connection: %v", err)
}
}
}
} }
func handleLogin(c *gin.Context) { func handleLogin(c *gin.Context) {

View File

@ -1,11 +1,15 @@
package kvm package kvm
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"net" "net"
"strings" "strings"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
"github.com/pion/webrtc/v4" "github.com/pion/webrtc/v4"
) )
@ -23,6 +27,7 @@ type SessionConfig struct {
ICEServers []string ICEServers []string
LocalIP string LocalIP string
IsCloud bool IsCloud bool
ws *websocket.Conn
} }
func (s *Session) ExchangeOffer(offerStr string) (string, error) { func (s *Session) ExchangeOffer(offerStr string) (string, error) {
@ -46,19 +51,11 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
return "", err return "", err
} }
// Create channel that is blocked until ICE Gathering is complete
gatherComplete := webrtc.GatheringCompletePromise(s.peerConnection)
// Sets the LocalDescription, and starts our UDP listeners // Sets the LocalDescription, and starts our UDP listeners
if err = s.peerConnection.SetLocalDescription(answer); err != nil { if err = s.peerConnection.SetLocalDescription(answer); err != nil {
return "", err return "", err
} }
// Block until ICE Gathering is complete, disabling trickle ICE
// we do this because we only can exchange one signaling message
// in a production application you should exchange ICE Candidates via OnICECandidate
<-gatherComplete
localDescription, err := json.Marshal(s.peerConnection.LocalDescription()) localDescription, err := json.Marshal(s.peerConnection.LocalDescription())
if err != nil { if err != nil {
return "", err return "", err
@ -144,6 +141,16 @@ func newSession(config SessionConfig) (*Session, error) {
}() }()
var isConnected bool var isConnected bool
peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
logger.Infof("Our WebRTC peerConnection has a new ICE candidate: %v", candidate)
if candidate != nil {
err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
if err != nil {
logger.Errorf("failed to write new-ice-candidate to WebRTC signaling channel: %v", err)
}
}
})
peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
logger.Infof("Connection State has changed %s", connectionState) logger.Infof("Connection State has changed %s", connectionState)
if connectionState == webrtc.ICEConnectionStateConnected { if connectionState == webrtc.ICEConnectionStateConnected {