diff --git a/cloud.go b/cloud.go
index be53b08..7ad8b75 100644
--- a/cloud.go
+++ b/cloud.go
@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"net/http"
 	"net/url"
+	"sync"
 	"time"
 
 	"github.com/coder/websocket/wsjson"
@@ -34,8 +35,8 @@ const (
 	// CloudOidcRequestTimeout is the timeout for OIDC token verification requests
 	// should be lower than the websocket response timeout set in cloud-api
 	CloudOidcRequestTimeout = 10 * time.Second
-	// CloudWebSocketPingInterval is the interval at which the websocket client sends ping messages to the cloud
-	CloudWebSocketPingInterval = 15 * time.Second
+	// WebsocketPingInterval is the interval at which the websocket client sends ping messages to the cloud
+	WebsocketPingInterval = 15 * time.Second
 )
 
 var (
@@ -51,59 +52,67 @@ var (
 			Help: "The timestamp when the cloud connection was established",
 		},
 	)
-	metricCloudConnectionLastPingTimestamp = promauto.NewGauge(
+	metricConnectionLastPingTimestamp = promauto.NewGaugeVec(
 		prometheus.GaugeOpts{
-			Name: "jetkvm_cloud_connection_last_ping_timestamp",
+			Name: "jetkvm_connection_last_ping_timestamp",
 			Help: "The timestamp when the last ping response was received",
 		},
+		[]string{"type", "source"},
 	)
-	metricCloudConnectionLastPingDuration = promauto.NewGauge(
+	metricConnectionLastPingDuration = promauto.NewGaugeVec(
 		prometheus.GaugeOpts{
-			Name: "jetkvm_cloud_connection_last_ping_duration",
+			Name: "jetkvm_connection_last_ping_duration",
 			Help: "The duration of the last ping response",
 		},
+		[]string{"type", "source"},
 	)
-	metricCloudConnectionPingDuration = promauto.NewHistogram(
+	metricConnectionPingDuration = promauto.NewHistogramVec(
 		prometheus.HistogramOpts{
-			Name: "jetkvm_cloud_connection_ping_duration",
+			Name: "jetkvm_connection_ping_duration",
 			Help: "The duration of the ping response",
 			Buckets: []float64{
 				0.1, 0.5, 1, 10,
 			},
 		},
+		[]string{"type", "source"},
 	)
-	metricCloudConnectionTotalPingCount = promauto.NewCounter(
+	metricConnectionTotalPingCount = promauto.NewCounterVec(
 		prometheus.CounterOpts{
-			Name: "jetkvm_cloud_connection_total_ping_count",
-			Help: "The total number of pings sent to the cloud",
+			Name: "jetkvm_connection_total_ping_count",
+			Help: "The total number of pings sent to the connection",
 		},
+		[]string{"type", "source"},
 	)
-	metricCloudConnectionSessionRequestCount = promauto.NewCounter(
+	metricConnectionSessionRequestCount = promauto.NewCounterVec(
 		prometheus.CounterOpts{
-			Name: "jetkvm_cloud_connection_session_total_request_count",
-			Help: "The total number of session requests received from the cloud",
+			Name: "jetkvm_connection_session_total_request_count",
+			Help: "The total number of session requests received",
 		},
+		[]string{"type", "source"},
 	)
-	metricCloudConnectionSessionRequestDuration = promauto.NewHistogram(
+	metricConnectionSessionRequestDuration = promauto.NewHistogramVec(
 		prometheus.HistogramOpts{
-			Name: "jetkvm_cloud_connection_session_request_duration",
+			Name: "jetkvm_connection_session_request_duration",
 			Help: "The duration of session requests",
 			Buckets: []float64{
 				0.1, 0.5, 1, 10,
 			},
 		},
+		[]string{"type", "source"},
 	)
-	metricCloudConnectionLastSessionRequestTimestamp = promauto.NewGauge(
+	metricConnectionLastSessionRequestTimestamp = promauto.NewGaugeVec(
 		prometheus.GaugeOpts{
-			Name: "jetkvm_cloud_connection_last_session_request_timestamp",
+			Name: "jetkvm_connection_last_session_request_timestamp",
 			Help: "The timestamp of the last session request",
 		},
+		[]string{"type", "source"},
 	)
-	metricCloudConnectionLastSessionRequestDuration = promauto.NewGauge(
+	metricConnectionLastSessionRequestDuration = promauto.NewGaugeVec(
 		prometheus.GaugeOpts{
-			Name: "jetkvm_cloud_connection_last_session_request_duration",
+			Name: "jetkvm_connection_last_session_request_duration",
 			Help: "The duration of the last session request",
 		},
+		[]string{"type", "source"},
 	)
 	metricCloudConnectionFailureCount = promauto.NewCounter(
 		prometheus.CounterOpts{
@@ -113,12 +122,21 @@ var (
 	)
 )
 
-func cloudResetMetrics(established bool) {
-	metricCloudConnectionLastPingTimestamp.Set(-1)
-	metricCloudConnectionLastPingDuration.Set(-1)
+var (
+	cloudDisconnectChan chan error
+	cloudDisconnectLock = &sync.Mutex{}
+)
 
-	metricCloudConnectionLastSessionRequestTimestamp.Set(-1)
-	metricCloudConnectionLastSessionRequestDuration.Set(-1)
+func wsResetMetrics(established bool, sourceType string, source string) {
+	metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1)
+	metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1)
+
+	metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1)
+	metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1)
+
+	if sourceType != "cloud" {
+		return
+	}
 
 	if established {
 		metricCloudConnectionEstablishedTimestamp.SetToCurrentTime()
@@ -213,6 +231,24 @@ func handleCloudRegister(c *gin.Context) {
 	c.JSON(200, gin.H{"message": "Cloud registration successful"})
 }
 
+func disconnectCloud(reason error) {
+	cloudDisconnectLock.Lock()
+	defer cloudDisconnectLock.Unlock()
+
+	if cloudDisconnectChan == nil {
+		cloudLogger.Tracef("cloud disconnect channel is not set, no need to disconnect")
+		return
+	}
+
+	// just in case the channel is closed, we don't want to panic
+	defer func() {
+		if r := recover(); r != nil {
+			cloudLogger.Infof("cloud disconnect channel is closed, no need to disconnect: %v", r)
+		}
+	}()
+	cloudDisconnectChan <- reason
+}
+
 func runWebsocketClient() error {
 	if config.CloudToken == "" {
 		time.Sleep(5 * time.Second)
@@ -232,6 +268,7 @@ func runWebsocketClient() error {
 
 	header := http.Header{}
 	header.Set("X-Device-ID", GetDeviceID())
+	header.Set("X-App-Version", builtAppVersion)
 	header.Set("Authorization", "Bearer "+config.CloudToken)
 	dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout)
 
@@ -246,71 +283,13 @@ func runWebsocketClient() error {
 	cloudLogger.Infof("websocket connected to %s", wsURL)
 
 	// set the metrics when we successfully connect to the cloud.
-	cloudResetMetrics(true)
+	wsResetMetrics(true, "cloud", "")
 
-	runCtx, cancelRun := context.WithCancel(context.Background())
-	defer cancelRun()
-	go func() {
-		for {
-			time.Sleep(CloudWebSocketPingInterval)
-
-			// set the timer for the ping duration
-			timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
-				metricCloudConnectionLastPingDuration.Set(v)
-				metricCloudConnectionPingDuration.Observe(v)
-			}))
-
-			err := c.Ping(runCtx)
-
-			if err != nil {
-				cloudLogger.Warnf("websocket ping error: %v", err)
-				cancelRun()
-				return
-			}
-
-			// dont use `defer` here because we want to observe the duration of the ping
-			timer.ObserveDuration()
-
-			metricCloudConnectionTotalPingCount.Inc()
-			metricCloudConnectionLastPingTimestamp.SetToCurrentTime()
-		}
-	}()
-	for {
-		typ, msg, err := c.Read(runCtx)
-		if err != nil {
-			return err
-		}
-		if typ != websocket.MessageText {
-			// ignore non-text messages
-			continue
-		}
-		var req WebRTCSessionRequest
-		err = json.Unmarshal(msg, &req)
-		if err != nil {
-			cloudLogger.Warnf("unable to parse ws message: %v", string(msg))
-			continue
-		}
-
-		cloudLogger.Infof("new session request: %v", req.OidcGoogle)
-		cloudLogger.Tracef("session request info: %v", req)
-
-		metricCloudConnectionSessionRequestCount.Inc()
-		metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime()
-		err = handleSessionRequest(runCtx, c, req)
-		if err != nil {
-			cloudLogger.Infof("error starting new session: %v", err)
-			continue
-		}
-	}
+	// we don't have a source for the cloud connection
+	return handleWebRTCSignalWsMessages(c, true, "")
 }
 
-func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
-	timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
-		metricCloudConnectionLastSessionRequestDuration.Set(v)
-		metricCloudConnectionSessionRequestDuration.Observe(v)
-	}))
-	defer timer.ObserveDuration()
-
+func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
 	oidcCtx, cancelOIDC := context.WithTimeout(ctx, CloudOidcRequestTimeout)
 	defer cancelOIDC()
 	provider, err := oidc.NewProvider(oidcCtx, "https://accounts.google.com")
@@ -338,10 +317,35 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
 		return fmt.Errorf("google identity mismatch")
 	}
 
+	return nil
+}
+
+func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool, source string) error {
+	var sourceType string
+	if isCloudConnection {
+		sourceType = "cloud"
+	} else {
+		sourceType = "local"
+	}
+
+	timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
+		metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(v)
+		metricConnectionSessionRequestDuration.WithLabelValues(sourceType, source).Observe(v)
+	}))
+	defer timer.ObserveDuration()
+
+	// If the message is from the cloud, we need to authenticate the session.
+	if isCloudConnection {
+		if err := authenticateSession(ctx, c, req); err != nil {
+			return err
+		}
+	}
+
 	session, err := newSession(SessionConfig{
-		ICEServers: req.ICEServers,
+		ws:         c,
+		IsCloud:    isCloudConnection,
 		LocalIP:    req.IP,
-		IsCloud:    true,
+		ICEServers: req.ICEServers,
 	})
 	if err != nil {
 		_ = wsjson.Write(context.Background(), c, gin.H{"error": err})
@@ -365,14 +369,14 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
 	cloudLogger.Info("new session accepted")
 	cloudLogger.Tracef("new session accepted: %v", session)
 	currentSession = session
-	_ = wsjson.Write(context.Background(), c, gin.H{"sd": sd})
+	_ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd})
 	return nil
 }
 
 func RunWebsocketClient() {
 	for {
 		// reset the metrics when we start the websocket client.
-		cloudResetMetrics(false)
+		wsResetMetrics(false, "cloud", "")
 
 		// If the cloud token is not set, we don't need to run the websocket client.
 		if config.CloudToken == "" {
@@ -448,6 +452,9 @@ func rpcDeregisterDevice() error {
 			return fmt.Errorf("failed to save configuration after deregistering: %w", err)
 		}
 
+		cloudLogger.Infof("device deregistered, disconnecting from cloud")
+		disconnectCloud(fmt.Errorf("device deregistered"))
+
 		return nil
 	}
 
diff --git a/jsonrpc.go b/jsonrpc.go
index 5dc19e0..e5deb49 100644
--- a/jsonrpc.go
+++ b/jsonrpc.go
@@ -771,9 +771,14 @@ func rpcSetUsbDeviceState(device string, enabled bool) error {
 }
 
 func rpcSetCloudUrl(apiUrl string, appUrl string) error {
+	currentCloudURL := config.CloudURL
 	config.CloudURL = apiUrl
 	config.CloudAppURL = appUrl
 
+	if currentCloudURL != apiUrl {
+		disconnectCloud(fmt.Errorf("cloud url changed from %s to %s", currentCloudURL, apiUrl))
+	}
+
 	if err := SaveConfig(); err != nil {
 		return fmt.Errorf("failed to save config: %w", err)
 	}
diff --git a/log.go b/log.go
index 7718a28..0d36c0d 100644
--- a/log.go
+++ b/log.go
@@ -6,3 +6,4 @@ import "github.com/pion/logging"
 // ref: https://github.com/pion/webrtc/wiki/Debugging-WebRTC
 var logger = logging.NewDefaultLoggerFactory().NewLogger("jetkvm")
 var cloudLogger = logging.NewDefaultLoggerFactory().NewLogger("cloud")
+var websocketLogger = logging.NewDefaultLoggerFactory().NewLogger("websocket")
diff --git a/ota.go b/ota.go
index f813c09..9c583b6 100644
--- a/ota.go
+++ b/ota.go
@@ -126,7 +126,15 @@ func downloadFile(ctx context.Context, path string, url string, downloadProgress
 		return fmt.Errorf("error creating request: %w", err)
 	}
 
-	resp, err := http.DefaultClient.Do(req)
+	client := http.Client{
+		// allow a longer timeout for the download but keep the TLS handshake short
+		Timeout: 10 * time.Minute,
+		Transport: &http.Transport{
+			TLSHandshakeTimeout: 1 * time.Minute,
+		},
+	}
+
+	resp, err := client.Do(req)
 	if err != nil {
 		return fmt.Errorf("error downloading file: %w", err)
 	}
diff --git a/ui/package-lock.json b/ui/package-lock.json
index e9caa20..ebce148 100644
--- a/ui/package-lock.json
+++ b/ui/package-lock.json
@@ -30,6 +30,7 @@
         "react-icons": "^5.4.0",
         "react-router-dom": "^6.22.3",
         "react-simple-keyboard": "^3.7.112",
+        "react-use-websocket": "^4.13.0",
         "react-xtermjs": "^1.0.9",
         "recharts": "^2.15.0",
         "tailwind-merge": "^2.5.5",
@@ -5180,6 +5181,11 @@
         "react-dom": ">=16.6.0"
       }
     },
+    "node_modules/react-use-websocket": {
+      "version": "4.13.0",
+      "resolved": "https://registry.npmjs.org/react-use-websocket/-/react-use-websocket-4.13.0.tgz",
+      "integrity": "sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw=="
+    },
     "node_modules/react-xtermjs": {
       "version": "1.0.9",
       "resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz",
diff --git a/ui/package.json b/ui/package.json
index f8f1c7a..a248616 100644
--- a/ui/package.json
+++ b/ui/package.json
@@ -40,6 +40,7 @@
     "react-icons": "^5.4.0",
     "react-router-dom": "^6.22.3",
     "react-simple-keyboard": "^3.7.112",
+    "react-use-websocket": "^4.13.0",
     "react-xtermjs": "^1.0.9",
     "recharts": "^2.15.0",
     "tailwind-merge": "^2.5.5",
diff --git a/ui/src/components/Header.tsx b/ui/src/components/Header.tsx
index 03a907e..19e9652 100644
--- a/ui/src/components/Header.tsx
+++ b/ui/src/components/Header.tsx
@@ -36,7 +36,7 @@ export default function DashboardNavbar({
   picture,
   kvmName,
 }: NavbarProps) {
-  const peerConnection = useRTCStore(state => state.peerConnection);
+  const peerConnectionState = useRTCStore(state => state.peerConnectionState);
   const setUser = useUserStore(state => state.setUser);
   const navigate = useNavigate();
   const onLogout = useCallback(async () => {
@@ -82,14 +82,14 @@ export default function DashboardNavbar({
                 <div className="hidden items-center gap-x-2 md:flex">
                   <div className="w-[159px]">
                     <PeerConnectionStatusCard
-                      state={peerConnection?.connectionState}
+                      state={peerConnectionState}
                       title={kvmName}
                     />
                   </div>
                   <div className="hidden w-[159px] md:block">
                     <USBStateStatus
                       state={usbState}
-                      peerConnectionState={peerConnection?.connectionState}
+                        peerConnectionState={peerConnectionState}
                     />
                   </div>
                 </div>
diff --git a/ui/src/components/VideoOverlay.tsx b/ui/src/components/VideoOverlay.tsx
index 0620af4..e34cf10 100644
--- a/ui/src/components/VideoOverlay.tsx
+++ b/ui/src/components/VideoOverlay.tsx
@@ -6,7 +6,7 @@ import { LuPlay } from "react-icons/lu";
 
 import { Button, LinkButton } from "@components/Button";
 import LoadingSpinner from "@components/LoadingSpinner";
-import { GridCard } from "@components/Card";
+import Card, { GridCard } from "@components/Card";
 
 interface OverlayContentProps {
   children: React.ReactNode;
@@ -94,7 +94,7 @@ interface ConnectionErrorOverlayProps {
   setupPeerConnection: () => Promise<void>;
 }
 
-export function ConnectionErrorOverlay({
+export function ConnectionFailedOverlay({
   show,
   setupPeerConnection,
 }: ConnectionErrorOverlayProps) {
@@ -151,6 +151,60 @@ export function ConnectionErrorOverlay({
   );
 }
 
+interface PeerConnectionDisconnectedOverlay {
+  show: boolean;
+}
+
+export function PeerConnectionDisconnectedOverlay({
+  show,
+}: PeerConnectionDisconnectedOverlay) {
+  return (
+    <AnimatePresence>
+      {show && (
+        <motion.div
+          className="aspect-video h-full w-full"
+          initial={{ opacity: 0 }}
+          animate={{ opacity: 1 }}
+          exit={{ opacity: 0, transition: { duration: 0 } }}
+          transition={{
+            duration: 0.4,
+            ease: "easeInOut",
+          }}
+        >
+          <OverlayContent>
+            <div className="flex flex-col items-start gap-y-1">
+              <ExclamationTriangleIcon className="h-12 w-12 text-yellow-500" />
+              <div className="text-left text-sm text-slate-700 dark:text-slate-300">
+                <div className="space-y-4">
+                  <div className="space-y-2 text-black dark:text-white">
+                    <h2 className="text-xl font-bold">Connection Issue Detected</h2>
+                    <ul className="list-disc space-y-2 pl-4 text-left">
+                      <li>Verify that the device is powered on and properly connected</li>
+                      <li>Check all cable connections for any loose or damaged wires</li>
+                      <li>Ensure your network connection is stable and active</li>
+                      <li>Try restarting both the device and your computer</li>
+                    </ul>
+                  </div>
+                  <div className="flex items-center gap-x-2">
+                    <Card>
+                      <div className="flex items-center gap-x-2 p-4">
+                        <LoadingSpinner className="h-4 w-4 text-blue-800 dark:text-blue-200" />
+                        <p className="text-sm text-slate-700 dark:text-slate-300">
+                          Retrying connection...
+                        </p>
+                      </div>
+                    </Card>
+                  </div>
+                </div>
+              </div>
+            </div>
+          </OverlayContent>
+        </motion.div>
+      )}
+    </AnimatePresence>
+  );
+}
+
 interface HDMIErrorOverlayProps {
   show: boolean;
   hdmiState: string;
diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx
index a025b7c..3d72f30 100644
--- a/ui/src/components/WebRTCVideo.tsx
+++ b/ui/src/components/WebRTCVideo.tsx
@@ -381,7 +381,7 @@ export default function WebRTCVideo() {
     (mediaStream: MediaStream) => {
       if (!videoElm.current) return;
       const videoElmRefValue = videoElm.current;
-      console.log("Adding stream to video element", videoElmRefValue);
+      // console.log("Adding stream to video element", videoElmRefValue);
       videoElmRefValue.srcObject = mediaStream;
       updateVideoSizeStore(videoElmRefValue);
     },
@@ -397,7 +397,7 @@ export default function WebRTCVideo() {
       peerConnection.addEventListener(
         "track",
         (e: RTCTrackEvent) => {
-          console.log("Adding stream to video element");
+          // console.log("Adding stream to video element");
           addStreamToVideoElm(e.streams[0]);
         },
         { signal },
diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx
index d2662fc..fef1764 100644
--- a/ui/src/routes/devices.$id.tsx
+++ b/ui/src/routes/devices.$id.tsx
@@ -1,4 +1,4 @@
-import { useCallback, useEffect, useRef, useState } from "react";
+import { useCallback, useEffect, useMemo, useRef, useState } from "react";
 import {
   LoaderFunctionArgs,
   Outlet,
@@ -14,6 +14,7 @@ import {
 import { useInterval } from "usehooks-ts";
 import FocusTrap from "focus-trap-react";
 import { motion, AnimatePresence } from "framer-motion";
+import useWebSocket from "react-use-websocket";
 
 import { cx } from "@/cva.config";
 import {
@@ -43,15 +44,16 @@ import UpdateInProgressStatusCard from "../components/UpdateInProgressStatusCard
 import api from "../api";
 import Modal from "../components/Modal";
 import { useDeviceUiNavigation } from "../hooks/useAppNavigation";
+import {
+  ConnectionFailedOverlay,
+  LoadingConnectionOverlay,
+  PeerConnectionDisconnectedOverlay,
+} from "../components/VideoOverlay";
 import { FeatureFlagProvider } from "../providers/FeatureFlagProvider";
 import notifications from "../notifications";
-import {
-  ConnectionErrorOverlay,
-  LoadingConnectionOverlay,
-} from "../components/VideoOverlay";
 
-import { SystemVersionInfo } from "./devices.$id.settings.general.update";
 import { DeviceStatus } from "./welcome-local";
+import { SystemVersionInfo } from "./devices.$id.settings.general.update";
 
 interface LocalLoaderResp {
   authMode: "password" | "noPassword" | null;
@@ -117,7 +119,6 @@ const loader = async ({ params }: LoaderFunctionArgs) => {
 
 export default function KvmIdRoute() {
   const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp;
-
   // Depending on the mode, we set the appropriate variables
   const user = "user" in loaderResp ? loaderResp.user : null;
   const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null;
@@ -130,6 +131,8 @@ export default function KvmIdRoute() {
 
   const setIsTurnServerInUse = useRTCStore(state => state.setTurnServerInUse);
   const peerConnection = useRTCStore(state => state.peerConnection);
+  const setPeerConnectionState = useRTCStore(state => state.setPeerConnectionState);
+  const peerConnectionState = useRTCStore(state => state.peerConnectionState);
   const setMediaMediaStream = useRTCStore(state => state.setMediaStream);
   const setPeerConnection = useRTCStore(state => state.setPeerConnection);
   const setDiskChannel = useRTCStore(state => state.setDiskChannel);
@@ -137,23 +140,28 @@ export default function KvmIdRoute() {
   const setTransceiver = useRTCStore(state => state.setTransceiver);
   const location = useLocation();
 
+  const isLegacySignalingEnabled = useRef(false);
+
   const [connectionFailed, setConnectionFailed] = useState(false);
 
   const navigate = useNavigate();
   const { otaState, setOtaState, setModalView } = useUpdateStore();
 
   const [loadingMessage, setLoadingMessage] = useState("Connecting to device...");
-  const closePeerConnection = useCallback(
-    function closePeerConnection() {
+  const cleanupAndStopReconnecting = useCallback(
+    function cleanupAndStopReconnecting() {
       console.log("Closing peer connection");
 
       setConnectionFailed(true);
+      if (peerConnection) {
+        setPeerConnectionState(peerConnection.connectionState);
+      }
       connectionFailedRef.current = true;
 
       peerConnection?.close();
       signalingAttempts.current = 0;
     },
-    [peerConnection],
+    [peerConnection, setPeerConnectionState],
   );
 
   // We need to track connectionFailed in a ref to avoid stale closure issues
@@ -171,95 +179,233 @@ export default function KvmIdRoute() {
   }, [connectionFailed]);
 
   const signalingAttempts = useRef(0);
-  const syncRemoteSessionDescription = useCallback(
-    async function syncRemoteSessionDescription(pc: RTCPeerConnection) {
+  const setRemoteSessionDescription = useCallback(
+    async function setRemoteSessionDescription(
+      pc: RTCPeerConnection,
+      remoteDescription: RTCSessionDescriptionInit,
+    ) {
+      setLoadingMessage("Setting remote description");
+
       try {
-        if (!pc) return;
-
-        const sd = btoa(JSON.stringify(pc.localDescription));
-
-        const sessionUrl = isOnDevice
-          ? `${DEVICE_API}/webrtc/session`
-          : `${CLOUD_API}/webrtc/session`;
-
-        console.log("Trying to get remote session description");
-        setLoadingMessage(
-          `Getting remote session description...  ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
-        );
-        const res = await api.POST(sessionUrl, {
-          sd,
-          // When on device, we don't need to specify the device id, as it's already known
-          ...(isOnDevice ? {} : { id: params.id }),
-        });
-
-        const json = await res.json();
-        if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
-        if (!res.ok) {
-          console.error("Error getting SDP", { status: res.status, json });
-          throw new Error("Error getting SDP");
-        }
-
-        console.log("Successfully got Remote Session Description. Setting.");
-        setLoadingMessage("Setting remote session description...");
-
-        const decodedSd = atob(json.sd);
-        const parsedSd = JSON.parse(decodedSd);
-        pc.setRemoteDescription(new RTCSessionDescription(parsedSd));
-
-        await new Promise((resolve, reject) => {
-          console.log("Waiting for remote description to be set");
-          const maxAttempts = 10;
-          const interval = 1000;
-          let attempts = 0;
-
-          const checkInterval = setInterval(() => {
-            attempts++;
-            // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
-            if (pc.sctp?.state === "connected") {
-              console.log("Remote description set");
-              clearInterval(checkInterval);
-              resolve(true);
-            } else if (attempts >= maxAttempts) {
-              console.log(
-                `Failed to get remote description after ${maxAttempts} attempts`,
-              );
-              closePeerConnection();
-              clearInterval(checkInterval);
-              reject(
-                new Error(
-                  `Failed to get remote description after ${maxAttempts} attempts`,
-                ),
-              );
-            } else {
-              console.log("Waiting for remote description to be set");
-            }
-          }, interval);
-        });
+        await pc.setRemoteDescription(new RTCSessionDescription(remoteDescription));
+        console.log("[setRemoteSessionDescription] Remote description set successfully");
+        setLoadingMessage("Establishing secure connection...");
       } catch (error) {
-        console.error("Error getting SDP", { error });
-        console.log("Connection failed", connectionFailedRef.current);
-        if (connectionFailedRef.current) return;
-        if (signalingAttempts.current < 5) {
-          signalingAttempts.current++;
-          await new Promise(resolve => setTimeout(resolve, 500));
-          console.log("Attempting to get SDP again", signalingAttempts.current);
-          syncRemoteSessionDescription(pc);
-        } else {
-          closePeerConnection();
-        }
+        console.error(
+          "[setRemoteSessionDescription] Failed to set remote description:",
+          error,
+        );
+        cleanupAndStopReconnecting();
+        return;
       }
+
+      // Replace the interval-based check with a more reliable approach
+      let attempts = 0;
+      const checkInterval = setInterval(() => {
+        attempts++;
+
+        // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects
+        if (pc.sctp?.state === "connected") {
+          console.log("[setRemoteSessionDescription] Remote description set");
+          clearInterval(checkInterval);
+          setLoadingMessage("Connection established");
+        } else if (attempts >= 10) {
+          console.log(
+            "[setRemoteSessionDescription] Failed to establish connection after 10 attempts",
+            {
+              connectionState: pc.connectionState,
+              iceConnectionState: pc.iceConnectionState,
+            },
+          );
+          cleanupAndStopReconnecting();
+          clearInterval(checkInterval);
+        } else {
+          console.log("[setRemoteSessionDescription] Waiting for connection, state:", {
+            connectionState: pc.connectionState,
+            iceConnectionState: pc.iceConnectionState,
+          });
+        }
+      }, 1000);
     },
-    [closePeerConnection, navigate, params.id],
+    [cleanupAndStopReconnecting],
+  );
+
+  const ignoreOffer = useRef(false);
+  const isSettingRemoteAnswerPending = useRef(false);
+  const makingOffer = useRef(false);
+
+  const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:";
+
+  const { sendMessage, getWebSocket } = useWebSocket(
+    isOnDevice
+      ? `${wsProtocol}//${window.location.host}/webrtc/signaling/client`
+      : `${CLOUD_API.replace("http", "ws")}/webrtc/signaling/client?id=${params.id}`,
+    {
+      heartbeat: true,
+      retryOnError: true,
+      reconnectAttempts: 5,
+      reconnectInterval: 1000,
+      onReconnectStop: () => {
+        console.log("Reconnect stopped");
+        cleanupAndStopReconnecting();
+      },
+
+      shouldReconnect(event) {
+        console.log("[Websocket] shouldReconnect", event);
+        // TODO: Why true?
+        return true;
+      },
+
+      onClose(event) {
+        console.log("[Websocket] onClose", event);
+        // We don't want to close everything down, we wait for the reconnect to stop instead
+      },
+
+      onError(event) {
+        console.log("[Websocket] onError", event);
+        // We don't want to close everything down, we wait for the reconnect to stop instead
+      },
+      onOpen() {
+        console.log("[Websocket] onOpen");
+      },
+
+      onMessage: message => {
+        if (message.data === "pong") return;
+
+        /*
+          Currently the signaling process is as follows:
+            After open, the other side will send a `device-metadata` message with the device version
+            If the device version is not set, we can assume the device is using the legacy signaling
+            Otherwise, we can assume the device is using the new signaling
+
+            If the device is using the legacy signaling, we close the websocket connection
+            and use the legacy HTTPSignaling function to get the remote session description
+
+            If the device is using the new signaling, we don't need to do anything special, but continue to use the websocket connection
+            to chat with the other peer about the connection
+        */
+
+        const parsedMessage = JSON.parse(message.data);
+        if (parsedMessage.type === "device-metadata") {
+          const { deviceVersion } = parsedMessage.data;
+          console.log("[Websocket] Received device-metadata message");
+          console.log("[Websocket] Device version", deviceVersion);
+          // If the device version is not set, we can assume the device is using the legacy signaling
+          if (!deviceVersion) {
+            console.log("[Websocket] Device is using legacy signaling");
+
+            // Now we don't need the websocket connection anymore, as we've established that we need to use the legacy signaling
+            // which does everything over HTTP(at least from the perspective of the client)
+            isLegacySignalingEnabled.current = true;
+            getWebSocket()?.close();
+          } else {
+            console.log("[Websocket] Device is using new signaling");
+            isLegacySignalingEnabled.current = false;
+          }
+          setupPeerConnection();
+        }
+
+        if (!peerConnection) return;
+        if (parsedMessage.type === "answer") {
+          console.log("[Websocket] Received answer");
+          const readyForOffer =
+            // If we're making an offer, we don't want to accept an answer
+            !makingOffer &&
+            // If the peer connection is stable or we're setting the remote answer pending, we're ready for an offer
+            (peerConnection?.signalingState === "stable" ||
+              isSettingRemoteAnswerPending.current);
+
+          // If we're not ready for an offer, we don't want to accept an offer
+          ignoreOffer.current = parsedMessage.type === "offer" && !readyForOffer;
+          if (ignoreOffer.current) return;
+
+          // Set so we don't accept an answer while we're setting the remote description
+          isSettingRemoteAnswerPending.current = parsedMessage.type === "answer";
+          console.log(
+            "[Websocket] Setting remote answer pending",
+            isSettingRemoteAnswerPending.current,
+          );
+
+          const sd = atob(parsedMessage.data);
+          const remoteSessionDescription = JSON.parse(sd);
+
+          setRemoteSessionDescription(
+            peerConnection,
+            new RTCSessionDescription(remoteSessionDescription),
+          );
+
+          // Reset the remote answer pending flag
+          isSettingRemoteAnswerPending.current = false;
+        } else if (parsedMessage.type === "new-ice-candidate") {
+          console.log("[Websocket] Received new-ice-candidate");
+          const candidate = parsedMessage.data;
+          peerConnection.addIceCandidate(candidate);
+        }
+      },
+    },
+
+    // Don't even retry once we declare failure
+    !connectionFailed && isLegacySignalingEnabled.current === false,
+  );
+
+  const sendWebRTCSignal = useCallback(
+    (type: string, data: unknown) => {
+      // Second argument tells the library not to queue the message, and send it once the connection is established again.
+      // We have event handlers that handle the connection set up, so we don't need to queue the message.
+      sendMessage(JSON.stringify({ type, data }), false);
+    },
+    [sendMessage],
+  );
+
+  const legacyHTTPSignaling = useCallback(
+    async (pc: RTCPeerConnection) => {
+      const sd = btoa(JSON.stringify(pc.localDescription));
+
+      // Legacy mode == UI in cloud with updated code connecting to older device version.
+      // In device mode, old devices wont server this JS, and on newer devices legacy mode wont be enabled
+      const sessionUrl = `${CLOUD_API}/webrtc/session`;
+
+      console.log("Trying to get remote session description");
+      setLoadingMessage(
+        `Getting remote session description...  ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`,
+      );
+      const res = await api.POST(sessionUrl, {
+        sd,
+        // When on device, we don't need to specify the device id, as it's already known
+        ...(isOnDevice ? {} : { id: params.id }),
+      });
+
+      const json = await res.json();
+      if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login");
+      if (!res.ok) {
+        console.error("Error getting SDP", { status: res.status, json });
+        cleanupAndStopReconnecting();
+        return;
+      }
+
+      console.log("Successfully got Remote Session Description. Setting.");
+      setLoadingMessage("Setting remote session description...");
+
+      const decodedSd = atob(json.sd);
+      const parsedSd = JSON.parse(decodedSd);
+      setRemoteSessionDescription(pc, new RTCSessionDescription(parsedSd));
+    },
+    [cleanupAndStopReconnecting, navigate, params.id, setRemoteSessionDescription],
   );
 
   const setupPeerConnection = useCallback(async () => {
-    console.log("Setting up peer connection");
+    console.log("[setupPeerConnection] Setting up peer connection");
     setConnectionFailed(false);
     setLoadingMessage("Connecting to device...");
 
+    if (peerConnection?.signalingState === "stable") {
+      console.log("[setupPeerConnection] Peer connection already established");
+      return;
+    }
+
     let pc: RTCPeerConnection;
     try {
-      console.log("Creating peer connection");
+      console.log("[setupPeerConnection] Creating peer connection");
       setLoadingMessage("Creating peer connection...");
       pc = new RTCPeerConnection({
         // We only use STUN or TURN servers if we're in the cloud
@@ -267,30 +413,65 @@ export default function KvmIdRoute() {
           ? { iceServers: [iceConfig?.iceServers] }
           : {}),
       });
-      console.log("Peer connection created", pc);
-      setLoadingMessage("Peer connection created");
+
+      setPeerConnectionState(pc.connectionState);
+      console.log("[setupPeerConnection] Peer connection created", pc);
+      setLoadingMessage("Setting up connection to device...");
     } catch (e) {
-      console.error(`Error creating peer connection: ${e}`);
+      console.error(`[setupPeerConnection] Error creating peer connection: ${e}`);
       setTimeout(() => {
-        closePeerConnection();
+        cleanupAndStopReconnecting();
       }, 1000);
       return;
     }
 
     // Set up event listeners and data channels
     pc.onconnectionstatechange = () => {
-      console.log("Connection state changed", pc.connectionState);
+      console.log("[setupPeerConnection] Connection state changed", pc.connectionState);
+      setPeerConnectionState(pc.connectionState);
+    };
+
+    pc.onnegotiationneeded = async () => {
+      try {
+        console.log("[setupPeerConnection] Creating offer");
+        makingOffer.current = true;
+
+        const offer = await pc.createOffer();
+        await pc.setLocalDescription(offer);
+        const sd = btoa(JSON.stringify(pc.localDescription));
+        const isNewSignalingEnabled = isLegacySignalingEnabled.current === false;
+        if (isNewSignalingEnabled) {
+          sendWebRTCSignal("offer", { sd: sd });
+        } else {
+          console.log("Legacy signanling. Waiting for ICE Gathering to complete...");
+        }
+      } catch (e) {
+        console.error(
+          `[setupPeerConnection] Error creating offer: ${e}`,
+          new Date().toISOString(),
+        );
+        cleanupAndStopReconnecting();
+      } finally {
+        makingOffer.current = false;
+      }
+    };
+
+    pc.onicecandidate = async ({ candidate }) => {
+      if (!candidate) return;
+      if (candidate.candidate === "") return;
+      sendWebRTCSignal("new-ice-candidate", candidate);
     };
 
     pc.onicegatheringstatechange = event => {
       const pc = event.currentTarget as RTCPeerConnection;
-      console.log("ICE Gathering State Changed", pc.iceGatheringState);
       if (pc.iceGatheringState === "complete") {
         console.log("ICE Gathering completed");
         setLoadingMessage("ICE Gathering completed");
 
-        // We can now start the https/ws connection to get the remote session description from the KVM device
-        syncRemoteSessionDescription(pc);
+        if (isLegacySignalingEnabled.current) {
+          // We can now start the https/ws connection to get the remote session description from the KVM device
+          legacyHTTPSignaling(pc);
+        }
       } else if (pc.iceGatheringState === "gathering") {
         console.log("ICE Gathering Started");
         setLoadingMessage("Gathering ICE candidates...");
@@ -314,31 +495,26 @@ export default function KvmIdRoute() {
     };
 
     setPeerConnection(pc);
-
-    try {
-      const offer = await pc.createOffer();
-      await pc.setLocalDescription(offer);
-    } catch (e) {
-      console.error(`Error creating offer: ${e}`, new Date().toISOString());
-      closePeerConnection();
-    }
   }, [
-    closePeerConnection,
+    cleanupAndStopReconnecting,
     iceConfig?.iceServers,
+    legacyHTTPSignaling,
+    peerConnection?.signalingState,
+    sendWebRTCSignal,
     setDiskChannel,
     setMediaMediaStream,
     setPeerConnection,
+    setPeerConnectionState,
     setRpcDataChannel,
     setTransceiver,
-    syncRemoteSessionDescription,
   ]);
 
-  // On boot, if the connection state is undefined, we connect to the WebRTC
   useEffect(() => {
-    if (peerConnection?.connectionState === undefined) {
-      setupPeerConnection();
+    if (peerConnectionState === "failed") {
+      console.log("Connection failed, closing peer connection");
+      cleanupAndStopReconnecting();
     }
-  }, [setupPeerConnection, peerConnection?.connectionState]);
+  }, [peerConnectionState, cleanupAndStopReconnecting]);
 
   // Cleanup effect
   const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats);
@@ -363,7 +539,7 @@ export default function KvmIdRoute() {
 
   // TURN server usage detection
   useEffect(() => {
-    if (peerConnection?.connectionState !== "connected") return;
+    if (peerConnectionState !== "connected") return;
     const { localCandidateStats, remoteCandidateStats } = useRTCStore.getState();
 
     const lastLocalStat = Array.from(localCandidateStats).pop();
@@ -375,7 +551,7 @@ export default function KvmIdRoute() {
     const remoteCandidateIsUsingTurn = lastRemoteStat[1].candidateType === "relay"; // [0] is the timestamp, which we don't care about here
 
     setIsTurnServerInUse(localCandidateIsUsingTurn || remoteCandidateIsUsingTurn);
-  }, [peerConnection?.connectionState, setIsTurnServerInUse]);
+  }, [peerConnectionState, setIsTurnServerInUse]);
 
   // TURN server usage reporting
   const isTurnServerInUse = useRTCStore(state => state.isTurnServerInUse);
@@ -466,10 +642,6 @@ export default function KvmIdRoute() {
     });
   }, [rpcDataChannel?.readyState, send, setHdmiState]);
 
-  // eslint-disable-next-line @typescript-eslint/ban-ts-comment
-  // @ts-expect-error
-  window.send = send;
-
   // When the update is successful, we need to refresh the client javascript and show a success modal
   useEffect(() => {
     if (queryParams.get("updateSuccess")) {
@@ -506,12 +678,12 @@ export default function KvmIdRoute() {
   useEffect(() => {
     if (!peerConnection) return;
     if (!kvmTerminal) {
-      console.log('Creating data channel "terminal"');
+      // console.log('Creating data channel "terminal"');
       setKvmTerminal(peerConnection.createDataChannel("terminal"));
     }
 
     if (!serialConsole) {
-      console.log('Creating data channel "serial"');
+      // console.log('Creating data channel "serial"');
       setSerialConsole(peerConnection.createDataChannel("serial"));
     }
   }, [kvmTerminal, peerConnection, serialConsole]);
@@ -554,6 +726,43 @@ export default function KvmIdRoute() {
     [send, setScrollSensitivity],
   );
 
+  const ConnectionStatusElement = useMemo(() => {
+    const hasConnectionFailed =
+      connectionFailed || ["failed", "closed"].includes(peerConnectionState || "");
+
+    const isPeerConnectionLoading =
+      ["connecting", "new"].includes(peerConnectionState || "") ||
+      peerConnection === null;
+
+    const isDisconnected = peerConnectionState === "disconnected";
+
+    const isOtherSession = location.pathname.includes("other-session");
+
+    if (isOtherSession) return null;
+    if (peerConnectionState === "connected") return null;
+    if (isDisconnected) {
+      return <PeerConnectionDisconnectedOverlay show={true} />;
+    }
+
+    if (hasConnectionFailed)
+      return (
+        <ConnectionFailedOverlay show={true} setupPeerConnection={setupPeerConnection} />
+      );
+
+    if (isPeerConnectionLoading) {
+      return <LoadingConnectionOverlay show={true} text={loadingMessage} />;
+    }
+
+    return null;
+  }, [
+    connectionFailed,
+    loadingMessage,
+    location.pathname,
+    peerConnection,
+    peerConnectionState,
+    setupPeerConnection,
+  ]);
+
   return (
     <FeatureFlagProvider appVersion={appVersion}>
       {!outlet && otaState.updating && (
@@ -593,27 +802,13 @@ export default function KvmIdRoute() {
           />
 
           <div className="flex h-full w-full overflow-hidden">
-            <div className="pointer-events-none fixed inset-0 isolate z-50 flex h-full w-full items-center justify-center">
+            <div className="pointer-events-none fixed inset-0 isolate z-20 flex h-full w-full items-center justify-center">
               <div className="my-2 h-full max-h-[720px] w-full max-w-[1280px] rounded-md">
-                <LoadingConnectionOverlay
-                  show={
-                    !connectionFailed &&
-                    (["connecting", "new"].includes(
-                      peerConnection?.connectionState || "",
-                    ) ||
-                      peerConnection === null) &&
-                    !location.pathname.includes("other-session")
-                  }
-                  text={loadingMessage}
-                />
-                <ConnectionErrorOverlay
-                  show={connectionFailed && !location.pathname.includes("other-session")}
-                  setupPeerConnection={setupPeerConnection}
-                />
+                {!!ConnectionStatusElement && ConnectionStatusElement}
               </div>
             </div>
 
-            <WebRTCVideo />
+            {peerConnectionState === "connected" && <WebRTCVideo />}
             <SidebarContainer sidebarView={sidebarView} />
           </div>
         </div>
diff --git a/web.go b/web.go
index 9201e7b..c3f6d8d 100644
--- a/web.go
+++ b/web.go
@@ -1,6 +1,7 @@
 package kvm
 
 import (
+	"context"
 	"embed"
 	"encoding/json"
 	"fmt"
@@ -10,8 +11,12 @@ import (
 	"strings"
 	"time"
 
+	"github.com/coder/websocket"
+	"github.com/coder/websocket/wsjson"
 	"github.com/gin-gonic/gin"
 	"github.com/google/uuid"
+	"github.com/pion/webrtc/v4"
+	"github.com/prometheus/client_golang/prometheus"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"golang.org/x/crypto/bcrypt"
 )
@@ -94,7 +99,7 @@ func setupRouter() *gin.Engine {
 	protected := r.Group("/")
 	protected.Use(protectedMiddleware())
 	{
-		protected.POST("/webrtc/session", handleWebRTCSession)
+		protected.GET("/webrtc/signaling/client", handleLocalWebRTCSignal)
 		protected.POST("/cloud/register", handleCloudRegister)
 		protected.GET("/cloud/state", handleCloudState)
 		protected.GET("/device", handleDevice)
@@ -121,35 +126,182 @@ func setupRouter() *gin.Engine {
 // TODO: support multiple sessions?
 var currentSession *Session
 
-func handleWebRTCSession(c *gin.Context) {
-	var req WebRTCSessionRequest
-
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		return
+func handleLocalWebRTCSignal(c *gin.Context) {
+	cloudLogger.Infof("new websocket connection established")
+	// Create WebSocket options with InsecureSkipVerify to bypass origin check
+	wsOptions := &websocket.AcceptOptions{
+		InsecureSkipVerify: true, // Allow connections from any origin
 	}
 
-	session, err := newSession(SessionConfig{})
+	wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions)
 	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err})
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
 
-	sd, err := session.ExchangeOffer(req.Sd)
+	// get the source from the request
+	source := c.ClientIP()
+
+	// Now use conn for websocket operations
+	defer wsCon.Close(websocket.StatusNormalClosure, "")
+
+	err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": gin.H{"deviceVersion": builtAppVersion}})
 	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err})
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
-	if currentSession != nil {
-		writeJSONRPCEvent("otherSessionConnected", nil, currentSession)
-		peerConn := currentSession.peerConnection
+
+	err = handleWebRTCSignalWsMessages(wsCon, false, source)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+}
+
+func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string) error {
+	runCtx, cancelRun := context.WithCancel(context.Background())
+	defer cancelRun()
+
+	// Add connection tracking to detect reconnections
+	connectionID := uuid.New().String()
+	cloudLogger.Infof("new websocket connection established with ID: %s", connectionID)
+
+	// connection type
+	var sourceType string
+	if isCloudConnection {
+		sourceType = "cloud"
+	} else {
+		sourceType = "local"
+	}
+
+	// probably we can use a better logging framework here
+	logInfof := func(format string, args ...interface{}) {
+		args = append(args, source, sourceType)
+		websocketLogger.Infof(format+", source: %s, sourceType: %s", args...)
+	}
+	logWarnf := func(format string, args ...interface{}) {
+		args = append(args, source, sourceType)
+		websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...)
+	}
+	logTracef := func(format string, args ...interface{}) {
+		args = append(args, source, sourceType)
+		websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...)
+	}
+
+	go func() {
+		for {
+			time.Sleep(WebsocketPingInterval)
+
+			// set the timer for the ping duration
+			timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
+				metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v)
+				metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v)
+			}))
+
+			logInfof("pinging websocket")
+			err := wsCon.Ping(runCtx)
+
+			if err != nil {
+				logWarnf("websocket ping error: %v", err)
+				cancelRun()
+				return
+			}
+
+			// dont use `defer` here because we want to observe the duration of the ping
+			timer.ObserveDuration()
+
+			metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc()
+			metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
+		}
+	}()
+
+	if isCloudConnection {
+		// create a channel to receive the disconnect event, once received, we cancelRun
+		cloudDisconnectChan = make(chan error)
+		defer func() {
+			close(cloudDisconnectChan)
+			cloudDisconnectChan = nil
+		}()
 		go func() {
-			time.Sleep(1 * time.Second)
-			_ = peerConn.Close()
+			for err := range cloudDisconnectChan {
+				if err == nil {
+					continue
+				}
+				cloudLogger.Infof("disconnecting from cloud due to: %v", err)
+				cancelRun()
+			}
 		}()
 	}
-	currentSession = session
-	c.JSON(http.StatusOK, gin.H{"sd": sd})
+
+	for {
+		typ, msg, err := wsCon.Read(runCtx)
+		if err != nil {
+			logWarnf("websocket read error: %v", err)
+			return err
+		}
+		if typ != websocket.MessageText {
+			// ignore non-text messages
+			continue
+		}
+
+		var message struct {
+			Type string          `json:"type"`
+			Data json.RawMessage `json:"data"`
+		}
+
+		err = json.Unmarshal(msg, &message)
+		if err != nil {
+			logWarnf("unable to parse ws message: %v", err)
+			continue
+		}
+
+		if message.Type == "offer" {
+			logInfof("new session request received")
+			var req WebRTCSessionRequest
+			err = json.Unmarshal(message.Data, &req)
+			if err != nil {
+				logWarnf("unable to parse session request data: %v", err)
+				continue
+			}
+
+			logInfof("new session request: %v", req.OidcGoogle)
+			logTracef("session request info: %v", req)
+
+			metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc()
+			metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
+			err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source)
+			if err != nil {
+				logWarnf("error starting new session: %v", err)
+				continue
+			}
+		} else if message.Type == "new-ice-candidate" {
+			logInfof("The client sent us a new ICE candidate: %v", string(message.Data))
+			var candidate webrtc.ICECandidateInit
+
+			// Attempt to unmarshal as a ICECandidateInit
+			if err := json.Unmarshal(message.Data, &candidate); err != nil {
+				logWarnf("unable to parse incoming ICE candidate data: %v", string(message.Data))
+				continue
+			}
+
+			if candidate.Candidate == "" {
+				logWarnf("empty incoming ICE candidate, skipping")
+				continue
+			}
+
+			logInfof("unmarshalled incoming ICE candidate: %v", candidate)
+
+			if currentSession == nil {
+				logInfof("no current session, skipping incoming ICE candidate")
+				continue
+			}
+
+			logInfof("adding incoming ICE candidate to current session: %v", candidate)
+			if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
+				logWarnf("failed to add incoming ICE candidate to our peer connection: %v", err)
+			}
+		}
+	}
 }
 
 func handleLogin(c *gin.Context) {
diff --git a/webrtc.go b/webrtc.go
index 12d4f95..a047ecc 100644
--- a/webrtc.go
+++ b/webrtc.go
@@ -1,11 +1,15 @@
 package kvm
 
 import (
+	"context"
 	"encoding/base64"
 	"encoding/json"
 	"net"
 	"strings"
 
+	"github.com/coder/websocket"
+	"github.com/coder/websocket/wsjson"
+	"github.com/gin-gonic/gin"
 	"github.com/pion/webrtc/v4"
 )
 
@@ -23,6 +27,7 @@ type SessionConfig struct {
 	ICEServers []string
 	LocalIP    string
 	IsCloud    bool
+	ws         *websocket.Conn
 }
 
 func (s *Session) ExchangeOffer(offerStr string) (string, error) {
@@ -46,19 +51,11 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
 		return "", err
 	}
 
-	// Create channel that is blocked until ICE Gathering is complete
-	gatherComplete := webrtc.GatheringCompletePromise(s.peerConnection)
-
 	// Sets the LocalDescription, and starts our UDP listeners
 	if err = s.peerConnection.SetLocalDescription(answer); err != nil {
 		return "", err
 	}
 
-	// Block until ICE Gathering is complete, disabling trickle ICE
-	// we do this because we only can exchange one signaling message
-	// in a production application you should exchange ICE Candidates via OnICECandidate
-	<-gatherComplete
-
 	localDescription, err := json.Marshal(s.peerConnection.LocalDescription())
 	if err != nil {
 		return "", err
@@ -144,6 +141,16 @@ func newSession(config SessionConfig) (*Session, error) {
 	}()
 	var isConnected bool
 
+	peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
+		logger.Infof("Our WebRTC peerConnection has a new ICE candidate: %v", candidate)
+		if candidate != nil {
+			err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
+			if err != nil {
+				logger.Errorf("failed to write new-ice-candidate to WebRTC signaling channel: %v", err)
+			}
+		}
+	})
+
 	peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
 		logger.Infof("Connection State has changed %s", connectionState)
 		if connectionState == webrtc.ICEConnectionStateConnected {