chore(cloud): use request id from the cloud

This commit is contained in:
Siyuan Miao 2025-04-11 16:03:46 +02:00
parent f98eaddf15
commit 94e83249ef
2 changed files with 37 additions and 9 deletions

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
@ -290,13 +291,15 @@ func runWebsocketClient() error {
header.Set("Authorization", "Bearer "+config.CloudToken) header.Set("Authorization", "Bearer "+config.CloudToken)
dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout) dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout)
scopedLogger := websocketLogger.With(). l := websocketLogger.With().
Str("source", wsURL.Host). Str("source", wsURL.Host).
Str("sourceType", "cloud"). Str("sourceType", "cloud").
Logger() Logger()
scopedLogger := &l
defer cancelDial() defer cancelDial()
c, _, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{ c, resp, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{
HTTPHeader: header, HTTPHeader: header,
OnPingReceived: func(ctx context.Context, payload []byte) bool { OnPingReceived: func(ctx context.Context, payload []byte) bool {
scopedLogger.Info().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received") scopedLogger.Info().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received")
@ -307,6 +310,24 @@ func runWebsocketClient() error {
return true return true
}, },
}) })
// get the request id from the response header
connectionId := resp.Header.Get("X-Request-ID")
if connectionId == "" {
connectionId = resp.Header.Get("Cf-Ray")
}
if connectionId == "" {
connectionId = uuid.New().String()
scopedLogger.Warn().
Str("connectionId", connectionId).
Msg("no connection id received from the server, generating a new one")
}
lWithConnectionId := scopedLogger.With().
Str("connectionID", connectionId).
Logger()
scopedLogger = &lWithConnectionId
// if the context is canceled, we don't want to return an error // if the context is canceled, we don't want to return an error
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
@ -316,13 +337,16 @@ func runWebsocketClient() error {
return err return err
} }
defer c.CloseNow() //nolint:errcheck defer c.CloseNow() //nolint:errcheck
cloudLogger.Info().Str("url", wsURL.String()).Msg("websocket connected") cloudLogger.Info().
Str("url", wsURL.String()).
Str("connectionID", connectionId).
Msg("websocket connected")
// set the metrics when we successfully connect to the cloud. // set the metrics when we successfully connect to the cloud.
wsResetMetrics(true, "cloud", wsURL.Host) wsResetMetrics(true, "cloud", wsURL.Host)
// we don't have a source for the cloud connection // we don't have a source for the cloud connection
return handleWebRTCSignalWsMessages(c, true, wsURL.Host, &scopedLogger) return handleWebRTCSignalWsMessages(c, true, wsURL.Host, connectionId, scopedLogger)
} }
func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error {

14
web.go
View File

@ -189,6 +189,7 @@ var (
func handleLocalWebRTCSignal(c *gin.Context) { func handleLocalWebRTCSignal(c *gin.Context) {
// get the source from the request // get the source from the request
source := c.ClientIP() source := c.ClientIP()
connectionID := uuid.New().String()
scopedLogger := websocketLogger.With(). scopedLogger := websocketLogger.With().
Str("component", "websocket"). Str("component", "websocket").
@ -226,20 +227,23 @@ func handleLocalWebRTCSignal(c *gin.Context) {
return return
} }
err = handleWebRTCSignalWsMessages(wsCon, false, source, &scopedLogger) err = handleWebRTCSignalWsMessages(wsCon, false, source, connectionID, &scopedLogger)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
} }
func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string, scopedLogger *zerolog.Logger) error { func handleWebRTCSignalWsMessages(
wsCon *websocket.Conn,
isCloudConnection bool,
source string,
connectionID string,
scopedLogger *zerolog.Logger,
) error {
runCtx, cancelRun := context.WithCancel(context.Background()) runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun() defer cancelRun()
// Add connection tracking to detect reconnections
connectionID := uuid.New().String()
// connection type // connection type
var sourceType string var sourceType string
if isCloudConnection { if isCloudConnection {