diff --git a/block_device.go b/block_device.go index e4eab80..2274098 100644 --- a/block_device.go +++ b/block_device.go @@ -7,7 +7,6 @@ import ( "os" "time" - "github.com/pojntfx/go-nbd/pkg/client" "github.com/pojntfx/go-nbd/pkg/server" "github.com/rs/zerolog" ) @@ -149,30 +148,3 @@ func (d *NBDDevice) runServerConn() { d.l.Info().Err(err).Msg("nbd server exited") } - -func (d *NBDDevice) runClientConn() { - err := client.Connect(d.clientConn, d.dev, &client.Options{ - ExportName: "jetkvm", - BlockSize: uint32(4 * 1024), - }) - d.l.Info().Err(err).Msg("nbd client exited") -} - -func (d *NBDDevice) Close() { - if d.dev != nil { - err := client.Disconnect(d.dev) - if err != nil { - d.l.Warn().Err(err).Msg("error disconnecting nbd client") - } - _ = d.dev.Close() - } - if d.listener != nil { - _ = d.listener.Close() - } - if d.clientConn != nil { - _ = d.clientConn.Close() - } - if d.serverConn != nil { - _ = d.serverConn.Close() - } -} diff --git a/block_device_linux.go b/block_device_linux.go new file mode 100644 index 0000000..8ca9372 --- /dev/null +++ b/block_device_linux.go @@ -0,0 +1,34 @@ +//go:build linux + +package kvm + +import ( + "github.com/pojntfx/go-nbd/pkg/client" +) + +func (d *NBDDevice) runClientConn() { + err := client.Connect(d.clientConn, d.dev, &client.Options{ + ExportName: "jetkvm", + BlockSize: uint32(4 * 1024), + }) + d.l.Info().Err(err).Msg("nbd client exited") +} + +func (d *NBDDevice) Close() { + if d.dev != nil { + err := client.Disconnect(d.dev) + if err != nil { + d.l.Warn().Err(err).Msg("error disconnecting nbd client") + } + _ = d.dev.Close() + } + if d.listener != nil { + _ = d.listener.Close() + } + if d.clientConn != nil { + _ = d.clientConn.Close() + } + if d.serverConn != nil { + _ = d.serverConn.Close() + } +} diff --git a/block_device_notlinux.go b/block_device_notlinux.go new file mode 100644 index 0000000..b6a9aba --- /dev/null +++ b/block_device_notlinux.go @@ -0,0 +1,17 @@ +//go:build !linux + +package kvm + +import ( + "os" +) + +func (d *NBDDevice) runClientConn() { + d.l.Error().Msg("platform not supported") + os.Exit(1) +} + +func (d *NBDDevice) Close() { + d.l.Error().Msg("platform not supported") + os.Exit(1) +} diff --git a/cloud.go b/cloud.go index fd96c41..fb1998a 100644 --- a/cloud.go +++ b/cloud.go @@ -139,11 +139,40 @@ var ( ) ) +type CloudConnectionState uint8 + +const ( + CloudConnectionStateNotConfigured CloudConnectionState = iota + CloudConnectionStateDisconnected + CloudConnectionStateConnecting + CloudConnectionStateConnected +) + var ( + cloudConnectionState CloudConnectionState = CloudConnectionStateNotConfigured + cloudConnectionStateLock = &sync.Mutex{} + cloudDisconnectChan chan error cloudDisconnectLock = &sync.Mutex{} ) +func setCloudConnectionState(state CloudConnectionState) { + cloudConnectionStateLock.Lock() + defer cloudConnectionStateLock.Unlock() + + if cloudConnectionState == CloudConnectionStateDisconnected && + (config.CloudToken == "" || config.CloudURL == "") { + state = CloudConnectionStateNotConfigured + } + + previousState := cloudConnectionState + cloudConnectionState = state + + go waitCtrlAndRequestDisplayUpdate( + previousState != state, + ) +} + func wsResetMetrics(established bool, sourceType string, source string) { metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1) metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1) @@ -285,6 +314,8 @@ func runWebsocketClient() error { wsURL.Scheme = "wss" } + setCloudConnectionState(CloudConnectionStateConnecting) + header := http.Header{} header.Set("X-Device-ID", GetDeviceID()) header.Set("X-App-Version", builtAppVersion) @@ -302,20 +333,26 @@ func runWebsocketClient() error { c, resp, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{ HTTPHeader: header, OnPingReceived: func(ctx context.Context, payload []byte) bool { - scopedLogger.Info().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received") + scopedLogger.Debug().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received") metricConnectionTotalPingReceivedCount.WithLabelValues("cloud", wsURL.Host).Inc() metricConnectionLastPingReceivedTimestamp.WithLabelValues("cloud", wsURL.Host).SetToCurrentTime() + setCloudConnectionState(CloudConnectionStateConnected) + return true }, }) - // get the request id from the response header - connectionId := resp.Header.Get("X-Request-ID") - if connectionId == "" { - connectionId = resp.Header.Get("Cf-Ray") + var connectionId string + if resp != nil { + // get the request id from the response header + connectionId = resp.Header.Get("X-Request-ID") + if connectionId == "" { + connectionId = resp.Header.Get("Cf-Ray") + } } + if connectionId == "" { connectionId = uuid.New().String() scopedLogger.Warn(). @@ -332,6 +369,8 @@ func runWebsocketClient() error { if err != nil { if errors.Is(err, context.Canceled) { cloudLogger.Info().Msg("websocket connection canceled") + setCloudConnectionState(CloudConnectionStateDisconnected) + return nil } return err @@ -450,14 +489,14 @@ func RunWebsocketClient() { } // If the network is not up, well, we can't connect to the cloud. - if !networkState.Up { - cloudLogger.Warn().Msg("waiting for network to be up, will retry in 3 seconds") + if !networkState.IsOnline() { + cloudLogger.Warn().Msg("waiting for network to be online, will retry in 3 seconds") time.Sleep(3 * time.Second) continue } // If the system time is not synchronized, the API request will fail anyway because the TLS handshake will fail. - if isTimeSyncNeeded() && !timeSyncSuccess { + if isTimeSyncNeeded() && !timeSync.IsSyncSuccess() { cloudLogger.Warn().Msg("system time is not synced, will retry in 3 seconds") time.Sleep(3 * time.Second) continue @@ -520,6 +559,8 @@ func rpcDeregisterDevice() error { cloudLogger.Info().Msg("device deregistered, disconnecting from cloud") disconnectCloud(fmt.Errorf("device deregistered")) + setCloudConnectionState(CloudConnectionStateNotConfigured) + return nil } diff --git a/config.go b/config.go index cf096a7..23d4c84 100644 --- a/config.go +++ b/config.go @@ -6,6 +6,8 @@ import ( "os" "sync" + "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/network" "github.com/jetkvm/kvm/internal/usbgadget" ) @@ -73,27 +75,28 @@ func (m *KeyboardMacro) Validate() error { } type Config struct { - CloudURL string `json:"cloud_url"` - CloudAppURL string `json:"cloud_app_url"` - CloudToken string `json:"cloud_token"` - GoogleIdentity string `json:"google_identity"` - JigglerEnabled bool `json:"jiggler_enabled"` - AutoUpdateEnabled bool `json:"auto_update_enabled"` - IncludePreRelease bool `json:"include_pre_release"` - HashedPassword string `json:"hashed_password"` - LocalAuthToken string `json:"local_auth_token"` - LocalAuthMode string `json:"localAuthMode"` //TODO: fix it with migration - WakeOnLanDevices []WakeOnLanDevice `json:"wake_on_lan_devices"` - KeyboardMacros []KeyboardMacro `json:"keyboard_macros"` - EdidString string `json:"hdmi_edid_string"` - ActiveExtension string `json:"active_extension"` - DisplayMaxBrightness int `json:"display_max_brightness"` - DisplayDimAfterSec int `json:"display_dim_after_sec"` - DisplayOffAfterSec int `json:"display_off_after_sec"` - TLSMode string `json:"tls_mode"` // options: "self-signed", "user-defined", "" - UsbConfig *usbgadget.Config `json:"usb_config"` - UsbDevices *usbgadget.Devices `json:"usb_devices"` - DefaultLogLevel string `json:"default_log_level"` + CloudURL string `json:"cloud_url"` + CloudAppURL string `json:"cloud_app_url"` + CloudToken string `json:"cloud_token"` + GoogleIdentity string `json:"google_identity"` + JigglerEnabled bool `json:"jiggler_enabled"` + AutoUpdateEnabled bool `json:"auto_update_enabled"` + IncludePreRelease bool `json:"include_pre_release"` + HashedPassword string `json:"hashed_password"` + LocalAuthToken string `json:"local_auth_token"` + LocalAuthMode string `json:"localAuthMode"` //TODO: fix it with migration + WakeOnLanDevices []WakeOnLanDevice `json:"wake_on_lan_devices"` + KeyboardMacros []KeyboardMacro `json:"keyboard_macros"` + EdidString string `json:"hdmi_edid_string"` + ActiveExtension string `json:"active_extension"` + DisplayMaxBrightness int `json:"display_max_brightness"` + DisplayDimAfterSec int `json:"display_dim_after_sec"` + DisplayOffAfterSec int `json:"display_off_after_sec"` + TLSMode string `json:"tls_mode"` // options: "self-signed", "user-defined", "" + UsbConfig *usbgadget.Config `json:"usb_config"` + UsbDevices *usbgadget.Devices `json:"usb_devices"` + NetworkConfig *network.NetworkConfig `json:"network_config"` + DefaultLogLevel string `json:"default_log_level"` } const configPath = "/userdata/kvm_config.json" @@ -121,6 +124,7 @@ var defaultConfig = &Config{ Keyboard: true, MassStorage: true, }, + NetworkConfig: &network.NetworkConfig{}, DefaultLogLevel: "INFO", } @@ -134,7 +138,7 @@ func LoadConfig() { defer configLock.Unlock() if config != nil { - logger.Info().Msg("config already loaded, skipping") + logger.Debug().Msg("config already loaded, skipping") return } @@ -164,9 +168,15 @@ func LoadConfig() { loadedConfig.UsbDevices = defaultConfig.UsbDevices } + if loadedConfig.NetworkConfig == nil { + loadedConfig.NetworkConfig = defaultConfig.NetworkConfig + } + config = &loadedConfig - rootLogger.UpdateLogLevel() + logging.GetRootLogger().UpdateLogLevel(config.DefaultLogLevel) + + logger.Info().Str("path", configPath).Msg("config loaded") } func SaveConfig() error { diff --git a/dev_deploy.sh b/dev_deploy.sh index 02bbb24..d0ccaf2 100755 --- a/dev_deploy.sh +++ b/dev_deploy.sh @@ -24,6 +24,7 @@ show_help() { REMOTE_USER="root" REMOTE_PATH="/userdata/jetkvm/bin" SKIP_UI_BUILD=false +LOG_TRACE_SCOPES="${LOG_TRACE_SCOPES:-jetkvm,cloud,websocket,native,jsonrpc}" # Parse command line arguments while [[ $# -gt 0 ]]; do @@ -91,7 +92,7 @@ cd "${REMOTE_PATH}" chmod +x jetkvm_app_debug # Run the application in the background -PION_LOG_TRACE=jetkvm,cloud,websocket ./jetkvm_app_debug +PION_LOG_TRACE=${LOG_TRACE_SCOPES} ./jetkvm_app_debug EOF echo "Deployment complete." diff --git a/display.go b/display.go index cbe9ddd..e2e82e1 100644 --- a/display.go +++ b/display.go @@ -33,50 +33,153 @@ func switchToScreen(screen string) { var displayedTexts = make(map[string]string) +func lvObjSetState(objName string, state string) (*CtrlResponse, error) { + return CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": objName, "state": state}) +} + +func lvObjAddFlag(objName string, flag string) (*CtrlResponse, error) { + return CallCtrlAction("lv_obj_add_flag", map[string]interface{}{"obj": objName, "flag": flag}) +} + +func lvObjClearFlag(objName string, flag string) (*CtrlResponse, error) { + return CallCtrlAction("lv_obj_clear_flag", map[string]interface{}{"obj": objName, "flag": flag}) +} + +func lvObjHide(objName string) (*CtrlResponse, error) { + return lvObjAddFlag(objName, "LV_OBJ_FLAG_HIDDEN") +} + +func lvObjShow(objName string) (*CtrlResponse, error) { + return lvObjClearFlag(objName, "LV_OBJ_FLAG_HIDDEN") +} + +func lvObjSetOpacity(objName string, opacity int) (*CtrlResponse, error) { // nolint:unused + return CallCtrlAction("lv_obj_set_style_opa_layered", map[string]interface{}{"obj": objName, "opa": opacity}) +} + +func lvObjFadeIn(objName string, duration uint32) (*CtrlResponse, error) { + return CallCtrlAction("lv_obj_fade_in", map[string]interface{}{"obj": objName, "time": duration}) +} + +func lvObjFadeOut(objName string, duration uint32) (*CtrlResponse, error) { + return CallCtrlAction("lv_obj_fade_out", map[string]interface{}{"obj": objName, "time": duration}) +} + +func lvLabelSetText(objName string, text string) (*CtrlResponse, error) { + return CallCtrlAction("lv_label_set_text", map[string]interface{}{"obj": objName, "text": text}) +} + +func lvImgSetSrc(objName string, src string) (*CtrlResponse, error) { + return CallCtrlAction("lv_img_set_src", map[string]interface{}{"obj": objName, "src": src}) +} + func updateLabelIfChanged(objName string, newText string) { if newText != "" && newText != displayedTexts[objName] { - _, _ = CallCtrlAction("lv_label_set_text", map[string]interface{}{"obj": objName, "text": newText}) + _, _ = lvLabelSetText(objName, newText) displayedTexts[objName] = newText } } func switchToScreenIfDifferent(screenName string) { - displayLogger.Info().Str("from", currentScreen).Str("to", screenName).Msg("switching screen") if currentScreen != screenName { + displayLogger.Info().Str("from", currentScreen).Str("to", screenName).Msg("switching screen") switchToScreen(screenName) } } +var ( + cloudBlinkLock sync.Mutex = sync.Mutex{} + cloudBlinkStopped bool + cloudBlinkTicker *time.Ticker +) + func updateDisplay() { - updateLabelIfChanged("ui_Home_Content_Ip", networkState.IPv4) + updateLabelIfChanged("ui_Home_Content_Ip", networkState.IPv4String()) if usbState == "configured" { updateLabelIfChanged("ui_Home_Footer_Usb_Status_Label", "Connected") - _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Usb_Status_Label", "state": "LV_STATE_DEFAULT"}) + _, _ = lvObjSetState("ui_Home_Footer_Usb_Status_Label", "LV_STATE_DEFAULT") } else { updateLabelIfChanged("ui_Home_Footer_Usb_Status_Label", "Disconnected") - _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Usb_Status_Label", "state": "LV_STATE_USER_2"}) + _, _ = lvObjSetState("ui_Home_Footer_Usb_Status_Label", "LV_STATE_USER_2") } if lastVideoState.Ready { updateLabelIfChanged("ui_Home_Footer_Hdmi_Status_Label", "Connected") - _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Hdmi_Status_Label", "state": "LV_STATE_DEFAULT"}) + _, _ = lvObjSetState("ui_Home_Footer_Hdmi_Status_Label", "LV_STATE_DEFAULT") } else { updateLabelIfChanged("ui_Home_Footer_Hdmi_Status_Label", "Disconnected") - _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Hdmi_Status_Label", "state": "LV_STATE_USER_2"}) + _, _ = lvObjSetState("ui_Home_Footer_Hdmi_Status_Label", "LV_STATE_USER_2") } updateLabelIfChanged("ui_Home_Header_Cloud_Status_Label", fmt.Sprintf("%d active", actionSessions)) - if networkState.Up { + + if networkState.IsUp() { switchToScreenIfDifferent("ui_Home_Screen") } else { switchToScreenIfDifferent("ui_No_Network_Screen") } + + if cloudConnectionState == CloudConnectionStateNotConfigured { + _, _ = lvObjHide("ui_Home_Header_Cloud_Status_Icon") + } else { + _, _ = lvObjShow("ui_Home_Header_Cloud_Status_Icon") + } + + switch cloudConnectionState { + case CloudConnectionStateDisconnected: + _, _ = lvImgSetSrc("ui_Home_Header_Cloud_Status_Icon", "cloud_disconnected.png") + stopCloudBlink() + case CloudConnectionStateConnecting: + _, _ = lvImgSetSrc("ui_Home_Header_Cloud_Status_Icon", "cloud.png") + startCloudBlink() + case CloudConnectionStateConnected: + _, _ = lvImgSetSrc("ui_Home_Header_Cloud_Status_Icon", "cloud.png") + stopCloudBlink() + } +} + +func startCloudBlink() { + if cloudBlinkTicker == nil { + cloudBlinkTicker = time.NewTicker(2 * time.Second) + } else { + // do nothing if the blink isn't stopped + if cloudBlinkStopped { + cloudBlinkLock.Lock() + defer cloudBlinkLock.Unlock() + + cloudBlinkStopped = false + cloudBlinkTicker.Reset(2 * time.Second) + } + } + + go func() { + for range cloudBlinkTicker.C { + if cloudConnectionState != CloudConnectionStateConnecting { + continue + } + _, _ = lvObjFadeOut("ui_Home_Header_Cloud_Status_Icon", 1000) + time.Sleep(1000 * time.Millisecond) + _, _ = lvObjFadeIn("ui_Home_Header_Cloud_Status_Icon", 1000) + time.Sleep(1000 * time.Millisecond) + } + }() +} + +func stopCloudBlink() { + if cloudBlinkTicker != nil { + cloudBlinkTicker.Stop() + } + + cloudBlinkLock.Lock() + defer cloudBlinkLock.Unlock() + cloudBlinkStopped = true } var ( displayInited = false displayUpdateLock = sync.Mutex{} + waitDisplayUpdate = sync.Mutex{} ) -func requestDisplayUpdate() { +func requestDisplayUpdate(shouldWakeDisplay bool) { displayUpdateLock.Lock() defer displayUpdateLock.Unlock() @@ -85,16 +188,26 @@ func requestDisplayUpdate() { return } go func() { - wakeDisplay(false) - displayLogger.Info().Msg("display updating") + if shouldWakeDisplay { + wakeDisplay(false) + } + displayLogger.Debug().Msg("display updating") //TODO: only run once regardless how many pending updates updateDisplay() }() } +func waitCtrlAndRequestDisplayUpdate(shouldWakeDisplay bool) { + waitDisplayUpdate.Lock() + defer waitDisplayUpdate.Unlock() + + waitCtrlClientConnected() + requestDisplayUpdate(shouldWakeDisplay) +} + func updateStaticContents() { //contents that never change - updateLabelIfChanged("ui_Home_Content_Mac", networkState.MAC) + updateLabelIfChanged("ui_Home_Content_Mac", networkState.MACString()) systemVersion, appVersion, err := GetLocalVersion() if err == nil { updateLabelIfChanged("ui_About_Content_Operating_System_Version_ContentLabel", systemVersion.String()) @@ -265,7 +378,7 @@ func init() { displayLogger.Info().Msg("display inited") startBacklightTickers() wakeDisplay(true) - requestDisplayUpdate() + requestDisplayUpdate(true) }() go watchTsEvents() diff --git a/go.mod b/go.mod index 1311a33..6784a59 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/coder/websocket v1.8.13 github.com/coreos/go-oidc/v3 v3.11.0 github.com/creack/pty v1.1.23 + github.com/fsnotify/fsnotify v1.9.0 github.com/gin-contrib/logger v1.2.5 github.com/gin-gonic/gin v1.10.0 github.com/google/uuid v1.6.0 @@ -44,6 +45,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.26.0 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/guregu/null/v6 v6.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect diff --git a/go.sum b/go.sum index 565c0cc..3ad832a 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfv github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= github.com/gin-contrib/logger v1.2.5 h1:qVQI4omayQecuN4zX9ZZnsOq7w9J/ZLds3J/FMn8ypM= @@ -54,6 +56,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/guregu/null/v6 v6.0.0 h1:N14VRS+4di81i1PXRiprbQJ9EM9gqBa0+KVMeS/QSjQ= +github.com/guregu/null/v6 v6.0.0/go.mod h1:hrMIhIfrOZeLPZhROSn149tpw2gHkidAqxoXNyeX3iQ= github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf h1:JO6ISZIvEUitto5zjQ3/VEnDM5rPbqIFuOhS0U0ByeA= github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf/go.mod h1:5Kt9XkWvkGi2OHOq0QsGxebHmhCcqJ8KCbNg/a6+n+g= github.com/hanwen/go-fuse/v2 v2.5.1 h1:OQBE8zVemSocRxA4OaFJbjJ5hlpCmIWbGr7r0M4uoQQ= diff --git a/hw.go b/hw.go index 21bffad..20d88eb 100644 --- a/hw.go +++ b/hw.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "regexp" + "strings" "sync" "time" ) @@ -51,6 +52,15 @@ func GetDeviceID() string { return deviceID } +func GetDefaultHostname() string { + deviceId := GetDeviceID() + if deviceId == "unknown_device_id" { + return "jetkvm" + } + + return fmt.Sprintf("jetkvm-%s", strings.ToLower(deviceId)) +} + func runWatchdog() { file, err := os.OpenFile("/dev/watchdog", os.O_WRONLY, 0) if err != nil { diff --git a/internal/confparser/confparser.go b/internal/confparser/confparser.go new file mode 100644 index 0000000..76102a3 --- /dev/null +++ b/internal/confparser/confparser.go @@ -0,0 +1,381 @@ +package confparser + +import ( + "fmt" + "net" + "reflect" + "slices" + "strconv" + "strings" + + "github.com/guregu/null/v6" + "golang.org/x/net/idna" +) + +type FieldConfig struct { + Name string + Required bool + RequiredIf map[string]interface{} + OneOf []string + ValidateTypes []string + Defaults interface{} + IsEmpty bool + CurrentValue interface{} + TypeString string + Delegated bool + shouldUpdateValue bool +} + +func SetDefaultsAndValidate(config interface{}) error { + return setDefaultsAndValidate(config, true) +} + +func setDefaultsAndValidate(config interface{}, isRoot bool) error { + // first we need to check if the config is a pointer + if reflect.TypeOf(config).Kind() != reflect.Ptr { + return fmt.Errorf("config is not a pointer") + } + + // now iterate over the lease struct and set the values + configType := reflect.TypeOf(config).Elem() + configValue := reflect.ValueOf(config).Elem() + + fields := make(map[string]FieldConfig) + + for i := 0; i < configType.NumField(); i++ { + field := configType.Field(i) + fieldValue := configValue.Field(i) + + defaultValue := field.Tag.Get("default") + + fieldType := field.Type.String() + + fieldConfig := FieldConfig{ + Name: field.Name, + OneOf: splitString(field.Tag.Get("one_of")), + ValidateTypes: splitString(field.Tag.Get("validate_type")), + RequiredIf: make(map[string]interface{}), + CurrentValue: fieldValue.Interface(), + IsEmpty: false, + TypeString: fieldType, + } + + // check if the field is required + required := field.Tag.Get("required") + if required != "" { + requiredBool, _ := strconv.ParseBool(required) + fieldConfig.Required = requiredBool + } + + var canUseOneOff = false + + // use switch to get the type + switch fieldValue.Interface().(type) { + case string, null.String: + if defaultValue != "" { + fieldConfig.Defaults = defaultValue + } + canUseOneOff = true + case []string: + if defaultValue != "" { + fieldConfig.Defaults = strings.Split(defaultValue, ",") + } + canUseOneOff = true + case int, null.Int: + if defaultValue != "" { + defaultValueInt, err := strconv.Atoi(defaultValue) + if err != nil { + return fmt.Errorf("invalid default value for field `%s`: %s", field.Name, defaultValue) + } + + fieldConfig.Defaults = defaultValueInt + } + case bool, null.Bool: + if defaultValue != "" { + defaultValueBool, err := strconv.ParseBool(defaultValue) + if err != nil { + return fmt.Errorf("invalid default value for field `%s`: %s", field.Name, defaultValue) + } + + fieldConfig.Defaults = defaultValueBool + } + default: + if defaultValue != "" { + return fmt.Errorf("field `%s` cannot use default value: unsupported type: %s", field.Name, fieldType) + } + + // check if it's a pointer + if fieldValue.Kind() == reflect.Ptr { + // check if the pointer is nil + if fieldValue.IsNil() { + fieldConfig.IsEmpty = true + } else { + fieldConfig.CurrentValue = fieldValue.Elem().Addr() + fieldConfig.Delegated = true + } + } else { + fieldConfig.Delegated = true + } + } + + // now check if the field is nullable interface + switch fieldValue.Interface().(type) { + case null.String: + if fieldValue.Interface().(null.String).IsZero() { + fieldConfig.IsEmpty = true + } + case null.Int: + if fieldValue.Interface().(null.Int).IsZero() { + fieldConfig.IsEmpty = true + } + case null.Bool: + if fieldValue.Interface().(null.Bool).IsZero() { + fieldConfig.IsEmpty = true + } + case []string: + if len(fieldValue.Interface().([]string)) == 0 { + fieldConfig.IsEmpty = true + } + } + + // now check if the field has required_if + requiredIf := field.Tag.Get("required_if") + if requiredIf != "" { + requiredIfParts := strings.Split(requiredIf, ",") + for _, part := range requiredIfParts { + partVal := strings.SplitN(part, "=", 2) + if len(partVal) != 2 { + return fmt.Errorf("invalid required_if for field `%s`: %s", field.Name, requiredIf) + } + + fieldConfig.RequiredIf[partVal[0]] = partVal[1] + } + } + + // check if the field can use one_of + if !canUseOneOff && len(fieldConfig.OneOf) > 0 { + return fmt.Errorf("field `%s` cannot use one_of: unsupported type: %s", field.Name, fieldType) + } + + fields[field.Name] = fieldConfig + } + + if err := validateFields(config, fields); err != nil { + return err + } + + return nil +} + +func validateFields(config interface{}, fields map[string]FieldConfig) error { + // now we can start to validate the fields + for _, fieldConfig := range fields { + if err := fieldConfig.validate(fields); err != nil { + return err + } + + fieldConfig.populate(config) + } + + return nil +} + +func (f *FieldConfig) validate(fields map[string]FieldConfig) error { + var required bool + var err error + + if required, err = f.validateRequired(fields); err != nil { + return err + } + + // check if the field needs to be updated and set defaults if needed + if err := f.checkIfFieldNeedsUpdate(); err != nil { + return err + } + + // then we can check if the field is one_of + if err := f.validateOneOf(); err != nil { + return err + } + + // and validate the type + if err := f.validateField(); err != nil { + return err + } + + // if the field is delegated, we need to validate the nested field + // but before that, let's check if the field is required + if required && f.Delegated { + if err := setDefaultsAndValidate(f.CurrentValue.(reflect.Value).Interface(), false); err != nil { + return err + } + } + + return nil +} + +func (f *FieldConfig) populate(config interface{}) { + // update the field if it's not empty + if !f.shouldUpdateValue { + return + } + + reflect.ValueOf(config).Elem().FieldByName(f.Name).Set(reflect.ValueOf(f.CurrentValue)) +} + +func (f *FieldConfig) checkIfFieldNeedsUpdate() error { + // populate the field if it's empty and has a default value + if f.IsEmpty && f.Defaults != nil { + switch f.CurrentValue.(type) { + case null.String: + f.CurrentValue = null.StringFrom(f.Defaults.(string)) + case null.Int: + f.CurrentValue = null.IntFrom(int64(f.Defaults.(int))) + case null.Bool: + f.CurrentValue = null.BoolFrom(f.Defaults.(bool)) + case string: + f.CurrentValue = f.Defaults.(string) + case int: + f.CurrentValue = f.Defaults.(int) + case bool: + f.CurrentValue = f.Defaults.(bool) + case []string: + f.CurrentValue = f.Defaults.([]string) + default: + return fmt.Errorf("field `%s` cannot use default value: unsupported type: %s", f.Name, f.TypeString) + } + + f.shouldUpdateValue = true + } + + return nil +} + +func (f *FieldConfig) validateRequired(fields map[string]FieldConfig) (bool, error) { + var required = f.Required + + // if the field is not required, we need to check if it's required_if + if !required && len(f.RequiredIf) > 0 { + for key, value := range f.RequiredIf { + // check if the field's result matches the required_if + // right now we only support string and int + requiredField, ok := fields[key] + if !ok { + return required, fmt.Errorf("required_if field `%s` not found", key) + } + + switch requiredField.CurrentValue.(type) { + case string: + if requiredField.CurrentValue.(string) == value.(string) { + required = true + } + case int: + if requiredField.CurrentValue.(int) == value.(int) { + required = true + } + case null.String: + if !requiredField.CurrentValue.(null.String).IsZero() && + requiredField.CurrentValue.(null.String).String == value.(string) { + required = true + } + case null.Int: + if !requiredField.CurrentValue.(null.Int).IsZero() && + requiredField.CurrentValue.(null.Int).Int64 == value.(int64) { + required = true + } + } + + // if the field is required, we can break the loop + // because we only need one of the required_if fields to be true + if required { + break + } + } + } + + if required && f.IsEmpty { + return false, fmt.Errorf("field `%s` is required", f.Name) + } + + return required, nil +} + +func checkIfSliceContains(slice []string, one_of []string) bool { + for _, oneOf := range one_of { + if slices.Contains(slice, oneOf) { + return true + } + } + + return false +} + +func (f *FieldConfig) validateOneOf() error { + if len(f.OneOf) == 0 { + return nil + } + + var val []string + switch f.CurrentValue.(type) { + case string: + val = []string{f.CurrentValue.(string)} + case null.String: + val = []string{f.CurrentValue.(null.String).String} + case []string: + // let's validate the value here + val = f.CurrentValue.([]string) + default: + return fmt.Errorf("field `%s` cannot use one_of: unsupported type: %s", f.Name, f.TypeString) + } + + if !checkIfSliceContains(val, f.OneOf) { + return fmt.Errorf( + "field `%s` is not one of the allowed values: %s, current value: %s", + f.Name, + strings.Join(f.OneOf, ", "), + strings.Join(val, ", "), + ) + } + + return nil +} + +func (f *FieldConfig) validateField() error { + if len(f.ValidateTypes) == 0 || f.IsEmpty { + return nil + } + + val, err := toString(f.CurrentValue) + if err != nil { + return fmt.Errorf("field `%s` cannot use validate_type: %s", f.Name, err) + } + + if val == "" { + return nil + } + + for _, validateType := range f.ValidateTypes { + switch validateType { + case "ipv4": + if net.ParseIP(val).To4() == nil { + return fmt.Errorf("field `%s` is not a valid IPv4 address: %s", f.Name, val) + } + case "ipv6": + if net.ParseIP(val).To16() == nil { + return fmt.Errorf("field `%s` is not a valid IPv6 address: %s", f.Name, val) + } + case "hwaddr": + if _, err := net.ParseMAC(val); err != nil { + return fmt.Errorf("field `%s` is not a valid MAC address: %s", f.Name, val) + } + case "hostname": + if _, err := idna.Lookup.ToASCII(val); err != nil { + return fmt.Errorf("field `%s` is not a valid hostname: %s", f.Name, val) + } + default: + return fmt.Errorf("field `%s` cannot use validate_type: unsupported validator: %s", f.Name, validateType) + } + } + + return nil +} diff --git a/internal/confparser/confparser_test.go b/internal/confparser/confparser_test.go new file mode 100644 index 0000000..dd5e00a --- /dev/null +++ b/internal/confparser/confparser_test.go @@ -0,0 +1,100 @@ +package confparser + +import ( + "net" + "testing" + "time" + + "github.com/guregu/null/v6" +) + +type testIPv6Address struct { //nolint:unused + Address net.IP `json:"address"` + Prefix net.IPNet `json:"prefix"` + ValidLifetime *time.Time `json:"valid_lifetime"` + PreferredLifetime *time.Time `json:"preferred_lifetime"` + Scope int `json:"scope"` +} + +type testIPv4StaticConfig struct { + Address null.String `json:"address" validate_type:"ipv4" required:"true"` + Netmask null.String `json:"netmask" validate_type:"ipv4" required:"true"` + Gateway null.String `json:"gateway" validate_type:"ipv4" required:"true"` + DNS []string `json:"dns" validate_type:"ipv4" required:"true"` +} + +type testIPv6StaticConfig struct { + Address null.String `json:"address" validate_type:"ipv6" required:"true"` + Prefix null.String `json:"prefix" validate_type:"ipv6" required:"true"` + Gateway null.String `json:"gateway" validate_type:"ipv6" required:"true"` + DNS []string `json:"dns" validate_type:"ipv6" required:"true"` +} +type testNetworkConfig struct { + Hostname null.String `json:"hostname,omitempty"` + Domain null.String `json:"domain,omitempty"` + + IPv4Mode null.String `json:"ipv4_mode" one_of:"dhcp,static,disabled" default:"dhcp"` + IPv4Static *testIPv4StaticConfig `json:"ipv4_static,omitempty" required_if:"IPv4Mode=static"` + + IPv6Mode null.String `json:"ipv6_mode" one_of:"slaac,dhcpv6,slaac_and_dhcpv6,static,link_local,disabled" default:"slaac"` + IPv6Static *testIPv6StaticConfig `json:"ipv6_static,omitempty" required_if:"IPv6Mode=static"` + + LLDPMode null.String `json:"lldp_mode,omitempty" one_of:"disabled,basic,all" default:"basic"` + LLDPTxTLVs []string `json:"lldp_tx_tlvs,omitempty" one_of:"chassis,port,system,vlan" default:"chassis,port,system,vlan"` + MDNSMode null.String `json:"mdns_mode,omitempty" one_of:"disabled,auto,ipv4_only,ipv6_only" default:"auto"` + TimeSyncMode null.String `json:"time_sync_mode,omitempty" one_of:"ntp_only,ntp_and_http,http_only,custom" default:"ntp_and_http"` + TimeSyncOrdering []string `json:"time_sync_ordering,omitempty" one_of:"http,ntp,ntp_dhcp,ntp_user_provided,ntp_fallback" default:"ntp,http"` + TimeSyncDisableFallback null.Bool `json:"time_sync_disable_fallback,omitempty" default:"false"` + TimeSyncParallel null.Int `json:"time_sync_parallel,omitempty" default:"4"` +} + +func TestValidateConfig(t *testing.T) { + config := &testNetworkConfig{} + + err := SetDefaultsAndValidate(config) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestValidateIPv4StaticConfigRequired(t *testing.T) { + config := &testNetworkConfig{ + IPv4Static: &testIPv4StaticConfig{ + Address: null.StringFrom("192.168.1.1"), + Gateway: null.StringFrom("192.168.1.1"), + }, + } + + err := SetDefaultsAndValidate(config) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestValidateIPv4StaticConfigRequiredIf(t *testing.T) { + config := &testNetworkConfig{ + IPv4Mode: null.StringFrom("static"), + } + + err := SetDefaultsAndValidate(config) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestValidateIPv4StaticConfigValidateType(t *testing.T) { + config := &testNetworkConfig{ + IPv4Static: &testIPv4StaticConfig{ + Address: null.StringFrom("X"), + Netmask: null.StringFrom("255.255.255.0"), + Gateway: null.StringFrom("192.168.1.1"), + DNS: []string{"8.8.8.8", "8.8.4.4"}, + }, + IPv4Mode: null.StringFrom("static"), + } + + err := SetDefaultsAndValidate(config) + if err == nil { + t.Fatalf("expected error, got nil") + } +} diff --git a/internal/confparser/utils.go b/internal/confparser/utils.go new file mode 100644 index 0000000..a46871e --- /dev/null +++ b/internal/confparser/utils.go @@ -0,0 +1,28 @@ +package confparser + +import ( + "fmt" + "reflect" + "strings" + + "github.com/guregu/null/v6" +) + +func splitString(s string) []string { + if s == "" { + return []string{} + } + + return strings.Split(s, ",") +} + +func toString(v interface{}) (string, error) { + switch v := v.(type) { + case string: + return v, nil + case null.String: + return v.String, nil + } + + return "", fmt.Errorf("unsupported type: %s", reflect.TypeOf(v)) +} diff --git a/internal/logging/logger.go b/internal/logging/logger.go new file mode 100644 index 0000000..39156ec --- /dev/null +++ b/internal/logging/logger.go @@ -0,0 +1,197 @@ +package logging + +import ( + "fmt" + "io" + "os" + "strings" + "sync" + "time" + + "github.com/rs/zerolog" +) + +type Logger struct { + l *zerolog.Logger + scopeLoggers map[string]*zerolog.Logger + scopeLevels map[string]zerolog.Level + scopeLevelMutex sync.Mutex + + defaultLogLevelFromEnv zerolog.Level + defaultLogLevelFromConfig zerolog.Level + defaultLogLevel zerolog.Level +} + +const ( + defaultLogLevel = zerolog.ErrorLevel +) + +type logOutput struct { + mu *sync.Mutex +} + +func (w *logOutput) Write(p []byte) (n int, err error) { + w.mu.Lock() + defer w.mu.Unlock() + + // TODO: write to file or syslog + if sseServer != nil { + // use a goroutine to avoid blocking the Write method + go func() { + sseServer.Message <- string(p) + }() + } + return len(p), nil +} + +var ( + consoleLogOutput io.Writer = zerolog.ConsoleWriter{ + Out: os.Stdout, + TimeFormat: time.RFC3339, + PartsOrder: []string{"time", "level", "scope", "component", "message"}, + FieldsExclude: []string{"scope", "component"}, + FormatPartValueByName: func(value interface{}, name string) string { + val := fmt.Sprintf("%s", value) + if name == "component" { + if value == nil { + return "-" + } + } + return val + }, + } + fileLogOutput io.Writer = &logOutput{mu: &sync.Mutex{}} + defaultLogOutput = zerolog.MultiLevelWriter(consoleLogOutput, fileLogOutput) + + zerologLevels = map[string]zerolog.Level{ + "DISABLE": zerolog.Disabled, + "NOLEVEL": zerolog.NoLevel, + "PANIC": zerolog.PanicLevel, + "FATAL": zerolog.FatalLevel, + "ERROR": zerolog.ErrorLevel, + "WARN": zerolog.WarnLevel, + "INFO": zerolog.InfoLevel, + "DEBUG": zerolog.DebugLevel, + "TRACE": zerolog.TraceLevel, + } +) + +func NewLogger(zerologLogger zerolog.Logger) *Logger { + return &Logger{ + l: &zerologLogger, + scopeLoggers: make(map[string]*zerolog.Logger), + scopeLevels: make(map[string]zerolog.Level), + scopeLevelMutex: sync.Mutex{}, + defaultLogLevelFromEnv: -2, + defaultLogLevelFromConfig: -2, + defaultLogLevel: defaultLogLevel, + } +} + +func (l *Logger) updateLogLevel() { + l.scopeLevelMutex.Lock() + defer l.scopeLevelMutex.Unlock() + + l.scopeLevels = make(map[string]zerolog.Level) + + finalDefaultLogLevel := l.defaultLogLevel + + for name, level := range zerologLevels { + env := os.Getenv(fmt.Sprintf("JETKVM_LOG_%s", name)) + + if env == "" { + env = os.Getenv(fmt.Sprintf("PION_LOG_%s", name)) + } + + if env == "" { + env = os.Getenv(fmt.Sprintf("PIONS_LOG_%s", name)) + } + + if env == "" { + continue + } + + if strings.ToLower(env) == "all" { + l.defaultLogLevelFromEnv = level + + if finalDefaultLogLevel > level { + finalDefaultLogLevel = level + } + + continue + } + + scopes := strings.Split(strings.ToLower(env), ",") + for _, scope := range scopes { + l.scopeLevels[scope] = level + } + } + + l.defaultLogLevel = finalDefaultLogLevel +} + +func (l *Logger) getScopeLoggerLevel(scope string) zerolog.Level { + if l.scopeLevels == nil { + l.updateLogLevel() + } + + var scopeLevel zerolog.Level + if l.defaultLogLevelFromConfig != -2 { + scopeLevel = l.defaultLogLevelFromConfig + } + if l.defaultLogLevelFromEnv != -2 { + scopeLevel = l.defaultLogLevelFromEnv + } + + // if the scope is not in the map, use the default level from the root logger + if level, ok := l.scopeLevels[scope]; ok { + scopeLevel = level + } + + return scopeLevel +} + +func (l *Logger) newScopeLogger(scope string) zerolog.Logger { + scopeLevel := l.getScopeLoggerLevel(scope) + logger := l.l.Level(scopeLevel).With().Str("component", scope).Logger() + + return logger +} + +func (l *Logger) getLogger(scope string) *zerolog.Logger { + logger, ok := l.scopeLoggers[scope] + if !ok || logger == nil { + scopeLogger := l.newScopeLogger(scope) + l.scopeLoggers[scope] = &scopeLogger + } + + return l.scopeLoggers[scope] +} + +func (l *Logger) UpdateLogLevel(configDefaultLogLevel string) { + needUpdate := false + + if configDefaultLogLevel != "" { + if logLevel, ok := zerologLevels[configDefaultLogLevel]; ok { + l.defaultLogLevelFromConfig = logLevel + } else { + l.l.Warn().Str("logLevel", configDefaultLogLevel).Msg("invalid defaultLogLevel from config, using ERROR") + } + + if l.defaultLogLevelFromConfig != l.defaultLogLevel { + needUpdate = true + } + } + + l.updateLogLevel() + + if needUpdate { + for scope, logger := range l.scopeLoggers { + currentLevel := logger.GetLevel() + targetLevel := l.getScopeLoggerLevel(scope) + if currentLevel != targetLevel { + *logger = l.newScopeLogger(scope) + } + } + } +} diff --git a/internal/logging/pion.go b/internal/logging/pion.go new file mode 100644 index 0000000..453b8bc --- /dev/null +++ b/internal/logging/pion.go @@ -0,0 +1,63 @@ +package logging + +import ( + "github.com/pion/logging" + "github.com/rs/zerolog" +) + +type pionLogger struct { + logger *zerolog.Logger +} + +// Print all messages except trace. +func (c pionLogger) Trace(msg string) { + c.logger.Trace().Msg(msg) +} +func (c pionLogger) Tracef(format string, args ...interface{}) { + c.logger.Trace().Msgf(format, args...) +} + +func (c pionLogger) Debug(msg string) { + c.logger.Debug().Msg(msg) +} +func (c pionLogger) Debugf(format string, args ...interface{}) { + c.logger.Debug().Msgf(format, args...) +} +func (c pionLogger) Info(msg string) { + c.logger.Info().Msg(msg) +} +func (c pionLogger) Infof(format string, args ...interface{}) { + c.logger.Info().Msgf(format, args...) +} +func (c pionLogger) Warn(msg string) { + c.logger.Warn().Msg(msg) +} +func (c pionLogger) Warnf(format string, args ...interface{}) { + c.logger.Warn().Msgf(format, args...) +} +func (c pionLogger) Error(msg string) { + c.logger.Error().Msg(msg) +} +func (c pionLogger) Errorf(format string, args ...interface{}) { + c.logger.Error().Msgf(format, args...) +} + +// customLoggerFactory satisfies the interface logging.LoggerFactory +// This allows us to create different loggers per subsystem. So we can +// add custom behavior. +type pionLoggerFactory struct{} + +func (c pionLoggerFactory) NewLogger(subsystem string) logging.LeveledLogger { + logger := rootLogger.getLogger(subsystem).With(). + Str("scope", "pion"). + Str("component", subsystem). + Logger() + + return pionLogger{logger: &logger} +} + +var defaultLoggerFactory = &pionLoggerFactory{} + +func GetPionDefaultLoggerFactory() logging.LoggerFactory { + return defaultLoggerFactory +} diff --git a/internal/logging/root.go b/internal/logging/root.go new file mode 100644 index 0000000..397ca64 --- /dev/null +++ b/internal/logging/root.go @@ -0,0 +1,20 @@ +package logging + +import "github.com/rs/zerolog" + +var ( + rootZerologLogger = zerolog.New(defaultLogOutput).With(). + Str("scope", "jetkvm"). + Timestamp(). + Stack(). + Logger() + rootLogger = NewLogger(rootZerologLogger) +) + +func GetRootLogger() *Logger { + return rootLogger +} + +func GetSubsystemLogger(subsystem string) *zerolog.Logger { + return rootLogger.getLogger(subsystem) +} diff --git a/internal/logging/sse.go b/internal/logging/sse.go new file mode 100644 index 0000000..05e6e9e --- /dev/null +++ b/internal/logging/sse.go @@ -0,0 +1,137 @@ +package logging + +import ( + "embed" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog" +) + +//go:embed sse.html +var sseHTML embed.FS + +type sseEvent struct { + Message chan string + NewClients chan chan string + ClosedClients chan chan string + TotalClients map[chan string]bool +} + +// New event messages are broadcast to all registered client connection channels +type sseClientChan chan string + +var ( + sseServer *sseEvent + sseLogger *zerolog.Logger +) + +func init() { + sseServer = newSseServer() + sseLogger = GetSubsystemLogger("sse") +} + +// Initialize event and Start procnteessing requests +func newSseServer() (event *sseEvent) { + event = &sseEvent{ + Message: make(chan string), + NewClients: make(chan chan string), + ClosedClients: make(chan chan string), + TotalClients: make(map[chan string]bool), + } + + go event.listen() + + return +} + +// It Listens all incoming requests from clients. +// Handles addition and removal of clients and broadcast messages to clients. +func (stream *sseEvent) listen() { + for { + select { + // Add new available client + case client := <-stream.NewClients: + stream.TotalClients[client] = true + sseLogger.Info(). + Int("total_clients", len(stream.TotalClients)). + Msg("new client connected") + + // Remove closed client + case client := <-stream.ClosedClients: + delete(stream.TotalClients, client) + close(client) + sseLogger.Info().Int("total_clients", len(stream.TotalClients)).Msg("client disconnected") + + // Broadcast message to client + case eventMsg := <-stream.Message: + for clientMessageChan := range stream.TotalClients { + select { + case clientMessageChan <- eventMsg: + // Message sent successfully + default: + // Failed to send, dropping message + } + } + } + } +} + +func (stream *sseEvent) serveHTTP() gin.HandlerFunc { + return func(c *gin.Context) { + clientChan := make(sseClientChan) + stream.NewClients <- clientChan + + go func() { + <-c.Writer.CloseNotify() + + for range clientChan { + } + + stream.ClosedClients <- clientChan + }() + + c.Set("clientChan", clientChan) + c.Next() + } +} + +func sseHeadersMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request.Method == "GET" && c.NegotiateFormat(gin.MIMEHTML) == gin.MIMEHTML { + c.FileFromFS("/sse.html", http.FS(sseHTML)) + c.Status(http.StatusOK) + c.Abort() + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Next() + } +} + +func AttachSSEHandler(router *gin.RouterGroup) { + router.StaticFS("/log-stream", http.FS(sseHTML)) + router.GET("/log-stream", sseHeadersMiddleware(), sseServer.serveHTTP(), func(c *gin.Context) { + v, ok := c.Get("clientChan") + if !ok { + return + } + clientChan, ok := v.(sseClientChan) + if !ok { + return + } + c.Stream(func(w io.Writer) bool { + if msg, ok := <-clientChan; ok { + c.SSEvent("message", msg) + return true + } + return false + }) + }) +} diff --git a/internal/logging/sse.html b/internal/logging/sse.html new file mode 100644 index 0000000..192b464 --- /dev/null +++ b/internal/logging/sse.html @@ -0,0 +1,319 @@ + + + + + + Server Sent Event + + + + +
+ +
+ +
+
+ + + + + \ No newline at end of file diff --git a/internal/logging/utils.go b/internal/logging/utils.go new file mode 100644 index 0000000..e622d96 --- /dev/null +++ b/internal/logging/utils.go @@ -0,0 +1,32 @@ +package logging + +import ( + "fmt" + "os" + + "github.com/rs/zerolog" +) + +var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel) + +func GetDefaultLogger() *zerolog.Logger { + return &defaultLogger +} + +func ErrorfL(l *zerolog.Logger, format string, err error, args ...interface{}) error { + // TODO: move rootLogger to logging package + if l == nil { + l = &defaultLogger + } + + l.Error().Err(err).Msgf(format, args...) + + if err == nil { + return fmt.Errorf(format, args...) + } + + err_msg := err.Error() + ": %v" + err_args := append(args, err) + + return fmt.Errorf(err_msg, err_args...) +} diff --git a/internal/mdns/mdns.go b/internal/mdns/mdns.go new file mode 100644 index 0000000..b882b93 --- /dev/null +++ b/internal/mdns/mdns.go @@ -0,0 +1,190 @@ +package mdns + +import ( + "fmt" + "net" + "reflect" + "strings" + "sync" + + "github.com/jetkvm/kvm/internal/logging" + pion_mdns "github.com/pion/mdns/v2" + "github.com/rs/zerolog" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +type MDNS struct { + conn *pion_mdns.Conn + lock sync.Mutex + l *zerolog.Logger + + localNames []string + listenOptions *MDNSListenOptions +} + +type MDNSListenOptions struct { + IPv4 bool + IPv6 bool +} + +type MDNSOptions struct { + Logger *zerolog.Logger + LocalNames []string + ListenOptions *MDNSListenOptions +} + +const ( + DefaultAddressIPv4 = pion_mdns.DefaultAddressIPv4 + DefaultAddressIPv6 = pion_mdns.DefaultAddressIPv6 +) + +func NewMDNS(opts *MDNSOptions) (*MDNS, error) { + if opts.Logger == nil { + opts.Logger = logging.GetDefaultLogger() + } + + if opts.ListenOptions == nil { + opts.ListenOptions = &MDNSListenOptions{ + IPv4: true, + IPv6: true, + } + } + + return &MDNS{ + l: opts.Logger, + lock: sync.Mutex{}, + localNames: opts.LocalNames, + listenOptions: opts.ListenOptions, + }, nil +} + +func (m *MDNS) start(allowRestart bool) error { + m.lock.Lock() + defer m.lock.Unlock() + + if m.conn != nil { + if !allowRestart { + return fmt.Errorf("mDNS server already running") + } + + m.conn.Close() + } + + if m.listenOptions == nil { + return fmt.Errorf("listen options not set") + } + + if !m.listenOptions.IPv4 && !m.listenOptions.IPv6 { + m.l.Info().Msg("mDNS server disabled") + return nil + } + + var ( + addr4, addr6 *net.UDPAddr + l4, l6 *net.UDPConn + p4 *ipv4.PacketConn + p6 *ipv6.PacketConn + err error + ) + + if m.listenOptions.IPv4 { + addr4, err = net.ResolveUDPAddr("udp4", DefaultAddressIPv4) + if err != nil { + return err + } + + l4, err = net.ListenUDP("udp4", addr4) + if err != nil { + return err + } + + p4 = ipv4.NewPacketConn(l4) + } + + if m.listenOptions.IPv6 { + addr6, err = net.ResolveUDPAddr("udp6", DefaultAddressIPv6) + if err != nil { + return err + } + + l6, err = net.ListenUDP("udp6", addr6) + if err != nil { + return err + } + + p6 = ipv6.NewPacketConn(l6) + } + + scopeLogger := m.l.With(). + Interface("local_names", m.localNames). + Bool("ipv4", m.listenOptions.IPv4). + Bool("ipv6", m.listenOptions.IPv6). + Logger() + + newLocalNames := make([]string, len(m.localNames)) + for i, name := range m.localNames { + newLocalNames[i] = strings.TrimRight(strings.ToLower(name), ".") + if !strings.HasSuffix(newLocalNames[i], ".local") { + newLocalNames[i] = newLocalNames[i] + ".local" + } + } + + mDNSConn, err := pion_mdns.Server(p4, p6, &pion_mdns.Config{ + LocalNames: newLocalNames, + LoggerFactory: logging.GetPionDefaultLoggerFactory(), + }) + + if err != nil { + scopeLogger.Warn().Err(err).Msg("failed to start mDNS server") + return err + } + + m.conn = mDNSConn + scopeLogger.Info().Msg("mDNS server started") + + return nil +} + +func (m *MDNS) Start() error { + return m.start(false) +} + +func (m *MDNS) Restart() error { + return m.start(true) +} + +func (m *MDNS) Stop() error { + m.lock.Lock() + defer m.lock.Unlock() + + if m.conn == nil { + return nil + } + + return m.conn.Close() +} + +func (m *MDNS) SetLocalNames(localNames []string, always bool) error { + if reflect.DeepEqual(m.localNames, localNames) && !always { + return nil + } + + m.localNames = localNames + _ = m.Restart() + + return nil +} + +func (m *MDNS) SetListenOptions(listenOptions *MDNSListenOptions) error { + if m.listenOptions != nil && + m.listenOptions.IPv4 == listenOptions.IPv4 && + m.listenOptions.IPv6 == listenOptions.IPv6 { + return nil + } + + m.listenOptions = listenOptions + _ = m.Restart() + + return nil +} diff --git a/internal/mdns/utils.go b/internal/mdns/utils.go new file mode 100644 index 0000000..7565eee --- /dev/null +++ b/internal/mdns/utils.go @@ -0,0 +1 @@ +package mdns diff --git a/internal/network/config.go b/internal/network/config.go new file mode 100644 index 0000000..74ddf19 --- /dev/null +++ b/internal/network/config.go @@ -0,0 +1,110 @@ +package network + +import ( + "fmt" + "net" + "time" + + "github.com/guregu/null/v6" + "github.com/jetkvm/kvm/internal/mdns" + "golang.org/x/net/idna" +) + +type IPv6Address struct { + Address net.IP `json:"address"` + Prefix net.IPNet `json:"prefix"` + ValidLifetime *time.Time `json:"valid_lifetime"` + PreferredLifetime *time.Time `json:"preferred_lifetime"` + Scope int `json:"scope"` +} + +type IPv4StaticConfig struct { + Address null.String `json:"address,omitempty" validate_type:"ipv4" required:"true"` + Netmask null.String `json:"netmask,omitempty" validate_type:"ipv4" required:"true"` + Gateway null.String `json:"gateway,omitempty" validate_type:"ipv4" required:"true"` + DNS []string `json:"dns,omitempty" validate_type:"ipv4" required:"true"` +} + +type IPv6StaticConfig struct { + Address null.String `json:"address,omitempty" validate_type:"ipv6" required:"true"` + Prefix null.String `json:"prefix,omitempty" validate_type:"ipv6" required:"true"` + Gateway null.String `json:"gateway,omitempty" validate_type:"ipv6" required:"true"` + DNS []string `json:"dns,omitempty" validate_type:"ipv6" required:"true"` +} +type NetworkConfig struct { + Hostname null.String `json:"hostname,omitempty" validate_type:"hostname"` + Domain null.String `json:"domain,omitempty" validate_type:"hostname"` + + IPv4Mode null.String `json:"ipv4_mode,omitempty" one_of:"dhcp,static,disabled" default:"dhcp"` + IPv4Static *IPv4StaticConfig `json:"ipv4_static,omitempty" required_if:"IPv4Mode=static"` + + IPv6Mode null.String `json:"ipv6_mode,omitempty" one_of:"slaac,dhcpv6,slaac_and_dhcpv6,static,link_local,disabled" default:"slaac"` + IPv6Static *IPv6StaticConfig `json:"ipv6_static,omitempty" required_if:"IPv6Mode=static"` + + LLDPMode null.String `json:"lldp_mode,omitempty" one_of:"disabled,basic,all" default:"basic"` + LLDPTxTLVs []string `json:"lldp_tx_tlvs,omitempty" one_of:"chassis,port,system,vlan" default:"chassis,port,system,vlan"` + MDNSMode null.String `json:"mdns_mode,omitempty" one_of:"disabled,auto,ipv4_only,ipv6_only" default:"auto"` + TimeSyncMode null.String `json:"time_sync_mode,omitempty" one_of:"ntp_only,ntp_and_http,http_only,custom" default:"ntp_and_http"` + TimeSyncOrdering []string `json:"time_sync_ordering,omitempty" one_of:"http,ntp,ntp_dhcp,ntp_user_provided,ntp_fallback" default:"ntp,http"` + TimeSyncDisableFallback null.Bool `json:"time_sync_disable_fallback,omitempty" default:"false"` + TimeSyncParallel null.Int `json:"time_sync_parallel,omitempty" default:"4"` +} + +func (c *NetworkConfig) GetMDNSMode() *mdns.MDNSListenOptions { + mode := c.MDNSMode.String + listenOptions := &mdns.MDNSListenOptions{ + IPv4: true, + IPv6: true, + } + + switch mode { + case "ipv4_only": + listenOptions.IPv6 = false + case "ipv6_only": + listenOptions.IPv4 = false + case "disabled": + listenOptions.IPv4 = false + listenOptions.IPv6 = false + } + + return listenOptions +} +func (s *NetworkInterfaceState) GetHostname() string { + hostname := ToValidHostname(s.config.Hostname.String) + + if hostname == "" { + return s.defaultHostname + } + + return hostname +} + +func ToValidDomain(domain string) string { + ascii, err := idna.Lookup.ToASCII(domain) + if err != nil { + return "" + } + + return ascii +} + +func (s *NetworkInterfaceState) GetDomain() string { + domain := ToValidDomain(s.config.Domain.String) + + if domain == "" { + lease := s.dhcpClient.GetLease() + if lease != nil && lease.Domain != "" { + domain = ToValidDomain(lease.Domain) + } + } + + if domain == "" { + return "local" + } + + return domain +} + +func (s *NetworkInterfaceState) GetFQDN() string { + return fmt.Sprintf("%s.%s", s.GetHostname(), s.GetDomain()) +} diff --git a/internal/network/dhcp.go b/internal/network/dhcp.go new file mode 100644 index 0000000..9e173cc --- /dev/null +++ b/internal/network/dhcp.go @@ -0,0 +1,11 @@ +package network + +type DhcpTargetState int + +const ( + DhcpTargetStateDoNothing DhcpTargetState = iota + DhcpTargetStateStart + DhcpTargetStateStop + DhcpTargetStateRenew + DhcpTargetStateRelease +) diff --git a/internal/network/hostname.go b/internal/network/hostname.go new file mode 100644 index 0000000..d75255c --- /dev/null +++ b/internal/network/hostname.go @@ -0,0 +1,137 @@ +package network + +import ( + "fmt" + "io" + "os" + "os/exec" + "strings" + "sync" + + "golang.org/x/net/idna" +) + +const ( + hostnamePath = "/etc/hostname" + hostsPath = "/etc/hosts" +) + +var ( + hostnameLock sync.Mutex = sync.Mutex{} +) + +func updateEtcHosts(hostname string, fqdn string) error { + // update /etc/hosts + hostsFile, err := os.OpenFile(hostsPath, os.O_RDWR|os.O_SYNC, os.ModeExclusive) + if err != nil { + return fmt.Errorf("failed to open %s: %w", hostsPath, err) + } + defer hostsFile.Close() + + // read all lines + if _, err := hostsFile.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek %s: %w", hostsPath, err) + } + + lines, err := io.ReadAll(hostsFile) + if err != nil { + return fmt.Errorf("failed to read %s: %w", hostsPath, err) + } + + newLines := []string{} + hostLine := fmt.Sprintf("127.0.1.1\t%s %s", hostname, fqdn) + hostLineExists := false + + for _, line := range strings.Split(string(lines), "\n") { + if strings.HasPrefix(line, "127.0.1.1") { + hostLineExists = true + line = hostLine + } + newLines = append(newLines, line) + } + + if !hostLineExists { + newLines = append(newLines, hostLine) + } + + if err := hostsFile.Truncate(0); err != nil { + return fmt.Errorf("failed to truncate %s: %w", hostsPath, err) + } + + if _, err := hostsFile.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek %s: %w", hostsPath, err) + } + + if _, err := hostsFile.Write([]byte(strings.Join(newLines, "\n"))); err != nil { + return fmt.Errorf("failed to write %s: %w", hostsPath, err) + } + + return nil +} + +func ToValidHostname(hostname string) string { + ascii, err := idna.Lookup.ToASCII(hostname) + if err != nil { + return "" + } + return ascii +} + +func SetHostname(hostname string, fqdn string) error { + hostnameLock.Lock() + defer hostnameLock.Unlock() + + hostname = ToValidHostname(strings.TrimSpace(hostname)) + fqdn = ToValidHostname(strings.TrimSpace(fqdn)) + + if hostname == "" { + return fmt.Errorf("invalid hostname: %s", hostname) + } + + if fqdn == "" { + fqdn = hostname + } + + // update /etc/hostname + if err := os.WriteFile(hostnamePath, []byte(hostname), 0644); err != nil { + return fmt.Errorf("failed to write %s: %w", hostnamePath, err) + } + + // update /etc/hosts + if err := updateEtcHosts(hostname, fqdn); err != nil { + return fmt.Errorf("failed to update /etc/hosts: %w", err) + } + + // run hostname + if err := exec.Command("hostname", "-F", hostnamePath).Run(); err != nil { + return fmt.Errorf("failed to run hostname: %w", err) + } + + return nil +} + +func (s *NetworkInterfaceState) setHostnameIfNotSame() error { + hostname := s.GetHostname() + currentHostname, _ := os.Hostname() + + fqdn := fmt.Sprintf("%s.%s", hostname, s.GetDomain()) + + if currentHostname == hostname && s.currentFqdn == fqdn && s.currentHostname == hostname { + return nil + } + + scopedLogger := s.l.With().Str("hostname", hostname).Str("fqdn", fqdn).Logger() + + err := SetHostname(hostname, fqdn) + if err != nil { + scopedLogger.Error().Err(err).Msg("failed to set hostname") + return err + } + + s.currentHostname = hostname + s.currentFqdn = fqdn + + scopedLogger.Info().Msg("hostname set") + + return nil +} diff --git a/internal/network/netif.go b/internal/network/netif.go new file mode 100644 index 0000000..c5db806 --- /dev/null +++ b/internal/network/netif.go @@ -0,0 +1,346 @@ +package network + +import ( + "fmt" + "net" + "sync" + + "github.com/jetkvm/kvm/internal/confparser" + "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/udhcpc" + "github.com/rs/zerolog" + + "github.com/vishvananda/netlink" +) + +type NetworkInterfaceState struct { + interfaceName string + interfaceUp bool + ipv4Addr *net.IP + ipv4Addresses []string + ipv6Addr *net.IP + ipv6Addresses []IPv6Address + ipv6LinkLocal *net.IP + macAddr *net.HardwareAddr + + l *zerolog.Logger + stateLock sync.Mutex + + config *NetworkConfig + dhcpClient *udhcpc.DHCPClient + + defaultHostname string + currentHostname string + currentFqdn string + + onStateChange func(state *NetworkInterfaceState) + onInitialCheck func(state *NetworkInterfaceState) + cbConfigChange func(config *NetworkConfig) + + checked bool +} + +type NetworkInterfaceOptions struct { + InterfaceName string + DhcpPidFile string + Logger *zerolog.Logger + DefaultHostname string + OnStateChange func(state *NetworkInterfaceState) + OnInitialCheck func(state *NetworkInterfaceState) + OnDhcpLeaseChange func(lease *udhcpc.Lease) + OnConfigChange func(config *NetworkConfig) + NetworkConfig *NetworkConfig +} + +func NewNetworkInterfaceState(opts *NetworkInterfaceOptions) (*NetworkInterfaceState, error) { + if opts.NetworkConfig == nil { + return nil, fmt.Errorf("NetworkConfig can not be nil") + } + + if opts.DefaultHostname == "" { + opts.DefaultHostname = "jetkvm" + } + + err := confparser.SetDefaultsAndValidate(opts.NetworkConfig) + if err != nil { + return nil, err + } + + l := opts.Logger + s := &NetworkInterfaceState{ + interfaceName: opts.InterfaceName, + defaultHostname: opts.DefaultHostname, + stateLock: sync.Mutex{}, + l: l, + onStateChange: opts.OnStateChange, + onInitialCheck: opts.OnInitialCheck, + cbConfigChange: opts.OnConfigChange, + config: opts.NetworkConfig, + } + + // create the dhcp client + dhcpClient := udhcpc.NewDHCPClient(&udhcpc.DHCPClientOptions{ + InterfaceName: opts.InterfaceName, + PidFile: opts.DhcpPidFile, + Logger: l, + OnLeaseChange: func(lease *udhcpc.Lease) { + _, err := s.update() + if err != nil { + opts.Logger.Error().Err(err).Msg("failed to update network state") + return + } + + _ = s.setHostnameIfNotSame() + + opts.OnDhcpLeaseChange(lease) + }, + }) + + s.dhcpClient = dhcpClient + + return s, nil +} + +func (s *NetworkInterfaceState) IsUp() bool { + return s.interfaceUp +} + +func (s *NetworkInterfaceState) HasIPAssigned() bool { + return s.ipv4Addr != nil || s.ipv6Addr != nil +} + +func (s *NetworkInterfaceState) IsOnline() bool { + return s.IsUp() && s.HasIPAssigned() +} + +func (s *NetworkInterfaceState) IPv4() *net.IP { + return s.ipv4Addr +} + +func (s *NetworkInterfaceState) IPv4String() string { + if s.ipv4Addr == nil { + return "..." + } + return s.ipv4Addr.String() +} + +func (s *NetworkInterfaceState) IPv6() *net.IP { + return s.ipv6Addr +} + +func (s *NetworkInterfaceState) IPv6String() string { + if s.ipv6Addr == nil { + return "..." + } + return s.ipv6Addr.String() +} + +func (s *NetworkInterfaceState) MAC() *net.HardwareAddr { + return s.macAddr +} + +func (s *NetworkInterfaceState) MACString() string { + if s.macAddr == nil { + return "" + } + return s.macAddr.String() +} + +func (s *NetworkInterfaceState) update() (DhcpTargetState, error) { + s.stateLock.Lock() + defer s.stateLock.Unlock() + + dhcpTargetState := DhcpTargetStateDoNothing + + iface, err := netlink.LinkByName(s.interfaceName) + if err != nil { + s.l.Error().Err(err).Msg("failed to get interface") + return dhcpTargetState, err + } + + // detect if the interface status changed + var changed bool + attrs := iface.Attrs() + state := attrs.OperState + newInterfaceUp := state == netlink.OperUp + + // check if the interface is coming up + interfaceGoingUp := !s.interfaceUp && newInterfaceUp + interfaceGoingDown := s.interfaceUp && !newInterfaceUp + + if s.interfaceUp != newInterfaceUp { + s.interfaceUp = newInterfaceUp + changed = true + } + + if changed { + if interfaceGoingUp { + s.l.Info().Msg("interface state transitioned to up") + dhcpTargetState = DhcpTargetStateRenew + } else if interfaceGoingDown { + s.l.Info().Msg("interface state transitioned to down") + } + } + + // set the mac address + s.macAddr = &attrs.HardwareAddr + + // get the ip addresses + addrs, err := netlinkAddrs(iface) + if err != nil { + return dhcpTargetState, logging.ErrorfL(s.l, "failed to get ip addresses", err) + } + + var ( + ipv4Addresses = make([]net.IP, 0) + ipv4AddressesString = make([]string, 0) + ipv6Addresses = make([]IPv6Address, 0) + // ipv6AddressesString = make([]string, 0) + ipv6LinkLocal *net.IP + ) + + for _, addr := range addrs { + if addr.IP.To4() != nil { + scopedLogger := s.l.With().Str("ipv4", addr.IP.String()).Logger() + if interfaceGoingDown { + // remove all IPv4 addresses from the interface. + scopedLogger.Info().Msg("state transitioned to down, removing IPv4 address") + err := netlink.AddrDel(iface, &addr) + if err != nil { + scopedLogger.Warn().Err(err).Msg("failed to delete address") + } + // notify the DHCP client to release the lease + dhcpTargetState = DhcpTargetStateRelease + continue + } + ipv4Addresses = append(ipv4Addresses, addr.IP) + ipv4AddressesString = append(ipv4AddressesString, addr.IPNet.String()) + } else if addr.IP.To16() != nil { + scopedLogger := s.l.With().Str("ipv6", addr.IP.String()).Logger() + // check if it's a link local address + if addr.IP.IsLinkLocalUnicast() { + ipv6LinkLocal = &addr.IP + continue + } + + if !addr.IP.IsGlobalUnicast() { + scopedLogger.Trace().Msg("not a global unicast address, skipping") + continue + } + + if interfaceGoingDown { + scopedLogger.Info().Msg("state transitioned to down, removing IPv6 address") + err := netlink.AddrDel(iface, &addr) + if err != nil { + scopedLogger.Warn().Err(err).Msg("failed to delete address") + } + continue + } + ipv6Addresses = append(ipv6Addresses, IPv6Address{ + Address: addr.IP, + Prefix: *addr.IPNet, + ValidLifetime: lifetimeToTime(addr.ValidLft), + PreferredLifetime: lifetimeToTime(addr.PreferedLft), + Scope: addr.Scope, + }) + // ipv6AddressesString = append(ipv6AddressesString, addr.IPNet.String()) + } + } + + if len(ipv4Addresses) > 0 { + // compare the addresses to see if there's a change + if s.ipv4Addr == nil || s.ipv4Addr.String() != ipv4Addresses[0].String() { + scopedLogger := s.l.With().Str("ipv4", ipv4Addresses[0].String()).Logger() + if s.ipv4Addr != nil { + scopedLogger.Info(). + Str("old_ipv4", s.ipv4Addr.String()). + Msg("IPv4 address changed") + } else { + scopedLogger.Info().Msg("IPv4 address found") + } + s.ipv4Addr = &ipv4Addresses[0] + changed = true + } + } + s.ipv4Addresses = ipv4AddressesString + + if ipv6LinkLocal != nil { + if s.ipv6LinkLocal == nil || s.ipv6LinkLocal.String() != ipv6LinkLocal.String() { + scopedLogger := s.l.With().Str("ipv6", ipv6LinkLocal.String()).Logger() + if s.ipv6LinkLocal != nil { + scopedLogger.Info(). + Str("old_ipv6", s.ipv6LinkLocal.String()). + Msg("IPv6 link local address changed") + } else { + scopedLogger.Info().Msg("IPv6 link local address found") + } + s.ipv6LinkLocal = ipv6LinkLocal + changed = true + } + } + s.ipv6Addresses = ipv6Addresses + + if len(ipv6Addresses) > 0 { + // compare the addresses to see if there's a change + if s.ipv6Addr == nil || s.ipv6Addr.String() != ipv6Addresses[0].Address.String() { + scopedLogger := s.l.With().Str("ipv6", ipv6Addresses[0].Address.String()).Logger() + if s.ipv6Addr != nil { + scopedLogger.Info(). + Str("old_ipv6", s.ipv6Addr.String()). + Msg("IPv6 address changed") + } else { + scopedLogger.Info().Msg("IPv6 address found") + } + s.ipv6Addr = &ipv6Addresses[0].Address + changed = true + } + } + + // if it's the initial check, we'll set changed to false + initialCheck := !s.checked + if initialCheck { + s.checked = true + changed = false + if dhcpTargetState == DhcpTargetStateRenew { + // it's the initial check, we'll start the DHCP client + // dhcpTargetState = DhcpTargetStateStart + // TODO: manage DHCP client start/stop + dhcpTargetState = DhcpTargetStateDoNothing + } + } + + if initialCheck { + s.onInitialCheck(s) + } else if changed { + s.onStateChange(s) + } + + return dhcpTargetState, nil +} + +func (s *NetworkInterfaceState) CheckAndUpdateDhcp() error { + dhcpTargetState, err := s.update() + if err != nil { + return logging.ErrorfL(s.l, "failed to update network state", err) + } + + switch dhcpTargetState { + case DhcpTargetStateRenew: + s.l.Info().Msg("renewing DHCP lease") + _ = s.dhcpClient.Renew() + case DhcpTargetStateRelease: + s.l.Info().Msg("releasing DHCP lease") + _ = s.dhcpClient.Release() + case DhcpTargetStateStart: + s.l.Warn().Msg("dhcpTargetStateStart not implemented") + case DhcpTargetStateStop: + s.l.Warn().Msg("dhcpTargetStateStop not implemented") + } + + return nil +} + +func (s *NetworkInterfaceState) onConfigChange(config *NetworkConfig) { + _ = s.setHostnameIfNotSame() + s.cbConfigChange(config) +} diff --git a/internal/network/netif_linux.go b/internal/network/netif_linux.go new file mode 100644 index 0000000..ec057f1 --- /dev/null +++ b/internal/network/netif_linux.go @@ -0,0 +1,58 @@ +//go:build linux + +package network + +import ( + "time" + + "github.com/vishvananda/netlink" + "github.com/vishvananda/netlink/nl" +) + +func (s *NetworkInterfaceState) HandleLinkUpdate(update netlink.LinkUpdate) { + if update.Link.Attrs().Name == s.interfaceName { + s.l.Info().Interface("update", update).Msg("interface link update received") + _ = s.CheckAndUpdateDhcp() + } +} + +func (s *NetworkInterfaceState) Run() error { + updates := make(chan netlink.LinkUpdate) + done := make(chan struct{}) + + if err := netlink.LinkSubscribe(updates, done); err != nil { + s.l.Warn().Err(err).Msg("failed to subscribe to link updates") + return err + } + + _ = s.setHostnameIfNotSame() + + // run the dhcp client + go s.dhcpClient.Run() // nolint:errcheck + + if err := s.CheckAndUpdateDhcp(); err != nil { + return err + } + + go func() { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case update := <-updates: + s.HandleLinkUpdate(update) + case <-ticker.C: + _ = s.CheckAndUpdateDhcp() + case <-done: + return + } + } + }() + + return nil +} + +func netlinkAddrs(iface netlink.Link) ([]netlink.Addr, error) { + return netlink.AddrList(iface, nl.FAMILY_ALL) +} diff --git a/internal/network/netif_notlinux.go b/internal/network/netif_notlinux.go new file mode 100644 index 0000000..d101630 --- /dev/null +++ b/internal/network/netif_notlinux.go @@ -0,0 +1,21 @@ +//go:build !linux + +package network + +import ( + "fmt" + + "github.com/vishvananda/netlink" +) + +func (s *NetworkInterfaceState) HandleLinkUpdate() error { + return fmt.Errorf("not implemented") +} + +func (s *NetworkInterfaceState) Run() error { + return fmt.Errorf("not implemented") +} + +func netlinkAddrs(iface netlink.Link) ([]netlink.Addr, error) { + return nil, fmt.Errorf("not implemented") +} diff --git a/internal/network/rpc.go b/internal/network/rpc.go new file mode 100644 index 0000000..32f34f5 --- /dev/null +++ b/internal/network/rpc.go @@ -0,0 +1,126 @@ +package network + +import ( + "fmt" + "time" + + "github.com/jetkvm/kvm/internal/confparser" + "github.com/jetkvm/kvm/internal/udhcpc" +) + +type RpcIPv6Address struct { + Address string `json:"address"` + ValidLifetime *time.Time `json:"valid_lifetime,omitempty"` + PreferredLifetime *time.Time `json:"preferred_lifetime,omitempty"` + Scope int `json:"scope"` +} + +type RpcNetworkState struct { + InterfaceName string `json:"interface_name"` + MacAddress string `json:"mac_address"` + IPv4 string `json:"ipv4,omitempty"` + IPv6 string `json:"ipv6,omitempty"` + IPv6LinkLocal string `json:"ipv6_link_local,omitempty"` + IPv4Addresses []string `json:"ipv4_addresses,omitempty"` + IPv6Addresses []RpcIPv6Address `json:"ipv6_addresses,omitempty"` + DHCPLease *udhcpc.Lease `json:"dhcp_lease,omitempty"` +} + +type RpcNetworkSettings struct { + NetworkConfig +} + +func (s *NetworkInterfaceState) MacAddress() string { + if s.macAddr == nil { + return "" + } + + return s.macAddr.String() +} + +func (s *NetworkInterfaceState) IPv4Address() string { + if s.ipv4Addr == nil { + return "" + } + + return s.ipv4Addr.String() +} + +func (s *NetworkInterfaceState) IPv6Address() string { + if s.ipv6Addr == nil { + return "" + } + + return s.ipv6Addr.String() +} + +func (s *NetworkInterfaceState) IPv6LinkLocalAddress() string { + if s.ipv6LinkLocal == nil { + return "" + } + + return s.ipv6LinkLocal.String() +} + +func (s *NetworkInterfaceState) RpcGetNetworkState() RpcNetworkState { + ipv6Addresses := make([]RpcIPv6Address, 0) + + if s.ipv6Addresses != nil { + for _, addr := range s.ipv6Addresses { + ipv6Addresses = append(ipv6Addresses, RpcIPv6Address{ + Address: addr.Prefix.String(), + ValidLifetime: addr.ValidLifetime, + PreferredLifetime: addr.PreferredLifetime, + Scope: addr.Scope, + }) + } + } + + return RpcNetworkState{ + InterfaceName: s.interfaceName, + MacAddress: s.MacAddress(), + IPv4: s.IPv4Address(), + IPv6: s.IPv6Address(), + IPv6LinkLocal: s.IPv6LinkLocalAddress(), + IPv4Addresses: s.ipv4Addresses, + IPv6Addresses: ipv6Addresses, + DHCPLease: s.dhcpClient.GetLease(), + } +} + +func (s *NetworkInterfaceState) RpcGetNetworkSettings() RpcNetworkSettings { + if s.config == nil { + return RpcNetworkSettings{} + } + + return RpcNetworkSettings{ + NetworkConfig: *s.config, + } +} + +func (s *NetworkInterfaceState) RpcSetNetworkSettings(settings RpcNetworkSettings) error { + currentSettings := s.config + + err := confparser.SetDefaultsAndValidate(&settings.NetworkConfig) + if err != nil { + return err + } + + if IsSame(currentSettings, settings.NetworkConfig) { + // no changes, do nothing + return nil + } + + s.config = &settings.NetworkConfig + s.onConfigChange(s.config) + + return nil +} + +func (s *NetworkInterfaceState) RpcRenewDHCPLease() error { + if s.dhcpClient == nil { + return fmt.Errorf("dhcp client not initialized") + } + + return s.dhcpClient.Renew() +} diff --git a/internal/network/utils.go b/internal/network/utils.go new file mode 100644 index 0000000..6d64332 --- /dev/null +++ b/internal/network/utils.go @@ -0,0 +1,26 @@ +package network + +import ( + "encoding/json" + "time" +) + +func lifetimeToTime(lifetime int) *time.Time { + if lifetime == 0 { + return nil + } + t := time.Now().Add(time.Duration(lifetime) * time.Second) + return &t +} + +func IsSame(a, b interface{}) bool { + aJSON, err := json.Marshal(a) + if err != nil { + return false + } + bJSON, err := json.Marshal(b) + if err != nil { + return false + } + return string(aJSON) == string(bJSON) +} diff --git a/internal/timesync/http.go b/internal/timesync/http.go new file mode 100644 index 0000000..3a51463 --- /dev/null +++ b/internal/timesync/http.go @@ -0,0 +1,132 @@ +package timesync + +import ( + "context" + "errors" + "math/rand" + "net/http" + "strconv" + "time" +) + +var defaultHTTPUrls = []string{ + "http://www.gstatic.com/generate_204", + "http://cp.cloudflare.com/", + "http://edge-http.microsoft.com/captiveportal/generate_204", + // Firefox, Apple, and Microsoft have inconsistent results, so we don't use it + // "http://detectportal.firefox.com/", + // "http://www.apple.com/library/test/success.html", + // "http://www.msftconnecttest.com/connecttest.txt", +} + +func (t *TimeSync) queryAllHttpTime() (now *time.Time) { + chunkSize := 4 + httpUrls := t.httpUrls + + // shuffle the http urls to avoid always querying the same servers + rand.Shuffle(len(httpUrls), func(i, j int) { httpUrls[i], httpUrls[j] = httpUrls[j], httpUrls[i] }) + + for i := 0; i < len(httpUrls); i += chunkSize { + chunk := httpUrls[i:min(i+chunkSize, len(httpUrls))] + results := t.queryMultipleHttp(chunk, timeSyncTimeout) + if results != nil { + return results + } + } + + return nil +} + +func (t *TimeSync) queryMultipleHttp(urls []string, timeout time.Duration) (now *time.Time) { + results := make(chan *time.Time, len(urls)) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for _, url := range urls { + go func(url string) { + scopedLogger := t.l.With(). + Str("http_url", url). + Logger() + + metricHttpRequestCount.WithLabelValues(url).Inc() + metricHttpTotalRequestCount.Inc() + + startTime := time.Now() + now, response, err := queryHttpTime( + ctx, + url, + timeout, + ) + duration := time.Since(startTime) + + metricHttpServerLastRTT.WithLabelValues(url).Set(float64(duration.Milliseconds())) + metricHttpServerRttHistogram.WithLabelValues(url).Observe(float64(duration.Milliseconds())) + + status := 0 + if response != nil { + status = response.StatusCode + } + metricHttpServerInfo.WithLabelValues( + url, + strconv.Itoa(status), + ).Set(1) + + if err == nil { + metricHttpTotalSuccessCount.Inc() + metricHttpSuccessCount.WithLabelValues(url).Inc() + + requestId := response.Header.Get("X-Request-Id") + if requestId != "" { + requestId = response.Header.Get("X-Msedge-Ref") + } + if requestId == "" { + requestId = response.Header.Get("Cf-Ray") + } + scopedLogger.Info(). + Str("time", now.Format(time.RFC3339)). + Int("status", status). + Str("request_id", requestId). + Str("time_taken", duration.String()). + Msg("HTTP server returned time") + + cancel() + results <- now + } else if errors.Is(err, context.Canceled) { + metricHttpCancelCount.WithLabelValues(url).Inc() + metricHttpTotalCancelCount.Inc() + } else { + scopedLogger.Warn(). + Str("error", err.Error()). + Int("status", status). + Msg("failed to query HTTP server") + } + }(url) + } + + return <-results +} + +func queryHttpTime( + ctx context.Context, + url string, + timeout time.Duration, +) (now *time.Time, response *http.Response, err error) { + client := http.Client{ + Timeout: timeout, + } + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, nil, err + } + dateStr := resp.Header.Get("Date") + parsedTime, err := time.Parse(time.RFC1123, dateStr) + if err != nil { + return nil, nil, err + } + return &parsedTime, resp, nil +} diff --git a/internal/timesync/metrics.go b/internal/timesync/metrics.go new file mode 100644 index 0000000..0e28acb --- /dev/null +++ b/internal/timesync/metrics.go @@ -0,0 +1,147 @@ +package timesync + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + metricTimeSyncStatus = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "jetkvm_timesync_status", + Help: "The status of the timesync, 1 if successful, 0 if not", + }, + ) + metricTimeSyncCount = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_count", + Help: "The number of times the timesync has been run", + }, + ) + metricTimeSyncSuccessCount = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_success_count", + Help: "The number of times the timesync has been successful", + }, + ) + metricRTCUpdateCount = promauto.NewCounter( //nolint:unused + prometheus.CounterOpts{ + Name: "jetkvm_timesync_rtc_update_count", + Help: "The number of times the RTC has been updated", + }, + ) + metricNtpTotalSuccessCount = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_ntp_total_success_count", + Help: "The total number of successful NTP requests", + }, + ) + metricNtpTotalRequestCount = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_ntp_total_request_count", + Help: "The total number of NTP requests sent", + }, + ) + metricNtpSuccessCount = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_ntp_success_count", + Help: "The number of successful NTP requests", + }, + []string{"url"}, + ) + metricNtpRequestCount = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_ntp_request_count", + Help: "The number of NTP requests sent to the server", + }, + []string{"url"}, + ) + metricNtpServerLastRTT = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "jetkvm_timesync_ntp_server_last_rtt", + Help: "The last RTT of the NTP server in milliseconds", + }, + []string{"url"}, + ) + metricNtpServerRttHistogram = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "jetkvm_timesync_ntp_server_rtt", + Help: "The histogram of the RTT of the NTP server in milliseconds", + Buckets: []float64{ + 10, 25, 50, 100, 200, 300, 500, 1000, + }, + }, + []string{"url"}, + ) + metricNtpServerInfo = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "jetkvm_timesync_ntp_server_info", + Help: "The info of the NTP server", + }, + []string{"url", "reference", "stratum", "precision"}, + ) + + metricHttpTotalSuccessCount = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_http_total_success_count", + Help: "The total number of successful HTTP requests", + }, + ) + metricHttpTotalRequestCount = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_http_total_request_count", + Help: "The total number of HTTP requests sent", + }, + ) + metricHttpTotalCancelCount = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_http_total_cancel_count", + Help: "The total number of HTTP requests cancelled", + }, + ) + metricHttpSuccessCount = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_http_success_count", + Help: "The number of successful HTTP requests", + }, + []string{"url"}, + ) + metricHttpRequestCount = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_http_request_count", + Help: "The number of HTTP requests sent to the server", + }, + []string{"url"}, + ) + metricHttpCancelCount = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "jetkvm_timesync_http_cancel_count", + Help: "The number of HTTP requests cancelled", + }, + []string{"url"}, + ) + metricHttpServerLastRTT = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "jetkvm_timesync_http_server_last_rtt", + Help: "The last RTT of the HTTP server in milliseconds", + }, + []string{"url"}, + ) + metricHttpServerRttHistogram = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "jetkvm_timesync_http_server_rtt", + Help: "The histogram of the RTT of the HTTP server in milliseconds", + Buckets: []float64{ + 10, 25, 50, 100, 200, 300, 500, 1000, + }, + }, + []string{"url"}, + ) + metricHttpServerInfo = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "jetkvm_timesync_http_server_info", + Help: "The info of the HTTP server", + }, + []string{"url", "http_code"}, + ) +) diff --git a/internal/timesync/ntp.go b/internal/timesync/ntp.go new file mode 100644 index 0000000..41656b7 --- /dev/null +++ b/internal/timesync/ntp.go @@ -0,0 +1,113 @@ +package timesync + +import ( + "math/rand/v2" + "strconv" + "time" + + "github.com/beevik/ntp" +) + +var defaultNTPServers = []string{ + "time.apple.com", + "time.aws.com", + "time.windows.com", + "time.google.com", + "162.159.200.123", // time.cloudflare.com + "0.pool.ntp.org", + "1.pool.ntp.org", + "2.pool.ntp.org", + "3.pool.ntp.org", +} + +func (t *TimeSync) queryNetworkTime() (now *time.Time, offset *time.Duration) { + chunkSize := 4 + ntpServers := t.ntpServers + + // shuffle the ntp servers to avoid always querying the same servers + rand.Shuffle(len(ntpServers), func(i, j int) { ntpServers[i], ntpServers[j] = ntpServers[j], ntpServers[i] }) + + for i := 0; i < len(ntpServers); i += chunkSize { + chunk := ntpServers[i:min(i+chunkSize, len(ntpServers))] + now, offset := t.queryMultipleNTP(chunk, timeSyncTimeout) + if now != nil { + return now, offset + } + } + + return nil, nil +} + +type ntpResult struct { + now *time.Time + offset *time.Duration +} + +func (t *TimeSync) queryMultipleNTP(servers []string, timeout time.Duration) (now *time.Time, offset *time.Duration) { + results := make(chan *ntpResult, len(servers)) + for _, server := range servers { + go func(server string) { + scopedLogger := t.l.With(). + Str("server", server). + Logger() + + // increase request count + metricNtpTotalRequestCount.Inc() + metricNtpRequestCount.WithLabelValues(server).Inc() + + // query the server + now, response, err := queryNtpServer(server, timeout) + + // set the last RTT + metricNtpServerLastRTT.WithLabelValues( + server, + ).Set(float64(response.RTT.Milliseconds())) + + // set the RTT histogram + metricNtpServerRttHistogram.WithLabelValues( + server, + ).Observe(float64(response.RTT.Milliseconds())) + + // set the server info + metricNtpServerInfo.WithLabelValues( + server, + response.ReferenceString(), + strconv.Itoa(int(response.Stratum)), + strconv.Itoa(int(response.Precision)), + ).Set(1) + + if err == nil { + // increase success count + metricNtpTotalSuccessCount.Inc() + metricNtpSuccessCount.WithLabelValues(server).Inc() + + scopedLogger.Info(). + Str("time", now.Format(time.RFC3339)). + Str("reference", response.ReferenceString()). + Str("rtt", response.RTT.String()). + Str("clockOffset", response.ClockOffset.String()). + Uint8("stratum", response.Stratum). + Msg("NTP server returned time") + results <- &ntpResult{ + now: now, + offset: &response.ClockOffset, + } + } else { + scopedLogger.Warn(). + Str("error", err.Error()). + Msg("failed to query NTP server") + } + }(server) + } + + result := <-results + return result.now, result.offset +} + +func queryNtpServer(server string, timeout time.Duration) (now *time.Time, response *ntp.Response, err error) { + resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout}) + if err != nil { + return nil, nil, err + } + return &resp.Time, resp, nil +} diff --git a/internal/timesync/rtc.go b/internal/timesync/rtc.go new file mode 100644 index 0000000..92ee485 --- /dev/null +++ b/internal/timesync/rtc.go @@ -0,0 +1,26 @@ +package timesync + +import ( + "fmt" + "os" +) + +var ( + rtcDeviceSearchPaths = []string{ + "/dev/rtc", + "/dev/rtc0", + "/dev/rtc1", + "/dev/misc/rtc", + "/dev/misc/rtc0", + "/dev/misc/rtc1", + } +) + +func getRtcDevicePath() (string, error) { + for _, path := range rtcDeviceSearchPaths { + if _, err := os.Stat(path); err == nil { + return path, nil + } + } + return "", fmt.Errorf("rtc device not found") +} diff --git a/internal/timesync/rtc_linux.go b/internal/timesync/rtc_linux.go new file mode 100644 index 0000000..27e4ec7 --- /dev/null +++ b/internal/timesync/rtc_linux.go @@ -0,0 +1,105 @@ +//go:build linux + +package timesync + +import ( + "fmt" + "os" + "time" + + "golang.org/x/sys/unix" +) + +func TimetoRtcTime(t time.Time) unix.RTCTime { + return unix.RTCTime{ + Sec: int32(t.Second()), + Min: int32(t.Minute()), + Hour: int32(t.Hour()), + Mday: int32(t.Day()), + Mon: int32(t.Month() - 1), + Year: int32(t.Year() - 1900), + Wday: int32(0), + Yday: int32(0), + Isdst: int32(0), + } +} + +func RtcTimetoTime(t unix.RTCTime) time.Time { + return time.Date( + int(t.Year)+1900, + time.Month(t.Mon+1), + int(t.Mday), + int(t.Hour), + int(t.Min), + int(t.Sec), + 0, + time.UTC, + ) +} + +func (t *TimeSync) getRtcDevice() (*os.File, error) { + if t.rtcDevice == nil { + file, err := os.OpenFile(t.rtcDevicePath, os.O_RDWR, 0666) + if err != nil { + return nil, err + } + t.rtcDevice = file + } + return t.rtcDevice, nil +} + +func (t *TimeSync) getRtcDeviceFd() (int, error) { + device, err := t.getRtcDevice() + if err != nil { + return 0, err + } + return int(device.Fd()), nil +} + +// Read implements Read for the Linux RTC +func (t *TimeSync) readRtcTime() (time.Time, error) { + fd, err := t.getRtcDeviceFd() + if err != nil { + return time.Time{}, fmt.Errorf("failed to get RTC device fd: %w", err) + } + + rtcTime, err := unix.IoctlGetRTCTime(fd) + if err != nil { + return time.Time{}, fmt.Errorf("failed to get RTC time: %w", err) + } + + date := RtcTimetoTime(*rtcTime) + + return date, nil +} + +// Set implements Set for the Linux RTC +// ... +// It might be not accurate as the time consumed by the system call is not taken into account +// but it's good enough for our purposes +func (t *TimeSync) setRtcTime(tu time.Time) error { + rt := TimetoRtcTime(tu) + + fd, err := t.getRtcDeviceFd() + if err != nil { + return fmt.Errorf("failed to get RTC device fd: %w", err) + } + + currentRtcTime, err := t.readRtcTime() + if err != nil { + return fmt.Errorf("failed to read RTC time: %w", err) + } + + t.l.Info(). + Interface("rtc_time", tu). + Str("offset", tu.Sub(currentRtcTime).String()). + Msg("set rtc time") + + if err := unix.IoctlSetRTCTime(fd, &rt); err != nil { + return fmt.Errorf("failed to set RTC time: %w", err) + } + + metricRTCUpdateCount.Inc() + + return nil +} diff --git a/internal/timesync/rtc_notlinux.go b/internal/timesync/rtc_notlinux.go new file mode 100644 index 0000000..e3c1b20 --- /dev/null +++ b/internal/timesync/rtc_notlinux.go @@ -0,0 +1,16 @@ +//go:build !linux + +package timesync + +import ( + "errors" + "time" +) + +func (t *TimeSync) readRtcTime() (time.Time, error) { + return time.Now(), nil +} + +func (t *TimeSync) setRtcTime(tu time.Time) error { + return errors.New("not supported") +} diff --git a/internal/timesync/timesync.go b/internal/timesync/timesync.go new file mode 100644 index 0000000..e956cf9 --- /dev/null +++ b/internal/timesync/timesync.go @@ -0,0 +1,208 @@ +package timesync + +import ( + "fmt" + "os" + "os/exec" + "sync" + "time" + + "github.com/jetkvm/kvm/internal/network" + "github.com/rs/zerolog" +) + +const ( + timeSyncRetryStep = 5 * time.Second + timeSyncRetryMaxInt = 1 * time.Minute + timeSyncWaitNetChkInt = 100 * time.Millisecond + timeSyncWaitNetUpInt = 3 * time.Second + timeSyncInterval = 1 * time.Hour + timeSyncTimeout = 2 * time.Second +) + +var ( + timeSyncRetryInterval = 0 * time.Second +) + +type TimeSync struct { + syncLock *sync.Mutex + l *zerolog.Logger + + ntpServers []string + httpUrls []string + networkConfig *network.NetworkConfig + + rtcDevicePath string + rtcDevice *os.File //nolint:unused + rtcLock *sync.Mutex + + syncSuccess bool + + preCheckFunc func() (bool, error) +} + +type TimeSyncOptions struct { + PreCheckFunc func() (bool, error) + Logger *zerolog.Logger + NetworkConfig *network.NetworkConfig +} + +type SyncMode struct { + Ntp bool + Http bool + Ordering []string + NtpUseFallback bool + HttpUseFallback bool +} + +func NewTimeSync(opts *TimeSyncOptions) *TimeSync { + rtcDevice, err := getRtcDevicePath() + if err != nil { + opts.Logger.Error().Err(err).Msg("failed to get RTC device path") + } else { + opts.Logger.Info().Str("path", rtcDevice).Msg("RTC device found") + } + + t := &TimeSync{ + syncLock: &sync.Mutex{}, + l: opts.Logger, + rtcDevicePath: rtcDevice, + rtcLock: &sync.Mutex{}, + preCheckFunc: opts.PreCheckFunc, + ntpServers: defaultNTPServers, + httpUrls: defaultHTTPUrls, + networkConfig: opts.NetworkConfig, + } + + if t.rtcDevicePath != "" { + rtcTime, _ := t.readRtcTime() + t.l.Info().Interface("rtc_time", rtcTime).Msg("read RTC time") + } + + return t +} + +func (t *TimeSync) getSyncMode() SyncMode { + syncMode := SyncMode{ + NtpUseFallback: true, + HttpUseFallback: true, + } + var syncModeString string + + if t.networkConfig != nil { + syncModeString = t.networkConfig.TimeSyncMode.String + if t.networkConfig.TimeSyncDisableFallback.Bool { + syncMode.NtpUseFallback = false + syncMode.HttpUseFallback = false + } + } + + switch syncModeString { + case "ntp_only": + syncMode.Ntp = true + case "http_only": + syncMode.Http = true + default: + syncMode.Ntp = true + syncMode.Http = true + } + + return syncMode +} + +func (t *TimeSync) doTimeSync() { + metricTimeSyncStatus.Set(0) + for { + if ok, err := t.preCheckFunc(); !ok { + if err != nil { + t.l.Error().Err(err).Msg("pre-check failed") + } + time.Sleep(timeSyncWaitNetChkInt) + continue + } + + t.l.Info().Msg("syncing system time") + start := time.Now() + err := t.Sync() + if err != nil { + t.l.Error().Str("error", err.Error()).Msg("failed to sync system time") + + // retry after a delay + timeSyncRetryInterval += timeSyncRetryStep + time.Sleep(timeSyncRetryInterval) + // reset the retry interval if it exceeds the max interval + if timeSyncRetryInterval > timeSyncRetryMaxInt { + timeSyncRetryInterval = 0 + } + + continue + } + t.syncSuccess = true + t.l.Info().Str("now", time.Now().Format(time.RFC3339)). + Str("time_taken", time.Since(start).String()). + Msg("time sync successful") + + metricTimeSyncStatus.Set(1) + + time.Sleep(timeSyncInterval) // after the first sync is done + } +} + +func (t *TimeSync) Sync() error { + var ( + now *time.Time + offset *time.Duration + ) + + syncMode := t.getSyncMode() + + metricTimeSyncCount.Inc() + + if syncMode.Ntp { + now, offset = t.queryNetworkTime() + } + + if syncMode.Http && now == nil { + now = t.queryAllHttpTime() + } + + if now == nil { + return fmt.Errorf("failed to get time from any source") + } + + if offset != nil { + newNow := time.Now().Add(*offset) + now = &newNow + } + + err := t.setSystemTime(*now) + if err != nil { + return fmt.Errorf("failed to set system time: %w", err) + } + + metricTimeSyncSuccessCount.Inc() + + return nil +} + +func (t *TimeSync) IsSyncSuccess() bool { + return t.syncSuccess +} + +func (t *TimeSync) Start() { + go t.doTimeSync() +} + +func (t *TimeSync) setSystemTime(now time.Time) error { + nowStr := now.Format("2006-01-02 15:04:05") + output, err := exec.Command("date", "-s", nowStr).CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run date -s: %w, %s", err, string(output)) + } + + if t.rtcDevicePath != "" { + return t.setRtcTime(now) + } + + return nil +} diff --git a/internal/udhcpc/options.go b/internal/udhcpc/options.go new file mode 100644 index 0000000..10c9f75 --- /dev/null +++ b/internal/udhcpc/options.go @@ -0,0 +1,12 @@ +package udhcpc + +func (u *DHCPClient) GetNtpServers() []string { + if u.lease == nil { + return nil + } + servers := make([]string, len(u.lease.NTPServers)) + for i, server := range u.lease.NTPServers { + servers[i] = server.String() + } + return servers +} diff --git a/internal/udhcpc/parser.go b/internal/udhcpc/parser.go new file mode 100644 index 0000000..66c3ba2 --- /dev/null +++ b/internal/udhcpc/parser.go @@ -0,0 +1,186 @@ +package udhcpc + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "os" + "reflect" + "strconv" + "strings" + "time" +) + +type Lease struct { + // from https://udhcp.busybox.net/README.udhcpc + IPAddress net.IP `env:"ip" json:"ip"` // The obtained IP + Netmask net.IP `env:"subnet" json:"netmask"` // The assigned subnet mask + Broadcast net.IP `env:"broadcast" json:"broadcast"` // The broadcast address for this network + TTL int `env:"ipttl" json:"ttl,omitempty"` // The TTL to use for this network + MTU int `env:"mtu" json:"mtu,omitempty"` // The MTU to use for this network + HostName string `env:"hostname" json:"hostname,omitempty"` // The assigned hostname + Domain string `env:"domain" json:"domain,omitempty"` // The domain name of the network + BootPNextServer net.IP `env:"siaddr" json:"bootp_next_server,omitempty"` // The bootp next server option + BootPServerName string `env:"sname" json:"bootp_server_name,omitempty"` // The bootp server name option + BootPFile string `env:"boot_file" json:"bootp_file,omitempty"` // The bootp boot file option + Timezone string `env:"timezone" json:"timezone,omitempty"` // Offset in seconds from UTC + Routers []net.IP `env:"router" json:"routers,omitempty"` // A list of routers + DNS []net.IP `env:"dns" json:"dns_servers,omitempty"` // A list of DNS servers + NTPServers []net.IP `env:"ntpsrv" json:"ntp_servers,omitempty"` // A list of NTP servers + LPRServers []net.IP `env:"lprsvr" json:"lpr_servers,omitempty"` // A list of LPR servers + TimeServers []net.IP `env:"timesvr" json:"_time_servers,omitempty"` // A list of time servers (obsolete) + IEN116NameServers []net.IP `env:"namesvr" json:"_name_servers,omitempty"` // A list of IEN 116 name servers (obsolete) + LogServers []net.IP `env:"logsvr" json:"_log_servers,omitempty"` // A list of MIT-LCS UDP log servers (obsolete) + CookieServers []net.IP `env:"cookiesvr" json:"_cookie_servers,omitempty"` // A list of RFC 865 cookie servers (obsolete) + WINSServers []net.IP `env:"wins" json:"_wins_servers,omitempty"` // A list of WINS servers + SwapServer net.IP `env:"swapsvr" json:"_swap_server,omitempty"` // The IP address of the client's swap server + BootSize int `env:"bootsize" json:"bootsize,omitempty"` // The length in 512 octect blocks of the bootfile + RootPath string `env:"rootpath" json:"root_path,omitempty"` // The path name of the client's root disk + LeaseTime time.Duration `env:"lease" json:"lease,omitempty"` // The lease time, in seconds + DHCPType string `env:"dhcptype" json:"dhcp_type,omitempty"` // DHCP message type (safely ignored) + ServerID string `env:"serverid" json:"server_id,omitempty"` // The IP of the server + Message string `env:"message" json:"reason,omitempty"` // Reason for a DHCPNAK + TFTPServerName string `env:"tftp" json:"tftp,omitempty"` // The TFTP server name + BootFileName string `env:"bootfile" json:"bootfile,omitempty"` // The boot file name + Uptime time.Duration `env:"uptime" json:"uptime,omitempty"` // The uptime of the device when the lease was obtained, in seconds + LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // The expiry time of the lease + isEmpty map[string]bool +} + +func (l *Lease) setIsEmpty(m map[string]bool) { + l.isEmpty = m +} + +func (l *Lease) IsEmpty(key string) bool { + return l.isEmpty[key] +} + +func (l *Lease) ToJSON() string { + json, err := json.Marshal(l) + if err != nil { + return "" + } + return string(json) +} + +func (l *Lease) SetLeaseExpiry() (time.Time, error) { + if l.Uptime == 0 || l.LeaseTime == 0 { + return time.Time{}, fmt.Errorf("uptime or lease time isn't set") + } + + // get the uptime of the device + + file, err := os.Open("/proc/uptime") + if err != nil { + return time.Time{}, fmt.Errorf("failed to open uptime file: %w", err) + } + defer file.Close() + + var uptime time.Duration + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + text := scanner.Text() + parts := strings.Split(text, " ") + uptime, err = time.ParseDuration(parts[0] + "s") + + if err != nil { + return time.Time{}, fmt.Errorf("failed to parse uptime: %w", err) + } + } + + relativeLeaseRemaining := (l.Uptime + l.LeaseTime) - uptime + leaseExpiry := time.Now().Add(relativeLeaseRemaining) + + l.LeaseExpiry = &leaseExpiry + + return leaseExpiry, nil +} + +func UnmarshalDHCPCLease(lease *Lease, str string) error { + // parse the lease file as a map + data := make(map[string]string) + for _, line := range strings.Split(str, "\n") { + line = strings.TrimSpace(line) + // skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + data[key] = value + } + + // now iterate over the lease struct and set the values + leaseType := reflect.TypeOf(lease).Elem() + leaseValue := reflect.ValueOf(lease).Elem() + + valuesParsed := make(map[string]bool) + + for i := 0; i < leaseType.NumField(); i++ { + field := leaseValue.Field(i) + + // get the env tag + key := leaseType.Field(i).Tag.Get("env") + if key == "" { + continue + } + + valuesParsed[key] = false + + // get the value from the data map + value, ok := data[key] + if !ok || value == "" { + continue + } + + switch field.Interface().(type) { + case string: + field.SetString(value) + case int: + val, err := strconv.Atoi(value) + if err != nil { + continue + } + field.SetInt(int64(val)) + case time.Duration: + val, err := time.ParseDuration(value + "s") + if err != nil { + continue + } + field.Set(reflect.ValueOf(val)) + case net.IP: + ip := net.ParseIP(value) + if ip == nil { + continue + } + field.Set(reflect.ValueOf(ip)) + case []net.IP: + val := make([]net.IP, 0) + for _, ipStr := range strings.Fields(value) { + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + val = append(val, ip) + } + field.Set(reflect.ValueOf(val)) + default: + return fmt.Errorf("unsupported field `%s` type: %s", key, field.Type().String()) + } + + valuesParsed[key] = true + } + + lease.setIsEmpty(valuesParsed) + + return nil +} diff --git a/internal/udhcpc/parser_test.go b/internal/udhcpc/parser_test.go new file mode 100644 index 0000000..423ab53 --- /dev/null +++ b/internal/udhcpc/parser_test.go @@ -0,0 +1,74 @@ +package udhcpc + +import ( + "testing" + "time" +) + +func TestUnmarshalDHCPCLease(t *testing.T) { + lease := &Lease{} + err := UnmarshalDHCPCLease(lease, ` +# generated @ Mon Jan 4 19:31:53 UTC 2021 +# 19:31:53 up 0 min, 0 users, load average: 0.72, 0.14, 0.04 +# the date might be inaccurate if the clock is not set +ip=192.168.0.240 +siaddr=192.168.0.1 +sname= +boot_file= +subnet=255.255.255.0 +timezone= +router=192.168.0.1 +timesvr= +namesvr= +dns=172.19.53.2 +logsvr= +cookiesvr= +lprsvr= +hostname= +bootsize= +domain= +swapsvr= +rootpath= +ipttl= +mtu= +broadcast= +ntpsrv=162.159.200.123 +wins= +lease=172800 +dhcptype= +serverid=192.168.0.1 +message= +tftp= +bootfile= + `) + if lease.IPAddress.String() != "192.168.0.240" { + t.Fatalf("expected ip to be 192.168.0.240, got %s", lease.IPAddress.String()) + } + if lease.Netmask.String() != "255.255.255.0" { + t.Fatalf("expected netmask to be 255.255.255.0, got %s", lease.Netmask.String()) + } + if len(lease.Routers) != 1 { + t.Fatalf("expected 1 router, got %d", len(lease.Routers)) + } + if lease.Routers[0].String() != "192.168.0.1" { + t.Fatalf("expected router to be 192.168.0.1, got %s", lease.Routers[0].String()) + } + if len(lease.NTPServers) != 1 { + t.Fatalf("expected 1 timeserver, got %d", len(lease.NTPServers)) + } + if lease.NTPServers[0].String() != "162.159.200.123" { + t.Fatalf("expected timeserver to be 162.159.200.123, got %s", lease.NTPServers[0].String()) + } + if len(lease.DNS) != 1 { + t.Fatalf("expected 1 dns, got %d", len(lease.DNS)) + } + if lease.DNS[0].String() != "172.19.53.2" { + t.Fatalf("expected dns to be 172.19.53.2, got %s", lease.DNS[0].String()) + } + if lease.LeaseTime != 172800*time.Second { + t.Fatalf("expected lease time to be 172800 seconds, got %d", lease.LeaseTime) + } + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/udhcpc/proc.go b/internal/udhcpc/proc.go new file mode 100644 index 0000000..69c2ab9 --- /dev/null +++ b/internal/udhcpc/proc.go @@ -0,0 +1,212 @@ +package udhcpc + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "syscall" +) + +func readFileNoStat(filename string) ([]byte, error) { + const maxBufferSize = 1024 * 1024 + + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + reader := io.LimitReader(f, maxBufferSize) + return io.ReadAll(reader) +} + +func toCmdline(path string) ([]string, error) { + data, err := readFileNoStat(path) + if err != nil { + return nil, err + } + + if len(data) < 1 { + return []string{}, nil + } + + return strings.Split(string(bytes.TrimRight(data, "\x00")), "\x00"), nil +} + +func (p *DHCPClient) findUdhcpcProcess() (int, error) { + // read procfs for udhcpc processes + // we do not use procfs.AllProcs() because we want to avoid the overhead of reading the entire procfs + processes, err := os.ReadDir("/proc") + if err != nil { + return 0, err + } + + // iterate over the processes + for _, d := range processes { + // check if file is numeric + pid, err := strconv.Atoi(d.Name()) + if err != nil { + continue + } + + // check if it's a directory + if !d.IsDir() { + continue + } + + cmdline, err := toCmdline(filepath.Join("/proc", d.Name(), "cmdline")) + if err != nil { + continue + } + + if len(cmdline) < 1 { + continue + } + + if cmdline[0] != "udhcpc" { + continue + } + + cmdlineText := strings.Join(cmdline, " ") + + // check if it's a udhcpc process + if strings.Contains(cmdlineText, fmt.Sprintf("-i %s", p.InterfaceName)) { + p.logger.Debug(). + Str("pid", d.Name()). + Interface("cmdline", cmdline). + Msg("found udhcpc process") + return pid, nil + } + } + + return 0, errors.New("udhcpc process not found") +} + +func (c *DHCPClient) getProcessPid() (int, error) { + var pid int + if c.pidFile != "" { + // try to read the pid file + pidHandle, err := os.ReadFile(c.pidFile) + if err != nil { + c.logger.Warn().Err(err). + Str("pidFile", c.pidFile).Msg("failed to read udhcpc pid file") + } + + // if it exists, try to read the pid + if pidHandle != nil { + pidFromFile, err := strconv.Atoi(string(pidHandle)) + if err != nil { + c.logger.Warn().Err(err). + Str("pidFile", c.pidFile).Msg("failed to convert pid file to int") + } + pid = pidFromFile + } + } + + // if the pid is 0, try to find the pid using procfs + if pid == 0 { + newPid, err := c.findUdhcpcProcess() + if err != nil { + return 0, err + } + pid = newPid + } + + return pid, nil +} + +func (c *DHCPClient) getProcess() *os.Process { + pid, err := c.getProcessPid() + if err != nil { + return nil + } + + process, err := os.FindProcess(pid) + if err != nil { + c.logger.Warn().Err(err). + Int("pid", pid).Msg("failed to find process") + return nil + } + + return process +} + +func (c *DHCPClient) GetProcess() *os.Process { + if c.process == nil { + process := c.getProcess() + if process == nil { + return nil + } + c.process = process + } + + err := c.process.Signal(syscall.Signal(0)) + if err != nil && errors.Is(err, os.ErrProcessDone) { + oldPid := c.process.Pid + + c.process = nil + c.process = c.getProcess() + if c.process == nil { + c.logger.Error().Msg("failed to find new udhcpc process") + return nil + } + c.logger.Warn(). + Int("oldPid", oldPid). + Int("newPid", c.process.Pid). + Msg("udhcpc process pid changed") + } else if err != nil { + c.logger.Warn().Err(err). + Int("pid", c.process.Pid).Msg("udhcpc process is not running") + } + + return c.process +} + +func (c *DHCPClient) KillProcess() error { + process := c.GetProcess() + if process == nil { + return nil + } + + return process.Kill() +} + +func (c *DHCPClient) ReleaseProcess() error { + process := c.GetProcess() + if process == nil { + return nil + } + + return process.Release() +} + +func (c *DHCPClient) signalProcess(sig syscall.Signal) error { + process := c.GetProcess() + if process == nil { + return nil + } + + s := process.Signal(sig) + if s != nil { + c.logger.Warn().Err(s). + Int("pid", process.Pid). + Str("signal", sig.String()). + Msg("failed to signal udhcpc process") + return s + } + + return nil +} + +func (c *DHCPClient) Renew() error { + return c.signalProcess(syscall.SIGUSR1) +} + +func (c *DHCPClient) Release() error { + return c.signalProcess(syscall.SIGUSR2) +} diff --git a/internal/udhcpc/udhcpc.go b/internal/udhcpc/udhcpc.go new file mode 100644 index 0000000..70ac1b8 --- /dev/null +++ b/internal/udhcpc/udhcpc.go @@ -0,0 +1,191 @@ +package udhcpc + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog" +) + +const ( + DHCPLeaseFile = "/run/udhcpc.%s.info" + DHCPPidFile = "/run/udhcpc.%s.pid" +) + +type DHCPClient struct { + InterfaceName string + leaseFile string + pidFile string + lease *Lease + logger *zerolog.Logger + process *os.Process + onLeaseChange func(lease *Lease) +} + +type DHCPClientOptions struct { + InterfaceName string + PidFile string + Logger *zerolog.Logger + OnLeaseChange func(lease *Lease) +} + +var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel) + +func NewDHCPClient(options *DHCPClientOptions) *DHCPClient { + if options.Logger == nil { + options.Logger = &defaultLogger + } + + l := options.Logger.With().Str("interface", options.InterfaceName).Logger() + return &DHCPClient{ + InterfaceName: options.InterfaceName, + logger: &l, + leaseFile: fmt.Sprintf(DHCPLeaseFile, options.InterfaceName), + pidFile: options.PidFile, + onLeaseChange: options.OnLeaseChange, + } +} + +func (c *DHCPClient) getWatchPaths() []string { + watchPaths := make(map[string]interface{}) + watchPaths[filepath.Dir(c.leaseFile)] = nil + + if c.pidFile != "" { + watchPaths[filepath.Dir(c.pidFile)] = nil + } + + paths := make([]string, 0) + for path := range watchPaths { + paths = append(paths, path) + } + return paths +} + +// Run starts the DHCP client and watches the lease file for changes. +// this isn't a blocking call, and the lease file is reloaded when a change is detected. +func (c *DHCPClient) Run() error { + err := c.loadLeaseFile() + if err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + defer watcher.Close() + + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + continue + } + if !event.Has(fsnotify.Write) && !event.Has(fsnotify.Create) { + continue + } + + if event.Name == c.leaseFile { + c.logger.Debug(). + Str("event", event.Op.String()). + Str("path", event.Name). + Msg("udhcpc lease file updated, reloading lease") + _ = c.loadLeaseFile() + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + c.logger.Error().Err(err).Msg("error watching lease file") + } + } + }() + + for _, path := range c.getWatchPaths() { + err = watcher.Add(path) + if err != nil { + c.logger.Error(). + Err(err). + Str("path", path). + Msg("failed to watch directory") + return err + } + } + + // TODO: update udhcpc pid file + // we'll comment this out for now because the pid might change + // process := c.GetProcess() + // if process == nil { + // c.logger.Error().Msg("udhcpc process not found") + // } + + // block the goroutine until the lease file is updated + <-make(chan struct{}) + + return nil +} + +func (c *DHCPClient) loadLeaseFile() error { + file, err := os.ReadFile(c.leaseFile) + if err != nil { + return err + } + + data := string(file) + if data == "" { + c.logger.Debug().Msg("udhcpc lease file is empty") + return nil + } + + lease := &Lease{} + err = UnmarshalDHCPCLease(lease, string(file)) + if err != nil { + return err + } + + isFirstLoad := c.lease == nil + c.lease = lease + + if lease.IPAddress == nil { + c.logger.Info(). + Interface("lease", lease). + Str("data", string(file)). + Msg("udhcpc lease cleared") + return nil + } + + msg := "udhcpc lease updated" + if isFirstLoad { + msg = "udhcpc lease loaded" + } + + leaseExpiry, err := lease.SetLeaseExpiry() + if err != nil { + c.logger.Error().Err(err).Msg("failed to get dhcp lease expiry") + } else { + expiresIn := time.Until(leaseExpiry) + c.logger.Info(). + Interface("expiry", leaseExpiry). + Str("expiresIn", expiresIn.String()). + Msg("current dhcp lease expiry time calculated") + } + + c.onLeaseChange(lease) + + c.logger.Info(). + Str("ip", lease.IPAddress.String()). + Str("leaseTime", lease.LeaseTime.String()). + Interface("data", lease). + Msg(msg) + + return nil +} + +func (c *DHCPClient) GetLease() *Lease { + return c.lease +} diff --git a/internal/websecure/store.go b/internal/websecure/store.go index 69ae3ef..ea7911c 100644 --- a/internal/websecure/store.go +++ b/internal/websecure/store.go @@ -96,7 +96,11 @@ func (s *CertStore) loadCertificate(hostname string) { s.certificates[hostname] = &cert - s.log.Info().Str("hostname", hostname).Msg("Loaded certificate") + if hostname == selfSignerCAMagicName { + s.log.Info().Msg("loaded CA certificate") + } else { + s.log.Info().Str("hostname", hostname).Msg("loaded certificate") + } } // GetCertificate returns the certificate for the given hostname @@ -131,7 +135,7 @@ func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key if !ignoreWarning { return nil, fmt.Errorf("certificate does not match hostname: %w", err) } - s.log.Warn().Err(err).Msg("Certificate does not match hostname") + s.log.Warn().Err(err).Msg("certificate does not match hostname") } } diff --git a/jsonrpc.go b/jsonrpc.go index 248390e..d35f635 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -962,6 +962,10 @@ var rpcHandlers = map[string]RPCHandler{ "getDeviceID": {Func: rpcGetDeviceID}, "deregisterDevice": {Func: rpcDeregisterDevice}, "getCloudState": {Func: rpcGetCloudState}, + "getNetworkState": {Func: rpcGetNetworkState}, + "getNetworkSettings": {Func: rpcGetNetworkSettings}, + "setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}}, + "renewDHCPLease": {Func: rpcRenewDHCPLease}, "keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}}, "absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}}, "relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}}, diff --git a/log.go b/log.go index ed46852..b353a2c 100644 --- a/log.go +++ b/log.go @@ -1,291 +1,32 @@ package kvm import ( - "fmt" - "io" - "os" - "strings" - "sync" - "time" - - "github.com/pion/logging" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/zerolog" ) -type Logger struct { - l *zerolog.Logger - scopeLoggers map[string]*zerolog.Logger - scopeLevels map[string]zerolog.Level - scopeLevelMutex sync.Mutex - - defaultLogLevelFromEnv zerolog.Level - defaultLogLevelFromConfig zerolog.Level - defaultLogLevel zerolog.Level -} - -const ( - defaultLogLevel = zerolog.ErrorLevel -) - -type logOutput struct { - mu *sync.Mutex -} - -func (w *logOutput) Write(p []byte) (n int, err error) { - w.mu.Lock() - defer w.mu.Unlock() - - // TODO: write to file or syslog - - return len(p), nil -} - -var ( - consoleLogOutput io.Writer = zerolog.ConsoleWriter{ - Out: os.Stdout, - TimeFormat: time.RFC3339, - PartsOrder: []string{"time", "level", "scope", "component", "message"}, - FieldsExclude: []string{"scope", "component"}, - FormatPartValueByName: func(value interface{}, name string) string { - val := fmt.Sprintf("%s", value) - if name == "component" { - if value == nil { - return "-" - } - } - return val - }, - } - fileLogOutput io.Writer = &logOutput{mu: &sync.Mutex{}} - defaultLogOutput = zerolog.MultiLevelWriter(consoleLogOutput, fileLogOutput) - - zerologLevels = map[string]zerolog.Level{ - "DISABLE": zerolog.Disabled, - "NOLEVEL": zerolog.NoLevel, - "PANIC": zerolog.PanicLevel, - "FATAL": zerolog.FatalLevel, - "ERROR": zerolog.ErrorLevel, - "WARN": zerolog.WarnLevel, - "INFO": zerolog.InfoLevel, - "DEBUG": zerolog.DebugLevel, - "TRACE": zerolog.TraceLevel, - } - - rootZerologLogger = zerolog.New(defaultLogOutput).With(). - Str("scope", "jetkvm"). - Timestamp(). - Stack(). - Logger() - rootLogger = NewLogger(rootZerologLogger) -) - -func NewLogger(zerologLogger zerolog.Logger) *Logger { - return &Logger{ - l: &zerologLogger, - scopeLoggers: make(map[string]*zerolog.Logger), - scopeLevels: make(map[string]zerolog.Level), - scopeLevelMutex: sync.Mutex{}, - defaultLogLevelFromEnv: -2, - defaultLogLevelFromConfig: -2, - defaultLogLevel: defaultLogLevel, - } -} - -func (l *Logger) updateLogLevel() { - l.scopeLevelMutex.Lock() - defer l.scopeLevelMutex.Unlock() - - l.scopeLevels = make(map[string]zerolog.Level) - - finalDefaultLogLevel := l.defaultLogLevel - - for name, level := range zerologLevels { - env := os.Getenv(fmt.Sprintf("JETKVM_LOG_%s", name)) - - if env == "" { - env = os.Getenv(fmt.Sprintf("PION_LOG_%s", name)) - } - - if env == "" { - env = os.Getenv(fmt.Sprintf("PIONS_LOG_%s", name)) - } - - if env == "" { - continue - } - - if strings.ToLower(env) == "all" { - l.defaultLogLevelFromEnv = level - - if finalDefaultLogLevel > level { - finalDefaultLogLevel = level - } - - continue - } - - scopes := strings.Split(strings.ToLower(env), ",") - for _, scope := range scopes { - l.scopeLevels[scope] = level - } - } - - l.defaultLogLevel = finalDefaultLogLevel -} - -func (l *Logger) getScopeLoggerLevel(scope string) zerolog.Level { - if l.scopeLevels == nil { - l.updateLogLevel() - } - - var scopeLevel zerolog.Level - if l.defaultLogLevelFromConfig != -2 { - scopeLevel = l.defaultLogLevelFromConfig - } - if l.defaultLogLevelFromEnv != -2 { - scopeLevel = l.defaultLogLevelFromEnv - } - - // if the scope is not in the map, use the default level from the root logger - if level, ok := l.scopeLevels[scope]; ok { - scopeLevel = level - } - - return scopeLevel -} - -func (l *Logger) newScopeLogger(scope string) zerolog.Logger { - scopeLevel := l.getScopeLoggerLevel(scope) - logger := l.l.Level(scopeLevel).With().Str("component", scope).Logger() - - return logger -} - -func (l *Logger) getLogger(scope string) *zerolog.Logger { - logger, ok := l.scopeLoggers[scope] - if !ok || logger == nil { - scopeLogger := l.newScopeLogger(scope) - l.scopeLoggers[scope] = &scopeLogger - } - - return l.scopeLoggers[scope] -} - -func (l *Logger) UpdateLogLevel() { - needUpdate := false - - if config != nil && config.DefaultLogLevel != "" { - if logLevel, ok := zerologLevels[config.DefaultLogLevel]; ok { - l.defaultLogLevelFromConfig = logLevel - } else { - l.l.Warn().Str("logLevel", config.DefaultLogLevel).Msg("invalid defaultLogLevel from config, using ERROR") - } - - if l.defaultLogLevelFromConfig != l.defaultLogLevel { - needUpdate = true - } - } - - l.updateLogLevel() - - if needUpdate { - for scope, logger := range l.scopeLoggers { - currentLevel := logger.GetLevel() - targetLevel := l.getScopeLoggerLevel(scope) - if currentLevel != targetLevel { - *logger = l.newScopeLogger(scope) - } - } - } -} - func ErrorfL(l *zerolog.Logger, format string, err error, args ...interface{}) error { - if l == nil { - l = rootLogger.getLogger("jetkvm") - } - - l.Error().Err(err).Msgf(format, args...) - - if err == nil { - return fmt.Errorf(format, args...) - } - - err_msg := err.Error() + ": %v" - err_args := append(args, err) - - return fmt.Errorf(err_msg, err_args...) + return logging.ErrorfL(l, format, err, args...) } var ( - logger = rootLogger.getLogger("jetkvm") - cloudLogger = rootLogger.getLogger("cloud") - websocketLogger = rootLogger.getLogger("websocket") - webrtcLogger = rootLogger.getLogger("webrtc") - nativeLogger = rootLogger.getLogger("native") - nbdLogger = rootLogger.getLogger("nbd") - ntpLogger = rootLogger.getLogger("ntp") - jsonRpcLogger = rootLogger.getLogger("jsonrpc") - watchdogLogger = rootLogger.getLogger("watchdog") - websecureLogger = rootLogger.getLogger("websecure") - otaLogger = rootLogger.getLogger("ota") - serialLogger = rootLogger.getLogger("serial") - terminalLogger = rootLogger.getLogger("terminal") - displayLogger = rootLogger.getLogger("display") - wolLogger = rootLogger.getLogger("wol") - usbLogger = rootLogger.getLogger("usb") + logger = logging.GetSubsystemLogger("jetkvm") + networkLogger = logging.GetSubsystemLogger("network") + cloudLogger = logging.GetSubsystemLogger("cloud") + websocketLogger = logging.GetSubsystemLogger("websocket") + webrtcLogger = logging.GetSubsystemLogger("webrtc") + nativeLogger = logging.GetSubsystemLogger("native") + nbdLogger = logging.GetSubsystemLogger("nbd") + timesyncLogger = logging.GetSubsystemLogger("timesync") + jsonRpcLogger = logging.GetSubsystemLogger("jsonrpc") + watchdogLogger = logging.GetSubsystemLogger("watchdog") + websecureLogger = logging.GetSubsystemLogger("websecure") + otaLogger = logging.GetSubsystemLogger("ota") + serialLogger = logging.GetSubsystemLogger("serial") + terminalLogger = logging.GetSubsystemLogger("terminal") + displayLogger = logging.GetSubsystemLogger("display") + wolLogger = logging.GetSubsystemLogger("wol") + usbLogger = logging.GetSubsystemLogger("usb") // external components - ginLogger = rootLogger.getLogger("gin") + ginLogger = logging.GetSubsystemLogger("gin") ) - -type pionLogger struct { - logger *zerolog.Logger -} - -// Print all messages except trace. -func (c pionLogger) Trace(msg string) { - c.logger.Trace().Msg(msg) -} -func (c pionLogger) Tracef(format string, args ...interface{}) { - c.logger.Trace().Msgf(format, args...) -} - -func (c pionLogger) Debug(msg string) { - c.logger.Debug().Msg(msg) -} -func (c pionLogger) Debugf(format string, args ...interface{}) { - c.logger.Debug().Msgf(format, args...) -} -func (c pionLogger) Info(msg string) { - c.logger.Info().Msg(msg) -} -func (c pionLogger) Infof(format string, args ...interface{}) { - c.logger.Info().Msgf(format, args...) -} -func (c pionLogger) Warn(msg string) { - c.logger.Warn().Msg(msg) -} -func (c pionLogger) Warnf(format string, args ...interface{}) { - c.logger.Warn().Msgf(format, args...) -} -func (c pionLogger) Error(msg string) { - c.logger.Error().Msg(msg) -} -func (c pionLogger) Errorf(format string, args ...interface{}) { - c.logger.Error().Msgf(format, args...) -} - -// customLoggerFactory satisfies the interface logging.LoggerFactory -// This allows us to create different loggers per subsystem. So we can -// add custom behavior. -type pionLoggerFactory struct{} - -func (c pionLoggerFactory) NewLogger(subsystem string) logging.LeveledLogger { - logger := rootLogger.getLogger(subsystem).With(). - Str("scope", "pion"). - Str("component", subsystem). - Logger() - - return pionLogger{logger: &logger} -} - -var defaultLoggerFactory = &pionLoggerFactory{} diff --git a/main.go b/main.go index 9eab708..25fbb3a 100644 --- a/main.go +++ b/main.go @@ -15,28 +15,54 @@ var appCtx context.Context func Main() { LoadConfig() - logger.Debug().Msg("config loaded") var cancel context.CancelFunc appCtx, cancel = context.WithCancel(context.Background()) defer cancel() - logger.Info().Msg("starting JetKvm") + + systemVersionLocal, appVersionLocal, err := GetLocalVersion() + if err != nil { + logger.Warn().Err(err).Msg("failed to get local version") + } + + logger.Info(). + Interface("system_version", systemVersionLocal). + Interface("app_version", appVersionLocal). + Msg("starting JetKVM") go runWatchdog() go confirmCurrentSystem() http.DefaultClient.Timeout = 1 * time.Minute - err := rootcerts.UpdateDefaultTransport() + err = rootcerts.UpdateDefaultTransport() if err != nil { - logger.Warn().Err(err).Msg("failed to load CA certs") + logger.Warn().Err(err).Msg("failed to load Root CA certificates") + } + logger.Info(). + Int("ca_certs_loaded", len(rootcerts.Certs())). + Msg("loaded Root CA certificates") + + // Initialize network + if err := initNetwork(); err != nil { + logger.Error().Err(err).Msg("failed to initialize network") + os.Exit(1) } - initNetwork() + // Initialize time sync + initTimeSync() + timeSync.Start() - go TimeSyncLoop() + // Initialize mDNS + if err := initMdns(); err != nil { + logger.Error().Err(err).Msg("failed to initialize mDNS") + os.Exit(1) + } + // Initialize native ctrl socket server StartNativeCtrlSocketServer() + + // Initialize native video socket server StartNativeVideoSocketServer() initPrometheus() diff --git a/mdns.go b/mdns.go new file mode 100644 index 0000000..d7a3b55 --- /dev/null +++ b/mdns.go @@ -0,0 +1,29 @@ +package kvm + +import ( + "github.com/jetkvm/kvm/internal/mdns" +) + +var mDNS *mdns.MDNS + +func initMdns() error { + m, err := mdns.NewMDNS(&mdns.MDNSOptions{ + Logger: logger, + LocalNames: []string{ + networkState.GetHostname(), + networkState.GetFQDN(), + }, + ListenOptions: &mdns.MDNSListenOptions{ + IPv4: true, + IPv6: true, + }, + }) + if err != nil { + return err + } + + // do not start the server yet, as we need to wait for the network state to be set + mDNS = m + + return nil +} diff --git a/native.go b/native.go index b61598c..496f580 100644 --- a/native.go +++ b/native.go @@ -8,13 +8,10 @@ import ( "io" "net" "os" - "os/exec" "sync" - "syscall" "time" "github.com/jetkvm/kvm/resource" - "github.com/rs/zerolog" "github.com/pion/webrtc/v4/pkg/media" ) @@ -36,19 +33,6 @@ type CtrlResponse struct { Data json.RawMessage `json:"data,omitempty"` } -type nativeOutput struct { - mu *sync.Mutex - logger *zerolog.Event -} - -func (w *nativeOutput) Write(p []byte) (n int, err error) { - w.mu.Lock() - defer w.mu.Unlock() - - w.logger.Msg(string(p)) - return len(p), nil -} - type EventHandler func(event CtrlResponse) var seq int32 = 1 @@ -262,30 +246,8 @@ func ExtractAndRunNativeBin() error { return fmt.Errorf("failed to make binary executable: %w", err) } // Run the binary in the background - cmd := exec.Command(binaryPath) - - nativeOutputLock := sync.Mutex{} - nativeStdout := &nativeOutput{ - mu: &nativeOutputLock, - logger: nativeLogger.Info().Str("pipe", "stdout"), - } - nativeStderr := &nativeOutput{ - mu: &nativeOutputLock, - logger: nativeLogger.Info().Str("pipe", "stderr"), - } - - // Redirect stdout and stderr to the current process - cmd.Stdout = nativeStdout - cmd.Stderr = nativeStderr - - // Set the process group ID so we can kill the process and its children when this process exits - cmd.SysProcAttr = &syscall.SysProcAttr{ - Setpgid: true, - Pdeathsig: syscall.SIGKILL, - } - - // Start the command - if err := cmd.Start(); err != nil { + cmd, err := startNativeBinary(binaryPath) + if err != nil { return fmt.Errorf("failed to start binary: %w", err) } @@ -335,7 +297,10 @@ func ensureBinaryUpdated(destPath string) error { _, err = os.Stat(destPath) if shouldOverwrite(destPath, srcHash) || err != nil { - nativeLogger.Info().Msg("writing jetkvm_native") + nativeLogger.Info(). + Interface("hash", srcHash). + Msg("writing jetkvm_native") + _ = os.Remove(destPath) destFile, err := os.OpenFile(destPath, os.O_CREATE|os.O_RDWR, 0755) if err != nil { diff --git a/native_linux.go b/native_linux.go new file mode 100644 index 0000000..54d2150 --- /dev/null +++ b/native_linux.go @@ -0,0 +1,57 @@ +//go:build linux + +package kvm + +import ( + "fmt" + "os/exec" + "sync" + "syscall" + + "github.com/rs/zerolog" +) + +type nativeOutput struct { + mu *sync.Mutex + logger *zerolog.Event +} + +func (w *nativeOutput) Write(p []byte) (n int, err error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.logger.Msg(string(p)) + return len(p), nil +} + +func startNativeBinary(binaryPath string) (*exec.Cmd, error) { + // Run the binary in the background + cmd := exec.Command(binaryPath) + + nativeOutputLock := sync.Mutex{} + nativeStdout := &nativeOutput{ + mu: &nativeOutputLock, + logger: nativeLogger.Info().Str("pipe", "stdout"), + } + nativeStderr := &nativeOutput{ + mu: &nativeOutputLock, + logger: nativeLogger.Info().Str("pipe", "stderr"), + } + + // Redirect stdout and stderr to the current process + cmd.Stdout = nativeStdout + cmd.Stderr = nativeStderr + + // Set the process group ID so we can kill the process and its children when this process exits + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Pdeathsig: syscall.SIGKILL, + } + + // Start the command + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start binary: %w", err) + } + + return cmd, nil +} diff --git a/native_notlinux.go b/native_notlinux.go new file mode 100644 index 0000000..df6df74 --- /dev/null +++ b/native_notlinux.go @@ -0,0 +1,12 @@ +//go:build !linux + +package kvm + +import ( + "fmt" + "os/exec" +) + +func startNativeBinary(binaryPath string) (*exec.Cmd, error) { + return nil, fmt.Errorf("not supported") +} diff --git a/network.go b/network.go index 6948d9a..8d9261b 100644 --- a/network.go +++ b/network.go @@ -1,237 +1,107 @@ package kvm import ( - "bytes" "fmt" - "net" - "os" - "strings" - "time" - "os/exec" - - "github.com/hashicorp/go-envparse" - "github.com/pion/mdns/v2" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - - "github.com/vishvananda/netlink" - "github.com/vishvananda/netlink/nl" + "github.com/jetkvm/kvm/internal/network" + "github.com/jetkvm/kvm/internal/udhcpc" ) -var mDNSConn *mdns.Conn - -var networkState NetworkState - -type NetworkState struct { - Up bool - IPv4 string - IPv6 string - MAC string - - checked bool -} - -type LocalIpInfo struct { - IPv4 string - IPv6 string - MAC string -} - const ( - NetIfName = "eth0" - DHCPLeaseFile = "/run/udhcpc.%s.info" + NetIfName = "eth0" ) -// setDhcpClientState sends signals to udhcpc to change it's current mode -// of operation. Setting active to true will force udhcpc to renew the DHCP lease. -// Setting active to false will put udhcpc into idle mode. -func setDhcpClientState(active bool) { - var signal string - if active { - signal = "-SIGUSR1" - } else { - signal = "-SIGUSR2" - } +var ( + networkState *network.NetworkInterfaceState +) - cmd := exec.Command("/usr/bin/killall", signal, "udhcpc") - if err := cmd.Run(); err != nil { - logger.Warn().Err(err).Msg("network: setDhcpClientState: failed to change udhcpc state") +func networkStateChanged() { + // do not block the main thread + go waitCtrlAndRequestDisplayUpdate(true) + + // always restart mDNS when the network state changes + if mDNS != nil { + _ = mDNS.SetLocalNames([]string{ + networkState.GetHostname(), + networkState.GetFQDN(), + }, true) } } -func checkNetworkState() { - iface, err := netlink.LinkByName(NetIfName) - if err != nil { - logger.Warn().Err(err).Str("interface", NetIfName).Msg("failed to get interface") - return - } +func initNetwork() error { + ensureConfigLoaded() - newState := NetworkState{ - Up: iface.Attrs().OperState == netlink.OperUp, - MAC: iface.Attrs().HardwareAddr.String(), + state, err := network.NewNetworkInterfaceState(&network.NetworkInterfaceOptions{ + DefaultHostname: GetDefaultHostname(), + InterfaceName: NetIfName, + NetworkConfig: config.NetworkConfig, + Logger: networkLogger, + OnStateChange: func(state *network.NetworkInterfaceState) { + networkStateChanged() + }, + OnInitialCheck: func(state *network.NetworkInterfaceState) { + networkStateChanged() + }, + OnDhcpLeaseChange: func(lease *udhcpc.Lease) { + networkStateChanged() - checked: true, - } - - addrs, err := netlink.AddrList(iface, nl.FAMILY_ALL) - if err != nil { - logger.Warn().Err(err).Str("interface", NetIfName).Msg("failed to get addresses") - } - - // If the link is going down, put udhcpc into idle mode. - // If the link is coming back up, activate udhcpc and force it to renew the lease. - if newState.Up != networkState.Up { - setDhcpClientState(newState.Up) - } - - for _, addr := range addrs { - if addr.IP.To4() != nil { - if !newState.Up && networkState.Up { - // If the network is going down, remove all IPv4 addresses from the interface. - logger.Info().Str("address", addr.IP.String()).Msg("network: state transitioned to down, removing IPv4 address") - err := netlink.AddrDel(iface, &addr) - if err != nil { - logger.Warn().Err(err).Str("address", addr.IP.String()).Msg("network: failed to delete address") - } - - newState.IPv4 = "..." - } else { - newState.IPv4 = addr.IP.String() + if currentSession == nil { + return } - } else if addr.IP.To16() != nil && newState.IPv6 == "" { - newState.IPv6 = addr.IP.String() - } - } - if newState != networkState { - logger.Info(). - Interface("newState", newState). - Interface("oldState", networkState). - Msg("network state changed") + writeJSONRPCEvent("networkState", networkState.RpcGetNetworkState(), currentSession) + }, + OnConfigChange: func(networkConfig *network.NetworkConfig) { + config.NetworkConfig = networkConfig + networkStateChanged() - // restart MDNS - _ = startMDNS() - networkState = newState - requestDisplayUpdate() - } -} - -func startMDNS() error { - // If server was previously running, stop it - if mDNSConn != nil { - logger.Info().Msg("stopping mDNS server") - err := mDNSConn.Close() - if err != nil { - logger.Warn().Err(err).Msg("failed to stop mDNS server") - } - } - - // Start a new server - hostname := "jetkvm.local" - - scopedLogger := logger.With().Str("hostname", hostname).Logger() - scopedLogger.Info().Msg("starting mDNS server") - - addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) - if err != nil { - return err - } - - addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) - if err != nil { - return err - } - - l4, err := net.ListenUDP("udp4", addr4) - if err != nil { - return err - } - - l6, err := net.ListenUDP("udp6", addr6) - if err != nil { - return err - } - - mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &mdns.Config{ - LocalNames: []string{hostname}, //TODO: make it configurable - LoggerFactory: defaultLoggerFactory, + if mDNS != nil { + _ = mDNS.SetListenOptions(networkConfig.GetMDNSMode()) + _ = mDNS.SetLocalNames([]string{ + networkState.GetHostname(), + networkState.GetFQDN(), + }, true) + } + }, }) - if err != nil { - scopedLogger.Warn().Err(err).Msg("failed to start mDNS server") - mDNSConn = nil + + if state == nil { + if err == nil { + return fmt.Errorf("failed to create NetworkInterfaceState") + } return err } - //defer server.Close() + + if err := state.Run(); err != nil { + return err + } + + networkState = state + return nil } -func getNTPServersFromDHCPInfo() ([]string, error) { - buf, err := os.ReadFile(fmt.Sprintf(DHCPLeaseFile, NetIfName)) - if err != nil { - // do not return error if file does not exist - if os.IsNotExist(err) { - return nil, nil - } - return nil, fmt.Errorf("failed to load udhcpc info: %w", err) - } - - // parse udhcpc info - env, err := envparse.Parse(bytes.NewReader(buf)) - if err != nil { - return nil, fmt.Errorf("failed to parse udhcpc info: %w", err) - } - - val, ok := env["ntpsrv"] - if !ok { - return nil, nil - } - - var servers []string - - for _, server := range strings.Fields(val) { - if net.ParseIP(server) == nil { - logger.Info().Str("server", server).Msg("invalid NTP server IP, ignoring") - } - servers = append(servers, server) - } - - return servers, nil +func rpcGetNetworkState() network.RpcNetworkState { + return networkState.RpcGetNetworkState() } -func initNetwork() { - ensureConfigLoaded() - - updates := make(chan netlink.LinkUpdate) - done := make(chan struct{}) - - if err := netlink.LinkSubscribe(updates, done); err != nil { - logger.Warn().Err(err).Msg("failed to subscribe to link updates") - return - } - - go func() { - waitCtrlClientConnected() - checkNetworkState() - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case update := <-updates: - if update.Link.Attrs().Name == NetIfName { - logger.Info().Interface("update", update).Msg("link update") - checkNetworkState() - } - case <-ticker.C: - checkNetworkState() - case <-done: - return - } - } - }() - err := startMDNS() - if err != nil { - logger.Warn().Err(err).Msg("failed to run mDNS") - } +func rpcGetNetworkSettings() network.RpcNetworkSettings { + return networkState.RpcGetNetworkSettings() +} + +func rpcSetNetworkSettings(settings network.RpcNetworkSettings) (*network.RpcNetworkSettings, error) { + s := networkState.RpcSetNetworkSettings(settings) + if s != nil { + return nil, s + } + + if err := SaveConfig(); err != nil { + return nil, err + } + + return &network.RpcNetworkSettings{NetworkConfig: *config.NetworkConfig}, nil +} + +func rpcRenewDHCPLease() error { + return networkState.RpcRenewDHCPLease() } diff --git a/ntp.go b/ntp.go deleted file mode 100644 index a104c56..0000000 --- a/ntp.go +++ /dev/null @@ -1,197 +0,0 @@ -package kvm - -import ( - "fmt" - "net/http" - "os/exec" - "strconv" - "time" - - "github.com/beevik/ntp" -) - -const ( - timeSyncRetryStep = 5 * time.Second - timeSyncRetryMaxInt = 1 * time.Minute - timeSyncWaitNetChkInt = 100 * time.Millisecond - timeSyncWaitNetUpInt = 3 * time.Second - timeSyncInterval = 1 * time.Hour - timeSyncTimeout = 2 * time.Second -) - -var ( - builtTimestamp string - timeSyncRetryInterval = 0 * time.Second - timeSyncSuccess = false - defaultNTPServers = []string{ - "time.cloudflare.com", - "time.apple.com", - } -) - -func isTimeSyncNeeded() bool { - if builtTimestamp == "" { - ntpLogger.Warn().Msg("Built timestamp is not set, time sync is needed") - return true - } - - ts, err := strconv.Atoi(builtTimestamp) - if err != nil { - ntpLogger.Warn().Str("error", err.Error()).Msg("Failed to parse built timestamp") - return true - } - - // builtTimestamp is UNIX timestamp in seconds - builtTime := time.Unix(int64(ts), 0) - now := time.Now() - - ntpLogger.Debug().Str("built_time", builtTime.Format(time.RFC3339)).Str("now", now.Format(time.RFC3339)).Msg("Built time and now") - - if now.Sub(builtTime) < 0 { - ntpLogger.Warn().Msg("System time is behind the built time, time sync is needed") - return true - } - - return false -} - -func TimeSyncLoop() { - for { - if !networkState.checked { - time.Sleep(timeSyncWaitNetChkInt) - continue - } - - if !networkState.Up { - ntpLogger.Info().Msg("Waiting for network to come up") - time.Sleep(timeSyncWaitNetUpInt) - continue - } - - // check if time sync is needed, but do nothing for now - isTimeSyncNeeded() - - ntpLogger.Info().Msg("Syncing system time") - start := time.Now() - err := SyncSystemTime() - if err != nil { - ntpLogger.Error().Str("error", err.Error()).Msg("Failed to sync system time") - - // retry after a delay - timeSyncRetryInterval += timeSyncRetryStep - time.Sleep(timeSyncRetryInterval) - // reset the retry interval if it exceeds the max interval - if timeSyncRetryInterval > timeSyncRetryMaxInt { - timeSyncRetryInterval = 0 - } - - continue - } - timeSyncSuccess = true - ntpLogger.Info().Str("now", time.Now().Format(time.RFC3339)). - Str("time_taken", time.Since(start).String()). - Msg("Time sync successful") - time.Sleep(timeSyncInterval) // after the first sync is done - } -} - -func SyncSystemTime() (err error) { - now, err := queryNetworkTime() - if err != nil { - return fmt.Errorf("failed to query network time: %w", err) - } - err = setSystemTime(*now) - if err != nil { - return fmt.Errorf("failed to set system time: %w", err) - } - return nil -} - -func queryNetworkTime() (*time.Time, error) { - ntpServers, err := getNTPServersFromDHCPInfo() - if err != nil { - ntpLogger.Info().Err(err).Msg("failed to get NTP servers from DHCP info") - } - - if ntpServers == nil { - ntpServers = defaultNTPServers - ntpLogger.Info(). - Interface("ntp_servers", ntpServers). - Msg("Using default NTP servers") - } else { - ntpLogger.Info(). - Interface("ntp_servers", ntpServers). - Msg("Using NTP servers from DHCP") - } - - for _, server := range ntpServers { - now, err := queryNtpServer(server, timeSyncTimeout) - if err == nil { - ntpLogger.Info(). - Str("ntp_server", server). - Str("time", now.Format(time.RFC3339)). - Msg("NTP server returned time") - return now, nil - } else { - ntpLogger.Error(). - Str("ntp_server", server). - Str("error", err.Error()). - Msg("failed to query NTP server") - } - } - - httpUrls := []string{ - "http://apple.com", - "http://cloudflare.com", - } - for _, url := range httpUrls { - now, err := queryHttpTime(url, timeSyncTimeout) - if err == nil { - ntpLogger.Info(). - Str("http_url", url). - Str("time", now.Format(time.RFC3339)). - Msg("HTTP server returned time") - return now, nil - } else { - ntpLogger.Error(). - Str("http_url", url). - Str("error", err.Error()). - Msg("failed to query HTTP server") - } - } - - return nil, ErrorfL(ntpLogger, "failed to query network time, all NTP servers and HTTP servers failed", nil) -} - -func queryNtpServer(server string, timeout time.Duration) (now *time.Time, err error) { - resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout}) - if err != nil { - return nil, err - } - return &resp.Time, nil -} - -func queryHttpTime(url string, timeout time.Duration) (*time.Time, error) { - client := http.Client{ - Timeout: timeout, - } - resp, err := client.Head(url) - if err != nil { - return nil, err - } - dateStr := resp.Header.Get("Date") - now, err := time.Parse(time.RFC1123, dateStr) - if err != nil { - return nil, err - } - return &now, nil -} - -func setSystemTime(now time.Time) error { - nowStr := now.Format("2006-01-02 15:04:05") - output, err := exec.Command("date", "-s", nowStr).CombinedOutput() - if err != nil { - return fmt.Errorf("failed to run date -s: %w, %s", err, string(output)) - } - return nil -} diff --git a/ota.go b/ota.go index a5da772..0559978 100644 --- a/ota.go +++ b/ota.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/sha256" + "crypto/tls" "encoding/hex" "encoding/json" "fmt" @@ -16,6 +17,7 @@ import ( "time" "github.com/Masterminds/semver/v3" + "github.com/gwatts/rootcerts" "github.com/rs/zerolog" ) @@ -127,10 +129,14 @@ func downloadFile(ctx context.Context, path string, url string, downloadProgress return fmt.Errorf("error creating request: %w", err) } - // TODO: set a separate timeout for the download but keep the TLS handshake short - // use Transport here will cause CA certificate validation failure so we temporarily removed it client := http.Client{ Timeout: 10 * time.Minute, + Transport: &http.Transport{ + TLSHandshakeTimeout: 30 * time.Second, + TLSClientConfig: &tls.Config{ + RootCAs: rootcerts.ServerCertPool(), + }, + }, } resp, err := client.Do(req) diff --git a/resource/jetkvm_native b/resource/jetkvm_native index 0d0719c..084ce14 100644 Binary files a/resource/jetkvm_native and b/resource/jetkvm_native differ diff --git a/resource/jetkvm_native.sha256 b/resource/jetkvm_native.sha256 index 65da816..b540b94 100644 --- a/resource/jetkvm_native.sha256 +++ b/resource/jetkvm_native.sha256 @@ -1 +1 @@ -c0803a9185298398eff9a925de69bd0ca882cd5983b989a45b748648146475c6 +4b925c7aa73d2e35a227833e806658cb17e1d25900611f93ed70b11ac9f1716d diff --git a/timesync.go b/timesync.go new file mode 100644 index 0000000..7b25fe2 --- /dev/null +++ b/timesync.go @@ -0,0 +1,53 @@ +package kvm + +import ( + "strconv" + "time" + + "github.com/jetkvm/kvm/internal/timesync" +) + +var ( + timeSync *timesync.TimeSync + builtTimestamp string +) + +func isTimeSyncNeeded() bool { + if builtTimestamp == "" { + timesyncLogger.Warn().Msg("built timestamp is not set, time sync is needed") + return true + } + + ts, err := strconv.Atoi(builtTimestamp) + if err != nil { + timesyncLogger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp") + return true + } + + // builtTimestamp is UNIX timestamp in seconds + builtTime := time.Unix(int64(ts), 0) + now := time.Now() + + if now.Sub(builtTime) < 0 { + timesyncLogger.Warn(). + Str("built_time", builtTime.Format(time.RFC3339)). + Str("now", now.Format(time.RFC3339)). + Msg("system time is behind the built time, time sync is needed") + return true + } + + return false +} + +func initTimeSync() { + timeSync = timesync.NewTimeSync(×ync.TimeSyncOptions{ + Logger: timesyncLogger, + NetworkConfig: config.NetworkConfig, + PreCheckFunc: func() (bool, error) { + if !networkState.IsOnline() { + return false, nil + } + return true, nil + }, + }) +} diff --git a/ui/dev_device.sh b/ui/dev_device.sh index 650cadd..2c7b497 100755 --- a/ui/dev_device.sh +++ b/ui/dev_device.sh @@ -15,5 +15,15 @@ echo "└─────────────────────── # Set the environment variable and run Vite echo "Starting development server with JetKVM device at: $ip_address" + +# Check if pwd is the current directory of the script +if [ "$(pwd)" != "$(dirname "$0")" ]; then + pushd "$(dirname "$0")" > /dev/null + echo "Changed directory to: $(pwd)" +fi + sleep 1 + JETKVM_PROXY_URL="ws://$ip_address" npx vite dev --mode=device + +popd > /dev/null diff --git a/ui/package-lock.json b/ui/package-lock.json index b51a2ea..9e77e10 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -19,6 +19,7 @@ "@xterm/addon-webgl": "^0.18.0", "@xterm/xterm": "^5.5.0", "cva": "^1.0.0-beta.1", + "dayjs": "^1.11.13", "eslint-import-resolver-alias": "^1.1.2", "focus-trap-react": "^10.2.3", "framer-motion": "^11.15.0", @@ -2433,6 +2434,11 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/dayjs": { + "version": "1.11.13", + "resolved": "https://registry.npmjs.org/dayjs/-/dayjs-1.11.13.tgz", + "integrity": "sha512-oaMBel6gjolK862uaPQOVTA7q3TZhuSvuMQAAglQDOWYO9A91IrAOUJEyKVlqJlHE0vq5p5UXxzdPfMH/x6xNg==" + }, "node_modules/debug": { "version": "4.3.4", "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", diff --git a/ui/package.json b/ui/package.json index 3160297..4dab092 100644 --- a/ui/package.json +++ b/ui/package.json @@ -30,6 +30,7 @@ "@xterm/addon-webgl": "^0.18.0", "@xterm/xterm": "^5.5.0", "cva": "^1.0.0-beta.1", + "dayjs": "^1.11.13", "eslint-import-resolver-alias": "^1.1.2", "focus-trap-react": "^10.2.3", "framer-motion": "^11.15.0", diff --git a/ui/public/sse.html b/ui/public/sse.html new file mode 120000 index 0000000..0a8b4f3 --- /dev/null +++ b/ui/public/sse.html @@ -0,0 +1 @@ +../../internal/logging/sse.html \ No newline at end of file diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index 0fa4121..db1fd04 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -663,6 +663,95 @@ export const useDeviceStore = create(set => ({ setSystemVersion: version => set({ systemVersion: version }), })); +export interface DhcpLease { + ip?: string; + netmask?: string; + broadcast?: string; + ttl?: string; + mtu?: string; + hostname?: string; + domain?: string; + bootp_next_server?: string; + bootp_server_name?: string; + bootp_file?: string; + timezone?: string; + routers?: string[]; + dns?: string[]; + ntp_servers?: string[]; + lpr_servers?: string[]; + _time_servers?: string[]; + _name_servers?: string[]; + _log_servers?: string[]; + _cookie_servers?: string[]; + _wins_servers?: string[]; + _swap_server?: string; + boot_size?: string; + root_path?: string; + lease?: string; + lease_expiry?: Date; + dhcp_type?: string; + server_id?: string; + message?: string; + tftp?: string; + bootfile?: string; +} + +export interface IPv6Address { + address: string; + prefix: string; + valid_lifetime: string; + preferred_lifetime: string; + scope: string; +} + +export interface NetworkState { + interface_name?: string; + mac_address?: string; + ipv4?: string; + ipv4_addresses?: string[]; + ipv6?: string; + ipv6_addresses?: IPv6Address[]; + ipv6_link_local?: string; + dhcp_lease?: DhcpLease; + + setNetworkState: (state: NetworkState) => void; + setDhcpLease: (lease: NetworkState["dhcp_lease"]) => void; + setDhcpLeaseExpiry: (expiry: Date) => void; +} + + +export type IPv6Mode = "disabled" | "slaac" | "dhcpv6" | "slaac_and_dhcpv6" | "static" | "link_local" | "unknown"; +export type IPv4Mode = "disabled" | "static" | "dhcp" | "unknown"; +export type LLDPMode = "disabled" | "basic" | "all" | "unknown"; +export type mDNSMode = "disabled" | "auto" | "ipv4_only" | "ipv6_only" | "unknown"; +export type TimeSyncMode = "ntp_only" | "ntp_and_http" | "http_only" | "custom" | "unknown"; + +export interface NetworkSettings { + hostname: string; + domain: string; + ipv4_mode: IPv4Mode; + ipv6_mode: IPv6Mode; + lldp_mode: LLDPMode; + lldp_tx_tlvs: string[]; + mdns_mode: mDNSMode; + time_sync_mode: TimeSyncMode; +} + +export const useNetworkStateStore = create((set, get) => ({ + setNetworkState: (state: NetworkState) => set(state), + setDhcpLease: (lease: NetworkState["dhcp_lease"]) => set({ dhcp_lease: lease }), + setDhcpLeaseExpiry: (expiry: Date) => { + const lease = get().dhcp_lease; + if (!lease) { + console.warn("No lease found"); + return; + } + + lease.lease_expiry = expiry; + set({ dhcp_lease: lease }); + } +})); + export interface KeySequenceStep { keys: string[]; modifiers: string[]; @@ -767,8 +856,8 @@ export const useMacrosStore = create((set, get) => ({ for (let i = 0; i < macro.steps.length; i++) { const step = macro.steps[i]; if (step.keys && step.keys.length > MAX_KEYS_PER_STEP) { - console.error(`Cannot save: macro "${macro.name}" step ${i+1} exceeds maximum of ${MAX_KEYS_PER_STEP} keys`); - throw new Error(`Cannot save: macro "${macro.name}" step ${i+1} exceeds maximum of ${MAX_KEYS_PER_STEP} keys`); + console.error(`Cannot save: macro "${macro.name}" step ${i + 1} exceeds maximum of ${MAX_KEYS_PER_STEP} keys`); + throw new Error(`Cannot save: macro "${macro.name}" step ${i + 1} exceeds maximum of ${MAX_KEYS_PER_STEP} keys`); } } } diff --git a/ui/src/main.tsx b/ui/src/main.tsx index e09a2a9..f4bdd34 100644 --- a/ui/src/main.tsx +++ b/ui/src/main.tsx @@ -42,6 +42,7 @@ import SettingsVideoRoute from "./routes/devices.$id.settings.video"; import SettingsAppearanceRoute from "./routes/devices.$id.settings.appearance"; import * as SettingsGeneralIndexRoute from "./routes/devices.$id.settings.general._index"; import SettingsGeneralUpdateRoute from "./routes/devices.$id.settings.general.update"; +import SettingsNetworkRoute from "./routes/devices.$id.settings.network"; import SecurityAccessLocalAuthRoute from "./routes/devices.$id.settings.access.local-auth"; import SettingsMacrosRoute from "./routes/devices.$id.settings.macros"; import SettingsMacrosAddRoute from "./routes/devices.$id.settings.macros.add"; @@ -156,6 +157,10 @@ if (isOnDevice) { path: "hardware", element: , }, + { + path: "network", + element: , + }, { path: "access", children: [ diff --git a/ui/src/routes/devices.$id.settings.network.tsx b/ui/src/routes/devices.$id.settings.network.tsx new file mode 100644 index 0000000..59d52ef --- /dev/null +++ b/ui/src/routes/devices.$id.settings.network.tsx @@ -0,0 +1,408 @@ +import { useCallback, useEffect, useState } from "react"; + +import { SelectMenuBasic } from "../components/SelectMenuBasic"; +import { SettingsPageHeader } from "../components/SettingsPageheader"; + +import { IPv4Mode, IPv6Mode, LLDPMode, mDNSMode, NetworkSettings, NetworkState, TimeSyncMode, useNetworkStateStore } from "@/hooks/stores"; +import { useJsonRpc } from "@/hooks/useJsonRpc"; +import notifications from "@/notifications"; +import { Button } from "@components/Button"; +import { GridCard } from "@components/Card"; +import InputField from "@components/InputField"; +import { SettingsItem } from "./devices.$id.settings"; + +import dayjs from 'dayjs'; +import relativeTime from 'dayjs/plugin/relativeTime'; + +dayjs.extend(relativeTime); + +const defaultNetworkSettings: NetworkSettings = { + hostname: "", + domain: "", + ipv4_mode: "unknown", + ipv6_mode: "unknown", + lldp_mode: "unknown", + lldp_tx_tlvs: [], + mdns_mode: "unknown", + time_sync_mode: "unknown", +} + +export function LifeTimeLabel({ lifetime }: { lifetime: string }) { + if (lifetime == "") { + return N/A; + } + + const [remaining, setRemaining] = useState(null); + + useEffect(() => { + setRemaining(dayjs(lifetime).fromNow()); + + const interval = setInterval(() => { + setRemaining(dayjs(lifetime).fromNow()); + }, 1000 * 30); + return () => clearInterval(interval); + }, [lifetime]); + + return <> + {dayjs(lifetime).format()} + {remaining && <> + {" "} + ({remaining}) + + } + +} + +export default function SettingsNetworkRoute() { + const [send] = useJsonRpc(); + const [networkState, setNetworkState] = useNetworkStateStore(state => [state, state.setNetworkState]); + + const [networkSettings, setNetworkSettings] = useState(defaultNetworkSettings); + const [networkSettingsLoaded, setNetworkSettingsLoaded] = useState(false); + + const getNetworkSettings = useCallback(() => { + setNetworkSettingsLoaded(false); + send("getNetworkSettings", {}, resp => { + if ("error" in resp) return; + console.log(resp.result); + setNetworkSettings(resp.result as NetworkSettings); + setNetworkSettingsLoaded(true); + }); + }, [send]); + + const setNetworkSettingsRemote = useCallback((settings: NetworkSettings) => { + setNetworkSettingsLoaded(false); + send("setNetworkSettings", { settings }, resp => { + if ("error" in resp) { + notifications.error("Failed to save network settings: " + (resp.error.data ? resp.error.data : resp.error.message)); + setNetworkSettingsLoaded(true); + return; + } + setNetworkSettings(resp.result as NetworkSettings); + setNetworkSettingsLoaded(true); + notifications.success("Network settings saved"); + }); + }, [send]); + + const getNetworkState = useCallback(() => { + send("getNetworkState", {}, resp => { + if ("error" in resp) return; + console.log(resp.result); + setNetworkState(resp.result as NetworkState); + }); + }, [send]); + + const handleRenewLease = useCallback(() => { + send("renewDHCPLease", {}, resp => { + if ("error" in resp) { + notifications.error("Failed to renew lease: " + resp.error.message); + } else { + notifications.success("DHCP lease renewed"); + } + }); + }, [send]); + + useEffect(() => { + getNetworkState(); + getNetworkSettings(); + }, [getNetworkState, getNetworkSettings]); + + const handleIpv4ModeChange = (value: IPv4Mode | string) => { + setNetworkSettings({ ...networkSettings, ipv4_mode: value as IPv4Mode }); + }; + + const handleIpv6ModeChange = (value: IPv6Mode | string) => { + setNetworkSettings({ ...networkSettings, ipv6_mode: value as IPv6Mode }); + }; + + const handleLldpModeChange = (value: LLDPMode | string) => { + setNetworkSettings({ ...networkSettings, lldp_mode: value as LLDPMode }); + }; + + // const handleLldpTxTlvsChange = (value: string[]) => { + // setNetworkSettings({ ...networkSettings, lldp_tx_tlvs: value }); + // }; + + const handleMdnsModeChange = (value: mDNSMode | string) => { + setNetworkSettings({ ...networkSettings, mdns_mode: value as mDNSMode }); + }; + + const handleTimeSyncModeChange = (value: TimeSyncMode | string) => { + setNetworkSettings({ ...networkSettings, time_sync_mode: value as TimeSyncMode }); + }; + + const filterUnknown = useCallback((options: { value: string; label: string; }[]) => { + if (!networkSettingsLoaded) return options; + return options.filter(option => option.value !== "unknown"); + }, [networkSettingsLoaded]); + + return ( +
+ +
+ } + > + + {networkState?.mac_address} + + +
+
+ + Hostname for the device +
+ + Leave blank for default + + + } + > + { + setNetworkSettings({ ...networkSettings, hostname: e.target.value }); + }} + disabled={!networkSettingsLoaded} + /> +
+
+
+ + Domain for the device +
+ + Leave blank to use DHCP provided domain, if there is no domain, use local + + + } + > + { + setNetworkSettings({ ...networkSettings, domain: e.target.value }); + }} + disabled={!networkSettingsLoaded} + /> +
+
+
+ + handleIpv4ModeChange(e.target.value)} + disabled={!networkSettingsLoaded} + options={filterUnknown([ + { value: "dhcp", label: "DHCP" }, + // { value: "static", label: "Static" }, + ])} + /> + + {networkState?.dhcp_lease && ( + +
+
+
+

+ Current DHCP Lease +

+
+
    + {networkState?.dhcp_lease?.ip &&
  • IP: {networkState?.dhcp_lease?.ip}
  • } + {networkState?.dhcp_lease?.netmask &&
  • Subnet: {networkState?.dhcp_lease?.netmask}
  • } + {networkState?.dhcp_lease?.broadcast &&
  • Broadcast: {networkState?.dhcp_lease?.broadcast}
  • } + {networkState?.dhcp_lease?.ttl &&
  • TTL: {networkState?.dhcp_lease?.ttl}
  • } + {networkState?.dhcp_lease?.mtu &&
  • MTU: {networkState?.dhcp_lease?.mtu}
  • } + {networkState?.dhcp_lease?.hostname &&
  • Hostname: {networkState?.dhcp_lease?.hostname}
  • } + {networkState?.dhcp_lease?.domain &&
  • Domain: {networkState?.dhcp_lease?.domain}
  • } + {networkState?.dhcp_lease?.routers &&
  • Gateway: {networkState?.dhcp_lease?.routers.join(", ")}
  • } + {networkState?.dhcp_lease?.dns &&
  • DNS: {networkState?.dhcp_lease?.dns.join(", ")}
  • } + {networkState?.dhcp_lease?.ntp_servers &&
  • NTP Servers: {networkState?.dhcp_lease?.ntp_servers.join(", ")}
  • } + {networkState?.dhcp_lease?.server_id &&
  • Server ID: {networkState?.dhcp_lease?.server_id}
  • } + {networkState?.dhcp_lease?.bootp_next_server &&
  • BootP Next Server: {networkState?.dhcp_lease?.bootp_next_server}
  • } + {networkState?.dhcp_lease?.bootp_server_name &&
  • BootP Server Name: {networkState?.dhcp_lease?.bootp_server_name}
  • } + {networkState?.dhcp_lease?.bootp_file &&
  • Boot File: {networkState?.dhcp_lease?.bootp_file}
  • } + {networkState?.dhcp_lease?.lease_expiry &&
  • + Lease Expiry: +
  • } + {/* {JSON.stringify(networkState?.dhcp_lease)} */} +
+
+
+
+
+
+
+
+
+ )} +
+
+ + handleIpv6ModeChange(e.target.value)} + disabled={!networkSettingsLoaded} + options={filterUnknown([ + // { value: "disabled", label: "Disabled" }, + { value: "slaac", label: "SLAAC" }, + // { value: "dhcpv6", label: "DHCPv6" }, + // { value: "slaac_and_dhcpv6", label: "SLAAC and DHCPv6" }, + // { value: "static", label: "Static" }, + // { value: "link_local", label: "Link-local only" }, + ])} + /> + + {networkState?.ipv6_addresses && ( + +
+
+
+

+ IPv6 Information +

+
+
+

+ IPv6 Link-local +

+

+ {networkState?.ipv6_link_local} +

+
+
+

+ IPv6 Addresses +

+
    + {networkState?.ipv6_addresses && networkState?.ipv6_addresses.map(addr => ( +
  • + {addr.address} + {addr.valid_lifetime && <> +
    + - valid_lft: {" "} + + + + } + {addr.preferred_lifetime && <> +
    + - pref_lft: {" "} + + + + } +
  • + ))} +
+
+
+
+
+
+
+ )} +
+
+ + handleLldpModeChange(e.target.value)} + disabled={!networkSettingsLoaded} + options={filterUnknown([ + { value: "disabled", label: "Disabled" }, + { value: "basic", label: "Basic" }, + { value: "all", label: "All" }, + ])} + /> + +
+
+ + handleMdnsModeChange(e.target.value)} + disabled={!networkSettingsLoaded} + options={filterUnknown([ + { value: "disabled", label: "Disabled" }, + { value: "auto", label: "Auto" }, + { value: "ipv4_only", label: "IPv4 only" }, + { value: "ipv6_only", label: "IPv6 only" }, + ])} + /> + +
+
+ + handleTimeSyncModeChange(e.target.value)} + disabled={!networkSettingsLoaded} + options={filterUnknown([ + { value: "unknown", label: "..." }, + // { value: "auto", label: "Auto" }, + { value: "ntp_only", label: "NTP only" }, + { value: "ntp_and_http", label: "NTP and HTTP" }, + { value: "http_only", label: "HTTP only" }, + // { value: "custom", label: "Custom" }, + ])} + /> + +
+
+
+
+ ); +} diff --git a/ui/src/routes/devices.$id.settings.tsx b/ui/src/routes/devices.$id.settings.tsx index c0b4181..f8e5262 100644 --- a/ui/src/routes/devices.$id.settings.tsx +++ b/ui/src/routes/devices.$id.settings.tsx @@ -9,6 +9,7 @@ import { LuArrowLeft, LuPalette, LuCommand, + LuNetwork, } from "react-icons/lu"; import React, { useEffect, useRef, useState } from "react"; @@ -207,6 +208,17 @@ export default function SettingsRoute() { +
+ (isActive ? "active" : "")} + > +
+ +

Network

+
+
+
state.setNetworkState); + const setUsbState = useHidStore(state => state.setUsbState); const setHdmiState = useVideoStore(state => state.setHdmiState); @@ -600,6 +604,11 @@ export default function KvmIdRoute() { setHdmiState(resp.params as Parameters[0]); } + if (resp.method === "networkState") { + console.log("Setting network state", resp.params); + setNetworkState(resp.params as NetworkState); + } + if (resp.method === "otaState") { const otaState = resp.params as UpdateState["otaState"]; setOtaState(otaState); diff --git a/ui/vite.config.ts b/ui/vite.config.ts index f8459cd..e47774f 100644 --- a/ui/vite.config.ts +++ b/ui/vite.config.ts @@ -35,6 +35,7 @@ export default defineConfig(({ mode, command }) => { "/auth": JETKVM_PROXY_URL, "/storage": JETKVM_PROXY_URL, "/cloud": JETKVM_PROXY_URL, + "/developer": JETKVM_PROXY_URL, } : undefined, }, diff --git a/usb.go b/usb.go index 3395db4..91674c9 100644 --- a/usb.go +++ b/usb.go @@ -66,6 +66,6 @@ func checkUSBState() { usbState = newState usbLogger.Info().Str("from", usbState).Str("to", newState).Msg("USB state changed") - requestDisplayUpdate() + requestDisplayUpdate(true) triggerUSBStateUpdate() } diff --git a/usb_mass_storage.go b/usb_mass_storage.go index 2b03f1f..79a05d1 100644 --- a/usb_mass_storage.go +++ b/usb_mass_storage.go @@ -62,7 +62,11 @@ func onDiskMessage(msg webrtc.DataChannelMessage) { func mountImage(imagePath string) error { err := setMassStorageImage("") if err != nil { - return fmt.Errorf("remove Mass Storage Image Error: %w", err) + return fmt.Errorf("remove mass storage image error: %w", err) + } + err = setMassStorageImage(imagePath) + if err != nil { + return fmt.Errorf("set mass storage image error: %w", err) } err = setMassStorageImage(imagePath) if err != nil { @@ -477,7 +481,6 @@ func handleUploadChannel(d *webrtc.DataChannel) { totalBytesWritten += int64(bytesWritten) sendProgress := time.Since(lastProgressTime) >= 200*time.Millisecond - if totalBytesWritten >= pendingUpload.Size { sendProgress = true close(uploadComplete) diff --git a/video.go b/video.go index d74add8..6fa77b9 100644 --- a/video.go +++ b/video.go @@ -43,7 +43,7 @@ func HandleVideoStateMessage(event CtrlResponse) { } lastVideoState = videoState triggerVideoStateUpdate() - requestDisplayUpdate() + requestDisplayUpdate(true) } func rpcGetVideoState() (VideoInputState, error) { diff --git a/web.go b/web.go index 6e74a13..766eaf5 100644 --- a/web.go +++ b/web.go @@ -9,6 +9,7 @@ import ( "fmt" "io/fs" "net/http" + "net/http/pprof" "path/filepath" "strings" "time" @@ -18,6 +19,7 @@ import ( gin_logger "github.com/gin-contrib/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/jetkvm/kvm/internal/logging" "github.com/pion/webrtc/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -103,6 +105,27 @@ func setupRouter() *gin.Engine { // A Prometheus metrics endpoint. r.GET("/metrics", gin.WrapH(promhttp.Handler())) + // Developer mode protected routes + developerModeRouter := r.Group("/developer/") + developerModeRouter.Use(basicAuthProtectedMiddleware(true)) + { + // pprof + developerModeRouter.GET("/pprof/", gin.WrapF(pprof.Index)) + developerModeRouter.GET("/pprof/cmdline", gin.WrapF(pprof.Cmdline)) + developerModeRouter.GET("/pprof/profile", gin.WrapF(pprof.Profile)) + developerModeRouter.POST("/pprof/symbol", gin.WrapF(pprof.Symbol)) + developerModeRouter.GET("/pprof/symbol", gin.WrapF(pprof.Symbol)) + developerModeRouter.GET("/pprof/trace", gin.WrapF(pprof.Trace)) + developerModeRouter.GET("/pprof/allocs", gin.WrapH(pprof.Handler("allocs"))) + developerModeRouter.GET("/pprof/block", gin.WrapH(pprof.Handler("block"))) + developerModeRouter.GET("/pprof/goroutine", gin.WrapH(pprof.Handler("goroutine"))) + developerModeRouter.GET("/pprof/heap", gin.WrapH(pprof.Handler("heap"))) + developerModeRouter.GET("/pprof/mutex", gin.WrapH(pprof.Handler("mutex"))) + developerModeRouter.GET("/pprof/threadcreate", gin.WrapH(pprof.Handler("threadcreate"))) + + logging.AttachSSEHandler(developerModeRouter) + } + // Protected routes (allows both password and noPassword modes) protected := r.Group("/") protected.Use(protectedMiddleware()) @@ -203,7 +226,7 @@ func handleLocalWebRTCSignal(c *gin.Context) { wsOptions := &websocket.AcceptOptions{ InsecureSkipVerify: true, // Allow connections from any origin OnPingReceived: func(ctx context.Context, payload []byte) bool { - scopedLogger.Info().Bytes("payload", payload).Msg("ping frame received") + scopedLogger.Debug().Bytes("payload", payload).Msg("ping frame received") metricConnectionTotalPingReceivedCount.WithLabelValues("local", source).Inc() metricConnectionLastPingReceivedTimestamp.WithLabelValues("local", source).SetToCurrentTime() @@ -242,7 +265,12 @@ func handleWebRTCSignalWsMessages( scopedLogger *zerolog.Logger, ) error { runCtx, cancelRun := context.WithCancel(context.Background()) - defer cancelRun() + defer func() { + if isCloudConnection { + setCloudConnectionState(CloudConnectionStateDisconnected) + } + cancelRun() + }() // connection type var sourceType string @@ -459,11 +487,51 @@ func protectedMiddleware() gin.HandlerFunc { } } +func sendErrorJsonThenAbort(c *gin.Context, status int, message string) { + c.JSON(status, gin.H{"error": message}) + c.Abort() +} + +func basicAuthProtectedMiddleware(requireDeveloperMode bool) gin.HandlerFunc { + return func(c *gin.Context) { + if requireDeveloperMode { + devModeState, err := rpcGetDevModeState() + if err != nil { + sendErrorJsonThenAbort(c, http.StatusInternalServerError, "Failed to get developer mode state") + return + } + + if !devModeState.Enabled { + sendErrorJsonThenAbort(c, http.StatusUnauthorized, "Developer mode is not enabled") + return + } + } + + if config.LocalAuthMode == "noPassword" { + sendErrorJsonThenAbort(c, http.StatusForbidden, "The resource is not available in noPassword mode") + return + } + + // calculate basic auth credentials + _, password, ok := c.Request.BasicAuth() + if !ok { + c.Header("WWW-Authenticate", "Basic realm=\"JetKVM\"") + sendErrorJsonThenAbort(c, http.StatusUnauthorized, "Basic auth is required") + return + } + + err := bcrypt.CompareHashAndPassword([]byte(config.HashedPassword), []byte(password)) + if err != nil { + sendErrorJsonThenAbort(c, http.StatusUnauthorized, "Invalid password") + return + } + + c.Next() + } +} + func RunWebServer() { r := setupRouter() - //if strings.Contains(builtAppVersion, "-dev") { - // pprof.Register(r) - //} err := r.Run(":80") if err != nil { panic(err) diff --git a/web_tls.go b/web_tls.go index cbff56b..564f150 100644 --- a/web_tls.go +++ b/web_tls.go @@ -54,7 +54,7 @@ func initCertStore() { func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { switch config.TLSMode { case "self-signed": - if isTimeSyncNeeded() || !timeSyncSuccess { + if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() { return nil, fmt.Errorf("time is not synced") } return certSigner.GetCertificate(info) @@ -174,7 +174,7 @@ func runWebSecureServer() { websecureLogger.Info().Msg("Shutting down websecure server") err := server.Shutdown(context.Background()) if err != nil { - websecureLogger.Error().Err(err).Msg("Failed to shutdown websecure server") + websecureLogger.Error().Err(err).Msg("failed to shutdown websecure server") } } }() diff --git a/webrtc.go b/webrtc.go index 1e093e2..f6c8529 100644 --- a/webrtc.go +++ b/webrtc.go @@ -10,6 +10,7 @@ import ( "github.com/coder/websocket" "github.com/coder/websocket/wsjson" "github.com/gin-gonic/gin" + "github.com/jetkvm/kvm/internal/logging" "github.com/pion/webrtc/v4" "github.com/rs/zerolog" ) @@ -68,7 +69,7 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) { func newSession(config SessionConfig) (*Session, error) { webrtcSettingEngine := webrtc.SettingEngine{ - LoggerFactory: defaultLoggerFactory, + LoggerFactory: logging.GetPionDefaultLoggerFactory(), } iceServer := webrtc.ICEServer{} @@ -205,7 +206,7 @@ func newSession(config SessionConfig) (*Session, error) { var actionSessions = 0 func onActiveSessionsChanged() { - requestDisplayUpdate() + requestDisplayUpdate(true) } func onFirstSessionConnected() {