diff --git a/cloud.go b/cloud.go index 579d1f6..fd96c41 100644 --- a/cloud.go +++ b/cloud.go @@ -12,6 +12,7 @@ import ( "time" "github.com/coder/websocket/wsjson" + "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -290,13 +291,15 @@ func runWebsocketClient() error { header.Set("Authorization", "Bearer "+config.CloudToken) dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout) - scopedLogger := websocketLogger.With(). + l := websocketLogger.With(). Str("source", wsURL.Host). Str("sourceType", "cloud"). Logger() + scopedLogger := &l + defer cancelDial() - c, _, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{ + c, resp, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{ HTTPHeader: header, OnPingReceived: func(ctx context.Context, payload []byte) bool { scopedLogger.Info().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received") @@ -307,6 +310,24 @@ func runWebsocketClient() error { return true }, }) + + // get the request id from the response header + connectionId := resp.Header.Get("X-Request-ID") + if connectionId == "" { + connectionId = resp.Header.Get("Cf-Ray") + } + if connectionId == "" { + connectionId = uuid.New().String() + scopedLogger.Warn(). + Str("connectionId", connectionId). + Msg("no connection id received from the server, generating a new one") + } + + lWithConnectionId := scopedLogger.With(). + Str("connectionID", connectionId). + Logger() + scopedLogger = &lWithConnectionId + // if the context is canceled, we don't want to return an error if err != nil { if errors.Is(err, context.Canceled) { @@ -316,13 +337,16 @@ func runWebsocketClient() error { return err } defer c.CloseNow() //nolint:errcheck - cloudLogger.Info().Str("url", wsURL.String()).Msg("websocket connected") + cloudLogger.Info(). + Str("url", wsURL.String()). + Str("connectionID", connectionId). + Msg("websocket connected") // set the metrics when we successfully connect to the cloud. wsResetMetrics(true, "cloud", wsURL.Host) // we don't have a source for the cloud connection - return handleWebRTCSignalWsMessages(c, true, wsURL.Host, &scopedLogger) + return handleWebRTCSignalWsMessages(c, true, wsURL.Host, connectionId, scopedLogger) } func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { diff --git a/web.go b/web.go index 8ff5929..6e74a13 100644 --- a/web.go +++ b/web.go @@ -189,6 +189,7 @@ var ( func handleLocalWebRTCSignal(c *gin.Context) { // get the source from the request source := c.ClientIP() + connectionID := uuid.New().String() scopedLogger := websocketLogger.With(). Str("component", "websocket"). @@ -226,20 +227,23 @@ func handleLocalWebRTCSignal(c *gin.Context) { return } - err = handleWebRTCSignalWsMessages(wsCon, false, source, &scopedLogger) + err = handleWebRTCSignalWsMessages(wsCon, false, source, connectionID, &scopedLogger) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } } -func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string, scopedLogger *zerolog.Logger) error { +func handleWebRTCSignalWsMessages( + wsCon *websocket.Conn, + isCloudConnection bool, + source string, + connectionID string, + scopedLogger *zerolog.Logger, +) error { runCtx, cancelRun := context.WithCancel(context.Background()) defer cancelRun() - // Add connection tracking to detect reconnections - connectionID := uuid.New().String() - // connection type var sourceType string if isCloudConnection {