[WIP] Updates: reduce PR complexity

This commit is contained in:
Alex P 2025-10-01 22:54:37 +03:00
parent 257993ec20
commit 178c7486cc
16 changed files with 40 additions and 1572 deletions

7
.gitignore vendored
View File

@ -20,3 +20,10 @@ node_modules
#internal/native/include #internal/native/include
#internal/native/lib #internal/native/lib
internal/audio/bin/ internal/audio/bin/
# backup files
*.bak
# core dumps
core
core.*

View File

@ -99,6 +99,7 @@ build_audio_output: build_audio_deps
-o $(BIN_DIR)/jetkvm_audio_output \ -o $(BIN_DIR)/jetkvm_audio_output \
internal/audio/c/jetkvm_audio_output.c \ internal/audio/c/jetkvm_audio_output.c \
internal/audio/c/ipc_protocol.c \ internal/audio/c/ipc_protocol.c \
internal/audio/c/audio_common.c \
internal/audio/c/audio.c \ internal/audio/c/audio.c \
$(CGO_LDFLAGS); \ $(CGO_LDFLAGS); \
fi fi
@ -114,6 +115,7 @@ build_audio_input: build_audio_deps
-o $(BIN_DIR)/jetkvm_audio_input \ -o $(BIN_DIR)/jetkvm_audio_input \
internal/audio/c/jetkvm_audio_input.c \ internal/audio/c/jetkvm_audio_input.c \
internal/audio/c/ipc_protocol.c \ internal/audio/c/ipc_protocol.c \
internal/audio/c/audio_common.c \
internal/audio/c/audio.c \ internal/audio/c/audio.c \
$(CGO_LDFLAGS); \ $(CGO_LDFLAGS); \
fi fi

View File

@ -12,7 +12,7 @@ var audioControlService *audio.AudioControlService
func ensureAudioControlService() *audio.AudioControlService { func ensureAudioControlService() *audio.AudioControlService {
if audioControlService == nil { if audioControlService == nil {
sessionProvider := &SessionProviderImpl{} sessionProvider := &KVMSessionProvider{}
audioControlService = audio.NewAudioControlService(sessionProvider, logger) audioControlService = audio.NewAudioControlService(sessionProvider, logger)
// Set up RPC callback function for the audio package // Set up RPC callback function for the audio package

View File

@ -1,24 +0,0 @@
package kvm
import "github.com/jetkvm/kvm/internal/audio"
// SessionProviderImpl implements the audio.SessionProvider interface
type SessionProviderImpl struct{}
// NewSessionProvider creates a new session provider
func NewSessionProvider() *SessionProviderImpl {
return &SessionProviderImpl{}
}
// IsSessionActive returns whether there's an active session
func (sp *SessionProviderImpl) IsSessionActive() bool {
return currentSession != nil
}
// GetAudioInputManager returns the current session's audio input manager
func (sp *SessionProviderImpl) GetAudioInputManager() *audio.AudioInputManager {
if currentSession == nil {
return nil
}
return currentSession.AudioInputManager
}

View File

@ -155,36 +155,6 @@ static inline void simd_clear_samples_s16(short *buffer, int samples) {
} }
} }
/**
* Interleave L/R channels using NEON (8 frames/iteration)
* Converts separate left/right buffers to interleaved stereo (LRLRLR...)
* @param left Left channel samples
* @param right Right channel samples
* @param output Interleaved stereo output buffer
* @param frames Number of stereo frames to process
*/
static inline void simd_interleave_stereo_s16(const short *left, const short *right,
short *output, int frames) {
simd_init_once();
int simd_frames = frames & ~7;
// SIMD path: interleave 8 frames (16 samples) per iteration
for (int i = 0; i < simd_frames; i += 8) {
int16x8_t left_vec = vld1q_s16(&left[i]);
int16x8_t right_vec = vld1q_s16(&right[i]);
int16x8x2_t interleaved = vzipq_s16(left_vec, right_vec);
vst1q_s16(&output[i * 2], interleaved.val[0]);
vst1q_s16(&output[i * 2 + 8], interleaved.val[1]);
}
// Scalar path: handle remaining frames
for (int i = simd_frames; i < frames; i++) {
output[i * 2] = left[i];
output[i * 2 + 1] = right[i];
}
}
/** /**
* Apply gain using NEON Q15 fixed-point math (8 samples/iteration) * Apply gain using NEON Q15 fixed-point math (8 samples/iteration)
* Uses vqrdmulhq_s16 for single-instruction saturating rounded multiply-high * Uses vqrdmulhq_s16 for single-instruction saturating rounded multiply-high
@ -214,234 +184,6 @@ static inline void simd_scale_volume_s16(short *samples, int count, float volume
} }
} }
/**
* Byte-swap 16-bit samples using NEON (8 samples/iteration)
* Converts between little-endian and big-endian formats
* @param samples Audio buffer to byte-swap in-place
* @param count Number of samples to process
*/
static inline void simd_swap_endian_s16(short *samples, int count) {
int simd_count = count & ~7;
// SIMD path: swap 8 samples per iteration
for (int i = 0; i < simd_count; i += 8) {
uint16x8_t samples_vec = vld1q_u16((uint16_t*)&samples[i]);
uint8x16_t samples_u8 = vreinterpretq_u8_u16(samples_vec);
uint8x16_t swapped_u8 = vrev16q_u8(samples_u8);
uint16x8_t swapped = vreinterpretq_u16_u8(swapped_u8);
vst1q_u16((uint16_t*)&samples[i], swapped);
}
// Scalar path: handle remaining samples
for (int i = simd_count; i < count; i++) {
samples[i] = __builtin_bswap16(samples[i]);
}
}
/**
* Convert S16 to float using NEON (4 samples/iteration)
* Converts 16-bit signed integers to normalized float [-1.0, 1.0]
* @param input S16 audio samples
* @param output Float output buffer
* @param count Number of samples to convert
*/
static inline void simd_s16_to_float(const short *input, float *output, int count) {
const float scale = 1.0f / 32768.0f;
int simd_count = count & ~3;
float32x4_t scale_vec = vdupq_n_f32(scale);
// SIMD path: convert 4 samples per iteration
for (int i = 0; i < simd_count; i += 4) {
int16x4_t s16_data = vld1_s16(input + i);
int32x4_t s32_data = vmovl_s16(s16_data);
float32x4_t float_data = vcvtq_f32_s32(s32_data);
float32x4_t scaled = vmulq_f32(float_data, scale_vec);
vst1q_f32(output + i, scaled);
}
// Scalar path: handle remaining samples
for (int i = simd_count; i < count; i++) {
output[i] = (float)input[i] * scale;
}
}
/**
* Convert float to S16 using NEON (4 samples/iteration)
* Converts normalized float [-1.0, 1.0] to 16-bit signed integers with saturation
* @param input Float audio samples
* @param output S16 output buffer
* @param count Number of samples to convert
*/
static inline void simd_float_to_s16(const float *input, short *output, int count) {
const float scale = 32767.0f;
int simd_count = count & ~3;
float32x4_t scale_vec = vdupq_n_f32(scale);
// SIMD path: convert 4 samples per iteration with saturation
for (int i = 0; i < simd_count; i += 4) {
float32x4_t float_data = vld1q_f32(input + i);
float32x4_t scaled = vmulq_f32(float_data, scale_vec);
int32x4_t s32_data = vcvtq_s32_f32(scaled);
int16x4_t s16_data = vqmovn_s32(s32_data);
vst1_s16(output + i, s16_data);
}
// Scalar path: handle remaining samples with clamping
for (int i = simd_count; i < count; i++) {
float scaled = input[i] * scale;
output[i] = (short)__builtin_fmaxf(__builtin_fminf(scaled, 32767.0f), -32768.0f);
}
}
/**
* Mono stereo (duplicate samples) using NEON (4 frames/iteration)
* Duplicates mono samples to both L and R channels
* @param mono Mono input buffer
* @param stereo Stereo output buffer
* @param frames Number of frames to process
*/
static inline void simd_mono_to_stereo_s16(const short *mono, short *stereo, int frames) {
int simd_frames = frames & ~3;
// SIMD path: duplicate 4 frames (8 samples) per iteration
for (int i = 0; i < simd_frames; i += 4) {
int16x4_t mono_data = vld1_s16(mono + i);
int16x4x2_t stereo_data = {mono_data, mono_data};
vst2_s16(stereo + i * 2, stereo_data);
}
// Scalar path: handle remaining frames
for (int i = simd_frames; i < frames; i++) {
stereo[i * 2] = mono[i];
stereo[i * 2 + 1] = mono[i];
}
}
/**
* Stereo mono (average L+R) using NEON (4 frames/iteration)
* Downmixes stereo to mono by averaging left and right channels
* @param stereo Interleaved stereo input buffer
* @param mono Mono output buffer
* @param frames Number of frames to process
*/
static inline void simd_stereo_to_mono_s16(const short *stereo, short *mono, int frames) {
int simd_frames = frames & ~3;
// SIMD path: average 4 stereo frames per iteration
for (int i = 0; i < simd_frames; i += 4) {
int16x4x2_t stereo_data = vld2_s16(stereo + i * 2);
int32x4_t left_wide = vmovl_s16(stereo_data.val[0]);
int32x4_t right_wide = vmovl_s16(stereo_data.val[1]);
int32x4_t sum = vaddq_s32(left_wide, right_wide);
int32x4_t avg = vshrq_n_s32(sum, 1);
int16x4_t mono_data = vqmovn_s32(avg);
vst1_s16(mono + i, mono_data);
}
// Scalar path: handle remaining frames
for (int i = simd_frames; i < frames; i++) {
mono[i] = (stereo[i * 2] + stereo[i * 2 + 1]) / 2;
}
}
/**
* Apply L/R balance using NEON (4 frames/iteration)
* Adjusts stereo balance: negative = more left, positive = more right
* @param stereo Interleaved stereo buffer to modify in-place
* @param frames Number of stereo frames to process
* @param balance Balance factor [-1.0 = full left, 0.0 = center, 1.0 = full right]
*/
static inline void simd_apply_stereo_balance_s16(short *stereo, int frames, float balance) {
int simd_frames = frames & ~3;
float left_gain = balance <= 0.0f ? 1.0f : 1.0f - balance;
float right_gain = balance >= 0.0f ? 1.0f : 1.0f + balance;
float32x4_t left_gain_vec = vdupq_n_f32(left_gain);
float32x4_t right_gain_vec = vdupq_n_f32(right_gain);
// SIMD path: apply balance to 4 stereo frames per iteration
for (int i = 0; i < simd_frames; i += 4) {
int16x4x2_t stereo_data = vld2_s16(stereo + i * 2);
int32x4_t left_wide = vmovl_s16(stereo_data.val[0]);
int32x4_t right_wide = vmovl_s16(stereo_data.val[1]);
float32x4_t left_float = vcvtq_f32_s32(left_wide);
float32x4_t right_float = vcvtq_f32_s32(right_wide);
left_float = vmulq_f32(left_float, left_gain_vec);
right_float = vmulq_f32(right_float, right_gain_vec);
int32x4_t left_result = vcvtq_s32_f32(left_float);
int32x4_t right_result = vcvtq_s32_f32(right_float);
stereo_data.val[0] = vqmovn_s32(left_result);
stereo_data.val[1] = vqmovn_s32(right_result);
vst2_s16(stereo + i * 2, stereo_data);
}
// Scalar path: handle remaining frames
for (int i = simd_frames; i < frames; i++) {
stereo[i * 2] = (short)(stereo[i * 2] * left_gain);
stereo[i * 2 + 1] = (short)(stereo[i * 2 + 1] * right_gain);
}
}
/**
* Deinterleave stereo L/R channels using NEON (4 frames/iteration)
* Separates interleaved stereo (LRLRLR...) into separate L and R buffers
* @param interleaved Interleaved stereo input buffer
* @param left Left channel output buffer
* @param right Right channel output buffer
* @param frames Number of stereo frames to process
*/
static inline void simd_deinterleave_stereo_s16(const short *interleaved, short *left,
short *right, int frames) {
int simd_frames = frames & ~3;
// SIMD path: deinterleave 4 frames (8 samples) per iteration
for (int i = 0; i < simd_frames; i += 4) {
int16x4x2_t stereo_data = vld2_s16(interleaved + i * 2);
vst1_s16(left + i, stereo_data.val[0]);
vst1_s16(right + i, stereo_data.val[1]);
}
// Scalar path: handle remaining frames
for (int i = simd_frames; i < frames; i++) {
left[i] = interleaved[i * 2];
right[i] = interleaved[i * 2 + 1];
}
}
/**
* Find max absolute sample value for silence detection using NEON (8 samples/iteration)
* Used to detect silence (threshold < 50 = ~0.15% max volume) and audio discontinuities
* @param samples Audio buffer to analyze
* @param count Number of samples to process
* @return Maximum absolute sample value in the buffer
*/
static inline short simd_find_max_abs_s16(const short *samples, int count) {
int simd_count = count & ~7;
int16x8_t max_vec = vdupq_n_s16(0);
// SIMD path: find max of 8 samples per iteration
for (int i = 0; i < simd_count; i += 8) {
int16x8_t samples_vec = vld1q_s16(&samples[i]);
int16x8_t abs_vec = vabsq_s16(samples_vec);
max_vec = vmaxq_s16(max_vec, abs_vec);
}
// Horizontal reduction: extract single max value from vector
int16x4_t max_half = vmax_s16(vget_low_s16(max_vec), vget_high_s16(max_vec));
int16x4_t max_folded = vpmax_s16(max_half, max_half);
max_folded = vpmax_s16(max_folded, max_folded);
short max_sample = vget_lane_s16(max_folded, 0);
// Scalar path: handle remaining samples
for (int i = simd_count; i < count; i++) {
short abs_sample = samples[i] < 0 ? -samples[i] : samples[i];
if (abs_sample > max_sample) {
max_sample = abs_sample;
}
}
return max_sample;
}
// ============================================================================ // ============================================================================
// INITIALIZATION STATE TRACKING // INITIALIZATION STATE TRACKING
// ============================================================================ // ============================================================================

View File

@ -11,6 +11,7 @@
*/ */
#include "ipc_protocol.h" #include "ipc_protocol.h"
#include "audio_common.h"
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
@ -48,80 +49,25 @@ typedef struct {
int trace_logging; // Enable trace logging (default: 0) int trace_logging; // Enable trace logging (default: 0)
} audio_config_t; } audio_config_t;
// ============================================================================
// SIGNAL HANDLERS
// ============================================================================
static void signal_handler(int signo) {
if (signo == SIGTERM || signo == SIGINT) {
printf("Audio input server: Received signal %d, shutting down...\n", signo);
g_running = 0;
}
}
static void setup_signal_handlers(void) {
struct sigaction sa;
memset(&sa, 0, sizeof(sa));
sa.sa_handler = signal_handler;
sigemptyset(&sa.sa_mask);
sa.sa_flags = 0;
sigaction(SIGTERM, &sa, NULL);
sigaction(SIGINT, &sa, NULL);
// Ignore SIGPIPE
signal(SIGPIPE, SIG_IGN);
}
// ============================================================================ // ============================================================================
// CONFIGURATION PARSING // CONFIGURATION PARSING
// ============================================================================ // ============================================================================
static int parse_env_int(const char *name, int default_value) {
const char *str = getenv(name);
if (str == NULL || str[0] == '\0') {
return default_value;
}
return atoi(str);
}
static const char* parse_env_string(const char *name, const char *default_value) {
const char *str = getenv(name);
if (str == NULL || str[0] == '\0') {
return default_value;
}
return str;
}
static int is_trace_enabled(void) {
const char *pion_trace = getenv("PION_LOG_TRACE");
if (pion_trace == NULL) {
return 0;
}
// Check if "audio" is in comma-separated list
if (strstr(pion_trace, "audio") != NULL) {
return 1;
}
return 0;
}
static void load_audio_config(audio_config_t *config) { static void load_audio_config(audio_config_t *config) {
// ALSA device configuration // ALSA device configuration
config->alsa_device = parse_env_string("ALSA_PLAYBACK_DEVICE", "hw:1,0"); config->alsa_device = audio_common_parse_env_string("ALSA_PLAYBACK_DEVICE", "hw:1,0");
// Opus configuration (informational only for decoder) // Opus configuration (informational only for decoder)
config->opus_bitrate = parse_env_int("OPUS_BITRATE", 96000); config->opus_bitrate = audio_common_parse_env_int("OPUS_BITRATE", 96000);
config->opus_complexity = parse_env_int("OPUS_COMPLEXITY", 1); config->opus_complexity = audio_common_parse_env_int("OPUS_COMPLEXITY", 1);
// Audio format // Audio format
config->sample_rate = parse_env_int("AUDIO_SAMPLE_RATE", 48000); config->sample_rate = audio_common_parse_env_int("AUDIO_SAMPLE_RATE", 48000);
config->channels = parse_env_int("AUDIO_CHANNELS", 2); config->channels = audio_common_parse_env_int("AUDIO_CHANNELS", 2);
config->frame_size = parse_env_int("AUDIO_FRAME_SIZE", 960); config->frame_size = audio_common_parse_env_int("AUDIO_FRAME_SIZE", 960);
// Logging // Logging
config->trace_logging = is_trace_enabled(); config->trace_logging = audio_common_is_trace_enabled();
// Log configuration // Log configuration
printf("Audio Input Server Configuration:\n"); printf("Audio Input Server Configuration:\n");
@ -269,7 +215,7 @@ int main(int argc, char **argv) {
printf("JetKVM Audio Input Server Starting...\n"); printf("JetKVM Audio Input Server Starting...\n");
// Setup signal handlers // Setup signal handlers
setup_signal_handlers(); audio_common_setup_signal_handlers(&g_running);
// Load configuration from environment // Load configuration from environment
audio_config_t config; audio_config_t config;

View File

@ -8,6 +8,7 @@
*/ */
#include "ipc_protocol.h" #include "ipc_protocol.h"
#include "audio_common.h"
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
@ -51,86 +52,31 @@ typedef struct {
int trace_logging; // Enable trace logging (default: 0) int trace_logging; // Enable trace logging (default: 0)
} audio_config_t; } audio_config_t;
// ============================================================================
// SIGNAL HANDLERS
// ============================================================================
static void signal_handler(int signo) {
if (signo == SIGTERM || signo == SIGINT) {
printf("Audio output server: Received signal %d, shutting down...\n", signo);
g_running = 0;
}
}
static void setup_signal_handlers(void) {
struct sigaction sa;
memset(&sa, 0, sizeof(sa));
sa.sa_handler = signal_handler;
sigemptyset(&sa.sa_mask);
sa.sa_flags = 0;
sigaction(SIGTERM, &sa, NULL);
sigaction(SIGINT, &sa, NULL);
// Ignore SIGPIPE (write to closed socket should return error, not crash)
signal(SIGPIPE, SIG_IGN);
}
// ============================================================================ // ============================================================================
// CONFIGURATION PARSING // CONFIGURATION PARSING
// ============================================================================ // ============================================================================
static int parse_env_int(const char *name, int default_value) {
const char *str = getenv(name);
if (str == NULL || str[0] == '\0') {
return default_value;
}
return atoi(str);
}
static const char* parse_env_string(const char *name, const char *default_value) {
const char *str = getenv(name);
if (str == NULL || str[0] == '\0') {
return default_value;
}
return str;
}
static int is_trace_enabled(void) {
const char *pion_trace = getenv("PION_LOG_TRACE");
if (pion_trace == NULL) {
return 0;
}
// Check if "audio" is in comma-separated list
if (strstr(pion_trace, "audio") != NULL) {
return 1;
}
return 0;
}
static void load_audio_config(audio_config_t *config) { static void load_audio_config(audio_config_t *config) {
// ALSA device configuration // ALSA device configuration
config->alsa_device = parse_env_string("ALSA_CAPTURE_DEVICE", "hw:0,0"); config->alsa_device = audio_common_parse_env_string("ALSA_CAPTURE_DEVICE", "hw:0,0");
// Opus encoder configuration // Opus encoder configuration
config->opus_bitrate = parse_env_int("OPUS_BITRATE", 96000); config->opus_bitrate = audio_common_parse_env_int("OPUS_BITRATE", 96000);
config->opus_complexity = parse_env_int("OPUS_COMPLEXITY", 1); config->opus_complexity = audio_common_parse_env_int("OPUS_COMPLEXITY", 1);
config->opus_vbr = parse_env_int("OPUS_VBR", 1); config->opus_vbr = audio_common_parse_env_int("OPUS_VBR", 1);
config->opus_vbr_constraint = parse_env_int("OPUS_VBR_CONSTRAINT", 1); config->opus_vbr_constraint = audio_common_parse_env_int("OPUS_VBR_CONSTRAINT", 1);
config->opus_signal_type = parse_env_int("OPUS_SIGNAL_TYPE", -1000); config->opus_signal_type = audio_common_parse_env_int("OPUS_SIGNAL_TYPE", -1000);
config->opus_bandwidth = parse_env_int("OPUS_BANDWIDTH", 1103); config->opus_bandwidth = audio_common_parse_env_int("OPUS_BANDWIDTH", 1103);
config->opus_dtx = parse_env_int("OPUS_DTX", 0); config->opus_dtx = audio_common_parse_env_int("OPUS_DTX", 0);
config->opus_lsb_depth = parse_env_int("OPUS_LSB_DEPTH", 16); config->opus_lsb_depth = audio_common_parse_env_int("OPUS_LSB_DEPTH", 16);
// Audio format // Audio format
config->sample_rate = parse_env_int("AUDIO_SAMPLE_RATE", 48000); config->sample_rate = audio_common_parse_env_int("AUDIO_SAMPLE_RATE", 48000);
config->channels = parse_env_int("AUDIO_CHANNELS", 2); config->channels = audio_common_parse_env_int("AUDIO_CHANNELS", 2);
config->frame_size = parse_env_int("AUDIO_FRAME_SIZE", 960); config->frame_size = audio_common_parse_env_int("AUDIO_FRAME_SIZE", 960);
// Logging // Logging
config->trace_logging = is_trace_enabled(); config->trace_logging = audio_common_is_trace_enabled();
// Log configuration // Log configuration
printf("Audio Output Server Configuration:\n"); printf("Audio Output Server Configuration:\n");
@ -310,7 +256,7 @@ int main(int argc, char **argv) {
printf("JetKVM Audio Output Server Starting...\n"); printf("JetKVM Audio Output Server Starting...\n");
// Setup signal handlers // Setup signal handlers
setup_signal_handlers(); audio_common_setup_signal_handlers(&g_running);
// Load configuration from environment // Load configuration from environment
audio_config_t config; audio_config_t config;

View File

@ -82,36 +82,6 @@ func GetAudioInputBinaryPath() string {
return audioInputBinPath return audioInputBinPath
} }
// CleanupBinaries removes extracted audio binaries (useful for cleanup/testing)
func CleanupBinaries() error {
var errs []error
if err := os.Remove(audioOutputBinPath); err != nil && !os.IsNotExist(err) {
errs = append(errs, fmt.Errorf("failed to remove audio output binary: %w", err))
}
if err := os.Remove(audioInputBinPath); err != nil && !os.IsNotExist(err) {
errs = append(errs, fmt.Errorf("failed to remove audio input binary: %w", err))
}
// Try to remove directory (will only succeed if empty)
os.Remove(audioBinDir)
if len(errs) > 0 {
return fmt.Errorf("cleanup errors: %v", errs)
}
return nil
}
// GetBinaryInfo returns information about embedded binaries
func GetBinaryInfo() map[string]int {
return map[string]int{
"audio_output_size": len(audioOutputBinary),
"audio_input_size": len(audioInputBinary),
}
}
// init ensures binaries are extracted when package is imported // init ensures binaries are extracted when package is imported
func init() { func init() {
// Extract binaries on package initialization // Extract binaries on package initialization

View File

@ -116,7 +116,8 @@ type UnifiedAudioServer struct {
// Configuration // Configuration
socketPath string socketPath string
magicNumber uint32 magicNumber uint32
socketBufferConfig SocketBufferConfig sendBufferSize int
recvBufferSize int
} }
// NewUnifiedAudioServer creates a new unified audio server // NewUnifiedAudioServer creates a new unified audio server
@ -143,7 +144,8 @@ func NewUnifiedAudioServer(isInput bool) (*UnifiedAudioServer, error) {
magicNumber: magicNumber, magicNumber: magicNumber,
messageChan: make(chan *UnifiedIPCMessage, Config.ChannelBufferSize), messageChan: make(chan *UnifiedIPCMessage, Config.ChannelBufferSize),
processChan: make(chan *UnifiedIPCMessage, Config.ChannelBufferSize), processChan: make(chan *UnifiedIPCMessage, Config.ChannelBufferSize),
socketBufferConfig: DefaultSocketBufferConfig(), sendBufferSize: Config.SocketOptimalBuffer,
recvBufferSize: Config.SocketOptimalBuffer,
} }
return server, nil return server, nil

View File

@ -1,127 +0,0 @@
package audio
import (
"sync/atomic"
"time"
"unsafe"
)
// MicrophoneContentionManager manages microphone access with cooldown periods
type MicrophoneContentionManager struct {
// Atomic fields MUST be first for ARM32 alignment (int64 fields need 8-byte alignment)
lastOpNano int64
cooldownNanos int64
operationID int64
lockPtr unsafe.Pointer
}
func NewMicrophoneContentionManager(cooldown time.Duration) *MicrophoneContentionManager {
return &MicrophoneContentionManager{
cooldownNanos: int64(cooldown),
}
}
type OperationResult struct {
Allowed bool
RemainingCooldown time.Duration
OperationID int64
}
func (mcm *MicrophoneContentionManager) TryOperation() OperationResult {
now := time.Now().UnixNano()
cooldown := atomic.LoadInt64(&mcm.cooldownNanos)
lastOp := atomic.LoadInt64(&mcm.lastOpNano)
elapsed := now - lastOp
if elapsed >= cooldown {
if atomic.CompareAndSwapInt64(&mcm.lastOpNano, lastOp, now) {
opID := atomic.AddInt64(&mcm.operationID, 1)
return OperationResult{
Allowed: true,
RemainingCooldown: 0,
OperationID: opID,
}
}
// Retry once if CAS failed
lastOp = atomic.LoadInt64(&mcm.lastOpNano)
elapsed = now - lastOp
if elapsed >= cooldown && atomic.CompareAndSwapInt64(&mcm.lastOpNano, lastOp, now) {
opID := atomic.AddInt64(&mcm.operationID, 1)
return OperationResult{
Allowed: true,
RemainingCooldown: 0,
OperationID: opID,
}
}
}
remaining := time.Duration(cooldown - elapsed)
if remaining < 0 {
remaining = 0
}
return OperationResult{
Allowed: false,
RemainingCooldown: remaining,
OperationID: atomic.LoadInt64(&mcm.operationID),
}
}
func (mcm *MicrophoneContentionManager) SetCooldown(cooldown time.Duration) {
atomic.StoreInt64(&mcm.cooldownNanos, int64(cooldown))
}
func (mcm *MicrophoneContentionManager) GetCooldown() time.Duration {
return time.Duration(atomic.LoadInt64(&mcm.cooldownNanos))
}
func (mcm *MicrophoneContentionManager) GetLastOperationTime() time.Time {
nanos := atomic.LoadInt64(&mcm.lastOpNano)
if nanos == 0 {
return time.Time{}
}
return time.Unix(0, nanos)
}
func (mcm *MicrophoneContentionManager) GetOperationCount() int64 {
return atomic.LoadInt64(&mcm.operationID)
}
func (mcm *MicrophoneContentionManager) Reset() {
atomic.StoreInt64(&mcm.lastOpNano, 0)
atomic.StoreInt64(&mcm.operationID, 0)
}
var (
globalMicContentionManager unsafe.Pointer
micContentionInitialized int32
)
func GetMicrophoneContentionManager() *MicrophoneContentionManager {
ptr := atomic.LoadPointer(&globalMicContentionManager)
if ptr != nil {
return (*MicrophoneContentionManager)(ptr)
}
if atomic.CompareAndSwapInt32(&micContentionInitialized, 0, 1) {
manager := NewMicrophoneContentionManager(Config.MicContentionTimeout)
atomic.StorePointer(&globalMicContentionManager, unsafe.Pointer(manager))
return manager
}
ptr = atomic.LoadPointer(&globalMicContentionManager)
if ptr != nil {
return (*MicrophoneContentionManager)(ptr)
}
return NewMicrophoneContentionManager(Config.MicContentionTimeout)
}
func TryMicrophoneOperation() OperationResult {
return GetMicrophoneContentionManager().TryOperation()
}
func SetMicrophoneCooldown(cooldown time.Duration) {
GetMicrophoneContentionManager().SetCooldown(cooldown)
}

View File

@ -1,166 +0,0 @@
package audio
import (
"fmt"
"net"
"syscall"
)
// Socket buffer sizes are now centralized in config_constants.go
// SocketBufferConfig holds socket buffer configuration
type SocketBufferConfig struct {
SendBufferSize int
RecvBufferSize int
Enabled bool
}
// DefaultSocketBufferConfig returns the default socket buffer configuration
func DefaultSocketBufferConfig() SocketBufferConfig {
return SocketBufferConfig{
SendBufferSize: Config.SocketOptimalBuffer,
RecvBufferSize: Config.SocketOptimalBuffer,
Enabled: true,
}
}
// HighLoadSocketBufferConfig returns configuration for high-load scenarios
func HighLoadSocketBufferConfig() SocketBufferConfig {
return SocketBufferConfig{
SendBufferSize: Config.SocketMaxBuffer,
RecvBufferSize: Config.SocketMaxBuffer,
Enabled: true,
}
}
// ConfigureSocketBuffers applies socket buffer configuration to a Unix socket connection
func ConfigureSocketBuffers(conn net.Conn, config SocketBufferConfig) error {
if !config.Enabled {
return nil
}
if err := ValidateSocketBufferConfig(config); err != nil {
return fmt.Errorf("invalid socket buffer config: %w", err)
}
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return fmt.Errorf("connection is not a Unix socket")
}
file, err := unixConn.File()
if err != nil {
return fmt.Errorf("failed to get socket file descriptor: %w", err)
}
defer file.Close()
fd := int(file.Fd())
if config.SendBufferSize > 0 {
if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, config.SendBufferSize); err != nil {
return fmt.Errorf("failed to set SO_SNDBUF to %d: %w", config.SendBufferSize, err)
}
}
if config.RecvBufferSize > 0 {
if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, config.RecvBufferSize); err != nil {
return fmt.Errorf("failed to set SO_RCVBUF to %d: %w", config.RecvBufferSize, err)
}
}
return nil
}
// GetSocketBufferSizes retrieves current socket buffer sizes
func GetSocketBufferSizes(conn net.Conn) (sendSize, recvSize int, err error) {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return 0, 0, fmt.Errorf("socket buffer query only supported for Unix sockets")
}
file, err := unixConn.File()
if err != nil {
return 0, 0, fmt.Errorf("failed to get socket file descriptor: %w", err)
}
defer file.Close()
fd := int(file.Fd())
// Get send buffer size
sendSize, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
if err != nil {
return 0, 0, fmt.Errorf("failed to get SO_SNDBUF: %w", err)
}
// Get receive buffer size
recvSize, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF)
if err != nil {
return 0, 0, fmt.Errorf("failed to get SO_RCVBUF: %w", err)
}
return sendSize, recvSize, nil
}
// ValidateSocketBufferConfig validates socket buffer configuration parameters.
//
// Validation Rules:
// - If config.Enabled is false, no validation is performed (returns nil)
// - SendBufferSize must be >= SocketMinBuffer (default: 8192 bytes)
// - RecvBufferSize must be >= SocketMinBuffer (default: 8192 bytes)
// - SendBufferSize must be <= SocketMaxBuffer (default: 1048576 bytes)
// - RecvBufferSize must be <= SocketMaxBuffer (default: 1048576 bytes)
//
// Error Conditions:
// - Returns error if send buffer size is below minimum threshold
// - Returns error if receive buffer size is below minimum threshold
// - Returns error if send buffer size exceeds maximum threshold
// - Returns error if receive buffer size exceeds maximum threshold
//
// The validation ensures socket buffers are sized appropriately for audio streaming
// performance while preventing excessive memory usage.
func ValidateSocketBufferConfig(config SocketBufferConfig) error {
if !config.Enabled {
return nil
}
minBuffer := Config.SocketMinBuffer
maxBuffer := Config.SocketMaxBuffer
if config.SendBufferSize < minBuffer {
return fmt.Errorf("send buffer size validation failed: got %d bytes, minimum required %d bytes (configured range: %d-%d)",
config.SendBufferSize, minBuffer, minBuffer, maxBuffer)
}
if config.RecvBufferSize < minBuffer {
return fmt.Errorf("receive buffer size validation failed: got %d bytes, minimum required %d bytes (configured range: %d-%d)",
config.RecvBufferSize, minBuffer, minBuffer, maxBuffer)
}
if config.SendBufferSize > maxBuffer {
return fmt.Errorf("send buffer size validation failed: got %d bytes, maximum allowed %d bytes (configured range: %d-%d)",
config.SendBufferSize, maxBuffer, minBuffer, maxBuffer)
}
if config.RecvBufferSize > maxBuffer {
return fmt.Errorf("receive buffer size validation failed: got %d bytes, maximum allowed %d bytes (configured range: %d-%d)",
config.RecvBufferSize, maxBuffer, minBuffer, maxBuffer)
}
return nil
}
// RecordSocketBufferMetrics records socket buffer metrics for monitoring
func RecordSocketBufferMetrics(conn net.Conn, component string) {
if conn == nil {
return
}
// Get current socket buffer sizes
_, _, err := GetSocketBufferSizes(conn)
if err != nil {
// Log error but don't fail
return
}
// Socket buffer sizes recorded for debugging if needed
}

View File

@ -1,56 +0,0 @@
package audio
import (
"os"
"strconv"
"github.com/jetkvm/kvm/internal/logging"
)
// getEnvInt reads an integer value from environment variable with fallback to default
func getEnvInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
if intValue, err := strconv.Atoi(value); err == nil {
return intValue
}
}
return defaultValue
}
// parseOpusConfig reads OPUS configuration from environment variables
// with fallback to default config values
func parseOpusConfig() (bitrate, complexity, vbr, signalType, bandwidth, dtx int) {
// Read configuration from environment variables with config defaults
bitrate = getEnvInt("JETKVM_OPUS_BITRATE", Config.CGOOpusBitrate)
complexity = getEnvInt("JETKVM_OPUS_COMPLEXITY", Config.CGOOpusComplexity)
vbr = getEnvInt("JETKVM_OPUS_VBR", Config.CGOOpusVBR)
signalType = getEnvInt("JETKVM_OPUS_SIGNAL_TYPE", Config.CGOOpusSignalType)
bandwidth = getEnvInt("JETKVM_OPUS_BANDWIDTH", Config.CGOOpusBandwidth)
dtx = getEnvInt("JETKVM_OPUS_DTX", Config.CGOOpusDTX)
return bitrate, complexity, vbr, signalType, bandwidth, dtx
}
// applyOpusConfig applies OPUS configuration to the global config
// with optional logging for the specified component
func applyOpusConfig(bitrate, complexity, vbr, signalType, bandwidth, dtx int, component string, enableLogging bool) {
config := Config
config.CGOOpusBitrate = bitrate
config.CGOOpusComplexity = complexity
config.CGOOpusVBR = vbr
config.CGOOpusSignalType = signalType
config.CGOOpusBandwidth = bandwidth
config.CGOOpusDTX = dtx
if enableLogging {
logger := logging.GetDefaultLogger().With().Str("component", component).Logger()
logger.Info().
Int("bitrate", bitrate).
Int("complexity", complexity).
Int("vbr", vbr).
Int("signal_type", signalType).
Int("bandwidth", bandwidth).
Int("dtx", dtx).
Msg("applied OPUS configuration")
}
}

View File

@ -118,8 +118,6 @@ func uiInit(rotation uint16) {
defer cgoLock.Unlock() defer cgoLock.Unlock()
cRotation := C.u_int16_t(rotation) cRotation := C.u_int16_t(rotation)
defer C.free(unsafe.Pointer(&cRotation))
C.jetkvm_ui_init(cRotation) C.jetkvm_ui_init(cRotation)
} }
@ -350,8 +348,6 @@ func uiDispSetRotation(rotation uint16) (bool, error) {
nativeLogger.Info().Uint16("rotation", rotation).Msg("setting rotation") nativeLogger.Info().Uint16("rotation", rotation).Msg("setting rotation")
cRotation := C.u_int16_t(rotation) cRotation := C.u_int16_t(rotation)
defer C.free(unsafe.Pointer(&cRotation))
C.jetkvm_ui_set_rotation(cRotation) C.jetkvm_ui_set_rotation(cRotation)
return true, nil return true, nil
} }

View File

@ -1,333 +0,0 @@
//go:build arm && linux
package usbgadget
import (
"context"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// Hardware integration tests for USB gadget operations
// These tests perform real hardware operations with proper cleanup and timeout handling
var (
testConfig = &Config{
VendorId: "0x1d6b", // The Linux Foundation
ProductId: "0x0104", // Multifunction Composite Gadget
SerialNumber: "",
Manufacturer: "JetKVM",
Product: "USB Emulation Device",
strictMode: false, // Disable strict mode for hardware tests
}
testDevices = &Devices{
AbsoluteMouse: true,
RelativeMouse: true,
Keyboard: true,
MassStorage: true,
}
testGadgetName = "jetkvm-test"
)
func TestUsbGadgetHardwareInit(t *testing.T) {
if testing.Short() {
t.Skip("Skipping hardware test in short mode")
}
// Create context with timeout to prevent hanging
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Ensure clean state before test
cleanupUsbGadget(t, testGadgetName)
// Test USB gadget initialization with timeout
var gadget *UsbGadget
done := make(chan bool, 1)
var initErr error
go func() {
defer func() {
if r := recover(); r != nil {
t.Logf("USB gadget initialization panicked: %v", r)
initErr = assert.AnError
}
done <- true
}()
gadget = NewUsbGadget(testGadgetName, testDevices, testConfig, nil)
if gadget == nil {
initErr = assert.AnError
}
}()
// Wait for initialization or timeout
select {
case <-done:
if initErr != nil {
t.Fatalf("USB gadget initialization failed: %v", initErr)
}
assert.NotNil(t, gadget, "USB gadget should be initialized")
case <-ctx.Done():
t.Fatal("USB gadget initialization timed out")
}
// Cleanup after test
defer func() {
if gadget != nil {
gadget.CloseHidFiles()
}
cleanupUsbGadget(t, testGadgetName)
}()
// Validate gadget state
assert.NotNil(t, gadget, "USB gadget should not be nil")
validateHardwareState(t, gadget)
// Test UDC binding state
bound, err := gadget.IsUDCBound()
assert.NoError(t, err, "Should be able to check UDC binding state")
t.Logf("UDC bound state: %v", bound)
}
func TestUsbGadgetHardwareReconfiguration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping hardware test in short mode")
}
// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
defer cancel()
// Ensure clean state
cleanupUsbGadget(t, testGadgetName)
// Initialize first gadget
gadget1 := createUsbGadgetWithTimeout(t, ctx, testGadgetName, testDevices, testConfig)
defer func() {
if gadget1 != nil {
gadget1.CloseHidFiles()
}
}()
// Validate initial state
assert.NotNil(t, gadget1, "First USB gadget should be initialized")
// Close first gadget properly
gadget1.CloseHidFiles()
gadget1 = nil
// Wait for cleanup to complete
time.Sleep(500 * time.Millisecond)
// Test reconfiguration with different report descriptor
altGadgetConfig := make(map[string]gadgetConfigItem)
for k, v := range defaultGadgetConfig {
altGadgetConfig[k] = v
}
// Modify absolute mouse configuration
oldAbsoluteMouseConfig := altGadgetConfig["absolute_mouse"]
oldAbsoluteMouseConfig.reportDesc = absoluteMouseCombinedReportDesc
altGadgetConfig["absolute_mouse"] = oldAbsoluteMouseConfig
// Create second gadget with modified configuration
gadget2 := createUsbGadgetWithTimeoutAndConfig(t, ctx, testGadgetName, altGadgetConfig, testDevices, testConfig)
defer func() {
if gadget2 != nil {
gadget2.CloseHidFiles()
}
cleanupUsbGadget(t, testGadgetName)
}()
assert.NotNil(t, gadget2, "Second USB gadget should be initialized")
validateHardwareState(t, gadget2)
// Validate UDC binding after reconfiguration
udcs := getUdcs()
assert.NotEmpty(t, udcs, "Should have at least one UDC")
if len(udcs) > 0 {
udc := udcs[0]
t.Logf("Available UDC: %s", udc)
// Check UDC binding state
udcStr, err := os.ReadFile("/sys/kernel/config/usb_gadget/" + testGadgetName + "/UDC")
if err == nil {
t.Logf("UDC binding: %s", strings.TrimSpace(string(udcStr)))
} else {
t.Logf("Could not read UDC binding: %v", err)
}
}
}
func TestUsbGadgetHardwareStressTest(t *testing.T) {
if testing.Short() {
t.Skip("Skipping stress test in short mode")
}
// Create context with longer timeout for stress test
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Ensure clean state
cleanupUsbGadget(t, testGadgetName)
// Perform multiple rapid reconfigurations
for i := 0; i < 3; i++ {
t.Logf("Stress test iteration %d", i+1)
// Create gadget
gadget := createUsbGadgetWithTimeout(t, ctx, testGadgetName, testDevices, testConfig)
if gadget == nil {
t.Fatalf("Failed to create USB gadget in iteration %d", i+1)
}
// Validate gadget
assert.NotNil(t, gadget, "USB gadget should be created in iteration %d", i+1)
validateHardwareState(t, gadget)
// Test basic operations
bound, err := gadget.IsUDCBound()
assert.NoError(t, err, "Should be able to check UDC state in iteration %d", i+1)
t.Logf("Iteration %d: UDC bound = %v", i+1, bound)
// Cleanup
gadget.CloseHidFiles()
gadget = nil
// Wait between iterations
time.Sleep(1 * time.Second)
// Check for timeout
select {
case <-ctx.Done():
t.Fatal("Stress test timed out")
default:
// Continue
}
}
// Final cleanup
cleanupUsbGadget(t, testGadgetName)
}
// Helper functions for hardware tests
// createUsbGadgetWithTimeout creates a USB gadget with timeout protection
func createUsbGadgetWithTimeout(t *testing.T, ctx context.Context, name string, devices *Devices, config *Config) *UsbGadget {
return createUsbGadgetWithTimeoutAndConfig(t, ctx, name, defaultGadgetConfig, devices, config)
}
// createUsbGadgetWithTimeoutAndConfig creates a USB gadget with custom config and timeout protection
func createUsbGadgetWithTimeoutAndConfig(t *testing.T, ctx context.Context, name string, gadgetConfig map[string]gadgetConfigItem, devices *Devices, config *Config) *UsbGadget {
var gadget *UsbGadget
done := make(chan bool, 1)
var createErr error
go func() {
defer func() {
if r := recover(); r != nil {
t.Logf("USB gadget creation panicked: %v", r)
createErr = assert.AnError
}
done <- true
}()
gadget = newUsbGadget(name, gadgetConfig, devices, config, nil)
if gadget == nil {
createErr = assert.AnError
}
}()
// Wait for creation or timeout
select {
case <-done:
if createErr != nil {
t.Logf("USB gadget creation failed: %v", createErr)
return nil
}
return gadget
case <-ctx.Done():
t.Logf("USB gadget creation timed out")
return nil
}
}
// cleanupUsbGadget ensures clean state by removing any existing USB gadget configuration
func cleanupUsbGadget(t *testing.T, name string) {
t.Logf("Cleaning up USB gadget: %s", name)
// Try to unbind UDC first
udcPath := "/sys/kernel/config/usb_gadget/" + name + "/UDC"
if _, err := os.Stat(udcPath); err == nil {
// Read current UDC binding
if udcData, err := os.ReadFile(udcPath); err == nil && len(strings.TrimSpace(string(udcData))) > 0 {
// Unbind UDC
if err := os.WriteFile(udcPath, []byte(""), 0644); err != nil {
t.Logf("Failed to unbind UDC: %v", err)
} else {
t.Logf("Successfully unbound UDC")
// Wait for unbinding to complete
time.Sleep(200 * time.Millisecond)
}
}
}
// Remove gadget directory if it exists
gadgetPath := "/sys/kernel/config/usb_gadget/" + name
if _, err := os.Stat(gadgetPath); err == nil {
// Try to remove configuration links first
configPath := gadgetPath + "/configs/c.1"
if entries, err := os.ReadDir(configPath); err == nil {
for _, entry := range entries {
if entry.Type()&os.ModeSymlink != 0 {
linkPath := configPath + "/" + entry.Name()
if err := os.Remove(linkPath); err != nil {
t.Logf("Failed to remove config link %s: %v", linkPath, err)
}
}
}
}
// Remove the gadget directory (this should cascade remove everything)
if err := os.RemoveAll(gadgetPath); err != nil {
t.Logf("Failed to remove gadget directory: %v", err)
} else {
t.Logf("Successfully removed gadget directory")
}
}
// Wait for cleanup to complete
time.Sleep(300 * time.Millisecond)
}
// validateHardwareState checks the current hardware state
func validateHardwareState(t *testing.T, gadget *UsbGadget) {
if gadget == nil {
return
}
// Check UDC binding state
bound, err := gadget.IsUDCBound()
if err != nil {
t.Logf("Warning: Could not check UDC binding state: %v", err)
} else {
t.Logf("UDC bound: %v", bound)
}
// Check available UDCs
udcs := getUdcs()
t.Logf("Available UDCs: %v", udcs)
// Check configfs mount
if _, err := os.Stat("/sys/kernel/config"); err != nil {
t.Logf("Warning: configfs not available: %v", err)
} else {
t.Logf("configfs is available")
}
}

View File

@ -1,437 +0,0 @@
package usbgadget
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// Unit tests for USB gadget configuration logic without hardware dependencies
// These tests follow the pattern of audio tests - testing business logic and validation
func TestUsbGadgetConfigValidation(t *testing.T) {
tests := []struct {
name string
config *Config
devices *Devices
expected bool
}{
{
name: "ValidConfig",
config: &Config{
VendorId: "0x1d6b",
ProductId: "0x0104",
Manufacturer: "JetKVM",
Product: "USB Emulation Device",
},
devices: &Devices{
Keyboard: true,
AbsoluteMouse: true,
RelativeMouse: true,
MassStorage: true,
},
expected: true,
},
{
name: "InvalidVendorId",
config: &Config{
VendorId: "invalid",
ProductId: "0x0104",
Manufacturer: "JetKVM",
Product: "USB Emulation Device",
},
devices: &Devices{
Keyboard: true,
},
expected: false,
},
{
name: "EmptyManufacturer",
config: &Config{
VendorId: "0x1d6b",
ProductId: "0x0104",
Manufacturer: "",
Product: "USB Emulation Device",
},
devices: &Devices{
Keyboard: true,
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateUsbGadgetConfiguration(tt.config, tt.devices)
if tt.expected {
assert.NoError(t, err, "Configuration should be valid")
} else {
assert.Error(t, err, "Configuration should be invalid")
}
})
}
}
func TestUsbGadgetDeviceConfiguration(t *testing.T) {
tests := []struct {
name string
devices *Devices
expectedConfigs []string
}{
{
name: "AllDevicesEnabled",
devices: &Devices{
Keyboard: true,
AbsoluteMouse: true,
RelativeMouse: true,
MassStorage: true,
Audio: true,
},
expectedConfigs: []string{"keyboard", "absolute_mouse", "relative_mouse", "mass_storage_base", "audio"},
},
{
name: "OnlyKeyboard",
devices: &Devices{
Keyboard: true,
},
expectedConfigs: []string{"keyboard"},
},
{
name: "MouseOnly",
devices: &Devices{
AbsoluteMouse: true,
RelativeMouse: true,
},
expectedConfigs: []string{"absolute_mouse", "relative_mouse"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configs := getEnabledGadgetConfigs(tt.devices)
assert.ElementsMatch(t, tt.expectedConfigs, configs, "Enabled configs should match expected")
})
}
}
func TestUsbGadgetStateTransition(t *testing.T) {
if testing.Short() {
t.Skip("Skipping state transition test in short mode")
}
tests := []struct {
name string
initialDevices *Devices
newDevices *Devices
expectedTransition string
}{
{
name: "EnableAudio",
initialDevices: &Devices{
Keyboard: true,
AbsoluteMouse: true,
Audio: false,
},
newDevices: &Devices{
Keyboard: true,
AbsoluteMouse: true,
Audio: true,
},
expectedTransition: "audio_enabled",
},
{
name: "DisableKeyboard",
initialDevices: &Devices{
Keyboard: true,
AbsoluteMouse: true,
},
newDevices: &Devices{
Keyboard: false,
AbsoluteMouse: true,
},
expectedTransition: "keyboard_disabled",
},
{
name: "NoChange",
initialDevices: &Devices{
Keyboard: true,
AbsoluteMouse: true,
},
newDevices: &Devices{
Keyboard: true,
AbsoluteMouse: true,
},
expectedTransition: "no_change",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
transition := simulateUsbGadgetStateTransition(ctx, tt.initialDevices, tt.newDevices)
assert.Equal(t, tt.expectedTransition, transition, "State transition should match expected")
})
}
}
func TestUsbGadgetConfigurationTimeout(t *testing.T) {
if testing.Short() {
t.Skip("Skipping timeout test in short mode")
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// Test that configuration validation completes within reasonable time
start := time.Now()
// Simulate multiple rapid configuration changes
for i := 0; i < 20; i++ {
devices := &Devices{
Keyboard: i%2 == 0,
AbsoluteMouse: i%3 == 0,
RelativeMouse: i%4 == 0,
MassStorage: i%5 == 0,
Audio: i%6 == 0,
}
config := &Config{
VendorId: "0x1d6b",
ProductId: "0x0104",
Manufacturer: "JetKVM",
Product: "USB Emulation Device",
}
err := validateUsbGadgetConfiguration(config, devices)
assert.NoError(t, err, "Configuration validation should not fail")
// Ensure we don't timeout
select {
case <-ctx.Done():
t.Fatal("USB gadget configuration test timed out")
default:
// Continue
}
}
elapsed := time.Since(start)
t.Logf("USB gadget configuration test completed in %v", elapsed)
assert.Less(t, elapsed, 2*time.Second, "Configuration validation should complete quickly")
}
func TestReportDescriptorValidation(t *testing.T) {
tests := []struct {
name string
reportDesc []byte
expected bool
}{
{
name: "ValidKeyboardReportDesc",
reportDesc: keyboardReportDesc,
expected: true,
},
{
name: "ValidAbsoluteMouseReportDesc",
reportDesc: absoluteMouseCombinedReportDesc,
expected: true,
},
{
name: "ValidRelativeMouseReportDesc",
reportDesc: relativeMouseCombinedReportDesc,
expected: true,
},
{
name: "EmptyReportDesc",
reportDesc: []byte{},
expected: false,
},
{
name: "InvalidReportDesc",
reportDesc: []byte{0xFF, 0xFF, 0xFF},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateReportDescriptor(tt.reportDesc)
if tt.expected {
assert.NoError(t, err, "Report descriptor should be valid")
} else {
assert.Error(t, err, "Report descriptor should be invalid")
}
})
}
}
// Helper functions for simulation (similar to audio tests)
// validateUsbGadgetConfiguration simulates the validation that happens in production
func validateUsbGadgetConfiguration(config *Config, devices *Devices) error {
if config == nil {
return assert.AnError
}
// Validate vendor ID format
if config.VendorId == "" || len(config.VendorId) < 4 {
return assert.AnError
}
if config.VendorId != "" && config.VendorId[:2] != "0x" {
return assert.AnError
}
// Validate product ID format
if config.ProductId == "" || len(config.ProductId) < 4 {
return assert.AnError
}
if config.ProductId != "" && config.ProductId[:2] != "0x" {
return assert.AnError
}
// Validate required fields
if config.Manufacturer == "" {
return assert.AnError
}
if config.Product == "" {
return assert.AnError
}
// Note: Allow configurations with no devices enabled for testing purposes
// In production, this would typically be validated at a higher level
return nil
}
// getEnabledGadgetConfigs returns the list of enabled gadget configurations
func getEnabledGadgetConfigs(devices *Devices) []string {
var configs []string
if devices.Keyboard {
configs = append(configs, "keyboard")
}
if devices.AbsoluteMouse {
configs = append(configs, "absolute_mouse")
}
if devices.RelativeMouse {
configs = append(configs, "relative_mouse")
}
if devices.MassStorage {
configs = append(configs, "mass_storage_base")
}
if devices.Audio {
configs = append(configs, "audio")
}
return configs
}
// simulateUsbGadgetStateTransition simulates the state management during USB reconfiguration
func simulateUsbGadgetStateTransition(ctx context.Context, initial, new *Devices) string {
// Check for audio changes
if initial.Audio != new.Audio {
if new.Audio {
// Simulate enabling audio device
time.Sleep(5 * time.Millisecond)
return "audio_enabled"
} else {
// Simulate disabling audio device
time.Sleep(5 * time.Millisecond)
return "audio_disabled"
}
}
// Check for keyboard changes
if initial.Keyboard != new.Keyboard {
if new.Keyboard {
time.Sleep(5 * time.Millisecond)
return "keyboard_enabled"
} else {
time.Sleep(5 * time.Millisecond)
return "keyboard_disabled"
}
}
// Check for mouse changes
if initial.AbsoluteMouse != new.AbsoluteMouse || initial.RelativeMouse != new.RelativeMouse {
time.Sleep(5 * time.Millisecond)
return "mouse_changed"
}
// Check for mass storage changes
if initial.MassStorage != new.MassStorage {
time.Sleep(5 * time.Millisecond)
return "mass_storage_changed"
}
return "no_change"
}
// validateReportDescriptor simulates HID report descriptor validation
func validateReportDescriptor(reportDesc []byte) error {
if len(reportDesc) == 0 {
return assert.AnError
}
// Basic HID report descriptor validation
// Check for valid usage page (0x05)
found := false
for i := 0; i < len(reportDesc)-1; i++ {
if reportDesc[i] == 0x05 {
found = true
break
}
}
if !found {
return assert.AnError
}
return nil
}
// Benchmark tests
func BenchmarkValidateUsbGadgetConfiguration(b *testing.B) {
config := &Config{
VendorId: "0x1d6b",
ProductId: "0x0104",
Manufacturer: "JetKVM",
Product: "USB Emulation Device",
}
devices := &Devices{
Keyboard: true,
AbsoluteMouse: true,
RelativeMouse: true,
MassStorage: true,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = validateUsbGadgetConfiguration(config, devices)
}
}
func BenchmarkGetEnabledGadgetConfigs(b *testing.B) {
devices := &Devices{
Keyboard: true,
AbsoluteMouse: true,
RelativeMouse: true,
MassStorage: true,
Audio: true,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = getEnabledGadgetConfigs(devices)
}
}
func BenchmarkValidateReportDescriptor(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = validateReportDescriptor(keyboardReportDesc)
}
}

Binary file not shown.