diff --git a/cloud.go b/cloud.go index 7ad8b75..070db8d 100644 --- a/cloud.go +++ b/cloud.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -59,6 +60,13 @@ var ( }, []string{"type", "source"}, ) + metricConnectionLastPingReceivedTimestamp = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "jetkvm_connection_last_ping_received_timestamp", + Help: "The timestamp when the last ping request was received", + }, + []string{"type", "source"}, + ) metricConnectionLastPingDuration = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_ping_duration", @@ -76,16 +84,23 @@ var ( }, []string{"type", "source"}, ) - metricConnectionTotalPingCount = promauto.NewCounterVec( + metricConnectionTotalPingSentCount = promauto.NewCounterVec( prometheus.CounterOpts{ - Name: "jetkvm_connection_total_ping_count", + Name: "jetkvm_connection_total_ping_sent", Help: "The total number of pings sent to the connection", }, []string{"type", "source"}, ) + metricConnectionTotalPingReceivedCount = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "jetkvm_connection_total_ping_received", + Help: "The total number of pings received from the connection", + }, + []string{"type", "source"}, + ) metricConnectionSessionRequestCount = promauto.NewCounterVec( prometheus.CounterOpts{ - Name: "jetkvm_connection_session_total_request_count", + Name: "jetkvm_connection_session_total_requests", Help: "The total number of session requests received", }, []string{"type", "source"}, @@ -131,6 +146,8 @@ func wsResetMetrics(established bool, sourceType string, source string) { metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1) metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1) + metricConnectionLastPingReceivedTimestamp.WithLabelValues(sourceType, source).Set(-1) + metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1) metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1) @@ -275,18 +292,31 @@ func runWebsocketClient() error { defer cancelDial() c, _, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{ HTTPHeader: header, + OnPingReceived: func(ctx context.Context, payload []byte) bool { + websocketLogger.Infof("ping frame received: %v, source: %s, sourceType: cloud", payload, wsURL.Host) + + metricConnectionTotalPingReceivedCount.WithLabelValues("cloud", wsURL.Host).Inc() + metricConnectionLastPingReceivedTimestamp.WithLabelValues("cloud", wsURL.Host).SetToCurrentTime() + + return true + }, }) + // if the context is canceled, we don't want to return an error if err != nil { + if errors.Is(err, context.Canceled) { + cloudLogger.Infof("websocket connection canceled") + return nil + } return err } defer c.CloseNow() //nolint:errcheck cloudLogger.Infof("websocket connected to %s", wsURL) // set the metrics when we successfully connect to the cloud. - wsResetMetrics(true, "cloud", "") + wsResetMetrics(true, "cloud", wsURL.Host) // we don't have a source for the cloud connection - return handleWebRTCSignalWsMessages(c, true, "") + return handleWebRTCSignalWsMessages(c, true, wsURL.Host) } func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { @@ -375,9 +405,6 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess func RunWebsocketClient() { for { - // reset the metrics when we start the websocket client. - wsResetMetrics(false, "cloud", "") - // If the cloud token is not set, we don't need to run the websocket client. if config.CloudToken == "" { time.Sleep(5 * time.Second) diff --git a/dev_deploy.sh b/dev_deploy.sh index 7fbf29e..02bbb24 100755 --- a/dev_deploy.sh +++ b/dev_deploy.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash +# # Exit immediately if a command exits with a non-zero status set -e @@ -16,7 +18,6 @@ show_help() { echo "Example:" echo " $0 -r 192.168.0.17" echo " $0 -r 192.168.0.17 -u admin" - exit 0 } # Default values @@ -70,10 +71,10 @@ cd bin ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app_debug || true" # Copy the binary to the remote host -cat jetkvm_app | ssh "${REMOTE_USER}@${REMOTE_HOST}" "cat > $REMOTE_PATH/jetkvm_app_debug" +ssh "${REMOTE_USER}@${REMOTE_HOST}" "cat > ${REMOTE_PATH}/jetkvm_app_debug" < jetkvm_app # Deploy and run the application on the remote host -ssh "${REMOTE_USER}@${REMOTE_HOST}" ash < void; }) { return ( - + -
+
{/* TODO: This doesn't work well with other-sessions */}
(null); const mediaStream = useRTCStore(state => state.mediaStream); const [isPlaying, setIsPlaying] = useState(false); + const peerConnectionState = useRTCStore(state => state.peerConnectionState); // Store hooks const settings = useSettingsStore(); @@ -601,7 +602,10 @@ export default function WebRTCVideo() { "cursor-none": settings.mouseMode === "absolute" && settings.isCursorHidden, - "opacity-0": isVideoLoading || hdmiError, + "opacity-0": + isVideoLoading || + hdmiError || + peerConnectionState !== "connected", "animate-slideUpFade border border-slate-800/30 opacity-0 shadow dark:border-slate-300/20": isPlaying, }, diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index fef1764..82bb542 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -243,7 +243,7 @@ export default function KvmIdRoute() { { heartbeat: true, retryOnError: true, - reconnectAttempts: 5, + reconnectAttempts: 15, reconnectInterval: 1000, onReconnectStop: () => { console.log("Reconnect stopped"); @@ -398,11 +398,6 @@ export default function KvmIdRoute() { setConnectionFailed(false); setLoadingMessage("Connecting to device..."); - if (peerConnection?.signalingState === "stable") { - console.log("[setupPeerConnection] Peer connection already established"); - return; - } - let pc: RTCPeerConnection; try { console.log("[setupPeerConnection] Creating peer connection"); @@ -499,7 +494,6 @@ export default function KvmIdRoute() { cleanupAndStopReconnecting, iceConfig?.iceServers, legacyHTTPSignaling, - peerConnection?.signalingState, sendWebRTCSignal, setDiskChannel, setMediaMediaStream, @@ -791,6 +785,7 @@ export default function KvmIdRoute() {
+
-
-
-
+
+ +
+
{!!ConnectionStatusElement && ConnectionStatusElement}
- - {peerConnectionState === "connected" && }
e.stopPropagation()} onKeyDown={e => { e.stopPropagation(); diff --git a/web.go b/web.go index c3f6d8d..6c35073 100644 --- a/web.go +++ b/web.go @@ -1,9 +1,11 @@ package kvm import ( + "bytes" "context" "embed" "encoding/json" + "errors" "fmt" "io/fs" "net/http" @@ -99,6 +101,22 @@ func setupRouter() *gin.Engine { protected := r.Group("/") protected.Use(protectedMiddleware()) { + /* + * Legacy WebRTC session endpoint + * + * This endpoint is maintained for backward compatibility when users upgrade from a version + * using the legacy HTTP-based signaling method to the new WebSocket-based signaling method. + * + * During the upgrade process, when the "Rebooting device after update..." message appears, + * the browser still runs the previous JavaScript code which polls this endpoint to establish + * a new WebRTC session. Once the session is established, the page will automatically reload + * with the updated code. + * + * Without this endpoint, the stale JavaScript would fail to establish a connection, + * causing users to see the "Rebooting device after update..." message indefinitely + * until they manually refresh the page, leading to a confusing user experience. + */ + protected.POST("/webrtc/session", handleWebRTCSession) protected.GET("/webrtc/signaling/client", handleLocalWebRTCSignal) protected.POST("/cloud/register", handleCloudRegister) protected.GET("/cloud/state", handleCloudState) @@ -126,11 +144,59 @@ func setupRouter() *gin.Engine { // TODO: support multiple sessions? var currentSession *Session +func handleWebRTCSession(c *gin.Context) { + var req WebRTCSessionRequest + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + session, err := newSession(SessionConfig{}) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err}) + return + } + + sd, err := session.ExchangeOffer(req.Sd) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err}) + return + } + if currentSession != nil { + writeJSONRPCEvent("otherSessionConnected", nil, currentSession) + peerConn := currentSession.peerConnection + go func() { + time.Sleep(1 * time.Second) + _ = peerConn.Close() + }() + } + currentSession = session + c.JSON(http.StatusOK, gin.H{"sd": sd}) +} + +var ( + pingMessage = []byte("ping") + pongMessage = []byte("pong") +) + func handleLocalWebRTCSignal(c *gin.Context) { cloudLogger.Infof("new websocket connection established") + + // get the source from the request + source := c.ClientIP() + // Create WebSocket options with InsecureSkipVerify to bypass origin check wsOptions := &websocket.AcceptOptions{ InsecureSkipVerify: true, // Allow connections from any origin + OnPingReceived: func(ctx context.Context, payload []byte) bool { + websocketLogger.Infof("ping frame received: %v, source: %s, sourceType: local", payload, source) + + metricConnectionTotalPingReceivedCount.WithLabelValues("local", source).Inc() + metricConnectionLastPingReceivedTimestamp.WithLabelValues("local", source).SetToCurrentTime() + + return true + }, } wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions) @@ -139,9 +205,6 @@ func handleLocalWebRTCSignal(c *gin.Context) { return } - // get the source from the request - source := c.ClientIP() - // Now use conn for websocket operations defer wsCon.Close(websocket.StatusNormalClosure, "") @@ -164,7 +227,6 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, // Add connection tracking to detect reconnections connectionID := uuid.New().String() - cloudLogger.Infof("new websocket connection established with ID: %s", connectionID) // connection type var sourceType string @@ -176,29 +238,40 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, // probably we can use a better logging framework here logInfof := func(format string, args ...interface{}) { - args = append(args, source, sourceType) - websocketLogger.Infof(format+", source: %s, sourceType: %s", args...) + args = append(args, source, sourceType, connectionID) + websocketLogger.Infof(format+", source: %s, sourceType: %s, id: %s", args...) } logWarnf := func(format string, args ...interface{}) { - args = append(args, source, sourceType) - websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...) + args = append(args, source, sourceType, connectionID) + websocketLogger.Warnf(format+", source: %s, sourceType: %s, id: %s", args...) } logTracef := func(format string, args ...interface{}) { - args = append(args, source, sourceType) - websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...) + args = append(args, source, sourceType, connectionID) + websocketLogger.Tracef(format+", source: %s, sourceType: %s, id: %s", args...) } + logInfof("new websocket connection established") + go func() { for { time.Sleep(WebsocketPingInterval) + if ctxErr := runCtx.Err(); ctxErr != nil { + if !errors.Is(ctxErr, context.Canceled) { + logWarnf("websocket connection closed: %v", ctxErr) + } else { + logTracef("websocket connection closed as the context was canceled: %v") + } + return + } + // set the timer for the ping duration timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v) metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v) })) - logInfof("pinging websocket") + logTracef("sending ping frame") err := wsCon.Ping(runCtx) if err != nil { @@ -208,10 +281,12 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, } // dont use `defer` here because we want to observe the duration of the ping - timer.ObserveDuration() + duration := timer.ObserveDuration() - metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc() + metricConnectionTotalPingSentCount.WithLabelValues(sourceType, source).Inc() metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() + + logTracef("received pong frame, duration: %v", duration) } }() @@ -249,6 +324,20 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, Data json.RawMessage `json:"data"` } + if bytes.Equal(msg, pingMessage) { + logInfof("ping message received: %s", string(msg)) + err = wsCon.Write(context.Background(), websocket.MessageText, pongMessage) + if err != nil { + logWarnf("unable to write pong message: %v", err) + return err + } + + metricConnectionTotalPingReceivedCount.WithLabelValues(sourceType, source).Inc() + metricConnectionLastPingReceivedTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() + + continue + } + err = json.Unmarshal(msg, &message) if err != nil { logWarnf("unable to parse ws message: %v", err) @@ -264,8 +353,9 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, continue } - logInfof("new session request: %v", req.OidcGoogle) - logTracef("session request info: %v", req) + if req.OidcGoogle != "" { + logInfof("new session request with OIDC Google: %v", req.OidcGoogle) + } metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc() metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()