From f712cb1719cd20f81b51757f07bfb505d49256d3 Mon Sep 17 00:00:00 2001 From: Siyuan Miao Date: Sat, 12 Apr 2025 17:32:15 +0200 Subject: [PATCH] refactor(network): rewrite network and timesync component --- cloud.go | 14 +- config.go | 4 +- display.go | 6 +- go.mod | 1 + go.sum | 2 + internal/timesync/http.go | 54 ++++ internal/timesync/ntp.go | 42 ++++ internal/timesync/rtc.go | 26 ++ internal/timesync/rtc_linux.go | 103 ++++++++ internal/timesync/rtc_notlinux.go | 16 ++ internal/timesync/timesync.go | 151 +++++++++++ internal/udhcpc/options.go | 12 + internal/udhcpc/parser.go | 150 +++++++++++ internal/udhcpc/parser_test.go | 74 ++++++ internal/udhcpc/proc.go | 212 ++++++++++++++++ internal/udhcpc/udhcpc.go | 145 +++++++++++ log.go | 3 +- main.go | 22 +- mdns.go | 60 +++++ network.go | 400 +++++++++++++++++++----------- ntp.go | 214 ---------------- timesync.go | 69 ++++++ web_tls.go | 2 +- 23 files changed, 1401 insertions(+), 381 deletions(-) create mode 100644 internal/timesync/http.go create mode 100644 internal/timesync/ntp.go create mode 100644 internal/timesync/rtc.go create mode 100644 internal/timesync/rtc_linux.go create mode 100644 internal/timesync/rtc_notlinux.go create mode 100644 internal/timesync/timesync.go create mode 100644 internal/udhcpc/options.go create mode 100644 internal/udhcpc/parser.go create mode 100644 internal/udhcpc/parser_test.go create mode 100644 internal/udhcpc/proc.go create mode 100644 internal/udhcpc/udhcpc.go create mode 100644 mdns.go delete mode 100644 ntp.go create mode 100644 timesync.go diff --git a/cloud.go b/cloud.go index 9fbf00b..f7bdb6e 100644 --- a/cloud.go +++ b/cloud.go @@ -311,11 +311,15 @@ func runWebsocketClient() error { }, }) - // get the request id from the response header - connectionId := resp.Header.Get("X-Request-ID") - if connectionId == "" { - connectionId = resp.Header.Get("Cf-Ray") + var connectionId string + if resp != nil { + // get the request id from the response header + connectionId = resp.Header.Get("X-Request-ID") + if connectionId == "" { + connectionId = resp.Header.Get("Cf-Ray") + } } + if connectionId == "" { connectionId = uuid.New().String() scopedLogger.Warn(). @@ -457,7 +461,7 @@ func RunWebsocketClient() { } // If the system time is not synchronized, the API request will fail anyway because the TLS handshake will fail. - if isTimeSyncNeeded() && !timeSyncSuccess { + if isTimeSyncNeeded() && !timeSync.IsSyncSuccess() { cloudLogger.Warn().Msg("system time is not synced, will retry in 3 seconds") time.Sleep(3 * time.Second) continue diff --git a/config.go b/config.go index cf096a7..ed7477e 100644 --- a/config.go +++ b/config.go @@ -134,7 +134,7 @@ func LoadConfig() { defer configLock.Unlock() if config != nil { - logger.Info().Msg("config already loaded, skipping") + logger.Debug().Msg("config already loaded, skipping") return } @@ -167,6 +167,8 @@ func LoadConfig() { config = &loadedConfig rootLogger.UpdateLogLevel() + + logger.Info().Str("path", configPath).Msg("config loaded") } func SaveConfig() error { diff --git a/display.go b/display.go index 38e12b1..7320cce 100644 --- a/display.go +++ b/display.go @@ -48,7 +48,7 @@ func switchToScreenIfDifferent(screenName string) { } func updateDisplay() { - updateLabelIfChanged("ui_Home_Content_Ip", networkState.IPv4) + updateLabelIfChanged("ui_Home_Content_Ip", networkState.IPv4String()) if usbState == "configured" { updateLabelIfChanged("ui_Home_Footer_Usb_Status_Label", "Connected") _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Usb_Status_Label", "state": "LV_STATE_DEFAULT"}) @@ -64,7 +64,7 @@ func updateDisplay() { _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Hdmi_Status_Label", "state": "LV_STATE_USER_2"}) } updateLabelIfChanged("ui_Home_Header_Cloud_Status_Label", fmt.Sprintf("%d active", actionSessions)) - if networkState.Up { + if networkState.IsUp() { switchToScreenIfDifferent("ui_Home_Screen") } else { switchToScreenIfDifferent("ui_No_Network_Screen") @@ -94,7 +94,7 @@ func requestDisplayUpdate() { func updateStaticContents() { //contents that never change - updateLabelIfChanged("ui_Home_Content_Mac", networkState.MAC) + updateLabelIfChanged("ui_Home_Content_Mac", networkState.MACString()) systemVersion, appVersion, err := GetLocalVersion() if err == nil { updateLabelIfChanged("ui_About_Content_Operating_System_Version_ContentLabel", systemVersion.String()) diff --git a/go.mod b/go.mod index 1311a33..bc231f2 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/coder/websocket v1.8.13 github.com/coreos/go-oidc/v3 v3.11.0 github.com/creack/pty v1.1.23 + github.com/fsnotify/fsnotify v1.9.0 github.com/gin-contrib/logger v1.2.5 github.com/gin-gonic/gin v1.10.0 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 565c0cc..018d3a8 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfv github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= github.com/gin-contrib/logger v1.2.5 h1:qVQI4omayQecuN4zX9ZZnsOq7w9J/ZLds3J/FMn8ypM= diff --git a/internal/timesync/http.go b/internal/timesync/http.go new file mode 100644 index 0000000..a6be68c --- /dev/null +++ b/internal/timesync/http.go @@ -0,0 +1,54 @@ +package timesync + +import ( + "net/http" + "time" +) + +func queryHttpTime( + url string, + timeout time.Duration, +) (now *time.Time, err error, response *http.Response) { + client := http.Client{ + Timeout: timeout, + } + resp, err := client.Head(url) + if err != nil { + return nil, err, nil + } + dateStr := resp.Header.Get("Date") + parsedTime, err := time.Parse(time.RFC1123, dateStr) + if err != nil { + return nil, err, resp + } + return &parsedTime, nil, resp +} + +func (t *TimeSync) queryAllHttpTime() (now *time.Time) { + for _, url := range t.httpUrls { + now, err, response := queryHttpTime(url, timeSyncTimeout) + + var status string + if response != nil { + status = response.Status + } + + scopedLogger := t.l.With(). + Str("http_url", url). + Str("status", status). + Logger() + + if err == nil { + scopedLogger.Info(). + Str("time", now.Format(time.RFC3339)). + Msg("HTTP server returned time") + return now + } else { + scopedLogger.Error(). + Str("error", err.Error()). + Msg("failed to query HTTP server") + } + } + + return nil +} diff --git a/internal/timesync/ntp.go b/internal/timesync/ntp.go new file mode 100644 index 0000000..9bc9812 --- /dev/null +++ b/internal/timesync/ntp.go @@ -0,0 +1,42 @@ +package timesync + +import ( + "time" + + "github.com/beevik/ntp" +) + +func (t *TimeSync) queryNetworkTime() (now *time.Time) { + for _, server := range t.ntpServers { + now, err, response := queryNtpServer(server, timeSyncTimeout) + + scopedLogger := t.l.With(). + Str("server", server). + Logger() + + if err == nil { + scopedLogger.Info(). + Str("time", now.Format(time.RFC3339)). + Str("reference", response.ReferenceString()). + Str("rtt", response.RTT.String()). + Str("clockOffset", response.ClockOffset.String()). + Uint8("stratum", response.Stratum). + Msg("NTP server returned time") + return now + } else { + scopedLogger.Error(). + Str("error", err.Error()). + Msg("failed to query NTP server") + } + } + + return nil +} + +func queryNtpServer(server string, timeout time.Duration) (now *time.Time, err error, response *ntp.Response) { + resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout}) + if err != nil { + return nil, err, nil + } + return &resp.Time, nil, resp +} diff --git a/internal/timesync/rtc.go b/internal/timesync/rtc.go new file mode 100644 index 0000000..92ee485 --- /dev/null +++ b/internal/timesync/rtc.go @@ -0,0 +1,26 @@ +package timesync + +import ( + "fmt" + "os" +) + +var ( + rtcDeviceSearchPaths = []string{ + "/dev/rtc", + "/dev/rtc0", + "/dev/rtc1", + "/dev/misc/rtc", + "/dev/misc/rtc0", + "/dev/misc/rtc1", + } +) + +func getRtcDevicePath() (string, error) { + for _, path := range rtcDeviceSearchPaths { + if _, err := os.Stat(path); err == nil { + return path, nil + } + } + return "", fmt.Errorf("rtc device not found") +} diff --git a/internal/timesync/rtc_linux.go b/internal/timesync/rtc_linux.go new file mode 100644 index 0000000..dccfab2 --- /dev/null +++ b/internal/timesync/rtc_linux.go @@ -0,0 +1,103 @@ +//go:build linux + +package timesync + +import ( + "fmt" + "os" + "time" + + "golang.org/x/sys/unix" +) + +func TimetoRtcTime(t time.Time) unix.RTCTime { + return unix.RTCTime{ + Sec: int32(t.Second()), + Min: int32(t.Minute()), + Hour: int32(t.Hour()), + Mday: int32(t.Day()), + Mon: int32(t.Month() - 1), + Year: int32(t.Year() - 1900), + Wday: int32(0), + Yday: int32(0), + Isdst: int32(0), + } +} + +func RtcTimetoTime(t unix.RTCTime) time.Time { + return time.Date( + int(t.Year)+1900, + time.Month(t.Mon+1), + int(t.Mday), + int(t.Hour), + int(t.Min), + int(t.Sec), + 0, + time.UTC, + ) +} + +func (t *TimeSync) getRtcDevice() (*os.File, error) { + if t.rtcDevice == nil { + file, err := os.OpenFile(t.rtcDevicePath, os.O_RDWR, 0666) + if err != nil { + return nil, err + } + t.rtcDevice = file + } + return t.rtcDevice, nil +} + +func (t *TimeSync) getRtcDeviceFd() (int, error) { + device, err := t.getRtcDevice() + if err != nil { + return 0, err + } + return int(device.Fd()), nil +} + +// Read implements Read for the Linux RTC +func (t *TimeSync) readRtcTime() (time.Time, error) { + fd, err := t.getRtcDeviceFd() + if err != nil { + return time.Time{}, fmt.Errorf("failed to get RTC device fd: %w", err) + } + + rtcTime, err := unix.IoctlGetRTCTime(fd) + if err != nil { + return time.Time{}, fmt.Errorf("failed to get RTC time: %w", err) + } + + date := RtcTimetoTime(*rtcTime) + + return date, nil +} + +// Set implements Set for the Linux RTC +// ... +// It might be not accurate as the time consumed by the system call is not taken into account +// but it's good enough for our purposes +func (t *TimeSync) setRtcTime(tu time.Time) error { + rt := TimetoRtcTime(tu) + + fd, err := t.getRtcDeviceFd() + if err != nil { + return fmt.Errorf("failed to get RTC device fd: %w", err) + } + + currentRtcTime, err := t.readRtcTime() + if err != nil { + return fmt.Errorf("failed to read RTC time: %w", err) + } + + t.l.Info(). + Interface("rtc_time", tu). + Str("offset", tu.Sub(currentRtcTime).String()). + Msg("set rtc time") + + if err := unix.IoctlSetRTCTime(fd, &rt); err != nil { + return fmt.Errorf("failed to set RTC time: %w", err) + } + + return nil +} diff --git a/internal/timesync/rtc_notlinux.go b/internal/timesync/rtc_notlinux.go new file mode 100644 index 0000000..e3c1b20 --- /dev/null +++ b/internal/timesync/rtc_notlinux.go @@ -0,0 +1,16 @@ +//go:build !linux + +package timesync + +import ( + "errors" + "time" +) + +func (t *TimeSync) readRtcTime() (time.Time, error) { + return time.Now(), nil +} + +func (t *TimeSync) setRtcTime(tu time.Time) error { + return errors.New("not supported") +} diff --git a/internal/timesync/timesync.go b/internal/timesync/timesync.go new file mode 100644 index 0000000..eac4749 --- /dev/null +++ b/internal/timesync/timesync.go @@ -0,0 +1,151 @@ +package timesync + +import ( + "fmt" + "os" + "os/exec" + "sync" + "time" + + "github.com/rs/zerolog" +) + +const ( + timeSyncRetryStep = 5 * time.Second + timeSyncRetryMaxInt = 1 * time.Minute + timeSyncWaitNetChkInt = 100 * time.Millisecond + timeSyncWaitNetUpInt = 3 * time.Second + timeSyncInterval = 1 * time.Hour + timeSyncTimeout = 2 * time.Second +) + +var ( + timeSyncRetryInterval = 0 * time.Second + defaultNTPServers = []string{ + "time.cloudflare.com", + "time.apple.com", + } +) + +type TimeSync struct { + syncLock *sync.Mutex + l *zerolog.Logger + + ntpServers []string + httpUrls []string + + rtcDevicePath string + rtcDevice *os.File + rtcLock *sync.Mutex + + syncSuccess bool + + preCheckFunc func() (bool, error) +} + +func NewTimeSync( + precheckFunc func() (bool, error), + ntpServers []string, + httpUrls []string, + logger *zerolog.Logger, +) *TimeSync { + rtcDevice, err := getRtcDevicePath() + if err != nil { + logger.Error().Err(err).Msg("failed to get RTC device path") + } else { + logger.Info().Str("path", rtcDevice).Msg("RTC device found") + } + + t := &TimeSync{ + syncLock: &sync.Mutex{}, + l: logger, + rtcDevicePath: rtcDevice, + rtcLock: &sync.Mutex{}, + preCheckFunc: precheckFunc, + ntpServers: ntpServers, + httpUrls: httpUrls, + } + + if t.rtcDevicePath != "" { + rtcTime, _ := t.readRtcTime() + t.l.Info().Interface("rtc_time", rtcTime).Msg("read RTC time") + } + + return t +} + +func (t *TimeSync) doTimeSync() { + for { + if ok, err := t.preCheckFunc(); !ok { + if err != nil { + t.l.Error().Err(err).Msg("pre-check failed") + } + time.Sleep(timeSyncWaitNetChkInt) + continue + } + + t.l.Info().Msg("syncing system time") + start := time.Now() + err := t.Sync() + if err != nil { + t.l.Error().Str("error", err.Error()).Msg("failed to sync system time") + + // retry after a delay + timeSyncRetryInterval += timeSyncRetryStep + time.Sleep(timeSyncRetryInterval) + // reset the retry interval if it exceeds the max interval + if timeSyncRetryInterval > timeSyncRetryMaxInt { + timeSyncRetryInterval = 0 + } + + continue + } + t.syncSuccess = true + t.l.Info().Str("now", time.Now().Format(time.RFC3339)). + Str("time_taken", time.Since(start).String()). + Msg("time sync successful") + + time.Sleep(timeSyncInterval) // after the first sync is done + } +} + +func (t *TimeSync) Sync() error { + var now *time.Time + now = t.queryNetworkTime() + if now == nil { + now = t.queryAllHttpTime() + } + + if now == nil { + return fmt.Errorf("failed to get time from any source") + } + + err := t.setSystemTime(*now) + if err != nil { + return fmt.Errorf("failed to set system time: %w", err) + } + + return nil +} + +func (t *TimeSync) IsSyncSuccess() bool { + return t.syncSuccess +} + +func (t *TimeSync) Start() { + go t.doTimeSync() +} + +func (t *TimeSync) setSystemTime(now time.Time) error { + nowStr := now.Format("2006-01-02 15:04:05") + output, err := exec.Command("date", "-s", nowStr).CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run date -s: %w, %s", err, string(output)) + } + + if t.rtcDevicePath != "" { + return t.setRtcTime(now) + } + + return nil +} diff --git a/internal/udhcpc/options.go b/internal/udhcpc/options.go new file mode 100644 index 0000000..10c9f75 --- /dev/null +++ b/internal/udhcpc/options.go @@ -0,0 +1,12 @@ +package udhcpc + +func (u *DHCPClient) GetNtpServers() []string { + if u.lease == nil { + return nil + } + servers := make([]string, len(u.lease.NTPServers)) + for i, server := range u.lease.NTPServers { + servers[i] = server.String() + } + return servers +} diff --git a/internal/udhcpc/parser.go b/internal/udhcpc/parser.go new file mode 100644 index 0000000..ee529aa --- /dev/null +++ b/internal/udhcpc/parser.go @@ -0,0 +1,150 @@ +package udhcpc + +import ( + "encoding/json" + "fmt" + "log" + "net" + "reflect" + "strconv" + "strings" + "time" +) + +type Lease struct { + // from https://udhcp.busybox.net/README.udhcpc + IPAddress net.IP `env:"ip" json:"ip"` // The obtained IP + Netmask net.IP `env:"subnet" json:"netmask"` // The assigned subnet mask + Broadcast net.IP `env:"broadcast" json:"broadcast"` // The broadcast address for this network + TTL int `env:"ipttl" json:"ttl,omitempty"` // The TTL to use for this network + MTU int `env:"mtu" json:"mtu,omitempty"` // The MTU to use for this network + HostName string `env:"hostname" json:"hostname,omitempty"` // The assigned hostname + Domain string `env:"domain" json:"domain,omitempty"` // The domain name of the network + BootPNextServer net.IP `env:"siaddr" json:"bootp_next_server,omitempty"` // The bootp next server option + BootPServerName string `env:"sname" json:"bootp_server_name,omitempty"` // The bootp server name option + BootPFile string `env:"boot_file" json:"bootp_file,omitempty"` // The bootp boot file option + Timezone string `env:"timezone" json:"timezone,omitempty"` // Offset in seconds from UTC + Routers []net.IP `env:"router" json:"routers,omitempty"` // A list of routers + DNS []net.IP `env:"dns" json:"dns_servers,omitempty"` // A list of DNS servers + NTPServers []net.IP `env:"ntpsrv" json:"ntp_servers,omitempty"` // A list of NTP servers + LPRServers []net.IP `env:"lprsvr" json:"lpr_servers,omitempty"` // A list of LPR servers + TimeServers []net.IP `env:"timesvr" json:"_time_servers,omitempty"` // A list of time servers (obsolete) + IEN116NameServers []net.IP `env:"namesvr" json:"_name_servers,omitempty"` // A list of IEN 116 name servers (obsolete) + LogServers []net.IP `env:"logsvr" json:"_log_servers,omitempty"` // A list of MIT-LCS UDP log servers (obsolete) + CookieServers []net.IP `env:"cookiesvr" json:"_cookie_servers,omitempty"` // A list of RFC 865 cookie servers (obsolete) + WINSServers []net.IP `env:"wins" json:"_wins_servers,omitempty"` // A list of WINS servers + SwapServer net.IP `env:"swapsvr" json:"_swap_server,omitempty"` // The IP address of the client's swap server + BootSize int `env:"bootsize" json:"bootsize,omitempty"` // The length in 512 octect blocks of the bootfile + RootPath string `env:"rootpath" json:"root_path,omitempty"` // The path name of the client's root disk + LeaseTime time.Duration `env:"lease" json:"lease,omitempty"` // The lease time, in seconds + DHCPType string `env:"dhcptype" json:"dhcp_type,omitempty"` // DHCP message type (safely ignored) + ServerID string `env:"serverid" json:"server_id,omitempty"` // The IP of the server + Message string `env:"message" json:"reason,omitempty"` // Reason for a DHCPNAK + TFTPServerName string `env:"tftp" json:"tftp,omitempty"` // The TFTP server name + BootFileName string `env:"bootfile" json:"bootfile,omitempty"` // The boot file name + isEmpty map[string]bool +} + +func (l *Lease) setIsEmpty(m map[string]bool) { + l.isEmpty = m +} + +func (l *Lease) IsEmpty(key string) bool { + return l.isEmpty[key] +} + +func (l *Lease) ToJSON() string { + json, err := json.Marshal(l) + if err != nil { + return "" + } + return string(json) +} + +func UnmarshalDHCPCLease(lease *Lease, str string) error { + // parse the lease file as a map + data := make(map[string]string) + for _, line := range strings.Split(str, "\n") { + line = strings.TrimSpace(line) + // skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + log.Printf("invalid line: %s", line) + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + data[key] = value + } + + // now iterate over the lease struct and set the values + leaseType := reflect.TypeOf(lease).Elem() + leaseValue := reflect.ValueOf(lease).Elem() + + valuesParsed := make(map[string]bool) + + for i := 0; i < leaseType.NumField(); i++ { + field := leaseValue.Field(i) + + // get the env tag + key := leaseType.Field(i).Tag.Get("env") + if key == "" { + continue + } + + valuesParsed[key] = false + + // get the value from the data map + value, ok := data[key] + if !ok || value == "" { + continue + } + + switch field.Interface().(type) { + case string: + field.SetString(value) + case int: + val, err := strconv.Atoi(value) + if err != nil { + continue + } + field.SetInt(int64(val)) + case time.Duration: + val, err := time.ParseDuration(value + "s") + if err != nil { + continue + } + field.Set(reflect.ValueOf(val)) + case net.IP: + ip := net.ParseIP(value) + if ip == nil { + continue + } + field.Set(reflect.ValueOf(ip)) + case []net.IP: + val := make([]net.IP, 0) + for _, ipStr := range strings.Fields(value) { + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + val = append(val, ip) + } + field.Set(reflect.ValueOf(val)) + default: + return fmt.Errorf("unsupported field `%s` type: %s", key, field.Type().String()) + } + + valuesParsed[key] = true + } + + lease.setIsEmpty(valuesParsed) + + return nil +} diff --git a/internal/udhcpc/parser_test.go b/internal/udhcpc/parser_test.go new file mode 100644 index 0000000..423ab53 --- /dev/null +++ b/internal/udhcpc/parser_test.go @@ -0,0 +1,74 @@ +package udhcpc + +import ( + "testing" + "time" +) + +func TestUnmarshalDHCPCLease(t *testing.T) { + lease := &Lease{} + err := UnmarshalDHCPCLease(lease, ` +# generated @ Mon Jan 4 19:31:53 UTC 2021 +# 19:31:53 up 0 min, 0 users, load average: 0.72, 0.14, 0.04 +# the date might be inaccurate if the clock is not set +ip=192.168.0.240 +siaddr=192.168.0.1 +sname= +boot_file= +subnet=255.255.255.0 +timezone= +router=192.168.0.1 +timesvr= +namesvr= +dns=172.19.53.2 +logsvr= +cookiesvr= +lprsvr= +hostname= +bootsize= +domain= +swapsvr= +rootpath= +ipttl= +mtu= +broadcast= +ntpsrv=162.159.200.123 +wins= +lease=172800 +dhcptype= +serverid=192.168.0.1 +message= +tftp= +bootfile= + `) + if lease.IPAddress.String() != "192.168.0.240" { + t.Fatalf("expected ip to be 192.168.0.240, got %s", lease.IPAddress.String()) + } + if lease.Netmask.String() != "255.255.255.0" { + t.Fatalf("expected netmask to be 255.255.255.0, got %s", lease.Netmask.String()) + } + if len(lease.Routers) != 1 { + t.Fatalf("expected 1 router, got %d", len(lease.Routers)) + } + if lease.Routers[0].String() != "192.168.0.1" { + t.Fatalf("expected router to be 192.168.0.1, got %s", lease.Routers[0].String()) + } + if len(lease.NTPServers) != 1 { + t.Fatalf("expected 1 timeserver, got %d", len(lease.NTPServers)) + } + if lease.NTPServers[0].String() != "162.159.200.123" { + t.Fatalf("expected timeserver to be 162.159.200.123, got %s", lease.NTPServers[0].String()) + } + if len(lease.DNS) != 1 { + t.Fatalf("expected 1 dns, got %d", len(lease.DNS)) + } + if lease.DNS[0].String() != "172.19.53.2" { + t.Fatalf("expected dns to be 172.19.53.2, got %s", lease.DNS[0].String()) + } + if lease.LeaseTime != 172800*time.Second { + t.Fatalf("expected lease time to be 172800 seconds, got %d", lease.LeaseTime) + } + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/udhcpc/proc.go b/internal/udhcpc/proc.go new file mode 100644 index 0000000..69c2ab9 --- /dev/null +++ b/internal/udhcpc/proc.go @@ -0,0 +1,212 @@ +package udhcpc + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "syscall" +) + +func readFileNoStat(filename string) ([]byte, error) { + const maxBufferSize = 1024 * 1024 + + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + reader := io.LimitReader(f, maxBufferSize) + return io.ReadAll(reader) +} + +func toCmdline(path string) ([]string, error) { + data, err := readFileNoStat(path) + if err != nil { + return nil, err + } + + if len(data) < 1 { + return []string{}, nil + } + + return strings.Split(string(bytes.TrimRight(data, "\x00")), "\x00"), nil +} + +func (p *DHCPClient) findUdhcpcProcess() (int, error) { + // read procfs for udhcpc processes + // we do not use procfs.AllProcs() because we want to avoid the overhead of reading the entire procfs + processes, err := os.ReadDir("/proc") + if err != nil { + return 0, err + } + + // iterate over the processes + for _, d := range processes { + // check if file is numeric + pid, err := strconv.Atoi(d.Name()) + if err != nil { + continue + } + + // check if it's a directory + if !d.IsDir() { + continue + } + + cmdline, err := toCmdline(filepath.Join("/proc", d.Name(), "cmdline")) + if err != nil { + continue + } + + if len(cmdline) < 1 { + continue + } + + if cmdline[0] != "udhcpc" { + continue + } + + cmdlineText := strings.Join(cmdline, " ") + + // check if it's a udhcpc process + if strings.Contains(cmdlineText, fmt.Sprintf("-i %s", p.InterfaceName)) { + p.logger.Debug(). + Str("pid", d.Name()). + Interface("cmdline", cmdline). + Msg("found udhcpc process") + return pid, nil + } + } + + return 0, errors.New("udhcpc process not found") +} + +func (c *DHCPClient) getProcessPid() (int, error) { + var pid int + if c.pidFile != "" { + // try to read the pid file + pidHandle, err := os.ReadFile(c.pidFile) + if err != nil { + c.logger.Warn().Err(err). + Str("pidFile", c.pidFile).Msg("failed to read udhcpc pid file") + } + + // if it exists, try to read the pid + if pidHandle != nil { + pidFromFile, err := strconv.Atoi(string(pidHandle)) + if err != nil { + c.logger.Warn().Err(err). + Str("pidFile", c.pidFile).Msg("failed to convert pid file to int") + } + pid = pidFromFile + } + } + + // if the pid is 0, try to find the pid using procfs + if pid == 0 { + newPid, err := c.findUdhcpcProcess() + if err != nil { + return 0, err + } + pid = newPid + } + + return pid, nil +} + +func (c *DHCPClient) getProcess() *os.Process { + pid, err := c.getProcessPid() + if err != nil { + return nil + } + + process, err := os.FindProcess(pid) + if err != nil { + c.logger.Warn().Err(err). + Int("pid", pid).Msg("failed to find process") + return nil + } + + return process +} + +func (c *DHCPClient) GetProcess() *os.Process { + if c.process == nil { + process := c.getProcess() + if process == nil { + return nil + } + c.process = process + } + + err := c.process.Signal(syscall.Signal(0)) + if err != nil && errors.Is(err, os.ErrProcessDone) { + oldPid := c.process.Pid + + c.process = nil + c.process = c.getProcess() + if c.process == nil { + c.logger.Error().Msg("failed to find new udhcpc process") + return nil + } + c.logger.Warn(). + Int("oldPid", oldPid). + Int("newPid", c.process.Pid). + Msg("udhcpc process pid changed") + } else if err != nil { + c.logger.Warn().Err(err). + Int("pid", c.process.Pid).Msg("udhcpc process is not running") + } + + return c.process +} + +func (c *DHCPClient) KillProcess() error { + process := c.GetProcess() + if process == nil { + return nil + } + + return process.Kill() +} + +func (c *DHCPClient) ReleaseProcess() error { + process := c.GetProcess() + if process == nil { + return nil + } + + return process.Release() +} + +func (c *DHCPClient) signalProcess(sig syscall.Signal) error { + process := c.GetProcess() + if process == nil { + return nil + } + + s := process.Signal(sig) + if s != nil { + c.logger.Warn().Err(s). + Int("pid", process.Pid). + Str("signal", sig.String()). + Msg("failed to signal udhcpc process") + return s + } + + return nil +} + +func (c *DHCPClient) Renew() error { + return c.signalProcess(syscall.SIGUSR1) +} + +func (c *DHCPClient) Release() error { + return c.signalProcess(syscall.SIGUSR2) +} diff --git a/internal/udhcpc/udhcpc.go b/internal/udhcpc/udhcpc.go new file mode 100644 index 0000000..1459b40 --- /dev/null +++ b/internal/udhcpc/udhcpc.go @@ -0,0 +1,145 @@ +package udhcpc + +import ( + "errors" + "fmt" + "os" + + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog" +) + +const ( + DHCPLeaseFile = "/run/udhcpc.%s.info" + DHCPPidFile = "/run/udhcpc.%s.pid" +) + +type DHCPClient struct { + InterfaceName string + leaseFile string + pidFile string + lease *Lease + logger *zerolog.Logger + process *os.Process + onLeaseChange func(lease *Lease) +} + +type DHCPClientOptions struct { + InterfaceName string + PidFile string + Logger *zerolog.Logger + OnLeaseChange func(lease *Lease) +} + +var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel) + +func NewDHCPClient(options *DHCPClientOptions) *DHCPClient { + if options.Logger == nil { + options.Logger = &defaultLogger + } + + l := options.Logger.With().Str("interface", options.InterfaceName).Logger() + return &DHCPClient{ + InterfaceName: options.InterfaceName, + logger: &l, + leaseFile: fmt.Sprintf(DHCPLeaseFile, options.InterfaceName), + pidFile: options.PidFile, + onLeaseChange: options.OnLeaseChange, + } +} + +// Run starts the DHCP client and watches the lease file for changes. +// this isn't a blocking call, and the lease file is reloaded when a change is detected. +func (c *DHCPClient) Run() error { + err := c.loadLeaseFile() + if err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + defer watcher.Close() + + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { + c.logger.Debug(). + Str("event", event.Name). + Msg("udhcpc lease file updated, reloading lease") + c.loadLeaseFile() + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + c.logger.Error().Err(err).Msg("error watching lease file") + } + } + }() + + watcher.Add(c.leaseFile) + + // TODO: update udhcpc pid file + // we'll comment this out for now because the pid might change + // process := c.GetProcess() + // if process == nil { + // c.logger.Error().Msg("udhcpc process not found") + // } + + // block the goroutine until the lease file is updated + <-make(chan struct{}) + + return nil +} + +func (c *DHCPClient) loadLeaseFile() error { + file, err := os.ReadFile(c.leaseFile) + if err != nil { + return err + } + + data := string(file) + if data == "" { + c.logger.Debug().Msg("udhcpc lease file is empty") + return nil + } + + lease := &Lease{} + err = UnmarshalDHCPCLease(lease, string(file)) + if err != nil { + return err + } + + isFirstLoad := c.lease == nil + c.lease = lease + + if lease.IPAddress == nil { + c.logger.Info(). + Interface("lease", lease). + Str("data", string(file)). + Msg("udhcpc lease cleared") + return nil + } + + msg := "udhcpc lease updated" + if isFirstLoad { + msg = "udhcpc lease loaded" + } + + c.onLeaseChange(lease) + + c.logger.Info(). + Str("ip", lease.IPAddress.String()). + Str("leaseTime", lease.LeaseTime.String()). + Interface("data", lease). + Msg(msg) + + return nil +} diff --git a/log.go b/log.go index ed46852..8bc24d9 100644 --- a/log.go +++ b/log.go @@ -218,12 +218,13 @@ func ErrorfL(l *zerolog.Logger, format string, err error, args ...interface{}) e var ( logger = rootLogger.getLogger("jetkvm") + networkLogger = rootLogger.getLogger("network") cloudLogger = rootLogger.getLogger("cloud") websocketLogger = rootLogger.getLogger("websocket") webrtcLogger = rootLogger.getLogger("webrtc") nativeLogger = rootLogger.getLogger("native") nbdLogger = rootLogger.getLogger("nbd") - ntpLogger = rootLogger.getLogger("ntp") + timesyncLogger = rootLogger.getLogger("timesync") jsonRpcLogger = rootLogger.getLogger("jsonrpc") watchdogLogger = rootLogger.getLogger("watchdog") websecureLogger = rootLogger.getLogger("websecure") diff --git a/main.go b/main.go index 9eab708..73a4702 100644 --- a/main.go +++ b/main.go @@ -15,26 +15,38 @@ var appCtx context.Context func Main() { LoadConfig() - logger.Debug().Msg("config loaded") var cancel context.CancelFunc appCtx, cancel = context.WithCancel(context.Background()) defer cancel() - logger.Info().Msg("starting JetKvm") + + systemVersionLocal, appVersionLocal, err := GetLocalVersion() + if err != nil { + logger.Warn().Err(err).Msg("failed to get local version") + } + + logger.Info(). + Interface("system_version", systemVersionLocal). + Interface("app_version", appVersionLocal). + Msg("starting JetKVM") go runWatchdog() go confirmCurrentSystem() http.DefaultClient.Timeout = 1 * time.Minute - err := rootcerts.UpdateDefaultTransport() + err = rootcerts.UpdateDefaultTransport() if err != nil { - logger.Warn().Err(err).Msg("failed to load CA certs") + logger.Warn().Err(err).Msg("failed to load Root CA certificates") } + logger.Info(). + Int("ca_certs_loaded", len(rootcerts.Certs())). + Msg("loaded Root CA certificates") initNetwork() + initTimeSync() - go TimeSyncLoop() + timeSync.Start() StartNativeCtrlSocketServer() StartNativeVideoSocketServer() diff --git a/mdns.go b/mdns.go new file mode 100644 index 0000000..309709e --- /dev/null +++ b/mdns.go @@ -0,0 +1,60 @@ +package kvm + +import ( + "net" + + "github.com/pion/mdns/v2" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var mDNSConn *mdns.Conn + +func startMDNS() error { + // If server was previously running, stop it + if mDNSConn != nil { + logger.Info().Msg("stopping mDNS server") + err := mDNSConn.Close() + if err != nil { + logger.Warn().Err(err).Msg("failed to stop mDNS server") + } + } + + // Start a new server + hostname := "jetkvm.local" + + scopedLogger := logger.With().Str("hostname", hostname).Logger() + scopedLogger.Info().Msg("starting mDNS server") + + addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) + if err != nil { + return err + } + + addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) + if err != nil { + return err + } + + l4, err := net.ListenUDP("udp4", addr4) + if err != nil { + return err + } + + l6, err := net.ListenUDP("udp6", addr6) + if err != nil { + return err + } + + mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &mdns.Config{ + LocalNames: []string{hostname}, //TODO: make it configurable + LoggerFactory: defaultLoggerFactory, + }) + if err != nil { + scopedLogger.Warn().Err(err).Msg("failed to start mDNS server") + mDNSConn = nil + return err + } + //defer server.Close() + return nil +} diff --git a/network.go b/network.go index e524e72..14ffe7d 100644 --- a/network.go +++ b/network.go @@ -1,214 +1,308 @@ package kvm import ( - "bytes" "fmt" "net" "os" - "strings" + "sync" "time" - "os/exec" - - "github.com/hashicorp/go-envparse" - "github.com/pion/mdns/v2" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" + "github.com/Masterminds/semver/v3" + "github.com/jetkvm/kvm/internal/udhcpc" + "github.com/rs/zerolog" "github.com/vishvananda/netlink" "github.com/vishvananda/netlink/nl" ) -var mDNSConn *mdns.Conn +var ( + networkState *NetworkInterfaceState +) -var networkState NetworkState +type DhcpTargetState int -type NetworkState struct { - Up bool - IPv4 string - IPv6 string - MAC string +const ( + DhcpTargetStateDoNothing DhcpTargetState = iota + DhcpTargetStateStart + DhcpTargetStateStop + DhcpTargetStateRenew + DhcpTargetStateRelease +) + +type NetworkInterfaceState struct { + interfaceName string + interfaceUp bool + ipv4Addr *net.IP + ipv6Addr *net.IP + macAddr *net.HardwareAddr + + l *zerolog.Logger + stateLock sync.Mutex + + dhcpClient *udhcpc.DHCPClient + + onStateChange func(state *NetworkInterfaceState) + onInitialCheck func(state *NetworkInterfaceState) checked bool } -func (s *NetworkState) IsUp() bool { - return s.Up && s.IPv4 != "" && s.IPv6 != "" +func (s *NetworkInterfaceState) IsUp() bool { + return s.interfaceUp } -func (s *NetworkState) HasIPAssigned() bool { - return s.IPv4 != "" || s.IPv6 != "" +func (s *NetworkInterfaceState) HasIPAssigned() bool { + return s.ipv4Addr != nil || s.ipv6Addr != nil } -func (s *NetworkState) IsOnline() bool { - return s.Up && s.HasIPAssigned() +func (s *NetworkInterfaceState) IsOnline() bool { + return s.IsUp() && s.HasIPAssigned() } -type LocalIpInfo struct { - IPv4 string - IPv6 string - MAC string +func (s *NetworkInterfaceState) IPv4() *net.IP { + return s.ipv4Addr +} + +func (s *NetworkInterfaceState) IPv4String() string { + if s.ipv4Addr == nil { + return "..." + } + return s.ipv4Addr.String() +} + +func (s *NetworkInterfaceState) IPv6() *net.IP { + return s.ipv6Addr +} + +func (s *NetworkInterfaceState) IPv6String() string { + if s.ipv6Addr == nil { + return "..." + } + return s.ipv6Addr.String() +} + +func (s *NetworkInterfaceState) MAC() *net.HardwareAddr { + return s.macAddr +} + +func (s *NetworkInterfaceState) MACString() string { + if s.macAddr == nil { + return "" + } + return s.macAddr.String() } const ( - NetIfName = "eth0" - DHCPLeaseFile = "/run/udhcpc.%s.info" + // TODO: add support for multiple interfaces + NetIfName = "eth0" ) -// setDhcpClientState sends signals to udhcpc to change it's current mode -// of operation. Setting active to true will force udhcpc to renew the DHCP lease. -// Setting active to false will put udhcpc into idle mode. -func setDhcpClientState(active bool) { - var signal string - if active { - signal = "-SIGUSR1" - } else { - signal = "-SIGUSR2" +func NewNetworkInterfaceState(ifname string) *NetworkInterfaceState { + logger := networkLogger.With().Str("interface", ifname).Logger() + + s := &NetworkInterfaceState{ + interfaceName: ifname, + stateLock: sync.Mutex{}, + l: &logger, + onStateChange: func(state *NetworkInterfaceState) { + go func() { + waitCtrlClientConnected() + requestDisplayUpdate() + }() + }, + onInitialCheck: func(state *NetworkInterfaceState) { + go func() { + waitCtrlClientConnected() + requestDisplayUpdate() + }() + }, } - cmd := exec.Command("/usr/bin/killall", signal, "udhcpc") - if err := cmd.Run(); err != nil { - logger.Warn().Err(err).Msg("network: setDhcpClientState: failed to change udhcpc state") + // use a pid file for udhcpc if the system version is 0.2.4 or higher + dhcpPidFile := "" + systemVersionLocal, _, _ := GetLocalVersion() + if systemVersionLocal != nil && + systemVersionLocal.Compare(semver.MustParse("0.2.4")) >= 0 { + dhcpPidFile = fmt.Sprintf("/run/udhcpc.%s.pid", ifname) } + + // create the dhcp client + dhcpClient := udhcpc.NewDHCPClient(&udhcpc.DHCPClientOptions{ + InterfaceName: ifname, + PidFile: dhcpPidFile, + Logger: &logger, + OnLeaseChange: func(lease *udhcpc.Lease) { + s.update() + }, + }) + + s.dhcpClient = dhcpClient + + return s } -func checkNetworkState() { - iface, err := netlink.LinkByName(NetIfName) +func (s *NetworkInterfaceState) update() (DhcpTargetState, error) { + s.stateLock.Lock() + defer s.stateLock.Unlock() + + dhcpTargetState := DhcpTargetStateDoNothing + + iface, err := netlink.LinkByName(s.interfaceName) if err != nil { - logger.Warn().Err(err).Str("interface", NetIfName).Msg("failed to get interface") - return + s.l.Error().Err(err).Msg("failed to get interface") + return dhcpTargetState, err } - newState := NetworkState{ - Up: iface.Attrs().OperState == netlink.OperUp, - MAC: iface.Attrs().HardwareAddr.String(), + // detect if the interface status changed + var changed bool + attrs := iface.Attrs() + state := attrs.OperState + newInterfaceUp := state == netlink.OperUp - checked: true, + // check if the interface is coming up + interfaceGoingUp := s.interfaceUp == false && newInterfaceUp == true + interfaceGoingDown := s.interfaceUp == true && newInterfaceUp == false + + if s.interfaceUp != newInterfaceUp { + s.interfaceUp = newInterfaceUp + changed = true } + if changed { + if interfaceGoingUp { + s.l.Info().Msg("interface state transitioned to up") + dhcpTargetState = DhcpTargetStateRenew + } else if interfaceGoingDown { + s.l.Info().Msg("interface state transitioned to down") + } + } + + // set the mac address + s.macAddr = &attrs.HardwareAddr + + // get the ip addresses addrs, err := netlink.AddrList(iface, nl.FAMILY_ALL) if err != nil { - logger.Warn().Err(err).Str("interface", NetIfName).Msg("failed to get addresses") + s.l.Error().Err(err).Msg("failed to get ip addresses") + return dhcpTargetState, err } - // If the link is going down, put udhcpc into idle mode. - // If the link is coming back up, activate udhcpc and force it to renew the lease. - if newState.Up != networkState.Up { - setDhcpClientState(newState.Up) - } + var ( + ipv4Addresses = make([]net.IP, 0) + ipv6Addresses = make([]net.IP, 0) + ) for _, addr := range addrs { if addr.IP.To4() != nil { - if !newState.Up && networkState.Up { - // If the network is going down, remove all IPv4 addresses from the interface. - logger.Info().Str("address", addr.IP.String()).Msg("network: state transitioned to down, removing IPv4 address") + scopedLogger := s.l.With().Str("ipv4", addr.IP.String()).Logger() + if interfaceGoingDown { + // remove all IPv4 addresses from the interface. + scopedLogger.Info().Msg("state transitioned to down, removing IPv4 address") err := netlink.AddrDel(iface, &addr) if err != nil { - logger.Warn().Err(err).Str("address", addr.IP.String()).Msg("network: failed to delete address") + scopedLogger.Warn().Err(err).Msg("failed to delete address") } - - newState.IPv4 = "..." - } else { - newState.IPv4 = addr.IP.String() + // notify the DHCP client to release the lease + dhcpTargetState = DhcpTargetStateRelease + continue } - } else if addr.IP.To16() != nil && newState.IPv6 == "" { - newState.IPv6 = addr.IP.String() + ipv4Addresses = append(ipv4Addresses, addr.IP) + } else if addr.IP.To16() != nil { + scopedLogger := s.l.With().Str("ipv6", addr.IP.String()).Logger() + // check if it's a link local address + if !addr.IP.IsGlobalUnicast() { + scopedLogger.Trace().Msg("not a global unicast address, skipping") + continue + } + + if interfaceGoingDown { + scopedLogger.Info().Msg("state transitioned to down, removing IPv6 address") + err := netlink.AddrDel(iface, &addr) + if err != nil { + scopedLogger.Warn().Err(err).Msg("failed to delete address") + } + continue + } + ipv6Addresses = append(ipv6Addresses, addr.IP) } } - if newState != networkState { - logger.Info(). - Interface("newState", newState). - Interface("oldState", networkState). - Msg("network state changed") - - // restart MDNS - _ = startMDNS() - networkState = newState - requestDisplayUpdate() + if len(ipv4Addresses) > 0 { + // compare the addresses to see if there's a change + if s.ipv4Addr == nil || s.ipv4Addr.String() != ipv4Addresses[0].String() { + scopedLogger := s.l.With().Str("ipv4", ipv4Addresses[0].String()).Logger() + if s.ipv4Addr != nil { + scopedLogger.Info(). + Str("old_ipv4", s.ipv4Addr.String()). + Msg("IPv4 address changed") + changed = true + } else { + scopedLogger.Info().Msg("IPv4 address found") + } + s.ipv4Addr = &ipv4Addresses[0] + changed = true + } } + + if len(ipv6Addresses) > 0 { + // compare the addresses to see if there's a change + if s.ipv6Addr == nil || s.ipv6Addr.String() != ipv6Addresses[0].String() { + scopedLogger := s.l.With().Str("ipv6", ipv6Addresses[0].String()).Logger() + if s.ipv6Addr != nil { + scopedLogger.Info(). + Str("old_ipv6", s.ipv6Addr.String()). + Msg("IPv6 address changed") + } else { + scopedLogger.Info().Msg("IPv6 address found") + } + s.ipv6Addr = &ipv6Addresses[0] + changed = true + } + } + + // if it's the initial check, we'll set changed to false + initialCheck := !s.checked + if initialCheck { + s.checked = true + changed = false + } + + if initialCheck { + s.onInitialCheck(s) + } else if changed { + s.onStateChange(s) + } + + return dhcpTargetState, nil } -func startMDNS() error { - // If server was previously running, stop it - if mDNSConn != nil { - logger.Info().Msg("stopping mDNS server") - err := mDNSConn.Close() - if err != nil { - logger.Warn().Err(err).Msg("failed to stop mDNS server") - } - } - - // Start a new server - hostname := "jetkvm.local" - - scopedLogger := logger.With().Str("hostname", hostname).Logger() - scopedLogger.Info().Msg("starting mDNS server") - - addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) +func (s *NetworkInterfaceState) CheckAndUpdateDhcp() error { + dhcpTargetState, err := s.update() if err != nil { - return err + return ErrorfL(s.l, "failed to update network state", err) } - addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) - if err != nil { - return err + switch dhcpTargetState { + case DhcpTargetStateRenew: + s.l.Info().Msg("renewing DHCP lease") + s.dhcpClient.Renew() + case DhcpTargetStateRelease: + s.l.Info().Msg("releasing DHCP lease") + s.dhcpClient.Release() + case DhcpTargetStateStart: + s.l.Warn().Msg("dhcpTargetStateStart not implemented") + case DhcpTargetStateStop: + s.l.Warn().Msg("dhcpTargetStateStop not implemented") } - l4, err := net.ListenUDP("udp4", addr4) - if err != nil { - return err - } - - l6, err := net.ListenUDP("udp6", addr6) - if err != nil { - return err - } - - mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &mdns.Config{ - LocalNames: []string{hostname}, //TODO: make it configurable - LoggerFactory: defaultLoggerFactory, - }) - if err != nil { - scopedLogger.Warn().Err(err).Msg("failed to start mDNS server") - mDNSConn = nil - return err - } - //defer server.Close() return nil } -func getNTPServersFromDHCPInfo() ([]string, error) { - buf, err := os.ReadFile(fmt.Sprintf(DHCPLeaseFile, NetIfName)) - if err != nil { - // do not return error if file does not exist - if os.IsNotExist(err) { - return nil, nil - } - return nil, fmt.Errorf("failed to load udhcpc info: %w", err) +func (s *NetworkInterfaceState) HandleLinkUpdate(update netlink.LinkUpdate) { + if update.Link.Attrs().Name == s.interfaceName { + s.l.Info().Interface("update", update).Msg("interface link update received") + s.CheckAndUpdateDhcp() } - - // parse udhcpc info - env, err := envparse.Parse(bytes.NewReader(buf)) - if err != nil { - return nil, fmt.Errorf("failed to parse udhcpc info: %w", err) - } - - val, ok := env["ntpsrv"] - if !ok { - return nil, nil - } - - var servers []string - - for _, server := range strings.Fields(val) { - if net.ParseIP(server) == nil { - logger.Info().Str("server", server).Msg("invalid NTP server IP, ignoring") - } - servers = append(servers, server) - } - - return servers, nil } func initNetwork() { @@ -218,25 +312,29 @@ func initNetwork() { done := make(chan struct{}) if err := netlink.LinkSubscribe(updates, done); err != nil { - logger.Warn().Err(err).Msg("failed to subscribe to link updates") + networkLogger.Warn().Err(err).Msg("failed to subscribe to link updates") return } + // TODO: support multiple interfaces + networkState = NewNetworkInterfaceState(NetIfName) + go networkState.dhcpClient.Run() + + if err := networkState.CheckAndUpdateDhcp(); err != nil { + os.Exit(1) + } + go func() { waitCtrlClientConnected() - checkNetworkState() ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case update := <-updates: - if update.Link.Attrs().Name == NetIfName { - logger.Info().Interface("update", update).Msg("link update") - checkNetworkState() - } + networkState.HandleLinkUpdate(update) case <-ticker.C: - checkNetworkState() + _ = networkState.CheckAndUpdateDhcp() case <-done: return } diff --git a/ntp.go b/ntp.go deleted file mode 100644 index 481c141..0000000 --- a/ntp.go +++ /dev/null @@ -1,214 +0,0 @@ -package kvm - -import ( - "fmt" - "net/http" - "os/exec" - "strconv" - "time" - - "github.com/beevik/ntp" -) - -const ( - timeSyncRetryStep = 5 * time.Second - timeSyncRetryMaxInt = 1 * time.Minute - timeSyncWaitNetChkInt = 100 * time.Millisecond - timeSyncWaitNetUpInt = 3 * time.Second - timeSyncInterval = 1 * time.Hour - timeSyncTimeout = 2 * time.Second -) - -var ( - builtTimestamp string - timeSyncRetryInterval = 0 * time.Second - timeSyncSuccess = false - defaultNTPServers = []string{ - "time.cloudflare.com", - "time.apple.com", - } -) - -func isTimeSyncNeeded() bool { - if builtTimestamp == "" { - ntpLogger.Warn().Msg("built timestamp is not set, time sync is needed") - return true - } - - ts, err := strconv.Atoi(builtTimestamp) - if err != nil { - ntpLogger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp") - return true - } - - // builtTimestamp is UNIX timestamp in seconds - builtTime := time.Unix(int64(ts), 0) - now := time.Now() - - if now.Sub(builtTime) < 0 { - ntpLogger.Warn(). - Str("built_time", builtTime.Format(time.RFC3339)). - Str("now", now.Format(time.RFC3339)). - Msg("system time is behind the built time, time sync is needed") - return true - } - - return false -} - -func TimeSyncLoop() { - for { - if !networkState.checked { - time.Sleep(timeSyncWaitNetChkInt) - continue - } - - if !networkState.IsOnline() { - ntpLogger.Info().Msg("waiting for network to be online") - time.Sleep(timeSyncWaitNetUpInt) - continue - } - - // check if time sync is needed, but do nothing for now - isTimeSyncNeeded() - - ntpLogger.Info().Msg("syncing system time") - start := time.Now() - err := SyncSystemTime() - if err != nil { - ntpLogger.Error().Str("error", err.Error()).Msg("failed to sync system time") - - // retry after a delay - timeSyncRetryInterval += timeSyncRetryStep - time.Sleep(timeSyncRetryInterval) - // reset the retry interval if it exceeds the max interval - if timeSyncRetryInterval > timeSyncRetryMaxInt { - timeSyncRetryInterval = 0 - } - - continue - } - timeSyncSuccess = true - ntpLogger.Info().Str("now", time.Now().Format(time.RFC3339)). - Str("time_taken", time.Since(start).String()). - Msg("time sync successful") - time.Sleep(timeSyncInterval) // after the first sync is done - } -} - -func SyncSystemTime() (err error) { - now, err := queryNetworkTime() - if err != nil { - return fmt.Errorf("failed to query network time: %w", err) - } - err = setSystemTime(*now) - if err != nil { - return fmt.Errorf("failed to set system time: %w", err) - } - return nil -} - -func queryNetworkTime() (*time.Time, error) { - ntpServers, err := getNTPServersFromDHCPInfo() - if err != nil { - ntpLogger.Info().Err(err).Msg("failed to get NTP servers from DHCP info") - } - - if ntpServers == nil { - ntpServers = defaultNTPServers - ntpLogger.Info(). - Interface("ntp_servers", ntpServers). - Msg("using default NTP servers") - } else { - ntpLogger.Info(). - Interface("ntp_servers", ntpServers). - Msg("using NTP servers from DHCP") - } - - for _, server := range ntpServers { - now, err, response := queryNtpServer(server, timeSyncTimeout) - - scopedLogger := ntpLogger.With(). - Str("server", server). - Logger() - - if err == nil { - scopedLogger.Info(). - Str("time", now.Format(time.RFC3339)). - Str("reference", response.ReferenceString()). - Str("rtt", response.RTT.String()). - Str("clockOffset", response.ClockOffset.String()). - Uint8("stratum", response.Stratum). - Msg("NTP server returned time") - return now, nil - } else { - scopedLogger.Error(). - Str("error", err.Error()). - Msg("failed to query NTP server") - } - } - - httpUrls := []string{ - "http://apple.com", - "http://cloudflare.com", - } - for _, url := range httpUrls { - now, err, response := queryHttpTime(url, timeSyncTimeout) - - var status string - if response != nil { - status = response.Status - } - - scopedLogger := ntpLogger.With(). - Str("http_url", url). - Str("status", status). - Logger() - - if err == nil { - scopedLogger.Info(). - Str("time", now.Format(time.RFC3339)). - Msg("HTTP server returned time") - return now, nil - } else { - scopedLogger.Error(). - Str("error", err.Error()). - Msg("failed to query HTTP server") - } - } - - return nil, ErrorfL(ntpLogger, "failed to query network time, all NTP servers and HTTP servers failed", nil) -} - -func queryNtpServer(server string, timeout time.Duration) (now *time.Time, err error, response *ntp.Response) { - resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout}) - if err != nil { - return nil, err, nil - } - return &resp.Time, nil, resp -} - -func queryHttpTime(url string, timeout time.Duration) (now *time.Time, err error, response *http.Response) { - client := http.Client{ - Timeout: timeout, - } - resp, err := client.Head(url) - if err != nil { - return nil, err, nil - } - dateStr := resp.Header.Get("Date") - parsedTime, err := time.Parse(time.RFC1123, dateStr) - if err != nil { - return nil, err, resp - } - return &parsedTime, nil, resp -} - -func setSystemTime(now time.Time) error { - nowStr := now.Format("2006-01-02 15:04:05") - output, err := exec.Command("date", "-s", nowStr).CombinedOutput() - if err != nil { - return fmt.Errorf("failed to run date -s: %w, %s", err, string(output)) - } - return nil -} diff --git a/timesync.go b/timesync.go new file mode 100644 index 0000000..20306ec --- /dev/null +++ b/timesync.go @@ -0,0 +1,69 @@ +package kvm + +import ( + "strconv" + "time" + + "github.com/jetkvm/kvm/internal/timesync" +) + +const ( + timeSyncRetryStep = 5 * time.Second + timeSyncRetryMaxInt = 1 * time.Minute + timeSyncWaitNetChkInt = 100 * time.Millisecond + timeSyncWaitNetUpInt = 3 * time.Second +) + +var ( + timeSync *timesync.TimeSync + defaultNTPServers = []string{ + "time.cloudflare.com", + "time.apple.com", + } + defaultHTTPUrls = []string{ + "http://apple.com", + "http://cloudflare.com", + } + builtTimestamp string +) + +func isTimeSyncNeeded() bool { + if builtTimestamp == "" { + timesyncLogger.Warn().Msg("built timestamp is not set, time sync is needed") + return true + } + + ts, err := strconv.Atoi(builtTimestamp) + if err != nil { + timesyncLogger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp") + return true + } + + // builtTimestamp is UNIX timestamp in seconds + builtTime := time.Unix(int64(ts), 0) + now := time.Now() + + if now.Sub(builtTime) < 0 { + timesyncLogger.Warn(). + Str("built_time", builtTime.Format(time.RFC3339)). + Str("now", now.Format(time.RFC3339)). + Msg("system time is behind the built time, time sync is needed") + return true + } + + return false +} + +func initTimeSync() { + timeSync = timesync.NewTimeSync( + func() (bool, error) { + if !networkState.IsOnline() { + return false, nil + } + return true, nil + }, + defaultNTPServers, + defaultHTTPUrls, + timesyncLogger, + ) +} diff --git a/web_tls.go b/web_tls.go index 2989957..46ba60f 100644 --- a/web_tls.go +++ b/web_tls.go @@ -53,7 +53,7 @@ func initCertStore() { func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { if config.TLSMode == "self-signed" { - if isTimeSyncNeeded() || !timeSyncSuccess { + if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() { return nil, fmt.Errorf("time is not synced") } return certSigner.GetCertificate(info)