refactor(network): rewrite network and timesync component

This commit is contained in:
Siyuan Miao 2025-04-12 17:32:15 +02:00
parent 7ebf7ba9fb
commit f712cb1719
23 changed files with 1401 additions and 381 deletions

View File

@ -311,11 +311,15 @@ func runWebsocketClient() error {
}, },
}) })
var connectionId string
if resp != nil {
// get the request id from the response header // get the request id from the response header
connectionId := resp.Header.Get("X-Request-ID") connectionId = resp.Header.Get("X-Request-ID")
if connectionId == "" { if connectionId == "" {
connectionId = resp.Header.Get("Cf-Ray") connectionId = resp.Header.Get("Cf-Ray")
} }
}
if connectionId == "" { if connectionId == "" {
connectionId = uuid.New().String() connectionId = uuid.New().String()
scopedLogger.Warn(). scopedLogger.Warn().
@ -457,7 +461,7 @@ func RunWebsocketClient() {
} }
// If the system time is not synchronized, the API request will fail anyway because the TLS handshake will fail. // If the system time is not synchronized, the API request will fail anyway because the TLS handshake will fail.
if isTimeSyncNeeded() && !timeSyncSuccess { if isTimeSyncNeeded() && !timeSync.IsSyncSuccess() {
cloudLogger.Warn().Msg("system time is not synced, will retry in 3 seconds") cloudLogger.Warn().Msg("system time is not synced, will retry in 3 seconds")
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
continue continue

View File

@ -134,7 +134,7 @@ func LoadConfig() {
defer configLock.Unlock() defer configLock.Unlock()
if config != nil { if config != nil {
logger.Info().Msg("config already loaded, skipping") logger.Debug().Msg("config already loaded, skipping")
return return
} }
@ -167,6 +167,8 @@ func LoadConfig() {
config = &loadedConfig config = &loadedConfig
rootLogger.UpdateLogLevel() rootLogger.UpdateLogLevel()
logger.Info().Str("path", configPath).Msg("config loaded")
} }
func SaveConfig() error { func SaveConfig() error {

View File

@ -48,7 +48,7 @@ func switchToScreenIfDifferent(screenName string) {
} }
func updateDisplay() { func updateDisplay() {
updateLabelIfChanged("ui_Home_Content_Ip", networkState.IPv4) updateLabelIfChanged("ui_Home_Content_Ip", networkState.IPv4String())
if usbState == "configured" { if usbState == "configured" {
updateLabelIfChanged("ui_Home_Footer_Usb_Status_Label", "Connected") updateLabelIfChanged("ui_Home_Footer_Usb_Status_Label", "Connected")
_, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Usb_Status_Label", "state": "LV_STATE_DEFAULT"}) _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Usb_Status_Label", "state": "LV_STATE_DEFAULT"})
@ -64,7 +64,7 @@ func updateDisplay() {
_, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Hdmi_Status_Label", "state": "LV_STATE_USER_2"}) _, _ = CallCtrlAction("lv_obj_set_state", map[string]interface{}{"obj": "ui_Home_Footer_Hdmi_Status_Label", "state": "LV_STATE_USER_2"})
} }
updateLabelIfChanged("ui_Home_Header_Cloud_Status_Label", fmt.Sprintf("%d active", actionSessions)) updateLabelIfChanged("ui_Home_Header_Cloud_Status_Label", fmt.Sprintf("%d active", actionSessions))
if networkState.Up { if networkState.IsUp() {
switchToScreenIfDifferent("ui_Home_Screen") switchToScreenIfDifferent("ui_Home_Screen")
} else { } else {
switchToScreenIfDifferent("ui_No_Network_Screen") switchToScreenIfDifferent("ui_No_Network_Screen")
@ -94,7 +94,7 @@ func requestDisplayUpdate() {
func updateStaticContents() { func updateStaticContents() {
//contents that never change //contents that never change
updateLabelIfChanged("ui_Home_Content_Mac", networkState.MAC) updateLabelIfChanged("ui_Home_Content_Mac", networkState.MACString())
systemVersion, appVersion, err := GetLocalVersion() systemVersion, appVersion, err := GetLocalVersion()
if err == nil { if err == nil {
updateLabelIfChanged("ui_About_Content_Operating_System_Version_ContentLabel", systemVersion.String()) updateLabelIfChanged("ui_About_Content_Operating_System_Version_ContentLabel", systemVersion.String())

1
go.mod
View File

@ -8,6 +8,7 @@ require (
github.com/coder/websocket v1.8.13 github.com/coder/websocket v1.8.13
github.com/coreos/go-oidc/v3 v3.11.0 github.com/coreos/go-oidc/v3 v3.11.0
github.com/creack/pty v1.1.23 github.com/creack/pty v1.1.23
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-contrib/logger v1.2.5 github.com/gin-contrib/logger v1.2.5
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0

2
go.sum
View File

@ -28,6 +28,8 @@ github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfv
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/gin-contrib/logger v1.2.5 h1:qVQI4omayQecuN4zX9ZZnsOq7w9J/ZLds3J/FMn8ypM= github.com/gin-contrib/logger v1.2.5 h1:qVQI4omayQecuN4zX9ZZnsOq7w9J/ZLds3J/FMn8ypM=

54
internal/timesync/http.go Normal file
View File

@ -0,0 +1,54 @@
package timesync
import (
"net/http"
"time"
)
func queryHttpTime(
url string,
timeout time.Duration,
) (now *time.Time, err error, response *http.Response) {
client := http.Client{
Timeout: timeout,
}
resp, err := client.Head(url)
if err != nil {
return nil, err, nil
}
dateStr := resp.Header.Get("Date")
parsedTime, err := time.Parse(time.RFC1123, dateStr)
if err != nil {
return nil, err, resp
}
return &parsedTime, nil, resp
}
func (t *TimeSync) queryAllHttpTime() (now *time.Time) {
for _, url := range t.httpUrls {
now, err, response := queryHttpTime(url, timeSyncTimeout)
var status string
if response != nil {
status = response.Status
}
scopedLogger := t.l.With().
Str("http_url", url).
Str("status", status).
Logger()
if err == nil {
scopedLogger.Info().
Str("time", now.Format(time.RFC3339)).
Msg("HTTP server returned time")
return now
} else {
scopedLogger.Error().
Str("error", err.Error()).
Msg("failed to query HTTP server")
}
}
return nil
}

42
internal/timesync/ntp.go Normal file
View File

@ -0,0 +1,42 @@
package timesync
import (
"time"
"github.com/beevik/ntp"
)
func (t *TimeSync) queryNetworkTime() (now *time.Time) {
for _, server := range t.ntpServers {
now, err, response := queryNtpServer(server, timeSyncTimeout)
scopedLogger := t.l.With().
Str("server", server).
Logger()
if err == nil {
scopedLogger.Info().
Str("time", now.Format(time.RFC3339)).
Str("reference", response.ReferenceString()).
Str("rtt", response.RTT.String()).
Str("clockOffset", response.ClockOffset.String()).
Uint8("stratum", response.Stratum).
Msg("NTP server returned time")
return now
} else {
scopedLogger.Error().
Str("error", err.Error()).
Msg("failed to query NTP server")
}
}
return nil
}
func queryNtpServer(server string, timeout time.Duration) (now *time.Time, err error, response *ntp.Response) {
resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout})
if err != nil {
return nil, err, nil
}
return &resp.Time, nil, resp
}

26
internal/timesync/rtc.go Normal file
View File

@ -0,0 +1,26 @@
package timesync
import (
"fmt"
"os"
)
var (
rtcDeviceSearchPaths = []string{
"/dev/rtc",
"/dev/rtc0",
"/dev/rtc1",
"/dev/misc/rtc",
"/dev/misc/rtc0",
"/dev/misc/rtc1",
}
)
func getRtcDevicePath() (string, error) {
for _, path := range rtcDeviceSearchPaths {
if _, err := os.Stat(path); err == nil {
return path, nil
}
}
return "", fmt.Errorf("rtc device not found")
}

View File

@ -0,0 +1,103 @@
//go:build linux
package timesync
import (
"fmt"
"os"
"time"
"golang.org/x/sys/unix"
)
func TimetoRtcTime(t time.Time) unix.RTCTime {
return unix.RTCTime{
Sec: int32(t.Second()),
Min: int32(t.Minute()),
Hour: int32(t.Hour()),
Mday: int32(t.Day()),
Mon: int32(t.Month() - 1),
Year: int32(t.Year() - 1900),
Wday: int32(0),
Yday: int32(0),
Isdst: int32(0),
}
}
func RtcTimetoTime(t unix.RTCTime) time.Time {
return time.Date(
int(t.Year)+1900,
time.Month(t.Mon+1),
int(t.Mday),
int(t.Hour),
int(t.Min),
int(t.Sec),
0,
time.UTC,
)
}
func (t *TimeSync) getRtcDevice() (*os.File, error) {
if t.rtcDevice == nil {
file, err := os.OpenFile(t.rtcDevicePath, os.O_RDWR, 0666)
if err != nil {
return nil, err
}
t.rtcDevice = file
}
return t.rtcDevice, nil
}
func (t *TimeSync) getRtcDeviceFd() (int, error) {
device, err := t.getRtcDevice()
if err != nil {
return 0, err
}
return int(device.Fd()), nil
}
// Read implements Read for the Linux RTC
func (t *TimeSync) readRtcTime() (time.Time, error) {
fd, err := t.getRtcDeviceFd()
if err != nil {
return time.Time{}, fmt.Errorf("failed to get RTC device fd: %w", err)
}
rtcTime, err := unix.IoctlGetRTCTime(fd)
if err != nil {
return time.Time{}, fmt.Errorf("failed to get RTC time: %w", err)
}
date := RtcTimetoTime(*rtcTime)
return date, nil
}
// Set implements Set for the Linux RTC
// ...
// It might be not accurate as the time consumed by the system call is not taken into account
// but it's good enough for our purposes
func (t *TimeSync) setRtcTime(tu time.Time) error {
rt := TimetoRtcTime(tu)
fd, err := t.getRtcDeviceFd()
if err != nil {
return fmt.Errorf("failed to get RTC device fd: %w", err)
}
currentRtcTime, err := t.readRtcTime()
if err != nil {
return fmt.Errorf("failed to read RTC time: %w", err)
}
t.l.Info().
Interface("rtc_time", tu).
Str("offset", tu.Sub(currentRtcTime).String()).
Msg("set rtc time")
if err := unix.IoctlSetRTCTime(fd, &rt); err != nil {
return fmt.Errorf("failed to set RTC time: %w", err)
}
return nil
}

View File

@ -0,0 +1,16 @@
//go:build !linux
package timesync
import (
"errors"
"time"
)
func (t *TimeSync) readRtcTime() (time.Time, error) {
return time.Now(), nil
}
func (t *TimeSync) setRtcTime(tu time.Time) error {
return errors.New("not supported")
}

View File

@ -0,0 +1,151 @@
package timesync
import (
"fmt"
"os"
"os/exec"
"sync"
"time"
"github.com/rs/zerolog"
)
const (
timeSyncRetryStep = 5 * time.Second
timeSyncRetryMaxInt = 1 * time.Minute
timeSyncWaitNetChkInt = 100 * time.Millisecond
timeSyncWaitNetUpInt = 3 * time.Second
timeSyncInterval = 1 * time.Hour
timeSyncTimeout = 2 * time.Second
)
var (
timeSyncRetryInterval = 0 * time.Second
defaultNTPServers = []string{
"time.cloudflare.com",
"time.apple.com",
}
)
type TimeSync struct {
syncLock *sync.Mutex
l *zerolog.Logger
ntpServers []string
httpUrls []string
rtcDevicePath string
rtcDevice *os.File
rtcLock *sync.Mutex
syncSuccess bool
preCheckFunc func() (bool, error)
}
func NewTimeSync(
precheckFunc func() (bool, error),
ntpServers []string,
httpUrls []string,
logger *zerolog.Logger,
) *TimeSync {
rtcDevice, err := getRtcDevicePath()
if err != nil {
logger.Error().Err(err).Msg("failed to get RTC device path")
} else {
logger.Info().Str("path", rtcDevice).Msg("RTC device found")
}
t := &TimeSync{
syncLock: &sync.Mutex{},
l: logger,
rtcDevicePath: rtcDevice,
rtcLock: &sync.Mutex{},
preCheckFunc: precheckFunc,
ntpServers: ntpServers,
httpUrls: httpUrls,
}
if t.rtcDevicePath != "" {
rtcTime, _ := t.readRtcTime()
t.l.Info().Interface("rtc_time", rtcTime).Msg("read RTC time")
}
return t
}
func (t *TimeSync) doTimeSync() {
for {
if ok, err := t.preCheckFunc(); !ok {
if err != nil {
t.l.Error().Err(err).Msg("pre-check failed")
}
time.Sleep(timeSyncWaitNetChkInt)
continue
}
t.l.Info().Msg("syncing system time")
start := time.Now()
err := t.Sync()
if err != nil {
t.l.Error().Str("error", err.Error()).Msg("failed to sync system time")
// retry after a delay
timeSyncRetryInterval += timeSyncRetryStep
time.Sleep(timeSyncRetryInterval)
// reset the retry interval if it exceeds the max interval
if timeSyncRetryInterval > timeSyncRetryMaxInt {
timeSyncRetryInterval = 0
}
continue
}
t.syncSuccess = true
t.l.Info().Str("now", time.Now().Format(time.RFC3339)).
Str("time_taken", time.Since(start).String()).
Msg("time sync successful")
time.Sleep(timeSyncInterval) // after the first sync is done
}
}
func (t *TimeSync) Sync() error {
var now *time.Time
now = t.queryNetworkTime()
if now == nil {
now = t.queryAllHttpTime()
}
if now == nil {
return fmt.Errorf("failed to get time from any source")
}
err := t.setSystemTime(*now)
if err != nil {
return fmt.Errorf("failed to set system time: %w", err)
}
return nil
}
func (t *TimeSync) IsSyncSuccess() bool {
return t.syncSuccess
}
func (t *TimeSync) Start() {
go t.doTimeSync()
}
func (t *TimeSync) setSystemTime(now time.Time) error {
nowStr := now.Format("2006-01-02 15:04:05")
output, err := exec.Command("date", "-s", nowStr).CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run date -s: %w, %s", err, string(output))
}
if t.rtcDevicePath != "" {
return t.setRtcTime(now)
}
return nil
}

View File

@ -0,0 +1,12 @@
package udhcpc
func (u *DHCPClient) GetNtpServers() []string {
if u.lease == nil {
return nil
}
servers := make([]string, len(u.lease.NTPServers))
for i, server := range u.lease.NTPServers {
servers[i] = server.String()
}
return servers
}

150
internal/udhcpc/parser.go Normal file
View File

@ -0,0 +1,150 @@
package udhcpc
import (
"encoding/json"
"fmt"
"log"
"net"
"reflect"
"strconv"
"strings"
"time"
)
type Lease struct {
// from https://udhcp.busybox.net/README.udhcpc
IPAddress net.IP `env:"ip" json:"ip"` // The obtained IP
Netmask net.IP `env:"subnet" json:"netmask"` // The assigned subnet mask
Broadcast net.IP `env:"broadcast" json:"broadcast"` // The broadcast address for this network
TTL int `env:"ipttl" json:"ttl,omitempty"` // The TTL to use for this network
MTU int `env:"mtu" json:"mtu,omitempty"` // The MTU to use for this network
HostName string `env:"hostname" json:"hostname,omitempty"` // The assigned hostname
Domain string `env:"domain" json:"domain,omitempty"` // The domain name of the network
BootPNextServer net.IP `env:"siaddr" json:"bootp_next_server,omitempty"` // The bootp next server option
BootPServerName string `env:"sname" json:"bootp_server_name,omitempty"` // The bootp server name option
BootPFile string `env:"boot_file" json:"bootp_file,omitempty"` // The bootp boot file option
Timezone string `env:"timezone" json:"timezone,omitempty"` // Offset in seconds from UTC
Routers []net.IP `env:"router" json:"routers,omitempty"` // A list of routers
DNS []net.IP `env:"dns" json:"dns_servers,omitempty"` // A list of DNS servers
NTPServers []net.IP `env:"ntpsrv" json:"ntp_servers,omitempty"` // A list of NTP servers
LPRServers []net.IP `env:"lprsvr" json:"lpr_servers,omitempty"` // A list of LPR servers
TimeServers []net.IP `env:"timesvr" json:"_time_servers,omitempty"` // A list of time servers (obsolete)
IEN116NameServers []net.IP `env:"namesvr" json:"_name_servers,omitempty"` // A list of IEN 116 name servers (obsolete)
LogServers []net.IP `env:"logsvr" json:"_log_servers,omitempty"` // A list of MIT-LCS UDP log servers (obsolete)
CookieServers []net.IP `env:"cookiesvr" json:"_cookie_servers,omitempty"` // A list of RFC 865 cookie servers (obsolete)
WINSServers []net.IP `env:"wins" json:"_wins_servers,omitempty"` // A list of WINS servers
SwapServer net.IP `env:"swapsvr" json:"_swap_server,omitempty"` // The IP address of the client's swap server
BootSize int `env:"bootsize" json:"bootsize,omitempty"` // The length in 512 octect blocks of the bootfile
RootPath string `env:"rootpath" json:"root_path,omitempty"` // The path name of the client's root disk
LeaseTime time.Duration `env:"lease" json:"lease,omitempty"` // The lease time, in seconds
DHCPType string `env:"dhcptype" json:"dhcp_type,omitempty"` // DHCP message type (safely ignored)
ServerID string `env:"serverid" json:"server_id,omitempty"` // The IP of the server
Message string `env:"message" json:"reason,omitempty"` // Reason for a DHCPNAK
TFTPServerName string `env:"tftp" json:"tftp,omitempty"` // The TFTP server name
BootFileName string `env:"bootfile" json:"bootfile,omitempty"` // The boot file name
isEmpty map[string]bool
}
func (l *Lease) setIsEmpty(m map[string]bool) {
l.isEmpty = m
}
func (l *Lease) IsEmpty(key string) bool {
return l.isEmpty[key]
}
func (l *Lease) ToJSON() string {
json, err := json.Marshal(l)
if err != nil {
return ""
}
return string(json)
}
func UnmarshalDHCPCLease(lease *Lease, str string) error {
// parse the lease file as a map
data := make(map[string]string)
for _, line := range strings.Split(str, "\n") {
line = strings.TrimSpace(line)
// skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
log.Printf("invalid line: %s", line)
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
data[key] = value
}
// now iterate over the lease struct and set the values
leaseType := reflect.TypeOf(lease).Elem()
leaseValue := reflect.ValueOf(lease).Elem()
valuesParsed := make(map[string]bool)
for i := 0; i < leaseType.NumField(); i++ {
field := leaseValue.Field(i)
// get the env tag
key := leaseType.Field(i).Tag.Get("env")
if key == "" {
continue
}
valuesParsed[key] = false
// get the value from the data map
value, ok := data[key]
if !ok || value == "" {
continue
}
switch field.Interface().(type) {
case string:
field.SetString(value)
case int:
val, err := strconv.Atoi(value)
if err != nil {
continue
}
field.SetInt(int64(val))
case time.Duration:
val, err := time.ParseDuration(value + "s")
if err != nil {
continue
}
field.Set(reflect.ValueOf(val))
case net.IP:
ip := net.ParseIP(value)
if ip == nil {
continue
}
field.Set(reflect.ValueOf(ip))
case []net.IP:
val := make([]net.IP, 0)
for _, ipStr := range strings.Fields(value) {
ip := net.ParseIP(ipStr)
if ip == nil {
continue
}
val = append(val, ip)
}
field.Set(reflect.ValueOf(val))
default:
return fmt.Errorf("unsupported field `%s` type: %s", key, field.Type().String())
}
valuesParsed[key] = true
}
lease.setIsEmpty(valuesParsed)
return nil
}

View File

@ -0,0 +1,74 @@
package udhcpc
import (
"testing"
"time"
)
func TestUnmarshalDHCPCLease(t *testing.T) {
lease := &Lease{}
err := UnmarshalDHCPCLease(lease, `
# generated @ Mon Jan 4 19:31:53 UTC 2021
# 19:31:53 up 0 min, 0 users, load average: 0.72, 0.14, 0.04
# the date might be inaccurate if the clock is not set
ip=192.168.0.240
siaddr=192.168.0.1
sname=
boot_file=
subnet=255.255.255.0
timezone=
router=192.168.0.1
timesvr=
namesvr=
dns=172.19.53.2
logsvr=
cookiesvr=
lprsvr=
hostname=
bootsize=
domain=
swapsvr=
rootpath=
ipttl=
mtu=
broadcast=
ntpsrv=162.159.200.123
wins=
lease=172800
dhcptype=
serverid=192.168.0.1
message=
tftp=
bootfile=
`)
if lease.IPAddress.String() != "192.168.0.240" {
t.Fatalf("expected ip to be 192.168.0.240, got %s", lease.IPAddress.String())
}
if lease.Netmask.String() != "255.255.255.0" {
t.Fatalf("expected netmask to be 255.255.255.0, got %s", lease.Netmask.String())
}
if len(lease.Routers) != 1 {
t.Fatalf("expected 1 router, got %d", len(lease.Routers))
}
if lease.Routers[0].String() != "192.168.0.1" {
t.Fatalf("expected router to be 192.168.0.1, got %s", lease.Routers[0].String())
}
if len(lease.NTPServers) != 1 {
t.Fatalf("expected 1 timeserver, got %d", len(lease.NTPServers))
}
if lease.NTPServers[0].String() != "162.159.200.123" {
t.Fatalf("expected timeserver to be 162.159.200.123, got %s", lease.NTPServers[0].String())
}
if len(lease.DNS) != 1 {
t.Fatalf("expected 1 dns, got %d", len(lease.DNS))
}
if lease.DNS[0].String() != "172.19.53.2" {
t.Fatalf("expected dns to be 172.19.53.2, got %s", lease.DNS[0].String())
}
if lease.LeaseTime != 172800*time.Second {
t.Fatalf("expected lease time to be 172800 seconds, got %d", lease.LeaseTime)
}
if err != nil {
t.Fatal(err)
}
}

212
internal/udhcpc/proc.go Normal file
View File

@ -0,0 +1,212 @@
package udhcpc
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
)
func readFileNoStat(filename string) ([]byte, error) {
const maxBufferSize = 1024 * 1024
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
reader := io.LimitReader(f, maxBufferSize)
return io.ReadAll(reader)
}
func toCmdline(path string) ([]string, error) {
data, err := readFileNoStat(path)
if err != nil {
return nil, err
}
if len(data) < 1 {
return []string{}, nil
}
return strings.Split(string(bytes.TrimRight(data, "\x00")), "\x00"), nil
}
func (p *DHCPClient) findUdhcpcProcess() (int, error) {
// read procfs for udhcpc processes
// we do not use procfs.AllProcs() because we want to avoid the overhead of reading the entire procfs
processes, err := os.ReadDir("/proc")
if err != nil {
return 0, err
}
// iterate over the processes
for _, d := range processes {
// check if file is numeric
pid, err := strconv.Atoi(d.Name())
if err != nil {
continue
}
// check if it's a directory
if !d.IsDir() {
continue
}
cmdline, err := toCmdline(filepath.Join("/proc", d.Name(), "cmdline"))
if err != nil {
continue
}
if len(cmdline) < 1 {
continue
}
if cmdline[0] != "udhcpc" {
continue
}
cmdlineText := strings.Join(cmdline, " ")
// check if it's a udhcpc process
if strings.Contains(cmdlineText, fmt.Sprintf("-i %s", p.InterfaceName)) {
p.logger.Debug().
Str("pid", d.Name()).
Interface("cmdline", cmdline).
Msg("found udhcpc process")
return pid, nil
}
}
return 0, errors.New("udhcpc process not found")
}
func (c *DHCPClient) getProcessPid() (int, error) {
var pid int
if c.pidFile != "" {
// try to read the pid file
pidHandle, err := os.ReadFile(c.pidFile)
if err != nil {
c.logger.Warn().Err(err).
Str("pidFile", c.pidFile).Msg("failed to read udhcpc pid file")
}
// if it exists, try to read the pid
if pidHandle != nil {
pidFromFile, err := strconv.Atoi(string(pidHandle))
if err != nil {
c.logger.Warn().Err(err).
Str("pidFile", c.pidFile).Msg("failed to convert pid file to int")
}
pid = pidFromFile
}
}
// if the pid is 0, try to find the pid using procfs
if pid == 0 {
newPid, err := c.findUdhcpcProcess()
if err != nil {
return 0, err
}
pid = newPid
}
return pid, nil
}
func (c *DHCPClient) getProcess() *os.Process {
pid, err := c.getProcessPid()
if err != nil {
return nil
}
process, err := os.FindProcess(pid)
if err != nil {
c.logger.Warn().Err(err).
Int("pid", pid).Msg("failed to find process")
return nil
}
return process
}
func (c *DHCPClient) GetProcess() *os.Process {
if c.process == nil {
process := c.getProcess()
if process == nil {
return nil
}
c.process = process
}
err := c.process.Signal(syscall.Signal(0))
if err != nil && errors.Is(err, os.ErrProcessDone) {
oldPid := c.process.Pid
c.process = nil
c.process = c.getProcess()
if c.process == nil {
c.logger.Error().Msg("failed to find new udhcpc process")
return nil
}
c.logger.Warn().
Int("oldPid", oldPid).
Int("newPid", c.process.Pid).
Msg("udhcpc process pid changed")
} else if err != nil {
c.logger.Warn().Err(err).
Int("pid", c.process.Pid).Msg("udhcpc process is not running")
}
return c.process
}
func (c *DHCPClient) KillProcess() error {
process := c.GetProcess()
if process == nil {
return nil
}
return process.Kill()
}
func (c *DHCPClient) ReleaseProcess() error {
process := c.GetProcess()
if process == nil {
return nil
}
return process.Release()
}
func (c *DHCPClient) signalProcess(sig syscall.Signal) error {
process := c.GetProcess()
if process == nil {
return nil
}
s := process.Signal(sig)
if s != nil {
c.logger.Warn().Err(s).
Int("pid", process.Pid).
Str("signal", sig.String()).
Msg("failed to signal udhcpc process")
return s
}
return nil
}
func (c *DHCPClient) Renew() error {
return c.signalProcess(syscall.SIGUSR1)
}
func (c *DHCPClient) Release() error {
return c.signalProcess(syscall.SIGUSR2)
}

145
internal/udhcpc/udhcpc.go Normal file
View File

@ -0,0 +1,145 @@
package udhcpc
import (
"errors"
"fmt"
"os"
"github.com/fsnotify/fsnotify"
"github.com/rs/zerolog"
)
const (
DHCPLeaseFile = "/run/udhcpc.%s.info"
DHCPPidFile = "/run/udhcpc.%s.pid"
)
type DHCPClient struct {
InterfaceName string
leaseFile string
pidFile string
lease *Lease
logger *zerolog.Logger
process *os.Process
onLeaseChange func(lease *Lease)
}
type DHCPClientOptions struct {
InterfaceName string
PidFile string
Logger *zerolog.Logger
OnLeaseChange func(lease *Lease)
}
var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel)
func NewDHCPClient(options *DHCPClientOptions) *DHCPClient {
if options.Logger == nil {
options.Logger = &defaultLogger
}
l := options.Logger.With().Str("interface", options.InterfaceName).Logger()
return &DHCPClient{
InterfaceName: options.InterfaceName,
logger: &l,
leaseFile: fmt.Sprintf(DHCPLeaseFile, options.InterfaceName),
pidFile: options.PidFile,
onLeaseChange: options.OnLeaseChange,
}
}
// Run starts the DHCP client and watches the lease file for changes.
// this isn't a blocking call, and the lease file is reloaded when a change is detected.
func (c *DHCPClient) Run() error {
err := c.loadLeaseFile()
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
defer watcher.Close()
go func() {
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
c.logger.Debug().
Str("event", event.Name).
Msg("udhcpc lease file updated, reloading lease")
c.loadLeaseFile()
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
c.logger.Error().Err(err).Msg("error watching lease file")
}
}
}()
watcher.Add(c.leaseFile)
// TODO: update udhcpc pid file
// we'll comment this out for now because the pid might change
// process := c.GetProcess()
// if process == nil {
// c.logger.Error().Msg("udhcpc process not found")
// }
// block the goroutine until the lease file is updated
<-make(chan struct{})
return nil
}
func (c *DHCPClient) loadLeaseFile() error {
file, err := os.ReadFile(c.leaseFile)
if err != nil {
return err
}
data := string(file)
if data == "" {
c.logger.Debug().Msg("udhcpc lease file is empty")
return nil
}
lease := &Lease{}
err = UnmarshalDHCPCLease(lease, string(file))
if err != nil {
return err
}
isFirstLoad := c.lease == nil
c.lease = lease
if lease.IPAddress == nil {
c.logger.Info().
Interface("lease", lease).
Str("data", string(file)).
Msg("udhcpc lease cleared")
return nil
}
msg := "udhcpc lease updated"
if isFirstLoad {
msg = "udhcpc lease loaded"
}
c.onLeaseChange(lease)
c.logger.Info().
Str("ip", lease.IPAddress.String()).
Str("leaseTime", lease.LeaseTime.String()).
Interface("data", lease).
Msg(msg)
return nil
}

3
log.go
View File

@ -218,12 +218,13 @@ func ErrorfL(l *zerolog.Logger, format string, err error, args ...interface{}) e
var ( var (
logger = rootLogger.getLogger("jetkvm") logger = rootLogger.getLogger("jetkvm")
networkLogger = rootLogger.getLogger("network")
cloudLogger = rootLogger.getLogger("cloud") cloudLogger = rootLogger.getLogger("cloud")
websocketLogger = rootLogger.getLogger("websocket") websocketLogger = rootLogger.getLogger("websocket")
webrtcLogger = rootLogger.getLogger("webrtc") webrtcLogger = rootLogger.getLogger("webrtc")
nativeLogger = rootLogger.getLogger("native") nativeLogger = rootLogger.getLogger("native")
nbdLogger = rootLogger.getLogger("nbd") nbdLogger = rootLogger.getLogger("nbd")
ntpLogger = rootLogger.getLogger("ntp") timesyncLogger = rootLogger.getLogger("timesync")
jsonRpcLogger = rootLogger.getLogger("jsonrpc") jsonRpcLogger = rootLogger.getLogger("jsonrpc")
watchdogLogger = rootLogger.getLogger("watchdog") watchdogLogger = rootLogger.getLogger("watchdog")
websecureLogger = rootLogger.getLogger("websecure") websecureLogger = rootLogger.getLogger("websecure")

22
main.go
View File

@ -15,26 +15,38 @@ var appCtx context.Context
func Main() { func Main() {
LoadConfig() LoadConfig()
logger.Debug().Msg("config loaded")
var cancel context.CancelFunc var cancel context.CancelFunc
appCtx, cancel = context.WithCancel(context.Background()) appCtx, cancel = context.WithCancel(context.Background())
defer cancel() defer cancel()
logger.Info().Msg("starting JetKvm")
systemVersionLocal, appVersionLocal, err := GetLocalVersion()
if err != nil {
logger.Warn().Err(err).Msg("failed to get local version")
}
logger.Info().
Interface("system_version", systemVersionLocal).
Interface("app_version", appVersionLocal).
Msg("starting JetKVM")
go runWatchdog() go runWatchdog()
go confirmCurrentSystem() go confirmCurrentSystem()
http.DefaultClient.Timeout = 1 * time.Minute http.DefaultClient.Timeout = 1 * time.Minute
err := rootcerts.UpdateDefaultTransport() err = rootcerts.UpdateDefaultTransport()
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("failed to load CA certs") logger.Warn().Err(err).Msg("failed to load Root CA certificates")
} }
logger.Info().
Int("ca_certs_loaded", len(rootcerts.Certs())).
Msg("loaded Root CA certificates")
initNetwork() initNetwork()
initTimeSync()
go TimeSyncLoop() timeSync.Start()
StartNativeCtrlSocketServer() StartNativeCtrlSocketServer()
StartNativeVideoSocketServer() StartNativeVideoSocketServer()

60
mdns.go Normal file
View File

@ -0,0 +1,60 @@
package kvm
import (
"net"
"github.com/pion/mdns/v2"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var mDNSConn *mdns.Conn
func startMDNS() error {
// If server was previously running, stop it
if mDNSConn != nil {
logger.Info().Msg("stopping mDNS server")
err := mDNSConn.Close()
if err != nil {
logger.Warn().Err(err).Msg("failed to stop mDNS server")
}
}
// Start a new server
hostname := "jetkvm.local"
scopedLogger := logger.With().Str("hostname", hostname).Logger()
scopedLogger.Info().Msg("starting mDNS server")
addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4)
if err != nil {
return err
}
addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
if err != nil {
return err
}
l4, err := net.ListenUDP("udp4", addr4)
if err != nil {
return err
}
l6, err := net.ListenUDP("udp6", addr6)
if err != nil {
return err
}
mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &mdns.Config{
LocalNames: []string{hostname}, //TODO: make it configurable
LoggerFactory: defaultLoggerFactory,
})
if err != nil {
scopedLogger.Warn().Err(err).Msg("failed to start mDNS server")
mDNSConn = nil
return err
}
//defer server.Close()
return nil
}

View File

@ -1,214 +1,308 @@
package kvm package kvm
import ( import (
"bytes"
"fmt" "fmt"
"net" "net"
"os" "os"
"strings" "sync"
"time" "time"
"os/exec" "github.com/Masterminds/semver/v3"
"github.com/jetkvm/kvm/internal/udhcpc"
"github.com/hashicorp/go-envparse" "github.com/rs/zerolog"
"github.com/pion/mdns/v2"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl" "github.com/vishvananda/netlink/nl"
) )
var mDNSConn *mdns.Conn var (
networkState *NetworkInterfaceState
)
var networkState NetworkState type DhcpTargetState int
type NetworkState struct { const (
Up bool DhcpTargetStateDoNothing DhcpTargetState = iota
IPv4 string DhcpTargetStateStart
IPv6 string DhcpTargetStateStop
MAC string DhcpTargetStateRenew
DhcpTargetStateRelease
)
type NetworkInterfaceState struct {
interfaceName string
interfaceUp bool
ipv4Addr *net.IP
ipv6Addr *net.IP
macAddr *net.HardwareAddr
l *zerolog.Logger
stateLock sync.Mutex
dhcpClient *udhcpc.DHCPClient
onStateChange func(state *NetworkInterfaceState)
onInitialCheck func(state *NetworkInterfaceState)
checked bool checked bool
} }
func (s *NetworkState) IsUp() bool { func (s *NetworkInterfaceState) IsUp() bool {
return s.Up && s.IPv4 != "" && s.IPv6 != "" return s.interfaceUp
} }
func (s *NetworkState) HasIPAssigned() bool { func (s *NetworkInterfaceState) HasIPAssigned() bool {
return s.IPv4 != "" || s.IPv6 != "" return s.ipv4Addr != nil || s.ipv6Addr != nil
} }
func (s *NetworkState) IsOnline() bool { func (s *NetworkInterfaceState) IsOnline() bool {
return s.Up && s.HasIPAssigned() return s.IsUp() && s.HasIPAssigned()
} }
type LocalIpInfo struct { func (s *NetworkInterfaceState) IPv4() *net.IP {
IPv4 string return s.ipv4Addr
IPv6 string }
MAC string
func (s *NetworkInterfaceState) IPv4String() string {
if s.ipv4Addr == nil {
return "..."
}
return s.ipv4Addr.String()
}
func (s *NetworkInterfaceState) IPv6() *net.IP {
return s.ipv6Addr
}
func (s *NetworkInterfaceState) IPv6String() string {
if s.ipv6Addr == nil {
return "..."
}
return s.ipv6Addr.String()
}
func (s *NetworkInterfaceState) MAC() *net.HardwareAddr {
return s.macAddr
}
func (s *NetworkInterfaceState) MACString() string {
if s.macAddr == nil {
return ""
}
return s.macAddr.String()
} }
const ( const (
// TODO: add support for multiple interfaces
NetIfName = "eth0" NetIfName = "eth0"
DHCPLeaseFile = "/run/udhcpc.%s.info"
) )
// setDhcpClientState sends signals to udhcpc to change it's current mode func NewNetworkInterfaceState(ifname string) *NetworkInterfaceState {
// of operation. Setting active to true will force udhcpc to renew the DHCP lease. logger := networkLogger.With().Str("interface", ifname).Logger()
// Setting active to false will put udhcpc into idle mode.
func setDhcpClientState(active bool) { s := &NetworkInterfaceState{
var signal string interfaceName: ifname,
if active { stateLock: sync.Mutex{},
signal = "-SIGUSR1" l: &logger,
} else { onStateChange: func(state *NetworkInterfaceState) {
signal = "-SIGUSR2" go func() {
waitCtrlClientConnected()
requestDisplayUpdate()
}()
},
onInitialCheck: func(state *NetworkInterfaceState) {
go func() {
waitCtrlClientConnected()
requestDisplayUpdate()
}()
},
} }
cmd := exec.Command("/usr/bin/killall", signal, "udhcpc") // use a pid file for udhcpc if the system version is 0.2.4 or higher
if err := cmd.Run(); err != nil { dhcpPidFile := ""
logger.Warn().Err(err).Msg("network: setDhcpClientState: failed to change udhcpc state") systemVersionLocal, _, _ := GetLocalVersion()
if systemVersionLocal != nil &&
systemVersionLocal.Compare(semver.MustParse("0.2.4")) >= 0 {
dhcpPidFile = fmt.Sprintf("/run/udhcpc.%s.pid", ifname)
} }
// create the dhcp client
dhcpClient := udhcpc.NewDHCPClient(&udhcpc.DHCPClientOptions{
InterfaceName: ifname,
PidFile: dhcpPidFile,
Logger: &logger,
OnLeaseChange: func(lease *udhcpc.Lease) {
s.update()
},
})
s.dhcpClient = dhcpClient
return s
} }
func checkNetworkState() { func (s *NetworkInterfaceState) update() (DhcpTargetState, error) {
iface, err := netlink.LinkByName(NetIfName) s.stateLock.Lock()
defer s.stateLock.Unlock()
dhcpTargetState := DhcpTargetStateDoNothing
iface, err := netlink.LinkByName(s.interfaceName)
if err != nil { if err != nil {
logger.Warn().Err(err).Str("interface", NetIfName).Msg("failed to get interface") s.l.Error().Err(err).Msg("failed to get interface")
return return dhcpTargetState, err
} }
newState := NetworkState{ // detect if the interface status changed
Up: iface.Attrs().OperState == netlink.OperUp, var changed bool
MAC: iface.Attrs().HardwareAddr.String(), attrs := iface.Attrs()
state := attrs.OperState
newInterfaceUp := state == netlink.OperUp
checked: true, // check if the interface is coming up
interfaceGoingUp := s.interfaceUp == false && newInterfaceUp == true
interfaceGoingDown := s.interfaceUp == true && newInterfaceUp == false
if s.interfaceUp != newInterfaceUp {
s.interfaceUp = newInterfaceUp
changed = true
} }
if changed {
if interfaceGoingUp {
s.l.Info().Msg("interface state transitioned to up")
dhcpTargetState = DhcpTargetStateRenew
} else if interfaceGoingDown {
s.l.Info().Msg("interface state transitioned to down")
}
}
// set the mac address
s.macAddr = &attrs.HardwareAddr
// get the ip addresses
addrs, err := netlink.AddrList(iface, nl.FAMILY_ALL) addrs, err := netlink.AddrList(iface, nl.FAMILY_ALL)
if err != nil { if err != nil {
logger.Warn().Err(err).Str("interface", NetIfName).Msg("failed to get addresses") s.l.Error().Err(err).Msg("failed to get ip addresses")
return dhcpTargetState, err
} }
// If the link is going down, put udhcpc into idle mode. var (
// If the link is coming back up, activate udhcpc and force it to renew the lease. ipv4Addresses = make([]net.IP, 0)
if newState.Up != networkState.Up { ipv6Addresses = make([]net.IP, 0)
setDhcpClientState(newState.Up) )
}
for _, addr := range addrs { for _, addr := range addrs {
if addr.IP.To4() != nil { if addr.IP.To4() != nil {
if !newState.Up && networkState.Up { scopedLogger := s.l.With().Str("ipv4", addr.IP.String()).Logger()
// If the network is going down, remove all IPv4 addresses from the interface. if interfaceGoingDown {
logger.Info().Str("address", addr.IP.String()).Msg("network: state transitioned to down, removing IPv4 address") // remove all IPv4 addresses from the interface.
scopedLogger.Info().Msg("state transitioned to down, removing IPv4 address")
err := netlink.AddrDel(iface, &addr) err := netlink.AddrDel(iface, &addr)
if err != nil { if err != nil {
logger.Warn().Err(err).Str("address", addr.IP.String()).Msg("network: failed to delete address") scopedLogger.Warn().Err(err).Msg("failed to delete address")
}
// notify the DHCP client to release the lease
dhcpTargetState = DhcpTargetStateRelease
continue
}
ipv4Addresses = append(ipv4Addresses, addr.IP)
} else if addr.IP.To16() != nil {
scopedLogger := s.l.With().Str("ipv6", addr.IP.String()).Logger()
// check if it's a link local address
if !addr.IP.IsGlobalUnicast() {
scopedLogger.Trace().Msg("not a global unicast address, skipping")
continue
} }
newState.IPv4 = "..." if interfaceGoingDown {
scopedLogger.Info().Msg("state transitioned to down, removing IPv6 address")
err := netlink.AddrDel(iface, &addr)
if err != nil {
scopedLogger.Warn().Err(err).Msg("failed to delete address")
}
continue
}
ipv6Addresses = append(ipv6Addresses, addr.IP)
}
}
if len(ipv4Addresses) > 0 {
// compare the addresses to see if there's a change
if s.ipv4Addr == nil || s.ipv4Addr.String() != ipv4Addresses[0].String() {
scopedLogger := s.l.With().Str("ipv4", ipv4Addresses[0].String()).Logger()
if s.ipv4Addr != nil {
scopedLogger.Info().
Str("old_ipv4", s.ipv4Addr.String()).
Msg("IPv4 address changed")
changed = true
} else { } else {
newState.IPv4 = addr.IP.String() scopedLogger.Info().Msg("IPv4 address found")
} }
} else if addr.IP.To16() != nil && newState.IPv6 == "" { s.ipv4Addr = &ipv4Addresses[0]
newState.IPv6 = addr.IP.String() changed = true
} }
} }
if newState != networkState { if len(ipv6Addresses) > 0 {
logger.Info(). // compare the addresses to see if there's a change
Interface("newState", newState). if s.ipv6Addr == nil || s.ipv6Addr.String() != ipv6Addresses[0].String() {
Interface("oldState", networkState). scopedLogger := s.l.With().Str("ipv6", ipv6Addresses[0].String()).Logger()
Msg("network state changed") if s.ipv6Addr != nil {
scopedLogger.Info().
// restart MDNS Str("old_ipv6", s.ipv6Addr.String()).
_ = startMDNS() Msg("IPv6 address changed")
networkState = newState } else {
requestDisplayUpdate() scopedLogger.Info().Msg("IPv6 address found")
} }
s.ipv6Addr = &ipv6Addresses[0]
changed = true
}
}
// if it's the initial check, we'll set changed to false
initialCheck := !s.checked
if initialCheck {
s.checked = true
changed = false
}
if initialCheck {
s.onInitialCheck(s)
} else if changed {
s.onStateChange(s)
}
return dhcpTargetState, nil
} }
func startMDNS() error { func (s *NetworkInterfaceState) CheckAndUpdateDhcp() error {
// If server was previously running, stop it dhcpTargetState, err := s.update()
if mDNSConn != nil {
logger.Info().Msg("stopping mDNS server")
err := mDNSConn.Close()
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("failed to stop mDNS server") return ErrorfL(s.l, "failed to update network state", err)
}
} }
// Start a new server switch dhcpTargetState {
hostname := "jetkvm.local" case DhcpTargetStateRenew:
s.l.Info().Msg("renewing DHCP lease")
scopedLogger := logger.With().Str("hostname", hostname).Logger() s.dhcpClient.Renew()
scopedLogger.Info().Msg("starting mDNS server") case DhcpTargetStateRelease:
s.l.Info().Msg("releasing DHCP lease")
addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) s.dhcpClient.Release()
if err != nil { case DhcpTargetStateStart:
return err s.l.Warn().Msg("dhcpTargetStateStart not implemented")
case DhcpTargetStateStop:
s.l.Warn().Msg("dhcpTargetStateStop not implemented")
} }
addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
if err != nil {
return err
}
l4, err := net.ListenUDP("udp4", addr4)
if err != nil {
return err
}
l6, err := net.ListenUDP("udp6", addr6)
if err != nil {
return err
}
mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &mdns.Config{
LocalNames: []string{hostname}, //TODO: make it configurable
LoggerFactory: defaultLoggerFactory,
})
if err != nil {
scopedLogger.Warn().Err(err).Msg("failed to start mDNS server")
mDNSConn = nil
return err
}
//defer server.Close()
return nil return nil
} }
func getNTPServersFromDHCPInfo() ([]string, error) { func (s *NetworkInterfaceState) HandleLinkUpdate(update netlink.LinkUpdate) {
buf, err := os.ReadFile(fmt.Sprintf(DHCPLeaseFile, NetIfName)) if update.Link.Attrs().Name == s.interfaceName {
if err != nil { s.l.Info().Interface("update", update).Msg("interface link update received")
// do not return error if file does not exist s.CheckAndUpdateDhcp()
if os.IsNotExist(err) {
return nil, nil
} }
return nil, fmt.Errorf("failed to load udhcpc info: %w", err)
}
// parse udhcpc info
env, err := envparse.Parse(bytes.NewReader(buf))
if err != nil {
return nil, fmt.Errorf("failed to parse udhcpc info: %w", err)
}
val, ok := env["ntpsrv"]
if !ok {
return nil, nil
}
var servers []string
for _, server := range strings.Fields(val) {
if net.ParseIP(server) == nil {
logger.Info().Str("server", server).Msg("invalid NTP server IP, ignoring")
}
servers = append(servers, server)
}
return servers, nil
} }
func initNetwork() { func initNetwork() {
@ -218,25 +312,29 @@ func initNetwork() {
done := make(chan struct{}) done := make(chan struct{})
if err := netlink.LinkSubscribe(updates, done); err != nil { if err := netlink.LinkSubscribe(updates, done); err != nil {
logger.Warn().Err(err).Msg("failed to subscribe to link updates") networkLogger.Warn().Err(err).Msg("failed to subscribe to link updates")
return return
} }
// TODO: support multiple interfaces
networkState = NewNetworkInterfaceState(NetIfName)
go networkState.dhcpClient.Run()
if err := networkState.CheckAndUpdateDhcp(); err != nil {
os.Exit(1)
}
go func() { go func() {
waitCtrlClientConnected() waitCtrlClientConnected()
checkNetworkState()
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case update := <-updates: case update := <-updates:
if update.Link.Attrs().Name == NetIfName { networkState.HandleLinkUpdate(update)
logger.Info().Interface("update", update).Msg("link update")
checkNetworkState()
}
case <-ticker.C: case <-ticker.C:
checkNetworkState() _ = networkState.CheckAndUpdateDhcp()
case <-done: case <-done:
return return
} }

214
ntp.go
View File

@ -1,214 +0,0 @@
package kvm
import (
"fmt"
"net/http"
"os/exec"
"strconv"
"time"
"github.com/beevik/ntp"
)
const (
timeSyncRetryStep = 5 * time.Second
timeSyncRetryMaxInt = 1 * time.Minute
timeSyncWaitNetChkInt = 100 * time.Millisecond
timeSyncWaitNetUpInt = 3 * time.Second
timeSyncInterval = 1 * time.Hour
timeSyncTimeout = 2 * time.Second
)
var (
builtTimestamp string
timeSyncRetryInterval = 0 * time.Second
timeSyncSuccess = false
defaultNTPServers = []string{
"time.cloudflare.com",
"time.apple.com",
}
)
func isTimeSyncNeeded() bool {
if builtTimestamp == "" {
ntpLogger.Warn().Msg("built timestamp is not set, time sync is needed")
return true
}
ts, err := strconv.Atoi(builtTimestamp)
if err != nil {
ntpLogger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp")
return true
}
// builtTimestamp is UNIX timestamp in seconds
builtTime := time.Unix(int64(ts), 0)
now := time.Now()
if now.Sub(builtTime) < 0 {
ntpLogger.Warn().
Str("built_time", builtTime.Format(time.RFC3339)).
Str("now", now.Format(time.RFC3339)).
Msg("system time is behind the built time, time sync is needed")
return true
}
return false
}
func TimeSyncLoop() {
for {
if !networkState.checked {
time.Sleep(timeSyncWaitNetChkInt)
continue
}
if !networkState.IsOnline() {
ntpLogger.Info().Msg("waiting for network to be online")
time.Sleep(timeSyncWaitNetUpInt)
continue
}
// check if time sync is needed, but do nothing for now
isTimeSyncNeeded()
ntpLogger.Info().Msg("syncing system time")
start := time.Now()
err := SyncSystemTime()
if err != nil {
ntpLogger.Error().Str("error", err.Error()).Msg("failed to sync system time")
// retry after a delay
timeSyncRetryInterval += timeSyncRetryStep
time.Sleep(timeSyncRetryInterval)
// reset the retry interval if it exceeds the max interval
if timeSyncRetryInterval > timeSyncRetryMaxInt {
timeSyncRetryInterval = 0
}
continue
}
timeSyncSuccess = true
ntpLogger.Info().Str("now", time.Now().Format(time.RFC3339)).
Str("time_taken", time.Since(start).String()).
Msg("time sync successful")
time.Sleep(timeSyncInterval) // after the first sync is done
}
}
func SyncSystemTime() (err error) {
now, err := queryNetworkTime()
if err != nil {
return fmt.Errorf("failed to query network time: %w", err)
}
err = setSystemTime(*now)
if err != nil {
return fmt.Errorf("failed to set system time: %w", err)
}
return nil
}
func queryNetworkTime() (*time.Time, error) {
ntpServers, err := getNTPServersFromDHCPInfo()
if err != nil {
ntpLogger.Info().Err(err).Msg("failed to get NTP servers from DHCP info")
}
if ntpServers == nil {
ntpServers = defaultNTPServers
ntpLogger.Info().
Interface("ntp_servers", ntpServers).
Msg("using default NTP servers")
} else {
ntpLogger.Info().
Interface("ntp_servers", ntpServers).
Msg("using NTP servers from DHCP")
}
for _, server := range ntpServers {
now, err, response := queryNtpServer(server, timeSyncTimeout)
scopedLogger := ntpLogger.With().
Str("server", server).
Logger()
if err == nil {
scopedLogger.Info().
Str("time", now.Format(time.RFC3339)).
Str("reference", response.ReferenceString()).
Str("rtt", response.RTT.String()).
Str("clockOffset", response.ClockOffset.String()).
Uint8("stratum", response.Stratum).
Msg("NTP server returned time")
return now, nil
} else {
scopedLogger.Error().
Str("error", err.Error()).
Msg("failed to query NTP server")
}
}
httpUrls := []string{
"http://apple.com",
"http://cloudflare.com",
}
for _, url := range httpUrls {
now, err, response := queryHttpTime(url, timeSyncTimeout)
var status string
if response != nil {
status = response.Status
}
scopedLogger := ntpLogger.With().
Str("http_url", url).
Str("status", status).
Logger()
if err == nil {
scopedLogger.Info().
Str("time", now.Format(time.RFC3339)).
Msg("HTTP server returned time")
return now, nil
} else {
scopedLogger.Error().
Str("error", err.Error()).
Msg("failed to query HTTP server")
}
}
return nil, ErrorfL(ntpLogger, "failed to query network time, all NTP servers and HTTP servers failed", nil)
}
func queryNtpServer(server string, timeout time.Duration) (now *time.Time, err error, response *ntp.Response) {
resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout})
if err != nil {
return nil, err, nil
}
return &resp.Time, nil, resp
}
func queryHttpTime(url string, timeout time.Duration) (now *time.Time, err error, response *http.Response) {
client := http.Client{
Timeout: timeout,
}
resp, err := client.Head(url)
if err != nil {
return nil, err, nil
}
dateStr := resp.Header.Get("Date")
parsedTime, err := time.Parse(time.RFC1123, dateStr)
if err != nil {
return nil, err, resp
}
return &parsedTime, nil, resp
}
func setSystemTime(now time.Time) error {
nowStr := now.Format("2006-01-02 15:04:05")
output, err := exec.Command("date", "-s", nowStr).CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run date -s: %w, %s", err, string(output))
}
return nil
}

69
timesync.go Normal file
View File

@ -0,0 +1,69 @@
package kvm
import (
"strconv"
"time"
"github.com/jetkvm/kvm/internal/timesync"
)
const (
timeSyncRetryStep = 5 * time.Second
timeSyncRetryMaxInt = 1 * time.Minute
timeSyncWaitNetChkInt = 100 * time.Millisecond
timeSyncWaitNetUpInt = 3 * time.Second
)
var (
timeSync *timesync.TimeSync
defaultNTPServers = []string{
"time.cloudflare.com",
"time.apple.com",
}
defaultHTTPUrls = []string{
"http://apple.com",
"http://cloudflare.com",
}
builtTimestamp string
)
func isTimeSyncNeeded() bool {
if builtTimestamp == "" {
timesyncLogger.Warn().Msg("built timestamp is not set, time sync is needed")
return true
}
ts, err := strconv.Atoi(builtTimestamp)
if err != nil {
timesyncLogger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp")
return true
}
// builtTimestamp is UNIX timestamp in seconds
builtTime := time.Unix(int64(ts), 0)
now := time.Now()
if now.Sub(builtTime) < 0 {
timesyncLogger.Warn().
Str("built_time", builtTime.Format(time.RFC3339)).
Str("now", now.Format(time.RFC3339)).
Msg("system time is behind the built time, time sync is needed")
return true
}
return false
}
func initTimeSync() {
timeSync = timesync.NewTimeSync(
func() (bool, error) {
if !networkState.IsOnline() {
return false, nil
}
return true, nil
},
defaultNTPServers,
defaultHTTPUrls,
timesyncLogger,
)
}

View File

@ -53,7 +53,7 @@ func initCertStore() {
func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if config.TLSMode == "self-signed" { if config.TLSMode == "self-signed" {
if isTimeSyncNeeded() || !timeSyncSuccess { if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() {
return nil, fmt.Errorf("time is not synced") return nil, fmt.Errorf("time is not synced")
} }
return certSigner.GetCertificate(info) return certSigner.GetCertificate(info)