diff --git a/cloud.go b/cloud.go index 7ad8b75..89666a1 100644 --- a/cloud.go +++ b/cloud.go @@ -275,6 +275,10 @@ func runWebsocketClient() error { defer cancelDial() c, _, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{ HTTPHeader: header, + OnPingReceived: func(ctx context.Context, payload []byte) bool { + websocketLogger.Infof("ping frame received: %v, source: %s, sourceType: cloud", payload, wsURL.Host) + return true + }, }) if err != nil { return err diff --git a/go.mod b/go.mod index 93fedab..fae1cbd 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.21.1 require ( github.com/Masterminds/semver/v3 v3.3.0 github.com/beevik/ntp v1.3.1 - github.com/coder/websocket v1.8.12 + github.com/coder/websocket v1.8.13 github.com/coreos/go-oidc/v3 v3.11.0 github.com/creack/pty v1.1.23 github.com/gin-gonic/gin v1.9.1 diff --git a/go.sum b/go.sum index b5769d8..1563130 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI= github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/creack/goselect v0.1.2 h1:2DNy14+JPjRBgPzAd1thbQp4BSIihxcBf0IXhQXDRa0= diff --git a/web.go b/web.go index 51bcd98..0258dc6 100644 --- a/web.go +++ b/web.go @@ -1,6 +1,7 @@ package kvm import ( + "bytes" "context" "embed" "encoding/json" @@ -173,11 +174,24 @@ func handleWebRTCSession(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"sd": sd}) } +var ( + pingMessage = []byte("ping") + pongMessage = []byte("pong") +) + func handleLocalWebRTCSignal(c *gin.Context) { cloudLogger.Infof("new websocket connection established") + + // get the source from the request + source := c.ClientIP() + // Create WebSocket options with InsecureSkipVerify to bypass origin check wsOptions := &websocket.AcceptOptions{ InsecureSkipVerify: true, // Allow connections from any origin + OnPingReceived: func(ctx context.Context, payload []byte) bool { + websocketLogger.Infof("ping frame received: %v, source: %s, sourceType: local", payload, source) + return true + }, } wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions) @@ -186,9 +200,6 @@ func handleLocalWebRTCSignal(c *gin.Context) { return } - // get the source from the request - source := c.ClientIP() - // Now use conn for websocket operations defer wsCon.Close(websocket.StatusNormalClosure, "") @@ -211,7 +222,6 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, // Add connection tracking to detect reconnections connectionID := uuid.New().String() - cloudLogger.Infof("new websocket connection established with ID: %s", connectionID) // connection type var sourceType string @@ -223,18 +233,20 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, // probably we can use a better logging framework here logInfof := func(format string, args ...interface{}) { - args = append(args, source, sourceType) - websocketLogger.Infof(format+", source: %s, sourceType: %s", args...) + args = append(args, source, sourceType, connectionID) + websocketLogger.Infof(format+", source: %s, sourceType: %s, id: %s", args...) } logWarnf := func(format string, args ...interface{}) { - args = append(args, source, sourceType) - websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...) + args = append(args, source, sourceType, connectionID) + websocketLogger.Warnf(format+", source: %s, sourceType: %s, id: %s", args...) } logTracef := func(format string, args ...interface{}) { - args = append(args, source, sourceType) - websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...) + args = append(args, source, sourceType, connectionID) + websocketLogger.Tracef(format+", source: %s, sourceType: %s, id: %s", args...) } + logInfof("new websocket connection established") + go func() { for { time.Sleep(WebsocketPingInterval) @@ -245,7 +257,7 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v) })) - logInfof("pinging websocket") + logTracef("sending ping frame") err := wsCon.Ping(runCtx) if err != nil { @@ -255,10 +267,12 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, } // dont use `defer` here because we want to observe the duration of the ping - timer.ObserveDuration() + duration := timer.ObserveDuration() metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc() metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() + + logTracef("received pong frame, duration: %v", duration) } }() @@ -296,6 +310,16 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, Data json.RawMessage `json:"data"` } + if bytes.Equal(msg, pingMessage) { + logInfof("ping message received: %s", string(msg)) + err = wsCon.Write(context.Background(), websocket.MessageText, pongMessage) + if err != nil { + logWarnf("unable to write pong message: %v", err) + return err + } + continue + } + err = json.Unmarshal(msg, &message) if err != nil { logWarnf("unable to parse ws message: %v", err) @@ -311,8 +335,9 @@ func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, continue } - logInfof("new session request: %v", req.OidcGoogle) - logTracef("session request info: %v", req) + if req.OidcGoogle != "" { + logInfof("new session request with OIDC Google: %v", req.OidcGoogle) + } metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc() metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()