From 50d4374a260d5523daa0a07b6ec98f59173efcef Mon Sep 17 00:00:00 2001 From: Siyuan Miao Date: Tue, 8 Apr 2025 16:44:10 +0200 Subject: [PATCH] chore(websocket): use MetricVec instead of Metric to store metrics --- cloud.go | 64 ++++++++++++++++++++++++++++------------------- web.go | 75 +++++++++++++++++++++++++++++++++++++------------------- 2 files changed, 89 insertions(+), 50 deletions(-) diff --git a/cloud.go b/cloud.go index 516a0b1..e22b775 100644 --- a/cloud.go +++ b/cloud.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/url" - "sync" "time" "github.com/coder/websocket/wsjson" @@ -52,19 +51,21 @@ var ( Help: "The timestamp when the cloud connection was established", }, ) - metricConnectionLastPingTimestamp = promauto.NewGauge( + metricConnectionLastPingTimestamp = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_ping_timestamp", Help: "The timestamp when the last ping response was received", }, + []string{"type", "source"}, ) - metricConnectionLastPingDuration = promauto.NewGauge( + metricConnectionLastPingDuration = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_ping_duration", Help: "The duration of the last ping response", }, + []string{"type", "source"}, ) - metricConnectionPingDuration = promauto.NewHistogram( + metricConnectionPingDuration = promauto.NewHistogramVec( prometheus.HistogramOpts{ Name: "jetkvm_connection_ping_duration", Help: "The duration of the ping response", @@ -72,20 +73,23 @@ var ( 0.1, 0.5, 1, 10, }, }, + []string{"type", "source"}, ) - metricConnectionTotalPingCount = promauto.NewCounter( + metricConnectionTotalPingCount = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "jetkvm_connection_total_ping_count", Help: "The total number of pings sent to the connection", }, + []string{"type", "source"}, ) - metricConnectionSessionRequestCount = promauto.NewCounter( + metricConnectionSessionRequestCount = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "jetkvm_connection_session_total_request_count", - Help: "The total number of session requests received from the", + Help: "The total number of session requests received", }, + []string{"type", "source"}, ) - metricConnectionSessionRequestDuration = promauto.NewHistogram( + metricConnectionSessionRequestDuration = promauto.NewHistogramVec( prometheus.HistogramOpts{ Name: "jetkvm_connection_session_request_duration", Help: "The duration of session requests", @@ -93,18 +97,21 @@ var ( 0.1, 0.5, 1, 10, }, }, + []string{"type", "source"}, ) - metricConnectionLastSessionRequestTimestamp = promauto.NewGauge( + metricConnectionLastSessionRequestTimestamp = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_session_request_timestamp", Help: "The timestamp of the last session request", }, + []string{"type", "source"}, ) - metricConnectionLastSessionRequestDuration = promauto.NewGauge( + metricConnectionLastSessionRequestDuration = promauto.NewGaugeVec( prometheus.GaugeOpts{ Name: "jetkvm_connection_last_session_request_duration", Help: "The duration of the last session request", }, + []string{"type", "source"}, ) metricCloudConnectionFailureCount = promauto.NewCounter( prometheus.CounterOpts{ @@ -114,17 +121,16 @@ var ( ) ) -var ( - cloudDisconnectChan chan error - cloudDisconnectLock = &sync.Mutex{} -) +func wsResetMetrics(established bool, sourceType string, source string) { + metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1) + metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1) -func cloudResetMetrics(established bool) { - metricConnectionLastPingTimestamp.Set(-1) - metricConnectionLastPingDuration.Set(-1) + metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1) + metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1) - metricConnectionLastSessionRequestTimestamp.Set(-1) - metricConnectionLastSessionRequestDuration.Set(-1) + if sourceType != "cloud" { + return + } if established { metricCloudConnectionEstablishedTimestamp.SetToCurrentTime() @@ -270,9 +276,10 @@ func runWebsocketClient() error { cloudLogger.Infof("websocket connected to %s", wsURL) // set the metrics when we successfully connect to the cloud. - cloudResetMetrics(true) + wsResetMetrics(true, "cloud", "") - return handleWebRTCSignalWsMessages(c, true) + // we don't have a source for the cloud connection + return handleWebRTCSignalWsMessages(c, true, "") } func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { @@ -306,10 +313,17 @@ func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessi return nil } -func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool) error { +func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool, source string) error { + var sourceType string + if isCloudConnection { + sourceType = "cloud" + } else { + sourceType = "local" + } + timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - metricConnectionLastSessionRequestDuration.Set(v) - metricConnectionSessionRequestDuration.Observe(v) + metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(v) + metricConnectionSessionRequestDuration.WithLabelValues(sourceType, source).Observe(v) })) defer timer.ObserveDuration() @@ -355,7 +369,7 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess func RunWebsocketClient() { for { // reset the metrics when we start the websocket client. - cloudResetMetrics(false) + wsResetMetrics(false, "cloud", "") // If the cloud token is not set, we don't need to run the websocket client. if config.CloudToken == "" { diff --git a/web.go b/web.go index 19efab1..7d7724b 100644 --- a/web.go +++ b/web.go @@ -138,16 +138,19 @@ func handleLocalWebRTCSignal(c *gin.Context) { return } + // get the source from the request + source := c.ClientIP() + // Now use conn for websocket operations defer wsCon.Close(websocket.StatusNormalClosure, "") - err = handleWebRTCSignalWsMessages(wsCon, false) + err = handleWebRTCSignalWsMessages(wsCon, false, source) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } } -func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool) error { +func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string) error { runCtx, cancelRun := context.WithCancel(context.Background()) defer cancelRun() @@ -155,21 +158,43 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool) connectionID := uuid.New().String() cloudLogger.Infof("new websocket connection established with ID: %s", connectionID) + // connection type + var sourceType string + if isCloudConnection { + sourceType = "cloud" + } else { + sourceType = "local" + } + + // probably we can use a better logging framework here + logInfof := func(format string, args ...interface{}) { + args = append(args, source, sourceType) + websocketLogger.Infof(format+", source: %s, sourceType: %s", args...) + } + logWarnf := func(format string, args ...interface{}) { + args = append(args, source, sourceType) + websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...) + } + logTracef := func(format string, args ...interface{}) { + args = append(args, source, sourceType) + websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...) + } + go func() { for { time.Sleep(WebsocketPingInterval) // set the timer for the ping duration timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - metricConnectionLastPingDuration.Set(v) - metricConnectionPingDuration.Observe(v) + metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v) + metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v) })) - cloudLogger.Infof("pinging websocket") + logInfof("pinging websocket") err := wsCon.Ping(runCtx) if err != nil { - cloudLogger.Warnf("websocket ping error: %v", err) + logWarnf("websocket ping error: %v", err) cancelRun() return } @@ -177,15 +202,15 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool) // dont use `defer` here because we want to observe the duration of the ping timer.ObserveDuration() - metricConnectionTotalPingCount.Inc() - metricConnectionLastPingTimestamp.SetToCurrentTime() + metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc() + metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() } }() for { typ, msg, err := wsCon.Read(runCtx) if err != nil { - websocketLogger.Warnf("websocket read error: %v", err) + logWarnf("websocket read error: %v", err) return err } if typ != websocket.MessageText { @@ -200,54 +225,54 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool) err = json.Unmarshal(msg, &message) if err != nil { - websocketLogger.Warnf("unable to parse ws message: %v", string(msg)) + logWarnf("unable to parse ws message: %v", err) continue } if message.Type == "offer" { - websocketLogger.Infof("new session request received") + logInfof("new session request received") var req WebRTCSessionRequest err = json.Unmarshal(message.Data, &req) if err != nil { - websocketLogger.Warnf("unable to parse session request data: %v", string(message.Data)) + logWarnf("unable to parse session request data: %v", err) continue } - websocketLogger.Infof("new session request: %v", req.OidcGoogle) - websocketLogger.Tracef("session request info: %v", req) + logInfof("new session request: %v", req.OidcGoogle) + logTracef("session request info: %v", req) - metricConnectionSessionRequestCount.Inc() - metricConnectionLastSessionRequestTimestamp.SetToCurrentTime() - err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection) + metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc() + metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() + err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source) if err != nil { - websocketLogger.Infof("error starting new session: %v", err) + logWarnf("error starting new session: %v", err) continue } } else if message.Type == "new-ice-candidate" { - websocketLogger.Infof("The client sent us a new ICE candidate: %v", string(message.Data)) + logInfof("The client sent us a new ICE candidate: %v", string(message.Data)) var candidate webrtc.ICECandidateInit // Attempt to unmarshal as a ICECandidateInit if err := json.Unmarshal(message.Data, &candidate); err != nil { - websocketLogger.Warnf("unable to parse incoming ICE candidate data: %v", string(message.Data)) + logWarnf("unable to parse incoming ICE candidate data: %v", string(message.Data)) continue } if candidate.Candidate == "" { - websocketLogger.Warnf("empty incoming ICE candidate, skipping") + logWarnf("empty incoming ICE candidate, skipping") continue } - websocketLogger.Infof("unmarshalled incoming ICE candidate: %v", candidate) + logInfof("unmarshalled incoming ICE candidate: %v", candidate) if currentSession == nil { - websocketLogger.Infof("no current session, skipping incoming ICE candidate") + logInfof("no current session, skipping incoming ICE candidate") continue } - websocketLogger.Infof("adding incoming ICE candidate to current session: %v", candidate) + logInfof("adding incoming ICE candidate to current session: %v", candidate) if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil { - websocketLogger.Warnf("failed to add incoming ICE candidate to our peer connection: %v", err) + logWarnf("failed to add incoming ICE candidate to our peer connection: %v", err) } } }