diff --git a/session_manager.go b/session_manager.go index a2e38607..40dea9df 100644 --- a/session_manager.go +++ b/session_manager.go @@ -31,12 +31,23 @@ const ( // Session timeout defaults defaultPendingSessionTimeout = 1 * time.Minute // Timeout for pending sessions (DoS protection) defaultObserverSessionTimeout = 2 * time.Minute // Timeout for inactive observer sessions + disabledTimeoutValue = 24 * time.Hour // Value used when timeout is disabled (0 setting) // Transfer and blacklist settings transferBlacklistDuration = 60 * time.Second // Duration to blacklist sessions after manual transfer // Grace period limits maxGracePeriodEntries = 10 // Maximum number of grace period entries to prevent memory exhaustion + + // Emergency promotion limits (DoS protection) + emergencyWindowDuration = 60 * time.Second // Sliding window duration for emergency promotion rate limiting + maxEmergencyPromotionsPerMinute = 3 // Maximum emergency promotions allowed within the sliding window + emergencyPromotionCooldown = 10 * time.Second // Minimum time between individual emergency promotions + maxConsecutiveEmergencyPromotions = 3 // Maximum consecutive emergency promotions before blocking + emergencyPromotionWindowCleanupAge = 60 * time.Second // Age at which emergency window entries are cleaned up + + // Trust scoring constants + invalidSessionTrustScore = -1000 // Trust score for non-existent sessions ) var ( @@ -81,12 +92,6 @@ type TransferBlacklistEntry struct { ExpiresAt time.Time } -// Broadcast throttling state (DoS protection) -var ( - lastBroadcast time.Time - broadcastMutex sync.Mutex -) - type SessionManager struct { mu sync.RWMutex primaryPromotionLock sync.Mutex @@ -108,6 +113,8 @@ type SessionManager struct { emergencyPromotionWindow []time.Time emergencyWindowMutex sync.Mutex + lastBroadcast time.Time + broadcastMutex sync.Mutex broadcastQueue chan struct{} broadcastPending atomic.Bool } @@ -184,15 +191,13 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe nicknameReserved := false defer func() { - if r := recover(); r != nil || nicknameReserved { + if r := recover(); r != nil { if nicknameReserved && session.Nickname != "" { if sm.nicknameIndex[session.Nickname] == session { delete(sm.nicknameIndex, session.Nickname) } } - if r != nil { - panic(r) - } + panic(r) } }() @@ -221,7 +226,6 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe // Check if a session with this ID already exists (reconnection) if existing, exists := sm.sessions[session.ID]; exists { - // SECURITY: Verify identity matches to prevent session hijacking if existing.Identity != session.Identity || existing.Source != session.Source { return fmt.Errorf("session ID already in use by different user (identity mismatch)") } @@ -237,7 +241,6 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe existing.ControlChannel = session.ControlChannel existing.RPCChannel = session.RPCChannel existing.HidChannel = session.HidChannel - existing.LastActive = time.Now() existing.flushCandidates = session.flushCandidates // Preserve existing mode and nickname session.Mode = existing.Mode @@ -1197,7 +1200,11 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf // Promote target session toSession.Mode = SessionModePrimary toSession.hidRPCAvailable = false // Force re-handshake - toSession.LastActive = time.Now() // Reset activity timestamp to prevent immediate timeout + // Only reset LastActive for emergency promotions to prevent immediate re-timeout + // For manual transfers, preserve existing LastActive to maintain timeout accuracy + if transferType == "emergency_timeout_promotion" || transferType == "emergency_promotion_deadlock_prevention" { + toSession.LastActive = time.Now() // Reset for emergency promotions only + } sm.primarySessionID = toSessionID // ALWAYS set lastPrimaryID to the new primary to support WebRTC reconnections @@ -1272,7 +1279,7 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf // Send reconnection signal for emergency promotions via WebSocket (more reliable than RPC when channel is stale) if toExists && (transferType == "emergency_timeout_promotion" || transferType == "emergency_auto_promotion") { go func() { - time.Sleep(100 * time.Millisecond) + time.Sleep(globalBroadcastDelay) eventData := map[string]interface{}{ "sessionId": toSessionID, @@ -1362,8 +1369,7 @@ func (sm *SessionManager) getCurrentPrimaryTimeout() time.Duration { // Use session settings if available if currentSessionSettings != nil { if currentSessionSettings.PrimaryTimeout == 0 { - // 0 means disabled - return a very large duration - return 24 * time.Hour + return disabledTimeoutValue } else if currentSessionSettings.PrimaryTimeout > 0 { return time.Duration(currentSessionSettings.PrimaryTimeout) * time.Second } @@ -1376,7 +1382,7 @@ func (sm *SessionManager) getCurrentPrimaryTimeout() time.Duration { func (sm *SessionManager) getSessionTrustScore(sessionID string) int { session, exists := sm.sessions[sessionID] if !exists { - return -1000 // Session doesn't exist + return invalidSessionTrustScore } score := 0 @@ -1422,9 +1428,7 @@ func (sm *SessionManager) findMostTrustedSessionForEmergency() string { bestSessionID := "" bestScore := -1 - // First pass: try to find observers or queued sessions (preferred) for sessionID, session := range sm.sessions { - // Skip if blacklisted, primary, or not eligible modes if sm.isSessionBlacklisted(sessionID) || session.Mode == SessionModePrimary || (session.Mode != SessionModeObserver && session.Mode != SessionModeQueued) { @@ -1438,24 +1442,6 @@ func (sm *SessionManager) findMostTrustedSessionForEmergency() string { } } - // If no observers/queued found, try pending sessions as last resort - if bestSessionID == "" { - for sessionID, session := range sm.sessions { - if sm.isSessionBlacklisted(sessionID) || session.Mode == SessionModePrimary { - continue - } - - if session.Mode == SessionModePending { - score := sm.getSessionTrustScore(sessionID) - if score > bestScore { - bestScore = score - bestSessionID = sessionID - } - } - } - } - - // Log the selection decision for audit trail if bestSessionID != "" { sm.logger.Info(). Str("selectedSession", bestSessionID). @@ -1589,13 +1575,13 @@ func (sm *SessionManager) broadcastSessionListUpdate() { } func (sm *SessionManager) executeBroadcast() { - broadcastMutex.Lock() - if time.Since(lastBroadcast) < globalBroadcastDelay { - broadcastMutex.Unlock() + sm.broadcastMutex.Lock() + if time.Since(sm.lastBroadcast) < globalBroadcastDelay { + sm.broadcastMutex.Unlock() return } - lastBroadcast = time.Now() - broadcastMutex.Unlock() + sm.lastBroadcast = time.Now() + sm.broadcastMutex.Unlock() sm.mu.RLock() infos := make([]SessionData, 0, len(sm.sessions)) @@ -1647,7 +1633,8 @@ func (sm *SessionManager) Shutdown() { sm.mu.Lock() defer sm.mu.Unlock() - // Clean up all sessions + close(sm.broadcastQueue) + for id := range sm.sessions { delete(sm.sessions, id) } @@ -1668,6 +1655,18 @@ func (sm *SessionManager) cleanupInactiveSessions(ctx context.Context) { now := time.Now() needsBroadcast := false + // Clean up expired emergency promotion window entries + sm.emergencyWindowMutex.Lock() + cutoff := now.Add(-emergencyPromotionWindowCleanupAge) + validEntries := make([]time.Time, 0, len(sm.emergencyPromotionWindow)) + for _, t := range sm.emergencyPromotionWindow { + if t.After(cutoff) { + validEntries = append(validEntries, t) + } + } + sm.emergencyPromotionWindow = validEntries + sm.emergencyWindowMutex.Unlock() + // Handle expired grace periods gracePeriodExpired := sm.handleGracePeriodExpiration(now) if gracePeriodExpired {