chore(websocket): use MetricVec instead of Metric to store metrics

This commit is contained in:
Siyuan Miao 2025-04-08 16:44:10 +02:00
parent 5b80ef9b62
commit 50d4374a26
2 changed files with 89 additions and 50 deletions

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
@ -52,19 +51,21 @@ var (
Help: "The timestamp when the cloud connection was established", Help: "The timestamp when the cloud connection was established",
}, },
) )
metricConnectionLastPingTimestamp = promauto.NewGauge( metricConnectionLastPingTimestamp = promauto.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "jetkvm_connection_last_ping_timestamp", Name: "jetkvm_connection_last_ping_timestamp",
Help: "The timestamp when the last ping response was received", Help: "The timestamp when the last ping response was received",
}, },
[]string{"type", "source"},
) )
metricConnectionLastPingDuration = promauto.NewGauge( metricConnectionLastPingDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "jetkvm_connection_last_ping_duration", Name: "jetkvm_connection_last_ping_duration",
Help: "The duration of the last ping response", Help: "The duration of the last ping response",
}, },
[]string{"type", "source"},
) )
metricConnectionPingDuration = promauto.NewHistogram( metricConnectionPingDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{ prometheus.HistogramOpts{
Name: "jetkvm_connection_ping_duration", Name: "jetkvm_connection_ping_duration",
Help: "The duration of the ping response", Help: "The duration of the ping response",
@ -72,20 +73,23 @@ var (
0.1, 0.5, 1, 10, 0.1, 0.5, 1, 10,
}, },
}, },
[]string{"type", "source"},
) )
metricConnectionTotalPingCount = promauto.NewCounter( metricConnectionTotalPingCount = promauto.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "jetkvm_connection_total_ping_count", Name: "jetkvm_connection_total_ping_count",
Help: "The total number of pings sent to the connection", Help: "The total number of pings sent to the connection",
}, },
[]string{"type", "source"},
) )
metricConnectionSessionRequestCount = promauto.NewCounter( metricConnectionSessionRequestCount = promauto.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "jetkvm_connection_session_total_request_count", Name: "jetkvm_connection_session_total_request_count",
Help: "The total number of session requests received from the", Help: "The total number of session requests received",
}, },
[]string{"type", "source"},
) )
metricConnectionSessionRequestDuration = promauto.NewHistogram( metricConnectionSessionRequestDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{ prometheus.HistogramOpts{
Name: "jetkvm_connection_session_request_duration", Name: "jetkvm_connection_session_request_duration",
Help: "The duration of session requests", Help: "The duration of session requests",
@ -93,18 +97,21 @@ var (
0.1, 0.5, 1, 10, 0.1, 0.5, 1, 10,
}, },
}, },
[]string{"type", "source"},
) )
metricConnectionLastSessionRequestTimestamp = promauto.NewGauge( metricConnectionLastSessionRequestTimestamp = promauto.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "jetkvm_connection_last_session_request_timestamp", Name: "jetkvm_connection_last_session_request_timestamp",
Help: "The timestamp of the last session request", Help: "The timestamp of the last session request",
}, },
[]string{"type", "source"},
) )
metricConnectionLastSessionRequestDuration = promauto.NewGauge( metricConnectionLastSessionRequestDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "jetkvm_connection_last_session_request_duration", Name: "jetkvm_connection_last_session_request_duration",
Help: "The duration of the last session request", Help: "The duration of the last session request",
}, },
[]string{"type", "source"},
) )
metricCloudConnectionFailureCount = promauto.NewCounter( metricCloudConnectionFailureCount = promauto.NewCounter(
prometheus.CounterOpts{ prometheus.CounterOpts{
@ -114,17 +121,16 @@ var (
) )
) )
var ( func wsResetMetrics(established bool, sourceType string, source string) {
cloudDisconnectChan chan error metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1)
cloudDisconnectLock = &sync.Mutex{} metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1)
)
func cloudResetMetrics(established bool) { metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).Set(-1)
metricConnectionLastPingTimestamp.Set(-1) metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(-1)
metricConnectionLastPingDuration.Set(-1)
metricConnectionLastSessionRequestTimestamp.Set(-1) if sourceType != "cloud" {
metricConnectionLastSessionRequestDuration.Set(-1) return
}
if established { if established {
metricCloudConnectionEstablishedTimestamp.SetToCurrentTime() metricCloudConnectionEstablishedTimestamp.SetToCurrentTime()
@ -270,9 +276,10 @@ func runWebsocketClient() error {
cloudLogger.Infof("websocket connected to %s", wsURL) cloudLogger.Infof("websocket connected to %s", wsURL)
// set the metrics when we successfully connect to the cloud. // set the metrics when we successfully connect to the cloud.
cloudResetMetrics(true) wsResetMetrics(true, "cloud", "")
return handleWebRTCSignalWsMessages(c, true) // we don't have a source for the cloud connection
return handleWebRTCSignalWsMessages(c, true, "")
} }
func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {
@ -306,10 +313,17 @@ func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessi
return nil return nil
} }
func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool) error { func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest, isCloudConnection bool, source string) error {
var sourceType string
if isCloudConnection {
sourceType = "cloud"
} else {
sourceType = "local"
}
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricConnectionLastSessionRequestDuration.Set(v) metricConnectionLastSessionRequestDuration.WithLabelValues(sourceType, source).Set(v)
metricConnectionSessionRequestDuration.Observe(v) metricConnectionSessionRequestDuration.WithLabelValues(sourceType, source).Observe(v)
})) }))
defer timer.ObserveDuration() defer timer.ObserveDuration()
@ -355,7 +369,7 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess
func RunWebsocketClient() { func RunWebsocketClient() {
for { for {
// reset the metrics when we start the websocket client. // reset the metrics when we start the websocket client.
cloudResetMetrics(false) wsResetMetrics(false, "cloud", "")
// If the cloud token is not set, we don't need to run the websocket client. // If the cloud token is not set, we don't need to run the websocket client.
if config.CloudToken == "" { if config.CloudToken == "" {

75
web.go
View File

@ -138,16 +138,19 @@ func handleLocalWebRTCSignal(c *gin.Context) {
return return
} }
// get the source from the request
source := c.ClientIP()
// Now use conn for websocket operations // Now use conn for websocket operations
defer wsCon.Close(websocket.StatusNormalClosure, "") defer wsCon.Close(websocket.StatusNormalClosure, "")
err = handleWebRTCSignalWsMessages(wsCon, false) err = handleWebRTCSignalWsMessages(wsCon, false, source)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
} }
func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool) error { func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string) error {
runCtx, cancelRun := context.WithCancel(context.Background()) runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun() defer cancelRun()
@ -155,21 +158,43 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool)
connectionID := uuid.New().String() connectionID := uuid.New().String()
cloudLogger.Infof("new websocket connection established with ID: %s", connectionID) cloudLogger.Infof("new websocket connection established with ID: %s", connectionID)
// connection type
var sourceType string
if isCloudConnection {
sourceType = "cloud"
} else {
sourceType = "local"
}
// probably we can use a better logging framework here
logInfof := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Infof(format+", source: %s, sourceType: %s", args...)
}
logWarnf := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...)
}
logTracef := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...)
}
go func() { go func() {
for { for {
time.Sleep(WebsocketPingInterval) time.Sleep(WebsocketPingInterval)
// set the timer for the ping duration // set the timer for the ping duration
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricConnectionLastPingDuration.Set(v) metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v)
metricConnectionPingDuration.Observe(v) metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v)
})) }))
cloudLogger.Infof("pinging websocket") logInfof("pinging websocket")
err := wsCon.Ping(runCtx) err := wsCon.Ping(runCtx)
if err != nil { if err != nil {
cloudLogger.Warnf("websocket ping error: %v", err) logWarnf("websocket ping error: %v", err)
cancelRun() cancelRun()
return return
} }
@ -177,15 +202,15 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool)
// dont use `defer` here because we want to observe the duration of the ping // dont use `defer` here because we want to observe the duration of the ping
timer.ObserveDuration() timer.ObserveDuration()
metricConnectionTotalPingCount.Inc() metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc()
metricConnectionLastPingTimestamp.SetToCurrentTime() metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
} }
}() }()
for { for {
typ, msg, err := wsCon.Read(runCtx) typ, msg, err := wsCon.Read(runCtx)
if err != nil { if err != nil {
websocketLogger.Warnf("websocket read error: %v", err) logWarnf("websocket read error: %v", err)
return err return err
} }
if typ != websocket.MessageText { if typ != websocket.MessageText {
@ -200,54 +225,54 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool)
err = json.Unmarshal(msg, &message) err = json.Unmarshal(msg, &message)
if err != nil { if err != nil {
websocketLogger.Warnf("unable to parse ws message: %v", string(msg)) logWarnf("unable to parse ws message: %v", err)
continue continue
} }
if message.Type == "offer" { if message.Type == "offer" {
websocketLogger.Infof("new session request received") logInfof("new session request received")
var req WebRTCSessionRequest var req WebRTCSessionRequest
err = json.Unmarshal(message.Data, &req) err = json.Unmarshal(message.Data, &req)
if err != nil { if err != nil {
websocketLogger.Warnf("unable to parse session request data: %v", string(message.Data)) logWarnf("unable to parse session request data: %v", err)
continue continue
} }
websocketLogger.Infof("new session request: %v", req.OidcGoogle) logInfof("new session request: %v", req.OidcGoogle)
websocketLogger.Tracef("session request info: %v", req) logTracef("session request info: %v", req)
metricConnectionSessionRequestCount.Inc() metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc()
metricConnectionLastSessionRequestTimestamp.SetToCurrentTime() metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection) err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source)
if err != nil { if err != nil {
websocketLogger.Infof("error starting new session: %v", err) logWarnf("error starting new session: %v", err)
continue continue
} }
} else if message.Type == "new-ice-candidate" { } else if message.Type == "new-ice-candidate" {
websocketLogger.Infof("The client sent us a new ICE candidate: %v", string(message.Data)) logInfof("The client sent us a new ICE candidate: %v", string(message.Data))
var candidate webrtc.ICECandidateInit var candidate webrtc.ICECandidateInit
// Attempt to unmarshal as a ICECandidateInit // Attempt to unmarshal as a ICECandidateInit
if err := json.Unmarshal(message.Data, &candidate); err != nil { if err := json.Unmarshal(message.Data, &candidate); err != nil {
websocketLogger.Warnf("unable to parse incoming ICE candidate data: %v", string(message.Data)) logWarnf("unable to parse incoming ICE candidate data: %v", string(message.Data))
continue continue
} }
if candidate.Candidate == "" { if candidate.Candidate == "" {
websocketLogger.Warnf("empty incoming ICE candidate, skipping") logWarnf("empty incoming ICE candidate, skipping")
continue continue
} }
websocketLogger.Infof("unmarshalled incoming ICE candidate: %v", candidate) logInfof("unmarshalled incoming ICE candidate: %v", candidate)
if currentSession == nil { if currentSession == nil {
websocketLogger.Infof("no current session, skipping incoming ICE candidate") logInfof("no current session, skipping incoming ICE candidate")
continue continue
} }
websocketLogger.Infof("adding incoming ICE candidate to current session: %v", candidate) logInfof("adding incoming ICE candidate to current session: %v", candidate)
if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil { if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
websocketLogger.Warnf("failed to add incoming ICE candidate to our peer connection: %v", err) logWarnf("failed to add incoming ICE candidate to our peer connection: %v", err)
} }
} }
} }