mirror of https://github.com/jetkvm/kvm.git
Compare commits
5 Commits
9a10d3ed38
...
08b0dd0c37
| Author | SHA1 | Date |
|---|---|---|
|
|
08b0dd0c37 | |
|
|
f2431e9bbf | |
|
|
c9d8dcb553 | |
|
|
711f7818bf | |
|
|
40ccecc902 |
|
|
@ -16,6 +16,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) {
|
||||||
|
|
||||||
switch message.Type() {
|
switch message.Type() {
|
||||||
case hidrpc.TypeHandshake:
|
case hidrpc.TypeHandshake:
|
||||||
|
if !session.HasPermission(PermissionVideoView) {
|
||||||
|
logger.Debug().
|
||||||
|
Str("sessionID", session.ID).
|
||||||
|
Str("mode", string(session.Mode)).
|
||||||
|
Msg("handshake blocked: session lacks PermissionVideoView")
|
||||||
|
return
|
||||||
|
}
|
||||||
message, err := hidrpc.NewHandshakeMessage().Marshal()
|
message, err := hidrpc.NewHandshakeMessage().Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("failed to marshal handshake message")
|
logger.Warn().Err(err).Msg("failed to marshal handshake message")
|
||||||
|
|
|
||||||
14
jiggler.go
14
jiggler.go
|
|
@ -129,22 +129,18 @@ func runJiggler() {
|
||||||
}
|
}
|
||||||
inactivitySeconds := config.JigglerConfig.InactivityLimitSeconds
|
inactivitySeconds := config.JigglerConfig.InactivityLimitSeconds
|
||||||
timeSinceLastInput := time.Since(gadget.GetLastUserInputTime())
|
timeSinceLastInput := time.Since(gadget.GetLastUserInputTime())
|
||||||
|
logger.Debug().Msgf("Time since last user input %v", timeSinceLastInput)
|
||||||
if timeSinceLastInput > time.Duration(inactivitySeconds)*time.Second {
|
if timeSinceLastInput > time.Duration(inactivitySeconds)*time.Second {
|
||||||
err := gadget.RelMouseReport(1, 0, 0)
|
logger.Debug().Msg("Jiggling mouse...")
|
||||||
|
//TODO: change to rel mouse
|
||||||
|
err := rpcAbsMouseReport(1, 1, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Msgf("Failed to jiggle mouse: %v", err)
|
logger.Warn().Msgf("Failed to jiggle mouse: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(50 * time.Millisecond)
|
err = rpcAbsMouseReport(0, 0, 0)
|
||||||
err = gadget.RelMouseReport(-1, 0, 0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Msgf("Failed to reset mouse position: %v", err)
|
logger.Warn().Msgf("Failed to reset mouse position: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sessionManager != nil {
|
|
||||||
if primarySession := sessionManager.GetPrimarySession(); primarySession != nil {
|
|
||||||
sessionManager.UpdateLastActive(primarySession.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
45
jsonrpc.go
45
jsonrpc.go
|
|
@ -32,6 +32,32 @@ func isValidNickname(nickname string) bool {
|
||||||
return nicknameRegex.MatchString(nickname)
|
return nicknameRegex.MatchString(nickname)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Global RPC rate limiting (protects against coordinated DoS from multiple sessions)
|
||||||
|
var (
|
||||||
|
globalRPCRateLimitMu sync.Mutex
|
||||||
|
globalRPCRateLimit int
|
||||||
|
globalRPCRateLimitWin time.Time
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkGlobalRPCRateLimit() bool {
|
||||||
|
const (
|
||||||
|
maxGlobalRPCPerSecond = 2000
|
||||||
|
rateLimitWindow = time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
globalRPCRateLimitMu.Lock()
|
||||||
|
defer globalRPCRateLimitMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
if now.Sub(globalRPCRateLimitWin) > rateLimitWindow {
|
||||||
|
globalRPCRateLimit = 0
|
||||||
|
globalRPCRateLimitWin = now
|
||||||
|
}
|
||||||
|
|
||||||
|
globalRPCRateLimit++
|
||||||
|
return globalRPCRateLimit <= maxGlobalRPCPerSecond
|
||||||
|
}
|
||||||
|
|
||||||
type JSONRPCRequest struct {
|
type JSONRPCRequest struct {
|
||||||
JSONRPC string `json:"jsonrpc"`
|
JSONRPC string `json:"jsonrpc"`
|
||||||
Method string `json:"method"`
|
Method string `json:"method"`
|
||||||
|
|
@ -119,7 +145,24 @@ func broadcastJSONRPCEvent(event string, params any) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func onRPCMessage(message webrtc.DataChannelMessage, session *Session) {
|
func onRPCMessage(message webrtc.DataChannelMessage, session *Session) {
|
||||||
// Rate limit check (DoS protection)
|
// Global rate limit check (protects against coordinated DoS from multiple sessions)
|
||||||
|
if !checkGlobalRPCRateLimit() {
|
||||||
|
jsonRpcLogger.Warn().
|
||||||
|
Str("sessionId", session.ID).
|
||||||
|
Msg("Global RPC rate limit exceeded")
|
||||||
|
errorResponse := JSONRPCResponse{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
Error: map[string]any{
|
||||||
|
"code": -32000,
|
||||||
|
"message": "Global rate limit exceeded",
|
||||||
|
},
|
||||||
|
ID: 0,
|
||||||
|
}
|
||||||
|
writeJSONRPCResponse(errorResponse, session)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-session rate limit check (DoS protection)
|
||||||
if !session.CheckRPCRateLimit() {
|
if !session.CheckRPCRateLimit() {
|
||||||
jsonRpcLogger.Warn().
|
jsonRpcLogger.Warn().
|
||||||
Str("sessionId", session.ID).
|
Str("sessionId", session.ID).
|
||||||
|
|
|
||||||
|
|
@ -95,19 +95,43 @@ func handleRequestSessionApprovalRPC(session *Session) (any, error) {
|
||||||
return map[string]interface{}{"status": "requested"}, nil
|
return map[string]interface{}{"status": "requested"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleUpdateSessionNicknameRPC handles nickname updates for sessions
|
func validateNickname(nickname string) error {
|
||||||
|
if len(nickname) < 2 {
|
||||||
|
return errors.New("nickname must be at least 2 characters")
|
||||||
|
}
|
||||||
|
if len(nickname) > 30 {
|
||||||
|
return errors.New("nickname must be 30 characters or less")
|
||||||
|
}
|
||||||
|
if !isValidNickname(nickname) {
|
||||||
|
return errors.New("nickname can only contain letters, numbers, spaces, and - _ . @")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, r := range nickname {
|
||||||
|
if r < 32 || r == 127 {
|
||||||
|
return fmt.Errorf("nickname contains control character at position %d", i)
|
||||||
|
}
|
||||||
|
if r >= 0x200B && r <= 0x200D {
|
||||||
|
return errors.New("nickname contains zero-width character")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := ""
|
||||||
|
for _, r := range nickname {
|
||||||
|
trimmed += string(r)
|
||||||
|
}
|
||||||
|
if trimmed != nickname {
|
||||||
|
return errors.New("nickname contains disallowed unicode")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func handleUpdateSessionNicknameRPC(params map[string]any, session *Session) (any, error) {
|
func handleUpdateSessionNicknameRPC(params map[string]any, session *Session) (any, error) {
|
||||||
sessionID, _ := params["sessionId"].(string)
|
sessionID, _ := params["sessionId"].(string)
|
||||||
nickname, _ := params["nickname"].(string)
|
nickname, _ := params["nickname"].(string)
|
||||||
|
|
||||||
if len(nickname) < 2 {
|
if err := validateNickname(nickname); err != nil {
|
||||||
return nil, errors.New("nickname must be at least 2 characters")
|
return nil, err
|
||||||
}
|
|
||||||
if len(nickname) > 30 {
|
|
||||||
return nil, errors.New("nickname must be 30 characters or less")
|
|
||||||
}
|
|
||||||
if !isValidNickname(nickname) {
|
|
||||||
return nil, errors.New("nickname can only contain letters, numbers, spaces, and - _ . @")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
targetSession := sessionManager.GetSession(sessionID)
|
targetSession := sessionManager.GetSession(sessionID)
|
||||||
|
|
|
||||||
|
|
@ -22,30 +22,43 @@ func (sm *SessionManager) attemptEmergencyPromotion(ctx emergencyPromotionContex
|
||||||
return promotedID, false, false
|
return promotedID, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emergency promotion path
|
sm.emergencyWindowMutex.Lock()
|
||||||
hasPrimary := sm.primarySessionID != ""
|
defer sm.emergencyWindowMutex.Unlock()
|
||||||
if !hasPrimary {
|
|
||||||
|
const slidingWindowDuration = 60 * time.Second
|
||||||
|
const maxEmergencyPromotionsPerMinute = 3
|
||||||
|
|
||||||
|
cutoff := ctx.now.Add(-slidingWindowDuration)
|
||||||
|
validEntries := make([]time.Time, 0, len(sm.emergencyPromotionWindow))
|
||||||
|
for _, t := range sm.emergencyPromotionWindow {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
validEntries = append(validEntries, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sm.emergencyPromotionWindow = validEntries
|
||||||
|
|
||||||
|
if len(sm.emergencyPromotionWindow) >= maxEmergencyPromotionsPerMinute {
|
||||||
sm.logger.Error().
|
sm.logger.Error().
|
||||||
Str("triggerSessionID", ctx.triggerSessionID).
|
Str("triggerSessionID", ctx.triggerSessionID).
|
||||||
Msg("CRITICAL: No primary session exists - bypassing all rate limits")
|
Int("promotionsInLastMinute", len(sm.emergencyPromotionWindow)).
|
||||||
} else {
|
Msg("Emergency promotion rate limit exceeded - potential attack")
|
||||||
// Rate limiting (only when we have a primary)
|
return "", false, true
|
||||||
if ctx.now.Sub(sm.lastEmergencyPromotion) < 30*time.Second {
|
}
|
||||||
sm.logger.Warn().
|
|
||||||
Str("triggerSessionID", ctx.triggerSessionID).
|
|
||||||
Dur("timeSinceLastEmergency", ctx.now.Sub(sm.lastEmergencyPromotion)).
|
|
||||||
Msgf("Emergency promotion rate limit exceeded - potential attack (%s)", ctx.triggerReason)
|
|
||||||
return "", false, true // shouldSkip = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Limit consecutive emergency promotions
|
if ctx.now.Sub(sm.lastEmergencyPromotion) < 10*time.Second {
|
||||||
if sm.consecutiveEmergencyPromotions >= 3 {
|
sm.logger.Warn().
|
||||||
sm.logger.Error().
|
Str("triggerSessionID", ctx.triggerSessionID).
|
||||||
Str("triggerSessionID", ctx.triggerSessionID).
|
Dur("timeSinceLastEmergency", ctx.now.Sub(sm.lastEmergencyPromotion)).
|
||||||
Int("consecutiveCount", sm.consecutiveEmergencyPromotions).
|
Msg("Emergency promotion cooldown active")
|
||||||
Msgf("Too many consecutive emergency promotions - blocking for security (%s)", ctx.triggerReason)
|
return "", false, true
|
||||||
return "", false, true // shouldSkip = true
|
}
|
||||||
}
|
|
||||||
|
if sm.consecutiveEmergencyPromotions >= 3 {
|
||||||
|
sm.logger.Error().
|
||||||
|
Str("triggerSessionID", ctx.triggerSessionID).
|
||||||
|
Int("consecutiveCount", sm.consecutiveEmergencyPromotions).
|
||||||
|
Msg("Too many consecutive emergency promotions - blocking")
|
||||||
|
return "", false, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find best session for emergency promotion
|
// Find best session for emergency promotion
|
||||||
|
|
@ -123,6 +136,9 @@ func (sm *SessionManager) promoteAfterGraceExpiration(expiredSessionID string, n
|
||||||
reason := "grace_expiration_promotion"
|
reason := "grace_expiration_promotion"
|
||||||
if isEmergency {
|
if isEmergency {
|
||||||
reason = "emergency_promotion_deadlock_prevention"
|
reason = "emergency_promotion_deadlock_prevention"
|
||||||
|
sm.emergencyWindowMutex.Lock()
|
||||||
|
sm.emergencyPromotionWindow = append(sm.emergencyPromotionWindow, now)
|
||||||
|
sm.emergencyWindowMutex.Unlock()
|
||||||
sm.lastEmergencyPromotion = now
|
sm.lastEmergencyPromotion = now
|
||||||
sm.consecutiveEmergencyPromotions++
|
sm.consecutiveEmergencyPromotions++
|
||||||
|
|
||||||
|
|
@ -249,6 +265,9 @@ func (sm *SessionManager) handlePrimarySessionTimeout(now time.Time) bool {
|
||||||
reason := "timeout_promotion"
|
reason := "timeout_promotion"
|
||||||
if isEmergency {
|
if isEmergency {
|
||||||
reason = "emergency_timeout_promotion"
|
reason = "emergency_timeout_promotion"
|
||||||
|
sm.emergencyWindowMutex.Lock()
|
||||||
|
sm.emergencyPromotionWindow = append(sm.emergencyPromotionWindow, now)
|
||||||
|
sm.emergencyWindowMutex.Unlock()
|
||||||
sm.lastEmergencyPromotion = now
|
sm.lastEmergencyPromotion = now
|
||||||
sm.consecutiveEmergencyPromotions++
|
sm.consecutiveEmergencyPromotions++
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
@ -30,12 +31,23 @@ const (
|
||||||
// Session timeout defaults
|
// Session timeout defaults
|
||||||
defaultPendingSessionTimeout = 1 * time.Minute // Timeout for pending sessions (DoS protection)
|
defaultPendingSessionTimeout = 1 * time.Minute // Timeout for pending sessions (DoS protection)
|
||||||
defaultObserverSessionTimeout = 2 * time.Minute // Timeout for inactive observer sessions
|
defaultObserverSessionTimeout = 2 * time.Minute // Timeout for inactive observer sessions
|
||||||
|
disabledTimeoutValue = 24 * time.Hour // Value used when timeout is disabled (0 setting)
|
||||||
|
|
||||||
// Transfer and blacklist settings
|
// Transfer and blacklist settings
|
||||||
transferBlacklistDuration = 60 * time.Second // Duration to blacklist sessions after manual transfer
|
transferBlacklistDuration = 60 * time.Second // Duration to blacklist sessions after manual transfer
|
||||||
|
|
||||||
// Grace period limits
|
// Grace period limits
|
||||||
maxGracePeriodEntries = 10 // Maximum number of grace period entries to prevent memory exhaustion
|
maxGracePeriodEntries = 10 // Maximum number of grace period entries to prevent memory exhaustion
|
||||||
|
|
||||||
|
// Emergency promotion limits (DoS protection)
|
||||||
|
emergencyWindowDuration = 60 * time.Second // Sliding window duration for emergency promotion rate limiting
|
||||||
|
maxEmergencyPromotionsPerMinute = 3 // Maximum emergency promotions allowed within the sliding window
|
||||||
|
emergencyPromotionCooldown = 10 * time.Second // Minimum time between individual emergency promotions
|
||||||
|
maxConsecutiveEmergencyPromotions = 3 // Maximum consecutive emergency promotions before blocking
|
||||||
|
emergencyPromotionWindowCleanupAge = 60 * time.Second // Age at which emergency window entries are cleaned up
|
||||||
|
|
||||||
|
// Trust scoring constants
|
||||||
|
invalidSessionTrustScore = -1000 // Trust score for non-existent sessions
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
@ -80,30 +92,31 @@ type TransferBlacklistEntry struct {
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast throttling state (DoS protection)
|
|
||||||
var (
|
|
||||||
lastBroadcast time.Time
|
|
||||||
broadcastMutex sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
type SessionManager struct {
|
type SessionManager struct {
|
||||||
mu sync.RWMutex // 24 bytes - place first for better alignment
|
mu sync.RWMutex
|
||||||
primaryTimeout time.Duration // 8 bytes
|
primaryPromotionLock sync.Mutex
|
||||||
logger *zerolog.Logger // 8 bytes
|
primaryTimeout time.Duration
|
||||||
sessions map[string]*Session // 8 bytes
|
logger *zerolog.Logger
|
||||||
nicknameIndex map[string]*Session // 8 bytes - O(1) nickname uniqueness lookups
|
sessions map[string]*Session
|
||||||
reconnectGrace map[string]time.Time // 8 bytes
|
nicknameIndex map[string]*Session
|
||||||
reconnectInfo map[string]*SessionData // 8 bytes
|
reconnectGrace map[string]time.Time
|
||||||
transferBlacklist []TransferBlacklistEntry // Prevent demoted sessions from immediate re-promotion
|
reconnectInfo map[string]*SessionData
|
||||||
queueOrder []string // 24 bytes (slice header)
|
transferBlacklist []TransferBlacklistEntry
|
||||||
primarySessionID string // 16 bytes
|
queueOrder []string
|
||||||
lastPrimaryID string // 16 bytes
|
primarySessionID string
|
||||||
maxSessions int // 8 bytes
|
lastPrimaryID string
|
||||||
cleanupCancel context.CancelFunc // For stopping cleanup goroutine
|
maxSessions int
|
||||||
|
cleanupCancel context.CancelFunc
|
||||||
|
|
||||||
// Emergency promotion tracking for safety
|
|
||||||
lastEmergencyPromotion time.Time
|
lastEmergencyPromotion time.Time
|
||||||
consecutiveEmergencyPromotions int
|
consecutiveEmergencyPromotions int
|
||||||
|
emergencyPromotionWindow []time.Time
|
||||||
|
emergencyWindowMutex sync.Mutex
|
||||||
|
|
||||||
|
lastBroadcast time.Time
|
||||||
|
broadcastMutex sync.Mutex
|
||||||
|
broadcastQueue chan struct{}
|
||||||
|
broadcastPending atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSessionManager creates a new session manager
|
// NewSessionManager creates a new session manager
|
||||||
|
|
@ -141,12 +154,13 @@ func NewSessionManager(logger *zerolog.Logger) *SessionManager {
|
||||||
logger: logger,
|
logger: logger,
|
||||||
maxSessions: maxSessions,
|
maxSessions: maxSessions,
|
||||||
primaryTimeout: primaryTimeout,
|
primaryTimeout: primaryTimeout,
|
||||||
|
broadcastQueue: make(chan struct{}, 100),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start background cleanup of inactive sessions
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
sm.cleanupCancel = cancel
|
sm.cleanupCancel = cancel
|
||||||
go sm.cleanupInactiveSessions(ctx)
|
go sm.cleanupInactiveSessions(ctx)
|
||||||
|
go sm.broadcastWorker(ctx)
|
||||||
|
|
||||||
return sm
|
return sm
|
||||||
}
|
}
|
||||||
|
|
@ -175,13 +189,26 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
|
||||||
sm.mu.Lock()
|
sm.mu.Lock()
|
||||||
defer sm.mu.Unlock()
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
// Check nickname uniqueness using O(1) index (only for non-empty nicknames)
|
nicknameReserved := false
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if nicknameReserved && session.Nickname != "" {
|
||||||
|
if sm.nicknameIndex[session.Nickname] == session {
|
||||||
|
delete(sm.nicknameIndex, session.Nickname)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if session.Nickname != "" {
|
if session.Nickname != "" {
|
||||||
if existingSession, exists := sm.nicknameIndex[session.Nickname]; exists {
|
if existingSession, exists := sm.nicknameIndex[session.Nickname]; exists {
|
||||||
if existingSession.ID != session.ID {
|
if existingSession.ID != session.ID {
|
||||||
return fmt.Errorf("nickname '%s' is already in use by another session", session.Nickname)
|
return fmt.Errorf("nickname '%s' is already in use by another session", session.Nickname)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
sm.nicknameIndex[session.Nickname] = session
|
||||||
|
nicknameReserved = true
|
||||||
}
|
}
|
||||||
|
|
||||||
wasWithinGracePeriod := false
|
wasWithinGracePeriod := false
|
||||||
|
|
@ -199,7 +226,6 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
|
||||||
|
|
||||||
// Check if a session with this ID already exists (reconnection)
|
// Check if a session with this ID already exists (reconnection)
|
||||||
if existing, exists := sm.sessions[session.ID]; exists {
|
if existing, exists := sm.sessions[session.ID]; exists {
|
||||||
// SECURITY: Verify identity matches to prevent session hijacking
|
|
||||||
if existing.Identity != session.Identity || existing.Source != session.Source {
|
if existing.Identity != session.Identity || existing.Source != session.Source {
|
||||||
return fmt.Errorf("session ID already in use by different user (identity mismatch)")
|
return fmt.Errorf("session ID already in use by different user (identity mismatch)")
|
||||||
}
|
}
|
||||||
|
|
@ -215,9 +241,8 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
|
||||||
existing.ControlChannel = session.ControlChannel
|
existing.ControlChannel = session.ControlChannel
|
||||||
existing.RPCChannel = session.RPCChannel
|
existing.RPCChannel = session.RPCChannel
|
||||||
existing.HidChannel = session.HidChannel
|
existing.HidChannel = session.HidChannel
|
||||||
existing.LastActive = time.Now()
|
|
||||||
existing.flushCandidates = session.flushCandidates
|
existing.flushCandidates = session.flushCandidates
|
||||||
// Preserve existing mode and nickname
|
// Preserve mode and nickname
|
||||||
session.Mode = existing.Mode
|
session.Mode = existing.Mode
|
||||||
session.Nickname = existing.Nickname
|
session.Nickname = existing.Nickname
|
||||||
session.CreatedAt = existing.CreatedAt
|
session.CreatedAt = existing.CreatedAt
|
||||||
|
|
@ -348,11 +373,9 @@ func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSe
|
||||||
Int("totalSessions", len(sm.sessions)).
|
Int("totalSessions", len(sm.sessions)).
|
||||||
Msg("Session added to manager")
|
Msg("Session added to manager")
|
||||||
|
|
||||||
// Ensure session has auto-generated nickname if needed
|
|
||||||
sm.ensureNickname(session)
|
sm.ensureNickname(session)
|
||||||
|
|
||||||
// Add to nickname index
|
if !nicknameReserved && session.Nickname != "" {
|
||||||
if session.Nickname != "" {
|
|
||||||
sm.nicknameIndex[session.Nickname] = session
|
sm.nicknameIndex[session.Nickname] = session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -383,13 +406,18 @@ func (sm *SessionManager) RemoveSession(sessionID string) {
|
||||||
wasPrimary := session.Mode == SessionModePrimary
|
wasPrimary := session.Mode == SessionModePrimary
|
||||||
delete(sm.sessions, sessionID)
|
delete(sm.sessions, sessionID)
|
||||||
|
|
||||||
|
if session.Nickname != "" {
|
||||||
|
if sm.nicknameIndex[session.Nickname] == session {
|
||||||
|
delete(sm.nicknameIndex, session.Nickname)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sm.logger.Info().
|
sm.logger.Info().
|
||||||
Str("sessionID", sessionID).
|
Str("sessionID", sessionID).
|
||||||
Bool("wasPrimary", wasPrimary).
|
Bool("wasPrimary", wasPrimary).
|
||||||
Int("remainingSessions", len(sm.sessions)).
|
Int("remainingSessions", len(sm.sessions)).
|
||||||
Msg("Session removed from manager")
|
Msg("Session removed from manager")
|
||||||
|
|
||||||
// Remove from queue if present
|
|
||||||
sm.removeFromQueue(sessionID)
|
sm.removeFromQueue(sessionID)
|
||||||
|
|
||||||
// Check if this session was marked for immediate removal (intentional logout)
|
// Check if this session was marked for immediate removal (intentional logout)
|
||||||
|
|
@ -1063,9 +1091,10 @@ func (sm *SessionManager) validateSinglePrimary() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// transferPrimaryRole is the centralized method for all primary role transfers
|
|
||||||
// It handles bidirectional blacklisting and logging consistently across all transfer types
|
|
||||||
func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transferType, context string) error {
|
func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transferType, context string) error {
|
||||||
|
sm.primaryPromotionLock.Lock()
|
||||||
|
defer sm.primaryPromotionLock.Unlock()
|
||||||
|
|
||||||
// Validate sessions exist
|
// Validate sessions exist
|
||||||
toSession, toExists := sm.sessions[toSessionID]
|
toSession, toExists := sm.sessions[toSessionID]
|
||||||
if !toExists {
|
if !toExists {
|
||||||
|
|
@ -1107,22 +1136,74 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf
|
||||||
Msg("Demoted existing primary session")
|
Msg("Demoted existing primary session")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SECURITY: Before promoting, verify there are no other primary sessions
|
primaryCount := 0
|
||||||
|
var existingPrimaryID string
|
||||||
for id, sess := range sm.sessions {
|
for id, sess := range sm.sessions {
|
||||||
if id != toSessionID && sess.Mode == SessionModePrimary {
|
if sess.Mode == SessionModePrimary {
|
||||||
sm.logger.Error().
|
primaryCount++
|
||||||
Str("existingPrimaryID", id).
|
if id != toSessionID {
|
||||||
Str("targetPromotionID", toSessionID).
|
existingPrimaryID = id
|
||||||
Str("transferType", transferType).
|
}
|
||||||
Msg("CRITICAL: Attempted to create second primary - blocking promotion")
|
|
||||||
return fmt.Errorf("cannot promote: another primary session exists (%s)", id)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if primaryCount > 1 || (primaryCount == 1 && existingPrimaryID != "" && existingPrimaryID != sm.primarySessionID) {
|
||||||
|
sm.logger.Error().
|
||||||
|
Int("primaryCount", primaryCount).
|
||||||
|
Str("existingPrimaryID", existingPrimaryID).
|
||||||
|
Str("targetPromotionID", toSessionID).
|
||||||
|
Str("managerPrimaryID", sm.primarySessionID).
|
||||||
|
Str("transferType", transferType).
|
||||||
|
Msg("CRITICAL: Dual-primary corruption detected - forcing fix")
|
||||||
|
|
||||||
|
for id, sess := range sm.sessions {
|
||||||
|
if sess.Mode == SessionModePrimary {
|
||||||
|
if id != sm.primarySessionID && id != toSessionID {
|
||||||
|
sess.Mode = SessionModeObserver
|
||||||
|
sm.logger.Warn().
|
||||||
|
Str("demotedSessionID", id).
|
||||||
|
Msg("Force-demoted session due to dual-primary corruption")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sm.primarySessionID != "" && sm.sessions[sm.primarySessionID] != nil {
|
||||||
|
if sm.sessions[sm.primarySessionID].Mode != SessionModePrimary {
|
||||||
|
sm.primarySessionID = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
existingPrimaryID = ""
|
||||||
|
for id, sess := range sm.sessions {
|
||||||
|
if id != toSessionID && sess.Mode == SessionModePrimary {
|
||||||
|
existingPrimaryID = id
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if existingPrimaryID != "" {
|
||||||
|
sm.logger.Error().
|
||||||
|
Str("existingPrimaryID", existingPrimaryID).
|
||||||
|
Str("targetPromotionID", toSessionID).
|
||||||
|
Msg("CRITICAL: Cannot fix dual-primary corruption - blocking promotion")
|
||||||
|
return fmt.Errorf("cannot promote: dual-primary corruption detected and fix failed (%s)", existingPrimaryID)
|
||||||
|
}
|
||||||
|
} else if existingPrimaryID != "" {
|
||||||
|
sm.logger.Error().
|
||||||
|
Str("existingPrimaryID", existingPrimaryID).
|
||||||
|
Str("targetPromotionID", toSessionID).
|
||||||
|
Str("transferType", transferType).
|
||||||
|
Msg("CRITICAL: Attempted to create second primary - blocking promotion")
|
||||||
|
return fmt.Errorf("cannot promote: another primary session exists (%s)", existingPrimaryID)
|
||||||
|
}
|
||||||
|
|
||||||
// Promote target session
|
// Promote target session
|
||||||
toSession.Mode = SessionModePrimary
|
toSession.Mode = SessionModePrimary
|
||||||
toSession.hidRPCAvailable = false // Force re-handshake
|
toSession.hidRPCAvailable = false
|
||||||
toSession.LastActive = time.Now() // Reset activity timestamp to prevent immediate timeout
|
// Reset LastActive only for emergency promotions to prevent immediate re-timeout
|
||||||
|
if transferType == "emergency_timeout_promotion" || transferType == "emergency_promotion_deadlock_prevention" {
|
||||||
|
toSession.LastActive = time.Now()
|
||||||
|
}
|
||||||
sm.primarySessionID = toSessionID
|
sm.primarySessionID = toSessionID
|
||||||
|
|
||||||
// ALWAYS set lastPrimaryID to the new primary to support WebRTC reconnections
|
// ALWAYS set lastPrimaryID to the new primary to support WebRTC reconnections
|
||||||
|
|
@ -1197,7 +1278,7 @@ func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transf
|
||||||
// Send reconnection signal for emergency promotions via WebSocket (more reliable than RPC when channel is stale)
|
// Send reconnection signal for emergency promotions via WebSocket (more reliable than RPC when channel is stale)
|
||||||
if toExists && (transferType == "emergency_timeout_promotion" || transferType == "emergency_auto_promotion") {
|
if toExists && (transferType == "emergency_timeout_promotion" || transferType == "emergency_auto_promotion") {
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(globalBroadcastDelay)
|
||||||
|
|
||||||
eventData := map[string]interface{}{
|
eventData := map[string]interface{}{
|
||||||
"sessionId": toSessionID,
|
"sessionId": toSessionID,
|
||||||
|
|
@ -1287,8 +1368,7 @@ func (sm *SessionManager) getCurrentPrimaryTimeout() time.Duration {
|
||||||
// Use session settings if available
|
// Use session settings if available
|
||||||
if currentSessionSettings != nil {
|
if currentSessionSettings != nil {
|
||||||
if currentSessionSettings.PrimaryTimeout == 0 {
|
if currentSessionSettings.PrimaryTimeout == 0 {
|
||||||
// 0 means disabled - return a very large duration
|
return disabledTimeoutValue
|
||||||
return 24 * time.Hour
|
|
||||||
} else if currentSessionSettings.PrimaryTimeout > 0 {
|
} else if currentSessionSettings.PrimaryTimeout > 0 {
|
||||||
return time.Duration(currentSessionSettings.PrimaryTimeout) * time.Second
|
return time.Duration(currentSessionSettings.PrimaryTimeout) * time.Second
|
||||||
}
|
}
|
||||||
|
|
@ -1301,7 +1381,7 @@ func (sm *SessionManager) getCurrentPrimaryTimeout() time.Duration {
|
||||||
func (sm *SessionManager) getSessionTrustScore(sessionID string) int {
|
func (sm *SessionManager) getSessionTrustScore(sessionID string) int {
|
||||||
session, exists := sm.sessions[sessionID]
|
session, exists := sm.sessions[sessionID]
|
||||||
if !exists {
|
if !exists {
|
||||||
return -1000 // Session doesn't exist
|
return invalidSessionTrustScore
|
||||||
}
|
}
|
||||||
|
|
||||||
score := 0
|
score := 0
|
||||||
|
|
@ -1347,9 +1427,7 @@ func (sm *SessionManager) findMostTrustedSessionForEmergency() string {
|
||||||
bestSessionID := ""
|
bestSessionID := ""
|
||||||
bestScore := -1
|
bestScore := -1
|
||||||
|
|
||||||
// First pass: try to find observers or queued sessions (preferred)
|
|
||||||
for sessionID, session := range sm.sessions {
|
for sessionID, session := range sm.sessions {
|
||||||
// Skip if blacklisted, primary, or not eligible modes
|
|
||||||
if sm.isSessionBlacklisted(sessionID) ||
|
if sm.isSessionBlacklisted(sessionID) ||
|
||||||
session.Mode == SessionModePrimary ||
|
session.Mode == SessionModePrimary ||
|
||||||
(session.Mode != SessionModeObserver && session.Mode != SessionModeQueued) {
|
(session.Mode != SessionModeObserver && session.Mode != SessionModeQueued) {
|
||||||
|
|
@ -1363,24 +1441,6 @@ func (sm *SessionManager) findMostTrustedSessionForEmergency() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no observers/queued found, try pending sessions as last resort
|
|
||||||
if bestSessionID == "" {
|
|
||||||
for sessionID, session := range sm.sessions {
|
|
||||||
if sm.isSessionBlacklisted(sessionID) || session.Mode == SessionModePrimary {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if session.Mode == SessionModePending {
|
|
||||||
score := sm.getSessionTrustScore(sessionID)
|
|
||||||
if score > bestScore {
|
|
||||||
bestScore = score
|
|
||||||
bestSessionID = sessionID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log the selection decision for audit trail
|
|
||||||
if bestSessionID != "" {
|
if bestSessionID != "" {
|
||||||
sm.logger.Info().
|
sm.logger.Info().
|
||||||
Str("selectedSession", bestSessionID).
|
Str("selectedSession", bestSessionID).
|
||||||
|
|
@ -1492,21 +1552,37 @@ func (sm *SessionManager) updateAllSessionNicknames() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sm *SessionManager) broadcastSessionListUpdate() {
|
func (sm *SessionManager) broadcastWorker(ctx context.Context) {
|
||||||
// Throttle broadcasts to prevent DoS
|
for {
|
||||||
broadcastMutex.Lock()
|
select {
|
||||||
if time.Since(lastBroadcast) < globalBroadcastDelay {
|
case <-ctx.Done():
|
||||||
broadcastMutex.Unlock()
|
return
|
||||||
return // Skip this broadcast to prevent storm
|
case <-sm.broadcastQueue:
|
||||||
|
sm.broadcastPending.Store(false)
|
||||||
|
sm.executeBroadcast()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
lastBroadcast = time.Now()
|
}
|
||||||
broadcastMutex.Unlock()
|
|
||||||
|
func (sm *SessionManager) broadcastSessionListUpdate() {
|
||||||
|
if sm.broadcastPending.CompareAndSwap(false, true) {
|
||||||
|
select {
|
||||||
|
case sm.broadcastQueue <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *SessionManager) executeBroadcast() {
|
||||||
|
sm.broadcastMutex.Lock()
|
||||||
|
if time.Since(sm.lastBroadcast) < globalBroadcastDelay {
|
||||||
|
sm.broadcastMutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sm.lastBroadcast = time.Now()
|
||||||
|
sm.broadcastMutex.Unlock()
|
||||||
|
|
||||||
// Must be called in a goroutine to avoid deadlock
|
|
||||||
// Get all sessions first - use read lock only, no validation during broadcasts
|
|
||||||
sm.mu.RLock()
|
sm.mu.RLock()
|
||||||
|
|
||||||
// Build session infos and collect active sessions in one pass
|
|
||||||
infos := make([]SessionData, 0, len(sm.sessions))
|
infos := make([]SessionData, 0, len(sm.sessions))
|
||||||
activeSessions := make([]*Session, 0, len(sm.sessions))
|
activeSessions := make([]*Session, 0, len(sm.sessions))
|
||||||
|
|
||||||
|
|
@ -1521,17 +1597,13 @@ func (sm *SessionManager) broadcastSessionListUpdate() {
|
||||||
LastActive: session.LastActive,
|
LastActive: session.LastActive,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Only collect sessions ready for broadcast
|
|
||||||
if session.RPCChannel != nil {
|
if session.RPCChannel != nil {
|
||||||
activeSessions = append(activeSessions, session)
|
activeSessions = append(activeSessions, session)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sm.mu.RUnlock()
|
sm.mu.RUnlock()
|
||||||
|
|
||||||
// Now send events without holding lock
|
|
||||||
for _, session := range activeSessions {
|
for _, session := range activeSessions {
|
||||||
// Per-session throttling to prevent broadcast storms
|
|
||||||
session.lastBroadcastMu.Lock()
|
session.lastBroadcastMu.Lock()
|
||||||
shouldSkip := time.Since(session.LastBroadcast) < sessionBroadcastDelay
|
shouldSkip := time.Since(session.LastBroadcast) < sessionBroadcastDelay
|
||||||
if !shouldSkip {
|
if !shouldSkip {
|
||||||
|
|
@ -1560,7 +1632,8 @@ func (sm *SessionManager) Shutdown() {
|
||||||
sm.mu.Lock()
|
sm.mu.Lock()
|
||||||
defer sm.mu.Unlock()
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
// Clean up all sessions
|
close(sm.broadcastQueue)
|
||||||
|
|
||||||
for id := range sm.sessions {
|
for id := range sm.sessions {
|
||||||
delete(sm.sessions, id)
|
delete(sm.sessions, id)
|
||||||
}
|
}
|
||||||
|
|
@ -1581,6 +1654,18 @@ func (sm *SessionManager) cleanupInactiveSessions(ctx context.Context) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
needsBroadcast := false
|
needsBroadcast := false
|
||||||
|
|
||||||
|
// Clean up expired emergency promotion window entries
|
||||||
|
sm.emergencyWindowMutex.Lock()
|
||||||
|
cutoff := now.Add(-emergencyPromotionWindowCleanupAge)
|
||||||
|
validEntries := make([]time.Time, 0, len(sm.emergencyPromotionWindow))
|
||||||
|
for _, t := range sm.emergencyPromotionWindow {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
validEntries = append(validEntries, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sm.emergencyPromotionWindow = validEntries
|
||||||
|
sm.emergencyWindowMutex.Unlock()
|
||||||
|
|
||||||
// Handle expired grace periods
|
// Handle expired grace periods
|
||||||
gracePeriodExpired := sm.handleGracePeriodExpiration(now)
|
gracePeriodExpired := sm.handleGracePeriodExpiration(now)
|
||||||
if gracePeriodExpired {
|
if gracePeriodExpired {
|
||||||
|
|
|
||||||
21
webrtc.go
21
webrtc.go
|
|
@ -123,7 +123,10 @@ func (s *Session) sendWebSocketSignal(messageType string, data map[string]interf
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := wsjson.Write(context.Background(), s.ws, gin.H{"type": messageType, "data": data})
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := wsjson.Write(ctx, s.ws, gin.H{"type": messageType, "data": data})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
webrtcLogger.Debug().Err(err).Str("sessionId", s.ID).Msg("Failed to send WebSocket signal")
|
webrtcLogger.Debug().Err(err).Str("sessionId", s.ID).Msg("Failed to send WebSocket signal")
|
||||||
return err
|
return err
|
||||||
|
|
@ -347,7 +350,13 @@ func newSession(config SessionConfig) (*Session, error) {
|
||||||
case "rpc":
|
case "rpc":
|
||||||
session.RPCChannel = d
|
session.RPCChannel = d
|
||||||
d.OnMessage(func(msg webrtc.DataChannelMessage) {
|
d.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||||
// Enqueue to ensure ordered processing
|
queueLen := len(session.rpcQueue)
|
||||||
|
if queueLen > 200 {
|
||||||
|
scopedLogger.Warn().
|
||||||
|
Str("sessionID", session.ID).
|
||||||
|
Int("queueLen", queueLen).
|
||||||
|
Msg("RPC queue approaching capacity")
|
||||||
|
}
|
||||||
session.rpcQueue <- msg
|
session.rpcQueue <- msg
|
||||||
})
|
})
|
||||||
triggerOTAStateUpdate()
|
triggerOTAStateUpdate()
|
||||||
|
|
@ -406,7 +415,9 @@ func newSession(config SessionConfig) (*Session, error) {
|
||||||
}
|
}
|
||||||
candidateBufferMutex.Unlock()
|
candidateBufferMutex.Unlock()
|
||||||
|
|
||||||
err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err := wsjson.Write(ctx, config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
scopedLogger.Warn().Err(err).Msg("failed to write new-ice-candidate to WebRTC signaling channel")
|
scopedLogger.Warn().Err(err).Msg("failed to write new-ice-candidate to WebRTC signaling channel")
|
||||||
}
|
}
|
||||||
|
|
@ -419,7 +430,9 @@ func newSession(config SessionConfig) (*Session, error) {
|
||||||
answerSent = true
|
answerSent = true
|
||||||
// Send all buffered candidates
|
// Send all buffered candidates
|
||||||
for _, candidate := range candidateBuffer {
|
for _, candidate := range candidateBuffer {
|
||||||
err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate})
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
err := wsjson.Write(ctx, config.ws, gin.H{"type": "new-ice-candidate", "data": candidate})
|
||||||
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
scopedLogger.Warn().Err(err).Msg("failed to write buffered new-ice-candidate to WebRTC signaling channel")
|
scopedLogger.Warn().Err(err).Msg("failed to write buffered new-ice-candidate to WebRTC signaling channel")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue