[WIP] Bugfixes: session promotion

This commit is contained in:
Alex P 2025-10-10 10:16:21 +03:00
parent 8dbd98b4f0
commit 309126bef6
2 changed files with 124 additions and 47 deletions

View File

@ -29,6 +29,9 @@ func Main() {
} }
currentSessionSettings = config.SessionSettings currentSessionSettings = config.SessionSettings
// Initialize global session manager (must be called after config and logger are ready)
initSessionManager()
var cancel context.CancelFunc var cancel context.CancelFunc
appCtx, cancel = context.WithCancel(context.Background()) appCtx, cancel = context.WithCancel(context.Background())
defer cancel() defer cancel()

View File

@ -88,6 +88,12 @@ type SessionManager struct {
// NewSessionManager creates a new session manager // NewSessionManager creates a new session manager
func NewSessionManager(logger *zerolog.Logger) *SessionManager { func NewSessionManager(logger *zerolog.Logger) *SessionManager {
// DEBUG: Log every time a new SessionManager is created
if logger != nil {
logger.Warn().
Msg("CREATING NEW SESSION MANAGER - This should only happen once at startup!")
}
// Use configuration values if available // Use configuration values if available
maxSessions := 10 maxSessions := 10
primaryTimeout := 5 * time.Minute primaryTimeout := 5 * time.Minute
@ -389,6 +395,8 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
Str("sessionID", session.ID). Str("sessionID", session.ID).
Str("mode", string(session.Mode)). Str("mode", string(session.Mode)).
Int("totalSessions", len(sm.sessions)). Int("totalSessions", len(sm.sessions)).
Str("sm_pointer", fmt.Sprintf("%p", sm)).
Str("sm.sessions_pointer", fmt.Sprintf("%p", sm.sessions)).
Msg("Session added to manager") Msg("Session added to manager")
// Ensure session has auto-generated nickname if needed // Ensure session has auto-generated nickname if needed
@ -661,6 +669,13 @@ func (sm *SessionManager) GetAllSessions() []SessionData {
// This was causing immediate demotion during transfers and page refreshes // This was causing immediate demotion during transfers and page refreshes
// Validation should only run during state changes, not data queries // Validation should only run during state changes, not data queries
// DEBUG: Log pointer addresses to verify we're using the same instance
sm.logger.Debug().
Int("sessions_count", len(sm.sessions)).
Str("sm_pointer", fmt.Sprintf("%p", sm)).
Str("sm.sessions_pointer", fmt.Sprintf("%p", sm.sessions)).
Msg("GetAllSessions called")
infos := make([]SessionData, 0, len(sm.sessions)) infos := make([]SessionData, 0, len(sm.sessions))
for _, session := range sm.sessions { for _, session := range sm.sessions {
infos = append(infos, SessionData{ infos = append(infos, SessionData{
@ -965,17 +980,26 @@ func (sm *SessionManager) UpdateLastActive(sessionID string) {
// validateSinglePrimary ensures there's only one primary session and fixes any inconsistencies // validateSinglePrimary ensures there's only one primary session and fixes any inconsistencies
func (sm *SessionManager) validateSinglePrimary() { func (sm *SessionManager) validateSinglePrimary() {
// CRITICAL DEBUG: Check if we actually hold the lock
// The caller should already hold sm.mu.Lock()
primarySessions := make([]*Session, 0) primarySessions := make([]*Session, 0)
// Capture session keys BEFORE logging to avoid lazy evaluation issues
sessionKeys := make([]string, 0, len(sm.sessions))
sessionPointers := make([]string, 0, len(sm.sessions))
for k, v := range sm.sessions {
sessionKeys = append(sessionKeys, k)
sessionPointers = append(sessionPointers, fmt.Sprintf("%s=%p", k[:8], v))
}
// DEBUG: Add pointer address to verify we're using the right manager instance
sm.logger.Debug(). sm.logger.Debug().
Int("sm.sessions_len", len(sm.sessions)). Int("sm.sessions_len_before_loop", len(sm.sessions)).
Interface("sm.sessions_keys", func() []string { Strs("sm.sessions_keys", sessionKeys).
keys := make([]string, 0, len(sm.sessions)) Strs("sm.session_pointers", sessionPointers).
for k := range sm.sessions { Str("sm_pointer", fmt.Sprintf("%p", sm)).
keys = append(keys, k) Str("sm.sessions_map_pointer", fmt.Sprintf("%p", sm.sessions)).
}
return keys
}()).
Msg("validateSinglePrimary: checking sm.sessions map") Msg("validateSinglePrimary: checking sm.sessions map")
// Find all sessions that think they're primary // Find all sessions that think they're primary
@ -1134,6 +1158,9 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf
if fromExists && fromSession.Mode == SessionModePrimary { if fromExists && fromSession.Mode == SessionModePrimary {
fromSession.Mode = SessionModeObserver fromSession.Mode = SessionModeObserver
fromSession.hidRPCAvailable = false fromSession.hidRPCAvailable = false
// Always delete grace period when demoting - no exceptions
// If a session times out or is manually transferred, it should not auto-reclaim primary
delete(sm.reconnectGrace, fromSessionID) delete(sm.reconnectGrace, fromSessionID)
delete(sm.reconnectInfo, fromSessionID) delete(sm.reconnectInfo, fromSessionID)
@ -1160,7 +1187,15 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf
toSession.Mode = SessionModePrimary toSession.Mode = SessionModePrimary
toSession.hidRPCAvailable = false // Force re-handshake toSession.hidRPCAvailable = false // Force re-handshake
sm.primarySessionID = toSessionID sm.primarySessionID = toSessionID
sm.lastPrimaryID = toSessionID // Set to new primary so grace period works on refresh
// Only set lastPrimaryID for grace period scenarios, NOT for manual transfers
// Manual transfers should clear lastPrimaryID to prevent reconnection conflicts
if transferType == "emergency_auto_promotion" || transferType == "emergency_promotion_deadlock_prevention" ||
transferType == "emergency_timeout_promotion" || transferType == "initial_promotion" {
sm.lastPrimaryID = toSessionID // Allow grace period recovery for emergency promotions
} else {
sm.lastPrimaryID = "" // Clear for manual transfers to prevent reconnection conflicts
}
// Clear input state // Clear input state
sm.clearInputState() sm.clearInputState()
@ -1171,27 +1206,32 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf
} }
// Apply bidirectional blacklisting - protect newly promoted session // Apply bidirectional blacklisting - protect newly promoted session
// Only apply blacklisting for MANUAL transfers, not emergency promotions
// Emergency promotions need to happen immediately without blacklist interference
isManualTransfer := (transferType == "direct_transfer" || transferType == "approval_transfer" || transferType == "release_transfer")
now := time.Now() now := time.Now()
blacklistDuration := 60 * time.Second blacklistDuration := 60 * time.Second
blacklistedCount := 0 blacklistedCount := 0
// First, clear any existing blacklist entries for the newly promoted session if isManualTransfer {
cleanedBlacklist := make([]TransferBlacklistEntry, 0) // First, clear any existing blacklist entries for the newly promoted session
for _, entry := range sm.transferBlacklist { cleanedBlacklist := make([]TransferBlacklistEntry, 0)
if entry.SessionID != toSessionID { // Remove any old blacklist entries for the new primary for _, entry := range sm.transferBlacklist {
cleanedBlacklist = append(cleanedBlacklist, entry) if entry.SessionID != toSessionID { // Remove any old blacklist entries for the new primary
cleanedBlacklist = append(cleanedBlacklist, entry)
}
} }
} sm.transferBlacklist = cleanedBlacklist
sm.transferBlacklist = cleanedBlacklist
// Then blacklist all other sessions // Then blacklist all other sessions
for sessionID := range sm.sessions { for sessionID := range sm.sessions {
if sessionID != toSessionID { // Don't blacklist the newly promoted session if sessionID != toSessionID { // Don't blacklist the newly promoted session
sm.transferBlacklist = append(sm.transferBlacklist, TransferBlacklistEntry{ sm.transferBlacklist = append(sm.transferBlacklist, TransferBlacklistEntry{
SessionID: sessionID, SessionID: sessionID,
ExpiresAt: now.Add(blacklistDuration), ExpiresAt: now.Add(blacklistDuration),
}) })
blacklistedCount++ blacklistedCount++
}
} }
} }
@ -1214,8 +1254,9 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf
Dur("blacklistDuration", blacklistDuration). Dur("blacklistDuration", blacklistDuration).
Msg("Primary role transferred with bidirectional protection") Msg("Primary role transferred with bidirectional protection")
// Validate session consistency after role transfer // DON'T validate here - causes recursive calls and map iteration issues
sm.validateSinglePrimary() // The caller (AddSession, RemoveSession, etc.) will validate after we return
// sm.validateSinglePrimary() // REMOVED to prevent recursion
// Handle WebRTC connection state for promoted sessions // Handle WebRTC connection state for promoted sessions
// When a session changes from observer to primary, the existing WebRTC connection // When a session changes from observer to primary, the existing WebRTC connection
@ -1629,22 +1670,31 @@ func (sm *SessionManager) cleanupInactiveSessions(ctx context.Context) {
if currentSessionSettings != nil && currentSessionSettings.RequireApproval { if currentSessionSettings != nil && currentSessionSettings.RequireApproval {
isEmergencyPromotion = true isEmergencyPromotion = true
// Rate limiting for emergency promotions // CRITICAL: Ensure we ALWAYS have a primary session
if now.Sub(sm.lastEmergencyPromotion) < 30*time.Second { // If there's NO primary, bypass rate limits entirely
sm.logger.Warn(). hasPrimary := sm.primarySessionID != ""
Str("expiredSessionID", sessionID). if !hasPrimary {
Dur("timeSinceLastEmergency", now.Sub(sm.lastEmergencyPromotion)).
Msg("Emergency promotion rate limit exceeded - potential attack")
continue // Skip this grace period expiration
}
// Limit consecutive emergency promotions
if sm.consecutiveEmergencyPromotions >= 3 {
sm.logger.Error(). sm.logger.Error().
Str("expiredSessionID", sessionID). Str("expiredSessionID", sessionID).
Int("consecutiveCount", sm.consecutiveEmergencyPromotions). Msg("CRITICAL: No primary session exists - bypassing all rate limits")
Msg("Too many consecutive emergency promotions - blocking for security") } else {
continue // Skip this grace period expiration // Rate limiting for emergency promotions (only when we have a primary)
if now.Sub(sm.lastEmergencyPromotion) < 30*time.Second {
sm.logger.Warn().
Str("expiredSessionID", sessionID).
Dur("timeSinceLastEmergency", now.Sub(sm.lastEmergencyPromotion)).
Msg("Emergency promotion rate limit exceeded - potential attack")
continue // Skip this grace period expiration
}
// Limit consecutive emergency promotions
if sm.consecutiveEmergencyPromotions >= 3 {
sm.logger.Error().
Str("expiredSessionID", sessionID).
Int("consecutiveCount", sm.consecutiveEmergencyPromotions).
Msg("Too many consecutive emergency promotions - blocking for security")
continue // Skip this grace period expiration
}
} }
promotedSessionID = sm.findMostTrustedSessionForEmergency() promotedSessionID = sm.findMostTrustedSessionForEmergency()
@ -1745,13 +1795,23 @@ func (sm *SessionManager) cleanupInactiveSessions(ctx context.Context) {
if currentSessionSettings != nil && currentSessionSettings.RequireApproval { if currentSessionSettings != nil && currentSessionSettings.RequireApproval {
isEmergencyPromotion = true isEmergencyPromotion = true
// Rate limiting for emergency promotions // CRITICAL: Ensure we ALWAYS have a primary session
if now.Sub(sm.lastEmergencyPromotion) < 30*time.Second { // primarySessionID was just cleared above, so this will always be empty
sm.logger.Warn(). // But check anyway for completeness
hasPrimary := sm.primarySessionID != ""
if !hasPrimary {
sm.logger.Error().
Str("timedOutSessionID", timedOutSessionID). Str("timedOutSessionID", timedOutSessionID).
Dur("timeSinceLastEmergency", now.Sub(sm.lastEmergencyPromotion)). Msg("CRITICAL: No primary session after timeout - bypassing all rate limits")
Msg("Emergency promotion rate limit exceeded during timeout - potential attack") } else {
continue // Skip this timeout // Rate limiting for emergency promotions (only when we have a primary)
if now.Sub(sm.lastEmergencyPromotion) < 30*time.Second {
sm.logger.Warn().
Str("timedOutSessionID", timedOutSessionID).
Dur("timeSinceLastEmergency", now.Sub(sm.lastEmergencyPromotion)).
Msg("Emergency promotion rate limit exceeded during timeout - potential attack")
continue // Skip this timeout
}
} }
// Use trust-based selection but exclude the timed-out session // Use trust-based selection but exclude the timed-out session
@ -1843,7 +1903,21 @@ func (sm *SessionManager) cleanupInactiveSessions(ctx context.Context) {
} }
// Global session manager instance // Global session manager instance
var sessionManager = NewSessionManager(websocketLogger) var (
sessionManager *SessionManager
sessionManagerOnce sync.Once
)
func initSessionManager() {
sessionManagerOnce.Do(func() {
sessionManager = NewSessionManager(websocketLogger)
if sessionManager != nil && websocketLogger != nil {
websocketLogger.Error().
Str("pointer", fmt.Sprintf("%p", sessionManager)).
Msg("!!! GLOBAL sessionManager VARIABLE INITIALIZED - THIS SHOULD ONLY HAPPEN ONCE !!!")
}
})
}
// Global session settings - references config.SessionSettings for persistence // Global session settings - references config.SessionSettings for persistence
var currentSessionSettings *SessionSettings var currentSessionSettings *SessionSettings