diff --git a/audio.go b/audio.go index a5305dd2..368b2f7e 100644 --- a/audio.go +++ b/audio.go @@ -15,10 +15,10 @@ var ( audioMutex sync.Mutex setAudioTrackMutex sync.Mutex // Prevents concurrent setAudioTrack() calls inputSourceMutex sync.Mutex // Serializes Connect() and WriteMessage() calls to input source - outputSource audio.AudioSource + outputSource atomic.Pointer[audio.AudioSource] inputSource atomic.Pointer[audio.AudioSource] - outputRelay *audio.OutputRelay - inputRelay *audio.InputRelay + outputRelay atomic.Pointer[audio.OutputRelay] + inputRelay atomic.Pointer[audio.InputRelay] audioInitialized bool activeConnections atomic.Int32 audioLogger zerolog.Logger @@ -79,58 +79,81 @@ func startAudio() error { return nil } - if outputSource == nil && audioOutputEnabled.Load() && currentAudioTrack != nil { - ensureConfigLoaded() - alsaDevice := getAlsaDevice(config.AudioOutputSource) - source := audio.NewCgoOutputSource(alsaDevice) - source.SetConfig(getAudioConfig()) - outputSource = source - outputRelay = audio.NewOutputRelay(outputSource, currentAudioTrack) - if err := outputRelay.Start(); err != nil { - audioLogger.Error().Err(err).Msg("Failed to start audio output relay") - } + if activeConnections.Load() <= 0 { + audioLogger.Debug().Msg("No active connections, skipping audio start") + return nil } ensureConfigLoaded() - if inputSource.Load() == nil && audioInputEnabled.Load() && config.UsbDevices != nil && config.UsbDevices.Audio { - alsaPlaybackDevice := getAlsaDevice("usb") - source := audio.NewCgoInputSource(alsaPlaybackDevice) - source.SetConfig(getAudioConfig()) - var audioSource audio.AudioSource = source - inputSource.Store(&audioSource) - inputRelay = audio.NewInputRelay(audioSource) - if err := inputRelay.Start(); err != nil { - audioLogger.Error().Err(err).Msg("Failed to start input relay") - } + if audioOutputEnabled.Load() && currentAudioTrack != nil { + startOutputAudioUnderMutex(getAlsaDevice(config.AudioOutputSource)) + } + + if audioInputEnabled.Load() && config.UsbDevices != nil && config.UsbDevices.Audio { + startInputAudioUnderMutex(getAlsaDevice("usb")) } return nil } +func startOutputAudioUnderMutex(alsaOutputDevice string) { + newSource := audio.NewCgoOutputSource(alsaOutputDevice, getAudioConfig()) + oldSource := outputSource.Swap(&newSource) + newRelay := audio.NewOutputRelay(&newSource, currentAudioTrack) + oldRelay := outputRelay.Swap(newRelay) + + if oldRelay != nil { + oldRelay.Stop() + } + + if oldSource != nil { + (*oldSource).Disconnect() + } + + if err := newRelay.Start(); err != nil { + audioLogger.Error().Err(err).Str("alsaOutputDevice", alsaOutputDevice).Msg("Failed to start audio output relay") + } +} + +func startInputAudioUnderMutex(alsaPlaybackDevice string) { + newSource := audio.NewCgoInputSource(alsaPlaybackDevice, getAudioConfig()) + oldSource := outputSource.Swap(&newSource) + newRelay := audio.NewInputRelay(&newSource) + oldRelay := inputRelay.Swap(newRelay) + + if oldRelay != nil { + oldRelay.Stop() + } + + if oldSource != nil { + (*oldSource).Disconnect() + } + + if err := newRelay.Start(); err != nil { + audioLogger.Error().Err(err).Str("alsaPlaybackDevice", alsaPlaybackDevice).Msg("Failed to start input relay") + } +} + func stopOutputAudio() { audioMutex.Lock() - outRelay := outputRelay - outSource := outputSource - outputRelay = nil - outputSource = nil + outRelay := outputRelay.Swap(nil) + outSource := outputSource.Swap(nil) audioMutex.Unlock() if outRelay != nil { outRelay.Stop() } if outSource != nil { - outSource.Disconnect() + (*outSource).Disconnect() } } func stopInputAudio() { audioMutex.Lock() - inRelay := inputRelay - inputRelay = nil - audioMutex.Unlock() - + inRelay := inputRelay.Swap(nil) inSource := inputSource.Swap(nil) + audioMutex.Unlock() if inRelay != nil { inRelay.Stop() @@ -156,7 +179,7 @@ func onWebRTCConnect() { func onWebRTCDisconnect() { count := activeConnections.Add(-1) - if count == 0 { + if count <= 0 { // Stop audio immediately to release HDMI audio device which shares hardware with video device stopAudio() } @@ -166,39 +189,12 @@ func setAudioTrack(audioTrack *webrtc.TrackLocalStaticSample) { setAudioTrackMutex.Lock() defer setAudioTrackMutex.Unlock() - // Capture old resources and update state in single critical section - audioMutex.Lock() + stopOutputAudio() + currentAudioTrack = audioTrack - oldRelay := outputRelay - oldSource := outputSource - outputRelay = nil - outputSource = nil - var newRelay *audio.OutputRelay - var newSource audio.AudioSource - if currentAudioTrack != nil && audioOutputEnabled.Load() { - ensureConfigLoaded() - alsaDevice := getAlsaDevice(config.AudioOutputSource) - newSource := audio.NewCgoOutputSource(alsaDevice) - newSource.SetConfig(getAudioConfig()) - newRelay = audio.NewOutputRelay(newSource, currentAudioTrack) - outputSource = newSource - outputRelay = newRelay - } - audioMutex.Unlock() - - // Stop/start resources outside mutex to avoid blocking on CGO calls - if oldRelay != nil { - oldRelay.Stop() - } - if oldSource != nil { - oldSource.Disconnect() - } - - if newRelay != nil { - if err := newRelay.Start(); err != nil { - audioLogger.Error().Err(err).Msg("Failed to start output relay") - } + if err := startAudio(); err != nil { + audioLogger.Error().Err(err).Msg("Failed to start with new audio track") } } @@ -250,72 +246,44 @@ func SetAudioOutputSource(source string) error { return nil } + stopOutputAudio() config.AudioOutputSource = source - stopOutputAudio() - - if audioOutputEnabled.Load() && activeConnections.Load() > 0 && currentAudioTrack != nil { - alsaDevice := getAlsaDevice(source) - newSource := audio.NewCgoOutputSource(alsaDevice) - newSource.SetConfig(getAudioConfig()) - newRelay := audio.NewOutputRelay(newSource, currentAudioTrack) - - audioMutex.Lock() - outputSource = newSource - outputRelay = newRelay - audioMutex.Unlock() - - if err := newRelay.Start(); err != nil { - audioLogger.Error().Err(err).Str("source", source).Msg("Failed to start audio relay with new source") - } + if err := startAudio(); err != nil { + audioLogger.Error().Err(err).Str("source", source).Msg("Failed to start audio output after source change") } return SaveConfig() } -func RestartAudioOutput() { +func RestartAudioOutput() error { audioMutex.Lock() - hasActiveOutput := outputSource != nil && currentAudioTrack != nil && audioOutputEnabled.Load() + hasActiveOutput := audioOutputEnabled.Load() && currentAudioTrack != nil && outputSource.Load() != nil audioMutex.Unlock() if !hasActiveOutput { - return + return nil } audioLogger.Info().Msg("Restarting audio output") - stopOutputAudio() - - ensureConfigLoaded() - alsaDevice := getAlsaDevice(config.AudioOutputSource) - - newSource := audio.NewCgoOutputSource(alsaDevice) - newSource.SetConfig(getAudioConfig()) - newRelay := audio.NewOutputRelay(newSource, currentAudioTrack) - - audioMutex.Lock() - outputSource = newSource - outputRelay = newRelay - audioMutex.Unlock() - - if err := newRelay.Start(); err != nil { - audioLogger.Error().Err(err).Msg("Failed to restart audio output") - } + return startAudio() } func handleInputTrackForSession(track *webrtc.TrackRemote) { myTrackID := track.ID() - audioLogger.Debug(). + trackLogger := audioLogger.With(). Str("codec", track.Codec().MimeType). Str("track_id", myTrackID). - Msg("starting input track handler") + Logger() + + trackLogger.Debug().Msg("starting input track handler") for { currentTrackID := currentInputTrack.Load() if currentTrackID != nil && *currentTrackID != myTrackID { - audioLogger.Debug(). - Str("my_track_id", myTrackID). + trackLogger.Debug(). Str("current_track_id", *currentTrackID). Msg("input track handler exiting - superseded") return @@ -324,10 +292,10 @@ func handleInputTrackForSession(track *webrtc.TrackRemote) { rtpPacket, _, err := track.ReadRTP() if err != nil { if err == io.EOF { - audioLogger.Debug().Str("track_id", myTrackID).Msg("input track ended") + trackLogger.Debug().Msg("input track ended") return } - audioLogger.Warn().Err(err).Str("track_id", myTrackID).Msg("failed to read RTP packet") + trackLogger.Warn().Err(err).Msg("failed to read RTP packet") continue } diff --git a/internal/audio/cgo_source.go b/internal/audio/cgo_source.go index 45f08be7..81f9b70b 100644 --- a/internal/audio/cgo_source.go +++ b/internal/audio/cgo_source.go @@ -25,46 +25,47 @@ const ( ) type CgoSource struct { - direction string - alsaDevice string - initialized bool - connected bool - mu sync.Mutex - logger zerolog.Logger - opusBuf []byte - config AudioConfig + outputDevice bool + alsaDevice string + connected bool + mu sync.Mutex + logger zerolog.Logger + opusBuf []byte + config AudioConfig } -func NewCgoOutputSource(alsaDevice string) *CgoSource { - logger := logging.GetDefaultLogger().With().Str("component", "audio-output-cgo").Logger() +var _ AudioSource = (*CgoSource)(nil) + +func NewCgoOutputSource(alsaDevice string, cfg AudioConfig) AudioSource { + logger := logging.GetDefaultLogger().With(). + Str("component", "audio-output-cgo"). + Str("alsa_device", alsaDevice). + Logger() return &CgoSource{ - direction: "output", - alsaDevice: alsaDevice, - logger: logger, - opusBuf: make([]byte, ipcMaxFrameSize), - config: DefaultAudioConfig(), + outputDevice: true, + alsaDevice: alsaDevice, + logger: logger, + opusBuf: make([]byte, ipcMaxFrameSize), + config: cfg, } } -func NewCgoInputSource(alsaDevice string) *CgoSource { - logger := logging.GetDefaultLogger().With().Str("component", "audio-input-cgo").Logger() +func NewCgoInputSource(alsaDevice string, cfg AudioConfig) AudioSource { + logger := logging.GetDefaultLogger().With(). + Str("component", "audio-input-cgo"). + Str("alsa_device", alsaDevice). + Logger() return &CgoSource{ - direction: "input", - alsaDevice: alsaDevice, - logger: logger, - opusBuf: make([]byte, ipcMaxFrameSize), - config: DefaultAudioConfig(), + outputDevice: false, + alsaDevice: alsaDevice, + logger: logger, + opusBuf: make([]byte, ipcMaxFrameSize), + config: cfg, } } -func (c *CgoSource) SetConfig(cfg AudioConfig) { - c.mu.Lock() - defer c.mu.Unlock() - c.config = cfg -} - func (c *CgoSource) Connect() error { c.mu.Lock() defer c.mu.Unlock() @@ -73,7 +74,7 @@ func (c *CgoSource) Connect() error { return nil } - if c.direction == "output" { + if c.outputDevice { os.Setenv("ALSA_CAPTURE_DEVICE", c.alsaDevice) dtx := C.uchar(0) @@ -93,7 +94,6 @@ func (c *CgoSource) Connect() error { Uint8("buffer_periods", c.config.BufferPeriods). Uint32("sample_rate", c.config.SampleRate). Uint8("packet_loss_perc", c.config.PacketLossPerc). - Str("alsa_device", c.alsaDevice). Msg("Initializing audio capture") C.update_audio_constants( @@ -139,7 +139,6 @@ func (c *CgoSource) Connect() error { } c.connected = true - c.initialized = true return nil } @@ -151,10 +150,12 @@ func (c *CgoSource) Disconnect() { return } - if c.direction == "output" { + if c.outputDevice { C.jetkvm_audio_capture_close() + os.Unsetenv("ALSA_CAPTURE_DEVICE") } else { C.jetkvm_audio_playback_close() + os.Unsetenv("ALSA_PLAYBACK_DEVICE") } c.connected = false @@ -173,7 +174,7 @@ func (c *CgoSource) ReadMessage() (uint8, []byte, error) { return 0, nil, fmt.Errorf("not connected") } - if c.direction != "output" { + if !c.outputDevice { c.mu.Unlock() return 0, nil, fmt.Errorf("ReadMessage only supported for output direction") } @@ -203,7 +204,7 @@ func (c *CgoSource) WriteMessage(msgType uint8, payload []byte) error { return fmt.Errorf("not connected") } - if c.direction != "input" { + if c.outputDevice { c.mu.Unlock() return fmt.Errorf("WriteMessage only supported for input direction") } diff --git a/internal/audio/cgo_source_stub.go b/internal/audio/cgo_source_stub.go index 3658877d..22cf499b 100644 --- a/internal/audio/cgo_source_stub.go +++ b/internal/audio/cgo_source_stub.go @@ -6,11 +6,13 @@ package audio type CgoSource struct{} -func NewCgoOutputSource(alsaDevice string) *CgoSource { +var _ AudioSource = (*CgoSource)(nil) + +func NewCgoOutputSource(alsaDevice string, audioConfig AudioConfig) AudioSource { panic("audio CGO source not supported on this platform") } -func NewCgoInputSource(alsaDevice string) *CgoSource { +func NewCgoInputSource(alsaDevice string, audioConfig AudioConfig) AudioSource { panic("audio CGO source not supported on this platform") } @@ -33,7 +35,3 @@ func (c *CgoSource) ReadMessage() (uint8, []byte, error) { func (c *CgoSource) WriteMessage(msgType uint8, payload []byte) error { panic("audio CGO source not supported on this platform") } - -func (c *CgoSource) SetConfig(cfg AudioConfig) { - panic("audio CGO source not supported on this platform") -} diff --git a/internal/audio/relay.go b/internal/audio/relay.go index e836482d..e877697c 100644 --- a/internal/audio/relay.go +++ b/internal/audio/relay.go @@ -13,7 +13,7 @@ import ( ) type OutputRelay struct { - source AudioSource + source *AudioSource audioTrack *webrtc.TrackLocalStaticSample ctx context.Context cancel context.CancelFunc @@ -26,7 +26,7 @@ type OutputRelay struct { framesDropped atomic.Uint32 } -func NewOutputRelay(source AudioSource, audioTrack *webrtc.TrackLocalStaticSample) *OutputRelay { +func NewOutputRelay(source *AudioSource, audioTrack *webrtc.TrackLocalStaticSample) *OutputRelay { ctx, cancel := context.WithCancel(context.Background()) logger := logging.GetDefaultLogger().With().Str("component", "audio-output-relay").Logger() @@ -73,19 +73,19 @@ func (r *OutputRelay) relayLoop() { const reconnectDelay = 1 * time.Second for r.running.Load() { - if !r.source.IsConnected() { - if err := r.source.Connect(); err != nil { + if !(*r.source).IsConnected() { + if err := (*r.source).Connect(); err != nil { r.logger.Debug().Err(err).Msg("failed to connect, will retry") time.Sleep(reconnectDelay) continue } } - msgType, payload, err := r.source.ReadMessage() + msgType, payload, err := (*r.source).ReadMessage() if err != nil { if r.running.Load() { r.logger.Warn().Err(err).Msg("read error, reconnecting") - r.source.Disconnect() + (*r.source).Disconnect() time.Sleep(reconnectDelay) } continue @@ -104,14 +104,14 @@ func (r *OutputRelay) relayLoop() { } type InputRelay struct { - source AudioSource + source *AudioSource ctx context.Context cancel context.CancelFunc logger zerolog.Logger running atomic.Bool } -func NewInputRelay(source AudioSource) *InputRelay { +func NewInputRelay(source *AudioSource) *InputRelay { ctx, cancel := context.WithCancel(context.Background()) logger := logging.GetDefaultLogger().With().Str("component", "audio-input-relay").Logger() diff --git a/internal/audio/source.go b/internal/audio/source.go index bdb953d4..b490b31d 100644 --- a/internal/audio/source.go +++ b/internal/audio/source.go @@ -32,5 +32,4 @@ type AudioSource interface { IsConnected() bool Connect() error Disconnect() - SetConfig(cfg AudioConfig) } diff --git a/jsonrpc.go b/jsonrpc.go index 1d596eaf..bbbb65e5 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -18,7 +18,6 @@ import ( "github.com/rs/zerolog" "go.bug.st/serial" - "github.com/jetkvm/kvm/internal/audio" "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/utils" @@ -688,10 +687,12 @@ func rpcGetUsbConfig() (usbgadget.Config, error) { func rpcSetUsbConfig(usbConfig usbgadget.Config) error { LoadConfig() + wasUsbAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio + config.UsbConfig = &usbConfig gadget.SetGadgetConfig(config.UsbConfig) - wasAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio - return updateUsbRelatedConfig(wasAudioEnabled) + + return updateUsbRelatedConfig(wasUsbAudioEnabled) } func rpcGetWakeOnLanDevices() ([]WakeOnLanDevice, error) { @@ -903,43 +904,23 @@ func rpcGetUsbDevices() (usbgadget.Devices, error) { return *config.UsbDevices, nil } -func updateUsbRelatedConfig(wasAudioEnabled bool) error { +func updateUsbRelatedConfig(wasUsbAudioEnabled bool) error { ensureConfigLoaded() + nowHasUsbAudio := config.UsbDevices != nil && config.UsbDevices.Audio + outputSourceIsUsb := config.AudioOutputSource == "usb" - audioMutex.Lock() - inRelay := inputRelay - inputRelay = nil - audioMutex.Unlock() + // must stop input audio before reconfiguring + stopInputAudio() - inSource := inputSource.Swap(nil) - - if inRelay != nil { - inRelay.Stop() - } - if inSource != nil { - (*inSource).Disconnect() - } - - // Auto-switch to HDMI audio output when USB audio is disabled - audioNowEnabled := config.UsbDevices != nil && config.UsbDevices.Audio - if wasAudioEnabled && !audioNowEnabled && config.AudioOutputSource == "usb" { - config.AudioOutputSource = "hdmi" + // if we're currently sourcing audio from USB, stop the output audio before reconfiguring + if outputSourceIsUsb { stopOutputAudio() - if audioOutputEnabled.Load() && activeConnections.Load() > 0 && currentAudioTrack != nil { - alsaDevice := getAlsaDevice("hdmi") - newSource := audio.NewCgoOutputSource(alsaDevice) - newSource.SetConfig(getAudioConfig()) - newRelay := audio.NewOutputRelay(newSource, currentAudioTrack) + } - audioMutex.Lock() - outputSource = newSource - outputRelay = newRelay - audioMutex.Unlock() - - if err := newRelay.Start(); err != nil { - logger.Warn().Err(err).Msg("Failed to start HDMI audio after USB audio disabled") - } - } + // Auto-switch to HDMI audio output when USB audio was selected and is now disabled + if wasUsbAudioEnabled && !nowHasUsbAudio && config.AudioOutputSource == "usb" { + logger.Info().Msg("USB audio just disabled, automatic switch audio output source to HDMI") + config.AudioOutputSource = "hdmi" } if err := gadget.UpdateGadgetConfig(); err != nil { @@ -950,18 +931,15 @@ func updateUsbRelatedConfig(wasAudioEnabled bool) error { return fmt.Errorf("failed to save config: %w", err) } - // Restart audio if USB audio is enabled with active connections - if activeConnections.Load() > 0 && config.UsbDevices != nil && config.UsbDevices.Audio { - if err := startAudio(); err != nil { - logger.Warn().Err(err).Msg("Failed to restart audio after USB reconfiguration") - } + if err := startAudio(); err != nil { + logger.Warn().Err(err).Msg("Failed to restart audio after USB reconfiguration") } return nil } func rpcSetUsbDevices(usbDevices usbgadget.Devices) error { - wasAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio + wasUsbAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio currentDevices := gadget.GetGadgetDevices() // Skip reconfiguration if devices haven't changed to avoid HID disruption @@ -973,11 +951,11 @@ func rpcSetUsbDevices(usbDevices usbgadget.Devices) error { config.UsbDevices = &usbDevices gadget.SetGadgetDevices(config.UsbDevices) - return updateUsbRelatedConfig(wasAudioEnabled) + return updateUsbRelatedConfig(wasUsbAudioEnabled) } func rpcSetUsbDeviceState(device string, enabled bool) error { - wasAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio + wasUsbAudioEnabled := config.UsbDevices != nil && config.UsbDevices.Audio currentDevices := gadget.GetGadgetDevices() switch device { @@ -1002,7 +980,7 @@ func rpcSetUsbDeviceState(device string, enabled bool) error { } gadget.SetGadgetDevices(config.UsbDevices) - return updateUsbRelatedConfig(wasAudioEnabled) + return updateUsbRelatedConfig(wasUsbAudioEnabled) } func rpcGetAudioOutputEnabled() (bool, error) { @@ -1105,8 +1083,7 @@ func rpcSetAudioConfig(bitrate int, complexity int, dtxEnabled bool, fecEnabled } func rpcRestartAudioOutput() error { - RestartAudioOutput() - return nil + return RestartAudioOutput() } func rpcGetAudioInputAutoEnable() (bool, error) {