diff --git a/.gitignore b/.gitignore index beace99a..59a2217a 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,10 @@ node_modules #internal/native/include #internal/native/lib internal/audio/bin/ + +# backup files +*.bak + +# core dumps +core +core.* diff --git a/Makefile b/Makefile index ca7dd61f..0c69d7b8 100644 --- a/Makefile +++ b/Makefile @@ -99,6 +99,7 @@ build_audio_output: build_audio_deps -o $(BIN_DIR)/jetkvm_audio_output \ internal/audio/c/jetkvm_audio_output.c \ internal/audio/c/ipc_protocol.c \ + internal/audio/c/audio_common.c \ internal/audio/c/audio.c \ $(CGO_LDFLAGS); \ fi @@ -114,6 +115,7 @@ build_audio_input: build_audio_deps -o $(BIN_DIR)/jetkvm_audio_input \ internal/audio/c/jetkvm_audio_input.c \ internal/audio/c/ipc_protocol.c \ + internal/audio/c/audio_common.c \ internal/audio/c/audio.c \ $(CGO_LDFLAGS); \ fi diff --git a/audio_handlers.go b/audio_handlers.go index cf9969dd..8b63859f 100644 --- a/audio_handlers.go +++ b/audio_handlers.go @@ -12,7 +12,7 @@ var audioControlService *audio.AudioControlService func ensureAudioControlService() *audio.AudioControlService { if audioControlService == nil { - sessionProvider := &SessionProviderImpl{} + sessionProvider := &KVMSessionProvider{} audioControlService = audio.NewAudioControlService(sessionProvider, logger) // Set up RPC callback function for the audio package diff --git a/audio_session_provider.go b/audio_session_provider.go deleted file mode 100644 index bc93303d..00000000 --- a/audio_session_provider.go +++ /dev/null @@ -1,24 +0,0 @@ -package kvm - -import "github.com/jetkvm/kvm/internal/audio" - -// SessionProviderImpl implements the audio.SessionProvider interface -type SessionProviderImpl struct{} - -// NewSessionProvider creates a new session provider -func NewSessionProvider() *SessionProviderImpl { - return &SessionProviderImpl{} -} - -// IsSessionActive returns whether there's an active session -func (sp *SessionProviderImpl) IsSessionActive() bool { - return currentSession != nil -} - -// GetAudioInputManager returns the current session's audio input manager -func (sp *SessionProviderImpl) GetAudioInputManager() *audio.AudioInputManager { - if currentSession == nil { - return nil - } - return currentSession.AudioInputManager -} diff --git a/internal/audio/c/audio.c b/internal/audio/c/audio.c index c1f582b1..a60a4e06 100644 --- a/internal/audio/c/audio.c +++ b/internal/audio/c/audio.c @@ -155,36 +155,6 @@ static inline void simd_clear_samples_s16(short *buffer, int samples) { } } -/** - * Interleave L/R channels using NEON (8 frames/iteration) - * Converts separate left/right buffers to interleaved stereo (LRLRLR...) - * @param left Left channel samples - * @param right Right channel samples - * @param output Interleaved stereo output buffer - * @param frames Number of stereo frames to process - */ -static inline void simd_interleave_stereo_s16(const short *left, const short *right, - short *output, int frames) { - simd_init_once(); - - int simd_frames = frames & ~7; - - // SIMD path: interleave 8 frames (16 samples) per iteration - for (int i = 0; i < simd_frames; i += 8) { - int16x8_t left_vec = vld1q_s16(&left[i]); - int16x8_t right_vec = vld1q_s16(&right[i]); - int16x8x2_t interleaved = vzipq_s16(left_vec, right_vec); - vst1q_s16(&output[i * 2], interleaved.val[0]); - vst1q_s16(&output[i * 2 + 8], interleaved.val[1]); - } - - // Scalar path: handle remaining frames - for (int i = simd_frames; i < frames; i++) { - output[i * 2] = left[i]; - output[i * 2 + 1] = right[i]; - } -} - /** * Apply gain using NEON Q15 fixed-point math (8 samples/iteration) * Uses vqrdmulhq_s16 for single-instruction saturating rounded multiply-high @@ -214,234 +184,6 @@ static inline void simd_scale_volume_s16(short *samples, int count, float volume } } -/** - * Byte-swap 16-bit samples using NEON (8 samples/iteration) - * Converts between little-endian and big-endian formats - * @param samples Audio buffer to byte-swap in-place - * @param count Number of samples to process - */ -static inline void simd_swap_endian_s16(short *samples, int count) { - int simd_count = count & ~7; - - // SIMD path: swap 8 samples per iteration - for (int i = 0; i < simd_count; i += 8) { - uint16x8_t samples_vec = vld1q_u16((uint16_t*)&samples[i]); - uint8x16_t samples_u8 = vreinterpretq_u8_u16(samples_vec); - uint8x16_t swapped_u8 = vrev16q_u8(samples_u8); - uint16x8_t swapped = vreinterpretq_u16_u8(swapped_u8); - vst1q_u16((uint16_t*)&samples[i], swapped); - } - - // Scalar path: handle remaining samples - for (int i = simd_count; i < count; i++) { - samples[i] = __builtin_bswap16(samples[i]); - } -} - -/** - * Convert S16 to float using NEON (4 samples/iteration) - * Converts 16-bit signed integers to normalized float [-1.0, 1.0] - * @param input S16 audio samples - * @param output Float output buffer - * @param count Number of samples to convert - */ -static inline void simd_s16_to_float(const short *input, float *output, int count) { - const float scale = 1.0f / 32768.0f; - int simd_count = count & ~3; - float32x4_t scale_vec = vdupq_n_f32(scale); - - // SIMD path: convert 4 samples per iteration - for (int i = 0; i < simd_count; i += 4) { - int16x4_t s16_data = vld1_s16(input + i); - int32x4_t s32_data = vmovl_s16(s16_data); - float32x4_t float_data = vcvtq_f32_s32(s32_data); - float32x4_t scaled = vmulq_f32(float_data, scale_vec); - vst1q_f32(output + i, scaled); - } - - // Scalar path: handle remaining samples - for (int i = simd_count; i < count; i++) { - output[i] = (float)input[i] * scale; - } -} - -/** - * Convert float to S16 using NEON (4 samples/iteration) - * Converts normalized float [-1.0, 1.0] to 16-bit signed integers with saturation - * @param input Float audio samples - * @param output S16 output buffer - * @param count Number of samples to convert - */ -static inline void simd_float_to_s16(const float *input, short *output, int count) { - const float scale = 32767.0f; - int simd_count = count & ~3; - float32x4_t scale_vec = vdupq_n_f32(scale); - - // SIMD path: convert 4 samples per iteration with saturation - for (int i = 0; i < simd_count; i += 4) { - float32x4_t float_data = vld1q_f32(input + i); - float32x4_t scaled = vmulq_f32(float_data, scale_vec); - int32x4_t s32_data = vcvtq_s32_f32(scaled); - int16x4_t s16_data = vqmovn_s32(s32_data); - vst1_s16(output + i, s16_data); - } - - // Scalar path: handle remaining samples with clamping - for (int i = simd_count; i < count; i++) { - float scaled = input[i] * scale; - output[i] = (short)__builtin_fmaxf(__builtin_fminf(scaled, 32767.0f), -32768.0f); - } -} - -/** - * Mono → stereo (duplicate samples) using NEON (4 frames/iteration) - * Duplicates mono samples to both L and R channels - * @param mono Mono input buffer - * @param stereo Stereo output buffer - * @param frames Number of frames to process - */ -static inline void simd_mono_to_stereo_s16(const short *mono, short *stereo, int frames) { - int simd_frames = frames & ~3; - - // SIMD path: duplicate 4 frames (8 samples) per iteration - for (int i = 0; i < simd_frames; i += 4) { - int16x4_t mono_data = vld1_s16(mono + i); - int16x4x2_t stereo_data = {mono_data, mono_data}; - vst2_s16(stereo + i * 2, stereo_data); - } - - // Scalar path: handle remaining frames - for (int i = simd_frames; i < frames; i++) { - stereo[i * 2] = mono[i]; - stereo[i * 2 + 1] = mono[i]; - } -} - -/** - * Stereo → mono (average L+R) using NEON (4 frames/iteration) - * Downmixes stereo to mono by averaging left and right channels - * @param stereo Interleaved stereo input buffer - * @param mono Mono output buffer - * @param frames Number of frames to process - */ -static inline void simd_stereo_to_mono_s16(const short *stereo, short *mono, int frames) { - int simd_frames = frames & ~3; - - // SIMD path: average 4 stereo frames per iteration - for (int i = 0; i < simd_frames; i += 4) { - int16x4x2_t stereo_data = vld2_s16(stereo + i * 2); - int32x4_t left_wide = vmovl_s16(stereo_data.val[0]); - int32x4_t right_wide = vmovl_s16(stereo_data.val[1]); - int32x4_t sum = vaddq_s32(left_wide, right_wide); - int32x4_t avg = vshrq_n_s32(sum, 1); - int16x4_t mono_data = vqmovn_s32(avg); - vst1_s16(mono + i, mono_data); - } - - // Scalar path: handle remaining frames - for (int i = simd_frames; i < frames; i++) { - mono[i] = (stereo[i * 2] + stereo[i * 2 + 1]) / 2; - } -} - -/** - * Apply L/R balance using NEON (4 frames/iteration) - * Adjusts stereo balance: negative = more left, positive = more right - * @param stereo Interleaved stereo buffer to modify in-place - * @param frames Number of stereo frames to process - * @param balance Balance factor [-1.0 = full left, 0.0 = center, 1.0 = full right] - */ -static inline void simd_apply_stereo_balance_s16(short *stereo, int frames, float balance) { - int simd_frames = frames & ~3; - float left_gain = balance <= 0.0f ? 1.0f : 1.0f - balance; - float right_gain = balance >= 0.0f ? 1.0f : 1.0f + balance; - float32x4_t left_gain_vec = vdupq_n_f32(left_gain); - float32x4_t right_gain_vec = vdupq_n_f32(right_gain); - - // SIMD path: apply balance to 4 stereo frames per iteration - for (int i = 0; i < simd_frames; i += 4) { - int16x4x2_t stereo_data = vld2_s16(stereo + i * 2); - int32x4_t left_wide = vmovl_s16(stereo_data.val[0]); - int32x4_t right_wide = vmovl_s16(stereo_data.val[1]); - float32x4_t left_float = vcvtq_f32_s32(left_wide); - float32x4_t right_float = vcvtq_f32_s32(right_wide); - left_float = vmulq_f32(left_float, left_gain_vec); - right_float = vmulq_f32(right_float, right_gain_vec); - int32x4_t left_result = vcvtq_s32_f32(left_float); - int32x4_t right_result = vcvtq_s32_f32(right_float); - stereo_data.val[0] = vqmovn_s32(left_result); - stereo_data.val[1] = vqmovn_s32(right_result); - vst2_s16(stereo + i * 2, stereo_data); - } - - // Scalar path: handle remaining frames - for (int i = simd_frames; i < frames; i++) { - stereo[i * 2] = (short)(stereo[i * 2] * left_gain); - stereo[i * 2 + 1] = (short)(stereo[i * 2 + 1] * right_gain); - } -} - -/** - * Deinterleave stereo → L/R channels using NEON (4 frames/iteration) - * Separates interleaved stereo (LRLRLR...) into separate L and R buffers - * @param interleaved Interleaved stereo input buffer - * @param left Left channel output buffer - * @param right Right channel output buffer - * @param frames Number of stereo frames to process - */ -static inline void simd_deinterleave_stereo_s16(const short *interleaved, short *left, - short *right, int frames) { - int simd_frames = frames & ~3; - - // SIMD path: deinterleave 4 frames (8 samples) per iteration - for (int i = 0; i < simd_frames; i += 4) { - int16x4x2_t stereo_data = vld2_s16(interleaved + i * 2); - vst1_s16(left + i, stereo_data.val[0]); - vst1_s16(right + i, stereo_data.val[1]); - } - - // Scalar path: handle remaining frames - for (int i = simd_frames; i < frames; i++) { - left[i] = interleaved[i * 2]; - right[i] = interleaved[i * 2 + 1]; - } -} - -/** - * Find max absolute sample value for silence detection using NEON (8 samples/iteration) - * Used to detect silence (threshold < 50 = ~0.15% max volume) and audio discontinuities - * @param samples Audio buffer to analyze - * @param count Number of samples to process - * @return Maximum absolute sample value in the buffer - */ -static inline short simd_find_max_abs_s16(const short *samples, int count) { - int simd_count = count & ~7; - int16x8_t max_vec = vdupq_n_s16(0); - - // SIMD path: find max of 8 samples per iteration - for (int i = 0; i < simd_count; i += 8) { - int16x8_t samples_vec = vld1q_s16(&samples[i]); - int16x8_t abs_vec = vabsq_s16(samples_vec); - max_vec = vmaxq_s16(max_vec, abs_vec); - } - - // Horizontal reduction: extract single max value from vector - int16x4_t max_half = vmax_s16(vget_low_s16(max_vec), vget_high_s16(max_vec)); - int16x4_t max_folded = vpmax_s16(max_half, max_half); - max_folded = vpmax_s16(max_folded, max_folded); - short max_sample = vget_lane_s16(max_folded, 0); - - // Scalar path: handle remaining samples - for (int i = simd_count; i < count; i++) { - short abs_sample = samples[i] < 0 ? -samples[i] : samples[i]; - if (abs_sample > max_sample) { - max_sample = abs_sample; - } - } - - return max_sample; -} - // ============================================================================ // INITIALIZATION STATE TRACKING // ============================================================================ diff --git a/internal/audio/c/jetkvm_audio_input.c b/internal/audio/c/jetkvm_audio_input.c index 19a5f239..17ba53af 100644 --- a/internal/audio/c/jetkvm_audio_input.c +++ b/internal/audio/c/jetkvm_audio_input.c @@ -11,6 +11,7 @@ */ #include "ipc_protocol.h" +#include "audio_common.h" #include #include #include @@ -48,80 +49,25 @@ typedef struct { int trace_logging; // Enable trace logging (default: 0) } audio_config_t; -// ============================================================================ -// SIGNAL HANDLERS -// ============================================================================ - -static void signal_handler(int signo) { - if (signo == SIGTERM || signo == SIGINT) { - printf("Audio input server: Received signal %d, shutting down...\n", signo); - g_running = 0; - } -} - -static void setup_signal_handlers(void) { - struct sigaction sa; - memset(&sa, 0, sizeof(sa)); - sa.sa_handler = signal_handler; - sigemptyset(&sa.sa_mask); - sa.sa_flags = 0; - - sigaction(SIGTERM, &sa, NULL); - sigaction(SIGINT, &sa, NULL); - - // Ignore SIGPIPE - signal(SIGPIPE, SIG_IGN); -} - // ============================================================================ // CONFIGURATION PARSING // ============================================================================ -static int parse_env_int(const char *name, int default_value) { - const char *str = getenv(name); - if (str == NULL || str[0] == '\0') { - return default_value; - } - return atoi(str); -} - -static const char* parse_env_string(const char *name, const char *default_value) { - const char *str = getenv(name); - if (str == NULL || str[0] == '\0') { - return default_value; - } - return str; -} - -static int is_trace_enabled(void) { - const char *pion_trace = getenv("PION_LOG_TRACE"); - if (pion_trace == NULL) { - return 0; - } - - // Check if "audio" is in comma-separated list - if (strstr(pion_trace, "audio") != NULL) { - return 1; - } - - return 0; -} - static void load_audio_config(audio_config_t *config) { // ALSA device configuration - config->alsa_device = parse_env_string("ALSA_PLAYBACK_DEVICE", "hw:1,0"); + config->alsa_device = audio_common_parse_env_string("ALSA_PLAYBACK_DEVICE", "hw:1,0"); // Opus configuration (informational only for decoder) - config->opus_bitrate = parse_env_int("OPUS_BITRATE", 96000); - config->opus_complexity = parse_env_int("OPUS_COMPLEXITY", 1); + config->opus_bitrate = audio_common_parse_env_int("OPUS_BITRATE", 96000); + config->opus_complexity = audio_common_parse_env_int("OPUS_COMPLEXITY", 1); // Audio format - config->sample_rate = parse_env_int("AUDIO_SAMPLE_RATE", 48000); - config->channels = parse_env_int("AUDIO_CHANNELS", 2); - config->frame_size = parse_env_int("AUDIO_FRAME_SIZE", 960); + config->sample_rate = audio_common_parse_env_int("AUDIO_SAMPLE_RATE", 48000); + config->channels = audio_common_parse_env_int("AUDIO_CHANNELS", 2); + config->frame_size = audio_common_parse_env_int("AUDIO_FRAME_SIZE", 960); // Logging - config->trace_logging = is_trace_enabled(); + config->trace_logging = audio_common_is_trace_enabled(); // Log configuration printf("Audio Input Server Configuration:\n"); @@ -269,7 +215,7 @@ int main(int argc, char **argv) { printf("JetKVM Audio Input Server Starting...\n"); // Setup signal handlers - setup_signal_handlers(); + audio_common_setup_signal_handlers(&g_running); // Load configuration from environment audio_config_t config; diff --git a/internal/audio/c/jetkvm_audio_output.c b/internal/audio/c/jetkvm_audio_output.c index 1863961b..cd98fa7a 100644 --- a/internal/audio/c/jetkvm_audio_output.c +++ b/internal/audio/c/jetkvm_audio_output.c @@ -8,6 +8,7 @@ */ #include "ipc_protocol.h" +#include "audio_common.h" #include #include #include @@ -51,86 +52,31 @@ typedef struct { int trace_logging; // Enable trace logging (default: 0) } audio_config_t; -// ============================================================================ -// SIGNAL HANDLERS -// ============================================================================ - -static void signal_handler(int signo) { - if (signo == SIGTERM || signo == SIGINT) { - printf("Audio output server: Received signal %d, shutting down...\n", signo); - g_running = 0; - } -} - -static void setup_signal_handlers(void) { - struct sigaction sa; - memset(&sa, 0, sizeof(sa)); - sa.sa_handler = signal_handler; - sigemptyset(&sa.sa_mask); - sa.sa_flags = 0; - - sigaction(SIGTERM, &sa, NULL); - sigaction(SIGINT, &sa, NULL); - - // Ignore SIGPIPE (write to closed socket should return error, not crash) - signal(SIGPIPE, SIG_IGN); -} - // ============================================================================ // CONFIGURATION PARSING // ============================================================================ -static int parse_env_int(const char *name, int default_value) { - const char *str = getenv(name); - if (str == NULL || str[0] == '\0') { - return default_value; - } - return atoi(str); -} - -static const char* parse_env_string(const char *name, const char *default_value) { - const char *str = getenv(name); - if (str == NULL || str[0] == '\0') { - return default_value; - } - return str; -} - -static int is_trace_enabled(void) { - const char *pion_trace = getenv("PION_LOG_TRACE"); - if (pion_trace == NULL) { - return 0; - } - - // Check if "audio" is in comma-separated list - if (strstr(pion_trace, "audio") != NULL) { - return 1; - } - - return 0; -} - static void load_audio_config(audio_config_t *config) { // ALSA device configuration - config->alsa_device = parse_env_string("ALSA_CAPTURE_DEVICE", "hw:0,0"); + config->alsa_device = audio_common_parse_env_string("ALSA_CAPTURE_DEVICE", "hw:0,0"); // Opus encoder configuration - config->opus_bitrate = parse_env_int("OPUS_BITRATE", 96000); - config->opus_complexity = parse_env_int("OPUS_COMPLEXITY", 1); - config->opus_vbr = parse_env_int("OPUS_VBR", 1); - config->opus_vbr_constraint = parse_env_int("OPUS_VBR_CONSTRAINT", 1); - config->opus_signal_type = parse_env_int("OPUS_SIGNAL_TYPE", -1000); - config->opus_bandwidth = parse_env_int("OPUS_BANDWIDTH", 1103); - config->opus_dtx = parse_env_int("OPUS_DTX", 0); - config->opus_lsb_depth = parse_env_int("OPUS_LSB_DEPTH", 16); + config->opus_bitrate = audio_common_parse_env_int("OPUS_BITRATE", 96000); + config->opus_complexity = audio_common_parse_env_int("OPUS_COMPLEXITY", 1); + config->opus_vbr = audio_common_parse_env_int("OPUS_VBR", 1); + config->opus_vbr_constraint = audio_common_parse_env_int("OPUS_VBR_CONSTRAINT", 1); + config->opus_signal_type = audio_common_parse_env_int("OPUS_SIGNAL_TYPE", -1000); + config->opus_bandwidth = audio_common_parse_env_int("OPUS_BANDWIDTH", 1103); + config->opus_dtx = audio_common_parse_env_int("OPUS_DTX", 0); + config->opus_lsb_depth = audio_common_parse_env_int("OPUS_LSB_DEPTH", 16); // Audio format - config->sample_rate = parse_env_int("AUDIO_SAMPLE_RATE", 48000); - config->channels = parse_env_int("AUDIO_CHANNELS", 2); - config->frame_size = parse_env_int("AUDIO_FRAME_SIZE", 960); + config->sample_rate = audio_common_parse_env_int("AUDIO_SAMPLE_RATE", 48000); + config->channels = audio_common_parse_env_int("AUDIO_CHANNELS", 2); + config->frame_size = audio_common_parse_env_int("AUDIO_FRAME_SIZE", 960); // Logging - config->trace_logging = is_trace_enabled(); + config->trace_logging = audio_common_is_trace_enabled(); // Log configuration printf("Audio Output Server Configuration:\n"); @@ -310,7 +256,7 @@ int main(int argc, char **argv) { printf("JetKVM Audio Output Server Starting...\n"); // Setup signal handlers - setup_signal_handlers(); + audio_common_setup_signal_handlers(&g_running); // Load configuration from environment audio_config_t config; diff --git a/internal/audio/embed.go b/internal/audio/embed.go index 0e926526..f7a4df40 100644 --- a/internal/audio/embed.go +++ b/internal/audio/embed.go @@ -82,36 +82,6 @@ func GetAudioInputBinaryPath() string { return audioInputBinPath } -// CleanupBinaries removes extracted audio binaries (useful for cleanup/testing) -func CleanupBinaries() error { - var errs []error - - if err := os.Remove(audioOutputBinPath); err != nil && !os.IsNotExist(err) { - errs = append(errs, fmt.Errorf("failed to remove audio output binary: %w", err)) - } - - if err := os.Remove(audioInputBinPath); err != nil && !os.IsNotExist(err) { - errs = append(errs, fmt.Errorf("failed to remove audio input binary: %w", err)) - } - - // Try to remove directory (will only succeed if empty) - os.Remove(audioBinDir) - - if len(errs) > 0 { - return fmt.Errorf("cleanup errors: %v", errs) - } - - return nil -} - -// GetBinaryInfo returns information about embedded binaries -func GetBinaryInfo() map[string]int { - return map[string]int{ - "audio_output_size": len(audioOutputBinary), - "audio_input_size": len(audioInputBinary), - } -} - // init ensures binaries are extracted when package is imported func init() { // Extract binaries on package initialization diff --git a/internal/audio/ipc_unified.go b/internal/audio/ipc_unified.go index 9024863b..5e42d388 100644 --- a/internal/audio/ipc_unified.go +++ b/internal/audio/ipc_unified.go @@ -114,9 +114,10 @@ type UnifiedAudioServer struct { wg sync.WaitGroup // Wait group for goroutine coordination // Configuration - socketPath string - magicNumber uint32 - socketBufferConfig SocketBufferConfig + socketPath string + magicNumber uint32 + sendBufferSize int + recvBufferSize int } // NewUnifiedAudioServer creates a new unified audio server @@ -143,7 +144,8 @@ func NewUnifiedAudioServer(isInput bool) (*UnifiedAudioServer, error) { magicNumber: magicNumber, messageChan: make(chan *UnifiedIPCMessage, Config.ChannelBufferSize), processChan: make(chan *UnifiedIPCMessage, Config.ChannelBufferSize), - socketBufferConfig: DefaultSocketBufferConfig(), + sendBufferSize: Config.SocketOptimalBuffer, + recvBufferSize: Config.SocketOptimalBuffer, } return server, nil diff --git a/internal/audio/mic_contention.go b/internal/audio/mic_contention.go deleted file mode 100644 index 08d60d3c..00000000 --- a/internal/audio/mic_contention.go +++ /dev/null @@ -1,127 +0,0 @@ -package audio - -import ( - "sync/atomic" - "time" - "unsafe" -) - -// MicrophoneContentionManager manages microphone access with cooldown periods -type MicrophoneContentionManager struct { - // Atomic fields MUST be first for ARM32 alignment (int64 fields need 8-byte alignment) - lastOpNano int64 - cooldownNanos int64 - operationID int64 - - lockPtr unsafe.Pointer -} - -func NewMicrophoneContentionManager(cooldown time.Duration) *MicrophoneContentionManager { - return &MicrophoneContentionManager{ - cooldownNanos: int64(cooldown), - } -} - -type OperationResult struct { - Allowed bool - RemainingCooldown time.Duration - OperationID int64 -} - -func (mcm *MicrophoneContentionManager) TryOperation() OperationResult { - now := time.Now().UnixNano() - cooldown := atomic.LoadInt64(&mcm.cooldownNanos) - lastOp := atomic.LoadInt64(&mcm.lastOpNano) - elapsed := now - lastOp - - if elapsed >= cooldown { - if atomic.CompareAndSwapInt64(&mcm.lastOpNano, lastOp, now) { - opID := atomic.AddInt64(&mcm.operationID, 1) - return OperationResult{ - Allowed: true, - RemainingCooldown: 0, - OperationID: opID, - } - } - // Retry once if CAS failed - lastOp = atomic.LoadInt64(&mcm.lastOpNano) - elapsed = now - lastOp - if elapsed >= cooldown && atomic.CompareAndSwapInt64(&mcm.lastOpNano, lastOp, now) { - opID := atomic.AddInt64(&mcm.operationID, 1) - return OperationResult{ - Allowed: true, - RemainingCooldown: 0, - OperationID: opID, - } - } - } - - remaining := time.Duration(cooldown - elapsed) - if remaining < 0 { - remaining = 0 - } - - return OperationResult{ - Allowed: false, - RemainingCooldown: remaining, - OperationID: atomic.LoadInt64(&mcm.operationID), - } -} - -func (mcm *MicrophoneContentionManager) SetCooldown(cooldown time.Duration) { - atomic.StoreInt64(&mcm.cooldownNanos, int64(cooldown)) -} - -func (mcm *MicrophoneContentionManager) GetCooldown() time.Duration { - return time.Duration(atomic.LoadInt64(&mcm.cooldownNanos)) -} - -func (mcm *MicrophoneContentionManager) GetLastOperationTime() time.Time { - nanos := atomic.LoadInt64(&mcm.lastOpNano) - if nanos == 0 { - return time.Time{} - } - return time.Unix(0, nanos) -} - -func (mcm *MicrophoneContentionManager) GetOperationCount() int64 { - return atomic.LoadInt64(&mcm.operationID) -} - -func (mcm *MicrophoneContentionManager) Reset() { - atomic.StoreInt64(&mcm.lastOpNano, 0) - atomic.StoreInt64(&mcm.operationID, 0) -} - -var ( - globalMicContentionManager unsafe.Pointer - micContentionInitialized int32 -) - -func GetMicrophoneContentionManager() *MicrophoneContentionManager { - ptr := atomic.LoadPointer(&globalMicContentionManager) - if ptr != nil { - return (*MicrophoneContentionManager)(ptr) - } - - if atomic.CompareAndSwapInt32(&micContentionInitialized, 0, 1) { - manager := NewMicrophoneContentionManager(Config.MicContentionTimeout) - atomic.StorePointer(&globalMicContentionManager, unsafe.Pointer(manager)) - return manager - } - - ptr = atomic.LoadPointer(&globalMicContentionManager) - if ptr != nil { - return (*MicrophoneContentionManager)(ptr) - } - - return NewMicrophoneContentionManager(Config.MicContentionTimeout) -} - -func TryMicrophoneOperation() OperationResult { - return GetMicrophoneContentionManager().TryOperation() -} - -func SetMicrophoneCooldown(cooldown time.Duration) { - GetMicrophoneContentionManager().SetCooldown(cooldown) -} diff --git a/internal/audio/socket_buffer.go b/internal/audio/socket_buffer.go deleted file mode 100644 index e6a5512e..00000000 --- a/internal/audio/socket_buffer.go +++ /dev/null @@ -1,166 +0,0 @@ -package audio - -import ( - "fmt" - "net" - "syscall" -) - -// Socket buffer sizes are now centralized in config_constants.go - -// SocketBufferConfig holds socket buffer configuration -type SocketBufferConfig struct { - SendBufferSize int - RecvBufferSize int - Enabled bool -} - -// DefaultSocketBufferConfig returns the default socket buffer configuration -func DefaultSocketBufferConfig() SocketBufferConfig { - return SocketBufferConfig{ - SendBufferSize: Config.SocketOptimalBuffer, - RecvBufferSize: Config.SocketOptimalBuffer, - Enabled: true, - } -} - -// HighLoadSocketBufferConfig returns configuration for high-load scenarios -func HighLoadSocketBufferConfig() SocketBufferConfig { - return SocketBufferConfig{ - SendBufferSize: Config.SocketMaxBuffer, - RecvBufferSize: Config.SocketMaxBuffer, - Enabled: true, - } -} - -// ConfigureSocketBuffers applies socket buffer configuration to a Unix socket connection -func ConfigureSocketBuffers(conn net.Conn, config SocketBufferConfig) error { - if !config.Enabled { - return nil - } - - if err := ValidateSocketBufferConfig(config); err != nil { - return fmt.Errorf("invalid socket buffer config: %w", err) - } - - unixConn, ok := conn.(*net.UnixConn) - if !ok { - return fmt.Errorf("connection is not a Unix socket") - } - - file, err := unixConn.File() - if err != nil { - return fmt.Errorf("failed to get socket file descriptor: %w", err) - } - defer file.Close() - - fd := int(file.Fd()) - - if config.SendBufferSize > 0 { - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, config.SendBufferSize); err != nil { - return fmt.Errorf("failed to set SO_SNDBUF to %d: %w", config.SendBufferSize, err) - } - } - - if config.RecvBufferSize > 0 { - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, config.RecvBufferSize); err != nil { - return fmt.Errorf("failed to set SO_RCVBUF to %d: %w", config.RecvBufferSize, err) - } - } - - return nil -} - -// GetSocketBufferSizes retrieves current socket buffer sizes -func GetSocketBufferSizes(conn net.Conn) (sendSize, recvSize int, err error) { - unixConn, ok := conn.(*net.UnixConn) - if !ok { - return 0, 0, fmt.Errorf("socket buffer query only supported for Unix sockets") - } - - file, err := unixConn.File() - if err != nil { - return 0, 0, fmt.Errorf("failed to get socket file descriptor: %w", err) - } - defer file.Close() - - fd := int(file.Fd()) - - // Get send buffer size - sendSize, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF) - if err != nil { - return 0, 0, fmt.Errorf("failed to get SO_SNDBUF: %w", err) - } - - // Get receive buffer size - recvSize, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF) - if err != nil { - return 0, 0, fmt.Errorf("failed to get SO_RCVBUF: %w", err) - } - - return sendSize, recvSize, nil -} - -// ValidateSocketBufferConfig validates socket buffer configuration parameters. -// -// Validation Rules: -// - If config.Enabled is false, no validation is performed (returns nil) -// - SendBufferSize must be >= SocketMinBuffer (default: 8192 bytes) -// - RecvBufferSize must be >= SocketMinBuffer (default: 8192 bytes) -// - SendBufferSize must be <= SocketMaxBuffer (default: 1048576 bytes) -// - RecvBufferSize must be <= SocketMaxBuffer (default: 1048576 bytes) -// -// Error Conditions: -// - Returns error if send buffer size is below minimum threshold -// - Returns error if receive buffer size is below minimum threshold -// - Returns error if send buffer size exceeds maximum threshold -// - Returns error if receive buffer size exceeds maximum threshold -// -// The validation ensures socket buffers are sized appropriately for audio streaming -// performance while preventing excessive memory usage. -func ValidateSocketBufferConfig(config SocketBufferConfig) error { - if !config.Enabled { - return nil - } - - minBuffer := Config.SocketMinBuffer - maxBuffer := Config.SocketMaxBuffer - - if config.SendBufferSize < minBuffer { - return fmt.Errorf("send buffer size validation failed: got %d bytes, minimum required %d bytes (configured range: %d-%d)", - config.SendBufferSize, minBuffer, minBuffer, maxBuffer) - } - - if config.RecvBufferSize < minBuffer { - return fmt.Errorf("receive buffer size validation failed: got %d bytes, minimum required %d bytes (configured range: %d-%d)", - config.RecvBufferSize, minBuffer, minBuffer, maxBuffer) - } - - if config.SendBufferSize > maxBuffer { - return fmt.Errorf("send buffer size validation failed: got %d bytes, maximum allowed %d bytes (configured range: %d-%d)", - config.SendBufferSize, maxBuffer, minBuffer, maxBuffer) - } - - if config.RecvBufferSize > maxBuffer { - return fmt.Errorf("receive buffer size validation failed: got %d bytes, maximum allowed %d bytes (configured range: %d-%d)", - config.RecvBufferSize, maxBuffer, minBuffer, maxBuffer) - } - - return nil -} - -// RecordSocketBufferMetrics records socket buffer metrics for monitoring -func RecordSocketBufferMetrics(conn net.Conn, component string) { - if conn == nil { - return - } - - // Get current socket buffer sizes - _, _, err := GetSocketBufferSizes(conn) - if err != nil { - // Log error but don't fail - return - } - - // Socket buffer sizes recorded for debugging if needed -} diff --git a/internal/audio/util_env.go b/internal/audio/util_env.go deleted file mode 100644 index 70b9c12c..00000000 --- a/internal/audio/util_env.go +++ /dev/null @@ -1,56 +0,0 @@ -package audio - -import ( - "os" - "strconv" - - "github.com/jetkvm/kvm/internal/logging" -) - -// getEnvInt reads an integer value from environment variable with fallback to default -func getEnvInt(key string, defaultValue int) int { - if value := os.Getenv(key); value != "" { - if intValue, err := strconv.Atoi(value); err == nil { - return intValue - } - } - return defaultValue -} - -// parseOpusConfig reads OPUS configuration from environment variables -// with fallback to default config values -func parseOpusConfig() (bitrate, complexity, vbr, signalType, bandwidth, dtx int) { - // Read configuration from environment variables with config defaults - bitrate = getEnvInt("JETKVM_OPUS_BITRATE", Config.CGOOpusBitrate) - complexity = getEnvInt("JETKVM_OPUS_COMPLEXITY", Config.CGOOpusComplexity) - vbr = getEnvInt("JETKVM_OPUS_VBR", Config.CGOOpusVBR) - signalType = getEnvInt("JETKVM_OPUS_SIGNAL_TYPE", Config.CGOOpusSignalType) - bandwidth = getEnvInt("JETKVM_OPUS_BANDWIDTH", Config.CGOOpusBandwidth) - dtx = getEnvInt("JETKVM_OPUS_DTX", Config.CGOOpusDTX) - - return bitrate, complexity, vbr, signalType, bandwidth, dtx -} - -// applyOpusConfig applies OPUS configuration to the global config -// with optional logging for the specified component -func applyOpusConfig(bitrate, complexity, vbr, signalType, bandwidth, dtx int, component string, enableLogging bool) { - config := Config - config.CGOOpusBitrate = bitrate - config.CGOOpusComplexity = complexity - config.CGOOpusVBR = vbr - config.CGOOpusSignalType = signalType - config.CGOOpusBandwidth = bandwidth - config.CGOOpusDTX = dtx - - if enableLogging { - logger := logging.GetDefaultLogger().With().Str("component", component).Logger() - logger.Info(). - Int("bitrate", bitrate). - Int("complexity", complexity). - Int("vbr", vbr). - Int("signal_type", signalType). - Int("bandwidth", bandwidth). - Int("dtx", dtx). - Msg("applied OPUS configuration") - } -} diff --git a/internal/native/cgo_linux.go b/internal/native/cgo_linux.go index 77b7d74f..c725b6aa 100644 --- a/internal/native/cgo_linux.go +++ b/internal/native/cgo_linux.go @@ -118,8 +118,6 @@ func uiInit(rotation uint16) { defer cgoLock.Unlock() cRotation := C.u_int16_t(rotation) - defer C.free(unsafe.Pointer(&cRotation)) - C.jetkvm_ui_init(cRotation) } @@ -350,8 +348,6 @@ func uiDispSetRotation(rotation uint16) (bool, error) { nativeLogger.Info().Uint16("rotation", rotation).Msg("setting rotation") cRotation := C.u_int16_t(rotation) - defer C.free(unsafe.Pointer(&cRotation)) - C.jetkvm_ui_set_rotation(cRotation) return true, nil } diff --git a/internal/usbgadget/usbgadget_hardware_test.go b/internal/usbgadget/usbgadget_hardware_test.go deleted file mode 100644 index 66b80b4f..00000000 --- a/internal/usbgadget/usbgadget_hardware_test.go +++ /dev/null @@ -1,333 +0,0 @@ -//go:build arm && linux - -package usbgadget - -import ( - "context" - "os" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -// Hardware integration tests for USB gadget operations -// These tests perform real hardware operations with proper cleanup and timeout handling - -var ( - testConfig = &Config{ - VendorId: "0x1d6b", // The Linux Foundation - ProductId: "0x0104", // Multifunction Composite Gadget - SerialNumber: "", - Manufacturer: "JetKVM", - Product: "USB Emulation Device", - strictMode: false, // Disable strict mode for hardware tests - } - testDevices = &Devices{ - AbsoluteMouse: true, - RelativeMouse: true, - Keyboard: true, - MassStorage: true, - } - testGadgetName = "jetkvm-test" -) - -func TestUsbGadgetHardwareInit(t *testing.T) { - if testing.Short() { - t.Skip("Skipping hardware test in short mode") - } - - // Create context with timeout to prevent hanging - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Ensure clean state before test - cleanupUsbGadget(t, testGadgetName) - - // Test USB gadget initialization with timeout - var gadget *UsbGadget - done := make(chan bool, 1) - var initErr error - - go func() { - defer func() { - if r := recover(); r != nil { - t.Logf("USB gadget initialization panicked: %v", r) - initErr = assert.AnError - } - done <- true - }() - - gadget = NewUsbGadget(testGadgetName, testDevices, testConfig, nil) - if gadget == nil { - initErr = assert.AnError - } - }() - - // Wait for initialization or timeout - select { - case <-done: - if initErr != nil { - t.Fatalf("USB gadget initialization failed: %v", initErr) - } - assert.NotNil(t, gadget, "USB gadget should be initialized") - case <-ctx.Done(): - t.Fatal("USB gadget initialization timed out") - } - - // Cleanup after test - defer func() { - if gadget != nil { - gadget.CloseHidFiles() - } - cleanupUsbGadget(t, testGadgetName) - }() - - // Validate gadget state - assert.NotNil(t, gadget, "USB gadget should not be nil") - validateHardwareState(t, gadget) - - // Test UDC binding state - bound, err := gadget.IsUDCBound() - assert.NoError(t, err, "Should be able to check UDC binding state") - t.Logf("UDC bound state: %v", bound) -} - -func TestUsbGadgetHardwareReconfiguration(t *testing.T) { - if testing.Short() { - t.Skip("Skipping hardware test in short mode") - } - - // Create context with timeout - ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) - defer cancel() - - // Ensure clean state - cleanupUsbGadget(t, testGadgetName) - - // Initialize first gadget - gadget1 := createUsbGadgetWithTimeout(t, ctx, testGadgetName, testDevices, testConfig) - defer func() { - if gadget1 != nil { - gadget1.CloseHidFiles() - } - }() - - // Validate initial state - assert.NotNil(t, gadget1, "First USB gadget should be initialized") - - // Close first gadget properly - gadget1.CloseHidFiles() - gadget1 = nil - - // Wait for cleanup to complete - time.Sleep(500 * time.Millisecond) - - // Test reconfiguration with different report descriptor - altGadgetConfig := make(map[string]gadgetConfigItem) - for k, v := range defaultGadgetConfig { - altGadgetConfig[k] = v - } - - // Modify absolute mouse configuration - oldAbsoluteMouseConfig := altGadgetConfig["absolute_mouse"] - oldAbsoluteMouseConfig.reportDesc = absoluteMouseCombinedReportDesc - altGadgetConfig["absolute_mouse"] = oldAbsoluteMouseConfig - - // Create second gadget with modified configuration - gadget2 := createUsbGadgetWithTimeoutAndConfig(t, ctx, testGadgetName, altGadgetConfig, testDevices, testConfig) - defer func() { - if gadget2 != nil { - gadget2.CloseHidFiles() - } - cleanupUsbGadget(t, testGadgetName) - }() - - assert.NotNil(t, gadget2, "Second USB gadget should be initialized") - validateHardwareState(t, gadget2) - - // Validate UDC binding after reconfiguration - udcs := getUdcs() - assert.NotEmpty(t, udcs, "Should have at least one UDC") - - if len(udcs) > 0 { - udc := udcs[0] - t.Logf("Available UDC: %s", udc) - - // Check UDC binding state - udcStr, err := os.ReadFile("/sys/kernel/config/usb_gadget/" + testGadgetName + "/UDC") - if err == nil { - t.Logf("UDC binding: %s", strings.TrimSpace(string(udcStr))) - } else { - t.Logf("Could not read UDC binding: %v", err) - } - } -} - -func TestUsbGadgetHardwareStressTest(t *testing.T) { - if testing.Short() { - t.Skip("Skipping stress test in short mode") - } - - // Create context with longer timeout for stress test - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - - // Ensure clean state - cleanupUsbGadget(t, testGadgetName) - - // Perform multiple rapid reconfigurations - for i := 0; i < 3; i++ { - t.Logf("Stress test iteration %d", i+1) - - // Create gadget - gadget := createUsbGadgetWithTimeout(t, ctx, testGadgetName, testDevices, testConfig) - if gadget == nil { - t.Fatalf("Failed to create USB gadget in iteration %d", i+1) - } - - // Validate gadget - assert.NotNil(t, gadget, "USB gadget should be created in iteration %d", i+1) - validateHardwareState(t, gadget) - - // Test basic operations - bound, err := gadget.IsUDCBound() - assert.NoError(t, err, "Should be able to check UDC state in iteration %d", i+1) - t.Logf("Iteration %d: UDC bound = %v", i+1, bound) - - // Cleanup - gadget.CloseHidFiles() - gadget = nil - - // Wait between iterations - time.Sleep(1 * time.Second) - - // Check for timeout - select { - case <-ctx.Done(): - t.Fatal("Stress test timed out") - default: - // Continue - } - } - - // Final cleanup - cleanupUsbGadget(t, testGadgetName) -} - -// Helper functions for hardware tests - -// createUsbGadgetWithTimeout creates a USB gadget with timeout protection -func createUsbGadgetWithTimeout(t *testing.T, ctx context.Context, name string, devices *Devices, config *Config) *UsbGadget { - return createUsbGadgetWithTimeoutAndConfig(t, ctx, name, defaultGadgetConfig, devices, config) -} - -// createUsbGadgetWithTimeoutAndConfig creates a USB gadget with custom config and timeout protection -func createUsbGadgetWithTimeoutAndConfig(t *testing.T, ctx context.Context, name string, gadgetConfig map[string]gadgetConfigItem, devices *Devices, config *Config) *UsbGadget { - var gadget *UsbGadget - done := make(chan bool, 1) - var createErr error - - go func() { - defer func() { - if r := recover(); r != nil { - t.Logf("USB gadget creation panicked: %v", r) - createErr = assert.AnError - } - done <- true - }() - - gadget = newUsbGadget(name, gadgetConfig, devices, config, nil) - if gadget == nil { - createErr = assert.AnError - } - }() - - // Wait for creation or timeout - select { - case <-done: - if createErr != nil { - t.Logf("USB gadget creation failed: %v", createErr) - return nil - } - return gadget - case <-ctx.Done(): - t.Logf("USB gadget creation timed out") - return nil - } -} - -// cleanupUsbGadget ensures clean state by removing any existing USB gadget configuration -func cleanupUsbGadget(t *testing.T, name string) { - t.Logf("Cleaning up USB gadget: %s", name) - - // Try to unbind UDC first - udcPath := "/sys/kernel/config/usb_gadget/" + name + "/UDC" - if _, err := os.Stat(udcPath); err == nil { - // Read current UDC binding - if udcData, err := os.ReadFile(udcPath); err == nil && len(strings.TrimSpace(string(udcData))) > 0 { - // Unbind UDC - if err := os.WriteFile(udcPath, []byte(""), 0644); err != nil { - t.Logf("Failed to unbind UDC: %v", err) - } else { - t.Logf("Successfully unbound UDC") - // Wait for unbinding to complete - time.Sleep(200 * time.Millisecond) - } - } - } - - // Remove gadget directory if it exists - gadgetPath := "/sys/kernel/config/usb_gadget/" + name - if _, err := os.Stat(gadgetPath); err == nil { - // Try to remove configuration links first - configPath := gadgetPath + "/configs/c.1" - if entries, err := os.ReadDir(configPath); err == nil { - for _, entry := range entries { - if entry.Type()&os.ModeSymlink != 0 { - linkPath := configPath + "/" + entry.Name() - if err := os.Remove(linkPath); err != nil { - t.Logf("Failed to remove config link %s: %v", linkPath, err) - } - } - } - } - - // Remove the gadget directory (this should cascade remove everything) - if err := os.RemoveAll(gadgetPath); err != nil { - t.Logf("Failed to remove gadget directory: %v", err) - } else { - t.Logf("Successfully removed gadget directory") - } - } - - // Wait for cleanup to complete - time.Sleep(300 * time.Millisecond) -} - -// validateHardwareState checks the current hardware state -func validateHardwareState(t *testing.T, gadget *UsbGadget) { - if gadget == nil { - return - } - - // Check UDC binding state - bound, err := gadget.IsUDCBound() - if err != nil { - t.Logf("Warning: Could not check UDC binding state: %v", err) - } else { - t.Logf("UDC bound: %v", bound) - } - - // Check available UDCs - udcs := getUdcs() - t.Logf("Available UDCs: %v", udcs) - - // Check configfs mount - if _, err := os.Stat("/sys/kernel/config"); err != nil { - t.Logf("Warning: configfs not available: %v", err) - } else { - t.Logf("configfs is available") - } -} diff --git a/internal/usbgadget/usbgadget_logic_test.go b/internal/usbgadget/usbgadget_logic_test.go deleted file mode 100644 index 454fbb09..00000000 --- a/internal/usbgadget/usbgadget_logic_test.go +++ /dev/null @@ -1,437 +0,0 @@ -package usbgadget - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -// Unit tests for USB gadget configuration logic without hardware dependencies -// These tests follow the pattern of audio tests - testing business logic and validation - -func TestUsbGadgetConfigValidation(t *testing.T) { - tests := []struct { - name string - config *Config - devices *Devices - expected bool - }{ - { - name: "ValidConfig", - config: &Config{ - VendorId: "0x1d6b", - ProductId: "0x0104", - Manufacturer: "JetKVM", - Product: "USB Emulation Device", - }, - devices: &Devices{ - Keyboard: true, - AbsoluteMouse: true, - RelativeMouse: true, - MassStorage: true, - }, - expected: true, - }, - { - name: "InvalidVendorId", - config: &Config{ - VendorId: "invalid", - ProductId: "0x0104", - Manufacturer: "JetKVM", - Product: "USB Emulation Device", - }, - devices: &Devices{ - Keyboard: true, - }, - expected: false, - }, - { - name: "EmptyManufacturer", - config: &Config{ - VendorId: "0x1d6b", - ProductId: "0x0104", - Manufacturer: "", - Product: "USB Emulation Device", - }, - devices: &Devices{ - Keyboard: true, - }, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateUsbGadgetConfiguration(tt.config, tt.devices) - if tt.expected { - assert.NoError(t, err, "Configuration should be valid") - } else { - assert.Error(t, err, "Configuration should be invalid") - } - }) - } -} - -func TestUsbGadgetDeviceConfiguration(t *testing.T) { - tests := []struct { - name string - devices *Devices - expectedConfigs []string - }{ - { - name: "AllDevicesEnabled", - devices: &Devices{ - Keyboard: true, - AbsoluteMouse: true, - RelativeMouse: true, - MassStorage: true, - Audio: true, - }, - expectedConfigs: []string{"keyboard", "absolute_mouse", "relative_mouse", "mass_storage_base", "audio"}, - }, - { - name: "OnlyKeyboard", - devices: &Devices{ - Keyboard: true, - }, - expectedConfigs: []string{"keyboard"}, - }, - { - name: "MouseOnly", - devices: &Devices{ - AbsoluteMouse: true, - RelativeMouse: true, - }, - expectedConfigs: []string{"absolute_mouse", "relative_mouse"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - configs := getEnabledGadgetConfigs(tt.devices) - assert.ElementsMatch(t, tt.expectedConfigs, configs, "Enabled configs should match expected") - }) - } -} - -func TestUsbGadgetStateTransition(t *testing.T) { - if testing.Short() { - t.Skip("Skipping state transition test in short mode") - } - - tests := []struct { - name string - initialDevices *Devices - newDevices *Devices - expectedTransition string - }{ - { - name: "EnableAudio", - initialDevices: &Devices{ - Keyboard: true, - AbsoluteMouse: true, - Audio: false, - }, - newDevices: &Devices{ - Keyboard: true, - AbsoluteMouse: true, - Audio: true, - }, - expectedTransition: "audio_enabled", - }, - { - name: "DisableKeyboard", - initialDevices: &Devices{ - Keyboard: true, - AbsoluteMouse: true, - }, - newDevices: &Devices{ - Keyboard: false, - AbsoluteMouse: true, - }, - expectedTransition: "keyboard_disabled", - }, - { - name: "NoChange", - initialDevices: &Devices{ - Keyboard: true, - AbsoluteMouse: true, - }, - newDevices: &Devices{ - Keyboard: true, - AbsoluteMouse: true, - }, - expectedTransition: "no_change", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - transition := simulateUsbGadgetStateTransition(ctx, tt.initialDevices, tt.newDevices) - assert.Equal(t, tt.expectedTransition, transition, "State transition should match expected") - }) - } -} - -func TestUsbGadgetConfigurationTimeout(t *testing.T) { - if testing.Short() { - t.Skip("Skipping timeout test in short mode") - } - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - // Test that configuration validation completes within reasonable time - start := time.Now() - - // Simulate multiple rapid configuration changes - for i := 0; i < 20; i++ { - devices := &Devices{ - Keyboard: i%2 == 0, - AbsoluteMouse: i%3 == 0, - RelativeMouse: i%4 == 0, - MassStorage: i%5 == 0, - Audio: i%6 == 0, - } - - config := &Config{ - VendorId: "0x1d6b", - ProductId: "0x0104", - Manufacturer: "JetKVM", - Product: "USB Emulation Device", - } - - err := validateUsbGadgetConfiguration(config, devices) - assert.NoError(t, err, "Configuration validation should not fail") - - // Ensure we don't timeout - select { - case <-ctx.Done(): - t.Fatal("USB gadget configuration test timed out") - default: - // Continue - } - } - - elapsed := time.Since(start) - t.Logf("USB gadget configuration test completed in %v", elapsed) - assert.Less(t, elapsed, 2*time.Second, "Configuration validation should complete quickly") -} - -func TestReportDescriptorValidation(t *testing.T) { - tests := []struct { - name string - reportDesc []byte - expected bool - }{ - { - name: "ValidKeyboardReportDesc", - reportDesc: keyboardReportDesc, - expected: true, - }, - { - name: "ValidAbsoluteMouseReportDesc", - reportDesc: absoluteMouseCombinedReportDesc, - expected: true, - }, - { - name: "ValidRelativeMouseReportDesc", - reportDesc: relativeMouseCombinedReportDesc, - expected: true, - }, - { - name: "EmptyReportDesc", - reportDesc: []byte{}, - expected: false, - }, - { - name: "InvalidReportDesc", - reportDesc: []byte{0xFF, 0xFF, 0xFF}, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateReportDescriptor(tt.reportDesc) - if tt.expected { - assert.NoError(t, err, "Report descriptor should be valid") - } else { - assert.Error(t, err, "Report descriptor should be invalid") - } - }) - } -} - -// Helper functions for simulation (similar to audio tests) - -// validateUsbGadgetConfiguration simulates the validation that happens in production -func validateUsbGadgetConfiguration(config *Config, devices *Devices) error { - if config == nil { - return assert.AnError - } - - // Validate vendor ID format - if config.VendorId == "" || len(config.VendorId) < 4 { - return assert.AnError - } - if config.VendorId != "" && config.VendorId[:2] != "0x" { - return assert.AnError - } - - // Validate product ID format - if config.ProductId == "" || len(config.ProductId) < 4 { - return assert.AnError - } - if config.ProductId != "" && config.ProductId[:2] != "0x" { - return assert.AnError - } - - // Validate required fields - if config.Manufacturer == "" { - return assert.AnError - } - if config.Product == "" { - return assert.AnError - } - - // Note: Allow configurations with no devices enabled for testing purposes - // In production, this would typically be validated at a higher level - - return nil -} - -// getEnabledGadgetConfigs returns the list of enabled gadget configurations -func getEnabledGadgetConfigs(devices *Devices) []string { - var configs []string - - if devices.Keyboard { - configs = append(configs, "keyboard") - } - if devices.AbsoluteMouse { - configs = append(configs, "absolute_mouse") - } - if devices.RelativeMouse { - configs = append(configs, "relative_mouse") - } - if devices.MassStorage { - configs = append(configs, "mass_storage_base") - } - if devices.Audio { - configs = append(configs, "audio") - } - - return configs -} - -// simulateUsbGadgetStateTransition simulates the state management during USB reconfiguration -func simulateUsbGadgetStateTransition(ctx context.Context, initial, new *Devices) string { - // Check for audio changes - if initial.Audio != new.Audio { - if new.Audio { - // Simulate enabling audio device - time.Sleep(5 * time.Millisecond) - return "audio_enabled" - } else { - // Simulate disabling audio device - time.Sleep(5 * time.Millisecond) - return "audio_disabled" - } - } - - // Check for keyboard changes - if initial.Keyboard != new.Keyboard { - if new.Keyboard { - time.Sleep(5 * time.Millisecond) - return "keyboard_enabled" - } else { - time.Sleep(5 * time.Millisecond) - return "keyboard_disabled" - } - } - - // Check for mouse changes - if initial.AbsoluteMouse != new.AbsoluteMouse || initial.RelativeMouse != new.RelativeMouse { - time.Sleep(5 * time.Millisecond) - return "mouse_changed" - } - - // Check for mass storage changes - if initial.MassStorage != new.MassStorage { - time.Sleep(5 * time.Millisecond) - return "mass_storage_changed" - } - - return "no_change" -} - -// validateReportDescriptor simulates HID report descriptor validation -func validateReportDescriptor(reportDesc []byte) error { - if len(reportDesc) == 0 { - return assert.AnError - } - - // Basic HID report descriptor validation - // Check for valid usage page (0x05) - found := false - for i := 0; i < len(reportDesc)-1; i++ { - if reportDesc[i] == 0x05 { - found = true - break - } - } - if !found { - return assert.AnError - } - - return nil -} - -// Benchmark tests - -func BenchmarkValidateUsbGadgetConfiguration(b *testing.B) { - config := &Config{ - VendorId: "0x1d6b", - ProductId: "0x0104", - Manufacturer: "JetKVM", - Product: "USB Emulation Device", - } - devices := &Devices{ - Keyboard: true, - AbsoluteMouse: true, - RelativeMouse: true, - MassStorage: true, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = validateUsbGadgetConfiguration(config, devices) - } -} - -func BenchmarkGetEnabledGadgetConfigs(b *testing.B) { - devices := &Devices{ - Keyboard: true, - AbsoluteMouse: true, - RelativeMouse: true, - MassStorage: true, - Audio: true, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = getEnabledGadgetConfigs(devices) - } -} - -func BenchmarkValidateReportDescriptor(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = validateReportDescriptor(keyboardReportDesc) - } -} diff --git a/test_usbgadget b/test_usbgadget deleted file mode 100755 index 75835678..00000000 Binary files a/test_usbgadget and /dev/null differ