From 1ec994110393be731f6f3626fd02cecdc1a2cbb2 Mon Sep 17 00:00:00 2001 From: Marc Brooks Date: Mon, 17 Nov 2025 19:57:41 -0600 Subject: [PATCH] Simplify audio management Moved all start/stop of sources into audio (out of jsonrpc) Clean up duplicated code, made direction a bool, more logging, made all source/relay atomics. Eliminate SetConfig since we always set it during start. Eliminate the extra initialized flag. Properly detect when USB audio was previously active. Relay has the pointer to the source, not a copy. CgoSource (and stub) expose the AudioSource interface. --- audio.go | 180 ++++++++++++------------------ internal/audio/cgo_source.go | 69 ++++++------ internal/audio/cgo_source_stub.go | 10 +- internal/audio/relay.go | 16 +-- internal/audio/source.go | 1 - jsonrpc.go | 69 ++++-------- 6 files changed, 144 insertions(+), 201 deletions(-) diff --git a/audio.go b/audio.go index a5305dd2..368b2f7e 100644 --- a/audio.go +++ b/audio.go @@ -15,10 +15,10 @@ var ( audioMutex sync.Mutex setAudioTrackMutex sync.Mutex // Prevents concurrent setAudioTrack() calls inputSourceMutex sync.Mutex // Serializes Connect() and WriteMessage() calls to input source - outputSource audio.AudioSource + outputSource atomic.Pointer[audio.AudioSource] inputSource atomic.Pointer[audio.AudioSource] - outputRelay *audio.OutputRelay - inputRelay *audio.InputRelay + outputRelay atomic.Pointer[audio.OutputRelay] + inputRelay atomic.Pointer[audio.InputRelay] audioInitialized bool activeConnections atomic.Int32 audioLogger zerolog.Logger @@ -79,58 +79,81 @@ func startAudio() error { return nil } - if outputSource == nil && audioOutputEnabled.Load() && currentAudioTrack != nil { - ensureConfigLoaded() - alsaDevice := getAlsaDevice(config.AudioOutputSource) - source := audio.NewCgoOutputSource(alsaDevice) - source.SetConfig(getAudioConfig()) - outputSource = source - outputRelay = audio.NewOutputRelay(outputSource, currentAudioTrack) - if err := outputRelay.Start(); err != nil { - audioLogger.Error().Err(err).Msg("Failed to start audio output relay") - } + if activeConnections.Load() <= 0 { + audioLogger.Debug().Msg("No active connections, skipping audio start") + return nil } ensureConfigLoaded() - if inputSource.Load() == nil && audioInputEnabled.Load() && config.UsbDevices != nil && config.UsbDevices.Audio { - alsaPlaybackDevice := getAlsaDevice("usb") - source := audio.NewCgoInputSource(alsaPlaybackDevice) - source.SetConfig(getAudioConfig()) - var audioSource audio.AudioSource = source - inputSource.Store(&audioSource) - inputRelay = audio.NewInputRelay(audioSource) - if err := inputRelay.Start(); err != nil { - audioLogger.Error().Err(err).Msg("Failed to start input relay") - } + if audioOutputEnabled.Load() && currentAudioTrack != nil { + startOutputAudioUnderMutex(getAlsaDevice(config.AudioOutputSource)) + } + + if audioInputEnabled.Load() && config.UsbDevices != nil && config.UsbDevices.Audio { + startInputAudioUnderMutex(getAlsaDevice("usb")) } return nil } +func startOutputAudioUnderMutex(alsaOutputDevice string) { + newSource := audio.NewCgoOutputSource(alsaOutputDevice, getAudioConfig()) + oldSource := outputSource.Swap(&newSource) + newRelay := audio.NewOutputRelay(&newSource, currentAudioTrack) + oldRelay := outputRelay.Swap(newRelay) + + if oldRelay != nil { + oldRelay.Stop() + } + + if oldSource != nil { + (*oldSource).Disconnect() + } + + if err := newRelay.Start(); err != nil { + audioLogger.Error().Err(err).Str("alsaOutputDevice", alsaOutputDevice).Msg("Failed to start audio output relay") + } +} + +func startInputAudioUnderMutex(alsaPlaybackDevice string) { + newSource := audio.NewCgoInputSource(alsaPlaybackDevice, getAudioConfig()) + oldSource := outputSource.Swap(&newSource) + newRelay := audio.NewInputRelay(&newSource) + oldRelay := inputRelay.Swap(newRelay) + + if oldRelay != nil { + oldRelay.Stop() + } + + if oldSource != nil { + (*oldSource).Disconnect() + } + + if err := newRelay.Start(); err != nil { + audioLogger.Error().Err(err).Str("alsaPlaybackDevice", alsaPlaybackDevice).Msg("Failed to start input relay") + } +} + func stopOutputAudio() { audioMutex.Lock() - outRelay := outputRelay - outSource := outputSource - outputRelay = nil - outputSource = nil + outRelay := outputRelay.Swap(nil) + outSource := outputSource.Swap(nil) audioMutex.Unlock() if outRelay != nil { outRelay.Stop() } if outSource != nil { - outSource.Disconnect() + (*outSource).Disconnect() } } func stopInputAudio() { audioMutex.Lock() - inRelay := inputRelay - inputRelay = nil - audioMutex.Unlock() - + inRelay := inputRelay.Swap(nil) inSource := inputSource.Swap(nil) + audioMutex.Unlock() if inRelay != nil { inRelay.Stop() @@ -156,7 +179,7 @@ func onWebRTCConnect() { func onWebRTCDisconnect() { count := activeConnections.Add(-1) - if count == 0 { + if count <= 0 { // Stop audio immediately to release HDMI audio device which shares hardware with video device stopAudio() } @@ -166,39 +189,12 @@ func setAudioTrack(audioTrack *webrtc.TrackLocalStaticSample) { setAudioTrackMutex.Lock() defer setAudioTrackMutex.Unlock() - // Capture old resources and update state in single critical section - audioMutex.Lock() + stopOutputAudio() + currentAudioTrack = audioTrack - oldRelay := outputRelay - oldSource := outputSource - outputRelay = nil - outputSource = nil - var newRelay *audio.OutputRelay - var newSource audio.AudioSource - if currentAudioTrack != nil && audioOutputEnabled.Load() { - ensureConfigLoaded() - alsaDevice := getAlsaDevice(config.AudioOutputSource) - newSource := audio.NewCgoOutputSource(alsaDevice) - newSource.SetConfig(getAudioConfig()) - newRelay = audio.NewOutputRelay(newSource, currentAudioTrack) - outputSource = newSource - outputRelay = newRelay - } - audioMutex.Unlock() - - // Stop/start resources outside mutex to avoid blocking on CGO calls - if oldRelay != nil { - oldRelay.Stop() - } - if oldSource != nil { - oldSource.Disconnect() - } - - if newRelay != nil { - if err := newRelay.Start(); err != nil { - audioLogger.Error().Err(err).Msg("Failed to start output relay") - } + if err := startAudio(); err != nil { + audioLogger.Error().Err(err).Msg("Failed to start with new audio track") } } @@ -250,72 +246,44 @@ func SetAudioOutputSource(source string) error { return nil } + stopOutputAudio() config.AudioOutputSource = source - stopOutputAudio() - - if audioOutputEnabled.Load() && activeConnections.Load() > 0 && currentAudioTrack != nil { - alsaDevice := getAlsaDevice(source) - newSource := audio.NewCgoOutputSource(alsaDevice) - newSource.SetConfig(getAudioConfig()) - newRelay := audio.NewOutputRelay(newSource, currentAudioTrack) - - audioMutex.Lock() - outputSource = newSource - outputRelay = newRelay - audioMutex.Unlock() - - if err := newRelay.Start(); err != nil { - audioLogger.Error().Err(err).Str("source", source).Msg("Failed to start audio relay with new source") - } + if err := startAudio(); err != nil { + audioLogger.Error().Err(err).Str("source", source).Msg("Failed to start audio output after source change") } return SaveConfig() } -func RestartAudioOutput() { +func RestartAudioOutput() error { audioMutex.Lock() - hasActiveOutput := outputSource != nil && currentAudioTrack != nil && audioOutputEnabled.Load() + hasActiveOutput := audioOutputEnabled.Load() && currentAudioTrack != nil && outputSource.Load() != nil audioMutex.Unlock() if !hasActiveOutput { - return + return nil } audioLogger.Info().Msg("Restarting audio output") - stopOutputAudio() - - ensureConfigLoaded() - alsaDevice := getAlsaDevice(config.AudioOutputSource) - - newSource := audio.NewCgoOutputSource(alsaDevice) - newSource.SetConfig(getAudioConfig()) - newRelay := audio.NewOutputRelay(newSource, currentAudioTrack) - - audioMutex.Lock() - outputSource = newSource - outputRelay = newRelay - audioMutex.Unlock() - - if err := newRelay.Start(); err != nil { - audioLogger.Error().Err(err).Msg("Failed to restart audio output") - } + return startAudio() } func handleInputTrackForSession(track *webrtc.TrackRemote) { myTrackID := track.ID() - audioLogger.Debug(). + trackLogger := audioLogger.With(). Str("codec", track.Codec().MimeType). Str("track_id", myTrackID). - Msg("starting input track handler") + Logger() + + trackLogger.Debug().Msg("starting input track handler") for { currentTrackID := currentInputTrack.Load() if currentTrackID != nil && *currentTrackID != myTrackID { - audioLogger.Debug(). - Str("my_track_id", myTrackID). + trackLogger.Debug(). Str("current_track_id", *currentTrackID). Msg("input track handler exiting - superseded") return @@ -324,10 +292,10 @@ func handleInputTrackForSession(track *webrtc.TrackRemote) { rtpPacket, _, err := track.ReadRTP() if err != nil { if err == io.EOF { - audioLogger.Debug().Str("track_id", myTrackID).Msg("input track ended") + trackLogger.Debug().Msg("input track ended") return } - audioLogger.Warn().Err(err).Str("track_id", myTrackID).Msg("failed to read RTP packet") + trackLogger.Warn().Err(err).Msg("failed to read RTP packet") continue } diff --git a/internal/audio/cgo_source.go b/internal/audio/cgo_source.go index 45f08be7..81f9b70b 100644 --- a/internal/audio/cgo_source.go +++ b/internal/audio/cgo_source.go @@ -25,46 +25,47 @@ const ( ) type CgoSource struct { - direction string - alsaDevice string - initialized bool - connected bool - mu sync.Mutex - logger zerolog.Logger - opusBuf []byte - config AudioConfig + outputDevice bool + alsaDevice string + connected bool + mu sync.Mutex + logger zerolog.Logger + opusBuf []byte + config AudioConfig } -func NewCgoOutputSource(alsaDevice string) *CgoSource { - logger := logging.GetDefaultLogger().With().Str("component", "audio-output-cgo").Logger() +var _ AudioSource = (*CgoSource)(nil) + +func NewCgoOutputSource(alsaDevice string, cfg AudioConfig) AudioSource { + logger := logging.GetDefaultLogger().With(). + Str("component", "audio-output-cgo"). + Str("alsa_device", alsaDevice). + Logger() return &CgoSource{ - direction: "output", - alsaDevice: alsaDevice, - logger: logger, - opusBuf: make([]byte, ipcMaxFrameSize), - config: DefaultAudioConfig(), + outputDevice: true, + alsaDevice: alsaDevice, + logger: logger, + opusBuf: make([]byte, ipcMaxFrameSize), + config: cfg, } } -func NewCgoInputSource(alsaDevice string) *CgoSource { - logger := logging.GetDefaultLogger().With().Str("component", "audio-input-cgo").Logger() +func NewCgoInputSource(alsaDevice string, cfg AudioConfig) AudioSource { + logger := logging.GetDefaultLogger().With(). + Str("component", "audio-input-cgo"). + Str("alsa_device", alsaDevice). + Logger() return &CgoSource{ - direction: "input", - alsaDevice: alsaDevice, - logger: logger, - opusBuf: make([]byte, ipcMaxFrameSize), - config: DefaultAudioConfig(), + outputDevice: false, + alsaDevice: alsaDevice, + logger: logger, + opusBuf: make([]byte, ipcMaxFrameSize), + config: cfg, } } -func (c *CgoSource) SetConfig(cfg AudioConfig) { - c.mu.Lock() - defer c.mu.Unlock() - c.config = cfg -} - func (c *CgoSource) Connect() error { c.mu.Lock() defer c.mu.Unlock() @@ -73,7 +74,7 @@ func (c *CgoSource) Connect() error { return nil } - if c.direction == "output" { + if c.outputDevice { os.Setenv("ALSA_CAPTURE_DEVICE", c.alsaDevice) dtx := C.uchar(0) @@ -93,7 +94,6 @@ func (c *CgoSource) Connect() error { Uint8("buffer_periods", c.config.BufferPeriods). Uint32("sample_rate", c.config.SampleRate). Uint8("packet_loss_perc", c.config.PacketLossPerc). - Str("alsa_device", c.alsaDevice). Msg("Initializing audio capture") C.update_audio_constants( @@ -139,7 +139,6 @@ func (c *CgoSource) Connect() error { } c.connected = true - c.initialized = true return nil } @@ -151,10 +150,12 @@ func (c *CgoSource) Disconnect() { return } - if c.direction == "output" { + if c.outputDevice { C.jetkvm_audio_capture_close() + os.Unsetenv("ALSA_CAPTURE_DEVICE") } else { C.jetkvm_audio_playback_close() + os.Unsetenv("ALSA_PLAYBACK_DEVICE") } c.connected = false @@ -173,7 +174,7 @@ func (c *CgoSource) ReadMessage() (uint8, []byte, error) { return 0, nil, fmt.Errorf("not connected") } - if c.direction != "output" { + if !c.outputDevice { c.mu.Unlock() return 0, nil, fmt.Errorf("ReadMessage only supported for output direction") } @@ -203,7 +204,7 @@ func (c *CgoSource) WriteMessage(msgType uint8, payload []byte) error { return fmt.Errorf("not connected") } - if c.direction != "input" { + if c.outputDevice { c.mu.Unlock() return fmt.Errorf("WriteMessage only supported for input direction") } diff --git a/internal/audio/cgo_source_stub.go b/internal/audio/cgo_source_stub.go index 3658877d..22cf499b 100644 --- a/internal/audio/cgo_source_stub.go +++ b/internal/audio/cgo_source_stub.go @@ -6,11 +6,13 @@ package audio type CgoSource struct{} -func NewCgoOutputSource(alsaDevice string) *CgoSource { +var _ AudioSource = (*CgoSource)(nil) + +func NewCgoOutputSource(alsaDevice string, audioConfig AudioConfig) AudioSource { panic("audio CGO source not supported on this platform") } -func NewCgoInputSource(alsaDevice string) *CgoSource { +func NewCgoInputSource(alsaDevice string, audioConfig AudioConfig) AudioSource { panic("audio CGO source not supported on this platform") } @@ -33,7 +35,3 @@ func (c *CgoSource) ReadMessage() (uint8, []byte, error) { func (c *CgoSource) WriteMessage(msgType uint8, payload []byte) error { panic("audio CGO source not supported on this platform") } - -func (c *CgoSource) SetConfig(cfg AudioConfig) { - panic("audio CGO source not supported on this platform") -} diff --git a/internal/audio/relay.go b/internal/audio/relay.go index e836482d..e877697c 100644 --- a/internal/audio/relay.go +++ b/internal/audio/relay.go @@ -13,7 +13,7 @@ import ( ) type OutputRelay struct { - source AudioSource + source *AudioSource audioTrack *webrtc.TrackLocalStaticSample ctx context.Context cancel context.CancelFunc @@ -26,7 +26,7 @@ type OutputRelay struct { framesDropped atomic.Uint32 } -func NewOutputRelay(source AudioSource, audioTrack *webrtc.TrackLocalStaticSample) *OutputRelay { +func NewOutputRelay(source *AudioSource, audioTrack *webrtc.TrackLocalStaticSample) *OutputRelay { ctx, cancel := context.WithCancel(context.Background()) logger := logging.GetDefaultLogger().With().Str("component", "audio-output-relay").Logger() @@ -73,19 +73,19 @@ func (r *OutputRelay) relayLoop() { const reconnectDelay = 1 * time.Second for r.running.Load() { - if !r.source.IsConnected() { - if err := r.source.Connect(); err != nil { + if !(*r.source).IsConnected() { + if err := (*r.source).Connect(); err != nil { r.logger.Debug().Err(err).Msg("failed to connect, will retry") time.Sleep(reconnectDelay) continue } } - msgType, payload, err := r.source.ReadMessage() + msgType, payload, err := (*r.source).ReadMessage() if err != nil { if r.running.Load() { r.logger.Warn().Err(err).Msg("read error, reconnecting") - r.source.Disconnect() + (*r.source).Disconnect() time.Sleep(reconnectDelay) } continue @@ -104,14 +104,14 @@ func (r *OutputRelay) relayLoop() { } type InputRelay struct { - source AudioSource + source *AudioSource ctx context.Context cancel context.CancelFunc logger zerolog.Logger running atomic.Bool } -func NewInputRelay(source AudioSource) *InputRelay { +func NewInputRelay(source *AudioSource) *InputRelay { ctx, cancel := context.WithCancel(context.Background()) logger := logging.GetDefaultLogger().With().Str("component", "audio-input-relay").Logger() diff --git a/internal/audio/source.go b/internal/audio/source.go index bdb953d4..b490b31d 100644 --- a/internal/audio/source.go +++ b/internal/audio/source.go @@ -32,5 +32,4 @@ type AudioSource interface { IsConnected() bool Connect() error Disconnect() - SetConfig(cfg AudioConfig) } diff --git a/jsonrpc.go b/jsonrpc.go index 1d596eaf..bbbb65e5 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -18,7 +18,6 @@ import ( "github.com/rs/zerolog" "go.bug.st/serial" - "github.com/jetkvm/kvm/internal/audio" "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/utils" @@ -688,10 +687,12 @@ func rpcGetUsbConfig() (usbgadget.Config, error) { func rpcSetUsbConfig(usbConfig usbgadget.Config) error { LoadConfig() + wasUsbAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio + config.UsbConfig = &usbConfig gadget.SetGadgetConfig(config.UsbConfig) - wasAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio - return updateUsbRelatedConfig(wasAudioEnabled) + + return updateUsbRelatedConfig(wasUsbAudioEnabled) } func rpcGetWakeOnLanDevices() ([]WakeOnLanDevice, error) { @@ -903,43 +904,23 @@ func rpcGetUsbDevices() (usbgadget.Devices, error) { return *config.UsbDevices, nil } -func updateUsbRelatedConfig(wasAudioEnabled bool) error { +func updateUsbRelatedConfig(wasUsbAudioEnabled bool) error { ensureConfigLoaded() + nowHasUsbAudio := config.UsbDevices != nil && config.UsbDevices.Audio + outputSourceIsUsb := config.AudioOutputSource == "usb" - audioMutex.Lock() - inRelay := inputRelay - inputRelay = nil - audioMutex.Unlock() + // must stop input audio before reconfiguring + stopInputAudio() - inSource := inputSource.Swap(nil) - - if inRelay != nil { - inRelay.Stop() - } - if inSource != nil { - (*inSource).Disconnect() - } - - // Auto-switch to HDMI audio output when USB audio is disabled - audioNowEnabled := config.UsbDevices != nil && config.UsbDevices.Audio - if wasAudioEnabled && !audioNowEnabled && config.AudioOutputSource == "usb" { - config.AudioOutputSource = "hdmi" + // if we're currently sourcing audio from USB, stop the output audio before reconfiguring + if outputSourceIsUsb { stopOutputAudio() - if audioOutputEnabled.Load() && activeConnections.Load() > 0 && currentAudioTrack != nil { - alsaDevice := getAlsaDevice("hdmi") - newSource := audio.NewCgoOutputSource(alsaDevice) - newSource.SetConfig(getAudioConfig()) - newRelay := audio.NewOutputRelay(newSource, currentAudioTrack) + } - audioMutex.Lock() - outputSource = newSource - outputRelay = newRelay - audioMutex.Unlock() - - if err := newRelay.Start(); err != nil { - logger.Warn().Err(err).Msg("Failed to start HDMI audio after USB audio disabled") - } - } + // Auto-switch to HDMI audio output when USB audio was selected and is now disabled + if wasUsbAudioEnabled && !nowHasUsbAudio && config.AudioOutputSource == "usb" { + logger.Info().Msg("USB audio just disabled, automatic switch audio output source to HDMI") + config.AudioOutputSource = "hdmi" } if err := gadget.UpdateGadgetConfig(); err != nil { @@ -950,18 +931,15 @@ func updateUsbRelatedConfig(wasAudioEnabled bool) error { return fmt.Errorf("failed to save config: %w", err) } - // Restart audio if USB audio is enabled with active connections - if activeConnections.Load() > 0 && config.UsbDevices != nil && config.UsbDevices.Audio { - if err := startAudio(); err != nil { - logger.Warn().Err(err).Msg("Failed to restart audio after USB reconfiguration") - } + if err := startAudio(); err != nil { + logger.Warn().Err(err).Msg("Failed to restart audio after USB reconfiguration") } return nil } func rpcSetUsbDevices(usbDevices usbgadget.Devices) error { - wasAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio + wasUsbAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio currentDevices := gadget.GetGadgetDevices() // Skip reconfiguration if devices haven't changed to avoid HID disruption @@ -973,11 +951,11 @@ func rpcSetUsbDevices(usbDevices usbgadget.Devices) error { config.UsbDevices = &usbDevices gadget.SetGadgetDevices(config.UsbDevices) - return updateUsbRelatedConfig(wasAudioEnabled) + return updateUsbRelatedConfig(wasUsbAudioEnabled) } func rpcSetUsbDeviceState(device string, enabled bool) error { - wasAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio + wasUsbAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio currentDevices := gadget.GetGadgetDevices() switch device { @@ -1002,7 +980,7 @@ func rpcSetUsbDeviceState(device string, enabled bool) error { } gadget.SetGadgetDevices(config.UsbDevices) - return updateUsbRelatedConfig(wasAudioEnabled) + return updateUsbRelatedConfig(wasUsbAudioEnabled) } func rpcGetAudioOutputEnabled() (bool, error) { @@ -1105,8 +1083,7 @@ func rpcSetAudioConfig(bitrate int, complexity int, dtxEnabled bool, fecEnabled } func rpcRestartAudioOutput() error { - RestartAudioOutput() - return nil + return RestartAudioOutput() } func rpcGetAudioInputAutoEnable() (bool, error) {