Compare commits

..

1 Commits

Author SHA1 Message Date
Alex 51d23dca83
Merge 9a10d3ed38 into 74e64f69a7 2025-10-17 10:52:34 +02:00
7 changed files with 125 additions and 312 deletions

View File

@ -16,13 +16,6 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) {
switch message.Type() { switch message.Type() {
case hidrpc.TypeHandshake: case hidrpc.TypeHandshake:
if !session.HasPermission(PermissionVideoView) {
logger.Debug().
Str("sessionID", session.ID).
Str("mode", string(session.Mode)).
Msg("handshake blocked: session lacks PermissionVideoView")
return
}
message, err := hidrpc.NewHandshakeMessage().Marshal() message, err := hidrpc.NewHandshakeMessage().Marshal()
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("failed to marshal handshake message") logger.Warn().Err(err).Msg("failed to marshal handshake message")

View File

@ -129,18 +129,22 @@ func runJiggler() {
} }
inactivitySeconds := config.JigglerConfig.InactivityLimitSeconds inactivitySeconds := config.JigglerConfig.InactivityLimitSeconds
timeSinceLastInput := time.Since(gadget.GetLastUserInputTime()) timeSinceLastInput := time.Since(gadget.GetLastUserInputTime())
logger.Debug().Msgf("Time since last user input %v", timeSinceLastInput)
if timeSinceLastInput > time.Duration(inactivitySeconds)*time.Second { if timeSinceLastInput > time.Duration(inactivitySeconds)*time.Second {
logger.Debug().Msg("Jiggling mouse...") err := gadget.RelMouseReport(1, 0, 0)
//TODO: change to rel mouse
err := rpcAbsMouseReport(1, 1, 0)
if err != nil { if err != nil {
logger.Warn().Msgf("Failed to jiggle mouse: %v", err) logger.Warn().Msgf("Failed to jiggle mouse: %v", err)
} }
err = rpcAbsMouseReport(0, 0, 0) time.Sleep(50 * time.Millisecond)
err = gadget.RelMouseReport(-1, 0, 0)
if err != nil { if err != nil {
logger.Warn().Msgf("Failed to reset mouse position: %v", err) logger.Warn().Msgf("Failed to reset mouse position: %v", err)
} }
if sessionManager != nil {
if primarySession := sessionManager.GetPrimarySession(); primarySession != nil {
sessionManager.UpdateLastActive(primarySession.ID)
}
}
} }
} }
} }

View File

@ -32,32 +32,6 @@ func isValidNickname(nickname string) bool {
return nicknameRegex.MatchString(nickname) return nicknameRegex.MatchString(nickname)
} }
// Global RPC rate limiting (protects against coordinated DoS from multiple sessions)
var (
globalRPCRateLimitMu sync.Mutex
globalRPCRateLimit int
globalRPCRateLimitWin time.Time
)
func checkGlobalRPCRateLimit() bool {
const (
maxGlobalRPCPerSecond = 2000
rateLimitWindow = time.Second
)
globalRPCRateLimitMu.Lock()
defer globalRPCRateLimitMu.Unlock()
now := time.Now()
if now.Sub(globalRPCRateLimitWin) > rateLimitWindow {
globalRPCRateLimit = 0
globalRPCRateLimitWin = now
}
globalRPCRateLimit++
return globalRPCRateLimit <= maxGlobalRPCPerSecond
}
type JSONRPCRequest struct { type JSONRPCRequest struct {
JSONRPC string `json:"jsonrpc"` JSONRPC string `json:"jsonrpc"`
Method string `json:"method"` Method string `json:"method"`
@ -145,24 +119,7 @@ func broadcastJSONRPCEvent(event string, params any) {
} }
func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { func onRPCMessage(message webrtc.DataChannelMessage, session *Session) {
// Global rate limit check (protects against coordinated DoS from multiple sessions) // Rate limit check (DoS protection)
if !checkGlobalRPCRateLimit() {
jsonRpcLogger.Warn().
Str("sessionId", session.ID).
Msg("Global RPC rate limit exceeded")
errorResponse := JSONRPCResponse{
JSONRPC: "2.0",
Error: map[string]any{
"code": -32000,
"message": "Global rate limit exceeded",
},
ID: 0,
}
writeJSONRPCResponse(errorResponse, session)
return
}
// Per-session rate limit check (DoS protection)
if !session.CheckRPCRateLimit() { if !session.CheckRPCRateLimit() {
jsonRpcLogger.Warn(). jsonRpcLogger.Warn().
Str("sessionId", session.ID). Str("sessionId", session.ID).

View File

@ -95,43 +95,19 @@ func handleRequestSessionApprovalRPC(session *Session) (any, error) {
return map[string]interface{}{"status": "requested"}, nil return map[string]interface{}{"status": "requested"}, nil
} }
func validateNickname(nickname string) error { // handleUpdateSessionNicknameRPC handles nickname updates for sessions
if len(nickname) < 2 {
return errors.New("nickname must be at least 2 characters")
}
if len(nickname) > 30 {
return errors.New("nickname must be 30 characters or less")
}
if !isValidNickname(nickname) {
return errors.New("nickname can only contain letters, numbers, spaces, and - _ . @")
}
for i, r := range nickname {
if r < 32 || r == 127 {
return fmt.Errorf("nickname contains control character at position %d", i)
}
if r >= 0x200B && r <= 0x200D {
return errors.New("nickname contains zero-width character")
}
}
trimmed := ""
for _, r := range nickname {
trimmed += string(r)
}
if trimmed != nickname {
return errors.New("nickname contains disallowed unicode")
}
return nil
}
func handleUpdateSessionNicknameRPC(params map[string]any, session *Session) (any, error) { func handleUpdateSessionNicknameRPC(params map[string]any, session *Session) (any, error) {
sessionID, _ := params["sessionId"].(string) sessionID, _ := params["sessionId"].(string)
nickname, _ := params["nickname"].(string) nickname, _ := params["nickname"].(string)
if err := validateNickname(nickname); err != nil { if len(nickname) < 2 {
return nil, err return nil, errors.New("nickname must be at least 2 characters")
}
if len(nickname) > 30 {
return nil, errors.New("nickname must be 30 characters or less")
}
if !isValidNickname(nickname) {
return nil, errors.New("nickname can only contain letters, numbers, spaces, and - _ . @")
} }
targetSession := sessionManager.GetSession(sessionID) targetSession := sessionManager.GetSession(sessionID)

View File

@ -22,43 +22,30 @@ func (sm *SessionManager) attemptEmergencyPromotion(ctx emergencyPromotionContex
return promotedID, false, false return promotedID, false, false
} }
sm.emergencyWindowMutex.Lock() // Emergency promotion path
defer sm.emergencyWindowMutex.Unlock() hasPrimary := sm.primarySessionID != ""
if !hasPrimary {
const slidingWindowDuration = 60 * time.Second
const maxEmergencyPromotionsPerMinute = 3
cutoff := ctx.now.Add(-slidingWindowDuration)
validEntries := make([]time.Time, 0, len(sm.emergencyPromotionWindow))
for _, t := range sm.emergencyPromotionWindow {
if t.After(cutoff) {
validEntries = append(validEntries, t)
}
}
sm.emergencyPromotionWindow = validEntries
if len(sm.emergencyPromotionWindow) >= maxEmergencyPromotionsPerMinute {
sm.logger.Error(). sm.logger.Error().
Str("triggerSessionID", ctx.triggerSessionID). Str("triggerSessionID", ctx.triggerSessionID).
Int("promotionsInLastMinute", len(sm.emergencyPromotionWindow)). Msg("CRITICAL: No primary session exists - bypassing all rate limits")
Msg("Emergency promotion rate limit exceeded - potential attack") } else {
return "", false, true // Rate limiting (only when we have a primary)
} if ctx.now.Sub(sm.lastEmergencyPromotion) < 30*time.Second {
if ctx.now.Sub(sm.lastEmergencyPromotion) < 10*time.Second {
sm.logger.Warn(). sm.logger.Warn().
Str("triggerSessionID", ctx.triggerSessionID). Str("triggerSessionID", ctx.triggerSessionID).
Dur("timeSinceLastEmergency", ctx.now.Sub(sm.lastEmergencyPromotion)). Dur("timeSinceLastEmergency", ctx.now.Sub(sm.lastEmergencyPromotion)).
Msg("Emergency promotion cooldown active") Msgf("Emergency promotion rate limit exceeded - potential attack (%s)", ctx.triggerReason)
return "", false, true return "", false, true // shouldSkip = true
} }
// Limit consecutive emergency promotions
if sm.consecutiveEmergencyPromotions >= 3 { if sm.consecutiveEmergencyPromotions >= 3 {
sm.logger.Error(). sm.logger.Error().
Str("triggerSessionID", ctx.triggerSessionID). Str("triggerSessionID", ctx.triggerSessionID).
Int("consecutiveCount", sm.consecutiveEmergencyPromotions). Int("consecutiveCount", sm.consecutiveEmergencyPromotions).
Msg("Too many consecutive emergency promotions - blocking") Msgf("Too many consecutive emergency promotions - blocking for security (%s)", ctx.triggerReason)
return "", false, true return "", false, true // shouldSkip = true
}
} }
// Find best session for emergency promotion // Find best session for emergency promotion
@ -136,9 +123,6 @@ func (sm *SessionManager) promoteAfterGraceExpiration(expiredSessionID string, n
reason := "grace_expiration_promotion" reason := "grace_expiration_promotion"
if isEmergency { if isEmergency {
reason = "emergency_promotion_deadlock_prevention" reason = "emergency_promotion_deadlock_prevention"
sm.emergencyWindowMutex.Lock()
sm.emergencyPromotionWindow = append(sm.emergencyPromotionWindow, now)
sm.emergencyWindowMutex.Unlock()
sm.lastEmergencyPromotion = now sm.lastEmergencyPromotion = now
sm.consecutiveEmergencyPromotions++ sm.consecutiveEmergencyPromotions++
@ -265,9 +249,6 @@ func (sm *SessionManager) handlePrimarySessionTimeout(now time.Time) bool {
reason := "timeout_promotion" reason := "timeout_promotion"
if isEmergency { if isEmergency {
reason = "emergency_timeout_promotion" reason = "emergency_timeout_promotion"
sm.emergencyWindowMutex.Lock()
sm.emergencyPromotionWindow = append(sm.emergencyPromotionWindow, now)
sm.emergencyWindowMutex.Unlock()
sm.lastEmergencyPromotion = now sm.lastEmergencyPromotion = now
sm.consecutiveEmergencyPromotions++ sm.consecutiveEmergencyPromotions++

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -31,23 +30,12 @@ const (
// Session timeout defaults // Session timeout defaults
defaultPendingSessionTimeout = 1 * time.Minute // Timeout for pending sessions (DoS protection) defaultPendingSessionTimeout = 1 * time.Minute // Timeout for pending sessions (DoS protection)
defaultObserverSessionTimeout = 2 * time.Minute // Timeout for inactive observer sessions defaultObserverSessionTimeout = 2 * time.Minute // Timeout for inactive observer sessions
disabledTimeoutValue = 24 * time.Hour // Value used when timeout is disabled (0 setting)
// Transfer and blacklist settings // Transfer and blacklist settings
transferBlacklistDuration = 60 * time.Second // Duration to blacklist sessions after manual transfer transferBlacklistDuration = 60 * time.Second // Duration to blacklist sessions after manual transfer
// Grace period limits // Grace period limits
maxGracePeriodEntries = 10 // Maximum number of grace period entries to prevent memory exhaustion maxGracePeriodEntries = 10 // Maximum number of grace period entries to prevent memory exhaustion
// Emergency promotion limits (DoS protection)
emergencyWindowDuration = 60 * time.Second // Sliding window duration for emergency promotion rate limiting
maxEmergencyPromotionsPerMinute = 3 // Maximum emergency promotions allowed within the sliding window
emergencyPromotionCooldown = 10 * time.Second // Minimum time between individual emergency promotions
maxConsecutiveEmergencyPromotions = 3 // Maximum consecutive emergency promotions before blocking
emergencyPromotionWindowCleanupAge = 60 * time.Second // Age at which emergency window entries are cleaned up
// Trust scoring constants
invalidSessionTrustScore = -1000 // Trust score for non-existent sessions
) )
var ( var (
@ -92,31 +80,30 @@ type TransferBlacklistEntry struct {
ExpiresAt time.Time ExpiresAt time.Time
} }
type SessionManager struct { // Broadcast throttling state (DoS protection)
mu sync.RWMutex var (
primaryPromotionLock sync.Mutex
primaryTimeout time.Duration
logger *zerolog.Logger
sessions map[string]*Session
nicknameIndex map[string]*Session
reconnectGrace map[string]time.Time
reconnectInfo map[string]*SessionData
transferBlacklist []TransferBlacklistEntry
queueOrder []string
primarySessionID string
lastPrimaryID string
maxSessions int
cleanupCancel context.CancelFunc
lastEmergencyPromotion time.Time
consecutiveEmergencyPromotions int
emergencyPromotionWindow []time.Time
emergencyWindowMutex sync.Mutex
lastBroadcast time.Time lastBroadcast time.Time
broadcastMutex sync.Mutex broadcastMutex sync.Mutex
broadcastQueue chan struct{} )
broadcastPending atomic.Bool
type SessionManager struct {
mu sync.RWMutex // 24 bytes - place first for better alignment
primaryTimeout time.Duration // 8 bytes
logger *zerolog.Logger // 8 bytes
sessions map[string]*Session // 8 bytes
nicknameIndex map[string]*Session // 8 bytes - O(1) nickname uniqueness lookups
reconnectGrace map[string]time.Time // 8 bytes
reconnectInfo map[string]*SessionData // 8 bytes
transferBlacklist []TransferBlacklistEntry // Prevent demoted sessions from immediate re-promotion
queueOrder []string // 24 bytes (slice header)
primarySessionID string // 16 bytes
lastPrimaryID string // 16 bytes
maxSessions int // 8 bytes
cleanupCancel context.CancelFunc // For stopping cleanup goroutine
// Emergency promotion tracking for safety
lastEmergencyPromotion time.Time
consecutiveEmergencyPromotions int
} }
// NewSessionManager creates a new session manager // NewSessionManager creates a new session manager
@ -154,13 +141,12 @@ func NewSessionManager(logger *zerolog.Logger) *SessionManager {
logger: logger, logger: logger,
maxSessions: maxSessions, maxSessions: maxSessions,
primaryTimeout: primaryTimeout, primaryTimeout: primaryTimeout,
broadcastQueue: make(chan struct{}, 100),
} }
// Start background cleanup of inactive sessions
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
sm.cleanupCancel = cancel sm.cleanupCancel = cancel
go sm.cleanupInactiveSessions(ctx) go sm.cleanupInactiveSessions(ctx)
go sm.broadcastWorker(ctx)
return sm return sm
} }
@ -189,26 +175,13 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
sm.mu.Lock() sm.mu.Lock()
defer sm.mu.Unlock() defer sm.mu.Unlock()
nicknameReserved := false // Check nickname uniqueness using O(1) index (only for non-empty nicknames)
defer func() {
if r := recover(); r != nil {
if nicknameReserved && session.Nickname != "" {
if sm.nicknameIndex[session.Nickname] == session {
delete(sm.nicknameIndex, session.Nickname)
}
}
panic(r)
}
}()
if session.Nickname != "" { if session.Nickname != "" {
if existingSession, exists := sm.nicknameIndex[session.Nickname]; exists { if existingSession, exists := sm.nicknameIndex[session.Nickname]; exists {
if existingSession.ID != session.ID { if existingSession.ID != session.ID {
return fmt.Errorf("nickname '%s' is already in use by another session", session.Nickname) return fmt.Errorf("nickname '%s' is already in use by another session", session.Nickname)
} }
} }
sm.nicknameIndex[session.Nickname] = session
nicknameReserved = true
} }
wasWithinGracePeriod := false wasWithinGracePeriod := false
@ -226,6 +199,7 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
// Check if a session with this ID already exists (reconnection) // Check if a session with this ID already exists (reconnection)
if existing, exists := sm.sessions[session.ID]; exists { if existing, exists := sm.sessions[session.ID]; exists {
// SECURITY: Verify identity matches to prevent session hijacking
if existing.Identity != session.Identity || existing.Source != session.Source { if existing.Identity != session.Identity || existing.Source != session.Source {
return fmt.Errorf("session ID already in use by different user (identity mismatch)") return fmt.Errorf("session ID already in use by different user (identity mismatch)")
} }
@ -241,8 +215,9 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
existing.ControlChannel = session.ControlChannel existing.ControlChannel = session.ControlChannel
existing.RPCChannel = session.RPCChannel existing.RPCChannel = session.RPCChannel
existing.HidChannel = session.HidChannel existing.HidChannel = session.HidChannel
existing.LastActive = time.Now()
existing.flushCandidates = session.flushCandidates existing.flushCandidates = session.flushCandidates
// Preserve mode and nickname // Preserve existing mode and nickname
session.Mode = existing.Mode session.Mode = existing.Mode
session.Nickname = existing.Nickname session.Nickname = existing.Nickname
session.CreatedAt = existing.CreatedAt session.CreatedAt = existing.CreatedAt
@ -373,9 +348,11 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
Int("totalSessions", len(sm.sessions)). Int("totalSessions", len(sm.sessions)).
Msg("Session added to manager") Msg("Session added to manager")
// Ensure session has auto-generated nickname if needed
sm.ensureNickname(session) sm.ensureNickname(session)
if !nicknameReserved && session.Nickname != "" { // Add to nickname index
if session.Nickname != "" {
sm.nicknameIndex[session.Nickname] = session sm.nicknameIndex[session.Nickname] = session
} }
@ -406,18 +383,13 @@ func (sm *SessionManager) RemoveSession(sessionID string) {
wasPrimary := session.Mode == SessionModePrimary wasPrimary := session.Mode == SessionModePrimary
delete(sm.sessions, sessionID) delete(sm.sessions, sessionID)
if session.Nickname != "" {
if sm.nicknameIndex[session.Nickname] == session {
delete(sm.nicknameIndex, session.Nickname)
}
}
sm.logger.Info(). sm.logger.Info().
Str("sessionID", sessionID). Str("sessionID", sessionID).
Bool("wasPrimary", wasPrimary). Bool("wasPrimary", wasPrimary).
Int("remainingSessions", len(sm.sessions)). Int("remainingSessions", len(sm.sessions)).
Msg("Session removed from manager") Msg("Session removed from manager")
// Remove from queue if present
sm.removeFromQueue(sessionID) sm.removeFromQueue(sessionID)
// Check if this session was marked for immediate removal (intentional logout) // Check if this session was marked for immediate removal (intentional logout)
@ -1091,10 +1063,9 @@ func (sm *SessionManager) validateSinglePrimary() {
} }
} }
// transferPrimaryRole is the centralized method for all primary role transfers
// It handles bidirectional blacklisting and logging consistently across all transfer types
func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transferType, context string) error { func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transferType, context string) error {
sm.primaryPromotionLock.Lock()
defer sm.primaryPromotionLock.Unlock()
// Validate sessions exist // Validate sessions exist
toSession, toExists := sm.sessions[toSessionID] toSession, toExists := sm.sessions[toSessionID]
if !toExists { if !toExists {
@ -1136,74 +1107,22 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf
Msg("Demoted existing primary session") Msg("Demoted existing primary session")
} }
primaryCount := 0 // SECURITY: Before promoting, verify there are no other primary sessions
var existingPrimaryID string
for id, sess := range sm.sessions {
if sess.Mode == SessionModePrimary {
primaryCount++
if id != toSessionID {
existingPrimaryID = id
}
}
}
if primaryCount > 1 || (primaryCount == 1 && existingPrimaryID != "" && existingPrimaryID != sm.primarySessionID) {
sm.logger.Error().
Int("primaryCount", primaryCount).
Str("existingPrimaryID", existingPrimaryID).
Str("targetPromotionID", toSessionID).
Str("managerPrimaryID", sm.primarySessionID).
Str("transferType", transferType).
Msg("CRITICAL: Dual-primary corruption detected - forcing fix")
for id, sess := range sm.sessions {
if sess.Mode == SessionModePrimary {
if id != sm.primarySessionID && id != toSessionID {
sess.Mode = SessionModeObserver
sm.logger.Warn().
Str("demotedSessionID", id).
Msg("Force-demoted session due to dual-primary corruption")
}
}
}
if sm.primarySessionID != "" && sm.sessions[sm.primarySessionID] != nil {
if sm.sessions[sm.primarySessionID].Mode != SessionModePrimary {
sm.primarySessionID = ""
}
}
existingPrimaryID = ""
for id, sess := range sm.sessions { for id, sess := range sm.sessions {
if id != toSessionID && sess.Mode == SessionModePrimary { if id != toSessionID && sess.Mode == SessionModePrimary {
existingPrimaryID = id
break
}
}
if existingPrimaryID != "" {
sm.logger.Error(). sm.logger.Error().
Str("existingPrimaryID", existingPrimaryID). Str("existingPrimaryID", id).
Str("targetPromotionID", toSessionID).
Msg("CRITICAL: Cannot fix dual-primary corruption - blocking promotion")
return fmt.Errorf("cannot promote: dual-primary corruption detected and fix failed (%s)", existingPrimaryID)
}
} else if existingPrimaryID != "" {
sm.logger.Error().
Str("existingPrimaryID", existingPrimaryID).
Str("targetPromotionID", toSessionID). Str("targetPromotionID", toSessionID).
Str("transferType", transferType). Str("transferType", transferType).
Msg("CRITICAL: Attempted to create second primary - blocking promotion") Msg("CRITICAL: Attempted to create second primary - blocking promotion")
return fmt.Errorf("cannot promote: another primary session exists (%s)", existingPrimaryID) return fmt.Errorf("cannot promote: another primary session exists (%s)", id)
}
} }
// Promote target session // Promote target session
toSession.Mode = SessionModePrimary toSession.Mode = SessionModePrimary
toSession.hidRPCAvailable = false toSession.hidRPCAvailable = false // Force re-handshake
// Reset LastActive only for emergency promotions to prevent immediate re-timeout toSession.LastActive = time.Now() // Reset activity timestamp to prevent immediate timeout
if transferType == "emergency_timeout_promotion" || transferType == "emergency_promotion_deadlock_prevention" {
toSession.LastActive = time.Now()
}
sm.primarySessionID = toSessionID sm.primarySessionID = toSessionID
// ALWAYS set lastPrimaryID to the new primary to support WebRTC reconnections // ALWAYS set lastPrimaryID to the new primary to support WebRTC reconnections
@ -1278,7 +1197,7 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf
// Send reconnection signal for emergency promotions via WebSocket (more reliable than RPC when channel is stale) // Send reconnection signal for emergency promotions via WebSocket (more reliable than RPC when channel is stale)
if toExists && (transferType == "emergency_timeout_promotion" || transferType == "emergency_auto_promotion") { if toExists && (transferType == "emergency_timeout_promotion" || transferType == "emergency_auto_promotion") {
go func() { go func() {
time.Sleep(globalBroadcastDelay) time.Sleep(100 * time.Millisecond)
eventData := map[string]interface{}{ eventData := map[string]interface{}{
"sessionId": toSessionID, "sessionId": toSessionID,
@ -1368,7 +1287,8 @@ func (sm *SessionManager) getCurrentPrimaryTimeout() time.Duration {
// Use session settings if available // Use session settings if available
if currentSessionSettings != nil { if currentSessionSettings != nil {
if currentSessionSettings.PrimaryTimeout == 0 { if currentSessionSettings.PrimaryTimeout == 0 {
return disabledTimeoutValue // 0 means disabled - return a very large duration
return 24 * time.Hour
} else if currentSessionSettings.PrimaryTimeout > 0 { } else if currentSessionSettings.PrimaryTimeout > 0 {
return time.Duration(currentSessionSettings.PrimaryTimeout) * time.Second return time.Duration(currentSessionSettings.PrimaryTimeout) * time.Second
} }
@ -1381,7 +1301,7 @@ func (sm *SessionManager) getCurrentPrimaryTimeout() time.Duration {
func (sm *SessionManager) getSessionTrustScore(sessionID string) int { func (sm *SessionManager) getSessionTrustScore(sessionID string) int {
session, exists := sm.sessions[sessionID] session, exists := sm.sessions[sessionID]
if !exists { if !exists {
return invalidSessionTrustScore return -1000 // Session doesn't exist
} }
score := 0 score := 0
@ -1427,7 +1347,9 @@ func (sm *SessionManager) findMostTrustedSessionForEmergency() string {
bestSessionID := "" bestSessionID := ""
bestScore := -1 bestScore := -1
// First pass: try to find observers or queued sessions (preferred)
for sessionID, session := range sm.sessions { for sessionID, session := range sm.sessions {
// Skip if blacklisted, primary, or not eligible modes
if sm.isSessionBlacklisted(sessionID) || if sm.isSessionBlacklisted(sessionID) ||
session.Mode == SessionModePrimary || session.Mode == SessionModePrimary ||
(session.Mode != SessionModeObserver && session.Mode != SessionModeQueued) { (session.Mode != SessionModeObserver && session.Mode != SessionModeQueued) {
@ -1441,6 +1363,24 @@ func (sm *SessionManager) findMostTrustedSessionForEmergency() string {
} }
} }
// If no observers/queued found, try pending sessions as last resort
if bestSessionID == "" {
for sessionID, session := range sm.sessions {
if sm.isSessionBlacklisted(sessionID) || session.Mode == SessionModePrimary {
continue
}
if session.Mode == SessionModePending {
score := sm.getSessionTrustScore(sessionID)
if score > bestScore {
bestScore = score
bestSessionID = sessionID
}
}
}
}
// Log the selection decision for audit trail
if bestSessionID != "" { if bestSessionID != "" {
sm.logger.Info(). sm.logger.Info().
Str("selectedSession", bestSessionID). Str("selectedSession", bestSessionID).
@ -1552,37 +1492,21 @@ func (sm *SessionManager) updateAllSessionNicknames() {
} }
} }
func (sm *SessionManager) broadcastWorker(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-sm.broadcastQueue:
sm.broadcastPending.Store(false)
sm.executeBroadcast()
}
}
}
func (sm *SessionManager) broadcastSessionListUpdate() { func (sm *SessionManager) broadcastSessionListUpdate() {
if sm.broadcastPending.CompareAndSwap(false, true) { // Throttle broadcasts to prevent DoS
select { broadcastMutex.Lock()
case sm.broadcastQueue <- struct{}{}: if time.Since(lastBroadcast) < globalBroadcastDelay {
default: broadcastMutex.Unlock()
} return // Skip this broadcast to prevent storm
}
} }
lastBroadcast = time.Now()
broadcastMutex.Unlock()
func (sm *SessionManager) executeBroadcast() { // Must be called in a goroutine to avoid deadlock
sm.broadcastMutex.Lock() // Get all sessions first - use read lock only, no validation during broadcasts
if time.Since(sm.lastBroadcast) < globalBroadcastDelay {
sm.broadcastMutex.Unlock()
return
}
sm.lastBroadcast = time.Now()
sm.broadcastMutex.Unlock()
sm.mu.RLock() sm.mu.RLock()
// Build session infos and collect active sessions in one pass
infos := make([]SessionData, 0, len(sm.sessions)) infos := make([]SessionData, 0, len(sm.sessions))
activeSessions := make([]*Session, 0, len(sm.sessions)) activeSessions := make([]*Session, 0, len(sm.sessions))
@ -1597,13 +1521,17 @@ func (sm *SessionManager) executeBroadcast() {
LastActive: session.LastActive, LastActive: session.LastActive,
}) })
// Only collect sessions ready for broadcast
if session.RPCChannel != nil { if session.RPCChannel != nil {
activeSessions = append(activeSessions, session) activeSessions = append(activeSessions, session)
} }
} }
sm.mu.RUnlock() sm.mu.RUnlock()
// Now send events without holding lock
for _, session := range activeSessions { for _, session := range activeSessions {
// Per-session throttling to prevent broadcast storms
session.lastBroadcastMu.Lock() session.lastBroadcastMu.Lock()
shouldSkip := time.Since(session.LastBroadcast) < sessionBroadcastDelay shouldSkip := time.Since(session.LastBroadcast) < sessionBroadcastDelay
if !shouldSkip { if !shouldSkip {
@ -1632,8 +1560,7 @@ func (sm *SessionManager) Shutdown() {
sm.mu.Lock() sm.mu.Lock()
defer sm.mu.Unlock() defer sm.mu.Unlock()
close(sm.broadcastQueue) // Clean up all sessions
for id := range sm.sessions { for id := range sm.sessions {
delete(sm.sessions, id) delete(sm.sessions, id)
} }
@ -1654,18 +1581,6 @@ func (sm *SessionManager) cleanupInactiveSessions(ctx context.Context) {
now := time.Now() now := time.Now()
needsBroadcast := false needsBroadcast := false
// Clean up expired emergency promotion window entries
sm.emergencyWindowMutex.Lock()
cutoff := now.Add(-emergencyPromotionWindowCleanupAge)
validEntries := make([]time.Time, 0, len(sm.emergencyPromotionWindow))
for _, t := range sm.emergencyPromotionWindow {
if t.After(cutoff) {
validEntries = append(validEntries, t)
}
}
sm.emergencyPromotionWindow = validEntries
sm.emergencyWindowMutex.Unlock()
// Handle expired grace periods // Handle expired grace periods
gracePeriodExpired := sm.handleGracePeriodExpiration(now) gracePeriodExpired := sm.handleGracePeriodExpiration(now)
if gracePeriodExpired { if gracePeriodExpired {

View File

@ -123,10 +123,7 @@ func (s *Session) sendWebSocketSignal(messageType string, data map[string]interf
return nil return nil
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) err := wsjson.Write(context.Background(), s.ws, gin.H{"type": messageType, "data": data})
defer cancel()
err := wsjson.Write(ctx, s.ws, gin.H{"type": messageType, "data": data})
if err != nil { if err != nil {
webrtcLogger.Debug().Err(err).Str("sessionId", s.ID).Msg("Failed to send WebSocket signal") webrtcLogger.Debug().Err(err).Str("sessionId", s.ID).Msg("Failed to send WebSocket signal")
return err return err
@ -350,13 +347,7 @@ func newSession(config SessionConfig) (*Session, error) {
case "rpc": case "rpc":
session.RPCChannel = d session.RPCChannel = d
d.OnMessage(func(msg webrtc.DataChannelMessage) { d.OnMessage(func(msg webrtc.DataChannelMessage) {
queueLen := len(session.rpcQueue) // Enqueue to ensure ordered processing
if queueLen > 200 {
scopedLogger.Warn().
Str("sessionID", session.ID).
Int("queueLen", queueLen).
Msg("RPC queue approaching capacity")
}
session.rpcQueue <- msg session.rpcQueue <- msg
}) })
triggerOTAStateUpdate() triggerOTAStateUpdate()
@ -415,9 +406,7 @@ func newSession(config SessionConfig) (*Session, error) {
} }
candidateBufferMutex.Unlock() candidateBufferMutex.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
defer cancel()
err := wsjson.Write(ctx, config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
if err != nil { if err != nil {
scopedLogger.Warn().Err(err).Msg("failed to write new-ice-candidate to WebRTC signaling channel") scopedLogger.Warn().Err(err).Msg("failed to write new-ice-candidate to WebRTC signaling channel")
} }
@ -430,9 +419,7 @@ func newSession(config SessionConfig) (*Session, error) {
answerSent = true answerSent = true
// Send all buffered candidates // Send all buffered candidates
for _, candidate := range candidateBuffer { for _, candidate := range candidateBuffer {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate})
err := wsjson.Write(ctx, config.ws, gin.H{"type": "new-ice-candidate", "data": candidate})
cancel()
if err != nil { if err != nil {
scopedLogger.Warn().Err(err).Msg("failed to write buffered new-ice-candidate to WebRTC signaling channel") scopedLogger.Warn().Err(err).Msg("failed to write buffered new-ice-candidate to WebRTC signaling channel")
} }