package kvm import ( "bytes" "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "sync" "time" "github.com/coder/websocket/wsjson" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/coreos/go-oidc/v3/oidc" "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/rs/zerolog" ) type CloudRegisterRequest struct { Token string `json:"token"` CloudAPI string `json:"cloudApi"` OidcGoogle string `json:"oidcGoogle"` ClientId string `json:"clientId"` } const ( // CloudWebSocketConnectTimeout is the timeout for the websocket connection to the cloud CloudWebSocketConnectTimeout = 1 * time.Minute // CloudAPIRequestTimeout is the timeout for cloud API requests CloudAPIRequestTimeout = 10 * time.Second // CloudOidcRequestTimeout is the timeout for OIDC token verification requests // should be lower than the websocket response timeout set in cloud-api CloudOidcRequestTimeout = 10 * time.Second // WebsocketPingInterval is the interval at which the websocket client sends ping messages to the cloud WebsocketPingInterval = 15 * time.Second ) var ( metricCloudConnectionStatus = promauto.NewGauge( prometheus.GaugeOpts{ Name: "jetkvm_cloud_connection_status", Help: "The status of the cloud connection", }, ) metricCloudConnectionEstablishedTimestamp = promauto.NewGauge( prometheus.GaugeOpts{ Name: "jetkvm_cloud_connection_established_timestamp", Help: "The timestamp when the cloud connection was established", }, ) metricConnectionLastPingTimestamp = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_ping_timestamp", Help: "The timestamp when the last ping response was received", }, []string{"type", "source"}, ) metricConnectionLastPingReceivedTimestamp = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_ping_received_timestamp", Help: "The timestamp when the last ping request was received", }, []string{"type", "source"}, ) metricConnectionLastPingDuration = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_ping_duration", Help: "The duration of the last ping response", }, []string{"type", "source"}, ) metricConnectionPingDuration = promauto.NewHistogramVec( prometheus.HistogramOpts{ Name: "jetkvm_connection_ping_duration", Help: "The duration of the ping response", Buckets: []float64{ 0.1, 0.5, 1, 10, }, }, []string{"type", "source"}, ) metricConnectionTotalPingSentCount = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "jetkvm_connection_total_ping_sent", Help: "The total number of pings sent to the connection", }, []string{"type", "source"}, ) metricConnectionTotalPingReceivedCount = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "jetkvm_connection_total_ping_received", Help: "The total number of pings received from the connection", }, []string{"type", "source"}, ) metricConnectionSessionRequestCount = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "jetkvm_connection_session_total_requests", Help: "The total number of session requests received", }, []string{"type", "source"}, ) metricConnectionSessionRequestDuration = promauto.NewHistogramVec( prometheus.HistogramOpts{ Name: "jetkvm_connection_session_request_duration", Help: "The duration of session requests", Buckets: []float64{ 0.1, 0.5, 1, 10, }, }, []string{"type", "source"}, ) metricConnectionLastSessionRequestTimestamp = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_session_request_timestamp", Help: "The timestamp of the last session request", }, []string{"type", "source"}, ) metricConnectionLastSessionRequestDuration = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_session_request_duration", Help: "The duration of the last session request", }, []string{"type", "source"}, ) metricCloudConnectionFailureCount = promauto.NewCounter( prometheus.CounterOpts{ Name: "jetkvm_cloud_connection_failure_count", Help: "The number of times the cloud connection has failed", }, ) ) type CloudConnectionState uint8 const ( CloudConnectionStateNotConfigured CloudConnectionState = iota CloudConnectionStateDisconnected CloudConnectionStateConnecting CloudConnectionStateConnected ) var ( cloudConnectionState CloudConnectionState = CloudConnectionStateNotConfigured cloudConnectionStateLock = &sync.Mutex{} cloudDisconnectChan chan error cloudDisconnectLock = &sync.Mutex{} ) func setCloudConnectionState(state CloudConnectionState) { cloudConnectionStateLock.Lock() defer cloudConnectionStateLock.Unlock() if cloudConnectionState == CloudConnectionStateDisconnected && (config.CloudToken == "" || config.CloudURL == "") { state = CloudConnectionStateNotConfigured } previousState := cloudConnectionState cloudConnectionState = state go waitCtrlAndRequestDisplayUpdate( previousState != state, ) } func wsResetMetrics(established bool, sourceType string, source string) { metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1) metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1) metricConnectionLastPingReceivedTimestamp.WithLabelValues(sourceType, source).Set(-1) metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1) metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1) if sourceType != "cloud" { return } if established { metricCloudConnectionEstablishedTimestamp.SetToCurrentTime() metricCloudConnectionStatus.Set(1) } else { metricCloudConnectionEstablishedTimestamp.Set(-1) metricCloudConnectionStatus.Set(-1) } } func handleCloudRegister(c *gin.Context) { var req CloudRegisterRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(400, gin.H{"error": "Invalid request body"}) return } // Exchange the temporary token for a permanent auth token payload := struct { TempToken string `json:"tempToken"` }{ TempToken: req.Token, } jsonPayload, err := json.Marshal(payload) if err != nil { c.JSON(500, gin.H{"error": "Failed to encode JSON payload: " + err.Error()}) return } client := &http.Client{Timeout: CloudAPIRequestTimeout} apiReq, err := http.NewRequest(http.MethodPost, config.CloudURL+"/devices/token", bytes.NewBuffer(jsonPayload)) if err != nil { c.JSON(500, gin.H{"error": "Failed to create register request: " + err.Error()}) return } apiReq.Header.Set("Content-Type", "application/json") apiResp, err := client.Do(apiReq) if err != nil { c.JSON(500, gin.H{"error": "Failed to exchange token: " + err.Error()}) return } defer apiResp.Body.Close() if apiResp.StatusCode != http.StatusOK { c.JSON(apiResp.StatusCode, gin.H{"error": "Failed to exchange token: " + apiResp.Status}) return } var tokenResp struct { SecretToken string `json:"secretToken"` } if err := json.NewDecoder(apiResp.Body).Decode(&tokenResp); err != nil { c.JSON(500, gin.H{"error": "Failed to parse token response: " + err.Error()}) return } if tokenResp.SecretToken == "" { c.JSON(500, gin.H{"error": "Received empty secret token"}) return } config.CloudToken = tokenResp.SecretToken provider, err := oidc.NewProvider(c, "https://accounts.google.com") if err != nil { c.JSON(500, gin.H{"error": "Failed to initialize OIDC provider: " + err.Error()}) return } oidcConfig := &oidc.Config{ ClientID: req.ClientId, } verifier := provider.Verifier(oidcConfig) idToken, err := verifier.Verify(c, req.OidcGoogle) if err != nil { c.JSON(400, gin.H{"error": "Invalid OIDC token: " + err.Error()}) return } config.GoogleIdentity = idToken.Audience[0] + ":" + idToken.Subject // Save the updated configuration if err := SaveConfig(); err != nil { c.JSON(500, gin.H{"error": "Failed to save configuration"}) return } c.JSON(200, gin.H{"message": "Cloud registration successful"}) } func disconnectCloud(reason error) { cloudDisconnectLock.Lock() defer cloudDisconnectLock.Unlock() if cloudDisconnectChan == nil { cloudLogger.Trace().Msg("cloud disconnect channel is not set, no need to disconnect") return } // just in case the channel is closed, we don't want to panic defer func() { if r := recover(); r != nil { cloudLogger.Warn().Interface("reason", r).Msg("cloud disconnect channel is closed, no need to disconnect") } }() cloudDisconnectChan <- reason } func runWebsocketClient() error { if config.CloudToken == "" { time.Sleep(5 * time.Second) return fmt.Errorf("cloud token is not set") } wsURL, err := url.Parse(config.CloudURL) if err != nil { return fmt.Errorf("failed to parse config.CloudURL: %w", err) } if wsURL.Scheme == "http" { wsURL.Scheme = "ws" } else { wsURL.Scheme = "wss" } setCloudConnectionState(CloudConnectionStateConnecting) header := http.Header{} header.Set("X-Device-ID", GetDeviceID()) header.Set("X-App-Version", builtAppVersion) header.Set("Authorization", "Bearer "+config.CloudToken) dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout) l := websocketLogger.With(). Str("source", wsURL.Host). Str("sourceType", "cloud"). Logger() scopedLogger := &l defer cancelDial() c, resp, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{ HTTPHeader: header, OnPingReceived: func(ctx context.Context, payload []byte) bool { scopedLogger.Debug().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received") metricConnectionTotalPingReceivedCount.WithLabelValues("cloud", wsURL.Host).Inc() metricConnectionLastPingReceivedTimestamp.WithLabelValues("cloud", wsURL.Host).SetToCurrentTime() setCloudConnectionState(CloudConnectionStateConnected) return true }, }) var connectionId string if resp != nil { // get the request id from the response header connectionId = resp.Header.Get("X-Request-ID") if connectionId == "" { connectionId = resp.Header.Get("Cf-Ray") } } if connectionId == "" { connectionId = uuid.New().String() scopedLogger.Warn(). Str("connectionId", connectionId). Msg("no connection id received from the server, generating a new one") } lWithConnectionId := scopedLogger.With(). Str("connectionID", connectionId). Logger() scopedLogger = &lWithConnectionId // if the context is canceled, we don't want to return an error if err != nil { if errors.Is(err, context.Canceled) { cloudLogger.Info().Msg("websocket connection canceled") setCloudConnectionState(CloudConnectionStateDisconnected) return nil } return err } defer c.CloseNow() //nolint:errcheck cloudLogger.Info(). Str("url", wsURL.String()). Str("connectionID", connectionId). Msg("websocket connected") // set the metrics when we successfully connect to the cloud. wsResetMetrics(true, "cloud", wsURL.Host) // we don't have a source for the cloud connection return handleWebRTCSignalWsMessages(c, true, wsURL.Host, connectionId, scopedLogger) } func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { oidcCtx, cancelOIDC := context.WithTimeout(ctx, CloudOidcRequestTimeout) defer cancelOIDC() provider, err := oidc.NewProvider(oidcCtx, "https://accounts.google.com") if err != nil { _ = wsjson.Write(context.Background(), c, gin.H{ "error": fmt.Sprintf("failed to initialize OIDC provider: %v", err), }) cloudLogger.Warn().Err(err).Msg("failed to initialize OIDC provider") return err } oidcConfig := &oidc.Config{ SkipClientIDCheck: true, } verifier := provider.Verifier(oidcConfig) idToken, err := verifier.Verify(oidcCtx, req.OidcGoogle) if err != nil { return err } googleIdentity := idToken.Audience[0] + ":" + idToken.Subject if config.GoogleIdentity != googleIdentity { _ = wsjson.Write(context.Background(), c, gin.H{"error": "google identity mismatch"}) return fmt.Errorf("google identity mismatch") } return nil } func handleSessionRequest( ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool, source string, scopedLogger *zerolog.Logger, ) error { var sourceType string if isCloudConnection { sourceType = "cloud" } else { sourceType = "local" } timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(v) metricConnectionSessionRequestDuration.WithLabelValues(sourceType, source).Observe(v) })) defer timer.ObserveDuration() // If the message is from the cloud, we need to authenticate the session. if isCloudConnection { if err := authenticateSession(ctx, c, req); err != nil { return err } } session, err := newSession(SessionConfig{ ws: c, IsCloud: isCloudConnection, LocalIP: req.IP, ICEServers: req.ICEServers, Logger: scopedLogger, }) if err != nil { _ = wsjson.Write(context.Background(), c, gin.H{"error": err}) return err } sd, err := session.ExchangeOffer(req.Sd) if err != nil { _ = wsjson.Write(context.Background(), c, gin.H{"error": err}) return err } if currentSession != nil { writeJSONRPCEvent("otherSessionConnected", nil, currentSession) peerConn := currentSession.peerConnection go func() { time.Sleep(1 * time.Second) _ = peerConn.Close() }() } cloudLogger.Info().Interface("session", session).Msg("new session accepted") cloudLogger.Trace().Interface("session", session).Msg("new session accepted") currentSession = session _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd}) return nil } func RunWebsocketClient() { for { // If the cloud token is not set, we don't need to run the websocket client. if config.CloudToken == "" { time.Sleep(5 * time.Second) continue } // If the network is not up, well, we can't connect to the cloud. if !networkState.IsOnline() { cloudLogger.Warn().Msg("waiting for network to be online, will retry in 3 seconds") time.Sleep(3 * time.Second) continue } // If the system time is not synchronized, the API request will fail anyway because the TLS handshake will fail. if isTimeSyncNeeded() && !timeSync.IsSyncSuccess() { cloudLogger.Warn().Msg("system time is not synced, will retry in 3 seconds") time.Sleep(3 * time.Second) continue } err := runWebsocketClient() if err != nil { cloudLogger.Warn().Err(err).Msg("websocket client error") metricCloudConnectionStatus.Set(0) metricCloudConnectionFailureCount.Inc() time.Sleep(5 * time.Second) } } } type CloudState struct { Connected bool `json:"connected"` URL string `json:"url,omitempty"` AppURL string `json:"appUrl,omitempty"` } func rpcGetCloudState() CloudState { return CloudState{ Connected: config.CloudToken != "" && config.CloudURL != "", URL: config.CloudURL, AppURL: config.CloudAppURL, } } func rpcDeregisterDevice() error { if config.CloudToken == "" || config.CloudURL == "" { return fmt.Errorf("cloud token or URL is not set") } req, err := http.NewRequest(http.MethodDelete, config.CloudURL+"/devices/"+GetDeviceID(), nil) if err != nil { return fmt.Errorf("failed to create deregister request: %w", err) } req.Header.Set("Authorization", "Bearer "+config.CloudToken) client := &http.Client{Timeout: CloudAPIRequestTimeout} resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to send deregister request: %w", err) } defer resp.Body.Close() // We consider both 200 OK and 404 Not Found as successful deregistration. // 200 OK means the device was found and deregistered. // 404 Not Found means the device is not in the database, which could be due to various reasons // (e.g., wrong cloud token, already deregistered). Regardless of the reason, we can safely remove it. if resp.StatusCode == http.StatusNotFound || (resp.StatusCode >= 200 && resp.StatusCode < 300) { config.CloudToken = "" config.GoogleIdentity = "" if err := SaveConfig(); err != nil { return fmt.Errorf("failed to save configuration after deregistering: %w", err) } cloudLogger.Info().Msg("device deregistered, disconnecting from cloud") disconnectCloud(fmt.Errorf("device deregistered")) setCloudConnectionState(CloudConnectionStateNotConfigured) return nil } return fmt.Errorf("deregister request failed with status: %s", resp.Status) }