diff --git a/block_device.go b/block_device.go
index e4eab80..2274098 100644
--- a/block_device.go
+++ b/block_device.go
@@ -7,7 +7,6 @@ import (
"os"
"time"
- "github.com/pojntfx/go-nbd/pkg/client"
"github.com/pojntfx/go-nbd/pkg/server"
"github.com/rs/zerolog"
)
@@ -149,30 +148,3 @@ func (d *NBDDevice) runServerConn() {
d.l.Info().Err(err).Msg("nbd server exited")
}
-
-func (d *NBDDevice) runClientConn() {
- err := client.Connect(d.clientConn, d.dev, &client.Options{
- ExportName: "jetkvm",
- BlockSize: uint32(4 * 1024),
- })
- d.l.Info().Err(err).Msg("nbd client exited")
-}
-
-func (d *NBDDevice) Close() {
- if d.dev != nil {
- err := client.Disconnect(d.dev)
- if err != nil {
- d.l.Warn().Err(err).Msg("error disconnecting nbd client")
- }
- _ = d.dev.Close()
- }
- if d.listener != nil {
- _ = d.listener.Close()
- }
- if d.clientConn != nil {
- _ = d.clientConn.Close()
- }
- if d.serverConn != nil {
- _ = d.serverConn.Close()
- }
-}
diff --git a/block_device_linux.go b/block_device_linux.go
new file mode 100644
index 0000000..8ca9372
--- /dev/null
+++ b/block_device_linux.go
@@ -0,0 +1,34 @@
+//go:build linux
+
+package kvm
+
+import (
+ "github.com/pojntfx/go-nbd/pkg/client"
+)
+
+func (d *NBDDevice) runClientConn() {
+ err := client.Connect(d.clientConn, d.dev, &client.Options{
+ ExportName: "jetkvm",
+ BlockSize: uint32(4 * 1024),
+ })
+ d.l.Info().Err(err).Msg("nbd client exited")
+}
+
+func (d *NBDDevice) Close() {
+ if d.dev != nil {
+ err := client.Disconnect(d.dev)
+ if err != nil {
+ d.l.Warn().Err(err).Msg("error disconnecting nbd client")
+ }
+ _ = d.dev.Close()
+ }
+ if d.listener != nil {
+ _ = d.listener.Close()
+ }
+ if d.clientConn != nil {
+ _ = d.clientConn.Close()
+ }
+ if d.serverConn != nil {
+ _ = d.serverConn.Close()
+ }
+}
diff --git a/block_device_notlinux.go b/block_device_notlinux.go
new file mode 100644
index 0000000..b6a9aba
--- /dev/null
+++ b/block_device_notlinux.go
@@ -0,0 +1,17 @@
+//go:build !linux
+
+package kvm
+
+import (
+ "os"
+)
+
+func (d *NBDDevice) runClientConn() {
+ d.l.Error().Msg("platform not supported")
+ os.Exit(1)
+}
+
+func (d *NBDDevice) Close() {
+ d.l.Error().Msg("platform not supported")
+ os.Exit(1)
+}
diff --git a/cloud.go b/cloud.go
index fd96c41..fb1998a 100644
--- a/cloud.go
+++ b/cloud.go
@@ -139,11 +139,40 @@ var (
)
)
+type CloudConnectionState uint8
+
+const (
+ CloudConnectionStateNotConfigured CloudConnectionState = iota
+ CloudConnectionStateDisconnected
+ CloudConnectionStateConnecting
+ CloudConnectionStateConnected
+)
+
var (
+ cloudConnectionState CloudConnectionState = CloudConnectionStateNotConfigured
+ cloudConnectionStateLock = &sync.Mutex{}
+
cloudDisconnectChan chan error
cloudDisconnectLock = &sync.Mutex{}
)
+func setCloudConnectionState(state CloudConnectionState) {
+ cloudConnectionStateLock.Lock()
+ defer cloudConnectionStateLock.Unlock()
+
+ if cloudConnectionState == CloudConnectionStateDisconnected &&
+ (config.CloudToken == "" || config.CloudURL == "") {
+ state = CloudConnectionStateNotConfigured
+ }
+
+ previousState := cloudConnectionState
+ cloudConnectionState = state
+
+ go waitCtrlAndRequestDisplayUpdate(
+ previousState != state,
+ )
+}
+
func wsResetMetrics(established bool, sourceType string, source string) {
metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).Set(-1)
metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(-1)
@@ -285,6 +314,8 @@ func runWebsocketClient() error {
wsURL.Scheme = "wss"
}
+ setCloudConnectionState(CloudConnectionStateConnecting)
+
header := http.Header{}
header.Set("X-Device-ID", GetDeviceID())
header.Set("X-App-Version", builtAppVersion)
@@ -302,20 +333,26 @@ func runWebsocketClient() error {
c, resp, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{
HTTPHeader: header,
OnPingReceived: func(ctx context.Context, payload []byte) bool {
- scopedLogger.Info().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received")
+ scopedLogger.Debug().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received")
metricConnectionTotalPingReceivedCount.WithLabelValues("cloud", wsURL.Host).Inc()
metricConnectionLastPingReceivedTimestamp.WithLabelValues("cloud", wsURL.Host).SetToCurrentTime()
+ setCloudConnectionState(CloudConnectionStateConnected)
+
return true
},
})
- // get the request id from the response header
- connectionId := resp.Header.Get("X-Request-ID")
- if connectionId == "" {
- connectionId = resp.Header.Get("Cf-Ray")
+ var connectionId string
+ if resp != nil {
+ // get the request id from the response header
+ connectionId = resp.Header.Get("X-Request-ID")
+ if connectionId == "" {
+ connectionId = resp.Header.Get("Cf-Ray")
+ }
}
+
if connectionId == "" {
connectionId = uuid.New().String()
scopedLogger.Warn().
@@ -332,6 +369,8 @@ func runWebsocketClient() error {
if err != nil {
if errors.Is(err, context.Canceled) {
cloudLogger.Info().Msg("websocket connection canceled")
+ setCloudConnectionState(CloudConnectionStateDisconnected)
+
return nil
}
return err
@@ -450,14 +489,14 @@ func RunWebsocketClient() {
}
// If the network is not up, well, we can't connect to the cloud.
- if !networkState.Up {
- cloudLogger.Warn().Msg("waiting for network to be up, will retry in 3 seconds")
+ if !networkState.IsOnline() {
+ cloudLogger.Warn().Msg("waiting for network to be online, will retry in 3 seconds")
time.Sleep(3 * time.Second)
continue
}
// If the system time is not synchronized, the API request will fail anyway because the TLS handshake will fail.
- if isTimeSyncNeeded() && !timeSyncSuccess {
+ if isTimeSyncNeeded() && !timeSync.IsSyncSuccess() {
cloudLogger.Warn().Msg("system time is not synced, will retry in 3 seconds")
time.Sleep(3 * time.Second)
continue
@@ -520,6 +559,8 @@ func rpcDeregisterDevice() error {
cloudLogger.Info().Msg("device deregistered, disconnecting from cloud")
disconnectCloud(fmt.Errorf("device deregistered"))
+ setCloudConnectionState(CloudConnectionStateNotConfigured)
+
return nil
}
diff --git a/config.go b/config.go
index cf096a7..23d4c84 100644
--- a/config.go
+++ b/config.go
@@ -6,6 +6,8 @@ import (
"os"
"sync"
+ "github.com/jetkvm/kvm/internal/logging"
+ "github.com/jetkvm/kvm/internal/network"
"github.com/jetkvm/kvm/internal/usbgadget"
)
@@ -73,27 +75,28 @@ func (m *KeyboardMacro) Validate() error {
}
type Config struct {
- CloudURL string `json:"cloud_url"`
- CloudAppURL string `json:"cloud_app_url"`
- CloudToken string `json:"cloud_token"`
- GoogleIdentity string `json:"google_identity"`
- JigglerEnabled bool `json:"jiggler_enabled"`
- AutoUpdateEnabled bool `json:"auto_update_enabled"`
- IncludePreRelease bool `json:"include_pre_release"`
- HashedPassword string `json:"hashed_password"`
- LocalAuthToken string `json:"local_auth_token"`
- LocalAuthMode string `json:"localAuthMode"` //TODO: fix it with migration
- WakeOnLanDevices []WakeOnLanDevice `json:"wake_on_lan_devices"`
- KeyboardMacros []KeyboardMacro `json:"keyboard_macros"`
- EdidString string `json:"hdmi_edid_string"`
- ActiveExtension string `json:"active_extension"`
- DisplayMaxBrightness int `json:"display_max_brightness"`
- DisplayDimAfterSec int `json:"display_dim_after_sec"`
- DisplayOffAfterSec int `json:"display_off_after_sec"`
- TLSMode string `json:"tls_mode"` // options: "self-signed", "user-defined", ""
- UsbConfig *usbgadget.Config `json:"usb_config"`
- UsbDevices *usbgadget.Devices `json:"usb_devices"`
- DefaultLogLevel string `json:"default_log_level"`
+ CloudURL string `json:"cloud_url"`
+ CloudAppURL string `json:"cloud_app_url"`
+ CloudToken string `json:"cloud_token"`
+ GoogleIdentity string `json:"google_identity"`
+ JigglerEnabled bool `json:"jiggler_enabled"`
+ AutoUpdateEnabled bool `json:"auto_update_enabled"`
+ IncludePreRelease bool `json:"include_pre_release"`
+ HashedPassword string `json:"hashed_password"`
+ LocalAuthToken string `json:"local_auth_token"`
+ LocalAuthMode string `json:"localAuthMode"` //TODO: fix it with migration
+ WakeOnLanDevices []WakeOnLanDevice `json:"wake_on_lan_devices"`
+ KeyboardMacros []KeyboardMacro `json:"keyboard_macros"`
+ EdidString string `json:"hdmi_edid_string"`
+ ActiveExtension string `json:"active_extension"`
+ DisplayMaxBrightness int `json:"display_max_brightness"`
+ DisplayDimAfterSec int `json:"display_dim_after_sec"`
+ DisplayOffAfterSec int `json:"display_off_after_sec"`
+ TLSMode string `json:"tls_mode"` // options: "self-signed", "user-defined", ""
+ UsbConfig *usbgadget.Config `json:"usb_config"`
+ UsbDevices *usbgadget.Devices `json:"usb_devices"`
+ NetworkConfig *network.NetworkConfig `json:"network_config"`
+ DefaultLogLevel string `json:"default_log_level"`
}
const configPath = "/userdata/kvm_config.json"
@@ -121,6 +124,7 @@ var defaultConfig = &Config{
Keyboard: true,
MassStorage: true,
},
+ NetworkConfig: &network.NetworkConfig{},
DefaultLogLevel: "INFO",
}
@@ -134,7 +138,7 @@ func LoadConfig() {
defer configLock.Unlock()
if config != nil {
- logger.Info().Msg("config already loaded, skipping")
+ logger.Debug().Msg("config already loaded, skipping")
return
}
@@ -164,9 +168,15 @@ func LoadConfig() {
loadedConfig.UsbDevices = defaultConfig.UsbDevices
}
+ if loadedConfig.NetworkConfig == nil {
+ loadedConfig.NetworkConfig = defaultConfig.NetworkConfig
+ }
+
config = &loadedConfig
- rootLogger.UpdateLogLevel()
+ logging.GetRootLogger().UpdateLogLevel(config.DefaultLogLevel)
+
+ logger.Info().Str("path", configPath).Msg("config loaded")
}
func SaveConfig() error {
diff --git a/dev_deploy.sh b/dev_deploy.sh
index 02bbb24..d0ccaf2 100755
--- a/dev_deploy.sh
+++ b/dev_deploy.sh
@@ -24,6 +24,7 @@ show_help() {
REMOTE_USER="root"
REMOTE_PATH="/userdata/jetkvm/bin"
SKIP_UI_BUILD=false
+LOG_TRACE_SCOPES="${LOG_TRACE_SCOPES:-jetkvm,cloud,websocket,native,jsonrpc}"
# Parse command line arguments
while [[ $# -gt 0 ]]; do
@@ -91,7 +92,7 @@ cd "${REMOTE_PATH}"
chmod +x jetkvm_app_debug
# Run the application in the background
-PION_LOG_TRACE=jetkvm,cloud,websocket ./jetkvm_app_debug
+PION_LOG_TRACE=${LOG_TRACE_SCOPES} ./jetkvm_app_debug
EOF
echo "Deployment complete."
diff --git a/display.go b/display.go
index cbe9ddd..e2e82e1 100644
--- a/display.go
+++ b/display.go
@@ -33,50 +33,153 @@ func switchToScreen(screen string) {
var displayedTexts = make(map[string]string)
+func lvObjSetState(objName string, state string) (*CtrlResponse, error) {
+ return CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": objName, "state": state})
+}
+
+func lvObjAddFlag(objName string, flag string) (*CtrlResponse, error) {
+ return CallCtrlAction("lv_obj_add_flag", map[string]interface{}{"obj": objName, "flag": flag})
+}
+
+func lvObjClearFlag(objName string, flag string) (*CtrlResponse, error) {
+ return CallCtrlAction("lv_obj_clear_flag", map[string]interface{}{"obj": objName, "flag": flag})
+}
+
+func lvObjHide(objName string) (*CtrlResponse, error) {
+ return lvObjAddFlag(objName, "LV_OBJ_FLAG_HIDDEN")
+}
+
+func lvObjShow(objName string) (*CtrlResponse, error) {
+ return lvObjClearFlag(objName, "LV_OBJ_FLAG_HIDDEN")
+}
+
+func lvObjSetOpacity(objName string, opacity int) (*CtrlResponse, error) { // nolint:unused
+ return CallCtrlAction("lv_obj_set_style_opa_layered", map[string]interface{}{"obj": objName, "opa": opacity})
+}
+
+func lvObjFadeIn(objName string, duration uint32) (*CtrlResponse, error) {
+ return CallCtrlAction("lv_obj_fade_in", map[string]interface{}{"obj": objName, "time": duration})
+}
+
+func lvObjFadeOut(objName string, duration uint32) (*CtrlResponse, error) {
+ return CallCtrlAction("lv_obj_fade_out", map[string]interface{}{"obj": objName, "time": duration})
+}
+
+func lvLabelSetText(objName string, text string) (*CtrlResponse, error) {
+ return CallCtrlAction("lv_label_set_text", map[string]interface{}{"obj": objName, "text": text})
+}
+
+func lvImgSetSrc(objName string, src string) (*CtrlResponse, error) {
+ return CallCtrlAction("lv_img_set_src", map[string]interface{}{"obj": objName, "src": src})
+}
+
func updateLabelIfChanged(objName string, newText string) {
if newText != "" && newText != displayedTexts[objName] {
- _, _ = CallCtrlAction("lv_label_set_text", map[string]interface{}{"obj": objName, "text": newText})
+ _, _ = lvLabelSetText(objName, newText)
displayedTexts[objName] = newText
}
}
func switchToScreenIfDifferent(screenName string) {
- displayLogger.Info().Str("from", currentScreen).Str("to", screenName).Msg("switching screen")
if currentScreen != screenName {
+ displayLogger.Info().Str("from", currentScreen).Str("to", screenName).Msg("switching screen")
switchToScreen(screenName)
}
}
+var (
+ cloudBlinkLock sync.Mutex = sync.Mutex{}
+ cloudBlinkStopped bool
+ cloudBlinkTicker *time.Ticker
+)
+
func updateDisplay() {
- updateLabelIfChanged("ui_Home_Content_Ip", networkState.IPv4)
+ updateLabelIfChanged("ui_Home_Content_Ip", networkState.IPv4String())
if usbState == "configured" {
updateLabelIfChanged("ui_Home_Footer_Usb_Status_Label", "Connected")
- _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Usb_Status_Label", "state": "LV_STATE_DEFAULT"})
+ _, _ = lvObjSetState("ui_Home_Footer_Usb_Status_Label", "LV_STATE_DEFAULT")
} else {
updateLabelIfChanged("ui_Home_Footer_Usb_Status_Label", "Disconnected")
- _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Usb_Status_Label", "state": "LV_STATE_USER_2"})
+ _, _ = lvObjSetState("ui_Home_Footer_Usb_Status_Label", "LV_STATE_USER_2")
}
if lastVideoState.Ready {
updateLabelIfChanged("ui_Home_Footer_Hdmi_Status_Label", "Connected")
- _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Hdmi_Status_Label", "state": "LV_STATE_DEFAULT"})
+ _, _ = lvObjSetState("ui_Home_Footer_Hdmi_Status_Label", "LV_STATE_DEFAULT")
} else {
updateLabelIfChanged("ui_Home_Footer_Hdmi_Status_Label", "Disconnected")
- _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Hdmi_Status_Label", "state": "LV_STATE_USER_2"})
+ _, _ = lvObjSetState("ui_Home_Footer_Hdmi_Status_Label", "LV_STATE_USER_2")
}
updateLabelIfChanged("ui_Home_Header_Cloud_Status_Label", fmt.Sprintf("%d active", actionSessions))
- if networkState.Up {
+
+ if networkState.IsUp() {
switchToScreenIfDifferent("ui_Home_Screen")
} else {
switchToScreenIfDifferent("ui_No_Network_Screen")
}
+
+ if cloudConnectionState == CloudConnectionStateNotConfigured {
+ _, _ = lvObjHide("ui_Home_Header_Cloud_Status_Icon")
+ } else {
+ _, _ = lvObjShow("ui_Home_Header_Cloud_Status_Icon")
+ }
+
+ switch cloudConnectionState {
+ case CloudConnectionStateDisconnected:
+ _, _ = lvImgSetSrc("ui_Home_Header_Cloud_Status_Icon", "cloud_disconnected.png")
+ stopCloudBlink()
+ case CloudConnectionStateConnecting:
+ _, _ = lvImgSetSrc("ui_Home_Header_Cloud_Status_Icon", "cloud.png")
+ startCloudBlink()
+ case CloudConnectionStateConnected:
+ _, _ = lvImgSetSrc("ui_Home_Header_Cloud_Status_Icon", "cloud.png")
+ stopCloudBlink()
+ }
+}
+
+func startCloudBlink() {
+ if cloudBlinkTicker == nil {
+ cloudBlinkTicker = time.NewTicker(2 * time.Second)
+ } else {
+ // do nothing if the blink isn't stopped
+ if cloudBlinkStopped {
+ cloudBlinkLock.Lock()
+ defer cloudBlinkLock.Unlock()
+
+ cloudBlinkStopped = false
+ cloudBlinkTicker.Reset(2 * time.Second)
+ }
+ }
+
+ go func() {
+ for range cloudBlinkTicker.C {
+ if cloudConnectionState != CloudConnectionStateConnecting {
+ continue
+ }
+ _, _ = lvObjFadeOut("ui_Home_Header_Cloud_Status_Icon", 1000)
+ time.Sleep(1000 * time.Millisecond)
+ _, _ = lvObjFadeIn("ui_Home_Header_Cloud_Status_Icon", 1000)
+ time.Sleep(1000 * time.Millisecond)
+ }
+ }()
+}
+
+func stopCloudBlink() {
+ if cloudBlinkTicker != nil {
+ cloudBlinkTicker.Stop()
+ }
+
+ cloudBlinkLock.Lock()
+ defer cloudBlinkLock.Unlock()
+ cloudBlinkStopped = true
}
var (
displayInited = false
displayUpdateLock = sync.Mutex{}
+ waitDisplayUpdate = sync.Mutex{}
)
-func requestDisplayUpdate() {
+func requestDisplayUpdate(shouldWakeDisplay bool) {
displayUpdateLock.Lock()
defer displayUpdateLock.Unlock()
@@ -85,16 +188,26 @@ func requestDisplayUpdate() {
return
}
go func() {
- wakeDisplay(false)
- displayLogger.Info().Msg("display updating")
+ if shouldWakeDisplay {
+ wakeDisplay(false)
+ }
+ displayLogger.Debug().Msg("display updating")
//TODO: only run once regardless how many pending updates
updateDisplay()
}()
}
+func waitCtrlAndRequestDisplayUpdate(shouldWakeDisplay bool) {
+ waitDisplayUpdate.Lock()
+ defer waitDisplayUpdate.Unlock()
+
+ waitCtrlClientConnected()
+ requestDisplayUpdate(shouldWakeDisplay)
+}
+
func updateStaticContents() {
//contents that never change
- updateLabelIfChanged("ui_Home_Content_Mac", networkState.MAC)
+ updateLabelIfChanged("ui_Home_Content_Mac", networkState.MACString())
systemVersion, appVersion, err := GetLocalVersion()
if err == nil {
updateLabelIfChanged("ui_About_Content_Operating_System_Version_ContentLabel", systemVersion.String())
@@ -265,7 +378,7 @@ func init() {
displayLogger.Info().Msg("display inited")
startBacklightTickers()
wakeDisplay(true)
- requestDisplayUpdate()
+ requestDisplayUpdate(true)
}()
go watchTsEvents()
diff --git a/go.mod b/go.mod
index 1311a33..6784a59 100644
--- a/go.mod
+++ b/go.mod
@@ -8,6 +8,7 @@ require (
github.com/coder/websocket v1.8.13
github.com/coreos/go-oidc/v3 v3.11.0
github.com/creack/pty v1.1.23
+ github.com/fsnotify/fsnotify v1.9.0
github.com/gin-contrib/logger v1.2.5
github.com/gin-gonic/gin v1.10.0
github.com/google/uuid v1.6.0
@@ -44,6 +45,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.26.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
+ github.com/guregu/null/v6 v6.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
diff --git a/go.sum b/go.sum
index 565c0cc..3ad832a 100644
--- a/go.sum
+++ b/go.sum
@@ -28,6 +28,8 @@ github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfv
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
+github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/gin-contrib/logger v1.2.5 h1:qVQI4omayQecuN4zX9ZZnsOq7w9J/ZLds3J/FMn8ypM=
@@ -54,6 +56,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/guregu/null/v6 v6.0.0 h1:N14VRS+4di81i1PXRiprbQJ9EM9gqBa0+KVMeS/QSjQ=
+github.com/guregu/null/v6 v6.0.0/go.mod h1:hrMIhIfrOZeLPZhROSn149tpw2gHkidAqxoXNyeX3iQ=
github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf h1:JO6ISZIvEUitto5zjQ3/VEnDM5rPbqIFuOhS0U0ByeA=
github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf/go.mod h1:5Kt9XkWvkGi2OHOq0QsGxebHmhCcqJ8KCbNg/a6+n+g=
github.com/hanwen/go-fuse/v2 v2.5.1 h1:OQBE8zVemSocRxA4OaFJbjJ5hlpCmIWbGr7r0M4uoQQ=
diff --git a/hw.go b/hw.go
index 21bffad..20d88eb 100644
--- a/hw.go
+++ b/hw.go
@@ -4,6 +4,7 @@ import (
"fmt"
"os"
"regexp"
+ "strings"
"sync"
"time"
)
@@ -51,6 +52,15 @@ func GetDeviceID() string {
return deviceID
}
+func GetDefaultHostname() string {
+ deviceId := GetDeviceID()
+ if deviceId == "unknown_device_id" {
+ return "jetkvm"
+ }
+
+ return fmt.Sprintf("jetkvm-%s", strings.ToLower(deviceId))
+}
+
func runWatchdog() {
file, err := os.OpenFile("/dev/watchdog", os.O_WRONLY, 0)
if err != nil {
diff --git a/internal/confparser/confparser.go b/internal/confparser/confparser.go
new file mode 100644
index 0000000..76102a3
--- /dev/null
+++ b/internal/confparser/confparser.go
@@ -0,0 +1,381 @@
+package confparser
+
+import (
+ "fmt"
+ "net"
+ "reflect"
+ "slices"
+ "strconv"
+ "strings"
+
+ "github.com/guregu/null/v6"
+ "golang.org/x/net/idna"
+)
+
+type FieldConfig struct {
+ Name string
+ Required bool
+ RequiredIf map[string]interface{}
+ OneOf []string
+ ValidateTypes []string
+ Defaults interface{}
+ IsEmpty bool
+ CurrentValue interface{}
+ TypeString string
+ Delegated bool
+ shouldUpdateValue bool
+}
+
+func SetDefaultsAndValidate(config interface{}) error {
+ return setDefaultsAndValidate(config, true)
+}
+
+func setDefaultsAndValidate(config interface{}, isRoot bool) error {
+ // first we need to check if the config is a pointer
+ if reflect.TypeOf(config).Kind() != reflect.Ptr {
+ return fmt.Errorf("config is not a pointer")
+ }
+
+ // now iterate over the lease struct and set the values
+ configType := reflect.TypeOf(config).Elem()
+ configValue := reflect.ValueOf(config).Elem()
+
+ fields := make(map[string]FieldConfig)
+
+ for i := 0; i < configType.NumField(); i++ {
+ field := configType.Field(i)
+ fieldValue := configValue.Field(i)
+
+ defaultValue := field.Tag.Get("default")
+
+ fieldType := field.Type.String()
+
+ fieldConfig := FieldConfig{
+ Name: field.Name,
+ OneOf: splitString(field.Tag.Get("one_of")),
+ ValidateTypes: splitString(field.Tag.Get("validate_type")),
+ RequiredIf: make(map[string]interface{}),
+ CurrentValue: fieldValue.Interface(),
+ IsEmpty: false,
+ TypeString: fieldType,
+ }
+
+ // check if the field is required
+ required := field.Tag.Get("required")
+ if required != "" {
+ requiredBool, _ := strconv.ParseBool(required)
+ fieldConfig.Required = requiredBool
+ }
+
+ var canUseOneOff = false
+
+ // use switch to get the type
+ switch fieldValue.Interface().(type) {
+ case string, null.String:
+ if defaultValue != "" {
+ fieldConfig.Defaults = defaultValue
+ }
+ canUseOneOff = true
+ case []string:
+ if defaultValue != "" {
+ fieldConfig.Defaults = strings.Split(defaultValue, ",")
+ }
+ canUseOneOff = true
+ case int, null.Int:
+ if defaultValue != "" {
+ defaultValueInt, err := strconv.Atoi(defaultValue)
+ if err != nil {
+ return fmt.Errorf("invalid default value for field `%s`: %s", field.Name, defaultValue)
+ }
+
+ fieldConfig.Defaults = defaultValueInt
+ }
+ case bool, null.Bool:
+ if defaultValue != "" {
+ defaultValueBool, err := strconv.ParseBool(defaultValue)
+ if err != nil {
+ return fmt.Errorf("invalid default value for field `%s`: %s", field.Name, defaultValue)
+ }
+
+ fieldConfig.Defaults = defaultValueBool
+ }
+ default:
+ if defaultValue != "" {
+ return fmt.Errorf("field `%s` cannot use default value: unsupported type: %s", field.Name, fieldType)
+ }
+
+ // check if it's a pointer
+ if fieldValue.Kind() == reflect.Ptr {
+ // check if the pointer is nil
+ if fieldValue.IsNil() {
+ fieldConfig.IsEmpty = true
+ } else {
+ fieldConfig.CurrentValue = fieldValue.Elem().Addr()
+ fieldConfig.Delegated = true
+ }
+ } else {
+ fieldConfig.Delegated = true
+ }
+ }
+
+ // now check if the field is nullable interface
+ switch fieldValue.Interface().(type) {
+ case null.String:
+ if fieldValue.Interface().(null.String).IsZero() {
+ fieldConfig.IsEmpty = true
+ }
+ case null.Int:
+ if fieldValue.Interface().(null.Int).IsZero() {
+ fieldConfig.IsEmpty = true
+ }
+ case null.Bool:
+ if fieldValue.Interface().(null.Bool).IsZero() {
+ fieldConfig.IsEmpty = true
+ }
+ case []string:
+ if len(fieldValue.Interface().([]string)) == 0 {
+ fieldConfig.IsEmpty = true
+ }
+ }
+
+ // now check if the field has required_if
+ requiredIf := field.Tag.Get("required_if")
+ if requiredIf != "" {
+ requiredIfParts := strings.Split(requiredIf, ",")
+ for _, part := range requiredIfParts {
+ partVal := strings.SplitN(part, "=", 2)
+ if len(partVal) != 2 {
+ return fmt.Errorf("invalid required_if for field `%s`: %s", field.Name, requiredIf)
+ }
+
+ fieldConfig.RequiredIf[partVal[0]] = partVal[1]
+ }
+ }
+
+ // check if the field can use one_of
+ if !canUseOneOff && len(fieldConfig.OneOf) > 0 {
+ return fmt.Errorf("field `%s` cannot use one_of: unsupported type: %s", field.Name, fieldType)
+ }
+
+ fields[field.Name] = fieldConfig
+ }
+
+ if err := validateFields(config, fields); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func validateFields(config interface{}, fields map[string]FieldConfig) error {
+ // now we can start to validate the fields
+ for _, fieldConfig := range fields {
+ if err := fieldConfig.validate(fields); err != nil {
+ return err
+ }
+
+ fieldConfig.populate(config)
+ }
+
+ return nil
+}
+
+func (f *FieldConfig) validate(fields map[string]FieldConfig) error {
+ var required bool
+ var err error
+
+ if required, err = f.validateRequired(fields); err != nil {
+ return err
+ }
+
+ // check if the field needs to be updated and set defaults if needed
+ if err := f.checkIfFieldNeedsUpdate(); err != nil {
+ return err
+ }
+
+ // then we can check if the field is one_of
+ if err := f.validateOneOf(); err != nil {
+ return err
+ }
+
+ // and validate the type
+ if err := f.validateField(); err != nil {
+ return err
+ }
+
+ // if the field is delegated, we need to validate the nested field
+ // but before that, let's check if the field is required
+ if required && f.Delegated {
+ if err := setDefaultsAndValidate(f.CurrentValue.(reflect.Value).Interface(), false); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (f *FieldConfig) populate(config interface{}) {
+ // update the field if it's not empty
+ if !f.shouldUpdateValue {
+ return
+ }
+
+ reflect.ValueOf(config).Elem().FieldByName(f.Name).Set(reflect.ValueOf(f.CurrentValue))
+}
+
+func (f *FieldConfig) checkIfFieldNeedsUpdate() error {
+ // populate the field if it's empty and has a default value
+ if f.IsEmpty && f.Defaults != nil {
+ switch f.CurrentValue.(type) {
+ case null.String:
+ f.CurrentValue = null.StringFrom(f.Defaults.(string))
+ case null.Int:
+ f.CurrentValue = null.IntFrom(int64(f.Defaults.(int)))
+ case null.Bool:
+ f.CurrentValue = null.BoolFrom(f.Defaults.(bool))
+ case string:
+ f.CurrentValue = f.Defaults.(string)
+ case int:
+ f.CurrentValue = f.Defaults.(int)
+ case bool:
+ f.CurrentValue = f.Defaults.(bool)
+ case []string:
+ f.CurrentValue = f.Defaults.([]string)
+ default:
+ return fmt.Errorf("field `%s` cannot use default value: unsupported type: %s", f.Name, f.TypeString)
+ }
+
+ f.shouldUpdateValue = true
+ }
+
+ return nil
+}
+
+func (f *FieldConfig) validateRequired(fields map[string]FieldConfig) (bool, error) {
+ var required = f.Required
+
+ // if the field is not required, we need to check if it's required_if
+ if !required && len(f.RequiredIf) > 0 {
+ for key, value := range f.RequiredIf {
+ // check if the field's result matches the required_if
+ // right now we only support string and int
+ requiredField, ok := fields[key]
+ if !ok {
+ return required, fmt.Errorf("required_if field `%s` not found", key)
+ }
+
+ switch requiredField.CurrentValue.(type) {
+ case string:
+ if requiredField.CurrentValue.(string) == value.(string) {
+ required = true
+ }
+ case int:
+ if requiredField.CurrentValue.(int) == value.(int) {
+ required = true
+ }
+ case null.String:
+ if !requiredField.CurrentValue.(null.String).IsZero() &&
+ requiredField.CurrentValue.(null.String).String == value.(string) {
+ required = true
+ }
+ case null.Int:
+ if !requiredField.CurrentValue.(null.Int).IsZero() &&
+ requiredField.CurrentValue.(null.Int).Int64 == value.(int64) {
+ required = true
+ }
+ }
+
+ // if the field is required, we can break the loop
+ // because we only need one of the required_if fields to be true
+ if required {
+ break
+ }
+ }
+ }
+
+ if required && f.IsEmpty {
+ return false, fmt.Errorf("field `%s` is required", f.Name)
+ }
+
+ return required, nil
+}
+
+func checkIfSliceContains(slice []string, one_of []string) bool {
+ for _, oneOf := range one_of {
+ if slices.Contains(slice, oneOf) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (f *FieldConfig) validateOneOf() error {
+ if len(f.OneOf) == 0 {
+ return nil
+ }
+
+ var val []string
+ switch f.CurrentValue.(type) {
+ case string:
+ val = []string{f.CurrentValue.(string)}
+ case null.String:
+ val = []string{f.CurrentValue.(null.String).String}
+ case []string:
+ // let's validate the value here
+ val = f.CurrentValue.([]string)
+ default:
+ return fmt.Errorf("field `%s` cannot use one_of: unsupported type: %s", f.Name, f.TypeString)
+ }
+
+ if !checkIfSliceContains(val, f.OneOf) {
+ return fmt.Errorf(
+ "field `%s` is not one of the allowed values: %s, current value: %s",
+ f.Name,
+ strings.Join(f.OneOf, ", "),
+ strings.Join(val, ", "),
+ )
+ }
+
+ return nil
+}
+
+func (f *FieldConfig) validateField() error {
+ if len(f.ValidateTypes) == 0 || f.IsEmpty {
+ return nil
+ }
+
+ val, err := toString(f.CurrentValue)
+ if err != nil {
+ return fmt.Errorf("field `%s` cannot use validate_type: %s", f.Name, err)
+ }
+
+ if val == "" {
+ return nil
+ }
+
+ for _, validateType := range f.ValidateTypes {
+ switch validateType {
+ case "ipv4":
+ if net.ParseIP(val).To4() == nil {
+ return fmt.Errorf("field `%s` is not a valid IPv4 address: %s", f.Name, val)
+ }
+ case "ipv6":
+ if net.ParseIP(val).To16() == nil {
+ return fmt.Errorf("field `%s` is not a valid IPv6 address: %s", f.Name, val)
+ }
+ case "hwaddr":
+ if _, err := net.ParseMAC(val); err != nil {
+ return fmt.Errorf("field `%s` is not a valid MAC address: %s", f.Name, val)
+ }
+ case "hostname":
+ if _, err := idna.Lookup.ToASCII(val); err != nil {
+ return fmt.Errorf("field `%s` is not a valid hostname: %s", f.Name, val)
+ }
+ default:
+ return fmt.Errorf("field `%s` cannot use validate_type: unsupported validator: %s", f.Name, validateType)
+ }
+ }
+
+ return nil
+}
diff --git a/internal/confparser/confparser_test.go b/internal/confparser/confparser_test.go
new file mode 100644
index 0000000..dd5e00a
--- /dev/null
+++ b/internal/confparser/confparser_test.go
@@ -0,0 +1,100 @@
+package confparser
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/guregu/null/v6"
+)
+
+type testIPv6Address struct { //nolint:unused
+ Address net.IP `json:"address"`
+ Prefix net.IPNet `json:"prefix"`
+ ValidLifetime *time.Time `json:"valid_lifetime"`
+ PreferredLifetime *time.Time `json:"preferred_lifetime"`
+ Scope int `json:"scope"`
+}
+
+type testIPv4StaticConfig struct {
+ Address null.String `json:"address" validate_type:"ipv4" required:"true"`
+ Netmask null.String `json:"netmask" validate_type:"ipv4" required:"true"`
+ Gateway null.String `json:"gateway" validate_type:"ipv4" required:"true"`
+ DNS []string `json:"dns" validate_type:"ipv4" required:"true"`
+}
+
+type testIPv6StaticConfig struct {
+ Address null.String `json:"address" validate_type:"ipv6" required:"true"`
+ Prefix null.String `json:"prefix" validate_type:"ipv6" required:"true"`
+ Gateway null.String `json:"gateway" validate_type:"ipv6" required:"true"`
+ DNS []string `json:"dns" validate_type:"ipv6" required:"true"`
+}
+type testNetworkConfig struct {
+ Hostname null.String `json:"hostname,omitempty"`
+ Domain null.String `json:"domain,omitempty"`
+
+ IPv4Mode null.String `json:"ipv4_mode" one_of:"dhcp,static,disabled" default:"dhcp"`
+ IPv4Static *testIPv4StaticConfig `json:"ipv4_static,omitempty" required_if:"IPv4Mode=static"`
+
+ IPv6Mode null.String `json:"ipv6_mode" one_of:"slaac,dhcpv6,slaac_and_dhcpv6,static,link_local,disabled" default:"slaac"`
+ IPv6Static *testIPv6StaticConfig `json:"ipv6_static,omitempty" required_if:"IPv6Mode=static"`
+
+ LLDPMode null.String `json:"lldp_mode,omitempty" one_of:"disabled,basic,all" default:"basic"`
+ LLDPTxTLVs []string `json:"lldp_tx_tlvs,omitempty" one_of:"chassis,port,system,vlan" default:"chassis,port,system,vlan"`
+ MDNSMode null.String `json:"mdns_mode,omitempty" one_of:"disabled,auto,ipv4_only,ipv6_only" default:"auto"`
+ TimeSyncMode null.String `json:"time_sync_mode,omitempty" one_of:"ntp_only,ntp_and_http,http_only,custom" default:"ntp_and_http"`
+ TimeSyncOrdering []string `json:"time_sync_ordering,omitempty" one_of:"http,ntp,ntp_dhcp,ntp_user_provided,ntp_fallback" default:"ntp,http"`
+ TimeSyncDisableFallback null.Bool `json:"time_sync_disable_fallback,omitempty" default:"false"`
+ TimeSyncParallel null.Int `json:"time_sync_parallel,omitempty" default:"4"`
+}
+
+func TestValidateConfig(t *testing.T) {
+ config := &testNetworkConfig{}
+
+ err := SetDefaultsAndValidate(config)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+}
+
+func TestValidateIPv4StaticConfigRequired(t *testing.T) {
+ config := &testNetworkConfig{
+ IPv4Static: &testIPv4StaticConfig{
+ Address: null.StringFrom("192.168.1.1"),
+ Gateway: null.StringFrom("192.168.1.1"),
+ },
+ }
+
+ err := SetDefaultsAndValidate(config)
+ if err == nil {
+ t.Fatalf("expected error, got nil")
+ }
+}
+
+func TestValidateIPv4StaticConfigRequiredIf(t *testing.T) {
+ config := &testNetworkConfig{
+ IPv4Mode: null.StringFrom("static"),
+ }
+
+ err := SetDefaultsAndValidate(config)
+ if err == nil {
+ t.Fatalf("expected error, got nil")
+ }
+}
+
+func TestValidateIPv4StaticConfigValidateType(t *testing.T) {
+ config := &testNetworkConfig{
+ IPv4Static: &testIPv4StaticConfig{
+ Address: null.StringFrom("X"),
+ Netmask: null.StringFrom("255.255.255.0"),
+ Gateway: null.StringFrom("192.168.1.1"),
+ DNS: []string{"8.8.8.8", "8.8.4.4"},
+ },
+ IPv4Mode: null.StringFrom("static"),
+ }
+
+ err := SetDefaultsAndValidate(config)
+ if err == nil {
+ t.Fatalf("expected error, got nil")
+ }
+}
diff --git a/internal/confparser/utils.go b/internal/confparser/utils.go
new file mode 100644
index 0000000..a46871e
--- /dev/null
+++ b/internal/confparser/utils.go
@@ -0,0 +1,28 @@
+package confparser
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/guregu/null/v6"
+)
+
+func splitString(s string) []string {
+ if s == "" {
+ return []string{}
+ }
+
+ return strings.Split(s, ",")
+}
+
+func toString(v interface{}) (string, error) {
+ switch v := v.(type) {
+ case string:
+ return v, nil
+ case null.String:
+ return v.String, nil
+ }
+
+ return "", fmt.Errorf("unsupported type: %s", reflect.TypeOf(v))
+}
diff --git a/internal/logging/logger.go b/internal/logging/logger.go
new file mode 100644
index 0000000..39156ec
--- /dev/null
+++ b/internal/logging/logger.go
@@ -0,0 +1,197 @@
+package logging
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/rs/zerolog"
+)
+
+type Logger struct {
+ l *zerolog.Logger
+ scopeLoggers map[string]*zerolog.Logger
+ scopeLevels map[string]zerolog.Level
+ scopeLevelMutex sync.Mutex
+
+ defaultLogLevelFromEnv zerolog.Level
+ defaultLogLevelFromConfig zerolog.Level
+ defaultLogLevel zerolog.Level
+}
+
+const (
+ defaultLogLevel = zerolog.ErrorLevel
+)
+
+type logOutput struct {
+ mu *sync.Mutex
+}
+
+func (w *logOutput) Write(p []byte) (n int, err error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // TODO: write to file or syslog
+ if sseServer != nil {
+ // use a goroutine to avoid blocking the Write method
+ go func() {
+ sseServer.Message <- string(p)
+ }()
+ }
+ return len(p), nil
+}
+
+var (
+ consoleLogOutput io.Writer = zerolog.ConsoleWriter{
+ Out: os.Stdout,
+ TimeFormat: time.RFC3339,
+ PartsOrder: []string{"time", "level", "scope", "component", "message"},
+ FieldsExclude: []string{"scope", "component"},
+ FormatPartValueByName: func(value interface{}, name string) string {
+ val := fmt.Sprintf("%s", value)
+ if name == "component" {
+ if value == nil {
+ return "-"
+ }
+ }
+ return val
+ },
+ }
+ fileLogOutput io.Writer = &logOutput{mu: &sync.Mutex{}}
+ defaultLogOutput = zerolog.MultiLevelWriter(consoleLogOutput, fileLogOutput)
+
+ zerologLevels = map[string]zerolog.Level{
+ "DISABLE": zerolog.Disabled,
+ "NOLEVEL": zerolog.NoLevel,
+ "PANIC": zerolog.PanicLevel,
+ "FATAL": zerolog.FatalLevel,
+ "ERROR": zerolog.ErrorLevel,
+ "WARN": zerolog.WarnLevel,
+ "INFO": zerolog.InfoLevel,
+ "DEBUG": zerolog.DebugLevel,
+ "TRACE": zerolog.TraceLevel,
+ }
+)
+
+func NewLogger(zerologLogger zerolog.Logger) *Logger {
+ return &Logger{
+ l: &zerologLogger,
+ scopeLoggers: make(map[string]*zerolog.Logger),
+ scopeLevels: make(map[string]zerolog.Level),
+ scopeLevelMutex: sync.Mutex{},
+ defaultLogLevelFromEnv: -2,
+ defaultLogLevelFromConfig: -2,
+ defaultLogLevel: defaultLogLevel,
+ }
+}
+
+func (l *Logger) updateLogLevel() {
+ l.scopeLevelMutex.Lock()
+ defer l.scopeLevelMutex.Unlock()
+
+ l.scopeLevels = make(map[string]zerolog.Level)
+
+ finalDefaultLogLevel := l.defaultLogLevel
+
+ for name, level := range zerologLevels {
+ env := os.Getenv(fmt.Sprintf("JETKVM_LOG_%s", name))
+
+ if env == "" {
+ env = os.Getenv(fmt.Sprintf("PION_LOG_%s", name))
+ }
+
+ if env == "" {
+ env = os.Getenv(fmt.Sprintf("PIONS_LOG_%s", name))
+ }
+
+ if env == "" {
+ continue
+ }
+
+ if strings.ToLower(env) == "all" {
+ l.defaultLogLevelFromEnv = level
+
+ if finalDefaultLogLevel > level {
+ finalDefaultLogLevel = level
+ }
+
+ continue
+ }
+
+ scopes := strings.Split(strings.ToLower(env), ",")
+ for _, scope := range scopes {
+ l.scopeLevels[scope] = level
+ }
+ }
+
+ l.defaultLogLevel = finalDefaultLogLevel
+}
+
+func (l *Logger) getScopeLoggerLevel(scope string) zerolog.Level {
+ if l.scopeLevels == nil {
+ l.updateLogLevel()
+ }
+
+ var scopeLevel zerolog.Level
+ if l.defaultLogLevelFromConfig != -2 {
+ scopeLevel = l.defaultLogLevelFromConfig
+ }
+ if l.defaultLogLevelFromEnv != -2 {
+ scopeLevel = l.defaultLogLevelFromEnv
+ }
+
+ // if the scope is not in the map, use the default level from the root logger
+ if level, ok := l.scopeLevels[scope]; ok {
+ scopeLevel = level
+ }
+
+ return scopeLevel
+}
+
+func (l *Logger) newScopeLogger(scope string) zerolog.Logger {
+ scopeLevel := l.getScopeLoggerLevel(scope)
+ logger := l.l.Level(scopeLevel).With().Str("component", scope).Logger()
+
+ return logger
+}
+
+func (l *Logger) getLogger(scope string) *zerolog.Logger {
+ logger, ok := l.scopeLoggers[scope]
+ if !ok || logger == nil {
+ scopeLogger := l.newScopeLogger(scope)
+ l.scopeLoggers[scope] = &scopeLogger
+ }
+
+ return l.scopeLoggers[scope]
+}
+
+func (l *Logger) UpdateLogLevel(configDefaultLogLevel string) {
+ needUpdate := false
+
+ if configDefaultLogLevel != "" {
+ if logLevel, ok := zerologLevels[configDefaultLogLevel]; ok {
+ l.defaultLogLevelFromConfig = logLevel
+ } else {
+ l.l.Warn().Str("logLevel", configDefaultLogLevel).Msg("invalid defaultLogLevel from config, using ERROR")
+ }
+
+ if l.defaultLogLevelFromConfig != l.defaultLogLevel {
+ needUpdate = true
+ }
+ }
+
+ l.updateLogLevel()
+
+ if needUpdate {
+ for scope, logger := range l.scopeLoggers {
+ currentLevel := logger.GetLevel()
+ targetLevel := l.getScopeLoggerLevel(scope)
+ if currentLevel != targetLevel {
+ *logger = l.newScopeLogger(scope)
+ }
+ }
+ }
+}
diff --git a/internal/logging/pion.go b/internal/logging/pion.go
new file mode 100644
index 0000000..453b8bc
--- /dev/null
+++ b/internal/logging/pion.go
@@ -0,0 +1,63 @@
+package logging
+
+import (
+ "github.com/pion/logging"
+ "github.com/rs/zerolog"
+)
+
+type pionLogger struct {
+ logger *zerolog.Logger
+}
+
+// Print all messages except trace.
+func (c pionLogger) Trace(msg string) {
+ c.logger.Trace().Msg(msg)
+}
+func (c pionLogger) Tracef(format string, args ...interface{}) {
+ c.logger.Trace().Msgf(format, args...)
+}
+
+func (c pionLogger) Debug(msg string) {
+ c.logger.Debug().Msg(msg)
+}
+func (c pionLogger) Debugf(format string, args ...interface{}) {
+ c.logger.Debug().Msgf(format, args...)
+}
+func (c pionLogger) Info(msg string) {
+ c.logger.Info().Msg(msg)
+}
+func (c pionLogger) Infof(format string, args ...interface{}) {
+ c.logger.Info().Msgf(format, args...)
+}
+func (c pionLogger) Warn(msg string) {
+ c.logger.Warn().Msg(msg)
+}
+func (c pionLogger) Warnf(format string, args ...interface{}) {
+ c.logger.Warn().Msgf(format, args...)
+}
+func (c pionLogger) Error(msg string) {
+ c.logger.Error().Msg(msg)
+}
+func (c pionLogger) Errorf(format string, args ...interface{}) {
+ c.logger.Error().Msgf(format, args...)
+}
+
+// customLoggerFactory satisfies the interface logging.LoggerFactory
+// This allows us to create different loggers per subsystem. So we can
+// add custom behavior.
+type pionLoggerFactory struct{}
+
+func (c pionLoggerFactory) NewLogger(subsystem string) logging.LeveledLogger {
+ logger := rootLogger.getLogger(subsystem).With().
+ Str("scope", "pion").
+ Str("component", subsystem).
+ Logger()
+
+ return pionLogger{logger: &logger}
+}
+
+var defaultLoggerFactory = &pionLoggerFactory{}
+
+func GetPionDefaultLoggerFactory() logging.LoggerFactory {
+ return defaultLoggerFactory
+}
diff --git a/internal/logging/root.go b/internal/logging/root.go
new file mode 100644
index 0000000..397ca64
--- /dev/null
+++ b/internal/logging/root.go
@@ -0,0 +1,20 @@
+package logging
+
+import "github.com/rs/zerolog"
+
+var (
+ rootZerologLogger = zerolog.New(defaultLogOutput).With().
+ Str("scope", "jetkvm").
+ Timestamp().
+ Stack().
+ Logger()
+ rootLogger = NewLogger(rootZerologLogger)
+)
+
+func GetRootLogger() *Logger {
+ return rootLogger
+}
+
+func GetSubsystemLogger(subsystem string) *zerolog.Logger {
+ return rootLogger.getLogger(subsystem)
+}
diff --git a/internal/logging/sse.go b/internal/logging/sse.go
new file mode 100644
index 0000000..05e6e9e
--- /dev/null
+++ b/internal/logging/sse.go
@@ -0,0 +1,137 @@
+package logging
+
+import (
+ "embed"
+ "io"
+ "net/http"
+
+ "github.com/gin-gonic/gin"
+ "github.com/rs/zerolog"
+)
+
+//go:embed sse.html
+var sseHTML embed.FS
+
+type sseEvent struct {
+ Message chan string
+ NewClients chan chan string
+ ClosedClients chan chan string
+ TotalClients map[chan string]bool
+}
+
+// New event messages are broadcast to all registered client connection channels
+type sseClientChan chan string
+
+var (
+ sseServer *sseEvent
+ sseLogger *zerolog.Logger
+)
+
+func init() {
+ sseServer = newSseServer()
+ sseLogger = GetSubsystemLogger("sse")
+}
+
+// Initialize event and Start procnteessing requests
+func newSseServer() (event *sseEvent) {
+ event = &sseEvent{
+ Message: make(chan string),
+ NewClients: make(chan chan string),
+ ClosedClients: make(chan chan string),
+ TotalClients: make(map[chan string]bool),
+ }
+
+ go event.listen()
+
+ return
+}
+
+// It Listens all incoming requests from clients.
+// Handles addition and removal of clients and broadcast messages to clients.
+func (stream *sseEvent) listen() {
+ for {
+ select {
+ // Add new available client
+ case client := <-stream.NewClients:
+ stream.TotalClients[client] = true
+ sseLogger.Info().
+ Int("total_clients", len(stream.TotalClients)).
+ Msg("new client connected")
+
+ // Remove closed client
+ case client := <-stream.ClosedClients:
+ delete(stream.TotalClients, client)
+ close(client)
+ sseLogger.Info().Int("total_clients", len(stream.TotalClients)).Msg("client disconnected")
+
+ // Broadcast message to client
+ case eventMsg := <-stream.Message:
+ for clientMessageChan := range stream.TotalClients {
+ select {
+ case clientMessageChan <- eventMsg:
+ // Message sent successfully
+ default:
+ // Failed to send, dropping message
+ }
+ }
+ }
+ }
+}
+
+func (stream *sseEvent) serveHTTP() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ clientChan := make(sseClientChan)
+ stream.NewClients <- clientChan
+
+ go func() {
+ <-c.Writer.CloseNotify()
+
+ for range clientChan {
+ }
+
+ stream.ClosedClients <- clientChan
+ }()
+
+ c.Set("clientChan", clientChan)
+ c.Next()
+ }
+}
+
+func sseHeadersMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if c.Request.Method == "GET" && c.NegotiateFormat(gin.MIMEHTML) == gin.MIMEHTML {
+ c.FileFromFS("/sse.html", http.FS(sseHTML))
+ c.Status(http.StatusOK)
+ c.Abort()
+ return
+ }
+
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("Transfer-Encoding", "chunked")
+ c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
+ c.Next()
+ }
+}
+
+func AttachSSEHandler(router *gin.RouterGroup) {
+ router.StaticFS("/log-stream", http.FS(sseHTML))
+ router.GET("/log-stream", sseHeadersMiddleware(), sseServer.serveHTTP(), func(c *gin.Context) {
+ v, ok := c.Get("clientChan")
+ if !ok {
+ return
+ }
+ clientChan, ok := v.(sseClientChan)
+ if !ok {
+ return
+ }
+ c.Stream(func(w io.Writer) bool {
+ if msg, ok := <-clientChan; ok {
+ c.SSEvent("message", msg)
+ return true
+ }
+ return false
+ })
+ })
+}
diff --git a/internal/logging/sse.html b/internal/logging/sse.html
new file mode 100644
index 0000000..192b464
--- /dev/null
+++ b/internal/logging/sse.html
@@ -0,0 +1,319 @@
+
+
+
+
+
+ Server Sent Event
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/internal/logging/utils.go b/internal/logging/utils.go
new file mode 100644
index 0000000..e622d96
--- /dev/null
+++ b/internal/logging/utils.go
@@ -0,0 +1,32 @@
+package logging
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/rs/zerolog"
+)
+
+var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel)
+
+func GetDefaultLogger() *zerolog.Logger {
+ return &defaultLogger
+}
+
+func ErrorfL(l *zerolog.Logger, format string, err error, args ...interface{}) error {
+ // TODO: move rootLogger to logging package
+ if l == nil {
+ l = &defaultLogger
+ }
+
+ l.Error().Err(err).Msgf(format, args...)
+
+ if err == nil {
+ return fmt.Errorf(format, args...)
+ }
+
+ err_msg := err.Error() + ": %v"
+ err_args := append(args, err)
+
+ return fmt.Errorf(err_msg, err_args...)
+}
diff --git a/internal/mdns/mdns.go b/internal/mdns/mdns.go
new file mode 100644
index 0000000..b882b93
--- /dev/null
+++ b/internal/mdns/mdns.go
@@ -0,0 +1,190 @@
+package mdns
+
+import (
+ "fmt"
+ "net"
+ "reflect"
+ "strings"
+ "sync"
+
+ "github.com/jetkvm/kvm/internal/logging"
+ pion_mdns "github.com/pion/mdns/v2"
+ "github.com/rs/zerolog"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+type MDNS struct {
+ conn *pion_mdns.Conn
+ lock sync.Mutex
+ l *zerolog.Logger
+
+ localNames []string
+ listenOptions *MDNSListenOptions
+}
+
+type MDNSListenOptions struct {
+ IPv4 bool
+ IPv6 bool
+}
+
+type MDNSOptions struct {
+ Logger *zerolog.Logger
+ LocalNames []string
+ ListenOptions *MDNSListenOptions
+}
+
+const (
+ DefaultAddressIPv4 = pion_mdns.DefaultAddressIPv4
+ DefaultAddressIPv6 = pion_mdns.DefaultAddressIPv6
+)
+
+func NewMDNS(opts *MDNSOptions) (*MDNS, error) {
+ if opts.Logger == nil {
+ opts.Logger = logging.GetDefaultLogger()
+ }
+
+ if opts.ListenOptions == nil {
+ opts.ListenOptions = &MDNSListenOptions{
+ IPv4: true,
+ IPv6: true,
+ }
+ }
+
+ return &MDNS{
+ l: opts.Logger,
+ lock: sync.Mutex{},
+ localNames: opts.LocalNames,
+ listenOptions: opts.ListenOptions,
+ }, nil
+}
+
+func (m *MDNS) start(allowRestart bool) error {
+ m.lock.Lock()
+ defer m.lock.Unlock()
+
+ if m.conn != nil {
+ if !allowRestart {
+ return fmt.Errorf("mDNS server already running")
+ }
+
+ m.conn.Close()
+ }
+
+ if m.listenOptions == nil {
+ return fmt.Errorf("listen options not set")
+ }
+
+ if !m.listenOptions.IPv4 && !m.listenOptions.IPv6 {
+ m.l.Info().Msg("mDNS server disabled")
+ return nil
+ }
+
+ var (
+ addr4, addr6 *net.UDPAddr
+ l4, l6 *net.UDPConn
+ p4 *ipv4.PacketConn
+ p6 *ipv6.PacketConn
+ err error
+ )
+
+ if m.listenOptions.IPv4 {
+ addr4, err = net.ResolveUDPAddr("udp4", DefaultAddressIPv4)
+ if err != nil {
+ return err
+ }
+
+ l4, err = net.ListenUDP("udp4", addr4)
+ if err != nil {
+ return err
+ }
+
+ p4 = ipv4.NewPacketConn(l4)
+ }
+
+ if m.listenOptions.IPv6 {
+ addr6, err = net.ResolveUDPAddr("udp6", DefaultAddressIPv6)
+ if err != nil {
+ return err
+ }
+
+ l6, err = net.ListenUDP("udp6", addr6)
+ if err != nil {
+ return err
+ }
+
+ p6 = ipv6.NewPacketConn(l6)
+ }
+
+ scopeLogger := m.l.With().
+ Interface("local_names", m.localNames).
+ Bool("ipv4", m.listenOptions.IPv4).
+ Bool("ipv6", m.listenOptions.IPv6).
+ Logger()
+
+ newLocalNames := make([]string, len(m.localNames))
+ for i, name := range m.localNames {
+ newLocalNames[i] = strings.TrimRight(strings.ToLower(name), ".")
+ if !strings.HasSuffix(newLocalNames[i], ".local") {
+ newLocalNames[i] = newLocalNames[i] + ".local"
+ }
+ }
+
+ mDNSConn, err := pion_mdns.Server(p4, p6, &pion_mdns.Config{
+ LocalNames: newLocalNames,
+ LoggerFactory: logging.GetPionDefaultLoggerFactory(),
+ })
+
+ if err != nil {
+ scopeLogger.Warn().Err(err).Msg("failed to start mDNS server")
+ return err
+ }
+
+ m.conn = mDNSConn
+ scopeLogger.Info().Msg("mDNS server started")
+
+ return nil
+}
+
+func (m *MDNS) Start() error {
+ return m.start(false)
+}
+
+func (m *MDNS) Restart() error {
+ return m.start(true)
+}
+
+func (m *MDNS) Stop() error {
+ m.lock.Lock()
+ defer m.lock.Unlock()
+
+ if m.conn == nil {
+ return nil
+ }
+
+ return m.conn.Close()
+}
+
+func (m *MDNS) SetLocalNames(localNames []string, always bool) error {
+ if reflect.DeepEqual(m.localNames, localNames) && !always {
+ return nil
+ }
+
+ m.localNames = localNames
+ _ = m.Restart()
+
+ return nil
+}
+
+func (m *MDNS) SetListenOptions(listenOptions *MDNSListenOptions) error {
+ if m.listenOptions != nil &&
+ m.listenOptions.IPv4 == listenOptions.IPv4 &&
+ m.listenOptions.IPv6 == listenOptions.IPv6 {
+ return nil
+ }
+
+ m.listenOptions = listenOptions
+ _ = m.Restart()
+
+ return nil
+}
diff --git a/internal/mdns/utils.go b/internal/mdns/utils.go
new file mode 100644
index 0000000..7565eee
--- /dev/null
+++ b/internal/mdns/utils.go
@@ -0,0 +1 @@
+package mdns
diff --git a/internal/network/config.go b/internal/network/config.go
new file mode 100644
index 0000000..74ddf19
--- /dev/null
+++ b/internal/network/config.go
@@ -0,0 +1,110 @@
+package network
+
+import (
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/guregu/null/v6"
+ "github.com/jetkvm/kvm/internal/mdns"
+ "golang.org/x/net/idna"
+)
+
+type IPv6Address struct {
+ Address net.IP `json:"address"`
+ Prefix net.IPNet `json:"prefix"`
+ ValidLifetime *time.Time `json:"valid_lifetime"`
+ PreferredLifetime *time.Time `json:"preferred_lifetime"`
+ Scope int `json:"scope"`
+}
+
+type IPv4StaticConfig struct {
+ Address null.String `json:"address,omitempty" validate_type:"ipv4" required:"true"`
+ Netmask null.String `json:"netmask,omitempty" validate_type:"ipv4" required:"true"`
+ Gateway null.String `json:"gateway,omitempty" validate_type:"ipv4" required:"true"`
+ DNS []string `json:"dns,omitempty" validate_type:"ipv4" required:"true"`
+}
+
+type IPv6StaticConfig struct {
+ Address null.String `json:"address,omitempty" validate_type:"ipv6" required:"true"`
+ Prefix null.String `json:"prefix,omitempty" validate_type:"ipv6" required:"true"`
+ Gateway null.String `json:"gateway,omitempty" validate_type:"ipv6" required:"true"`
+ DNS []string `json:"dns,omitempty" validate_type:"ipv6" required:"true"`
+}
+type NetworkConfig struct {
+ Hostname null.String `json:"hostname,omitempty" validate_type:"hostname"`
+ Domain null.String `json:"domain,omitempty" validate_type:"hostname"`
+
+ IPv4Mode null.String `json:"ipv4_mode,omitempty" one_of:"dhcp,static,disabled" default:"dhcp"`
+ IPv4Static *IPv4StaticConfig `json:"ipv4_static,omitempty" required_if:"IPv4Mode=static"`
+
+ IPv6Mode null.String `json:"ipv6_mode,omitempty" one_of:"slaac,dhcpv6,slaac_and_dhcpv6,static,link_local,disabled" default:"slaac"`
+ IPv6Static *IPv6StaticConfig `json:"ipv6_static,omitempty" required_if:"IPv6Mode=static"`
+
+ LLDPMode null.String `json:"lldp_mode,omitempty" one_of:"disabled,basic,all" default:"basic"`
+ LLDPTxTLVs []string `json:"lldp_tx_tlvs,omitempty" one_of:"chassis,port,system,vlan" default:"chassis,port,system,vlan"`
+ MDNSMode null.String `json:"mdns_mode,omitempty" one_of:"disabled,auto,ipv4_only,ipv6_only" default:"auto"`
+ TimeSyncMode null.String `json:"time_sync_mode,omitempty" one_of:"ntp_only,ntp_and_http,http_only,custom" default:"ntp_and_http"`
+ TimeSyncOrdering []string `json:"time_sync_ordering,omitempty" one_of:"http,ntp,ntp_dhcp,ntp_user_provided,ntp_fallback" default:"ntp,http"`
+ TimeSyncDisableFallback null.Bool `json:"time_sync_disable_fallback,omitempty" default:"false"`
+ TimeSyncParallel null.Int `json:"time_sync_parallel,omitempty" default:"4"`
+}
+
+func (c *NetworkConfig) GetMDNSMode() *mdns.MDNSListenOptions {
+ mode := c.MDNSMode.String
+ listenOptions := &mdns.MDNSListenOptions{
+ IPv4: true,
+ IPv6: true,
+ }
+
+ switch mode {
+ case "ipv4_only":
+ listenOptions.IPv6 = false
+ case "ipv6_only":
+ listenOptions.IPv4 = false
+ case "disabled":
+ listenOptions.IPv4 = false
+ listenOptions.IPv6 = false
+ }
+
+ return listenOptions
+}
+func (s *NetworkInterfaceState) GetHostname() string {
+ hostname := ToValidHostname(s.config.Hostname.String)
+
+ if hostname == "" {
+ return s.defaultHostname
+ }
+
+ return hostname
+}
+
+func ToValidDomain(domain string) string {
+ ascii, err := idna.Lookup.ToASCII(domain)
+ if err != nil {
+ return ""
+ }
+
+ return ascii
+}
+
+func (s *NetworkInterfaceState) GetDomain() string {
+ domain := ToValidDomain(s.config.Domain.String)
+
+ if domain == "" {
+ lease := s.dhcpClient.GetLease()
+ if lease != nil && lease.Domain != "" {
+ domain = ToValidDomain(lease.Domain)
+ }
+ }
+
+ if domain == "" {
+ return "local"
+ }
+
+ return domain
+}
+
+func (s *NetworkInterfaceState) GetFQDN() string {
+ return fmt.Sprintf("%s.%s", s.GetHostname(), s.GetDomain())
+}
diff --git a/internal/network/dhcp.go b/internal/network/dhcp.go
new file mode 100644
index 0000000..9e173cc
--- /dev/null
+++ b/internal/network/dhcp.go
@@ -0,0 +1,11 @@
+package network
+
+type DhcpTargetState int
+
+const (
+ DhcpTargetStateDoNothing DhcpTargetState = iota
+ DhcpTargetStateStart
+ DhcpTargetStateStop
+ DhcpTargetStateRenew
+ DhcpTargetStateRelease
+)
diff --git a/internal/network/hostname.go b/internal/network/hostname.go
new file mode 100644
index 0000000..d75255c
--- /dev/null
+++ b/internal/network/hostname.go
@@ -0,0 +1,137 @@
+package network
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "strings"
+ "sync"
+
+ "golang.org/x/net/idna"
+)
+
+const (
+ hostnamePath = "/etc/hostname"
+ hostsPath = "/etc/hosts"
+)
+
+var (
+ hostnameLock sync.Mutex = sync.Mutex{}
+)
+
+func updateEtcHosts(hostname string, fqdn string) error {
+ // update /etc/hosts
+ hostsFile, err := os.OpenFile(hostsPath, os.O_RDWR|os.O_SYNC, os.ModeExclusive)
+ if err != nil {
+ return fmt.Errorf("failed to open %s: %w", hostsPath, err)
+ }
+ defer hostsFile.Close()
+
+ // read all lines
+ if _, err := hostsFile.Seek(0, io.SeekStart); err != nil {
+ return fmt.Errorf("failed to seek %s: %w", hostsPath, err)
+ }
+
+ lines, err := io.ReadAll(hostsFile)
+ if err != nil {
+ return fmt.Errorf("failed to read %s: %w", hostsPath, err)
+ }
+
+ newLines := []string{}
+ hostLine := fmt.Sprintf("127.0.1.1\t%s %s", hostname, fqdn)
+ hostLineExists := false
+
+ for _, line := range strings.Split(string(lines), "\n") {
+ if strings.HasPrefix(line, "127.0.1.1") {
+ hostLineExists = true
+ line = hostLine
+ }
+ newLines = append(newLines, line)
+ }
+
+ if !hostLineExists {
+ newLines = append(newLines, hostLine)
+ }
+
+ if err := hostsFile.Truncate(0); err != nil {
+ return fmt.Errorf("failed to truncate %s: %w", hostsPath, err)
+ }
+
+ if _, err := hostsFile.Seek(0, io.SeekStart); err != nil {
+ return fmt.Errorf("failed to seek %s: %w", hostsPath, err)
+ }
+
+ if _, err := hostsFile.Write([]byte(strings.Join(newLines, "\n"))); err != nil {
+ return fmt.Errorf("failed to write %s: %w", hostsPath, err)
+ }
+
+ return nil
+}
+
+func ToValidHostname(hostname string) string {
+ ascii, err := idna.Lookup.ToASCII(hostname)
+ if err != nil {
+ return ""
+ }
+ return ascii
+}
+
+func SetHostname(hostname string, fqdn string) error {
+ hostnameLock.Lock()
+ defer hostnameLock.Unlock()
+
+ hostname = ToValidHostname(strings.TrimSpace(hostname))
+ fqdn = ToValidHostname(strings.TrimSpace(fqdn))
+
+ if hostname == "" {
+ return fmt.Errorf("invalid hostname: %s", hostname)
+ }
+
+ if fqdn == "" {
+ fqdn = hostname
+ }
+
+ // update /etc/hostname
+ if err := os.WriteFile(hostnamePath, []byte(hostname), 0644); err != nil {
+ return fmt.Errorf("failed to write %s: %w", hostnamePath, err)
+ }
+
+ // update /etc/hosts
+ if err := updateEtcHosts(hostname, fqdn); err != nil {
+ return fmt.Errorf("failed to update /etc/hosts: %w", err)
+ }
+
+ // run hostname
+ if err := exec.Command("hostname", "-F", hostnamePath).Run(); err != nil {
+ return fmt.Errorf("failed to run hostname: %w", err)
+ }
+
+ return nil
+}
+
+func (s *NetworkInterfaceState) setHostnameIfNotSame() error {
+ hostname := s.GetHostname()
+ currentHostname, _ := os.Hostname()
+
+ fqdn := fmt.Sprintf("%s.%s", hostname, s.GetDomain())
+
+ if currentHostname == hostname && s.currentFqdn == fqdn && s.currentHostname == hostname {
+ return nil
+ }
+
+ scopedLogger := s.l.With().Str("hostname", hostname).Str("fqdn", fqdn).Logger()
+
+ err := SetHostname(hostname, fqdn)
+ if err != nil {
+ scopedLogger.Error().Err(err).Msg("failed to set hostname")
+ return err
+ }
+
+ s.currentHostname = hostname
+ s.currentFqdn = fqdn
+
+ scopedLogger.Info().Msg("hostname set")
+
+ return nil
+}
diff --git a/internal/network/netif.go b/internal/network/netif.go
new file mode 100644
index 0000000..c5db806
--- /dev/null
+++ b/internal/network/netif.go
@@ -0,0 +1,346 @@
+package network
+
+import (
+ "fmt"
+ "net"
+ "sync"
+
+ "github.com/jetkvm/kvm/internal/confparser"
+ "github.com/jetkvm/kvm/internal/logging"
+ "github.com/jetkvm/kvm/internal/udhcpc"
+ "github.com/rs/zerolog"
+
+ "github.com/vishvananda/netlink"
+)
+
+type NetworkInterfaceState struct {
+ interfaceName string
+ interfaceUp bool
+ ipv4Addr *net.IP
+ ipv4Addresses []string
+ ipv6Addr *net.IP
+ ipv6Addresses []IPv6Address
+ ipv6LinkLocal *net.IP
+ macAddr *net.HardwareAddr
+
+ l *zerolog.Logger
+ stateLock sync.Mutex
+
+ config *NetworkConfig
+ dhcpClient *udhcpc.DHCPClient
+
+ defaultHostname string
+ currentHostname string
+ currentFqdn string
+
+ onStateChange func(state *NetworkInterfaceState)
+ onInitialCheck func(state *NetworkInterfaceState)
+ cbConfigChange func(config *NetworkConfig)
+
+ checked bool
+}
+
+type NetworkInterfaceOptions struct {
+ InterfaceName string
+ DhcpPidFile string
+ Logger *zerolog.Logger
+ DefaultHostname string
+ OnStateChange func(state *NetworkInterfaceState)
+ OnInitialCheck func(state *NetworkInterfaceState)
+ OnDhcpLeaseChange func(lease *udhcpc.Lease)
+ OnConfigChange func(config *NetworkConfig)
+ NetworkConfig *NetworkConfig
+}
+
+func NewNetworkInterfaceState(opts *NetworkInterfaceOptions) (*NetworkInterfaceState, error) {
+ if opts.NetworkConfig == nil {
+ return nil, fmt.Errorf("NetworkConfig can not be nil")
+ }
+
+ if opts.DefaultHostname == "" {
+ opts.DefaultHostname = "jetkvm"
+ }
+
+ err := confparser.SetDefaultsAndValidate(opts.NetworkConfig)
+ if err != nil {
+ return nil, err
+ }
+
+ l := opts.Logger
+ s := &NetworkInterfaceState{
+ interfaceName: opts.InterfaceName,
+ defaultHostname: opts.DefaultHostname,
+ stateLock: sync.Mutex{},
+ l: l,
+ onStateChange: opts.OnStateChange,
+ onInitialCheck: opts.OnInitialCheck,
+ cbConfigChange: opts.OnConfigChange,
+ config: opts.NetworkConfig,
+ }
+
+ // create the dhcp client
+ dhcpClient := udhcpc.NewDHCPClient(&udhcpc.DHCPClientOptions{
+ InterfaceName: opts.InterfaceName,
+ PidFile: opts.DhcpPidFile,
+ Logger: l,
+ OnLeaseChange: func(lease *udhcpc.Lease) {
+ _, err := s.update()
+ if err != nil {
+ opts.Logger.Error().Err(err).Msg("failed to update network state")
+ return
+ }
+
+ _ = s.setHostnameIfNotSame()
+
+ opts.OnDhcpLeaseChange(lease)
+ },
+ })
+
+ s.dhcpClient = dhcpClient
+
+ return s, nil
+}
+
+func (s *NetworkInterfaceState) IsUp() bool {
+ return s.interfaceUp
+}
+
+func (s *NetworkInterfaceState) HasIPAssigned() bool {
+ return s.ipv4Addr != nil || s.ipv6Addr != nil
+}
+
+func (s *NetworkInterfaceState) IsOnline() bool {
+ return s.IsUp() && s.HasIPAssigned()
+}
+
+func (s *NetworkInterfaceState) IPv4() *net.IP {
+ return s.ipv4Addr
+}
+
+func (s *NetworkInterfaceState) IPv4String() string {
+ if s.ipv4Addr == nil {
+ return "..."
+ }
+ return s.ipv4Addr.String()
+}
+
+func (s *NetworkInterfaceState) IPv6() *net.IP {
+ return s.ipv6Addr
+}
+
+func (s *NetworkInterfaceState) IPv6String() string {
+ if s.ipv6Addr == nil {
+ return "..."
+ }
+ return s.ipv6Addr.String()
+}
+
+func (s *NetworkInterfaceState) MAC() *net.HardwareAddr {
+ return s.macAddr
+}
+
+func (s *NetworkInterfaceState) MACString() string {
+ if s.macAddr == nil {
+ return ""
+ }
+ return s.macAddr.String()
+}
+
+func (s *NetworkInterfaceState) update() (DhcpTargetState, error) {
+ s.stateLock.Lock()
+ defer s.stateLock.Unlock()
+
+ dhcpTargetState := DhcpTargetStateDoNothing
+
+ iface, err := netlink.LinkByName(s.interfaceName)
+ if err != nil {
+ s.l.Error().Err(err).Msg("failed to get interface")
+ return dhcpTargetState, err
+ }
+
+ // detect if the interface status changed
+ var changed bool
+ attrs := iface.Attrs()
+ state := attrs.OperState
+ newInterfaceUp := state == netlink.OperUp
+
+ // check if the interface is coming up
+ interfaceGoingUp := !s.interfaceUp && newInterfaceUp
+ interfaceGoingDown := s.interfaceUp && !newInterfaceUp
+
+ if s.interfaceUp != newInterfaceUp {
+ s.interfaceUp = newInterfaceUp
+ changed = true
+ }
+
+ if changed {
+ if interfaceGoingUp {
+ s.l.Info().Msg("interface state transitioned to up")
+ dhcpTargetState = DhcpTargetStateRenew
+ } else if interfaceGoingDown {
+ s.l.Info().Msg("interface state transitioned to down")
+ }
+ }
+
+ // set the mac address
+ s.macAddr = &attrs.HardwareAddr
+
+ // get the ip addresses
+ addrs, err := netlinkAddrs(iface)
+ if err != nil {
+ return dhcpTargetState, logging.ErrorfL(s.l, "failed to get ip addresses", err)
+ }
+
+ var (
+ ipv4Addresses = make([]net.IP, 0)
+ ipv4AddressesString = make([]string, 0)
+ ipv6Addresses = make([]IPv6Address, 0)
+ // ipv6AddressesString = make([]string, 0)
+ ipv6LinkLocal *net.IP
+ )
+
+ for _, addr := range addrs {
+ if addr.IP.To4() != nil {
+ scopedLogger := s.l.With().Str("ipv4", addr.IP.String()).Logger()
+ if interfaceGoingDown {
+ // remove all IPv4 addresses from the interface.
+ scopedLogger.Info().Msg("state transitioned to down, removing IPv4 address")
+ err := netlink.AddrDel(iface, &addr)
+ if err != nil {
+ scopedLogger.Warn().Err(err).Msg("failed to delete address")
+ }
+ // notify the DHCP client to release the lease
+ dhcpTargetState = DhcpTargetStateRelease
+ continue
+ }
+ ipv4Addresses = append(ipv4Addresses, addr.IP)
+ ipv4AddressesString = append(ipv4AddressesString, addr.IPNet.String())
+ } else if addr.IP.To16() != nil {
+ scopedLogger := s.l.With().Str("ipv6", addr.IP.String()).Logger()
+ // check if it's a link local address
+ if addr.IP.IsLinkLocalUnicast() {
+ ipv6LinkLocal = &addr.IP
+ continue
+ }
+
+ if !addr.IP.IsGlobalUnicast() {
+ scopedLogger.Trace().Msg("not a global unicast address, skipping")
+ continue
+ }
+
+ if interfaceGoingDown {
+ scopedLogger.Info().Msg("state transitioned to down, removing IPv6 address")
+ err := netlink.AddrDel(iface, &addr)
+ if err != nil {
+ scopedLogger.Warn().Err(err).Msg("failed to delete address")
+ }
+ continue
+ }
+ ipv6Addresses = append(ipv6Addresses, IPv6Address{
+ Address: addr.IP,
+ Prefix: *addr.IPNet,
+ ValidLifetime: lifetimeToTime(addr.ValidLft),
+ PreferredLifetime: lifetimeToTime(addr.PreferedLft),
+ Scope: addr.Scope,
+ })
+ // ipv6AddressesString = append(ipv6AddressesString, addr.IPNet.String())
+ }
+ }
+
+ if len(ipv4Addresses) > 0 {
+ // compare the addresses to see if there's a change
+ if s.ipv4Addr == nil || s.ipv4Addr.String() != ipv4Addresses[0].String() {
+ scopedLogger := s.l.With().Str("ipv4", ipv4Addresses[0].String()).Logger()
+ if s.ipv4Addr != nil {
+ scopedLogger.Info().
+ Str("old_ipv4", s.ipv4Addr.String()).
+ Msg("IPv4 address changed")
+ } else {
+ scopedLogger.Info().Msg("IPv4 address found")
+ }
+ s.ipv4Addr = &ipv4Addresses[0]
+ changed = true
+ }
+ }
+ s.ipv4Addresses = ipv4AddressesString
+
+ if ipv6LinkLocal != nil {
+ if s.ipv6LinkLocal == nil || s.ipv6LinkLocal.String() != ipv6LinkLocal.String() {
+ scopedLogger := s.l.With().Str("ipv6", ipv6LinkLocal.String()).Logger()
+ if s.ipv6LinkLocal != nil {
+ scopedLogger.Info().
+ Str("old_ipv6", s.ipv6LinkLocal.String()).
+ Msg("IPv6 link local address changed")
+ } else {
+ scopedLogger.Info().Msg("IPv6 link local address found")
+ }
+ s.ipv6LinkLocal = ipv6LinkLocal
+ changed = true
+ }
+ }
+ s.ipv6Addresses = ipv6Addresses
+
+ if len(ipv6Addresses) > 0 {
+ // compare the addresses to see if there's a change
+ if s.ipv6Addr == nil || s.ipv6Addr.String() != ipv6Addresses[0].Address.String() {
+ scopedLogger := s.l.With().Str("ipv6", ipv6Addresses[0].Address.String()).Logger()
+ if s.ipv6Addr != nil {
+ scopedLogger.Info().
+ Str("old_ipv6", s.ipv6Addr.String()).
+ Msg("IPv6 address changed")
+ } else {
+ scopedLogger.Info().Msg("IPv6 address found")
+ }
+ s.ipv6Addr = &ipv6Addresses[0].Address
+ changed = true
+ }
+ }
+
+ // if it's the initial check, we'll set changed to false
+ initialCheck := !s.checked
+ if initialCheck {
+ s.checked = true
+ changed = false
+ if dhcpTargetState == DhcpTargetStateRenew {
+ // it's the initial check, we'll start the DHCP client
+ // dhcpTargetState = DhcpTargetStateStart
+ // TODO: manage DHCP client start/stop
+ dhcpTargetState = DhcpTargetStateDoNothing
+ }
+ }
+
+ if initialCheck {
+ s.onInitialCheck(s)
+ } else if changed {
+ s.onStateChange(s)
+ }
+
+ return dhcpTargetState, nil
+}
+
+func (s *NetworkInterfaceState) CheckAndUpdateDhcp() error {
+ dhcpTargetState, err := s.update()
+ if err != nil {
+ return logging.ErrorfL(s.l, "failed to update network state", err)
+ }
+
+ switch dhcpTargetState {
+ case DhcpTargetStateRenew:
+ s.l.Info().Msg("renewing DHCP lease")
+ _ = s.dhcpClient.Renew()
+ case DhcpTargetStateRelease:
+ s.l.Info().Msg("releasing DHCP lease")
+ _ = s.dhcpClient.Release()
+ case DhcpTargetStateStart:
+ s.l.Warn().Msg("dhcpTargetStateStart not implemented")
+ case DhcpTargetStateStop:
+ s.l.Warn().Msg("dhcpTargetStateStop not implemented")
+ }
+
+ return nil
+}
+
+func (s *NetworkInterfaceState) onConfigChange(config *NetworkConfig) {
+ _ = s.setHostnameIfNotSame()
+ s.cbConfigChange(config)
+}
diff --git a/internal/network/netif_linux.go b/internal/network/netif_linux.go
new file mode 100644
index 0000000..ec057f1
--- /dev/null
+++ b/internal/network/netif_linux.go
@@ -0,0 +1,58 @@
+//go:build linux
+
+package network
+
+import (
+ "time"
+
+ "github.com/vishvananda/netlink"
+ "github.com/vishvananda/netlink/nl"
+)
+
+func (s *NetworkInterfaceState) HandleLinkUpdate(update netlink.LinkUpdate) {
+ if update.Link.Attrs().Name == s.interfaceName {
+ s.l.Info().Interface("update", update).Msg("interface link update received")
+ _ = s.CheckAndUpdateDhcp()
+ }
+}
+
+func (s *NetworkInterfaceState) Run() error {
+ updates := make(chan netlink.LinkUpdate)
+ done := make(chan struct{})
+
+ if err := netlink.LinkSubscribe(updates, done); err != nil {
+ s.l.Warn().Err(err).Msg("failed to subscribe to link updates")
+ return err
+ }
+
+ _ = s.setHostnameIfNotSame()
+
+ // run the dhcp client
+ go s.dhcpClient.Run() // nolint:errcheck
+
+ if err := s.CheckAndUpdateDhcp(); err != nil {
+ return err
+ }
+
+ go func() {
+ ticker := time.NewTicker(1 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case update := <-updates:
+ s.HandleLinkUpdate(update)
+ case <-ticker.C:
+ _ = s.CheckAndUpdateDhcp()
+ case <-done:
+ return
+ }
+ }
+ }()
+
+ return nil
+}
+
+func netlinkAddrs(iface netlink.Link) ([]netlink.Addr, error) {
+ return netlink.AddrList(iface, nl.FAMILY_ALL)
+}
diff --git a/internal/network/netif_notlinux.go b/internal/network/netif_notlinux.go
new file mode 100644
index 0000000..d101630
--- /dev/null
+++ b/internal/network/netif_notlinux.go
@@ -0,0 +1,21 @@
+//go:build !linux
+
+package network
+
+import (
+ "fmt"
+
+ "github.com/vishvananda/netlink"
+)
+
+func (s *NetworkInterfaceState) HandleLinkUpdate() error {
+ return fmt.Errorf("not implemented")
+}
+
+func (s *NetworkInterfaceState) Run() error {
+ return fmt.Errorf("not implemented")
+}
+
+func netlinkAddrs(iface netlink.Link) ([]netlink.Addr, error) {
+ return nil, fmt.Errorf("not implemented")
+}
diff --git a/internal/network/rpc.go b/internal/network/rpc.go
new file mode 100644
index 0000000..32f34f5
--- /dev/null
+++ b/internal/network/rpc.go
@@ -0,0 +1,126 @@
+package network
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/jetkvm/kvm/internal/confparser"
+ "github.com/jetkvm/kvm/internal/udhcpc"
+)
+
+type RpcIPv6Address struct {
+ Address string `json:"address"`
+ ValidLifetime *time.Time `json:"valid_lifetime,omitempty"`
+ PreferredLifetime *time.Time `json:"preferred_lifetime,omitempty"`
+ Scope int `json:"scope"`
+}
+
+type RpcNetworkState struct {
+ InterfaceName string `json:"interface_name"`
+ MacAddress string `json:"mac_address"`
+ IPv4 string `json:"ipv4,omitempty"`
+ IPv6 string `json:"ipv6,omitempty"`
+ IPv6LinkLocal string `json:"ipv6_link_local,omitempty"`
+ IPv4Addresses []string `json:"ipv4_addresses,omitempty"`
+ IPv6Addresses []RpcIPv6Address `json:"ipv6_addresses,omitempty"`
+ DHCPLease *udhcpc.Lease `json:"dhcp_lease,omitempty"`
+}
+
+type RpcNetworkSettings struct {
+ NetworkConfig
+}
+
+func (s *NetworkInterfaceState) MacAddress() string {
+ if s.macAddr == nil {
+ return ""
+ }
+
+ return s.macAddr.String()
+}
+
+func (s *NetworkInterfaceState) IPv4Address() string {
+ if s.ipv4Addr == nil {
+ return ""
+ }
+
+ return s.ipv4Addr.String()
+}
+
+func (s *NetworkInterfaceState) IPv6Address() string {
+ if s.ipv6Addr == nil {
+ return ""
+ }
+
+ return s.ipv6Addr.String()
+}
+
+func (s *NetworkInterfaceState) IPv6LinkLocalAddress() string {
+ if s.ipv6LinkLocal == nil {
+ return ""
+ }
+
+ return s.ipv6LinkLocal.String()
+}
+
+func (s *NetworkInterfaceState) RpcGetNetworkState() RpcNetworkState {
+ ipv6Addresses := make([]RpcIPv6Address, 0)
+
+ if s.ipv6Addresses != nil {
+ for _, addr := range s.ipv6Addresses {
+ ipv6Addresses = append(ipv6Addresses, RpcIPv6Address{
+ Address: addr.Prefix.String(),
+ ValidLifetime: addr.ValidLifetime,
+ PreferredLifetime: addr.PreferredLifetime,
+ Scope: addr.Scope,
+ })
+ }
+ }
+
+ return RpcNetworkState{
+ InterfaceName: s.interfaceName,
+ MacAddress: s.MacAddress(),
+ IPv4: s.IPv4Address(),
+ IPv6: s.IPv6Address(),
+ IPv6LinkLocal: s.IPv6LinkLocalAddress(),
+ IPv4Addresses: s.ipv4Addresses,
+ IPv6Addresses: ipv6Addresses,
+ DHCPLease: s.dhcpClient.GetLease(),
+ }
+}
+
+func (s *NetworkInterfaceState) RpcGetNetworkSettings() RpcNetworkSettings {
+ if s.config == nil {
+ return RpcNetworkSettings{}
+ }
+
+ return RpcNetworkSettings{
+ NetworkConfig: *s.config,
+ }
+}
+
+func (s *NetworkInterfaceState) RpcSetNetworkSettings(settings RpcNetworkSettings) error {
+ currentSettings := s.config
+
+ err := confparser.SetDefaultsAndValidate(&settings.NetworkConfig)
+ if err != nil {
+ return err
+ }
+
+ if IsSame(currentSettings, settings.NetworkConfig) {
+ // no changes, do nothing
+ return nil
+ }
+
+ s.config = &settings.NetworkConfig
+ s.onConfigChange(s.config)
+
+ return nil
+}
+
+func (s *NetworkInterfaceState) RpcRenewDHCPLease() error {
+ if s.dhcpClient == nil {
+ return fmt.Errorf("dhcp client not initialized")
+ }
+
+ return s.dhcpClient.Renew()
+}
diff --git a/internal/network/utils.go b/internal/network/utils.go
new file mode 100644
index 0000000..6d64332
--- /dev/null
+++ b/internal/network/utils.go
@@ -0,0 +1,26 @@
+package network
+
+import (
+ "encoding/json"
+ "time"
+)
+
+func lifetimeToTime(lifetime int) *time.Time {
+ if lifetime == 0 {
+ return nil
+ }
+ t := time.Now().Add(time.Duration(lifetime) * time.Second)
+ return &t
+}
+
+func IsSame(a, b interface{}) bool {
+ aJSON, err := json.Marshal(a)
+ if err != nil {
+ return false
+ }
+ bJSON, err := json.Marshal(b)
+ if err != nil {
+ return false
+ }
+ return string(aJSON) == string(bJSON)
+}
diff --git a/internal/timesync/http.go b/internal/timesync/http.go
new file mode 100644
index 0000000..3a51463
--- /dev/null
+++ b/internal/timesync/http.go
@@ -0,0 +1,132 @@
+package timesync
+
+import (
+ "context"
+ "errors"
+ "math/rand"
+ "net/http"
+ "strconv"
+ "time"
+)
+
+var defaultHTTPUrls = []string{
+ "http://www.gstatic.com/generate_204",
+ "http://cp.cloudflare.com/",
+ "http://edge-http.microsoft.com/captiveportal/generate_204",
+ // Firefox, Apple, and Microsoft have inconsistent results, so we don't use it
+ // "http://detectportal.firefox.com/",
+ // "http://www.apple.com/library/test/success.html",
+ // "http://www.msftconnecttest.com/connecttest.txt",
+}
+
+func (t *TimeSync) queryAllHttpTime() (now *time.Time) {
+ chunkSize := 4
+ httpUrls := t.httpUrls
+
+ // shuffle the http urls to avoid always querying the same servers
+ rand.Shuffle(len(httpUrls), func(i, j int) { httpUrls[i], httpUrls[j] = httpUrls[j], httpUrls[i] })
+
+ for i := 0; i < len(httpUrls); i += chunkSize {
+ chunk := httpUrls[i:min(i+chunkSize, len(httpUrls))]
+ results := t.queryMultipleHttp(chunk, timeSyncTimeout)
+ if results != nil {
+ return results
+ }
+ }
+
+ return nil
+}
+
+func (t *TimeSync) queryMultipleHttp(urls []string, timeout time.Duration) (now *time.Time) {
+ results := make(chan *time.Time, len(urls))
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+
+ for _, url := range urls {
+ go func(url string) {
+ scopedLogger := t.l.With().
+ Str("http_url", url).
+ Logger()
+
+ metricHttpRequestCount.WithLabelValues(url).Inc()
+ metricHttpTotalRequestCount.Inc()
+
+ startTime := time.Now()
+ now, response, err := queryHttpTime(
+ ctx,
+ url,
+ timeout,
+ )
+ duration := time.Since(startTime)
+
+ metricHttpServerLastRTT.WithLabelValues(url).Set(float64(duration.Milliseconds()))
+ metricHttpServerRttHistogram.WithLabelValues(url).Observe(float64(duration.Milliseconds()))
+
+ status := 0
+ if response != nil {
+ status = response.StatusCode
+ }
+ metricHttpServerInfo.WithLabelValues(
+ url,
+ strconv.Itoa(status),
+ ).Set(1)
+
+ if err == nil {
+ metricHttpTotalSuccessCount.Inc()
+ metricHttpSuccessCount.WithLabelValues(url).Inc()
+
+ requestId := response.Header.Get("X-Request-Id")
+ if requestId != "" {
+ requestId = response.Header.Get("X-Msedge-Ref")
+ }
+ if requestId == "" {
+ requestId = response.Header.Get("Cf-Ray")
+ }
+ scopedLogger.Info().
+ Str("time", now.Format(time.RFC3339)).
+ Int("status", status).
+ Str("request_id", requestId).
+ Str("time_taken", duration.String()).
+ Msg("HTTP server returned time")
+
+ cancel()
+ results <- now
+ } else if errors.Is(err, context.Canceled) {
+ metricHttpCancelCount.WithLabelValues(url).Inc()
+ metricHttpTotalCancelCount.Inc()
+ } else {
+ scopedLogger.Warn().
+ Str("error", err.Error()).
+ Int("status", status).
+ Msg("failed to query HTTP server")
+ }
+ }(url)
+ }
+
+ return <-results
+}
+
+func queryHttpTime(
+ ctx context.Context,
+ url string,
+ timeout time.Duration,
+) (now *time.Time, response *http.Response, err error) {
+ client := http.Client{
+ Timeout: timeout,
+ }
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, nil, err
+ }
+ dateStr := resp.Header.Get("Date")
+ parsedTime, err := time.Parse(time.RFC1123, dateStr)
+ if err != nil {
+ return nil, nil, err
+ }
+ return &parsedTime, resp, nil
+}
diff --git a/internal/timesync/metrics.go b/internal/timesync/metrics.go
new file mode 100644
index 0000000..0e28acb
--- /dev/null
+++ b/internal/timesync/metrics.go
@@ -0,0 +1,147 @@
+package timesync
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+var (
+ metricTimeSyncStatus = promauto.NewGauge(
+ prometheus.GaugeOpts{
+ Name: "jetkvm_timesync_status",
+ Help: "The status of the timesync, 1 if successful, 0 if not",
+ },
+ )
+ metricTimeSyncCount = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_count",
+ Help: "The number of times the timesync has been run",
+ },
+ )
+ metricTimeSyncSuccessCount = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_success_count",
+ Help: "The number of times the timesync has been successful",
+ },
+ )
+ metricRTCUpdateCount = promauto.NewCounter( //nolint:unused
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_rtc_update_count",
+ Help: "The number of times the RTC has been updated",
+ },
+ )
+ metricNtpTotalSuccessCount = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_ntp_total_success_count",
+ Help: "The total number of successful NTP requests",
+ },
+ )
+ metricNtpTotalRequestCount = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_ntp_total_request_count",
+ Help: "The total number of NTP requests sent",
+ },
+ )
+ metricNtpSuccessCount = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_ntp_success_count",
+ Help: "The number of successful NTP requests",
+ },
+ []string{"url"},
+ )
+ metricNtpRequestCount = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_ntp_request_count",
+ Help: "The number of NTP requests sent to the server",
+ },
+ []string{"url"},
+ )
+ metricNtpServerLastRTT = promauto.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "jetkvm_timesync_ntp_server_last_rtt",
+ Help: "The last RTT of the NTP server in milliseconds",
+ },
+ []string{"url"},
+ )
+ metricNtpServerRttHistogram = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Name: "jetkvm_timesync_ntp_server_rtt",
+ Help: "The histogram of the RTT of the NTP server in milliseconds",
+ Buckets: []float64{
+ 10, 25, 50, 100, 200, 300, 500, 1000,
+ },
+ },
+ []string{"url"},
+ )
+ metricNtpServerInfo = promauto.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "jetkvm_timesync_ntp_server_info",
+ Help: "The info of the NTP server",
+ },
+ []string{"url", "reference", "stratum", "precision"},
+ )
+
+ metricHttpTotalSuccessCount = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_http_total_success_count",
+ Help: "The total number of successful HTTP requests",
+ },
+ )
+ metricHttpTotalRequestCount = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_http_total_request_count",
+ Help: "The total number of HTTP requests sent",
+ },
+ )
+ metricHttpTotalCancelCount = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_http_total_cancel_count",
+ Help: "The total number of HTTP requests cancelled",
+ },
+ )
+ metricHttpSuccessCount = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_http_success_count",
+ Help: "The number of successful HTTP requests",
+ },
+ []string{"url"},
+ )
+ metricHttpRequestCount = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_http_request_count",
+ Help: "The number of HTTP requests sent to the server",
+ },
+ []string{"url"},
+ )
+ metricHttpCancelCount = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "jetkvm_timesync_http_cancel_count",
+ Help: "The number of HTTP requests cancelled",
+ },
+ []string{"url"},
+ )
+ metricHttpServerLastRTT = promauto.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "jetkvm_timesync_http_server_last_rtt",
+ Help: "The last RTT of the HTTP server in milliseconds",
+ },
+ []string{"url"},
+ )
+ metricHttpServerRttHistogram = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Name: "jetkvm_timesync_http_server_rtt",
+ Help: "The histogram of the RTT of the HTTP server in milliseconds",
+ Buckets: []float64{
+ 10, 25, 50, 100, 200, 300, 500, 1000,
+ },
+ },
+ []string{"url"},
+ )
+ metricHttpServerInfo = promauto.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "jetkvm_timesync_http_server_info",
+ Help: "The info of the HTTP server",
+ },
+ []string{"url", "http_code"},
+ )
+)
diff --git a/internal/timesync/ntp.go b/internal/timesync/ntp.go
new file mode 100644
index 0000000..41656b7
--- /dev/null
+++ b/internal/timesync/ntp.go
@@ -0,0 +1,113 @@
+package timesync
+
+import (
+ "math/rand/v2"
+ "strconv"
+ "time"
+
+ "github.com/beevik/ntp"
+)
+
+var defaultNTPServers = []string{
+ "time.apple.com",
+ "time.aws.com",
+ "time.windows.com",
+ "time.google.com",
+ "162.159.200.123", // time.cloudflare.com
+ "0.pool.ntp.org",
+ "1.pool.ntp.org",
+ "2.pool.ntp.org",
+ "3.pool.ntp.org",
+}
+
+func (t *TimeSync) queryNetworkTime() (now *time.Time, offset *time.Duration) {
+ chunkSize := 4
+ ntpServers := t.ntpServers
+
+ // shuffle the ntp servers to avoid always querying the same servers
+ rand.Shuffle(len(ntpServers), func(i, j int) { ntpServers[i], ntpServers[j] = ntpServers[j], ntpServers[i] })
+
+ for i := 0; i < len(ntpServers); i += chunkSize {
+ chunk := ntpServers[i:min(i+chunkSize, len(ntpServers))]
+ now, offset := t.queryMultipleNTP(chunk, timeSyncTimeout)
+ if now != nil {
+ return now, offset
+ }
+ }
+
+ return nil, nil
+}
+
+type ntpResult struct {
+ now *time.Time
+ offset *time.Duration
+}
+
+func (t *TimeSync) queryMultipleNTP(servers []string, timeout time.Duration) (now *time.Time, offset *time.Duration) {
+ results := make(chan *ntpResult, len(servers))
+ for _, server := range servers {
+ go func(server string) {
+ scopedLogger := t.l.With().
+ Str("server", server).
+ Logger()
+
+ // increase request count
+ metricNtpTotalRequestCount.Inc()
+ metricNtpRequestCount.WithLabelValues(server).Inc()
+
+ // query the server
+ now, response, err := queryNtpServer(server, timeout)
+
+ // set the last RTT
+ metricNtpServerLastRTT.WithLabelValues(
+ server,
+ ).Set(float64(response.RTT.Milliseconds()))
+
+ // set the RTT histogram
+ metricNtpServerRttHistogram.WithLabelValues(
+ server,
+ ).Observe(float64(response.RTT.Milliseconds()))
+
+ // set the server info
+ metricNtpServerInfo.WithLabelValues(
+ server,
+ response.ReferenceString(),
+ strconv.Itoa(int(response.Stratum)),
+ strconv.Itoa(int(response.Precision)),
+ ).Set(1)
+
+ if err == nil {
+ // increase success count
+ metricNtpTotalSuccessCount.Inc()
+ metricNtpSuccessCount.WithLabelValues(server).Inc()
+
+ scopedLogger.Info().
+ Str("time", now.Format(time.RFC3339)).
+ Str("reference", response.ReferenceString()).
+ Str("rtt", response.RTT.String()).
+ Str("clockOffset", response.ClockOffset.String()).
+ Uint8("stratum", response.Stratum).
+ Msg("NTP server returned time")
+ results <- &ntpResult{
+ now: now,
+ offset: &response.ClockOffset,
+ }
+ } else {
+ scopedLogger.Warn().
+ Str("error", err.Error()).
+ Msg("failed to query NTP server")
+ }
+ }(server)
+ }
+
+ result := <-results
+ return result.now, result.offset
+}
+
+func queryNtpServer(server string, timeout time.Duration) (now *time.Time, response *ntp.Response, err error) {
+ resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout})
+ if err != nil {
+ return nil, nil, err
+ }
+ return &resp.Time, resp, nil
+}
diff --git a/internal/timesync/rtc.go b/internal/timesync/rtc.go
new file mode 100644
index 0000000..92ee485
--- /dev/null
+++ b/internal/timesync/rtc.go
@@ -0,0 +1,26 @@
+package timesync
+
+import (
+ "fmt"
+ "os"
+)
+
+var (
+ rtcDeviceSearchPaths = []string{
+ "/dev/rtc",
+ "/dev/rtc0",
+ "/dev/rtc1",
+ "/dev/misc/rtc",
+ "/dev/misc/rtc0",
+ "/dev/misc/rtc1",
+ }
+)
+
+func getRtcDevicePath() (string, error) {
+ for _, path := range rtcDeviceSearchPaths {
+ if _, err := os.Stat(path); err == nil {
+ return path, nil
+ }
+ }
+ return "", fmt.Errorf("rtc device not found")
+}
diff --git a/internal/timesync/rtc_linux.go b/internal/timesync/rtc_linux.go
new file mode 100644
index 0000000..27e4ec7
--- /dev/null
+++ b/internal/timesync/rtc_linux.go
@@ -0,0 +1,105 @@
+//go:build linux
+
+package timesync
+
+import (
+ "fmt"
+ "os"
+ "time"
+
+ "golang.org/x/sys/unix"
+)
+
+func TimetoRtcTime(t time.Time) unix.RTCTime {
+ return unix.RTCTime{
+ Sec: int32(t.Second()),
+ Min: int32(t.Minute()),
+ Hour: int32(t.Hour()),
+ Mday: int32(t.Day()),
+ Mon: int32(t.Month() - 1),
+ Year: int32(t.Year() - 1900),
+ Wday: int32(0),
+ Yday: int32(0),
+ Isdst: int32(0),
+ }
+}
+
+func RtcTimetoTime(t unix.RTCTime) time.Time {
+ return time.Date(
+ int(t.Year)+1900,
+ time.Month(t.Mon+1),
+ int(t.Mday),
+ int(t.Hour),
+ int(t.Min),
+ int(t.Sec),
+ 0,
+ time.UTC,
+ )
+}
+
+func (t *TimeSync) getRtcDevice() (*os.File, error) {
+ if t.rtcDevice == nil {
+ file, err := os.OpenFile(t.rtcDevicePath, os.O_RDWR, 0666)
+ if err != nil {
+ return nil, err
+ }
+ t.rtcDevice = file
+ }
+ return t.rtcDevice, nil
+}
+
+func (t *TimeSync) getRtcDeviceFd() (int, error) {
+ device, err := t.getRtcDevice()
+ if err != nil {
+ return 0, err
+ }
+ return int(device.Fd()), nil
+}
+
+// Read implements Read for the Linux RTC
+func (t *TimeSync) readRtcTime() (time.Time, error) {
+ fd, err := t.getRtcDeviceFd()
+ if err != nil {
+ return time.Time{}, fmt.Errorf("failed to get RTC device fd: %w", err)
+ }
+
+ rtcTime, err := unix.IoctlGetRTCTime(fd)
+ if err != nil {
+ return time.Time{}, fmt.Errorf("failed to get RTC time: %w", err)
+ }
+
+ date := RtcTimetoTime(*rtcTime)
+
+ return date, nil
+}
+
+// Set implements Set for the Linux RTC
+// ...
+// It might be not accurate as the time consumed by the system call is not taken into account
+// but it's good enough for our purposes
+func (t *TimeSync) setRtcTime(tu time.Time) error {
+ rt := TimetoRtcTime(tu)
+
+ fd, err := t.getRtcDeviceFd()
+ if err != nil {
+ return fmt.Errorf("failed to get RTC device fd: %w", err)
+ }
+
+ currentRtcTime, err := t.readRtcTime()
+ if err != nil {
+ return fmt.Errorf("failed to read RTC time: %w", err)
+ }
+
+ t.l.Info().
+ Interface("rtc_time", tu).
+ Str("offset", tu.Sub(currentRtcTime).String()).
+ Msg("set rtc time")
+
+ if err := unix.IoctlSetRTCTime(fd, &rt); err != nil {
+ return fmt.Errorf("failed to set RTC time: %w", err)
+ }
+
+ metricRTCUpdateCount.Inc()
+
+ return nil
+}
diff --git a/internal/timesync/rtc_notlinux.go b/internal/timesync/rtc_notlinux.go
new file mode 100644
index 0000000..e3c1b20
--- /dev/null
+++ b/internal/timesync/rtc_notlinux.go
@@ -0,0 +1,16 @@
+//go:build !linux
+
+package timesync
+
+import (
+ "errors"
+ "time"
+)
+
+func (t *TimeSync) readRtcTime() (time.Time, error) {
+ return time.Now(), nil
+}
+
+func (t *TimeSync) setRtcTime(tu time.Time) error {
+ return errors.New("not supported")
+}
diff --git a/internal/timesync/timesync.go b/internal/timesync/timesync.go
new file mode 100644
index 0000000..e956cf9
--- /dev/null
+++ b/internal/timesync/timesync.go
@@ -0,0 +1,208 @@
+package timesync
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "sync"
+ "time"
+
+ "github.com/jetkvm/kvm/internal/network"
+ "github.com/rs/zerolog"
+)
+
+const (
+ timeSyncRetryStep = 5 * time.Second
+ timeSyncRetryMaxInt = 1 * time.Minute
+ timeSyncWaitNetChkInt = 100 * time.Millisecond
+ timeSyncWaitNetUpInt = 3 * time.Second
+ timeSyncInterval = 1 * time.Hour
+ timeSyncTimeout = 2 * time.Second
+)
+
+var (
+ timeSyncRetryInterval = 0 * time.Second
+)
+
+type TimeSync struct {
+ syncLock *sync.Mutex
+ l *zerolog.Logger
+
+ ntpServers []string
+ httpUrls []string
+ networkConfig *network.NetworkConfig
+
+ rtcDevicePath string
+ rtcDevice *os.File //nolint:unused
+ rtcLock *sync.Mutex
+
+ syncSuccess bool
+
+ preCheckFunc func() (bool, error)
+}
+
+type TimeSyncOptions struct {
+ PreCheckFunc func() (bool, error)
+ Logger *zerolog.Logger
+ NetworkConfig *network.NetworkConfig
+}
+
+type SyncMode struct {
+ Ntp bool
+ Http bool
+ Ordering []string
+ NtpUseFallback bool
+ HttpUseFallback bool
+}
+
+func NewTimeSync(opts *TimeSyncOptions) *TimeSync {
+ rtcDevice, err := getRtcDevicePath()
+ if err != nil {
+ opts.Logger.Error().Err(err).Msg("failed to get RTC device path")
+ } else {
+ opts.Logger.Info().Str("path", rtcDevice).Msg("RTC device found")
+ }
+
+ t := &TimeSync{
+ syncLock: &sync.Mutex{},
+ l: opts.Logger,
+ rtcDevicePath: rtcDevice,
+ rtcLock: &sync.Mutex{},
+ preCheckFunc: opts.PreCheckFunc,
+ ntpServers: defaultNTPServers,
+ httpUrls: defaultHTTPUrls,
+ networkConfig: opts.NetworkConfig,
+ }
+
+ if t.rtcDevicePath != "" {
+ rtcTime, _ := t.readRtcTime()
+ t.l.Info().Interface("rtc_time", rtcTime).Msg("read RTC time")
+ }
+
+ return t
+}
+
+func (t *TimeSync) getSyncMode() SyncMode {
+ syncMode := SyncMode{
+ NtpUseFallback: true,
+ HttpUseFallback: true,
+ }
+ var syncModeString string
+
+ if t.networkConfig != nil {
+ syncModeString = t.networkConfig.TimeSyncMode.String
+ if t.networkConfig.TimeSyncDisableFallback.Bool {
+ syncMode.NtpUseFallback = false
+ syncMode.HttpUseFallback = false
+ }
+ }
+
+ switch syncModeString {
+ case "ntp_only":
+ syncMode.Ntp = true
+ case "http_only":
+ syncMode.Http = true
+ default:
+ syncMode.Ntp = true
+ syncMode.Http = true
+ }
+
+ return syncMode
+}
+
+func (t *TimeSync) doTimeSync() {
+ metricTimeSyncStatus.Set(0)
+ for {
+ if ok, err := t.preCheckFunc(); !ok {
+ if err != nil {
+ t.l.Error().Err(err).Msg("pre-check failed")
+ }
+ time.Sleep(timeSyncWaitNetChkInt)
+ continue
+ }
+
+ t.l.Info().Msg("syncing system time")
+ start := time.Now()
+ err := t.Sync()
+ if err != nil {
+ t.l.Error().Str("error", err.Error()).Msg("failed to sync system time")
+
+ // retry after a delay
+ timeSyncRetryInterval += timeSyncRetryStep
+ time.Sleep(timeSyncRetryInterval)
+ // reset the retry interval if it exceeds the max interval
+ if timeSyncRetryInterval > timeSyncRetryMaxInt {
+ timeSyncRetryInterval = 0
+ }
+
+ continue
+ }
+ t.syncSuccess = true
+ t.l.Info().Str("now", time.Now().Format(time.RFC3339)).
+ Str("time_taken", time.Since(start).String()).
+ Msg("time sync successful")
+
+ metricTimeSyncStatus.Set(1)
+
+ time.Sleep(timeSyncInterval) // after the first sync is done
+ }
+}
+
+func (t *TimeSync) Sync() error {
+ var (
+ now *time.Time
+ offset *time.Duration
+ )
+
+ syncMode := t.getSyncMode()
+
+ metricTimeSyncCount.Inc()
+
+ if syncMode.Ntp {
+ now, offset = t.queryNetworkTime()
+ }
+
+ if syncMode.Http && now == nil {
+ now = t.queryAllHttpTime()
+ }
+
+ if now == nil {
+ return fmt.Errorf("failed to get time from any source")
+ }
+
+ if offset != nil {
+ newNow := time.Now().Add(*offset)
+ now = &newNow
+ }
+
+ err := t.setSystemTime(*now)
+ if err != nil {
+ return fmt.Errorf("failed to set system time: %w", err)
+ }
+
+ metricTimeSyncSuccessCount.Inc()
+
+ return nil
+}
+
+func (t *TimeSync) IsSyncSuccess() bool {
+ return t.syncSuccess
+}
+
+func (t *TimeSync) Start() {
+ go t.doTimeSync()
+}
+
+func (t *TimeSync) setSystemTime(now time.Time) error {
+ nowStr := now.Format("2006-01-02 15:04:05")
+ output, err := exec.Command("date", "-s", nowStr).CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("failed to run date -s: %w, %s", err, string(output))
+ }
+
+ if t.rtcDevicePath != "" {
+ return t.setRtcTime(now)
+ }
+
+ return nil
+}
diff --git a/internal/udhcpc/options.go b/internal/udhcpc/options.go
new file mode 100644
index 0000000..10c9f75
--- /dev/null
+++ b/internal/udhcpc/options.go
@@ -0,0 +1,12 @@
+package udhcpc
+
+func (u *DHCPClient) GetNtpServers() []string {
+ if u.lease == nil {
+ return nil
+ }
+ servers := make([]string, len(u.lease.NTPServers))
+ for i, server := range u.lease.NTPServers {
+ servers[i] = server.String()
+ }
+ return servers
+}
diff --git a/internal/udhcpc/parser.go b/internal/udhcpc/parser.go
new file mode 100644
index 0000000..66c3ba2
--- /dev/null
+++ b/internal/udhcpc/parser.go
@@ -0,0 +1,186 @@
+package udhcpc
+
+import (
+ "bufio"
+ "encoding/json"
+ "fmt"
+ "net"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type Lease struct {
+ // from https://udhcp.busybox.net/README.udhcpc
+ IPAddress net.IP `env:"ip" json:"ip"` // The obtained IP
+ Netmask net.IP `env:"subnet" json:"netmask"` // The assigned subnet mask
+ Broadcast net.IP `env:"broadcast" json:"broadcast"` // The broadcast address for this network
+ TTL int `env:"ipttl" json:"ttl,omitempty"` // The TTL to use for this network
+ MTU int `env:"mtu" json:"mtu,omitempty"` // The MTU to use for this network
+ HostName string `env:"hostname" json:"hostname,omitempty"` // The assigned hostname
+ Domain string `env:"domain" json:"domain,omitempty"` // The domain name of the network
+ BootPNextServer net.IP `env:"siaddr" json:"bootp_next_server,omitempty"` // The bootp next server option
+ BootPServerName string `env:"sname" json:"bootp_server_name,omitempty"` // The bootp server name option
+ BootPFile string `env:"boot_file" json:"bootp_file,omitempty"` // The bootp boot file option
+ Timezone string `env:"timezone" json:"timezone,omitempty"` // Offset in seconds from UTC
+ Routers []net.IP `env:"router" json:"routers,omitempty"` // A list of routers
+ DNS []net.IP `env:"dns" json:"dns_servers,omitempty"` // A list of DNS servers
+ NTPServers []net.IP `env:"ntpsrv" json:"ntp_servers,omitempty"` // A list of NTP servers
+ LPRServers []net.IP `env:"lprsvr" json:"lpr_servers,omitempty"` // A list of LPR servers
+ TimeServers []net.IP `env:"timesvr" json:"_time_servers,omitempty"` // A list of time servers (obsolete)
+ IEN116NameServers []net.IP `env:"namesvr" json:"_name_servers,omitempty"` // A list of IEN 116 name servers (obsolete)
+ LogServers []net.IP `env:"logsvr" json:"_log_servers,omitempty"` // A list of MIT-LCS UDP log servers (obsolete)
+ CookieServers []net.IP `env:"cookiesvr" json:"_cookie_servers,omitempty"` // A list of RFC 865 cookie servers (obsolete)
+ WINSServers []net.IP `env:"wins" json:"_wins_servers,omitempty"` // A list of WINS servers
+ SwapServer net.IP `env:"swapsvr" json:"_swap_server,omitempty"` // The IP address of the client's swap server
+ BootSize int `env:"bootsize" json:"bootsize,omitempty"` // The length in 512 octect blocks of the bootfile
+ RootPath string `env:"rootpath" json:"root_path,omitempty"` // The path name of the client's root disk
+ LeaseTime time.Duration `env:"lease" json:"lease,omitempty"` // The lease time, in seconds
+ DHCPType string `env:"dhcptype" json:"dhcp_type,omitempty"` // DHCP message type (safely ignored)
+ ServerID string `env:"serverid" json:"server_id,omitempty"` // The IP of the server
+ Message string `env:"message" json:"reason,omitempty"` // Reason for a DHCPNAK
+ TFTPServerName string `env:"tftp" json:"tftp,omitempty"` // The TFTP server name
+ BootFileName string `env:"bootfile" json:"bootfile,omitempty"` // The boot file name
+ Uptime time.Duration `env:"uptime" json:"uptime,omitempty"` // The uptime of the device when the lease was obtained, in seconds
+ LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // The expiry time of the lease
+ isEmpty map[string]bool
+}
+
+func (l *Lease) setIsEmpty(m map[string]bool) {
+ l.isEmpty = m
+}
+
+func (l *Lease) IsEmpty(key string) bool {
+ return l.isEmpty[key]
+}
+
+func (l *Lease) ToJSON() string {
+ json, err := json.Marshal(l)
+ if err != nil {
+ return ""
+ }
+ return string(json)
+}
+
+func (l *Lease) SetLeaseExpiry() (time.Time, error) {
+ if l.Uptime == 0 || l.LeaseTime == 0 {
+ return time.Time{}, fmt.Errorf("uptime or lease time isn't set")
+ }
+
+ // get the uptime of the device
+
+ file, err := os.Open("/proc/uptime")
+ if err != nil {
+ return time.Time{}, fmt.Errorf("failed to open uptime file: %w", err)
+ }
+ defer file.Close()
+
+ var uptime time.Duration
+
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ text := scanner.Text()
+ parts := strings.Split(text, " ")
+ uptime, err = time.ParseDuration(parts[0] + "s")
+
+ if err != nil {
+ return time.Time{}, fmt.Errorf("failed to parse uptime: %w", err)
+ }
+ }
+
+ relativeLeaseRemaining := (l.Uptime + l.LeaseTime) - uptime
+ leaseExpiry := time.Now().Add(relativeLeaseRemaining)
+
+ l.LeaseExpiry = &leaseExpiry
+
+ return leaseExpiry, nil
+}
+
+func UnmarshalDHCPCLease(lease *Lease, str string) error {
+ // parse the lease file as a map
+ data := make(map[string]string)
+ for _, line := range strings.Split(str, "\n") {
+ line = strings.TrimSpace(line)
+ // skip empty lines and comments
+ if line == "" || strings.HasPrefix(line, "#") {
+ continue
+ }
+
+ parts := strings.SplitN(line, "=", 2)
+ if len(parts) != 2 {
+ continue
+ }
+
+ key := strings.TrimSpace(parts[0])
+ value := strings.TrimSpace(parts[1])
+
+ data[key] = value
+ }
+
+ // now iterate over the lease struct and set the values
+ leaseType := reflect.TypeOf(lease).Elem()
+ leaseValue := reflect.ValueOf(lease).Elem()
+
+ valuesParsed := make(map[string]bool)
+
+ for i := 0; i < leaseType.NumField(); i++ {
+ field := leaseValue.Field(i)
+
+ // get the env tag
+ key := leaseType.Field(i).Tag.Get("env")
+ if key == "" {
+ continue
+ }
+
+ valuesParsed[key] = false
+
+ // get the value from the data map
+ value, ok := data[key]
+ if !ok || value == "" {
+ continue
+ }
+
+ switch field.Interface().(type) {
+ case string:
+ field.SetString(value)
+ case int:
+ val, err := strconv.Atoi(value)
+ if err != nil {
+ continue
+ }
+ field.SetInt(int64(val))
+ case time.Duration:
+ val, err := time.ParseDuration(value + "s")
+ if err != nil {
+ continue
+ }
+ field.Set(reflect.ValueOf(val))
+ case net.IP:
+ ip := net.ParseIP(value)
+ if ip == nil {
+ continue
+ }
+ field.Set(reflect.ValueOf(ip))
+ case []net.IP:
+ val := make([]net.IP, 0)
+ for _, ipStr := range strings.Fields(value) {
+ ip := net.ParseIP(ipStr)
+ if ip == nil {
+ continue
+ }
+ val = append(val, ip)
+ }
+ field.Set(reflect.ValueOf(val))
+ default:
+ return fmt.Errorf("unsupported field `%s` type: %s", key, field.Type().String())
+ }
+
+ valuesParsed[key] = true
+ }
+
+ lease.setIsEmpty(valuesParsed)
+
+ return nil
+}
diff --git a/internal/udhcpc/parser_test.go b/internal/udhcpc/parser_test.go
new file mode 100644
index 0000000..423ab53
--- /dev/null
+++ b/internal/udhcpc/parser_test.go
@@ -0,0 +1,74 @@
+package udhcpc
+
+import (
+ "testing"
+ "time"
+)
+
+func TestUnmarshalDHCPCLease(t *testing.T) {
+ lease := &Lease{}
+ err := UnmarshalDHCPCLease(lease, `
+# generated @ Mon Jan 4 19:31:53 UTC 2021
+# 19:31:53 up 0 min, 0 users, load average: 0.72, 0.14, 0.04
+# the date might be inaccurate if the clock is not set
+ip=192.168.0.240
+siaddr=192.168.0.1
+sname=
+boot_file=
+subnet=255.255.255.0
+timezone=
+router=192.168.0.1
+timesvr=
+namesvr=
+dns=172.19.53.2
+logsvr=
+cookiesvr=
+lprsvr=
+hostname=
+bootsize=
+domain=
+swapsvr=
+rootpath=
+ipttl=
+mtu=
+broadcast=
+ntpsrv=162.159.200.123
+wins=
+lease=172800
+dhcptype=
+serverid=192.168.0.1
+message=
+tftp=
+bootfile=
+ `)
+ if lease.IPAddress.String() != "192.168.0.240" {
+ t.Fatalf("expected ip to be 192.168.0.240, got %s", lease.IPAddress.String())
+ }
+ if lease.Netmask.String() != "255.255.255.0" {
+ t.Fatalf("expected netmask to be 255.255.255.0, got %s", lease.Netmask.String())
+ }
+ if len(lease.Routers) != 1 {
+ t.Fatalf("expected 1 router, got %d", len(lease.Routers))
+ }
+ if lease.Routers[0].String() != "192.168.0.1" {
+ t.Fatalf("expected router to be 192.168.0.1, got %s", lease.Routers[0].String())
+ }
+ if len(lease.NTPServers) != 1 {
+ t.Fatalf("expected 1 timeserver, got %d", len(lease.NTPServers))
+ }
+ if lease.NTPServers[0].String() != "162.159.200.123" {
+ t.Fatalf("expected timeserver to be 162.159.200.123, got %s", lease.NTPServers[0].String())
+ }
+ if len(lease.DNS) != 1 {
+ t.Fatalf("expected 1 dns, got %d", len(lease.DNS))
+ }
+ if lease.DNS[0].String() != "172.19.53.2" {
+ t.Fatalf("expected dns to be 172.19.53.2, got %s", lease.DNS[0].String())
+ }
+ if lease.LeaseTime != 172800*time.Second {
+ t.Fatalf("expected lease time to be 172800 seconds, got %d", lease.LeaseTime)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/internal/udhcpc/proc.go b/internal/udhcpc/proc.go
new file mode 100644
index 0000000..69c2ab9
--- /dev/null
+++ b/internal/udhcpc/proc.go
@@ -0,0 +1,212 @@
+package udhcpc
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "syscall"
+)
+
+func readFileNoStat(filename string) ([]byte, error) {
+ const maxBufferSize = 1024 * 1024
+
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ reader := io.LimitReader(f, maxBufferSize)
+ return io.ReadAll(reader)
+}
+
+func toCmdline(path string) ([]string, error) {
+ data, err := readFileNoStat(path)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(data) < 1 {
+ return []string{}, nil
+ }
+
+ return strings.Split(string(bytes.TrimRight(data, "\x00")), "\x00"), nil
+}
+
+func (p *DHCPClient) findUdhcpcProcess() (int, error) {
+ // read procfs for udhcpc processes
+ // we do not use procfs.AllProcs() because we want to avoid the overhead of reading the entire procfs
+ processes, err := os.ReadDir("/proc")
+ if err != nil {
+ return 0, err
+ }
+
+ // iterate over the processes
+ for _, d := range processes {
+ // check if file is numeric
+ pid, err := strconv.Atoi(d.Name())
+ if err != nil {
+ continue
+ }
+
+ // check if it's a directory
+ if !d.IsDir() {
+ continue
+ }
+
+ cmdline, err := toCmdline(filepath.Join("/proc", d.Name(), "cmdline"))
+ if err != nil {
+ continue
+ }
+
+ if len(cmdline) < 1 {
+ continue
+ }
+
+ if cmdline[0] != "udhcpc" {
+ continue
+ }
+
+ cmdlineText := strings.Join(cmdline, " ")
+
+ // check if it's a udhcpc process
+ if strings.Contains(cmdlineText, fmt.Sprintf("-i %s", p.InterfaceName)) {
+ p.logger.Debug().
+ Str("pid", d.Name()).
+ Interface("cmdline", cmdline).
+ Msg("found udhcpc process")
+ return pid, nil
+ }
+ }
+
+ return 0, errors.New("udhcpc process not found")
+}
+
+func (c *DHCPClient) getProcessPid() (int, error) {
+ var pid int
+ if c.pidFile != "" {
+ // try to read the pid file
+ pidHandle, err := os.ReadFile(c.pidFile)
+ if err != nil {
+ c.logger.Warn().Err(err).
+ Str("pidFile", c.pidFile).Msg("failed to read udhcpc pid file")
+ }
+
+ // if it exists, try to read the pid
+ if pidHandle != nil {
+ pidFromFile, err := strconv.Atoi(string(pidHandle))
+ if err != nil {
+ c.logger.Warn().Err(err).
+ Str("pidFile", c.pidFile).Msg("failed to convert pid file to int")
+ }
+ pid = pidFromFile
+ }
+ }
+
+ // if the pid is 0, try to find the pid using procfs
+ if pid == 0 {
+ newPid, err := c.findUdhcpcProcess()
+ if err != nil {
+ return 0, err
+ }
+ pid = newPid
+ }
+
+ return pid, nil
+}
+
+func (c *DHCPClient) getProcess() *os.Process {
+ pid, err := c.getProcessPid()
+ if err != nil {
+ return nil
+ }
+
+ process, err := os.FindProcess(pid)
+ if err != nil {
+ c.logger.Warn().Err(err).
+ Int("pid", pid).Msg("failed to find process")
+ return nil
+ }
+
+ return process
+}
+
+func (c *DHCPClient) GetProcess() *os.Process {
+ if c.process == nil {
+ process := c.getProcess()
+ if process == nil {
+ return nil
+ }
+ c.process = process
+ }
+
+ err := c.process.Signal(syscall.Signal(0))
+ if err != nil && errors.Is(err, os.ErrProcessDone) {
+ oldPid := c.process.Pid
+
+ c.process = nil
+ c.process = c.getProcess()
+ if c.process == nil {
+ c.logger.Error().Msg("failed to find new udhcpc process")
+ return nil
+ }
+ c.logger.Warn().
+ Int("oldPid", oldPid).
+ Int("newPid", c.process.Pid).
+ Msg("udhcpc process pid changed")
+ } else if err != nil {
+ c.logger.Warn().Err(err).
+ Int("pid", c.process.Pid).Msg("udhcpc process is not running")
+ }
+
+ return c.process
+}
+
+func (c *DHCPClient) KillProcess() error {
+ process := c.GetProcess()
+ if process == nil {
+ return nil
+ }
+
+ return process.Kill()
+}
+
+func (c *DHCPClient) ReleaseProcess() error {
+ process := c.GetProcess()
+ if process == nil {
+ return nil
+ }
+
+ return process.Release()
+}
+
+func (c *DHCPClient) signalProcess(sig syscall.Signal) error {
+ process := c.GetProcess()
+ if process == nil {
+ return nil
+ }
+
+ s := process.Signal(sig)
+ if s != nil {
+ c.logger.Warn().Err(s).
+ Int("pid", process.Pid).
+ Str("signal", sig.String()).
+ Msg("failed to signal udhcpc process")
+ return s
+ }
+
+ return nil
+}
+
+func (c *DHCPClient) Renew() error {
+ return c.signalProcess(syscall.SIGUSR1)
+}
+
+func (c *DHCPClient) Release() error {
+ return c.signalProcess(syscall.SIGUSR2)
+}
diff --git a/internal/udhcpc/udhcpc.go b/internal/udhcpc/udhcpc.go
new file mode 100644
index 0000000..70ac1b8
--- /dev/null
+++ b/internal/udhcpc/udhcpc.go
@@ -0,0 +1,191 @@
+package udhcpc
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/fsnotify/fsnotify"
+ "github.com/rs/zerolog"
+)
+
+const (
+ DHCPLeaseFile = "/run/udhcpc.%s.info"
+ DHCPPidFile = "/run/udhcpc.%s.pid"
+)
+
+type DHCPClient struct {
+ InterfaceName string
+ leaseFile string
+ pidFile string
+ lease *Lease
+ logger *zerolog.Logger
+ process *os.Process
+ onLeaseChange func(lease *Lease)
+}
+
+type DHCPClientOptions struct {
+ InterfaceName string
+ PidFile string
+ Logger *zerolog.Logger
+ OnLeaseChange func(lease *Lease)
+}
+
+var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel)
+
+func NewDHCPClient(options *DHCPClientOptions) *DHCPClient {
+ if options.Logger == nil {
+ options.Logger = &defaultLogger
+ }
+
+ l := options.Logger.With().Str("interface", options.InterfaceName).Logger()
+ return &DHCPClient{
+ InterfaceName: options.InterfaceName,
+ logger: &l,
+ leaseFile: fmt.Sprintf(DHCPLeaseFile, options.InterfaceName),
+ pidFile: options.PidFile,
+ onLeaseChange: options.OnLeaseChange,
+ }
+}
+
+func (c *DHCPClient) getWatchPaths() []string {
+ watchPaths := make(map[string]interface{})
+ watchPaths[filepath.Dir(c.leaseFile)] = nil
+
+ if c.pidFile != "" {
+ watchPaths[filepath.Dir(c.pidFile)] = nil
+ }
+
+ paths := make([]string, 0)
+ for path := range watchPaths {
+ paths = append(paths, path)
+ }
+ return paths
+}
+
+// Run starts the DHCP client and watches the lease file for changes.
+// this isn't a blocking call, and the lease file is reloaded when a change is detected.
+func (c *DHCPClient) Run() error {
+ err := c.loadLeaseFile()
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return err
+ }
+
+ watcher, err := fsnotify.NewWatcher()
+ if err != nil {
+ return err
+ }
+ defer watcher.Close()
+
+ go func() {
+ for {
+ select {
+ case event, ok := <-watcher.Events:
+ if !ok {
+ continue
+ }
+ if !event.Has(fsnotify.Write) && !event.Has(fsnotify.Create) {
+ continue
+ }
+
+ if event.Name == c.leaseFile {
+ c.logger.Debug().
+ Str("event", event.Op.String()).
+ Str("path", event.Name).
+ Msg("udhcpc lease file updated, reloading lease")
+ _ = c.loadLeaseFile()
+ }
+ case err, ok := <-watcher.Errors:
+ if !ok {
+ return
+ }
+ c.logger.Error().Err(err).Msg("error watching lease file")
+ }
+ }
+ }()
+
+ for _, path := range c.getWatchPaths() {
+ err = watcher.Add(path)
+ if err != nil {
+ c.logger.Error().
+ Err(err).
+ Str("path", path).
+ Msg("failed to watch directory")
+ return err
+ }
+ }
+
+ // TODO: update udhcpc pid file
+ // we'll comment this out for now because the pid might change
+ // process := c.GetProcess()
+ // if process == nil {
+ // c.logger.Error().Msg("udhcpc process not found")
+ // }
+
+ // block the goroutine until the lease file is updated
+ <-make(chan struct{})
+
+ return nil
+}
+
+func (c *DHCPClient) loadLeaseFile() error {
+ file, err := os.ReadFile(c.leaseFile)
+ if err != nil {
+ return err
+ }
+
+ data := string(file)
+ if data == "" {
+ c.logger.Debug().Msg("udhcpc lease file is empty")
+ return nil
+ }
+
+ lease := &Lease{}
+ err = UnmarshalDHCPCLease(lease, string(file))
+ if err != nil {
+ return err
+ }
+
+ isFirstLoad := c.lease == nil
+ c.lease = lease
+
+ if lease.IPAddress == nil {
+ c.logger.Info().
+ Interface("lease", lease).
+ Str("data", string(file)).
+ Msg("udhcpc lease cleared")
+ return nil
+ }
+
+ msg := "udhcpc lease updated"
+ if isFirstLoad {
+ msg = "udhcpc lease loaded"
+ }
+
+ leaseExpiry, err := lease.SetLeaseExpiry()
+ if err != nil {
+ c.logger.Error().Err(err).Msg("failed to get dhcp lease expiry")
+ } else {
+ expiresIn := time.Until(leaseExpiry)
+ c.logger.Info().
+ Interface("expiry", leaseExpiry).
+ Str("expiresIn", expiresIn.String()).
+ Msg("current dhcp lease expiry time calculated")
+ }
+
+ c.onLeaseChange(lease)
+
+ c.logger.Info().
+ Str("ip", lease.IPAddress.String()).
+ Str("leaseTime", lease.LeaseTime.String()).
+ Interface("data", lease).
+ Msg(msg)
+
+ return nil
+}
+
+func (c *DHCPClient) GetLease() *Lease {
+ return c.lease
+}
diff --git a/internal/websecure/store.go b/internal/websecure/store.go
index 69ae3ef..ea7911c 100644
--- a/internal/websecure/store.go
+++ b/internal/websecure/store.go
@@ -96,7 +96,11 @@ func (s *CertStore) loadCertificate(hostname string) {
s.certificates[hostname] = &cert
- s.log.Info().Str("hostname", hostname).Msg("Loaded certificate")
+ if hostname == selfSignerCAMagicName {
+ s.log.Info().Msg("loaded CA certificate")
+ } else {
+ s.log.Info().Str("hostname", hostname).Msg("loaded certificate")
+ }
}
// GetCertificate returns the certificate for the given hostname
@@ -131,7 +135,7 @@ func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key
if !ignoreWarning {
return nil, fmt.Errorf("certificate does not match hostname: %w", err)
}
- s.log.Warn().Err(err).Msg("Certificate does not match hostname")
+ s.log.Warn().Err(err).Msg("certificate does not match hostname")
}
}
diff --git a/jsonrpc.go b/jsonrpc.go
index 248390e..d35f635 100644
--- a/jsonrpc.go
+++ b/jsonrpc.go
@@ -962,6 +962,10 @@ var rpcHandlers = map[string]RPCHandler{
"getDeviceID": {Func: rpcGetDeviceID},
"deregisterDevice": {Func: rpcDeregisterDevice},
"getCloudState": {Func: rpcGetCloudState},
+ "getNetworkState": {Func: rpcGetNetworkState},
+ "getNetworkSettings": {Func: rpcGetNetworkSettings},
+ "setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}},
+ "renewDHCPLease": {Func: rpcRenewDHCPLease},
"keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}},
"absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}},
"relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}},
diff --git a/log.go b/log.go
index ed46852..b353a2c 100644
--- a/log.go
+++ b/log.go
@@ -1,291 +1,32 @@
package kvm
import (
- "fmt"
- "io"
- "os"
- "strings"
- "sync"
- "time"
-
- "github.com/pion/logging"
+ "github.com/jetkvm/kvm/internal/logging"
"github.com/rs/zerolog"
)
-type Logger struct {
- l *zerolog.Logger
- scopeLoggers map[string]*zerolog.Logger
- scopeLevels map[string]zerolog.Level
- scopeLevelMutex sync.Mutex
-
- defaultLogLevelFromEnv zerolog.Level
- defaultLogLevelFromConfig zerolog.Level
- defaultLogLevel zerolog.Level
-}
-
-const (
- defaultLogLevel = zerolog.ErrorLevel
-)
-
-type logOutput struct {
- mu *sync.Mutex
-}
-
-func (w *logOutput) Write(p []byte) (n int, err error) {
- w.mu.Lock()
- defer w.mu.Unlock()
-
- // TODO: write to file or syslog
-
- return len(p), nil
-}
-
-var (
- consoleLogOutput io.Writer = zerolog.ConsoleWriter{
- Out: os.Stdout,
- TimeFormat: time.RFC3339,
- PartsOrder: []string{"time", "level", "scope", "component", "message"},
- FieldsExclude: []string{"scope", "component"},
- FormatPartValueByName: func(value interface{}, name string) string {
- val := fmt.Sprintf("%s", value)
- if name == "component" {
- if value == nil {
- return "-"
- }
- }
- return val
- },
- }
- fileLogOutput io.Writer = &logOutput{mu: &sync.Mutex{}}
- defaultLogOutput = zerolog.MultiLevelWriter(consoleLogOutput, fileLogOutput)
-
- zerologLevels = map[string]zerolog.Level{
- "DISABLE": zerolog.Disabled,
- "NOLEVEL": zerolog.NoLevel,
- "PANIC": zerolog.PanicLevel,
- "FATAL": zerolog.FatalLevel,
- "ERROR": zerolog.ErrorLevel,
- "WARN": zerolog.WarnLevel,
- "INFO": zerolog.InfoLevel,
- "DEBUG": zerolog.DebugLevel,
- "TRACE": zerolog.TraceLevel,
- }
-
- rootZerologLogger = zerolog.New(defaultLogOutput).With().
- Str("scope", "jetkvm").
- Timestamp().
- Stack().
- Logger()
- rootLogger = NewLogger(rootZerologLogger)
-)
-
-func NewLogger(zerologLogger zerolog.Logger) *Logger {
- return &Logger{
- l: &zerologLogger,
- scopeLoggers: make(map[string]*zerolog.Logger),
- scopeLevels: make(map[string]zerolog.Level),
- scopeLevelMutex: sync.Mutex{},
- defaultLogLevelFromEnv: -2,
- defaultLogLevelFromConfig: -2,
- defaultLogLevel: defaultLogLevel,
- }
-}
-
-func (l *Logger) updateLogLevel() {
- l.scopeLevelMutex.Lock()
- defer l.scopeLevelMutex.Unlock()
-
- l.scopeLevels = make(map[string]zerolog.Level)
-
- finalDefaultLogLevel := l.defaultLogLevel
-
- for name, level := range zerologLevels {
- env := os.Getenv(fmt.Sprintf("JETKVM_LOG_%s", name))
-
- if env == "" {
- env = os.Getenv(fmt.Sprintf("PION_LOG_%s", name))
- }
-
- if env == "" {
- env = os.Getenv(fmt.Sprintf("PIONS_LOG_%s", name))
- }
-
- if env == "" {
- continue
- }
-
- if strings.ToLower(env) == "all" {
- l.defaultLogLevelFromEnv = level
-
- if finalDefaultLogLevel > level {
- finalDefaultLogLevel = level
- }
-
- continue
- }
-
- scopes := strings.Split(strings.ToLower(env), ",")
- for _, scope := range scopes {
- l.scopeLevels[scope] = level
- }
- }
-
- l.defaultLogLevel = finalDefaultLogLevel
-}
-
-func (l *Logger) getScopeLoggerLevel(scope string) zerolog.Level {
- if l.scopeLevels == nil {
- l.updateLogLevel()
- }
-
- var scopeLevel zerolog.Level
- if l.defaultLogLevelFromConfig != -2 {
- scopeLevel = l.defaultLogLevelFromConfig
- }
- if l.defaultLogLevelFromEnv != -2 {
- scopeLevel = l.defaultLogLevelFromEnv
- }
-
- // if the scope is not in the map, use the default level from the root logger
- if level, ok := l.scopeLevels[scope]; ok {
- scopeLevel = level
- }
-
- return scopeLevel
-}
-
-func (l *Logger) newScopeLogger(scope string) zerolog.Logger {
- scopeLevel := l.getScopeLoggerLevel(scope)
- logger := l.l.Level(scopeLevel).With().Str("component", scope).Logger()
-
- return logger
-}
-
-func (l *Logger) getLogger(scope string) *zerolog.Logger {
- logger, ok := l.scopeLoggers[scope]
- if !ok || logger == nil {
- scopeLogger := l.newScopeLogger(scope)
- l.scopeLoggers[scope] = &scopeLogger
- }
-
- return l.scopeLoggers[scope]
-}
-
-func (l *Logger) UpdateLogLevel() {
- needUpdate := false
-
- if config != nil && config.DefaultLogLevel != "" {
- if logLevel, ok := zerologLevels[config.DefaultLogLevel]; ok {
- l.defaultLogLevelFromConfig = logLevel
- } else {
- l.l.Warn().Str("logLevel", config.DefaultLogLevel).Msg("invalid defaultLogLevel from config, using ERROR")
- }
-
- if l.defaultLogLevelFromConfig != l.defaultLogLevel {
- needUpdate = true
- }
- }
-
- l.updateLogLevel()
-
- if needUpdate {
- for scope, logger := range l.scopeLoggers {
- currentLevel := logger.GetLevel()
- targetLevel := l.getScopeLoggerLevel(scope)
- if currentLevel != targetLevel {
- *logger = l.newScopeLogger(scope)
- }
- }
- }
-}
-
func ErrorfL(l *zerolog.Logger, format string, err error, args ...interface{}) error {
- if l == nil {
- l = rootLogger.getLogger("jetkvm")
- }
-
- l.Error().Err(err).Msgf(format, args...)
-
- if err == nil {
- return fmt.Errorf(format, args...)
- }
-
- err_msg := err.Error() + ": %v"
- err_args := append(args, err)
-
- return fmt.Errorf(err_msg, err_args...)
+ return logging.ErrorfL(l, format, err, args...)
}
var (
- logger = rootLogger.getLogger("jetkvm")
- cloudLogger = rootLogger.getLogger("cloud")
- websocketLogger = rootLogger.getLogger("websocket")
- webrtcLogger = rootLogger.getLogger("webrtc")
- nativeLogger = rootLogger.getLogger("native")
- nbdLogger = rootLogger.getLogger("nbd")
- ntpLogger = rootLogger.getLogger("ntp")
- jsonRpcLogger = rootLogger.getLogger("jsonrpc")
- watchdogLogger = rootLogger.getLogger("watchdog")
- websecureLogger = rootLogger.getLogger("websecure")
- otaLogger = rootLogger.getLogger("ota")
- serialLogger = rootLogger.getLogger("serial")
- terminalLogger = rootLogger.getLogger("terminal")
- displayLogger = rootLogger.getLogger("display")
- wolLogger = rootLogger.getLogger("wol")
- usbLogger = rootLogger.getLogger("usb")
+ logger = logging.GetSubsystemLogger("jetkvm")
+ networkLogger = logging.GetSubsystemLogger("network")
+ cloudLogger = logging.GetSubsystemLogger("cloud")
+ websocketLogger = logging.GetSubsystemLogger("websocket")
+ webrtcLogger = logging.GetSubsystemLogger("webrtc")
+ nativeLogger = logging.GetSubsystemLogger("native")
+ nbdLogger = logging.GetSubsystemLogger("nbd")
+ timesyncLogger = logging.GetSubsystemLogger("timesync")
+ jsonRpcLogger = logging.GetSubsystemLogger("jsonrpc")
+ watchdogLogger = logging.GetSubsystemLogger("watchdog")
+ websecureLogger = logging.GetSubsystemLogger("websecure")
+ otaLogger = logging.GetSubsystemLogger("ota")
+ serialLogger = logging.GetSubsystemLogger("serial")
+ terminalLogger = logging.GetSubsystemLogger("terminal")
+ displayLogger = logging.GetSubsystemLogger("display")
+ wolLogger = logging.GetSubsystemLogger("wol")
+ usbLogger = logging.GetSubsystemLogger("usb")
// external components
- ginLogger = rootLogger.getLogger("gin")
+ ginLogger = logging.GetSubsystemLogger("gin")
)
-
-type pionLogger struct {
- logger *zerolog.Logger
-}
-
-// Print all messages except trace.
-func (c pionLogger) Trace(msg string) {
- c.logger.Trace().Msg(msg)
-}
-func (c pionLogger) Tracef(format string, args ...interface{}) {
- c.logger.Trace().Msgf(format, args...)
-}
-
-func (c pionLogger) Debug(msg string) {
- c.logger.Debug().Msg(msg)
-}
-func (c pionLogger) Debugf(format string, args ...interface{}) {
- c.logger.Debug().Msgf(format, args...)
-}
-func (c pionLogger) Info(msg string) {
- c.logger.Info().Msg(msg)
-}
-func (c pionLogger) Infof(format string, args ...interface{}) {
- c.logger.Info().Msgf(format, args...)
-}
-func (c pionLogger) Warn(msg string) {
- c.logger.Warn().Msg(msg)
-}
-func (c pionLogger) Warnf(format string, args ...interface{}) {
- c.logger.Warn().Msgf(format, args...)
-}
-func (c pionLogger) Error(msg string) {
- c.logger.Error().Msg(msg)
-}
-func (c pionLogger) Errorf(format string, args ...interface{}) {
- c.logger.Error().Msgf(format, args...)
-}
-
-// customLoggerFactory satisfies the interface logging.LoggerFactory
-// This allows us to create different loggers per subsystem. So we can
-// add custom behavior.
-type pionLoggerFactory struct{}
-
-func (c pionLoggerFactory) NewLogger(subsystem string) logging.LeveledLogger {
- logger := rootLogger.getLogger(subsystem).With().
- Str("scope", "pion").
- Str("component", subsystem).
- Logger()
-
- return pionLogger{logger: &logger}
-}
-
-var defaultLoggerFactory = &pionLoggerFactory{}
diff --git a/main.go b/main.go
index 9eab708..25fbb3a 100644
--- a/main.go
+++ b/main.go
@@ -15,28 +15,54 @@ var appCtx context.Context
func Main() {
LoadConfig()
- logger.Debug().Msg("config loaded")
var cancel context.CancelFunc
appCtx, cancel = context.WithCancel(context.Background())
defer cancel()
- logger.Info().Msg("starting JetKvm")
+
+ systemVersionLocal, appVersionLocal, err := GetLocalVersion()
+ if err != nil {
+ logger.Warn().Err(err).Msg("failed to get local version")
+ }
+
+ logger.Info().
+ Interface("system_version", systemVersionLocal).
+ Interface("app_version", appVersionLocal).
+ Msg("starting JetKVM")
go runWatchdog()
go confirmCurrentSystem()
http.DefaultClient.Timeout = 1 * time.Minute
- err := rootcerts.UpdateDefaultTransport()
+ err = rootcerts.UpdateDefaultTransport()
if err != nil {
- logger.Warn().Err(err).Msg("failed to load CA certs")
+ logger.Warn().Err(err).Msg("failed to load Root CA certificates")
+ }
+ logger.Info().
+ Int("ca_certs_loaded", len(rootcerts.Certs())).
+ Msg("loaded Root CA certificates")
+
+ // Initialize network
+ if err := initNetwork(); err != nil {
+ logger.Error().Err(err).Msg("failed to initialize network")
+ os.Exit(1)
}
- initNetwork()
+ // Initialize time sync
+ initTimeSync()
+ timeSync.Start()
- go TimeSyncLoop()
+ // Initialize mDNS
+ if err := initMdns(); err != nil {
+ logger.Error().Err(err).Msg("failed to initialize mDNS")
+ os.Exit(1)
+ }
+ // Initialize native ctrl socket server
StartNativeCtrlSocketServer()
+
+ // Initialize native video socket server
StartNativeVideoSocketServer()
initPrometheus()
diff --git a/mdns.go b/mdns.go
new file mode 100644
index 0000000..d7a3b55
--- /dev/null
+++ b/mdns.go
@@ -0,0 +1,29 @@
+package kvm
+
+import (
+ "github.com/jetkvm/kvm/internal/mdns"
+)
+
+var mDNS *mdns.MDNS
+
+func initMdns() error {
+ m, err := mdns.NewMDNS(&mdns.MDNSOptions{
+ Logger: logger,
+ LocalNames: []string{
+ networkState.GetHostname(),
+ networkState.GetFQDN(),
+ },
+ ListenOptions: &mdns.MDNSListenOptions{
+ IPv4: true,
+ IPv6: true,
+ },
+ })
+ if err != nil {
+ return err
+ }
+
+ // do not start the server yet, as we need to wait for the network state to be set
+ mDNS = m
+
+ return nil
+}
diff --git a/native.go b/native.go
index b61598c..496f580 100644
--- a/native.go
+++ b/native.go
@@ -8,13 +8,10 @@ import (
"io"
"net"
"os"
- "os/exec"
"sync"
- "syscall"
"time"
"github.com/jetkvm/kvm/resource"
- "github.com/rs/zerolog"
"github.com/pion/webrtc/v4/pkg/media"
)
@@ -36,19 +33,6 @@ type CtrlResponse struct {
Data json.RawMessage `json:"data,omitempty"`
}
-type nativeOutput struct {
- mu *sync.Mutex
- logger *zerolog.Event
-}
-
-func (w *nativeOutput) Write(p []byte) (n int, err error) {
- w.mu.Lock()
- defer w.mu.Unlock()
-
- w.logger.Msg(string(p))
- return len(p), nil
-}
-
type EventHandler func(event CtrlResponse)
var seq int32 = 1
@@ -262,30 +246,8 @@ func ExtractAndRunNativeBin() error {
return fmt.Errorf("failed to make binary executable: %w", err)
}
// Run the binary in the background
- cmd := exec.Command(binaryPath)
-
- nativeOutputLock := sync.Mutex{}
- nativeStdout := &nativeOutput{
- mu: &nativeOutputLock,
- logger: nativeLogger.Info().Str("pipe", "stdout"),
- }
- nativeStderr := &nativeOutput{
- mu: &nativeOutputLock,
- logger: nativeLogger.Info().Str("pipe", "stderr"),
- }
-
- // Redirect stdout and stderr to the current process
- cmd.Stdout = nativeStdout
- cmd.Stderr = nativeStderr
-
- // Set the process group ID so we can kill the process and its children when this process exits
- cmd.SysProcAttr = &syscall.SysProcAttr{
- Setpgid: true,
- Pdeathsig: syscall.SIGKILL,
- }
-
- // Start the command
- if err := cmd.Start(); err != nil {
+ cmd, err := startNativeBinary(binaryPath)
+ if err != nil {
return fmt.Errorf("failed to start binary: %w", err)
}
@@ -335,7 +297,10 @@ func ensureBinaryUpdated(destPath string) error {
_, err = os.Stat(destPath)
if shouldOverwrite(destPath, srcHash) || err != nil {
- nativeLogger.Info().Msg("writing jetkvm_native")
+ nativeLogger.Info().
+ Interface("hash", srcHash).
+ Msg("writing jetkvm_native")
+
_ = os.Remove(destPath)
destFile, err := os.OpenFile(destPath, os.O_CREATE|os.O_RDWR, 0755)
if err != nil {
diff --git a/native_linux.go b/native_linux.go
new file mode 100644
index 0000000..54d2150
--- /dev/null
+++ b/native_linux.go
@@ -0,0 +1,57 @@
+//go:build linux
+
+package kvm
+
+import (
+ "fmt"
+ "os/exec"
+ "sync"
+ "syscall"
+
+ "github.com/rs/zerolog"
+)
+
+type nativeOutput struct {
+ mu *sync.Mutex
+ logger *zerolog.Event
+}
+
+func (w *nativeOutput) Write(p []byte) (n int, err error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ w.logger.Msg(string(p))
+ return len(p), nil
+}
+
+func startNativeBinary(binaryPath string) (*exec.Cmd, error) {
+ // Run the binary in the background
+ cmd := exec.Command(binaryPath)
+
+ nativeOutputLock := sync.Mutex{}
+ nativeStdout := &nativeOutput{
+ mu: &nativeOutputLock,
+ logger: nativeLogger.Info().Str("pipe", "stdout"),
+ }
+ nativeStderr := &nativeOutput{
+ mu: &nativeOutputLock,
+ logger: nativeLogger.Info().Str("pipe", "stderr"),
+ }
+
+ // Redirect stdout and stderr to the current process
+ cmd.Stdout = nativeStdout
+ cmd.Stderr = nativeStderr
+
+ // Set the process group ID so we can kill the process and its children when this process exits
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setpgid: true,
+ Pdeathsig: syscall.SIGKILL,
+ }
+
+ // Start the command
+ if err := cmd.Start(); err != nil {
+ return nil, fmt.Errorf("failed to start binary: %w", err)
+ }
+
+ return cmd, nil
+}
diff --git a/native_notlinux.go b/native_notlinux.go
new file mode 100644
index 0000000..df6df74
--- /dev/null
+++ b/native_notlinux.go
@@ -0,0 +1,12 @@
+//go:build !linux
+
+package kvm
+
+import (
+ "fmt"
+ "os/exec"
+)
+
+func startNativeBinary(binaryPath string) (*exec.Cmd, error) {
+ return nil, fmt.Errorf("not supported")
+}
diff --git a/network.go b/network.go
index 6948d9a..8d9261b 100644
--- a/network.go
+++ b/network.go
@@ -1,237 +1,107 @@
package kvm
import (
- "bytes"
"fmt"
- "net"
- "os"
- "strings"
- "time"
- "os/exec"
-
- "github.com/hashicorp/go-envparse"
- "github.com/pion/mdns/v2"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
-
- "github.com/vishvananda/netlink"
- "github.com/vishvananda/netlink/nl"
+ "github.com/jetkvm/kvm/internal/network"
+ "github.com/jetkvm/kvm/internal/udhcpc"
)
-var mDNSConn *mdns.Conn
-
-var networkState NetworkState
-
-type NetworkState struct {
- Up bool
- IPv4 string
- IPv6 string
- MAC string
-
- checked bool
-}
-
-type LocalIpInfo struct {
- IPv4 string
- IPv6 string
- MAC string
-}
-
const (
- NetIfName = "eth0"
- DHCPLeaseFile = "/run/udhcpc.%s.info"
+ NetIfName = "eth0"
)
-// setDhcpClientState sends signals to udhcpc to change it's current mode
-// of operation. Setting active to true will force udhcpc to renew the DHCP lease.
-// Setting active to false will put udhcpc into idle mode.
-func setDhcpClientState(active bool) {
- var signal string
- if active {
- signal = "-SIGUSR1"
- } else {
- signal = "-SIGUSR2"
- }
+var (
+ networkState *network.NetworkInterfaceState
+)
- cmd := exec.Command("/usr/bin/killall", signal, "udhcpc")
- if err := cmd.Run(); err != nil {
- logger.Warn().Err(err).Msg("network: setDhcpClientState: failed to change udhcpc state")
+func networkStateChanged() {
+ // do not block the main thread
+ go waitCtrlAndRequestDisplayUpdate(true)
+
+ // always restart mDNS when the network state changes
+ if mDNS != nil {
+ _ = mDNS.SetLocalNames([]string{
+ networkState.GetHostname(),
+ networkState.GetFQDN(),
+ }, true)
}
}
-func checkNetworkState() {
- iface, err := netlink.LinkByName(NetIfName)
- if err != nil {
- logger.Warn().Err(err).Str("interface", NetIfName).Msg("failed to get interface")
- return
- }
+func initNetwork() error {
+ ensureConfigLoaded()
- newState := NetworkState{
- Up: iface.Attrs().OperState == netlink.OperUp,
- MAC: iface.Attrs().HardwareAddr.String(),
+ state, err := network.NewNetworkInterfaceState(&network.NetworkInterfaceOptions{
+ DefaultHostname: GetDefaultHostname(),
+ InterfaceName: NetIfName,
+ NetworkConfig: config.NetworkConfig,
+ Logger: networkLogger,
+ OnStateChange: func(state *network.NetworkInterfaceState) {
+ networkStateChanged()
+ },
+ OnInitialCheck: func(state *network.NetworkInterfaceState) {
+ networkStateChanged()
+ },
+ OnDhcpLeaseChange: func(lease *udhcpc.Lease) {
+ networkStateChanged()
- checked: true,
- }
-
- addrs, err := netlink.AddrList(iface, nl.FAMILY_ALL)
- if err != nil {
- logger.Warn().Err(err).Str("interface", NetIfName).Msg("failed to get addresses")
- }
-
- // If the link is going down, put udhcpc into idle mode.
- // If the link is coming back up, activate udhcpc and force it to renew the lease.
- if newState.Up != networkState.Up {
- setDhcpClientState(newState.Up)
- }
-
- for _, addr := range addrs {
- if addr.IP.To4() != nil {
- if !newState.Up && networkState.Up {
- // If the network is going down, remove all IPv4 addresses from the interface.
- logger.Info().Str("address", addr.IP.String()).Msg("network: state transitioned to down, removing IPv4 address")
- err := netlink.AddrDel(iface, &addr)
- if err != nil {
- logger.Warn().Err(err).Str("address", addr.IP.String()).Msg("network: failed to delete address")
- }
-
- newState.IPv4 = "..."
- } else {
- newState.IPv4 = addr.IP.String()
+ if currentSession == nil {
+ return
}
- } else if addr.IP.To16() != nil && newState.IPv6 == "" {
- newState.IPv6 = addr.IP.String()
- }
- }
- if newState != networkState {
- logger.Info().
- Interface("newState", newState).
- Interface("oldState", networkState).
- Msg("network state changed")
+ writeJSONRPCEvent("networkState", networkState.RpcGetNetworkState(), currentSession)
+ },
+ OnConfigChange: func(networkConfig *network.NetworkConfig) {
+ config.NetworkConfig = networkConfig
+ networkStateChanged()
- // restart MDNS
- _ = startMDNS()
- networkState = newState
- requestDisplayUpdate()
- }
-}
-
-func startMDNS() error {
- // If server was previously running, stop it
- if mDNSConn != nil {
- logger.Info().Msg("stopping mDNS server")
- err := mDNSConn.Close()
- if err != nil {
- logger.Warn().Err(err).Msg("failed to stop mDNS server")
- }
- }
-
- // Start a new server
- hostname := "jetkvm.local"
-
- scopedLogger := logger.With().Str("hostname", hostname).Logger()
- scopedLogger.Info().Msg("starting mDNS server")
-
- addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4)
- if err != nil {
- return err
- }
-
- addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
- if err != nil {
- return err
- }
-
- l4, err := net.ListenUDP("udp4", addr4)
- if err != nil {
- return err
- }
-
- l6, err := net.ListenUDP("udp6", addr6)
- if err != nil {
- return err
- }
-
- mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &mdns.Config{
- LocalNames: []string{hostname}, //TODO: make it configurable
- LoggerFactory: defaultLoggerFactory,
+ if mDNS != nil {
+ _ = mDNS.SetListenOptions(networkConfig.GetMDNSMode())
+ _ = mDNS.SetLocalNames([]string{
+ networkState.GetHostname(),
+ networkState.GetFQDN(),
+ }, true)
+ }
+ },
})
- if err != nil {
- scopedLogger.Warn().Err(err).Msg("failed to start mDNS server")
- mDNSConn = nil
+
+ if state == nil {
+ if err == nil {
+ return fmt.Errorf("failed to create NetworkInterfaceState")
+ }
return err
}
- //defer server.Close()
+
+ if err := state.Run(); err != nil {
+ return err
+ }
+
+ networkState = state
+
return nil
}
-func getNTPServersFromDHCPInfo() ([]string, error) {
- buf, err := os.ReadFile(fmt.Sprintf(DHCPLeaseFile, NetIfName))
- if err != nil {
- // do not return error if file does not exist
- if os.IsNotExist(err) {
- return nil, nil
- }
- return nil, fmt.Errorf("failed to load udhcpc info: %w", err)
- }
-
- // parse udhcpc info
- env, err := envparse.Parse(bytes.NewReader(buf))
- if err != nil {
- return nil, fmt.Errorf("failed to parse udhcpc info: %w", err)
- }
-
- val, ok := env["ntpsrv"]
- if !ok {
- return nil, nil
- }
-
- var servers []string
-
- for _, server := range strings.Fields(val) {
- if net.ParseIP(server) == nil {
- logger.Info().Str("server", server).Msg("invalid NTP server IP, ignoring")
- }
- servers = append(servers, server)
- }
-
- return servers, nil
+func rpcGetNetworkState() network.RpcNetworkState {
+ return networkState.RpcGetNetworkState()
}
-func initNetwork() {
- ensureConfigLoaded()
-
- updates := make(chan netlink.LinkUpdate)
- done := make(chan struct{})
-
- if err := netlink.LinkSubscribe(updates, done); err != nil {
- logger.Warn().Err(err).Msg("failed to subscribe to link updates")
- return
- }
-
- go func() {
- waitCtrlClientConnected()
- checkNetworkState()
- ticker := time.NewTicker(1 * time.Second)
- defer ticker.Stop()
-
- for {
- select {
- case update := <-updates:
- if update.Link.Attrs().Name == NetIfName {
- logger.Info().Interface("update", update).Msg("link update")
- checkNetworkState()
- }
- case <-ticker.C:
- checkNetworkState()
- case <-done:
- return
- }
- }
- }()
- err := startMDNS()
- if err != nil {
- logger.Warn().Err(err).Msg("failed to run mDNS")
- }
+func rpcGetNetworkSettings() network.RpcNetworkSettings {
+ return networkState.RpcGetNetworkSettings()
+}
+
+func rpcSetNetworkSettings(settings network.RpcNetworkSettings) (*network.RpcNetworkSettings, error) {
+ s := networkState.RpcSetNetworkSettings(settings)
+ if s != nil {
+ return nil, s
+ }
+
+ if err := SaveConfig(); err != nil {
+ return nil, err
+ }
+
+ return &network.RpcNetworkSettings{NetworkConfig: *config.NetworkConfig}, nil
+}
+
+func rpcRenewDHCPLease() error {
+ return networkState.RpcRenewDHCPLease()
}
diff --git a/ntp.go b/ntp.go
deleted file mode 100644
index a104c56..0000000
--- a/ntp.go
+++ /dev/null
@@ -1,197 +0,0 @@
-package kvm
-
-import (
- "fmt"
- "net/http"
- "os/exec"
- "strconv"
- "time"
-
- "github.com/beevik/ntp"
-)
-
-const (
- timeSyncRetryStep = 5 * time.Second
- timeSyncRetryMaxInt = 1 * time.Minute
- timeSyncWaitNetChkInt = 100 * time.Millisecond
- timeSyncWaitNetUpInt = 3 * time.Second
- timeSyncInterval = 1 * time.Hour
- timeSyncTimeout = 2 * time.Second
-)
-
-var (
- builtTimestamp string
- timeSyncRetryInterval = 0 * time.Second
- timeSyncSuccess = false
- defaultNTPServers = []string{
- "time.cloudflare.com",
- "time.apple.com",
- }
-)
-
-func isTimeSyncNeeded() bool {
- if builtTimestamp == "" {
- ntpLogger.Warn().Msg("Built timestamp is not set, time sync is needed")
- return true
- }
-
- ts, err := strconv.Atoi(builtTimestamp)
- if err != nil {
- ntpLogger.Warn().Str("error", err.Error()).Msg("Failed to parse built timestamp")
- return true
- }
-
- // builtTimestamp is UNIX timestamp in seconds
- builtTime := time.Unix(int64(ts), 0)
- now := time.Now()
-
- ntpLogger.Debug().Str("built_time", builtTime.Format(time.RFC3339)).Str("now", now.Format(time.RFC3339)).Msg("Built time and now")
-
- if now.Sub(builtTime) < 0 {
- ntpLogger.Warn().Msg("System time is behind the built time, time sync is needed")
- return true
- }
-
- return false
-}
-
-func TimeSyncLoop() {
- for {
- if !networkState.checked {
- time.Sleep(timeSyncWaitNetChkInt)
- continue
- }
-
- if !networkState.Up {
- ntpLogger.Info().Msg("Waiting for network to come up")
- time.Sleep(timeSyncWaitNetUpInt)
- continue
- }
-
- // check if time sync is needed, but do nothing for now
- isTimeSyncNeeded()
-
- ntpLogger.Info().Msg("Syncing system time")
- start := time.Now()
- err := SyncSystemTime()
- if err != nil {
- ntpLogger.Error().Str("error", err.Error()).Msg("Failed to sync system time")
-
- // retry after a delay
- timeSyncRetryInterval += timeSyncRetryStep
- time.Sleep(timeSyncRetryInterval)
- // reset the retry interval if it exceeds the max interval
- if timeSyncRetryInterval > timeSyncRetryMaxInt {
- timeSyncRetryInterval = 0
- }
-
- continue
- }
- timeSyncSuccess = true
- ntpLogger.Info().Str("now", time.Now().Format(time.RFC3339)).
- Str("time_taken", time.Since(start).String()).
- Msg("Time sync successful")
- time.Sleep(timeSyncInterval) // after the first sync is done
- }
-}
-
-func SyncSystemTime() (err error) {
- now, err := queryNetworkTime()
- if err != nil {
- return fmt.Errorf("failed to query network time: %w", err)
- }
- err = setSystemTime(*now)
- if err != nil {
- return fmt.Errorf("failed to set system time: %w", err)
- }
- return nil
-}
-
-func queryNetworkTime() (*time.Time, error) {
- ntpServers, err := getNTPServersFromDHCPInfo()
- if err != nil {
- ntpLogger.Info().Err(err).Msg("failed to get NTP servers from DHCP info")
- }
-
- if ntpServers == nil {
- ntpServers = defaultNTPServers
- ntpLogger.Info().
- Interface("ntp_servers", ntpServers).
- Msg("Using default NTP servers")
- } else {
- ntpLogger.Info().
- Interface("ntp_servers", ntpServers).
- Msg("Using NTP servers from DHCP")
- }
-
- for _, server := range ntpServers {
- now, err := queryNtpServer(server, timeSyncTimeout)
- if err == nil {
- ntpLogger.Info().
- Str("ntp_server", server).
- Str("time", now.Format(time.RFC3339)).
- Msg("NTP server returned time")
- return now, nil
- } else {
- ntpLogger.Error().
- Str("ntp_server", server).
- Str("error", err.Error()).
- Msg("failed to query NTP server")
- }
- }
-
- httpUrls := []string{
- "http://apple.com",
- "http://cloudflare.com",
- }
- for _, url := range httpUrls {
- now, err := queryHttpTime(url, timeSyncTimeout)
- if err == nil {
- ntpLogger.Info().
- Str("http_url", url).
- Str("time", now.Format(time.RFC3339)).
- Msg("HTTP server returned time")
- return now, nil
- } else {
- ntpLogger.Error().
- Str("http_url", url).
- Str("error", err.Error()).
- Msg("failed to query HTTP server")
- }
- }
-
- return nil, ErrorfL(ntpLogger, "failed to query network time, all NTP servers and HTTP servers failed", nil)
-}
-
-func queryNtpServer(server string, timeout time.Duration) (now *time.Time, err error) {
- resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout})
- if err != nil {
- return nil, err
- }
- return &resp.Time, nil
-}
-
-func queryHttpTime(url string, timeout time.Duration) (*time.Time, error) {
- client := http.Client{
- Timeout: timeout,
- }
- resp, err := client.Head(url)
- if err != nil {
- return nil, err
- }
- dateStr := resp.Header.Get("Date")
- now, err := time.Parse(time.RFC1123, dateStr)
- if err != nil {
- return nil, err
- }
- return &now, nil
-}
-
-func setSystemTime(now time.Time) error {
- nowStr := now.Format("2006-01-02 15:04:05")
- output, err := exec.Command("date", "-s", nowStr).CombinedOutput()
- if err != nil {
- return fmt.Errorf("failed to run date -s: %w, %s", err, string(output))
- }
- return nil
-}
diff --git a/ota.go b/ota.go
index a5da772..0559978 100644
--- a/ota.go
+++ b/ota.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/sha256"
+ "crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
@@ -16,6 +17,7 @@ import (
"time"
"github.com/Masterminds/semver/v3"
+ "github.com/gwatts/rootcerts"
"github.com/rs/zerolog"
)
@@ -127,10 +129,14 @@ func downloadFile(ctx context.Context, path string, url string, downloadProgress
return fmt.Errorf("error creating request: %w", err)
}
- // TODO: set a separate timeout for the download but keep the TLS handshake short
- // use Transport here will cause CA certificate validation failure so we temporarily removed it
client := http.Client{
Timeout: 10 * time.Minute,
+ Transport: &http.Transport{
+ TLSHandshakeTimeout: 30 * time.Second,
+ TLSClientConfig: &tls.Config{
+ RootCAs: rootcerts.ServerCertPool(),
+ },
+ },
}
resp, err := client.Do(req)
diff --git a/resource/jetkvm_native b/resource/jetkvm_native
index 0d0719c..084ce14 100644
Binary files a/resource/jetkvm_native and b/resource/jetkvm_native differ
diff --git a/resource/jetkvm_native.sha256 b/resource/jetkvm_native.sha256
index 65da816..b540b94 100644
--- a/resource/jetkvm_native.sha256
+++ b/resource/jetkvm_native.sha256
@@ -1 +1 @@
-c0803a9185298398eff9a925de69bd0ca882cd5983b989a45b748648146475c6
+4b925c7aa73d2e35a227833e806658cb17e1d25900611f93ed70b11ac9f1716d
diff --git a/timesync.go b/timesync.go
new file mode 100644
index 0000000..7b25fe2
--- /dev/null
+++ b/timesync.go
@@ -0,0 +1,53 @@
+package kvm
+
+import (
+ "strconv"
+ "time"
+
+ "github.com/jetkvm/kvm/internal/timesync"
+)
+
+var (
+ timeSync *timesync.TimeSync
+ builtTimestamp string
+)
+
+func isTimeSyncNeeded() bool {
+ if builtTimestamp == "" {
+ timesyncLogger.Warn().Msg("built timestamp is not set, time sync is needed")
+ return true
+ }
+
+ ts, err := strconv.Atoi(builtTimestamp)
+ if err != nil {
+ timesyncLogger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp")
+ return true
+ }
+
+ // builtTimestamp is UNIX timestamp in seconds
+ builtTime := time.Unix(int64(ts), 0)
+ now := time.Now()
+
+ if now.Sub(builtTime) < 0 {
+ timesyncLogger.Warn().
+ Str("built_time", builtTime.Format(time.RFC3339)).
+ Str("now", now.Format(time.RFC3339)).
+ Msg("system time is behind the built time, time sync is needed")
+ return true
+ }
+
+ return false
+}
+
+func initTimeSync() {
+ timeSync = timesync.NewTimeSync(×ync.TimeSyncOptions{
+ Logger: timesyncLogger,
+ NetworkConfig: config.NetworkConfig,
+ PreCheckFunc: func() (bool, error) {
+ if !networkState.IsOnline() {
+ return false, nil
+ }
+ return true, nil
+ },
+ })
+}
diff --git a/ui/dev_device.sh b/ui/dev_device.sh
index 650cadd..2c7b497 100755
--- a/ui/dev_device.sh
+++ b/ui/dev_device.sh
@@ -15,5 +15,15 @@ echo "└───────────────────────
# Set the environment variable and run Vite
echo "Starting development server with JetKVM device at: $ip_address"
+
+# Check if pwd is the current directory of the script
+if [ "$(pwd)" != "$(dirname "$0")" ]; then
+ pushd "$(dirname "$0")" > /dev/null
+ echo "Changed directory to: $(pwd)"
+fi
+
sleep 1
+
JETKVM_PROXY_URL="ws://$ip_address" npx vite dev --mode=device
+
+popd > /dev/null
diff --git a/ui/package-lock.json b/ui/package-lock.json
index b51a2ea..9e77e10 100644
--- a/ui/package-lock.json
+++ b/ui/package-lock.json
@@ -19,6 +19,7 @@
"@xterm/addon-webgl": "^0.18.0",
"@xterm/xterm": "^5.5.0",
"cva": "^1.0.0-beta.1",
+ "dayjs": "^1.11.13",
"eslint-import-resolver-alias": "^1.1.2",
"focus-trap-react": "^10.2.3",
"framer-motion": "^11.15.0",
@@ -2433,6 +2434,11 @@
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/dayjs": {
+ "version": "1.11.13",
+ "resolved": "https://registry.npmjs.org/dayjs/-/dayjs-1.11.13.tgz",
+ "integrity": "sha512-oaMBel6gjolK862uaPQOVTA7q3TZhuSvuMQAAglQDOWYO9A91IrAOUJEyKVlqJlHE0vq5p5UXxzdPfMH/x6xNg=="
+ },
"node_modules/debug": {
"version": "4.3.4",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
diff --git a/ui/package.json b/ui/package.json
index 3160297..4dab092 100644
--- a/ui/package.json
+++ b/ui/package.json
@@ -30,6 +30,7 @@
"@xterm/addon-webgl": "^0.18.0",
"@xterm/xterm": "^5.5.0",
"cva": "^1.0.0-beta.1",
+ "dayjs": "^1.11.13",
"eslint-import-resolver-alias": "^1.1.2",
"focus-trap-react": "^10.2.3",
"framer-motion": "^11.15.0",
diff --git a/ui/public/sse.html b/ui/public/sse.html
new file mode 120000
index 0000000..0a8b4f3
--- /dev/null
+++ b/ui/public/sse.html
@@ -0,0 +1 @@
+../../internal/logging/sse.html
\ No newline at end of file
diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts
index 0fa4121..db1fd04 100644
--- a/ui/src/hooks/stores.ts
+++ b/ui/src/hooks/stores.ts
@@ -663,6 +663,95 @@ export const useDeviceStore = create(set => ({
setSystemVersion: version => set({ systemVersion: version }),
}));
+export interface DhcpLease {
+ ip?: string;
+ netmask?: string;
+ broadcast?: string;
+ ttl?: string;
+ mtu?: string;
+ hostname?: string;
+ domain?: string;
+ bootp_next_server?: string;
+ bootp_server_name?: string;
+ bootp_file?: string;
+ timezone?: string;
+ routers?: string[];
+ dns?: string[];
+ ntp_servers?: string[];
+ lpr_servers?: string[];
+ _time_servers?: string[];
+ _name_servers?: string[];
+ _log_servers?: string[];
+ _cookie_servers?: string[];
+ _wins_servers?: string[];
+ _swap_server?: string;
+ boot_size?: string;
+ root_path?: string;
+ lease?: string;
+ lease_expiry?: Date;
+ dhcp_type?: string;
+ server_id?: string;
+ message?: string;
+ tftp?: string;
+ bootfile?: string;
+}
+
+export interface IPv6Address {
+ address: string;
+ prefix: string;
+ valid_lifetime: string;
+ preferred_lifetime: string;
+ scope: string;
+}
+
+export interface NetworkState {
+ interface_name?: string;
+ mac_address?: string;
+ ipv4?: string;
+ ipv4_addresses?: string[];
+ ipv6?: string;
+ ipv6_addresses?: IPv6Address[];
+ ipv6_link_local?: string;
+ dhcp_lease?: DhcpLease;
+
+ setNetworkState: (state: NetworkState) => void;
+ setDhcpLease: (lease: NetworkState["dhcp_lease"]) => void;
+ setDhcpLeaseExpiry: (expiry: Date) => void;
+}
+
+
+export type IPv6Mode = "disabled" | "slaac" | "dhcpv6" | "slaac_and_dhcpv6" | "static" | "link_local" | "unknown";
+export type IPv4Mode = "disabled" | "static" | "dhcp" | "unknown";
+export type LLDPMode = "disabled" | "basic" | "all" | "unknown";
+export type mDNSMode = "disabled" | "auto" | "ipv4_only" | "ipv6_only" | "unknown";
+export type TimeSyncMode = "ntp_only" | "ntp_and_http" | "http_only" | "custom" | "unknown";
+
+export interface NetworkSettings {
+ hostname: string;
+ domain: string;
+ ipv4_mode: IPv4Mode;
+ ipv6_mode: IPv6Mode;
+ lldp_mode: LLDPMode;
+ lldp_tx_tlvs: string[];
+ mdns_mode: mDNSMode;
+ time_sync_mode: TimeSyncMode;
+}
+
+export const useNetworkStateStore = create((set, get) => ({
+ setNetworkState: (state: NetworkState) => set(state),
+ setDhcpLease: (lease: NetworkState["dhcp_lease"]) => set({ dhcp_lease: lease }),
+ setDhcpLeaseExpiry: (expiry: Date) => {
+ const lease = get().dhcp_lease;
+ if (!lease) {
+ console.warn("No lease found");
+ return;
+ }
+
+ lease.lease_expiry = expiry;
+ set({ dhcp_lease: lease });
+ }
+}));
+
export interface KeySequenceStep {
keys: string[];
modifiers: string[];
@@ -767,8 +856,8 @@ export const useMacrosStore = create((set, get) => ({
for (let i = 0; i < macro.steps.length; i++) {
const step = macro.steps[i];
if (step.keys && step.keys.length > MAX_KEYS_PER_STEP) {
- console.error(`Cannot save: macro "${macro.name}" step ${i+1} exceeds maximum of ${MAX_KEYS_PER_STEP} keys`);
- throw new Error(`Cannot save: macro "${macro.name}" step ${i+1} exceeds maximum of ${MAX_KEYS_PER_STEP} keys`);
+ console.error(`Cannot save: macro "${macro.name}" step ${i + 1} exceeds maximum of ${MAX_KEYS_PER_STEP} keys`);
+ throw new Error(`Cannot save: macro "${macro.name}" step ${i + 1} exceeds maximum of ${MAX_KEYS_PER_STEP} keys`);
}
}
}
diff --git a/ui/src/main.tsx b/ui/src/main.tsx
index e09a2a9..f4bdd34 100644
--- a/ui/src/main.tsx
+++ b/ui/src/main.tsx
@@ -42,6 +42,7 @@ import SettingsVideoRoute from "./routes/devices.$id.settings.video";
import SettingsAppearanceRoute from "./routes/devices.$id.settings.appearance";
import * as SettingsGeneralIndexRoute from "./routes/devices.$id.settings.general._index";
import SettingsGeneralUpdateRoute from "./routes/devices.$id.settings.general.update";
+import SettingsNetworkRoute from "./routes/devices.$id.settings.network";
import SecurityAccessLocalAuthRoute from "./routes/devices.$id.settings.access.local-auth";
import SettingsMacrosRoute from "./routes/devices.$id.settings.macros";
import SettingsMacrosAddRoute from "./routes/devices.$id.settings.macros.add";
@@ -156,6 +157,10 @@ if (isOnDevice) {
path: "hardware",
element: ,
},
+ {
+ path: "network",
+ element: ,
+ },
{
path: "access",
children: [
diff --git a/ui/src/routes/devices.$id.settings.network.tsx b/ui/src/routes/devices.$id.settings.network.tsx
new file mode 100644
index 0000000..59d52ef
--- /dev/null
+++ b/ui/src/routes/devices.$id.settings.network.tsx
@@ -0,0 +1,408 @@
+import { useCallback, useEffect, useState } from "react";
+
+import { SelectMenuBasic } from "../components/SelectMenuBasic";
+import { SettingsPageHeader } from "../components/SettingsPageheader";
+
+import { IPv4Mode, IPv6Mode, LLDPMode, mDNSMode, NetworkSettings, NetworkState, TimeSyncMode, useNetworkStateStore } from "@/hooks/stores";
+import { useJsonRpc } from "@/hooks/useJsonRpc";
+import notifications from "@/notifications";
+import { Button } from "@components/Button";
+import { GridCard } from "@components/Card";
+import InputField from "@components/InputField";
+import { SettingsItem } from "./devices.$id.settings";
+
+import dayjs from 'dayjs';
+import relativeTime from 'dayjs/plugin/relativeTime';
+
+dayjs.extend(relativeTime);
+
+const defaultNetworkSettings: NetworkSettings = {
+ hostname: "",
+ domain: "",
+ ipv4_mode: "unknown",
+ ipv6_mode: "unknown",
+ lldp_mode: "unknown",
+ lldp_tx_tlvs: [],
+ mdns_mode: "unknown",
+ time_sync_mode: "unknown",
+}
+
+export function LifeTimeLabel({ lifetime }: { lifetime: string }) {
+ if (lifetime == "") {
+ return N/A ;
+ }
+
+ const [remaining, setRemaining] = useState(null);
+
+ useEffect(() => {
+ setRemaining(dayjs(lifetime).fromNow());
+
+ const interval = setInterval(() => {
+ setRemaining(dayjs(lifetime).fromNow());
+ }, 1000 * 30);
+ return () => clearInterval(interval);
+ }, [lifetime]);
+
+ return <>
+ {dayjs(lifetime).format()}
+ {remaining && <>
+ {" "}
+ ({remaining})
+
+ >}
+ >
+}
+
+export default function SettingsNetworkRoute() {
+ const [send] = useJsonRpc();
+ const [networkState, setNetworkState] = useNetworkStateStore(state => [state, state.setNetworkState]);
+
+ const [networkSettings, setNetworkSettings] = useState(defaultNetworkSettings);
+ const [networkSettingsLoaded, setNetworkSettingsLoaded] = useState(false);
+
+ const getNetworkSettings = useCallback(() => {
+ setNetworkSettingsLoaded(false);
+ send("getNetworkSettings", {}, resp => {
+ if ("error" in resp) return;
+ console.log(resp.result);
+ setNetworkSettings(resp.result as NetworkSettings);
+ setNetworkSettingsLoaded(true);
+ });
+ }, [send]);
+
+ const setNetworkSettingsRemote = useCallback((settings: NetworkSettings) => {
+ setNetworkSettingsLoaded(false);
+ send("setNetworkSettings", { settings }, resp => {
+ if ("error" in resp) {
+ notifications.error("Failed to save network settings: " + (resp.error.data ? resp.error.data : resp.error.message));
+ setNetworkSettingsLoaded(true);
+ return;
+ }
+ setNetworkSettings(resp.result as NetworkSettings);
+ setNetworkSettingsLoaded(true);
+ notifications.success("Network settings saved");
+ });
+ }, [send]);
+
+ const getNetworkState = useCallback(() => {
+ send("getNetworkState", {}, resp => {
+ if ("error" in resp) return;
+ console.log(resp.result);
+ setNetworkState(resp.result as NetworkState);
+ });
+ }, [send]);
+
+ const handleRenewLease = useCallback(() => {
+ send("renewDHCPLease", {}, resp => {
+ if ("error" in resp) {
+ notifications.error("Failed to renew lease: " + resp.error.message);
+ } else {
+ notifications.success("DHCP lease renewed");
+ }
+ });
+ }, [send]);
+
+ useEffect(() => {
+ getNetworkState();
+ getNetworkSettings();
+ }, [getNetworkState, getNetworkSettings]);
+
+ const handleIpv4ModeChange = (value: IPv4Mode | string) => {
+ setNetworkSettings({ ...networkSettings, ipv4_mode: value as IPv4Mode });
+ };
+
+ const handleIpv6ModeChange = (value: IPv6Mode | string) => {
+ setNetworkSettings({ ...networkSettings, ipv6_mode: value as IPv6Mode });
+ };
+
+ const handleLldpModeChange = (value: LLDPMode | string) => {
+ setNetworkSettings({ ...networkSettings, lldp_mode: value as LLDPMode });
+ };
+
+ // const handleLldpTxTlvsChange = (value: string[]) => {
+ // setNetworkSettings({ ...networkSettings, lldp_tx_tlvs: value });
+ // };
+
+ const handleMdnsModeChange = (value: mDNSMode | string) => {
+ setNetworkSettings({ ...networkSettings, mdns_mode: value as mDNSMode });
+ };
+
+ const handleTimeSyncModeChange = (value: TimeSyncMode | string) => {
+ setNetworkSettings({ ...networkSettings, time_sync_mode: value as TimeSyncMode });
+ };
+
+ const filterUnknown = useCallback((options: { value: string; label: string; }[]) => {
+ if (!networkSettingsLoaded) return options;
+ return options.filter(option => option.value !== "unknown");
+ }, [networkSettingsLoaded]);
+
+ return (
+
+
+
+ >}
+ >
+
+ {networkState?.mac_address}
+
+
+
+
+
+ Hostname for the device
+
+
+ Leave blank for default
+
+ >
+ }
+ >
+ {
+ setNetworkSettings({ ...networkSettings, hostname: e.target.value });
+ }}
+ disabled={!networkSettingsLoaded}
+ />
+
+
+
+
+ Domain for the device
+
+
+ Leave blank to use DHCP provided domain, if there is no domain, use local
+
+ >
+ }
+ >
+ {
+ setNetworkSettings({ ...networkSettings, domain: e.target.value });
+ }}
+ disabled={!networkSettingsLoaded}
+ />
+
+
+
+
+ handleIpv4ModeChange(e.target.value)}
+ disabled={!networkSettingsLoaded}
+ options={filterUnknown([
+ { value: "dhcp", label: "DHCP" },
+ // { value: "static", label: "Static" },
+ ])}
+ />
+
+ {networkState?.dhcp_lease && (
+
+
+
+
+
+ Current DHCP Lease
+
+
+
+ {networkState?.dhcp_lease?.ip && IP: {networkState?.dhcp_lease?.ip} }
+ {networkState?.dhcp_lease?.netmask && Subnet: {networkState?.dhcp_lease?.netmask} }
+ {networkState?.dhcp_lease?.broadcast && Broadcast: {networkState?.dhcp_lease?.broadcast} }
+ {networkState?.dhcp_lease?.ttl && TTL: {networkState?.dhcp_lease?.ttl} }
+ {networkState?.dhcp_lease?.mtu && MTU: {networkState?.dhcp_lease?.mtu} }
+ {networkState?.dhcp_lease?.hostname && Hostname: {networkState?.dhcp_lease?.hostname} }
+ {networkState?.dhcp_lease?.domain && Domain: {networkState?.dhcp_lease?.domain} }
+ {networkState?.dhcp_lease?.routers && Gateway: {networkState?.dhcp_lease?.routers.join(", ")} }
+ {networkState?.dhcp_lease?.dns && DNS: {networkState?.dhcp_lease?.dns.join(", ")} }
+ {networkState?.dhcp_lease?.ntp_servers && NTP Servers: {networkState?.dhcp_lease?.ntp_servers.join(", ")} }
+ {networkState?.dhcp_lease?.server_id && Server ID: {networkState?.dhcp_lease?.server_id} }
+ {networkState?.dhcp_lease?.bootp_next_server && BootP Next Server: {networkState?.dhcp_lease?.bootp_next_server} }
+ {networkState?.dhcp_lease?.bootp_server_name && BootP Server Name: {networkState?.dhcp_lease?.bootp_server_name} }
+ {networkState?.dhcp_lease?.bootp_file && Boot File: {networkState?.dhcp_lease?.bootp_file} }
+ {networkState?.dhcp_lease?.lease_expiry &&
+ Lease Expiry:
+ }
+ {/* {JSON.stringify(networkState?.dhcp_lease)} */}
+
+
+
+
+
+ {
+ handleRenewLease();
+ }}
+ />
+
+
+
+
+ )}
+
+
+
+ handleIpv6ModeChange(e.target.value)}
+ disabled={!networkSettingsLoaded}
+ options={filterUnknown([
+ // { value: "disabled", label: "Disabled" },
+ { value: "slaac", label: "SLAAC" },
+ // { value: "dhcpv6", label: "DHCPv6" },
+ // { value: "slaac_and_dhcpv6", label: "SLAAC and DHCPv6" },
+ // { value: "static", label: "Static" },
+ // { value: "link_local", label: "Link-local only" },
+ ])}
+ />
+
+ {networkState?.ipv6_addresses && (
+
+
+
+
+
+ IPv6 Information
+
+
+
+
+ IPv6 Link-local
+
+
+ {networkState?.ipv6_link_local}
+
+
+
+
+ IPv6 Addresses
+
+
+ {networkState?.ipv6_addresses && networkState?.ipv6_addresses.map(addr => (
+
+ {addr.address}
+ {addr.valid_lifetime && <>
+
+ - valid_lft: {" "}
+
+
+
+ >}
+ {addr.preferred_lifetime && <>
+
+ - pref_lft: {" "}
+
+
+
+ >}
+
+ ))}
+
+
+
+
+
+
+
+ )}
+
+
+
+ handleLldpModeChange(e.target.value)}
+ disabled={!networkSettingsLoaded}
+ options={filterUnknown([
+ { value: "disabled", label: "Disabled" },
+ { value: "basic", label: "Basic" },
+ { value: "all", label: "All" },
+ ])}
+ />
+
+
+
+
+ handleMdnsModeChange(e.target.value)}
+ disabled={!networkSettingsLoaded}
+ options={filterUnknown([
+ { value: "disabled", label: "Disabled" },
+ { value: "auto", label: "Auto" },
+ { value: "ipv4_only", label: "IPv4 only" },
+ { value: "ipv6_only", label: "IPv6 only" },
+ ])}
+ />
+
+
+
+
+ handleTimeSyncModeChange(e.target.value)}
+ disabled={!networkSettingsLoaded}
+ options={filterUnknown([
+ { value: "unknown", label: "..." },
+ // { value: "auto", label: "Auto" },
+ { value: "ntp_only", label: "NTP only" },
+ { value: "ntp_and_http", label: "NTP and HTTP" },
+ { value: "http_only", label: "HTTP only" },
+ // { value: "custom", label: "Custom" },
+ ])}
+ />
+
+
+
+ {
+ setNetworkSettingsRemote(networkSettings);
+ }}
+ size="SM"
+ theme="light"
+ text="Save Settings"
+ />
+
+
+ );
+}
diff --git a/ui/src/routes/devices.$id.settings.tsx b/ui/src/routes/devices.$id.settings.tsx
index c0b4181..f8e5262 100644
--- a/ui/src/routes/devices.$id.settings.tsx
+++ b/ui/src/routes/devices.$id.settings.tsx
@@ -9,6 +9,7 @@ import {
LuArrowLeft,
LuPalette,
LuCommand,
+ LuNetwork,
} from "react-icons/lu";
import React, { useEffect, useRef, useState } from "react";
@@ -207,6 +208,17 @@ export default function SettingsRoute() {
+
+
(isActive ? "active" : "")}
+ >
+
+
+
Network
+
+
+
state.setNetworkState);
+
const setUsbState = useHidStore(state => state.setUsbState);
const setHdmiState = useVideoStore(state => state.setHdmiState);
@@ -600,6 +604,11 @@ export default function KvmIdRoute() {
setHdmiState(resp.params as Parameters[0]);
}
+ if (resp.method === "networkState") {
+ console.log("Setting network state", resp.params);
+ setNetworkState(resp.params as NetworkState);
+ }
+
if (resp.method === "otaState") {
const otaState = resp.params as UpdateState["otaState"];
setOtaState(otaState);
diff --git a/ui/vite.config.ts b/ui/vite.config.ts
index f8459cd..e47774f 100644
--- a/ui/vite.config.ts
+++ b/ui/vite.config.ts
@@ -35,6 +35,7 @@ export default defineConfig(({ mode, command }) => {
"/auth": JETKVM_PROXY_URL,
"/storage": JETKVM_PROXY_URL,
"/cloud": JETKVM_PROXY_URL,
+ "/developer": JETKVM_PROXY_URL,
}
: undefined,
},
diff --git a/usb.go b/usb.go
index 3395db4..91674c9 100644
--- a/usb.go
+++ b/usb.go
@@ -66,6 +66,6 @@ func checkUSBState() {
usbState = newState
usbLogger.Info().Str("from", usbState).Str("to", newState).Msg("USB state changed")
- requestDisplayUpdate()
+ requestDisplayUpdate(true)
triggerUSBStateUpdate()
}
diff --git a/usb_mass_storage.go b/usb_mass_storage.go
index 2b03f1f..79a05d1 100644
--- a/usb_mass_storage.go
+++ b/usb_mass_storage.go
@@ -62,7 +62,11 @@ func onDiskMessage(msg webrtc.DataChannelMessage) {
func mountImage(imagePath string) error {
err := setMassStorageImage("")
if err != nil {
- return fmt.Errorf("remove Mass Storage Image Error: %w", err)
+ return fmt.Errorf("remove mass storage image error: %w", err)
+ }
+ err = setMassStorageImage(imagePath)
+ if err != nil {
+ return fmt.Errorf("set mass storage image error: %w", err)
}
err = setMassStorageImage(imagePath)
if err != nil {
@@ -477,7 +481,6 @@ func handleUploadChannel(d *webrtc.DataChannel) {
totalBytesWritten += int64(bytesWritten)
sendProgress := time.Since(lastProgressTime) >= 200*time.Millisecond
-
if totalBytesWritten >= pendingUpload.Size {
sendProgress = true
close(uploadComplete)
diff --git a/video.go b/video.go
index d74add8..6fa77b9 100644
--- a/video.go
+++ b/video.go
@@ -43,7 +43,7 @@ func HandleVideoStateMessage(event CtrlResponse) {
}
lastVideoState = videoState
triggerVideoStateUpdate()
- requestDisplayUpdate()
+ requestDisplayUpdate(true)
}
func rpcGetVideoState() (VideoInputState, error) {
diff --git a/web.go b/web.go
index 6e74a13..766eaf5 100644
--- a/web.go
+++ b/web.go
@@ -9,6 +9,7 @@ import (
"fmt"
"io/fs"
"net/http"
+ "net/http/pprof"
"path/filepath"
"strings"
"time"
@@ -18,6 +19,7 @@ import (
gin_logger "github.com/gin-contrib/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
+ "github.com/jetkvm/kvm/internal/logging"
"github.com/pion/webrtc/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -103,6 +105,27 @@ func setupRouter() *gin.Engine {
// A Prometheus metrics endpoint.
r.GET("/metrics", gin.WrapH(promhttp.Handler()))
+ // Developer mode protected routes
+ developerModeRouter := r.Group("/developer/")
+ developerModeRouter.Use(basicAuthProtectedMiddleware(true))
+ {
+ // pprof
+ developerModeRouter.GET("/pprof/", gin.WrapF(pprof.Index))
+ developerModeRouter.GET("/pprof/cmdline", gin.WrapF(pprof.Cmdline))
+ developerModeRouter.GET("/pprof/profile", gin.WrapF(pprof.Profile))
+ developerModeRouter.POST("/pprof/symbol", gin.WrapF(pprof.Symbol))
+ developerModeRouter.GET("/pprof/symbol", gin.WrapF(pprof.Symbol))
+ developerModeRouter.GET("/pprof/trace", gin.WrapF(pprof.Trace))
+ developerModeRouter.GET("/pprof/allocs", gin.WrapH(pprof.Handler("allocs")))
+ developerModeRouter.GET("/pprof/block", gin.WrapH(pprof.Handler("block")))
+ developerModeRouter.GET("/pprof/goroutine", gin.WrapH(pprof.Handler("goroutine")))
+ developerModeRouter.GET("/pprof/heap", gin.WrapH(pprof.Handler("heap")))
+ developerModeRouter.GET("/pprof/mutex", gin.WrapH(pprof.Handler("mutex")))
+ developerModeRouter.GET("/pprof/threadcreate", gin.WrapH(pprof.Handler("threadcreate")))
+
+ logging.AttachSSEHandler(developerModeRouter)
+ }
+
// Protected routes (allows both password and noPassword modes)
protected := r.Group("/")
protected.Use(protectedMiddleware())
@@ -203,7 +226,7 @@ func handleLocalWebRTCSignal(c *gin.Context) {
wsOptions := &websocket.AcceptOptions{
InsecureSkipVerify: true, // Allow connections from any origin
OnPingReceived: func(ctx context.Context, payload []byte) bool {
- scopedLogger.Info().Bytes("payload", payload).Msg("ping frame received")
+ scopedLogger.Debug().Bytes("payload", payload).Msg("ping frame received")
metricConnectionTotalPingReceivedCount.WithLabelValues("local", source).Inc()
metricConnectionLastPingReceivedTimestamp.WithLabelValues("local", source).SetToCurrentTime()
@@ -242,7 +265,12 @@ func handleWebRTCSignalWsMessages(
scopedLogger *zerolog.Logger,
) error {
runCtx, cancelRun := context.WithCancel(context.Background())
- defer cancelRun()
+ defer func() {
+ if isCloudConnection {
+ setCloudConnectionState(CloudConnectionStateDisconnected)
+ }
+ cancelRun()
+ }()
// connection type
var sourceType string
@@ -459,11 +487,51 @@ func protectedMiddleware() gin.HandlerFunc {
}
}
+func sendErrorJsonThenAbort(c *gin.Context, status int, message string) {
+ c.JSON(status, gin.H{"error": message})
+ c.Abort()
+}
+
+func basicAuthProtectedMiddleware(requireDeveloperMode bool) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if requireDeveloperMode {
+ devModeState, err := rpcGetDevModeState()
+ if err != nil {
+ sendErrorJsonThenAbort(c, http.StatusInternalServerError, "Failed to get developer mode state")
+ return
+ }
+
+ if !devModeState.Enabled {
+ sendErrorJsonThenAbort(c, http.StatusUnauthorized, "Developer mode is not enabled")
+ return
+ }
+ }
+
+ if config.LocalAuthMode == "noPassword" {
+ sendErrorJsonThenAbort(c, http.StatusForbidden, "The resource is not available in noPassword mode")
+ return
+ }
+
+ // calculate basic auth credentials
+ _, password, ok := c.Request.BasicAuth()
+ if !ok {
+ c.Header("WWW-Authenticate", "Basic realm=\"JetKVM\"")
+ sendErrorJsonThenAbort(c, http.StatusUnauthorized, "Basic auth is required")
+ return
+ }
+
+ err := bcrypt.CompareHashAndPassword([]byte(config.HashedPassword), []byte(password))
+ if err != nil {
+ sendErrorJsonThenAbort(c, http.StatusUnauthorized, "Invalid password")
+ return
+ }
+
+ c.Next()
+ }
+}
+
func RunWebServer() {
r := setupRouter()
- //if strings.Contains(builtAppVersion, "-dev") {
- // pprof.Register(r)
- //}
err := r.Run(":80")
if err != nil {
panic(err)
diff --git a/web_tls.go b/web_tls.go
index cbff56b..564f150 100644
--- a/web_tls.go
+++ b/web_tls.go
@@ -54,7 +54,7 @@ func initCertStore() {
func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
switch config.TLSMode {
case "self-signed":
- if isTimeSyncNeeded() || !timeSyncSuccess {
+ if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() {
return nil, fmt.Errorf("time is not synced")
}
return certSigner.GetCertificate(info)
@@ -174,7 +174,7 @@ func runWebSecureServer() {
websecureLogger.Info().Msg("Shutting down websecure server")
err := server.Shutdown(context.Background())
if err != nil {
- websecureLogger.Error().Err(err).Msg("Failed to shutdown websecure server")
+ websecureLogger.Error().Err(err).Msg("failed to shutdown websecure server")
}
}
}()
diff --git a/webrtc.go b/webrtc.go
index 1e093e2..f6c8529 100644
--- a/webrtc.go
+++ b/webrtc.go
@@ -10,6 +10,7 @@ import (
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
+ "github.com/jetkvm/kvm/internal/logging"
"github.com/pion/webrtc/v4"
"github.com/rs/zerolog"
)
@@ -68,7 +69,7 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
func newSession(config SessionConfig) (*Session, error) {
webrtcSettingEngine := webrtc.SettingEngine{
- LoggerFactory: defaultLoggerFactory,
+ LoggerFactory: logging.GetPionDefaultLoggerFactory(),
}
iceServer := webrtc.ICEServer{}
@@ -205,7 +206,7 @@ func newSession(config SessionConfig) (*Session, error) {
var actionSessions = 0
func onActiveSessionsChanged() {
- requestDisplayUpdate()
+ requestDisplayUpdate(true)
}
func onFirstSessionConnected() {