diff --git a/internal/native/grpc_server.go b/internal/native/grpc_server.go index dc177ef9..9b54fb5b 100644 --- a/internal/native/grpc_server.go +++ b/internal/native/grpc_server.go @@ -15,18 +15,20 @@ import ( // grpcServer wraps the Native instance and implements the gRPC service type grpcServer struct { pb.UnimplementedNativeServiceServer - native *Native - logger *zerolog.Logger - eventCh chan *pb.Event - eventM sync.Mutex + native *Native + logger *zerolog.Logger + eventStreamChan chan *pb.Event + eventStreamMu sync.Mutex + eventStreamCtx context.Context + eventStreamCancel context.CancelFunc } // NewGRPCServer creates a new gRPC server for the native service func NewGRPCServer(n *Native, logger *zerolog.Logger) *grpcServer { s := &grpcServer{ - native: n, - logger: logger, - eventCh: make(chan *pb.Event, 100), + native: n, + logger: logger, + eventStreamChan: make(chan *pb.Event, 100), } // Store original callbacks and wrap them to also broadcast events @@ -82,7 +84,7 @@ func NewGRPCServer(n *Native, logger *zerolog.Logger) *grpcServer { } func (s *grpcServer) broadcastEvent(event *pb.Event) { - s.eventCh <- event + s.eventStreamChan <- event } func (s *grpcServer) IsReady(ctx context.Context, req *pb.IsReadyRequest) (*pb.IsReadyResponse, error) { @@ -94,15 +96,49 @@ func (s *grpcServer) StreamEvents(req *pb.Empty, stream pb.NativeService_StreamE setProcTitle("connected") defer setProcTitle("waiting") + // Cancel previous stream if exists + s.eventStreamMu.Lock() + if s.eventStreamCancel != nil { + s.logger.Debug().Msg("cancelling previous StreamEvents call") + s.eventStreamCancel() + } + + // Create a cancellable context for this stream + ctx, cancel := context.WithCancel(stream.Context()) + s.eventStreamCtx = ctx + s.eventStreamCancel = cancel + s.eventStreamMu.Unlock() + + // Clean up when this stream ends + defer func() { + s.eventStreamMu.Lock() + defer s.eventStreamMu.Unlock() + if s.eventStreamCtx == ctx { + s.eventStreamCancel = nil + s.eventStreamCtx = nil + } + cancel() + }() + // Stream events for { select { - case event := <-s.eventCh: + case event := <-s.eventStreamChan: + // Check if this stream is still the active one + s.eventStreamMu.Lock() + isActive := s.eventStreamCtx == ctx + s.eventStreamMu.Unlock() + + if !isActive { + s.logger.Debug().Msg("stream replaced by new call, exiting") + return context.Canceled + } + if err := stream.Send(event); err != nil { return err } - case <-stream.Context().Done(): - return stream.Context().Err() + case <-ctx.Done(): + return ctx.Err() } } }