refactor(mdms): move mdns to internal/mdns package

This commit is contained in:
Siyuan Miao 2025-04-15 00:46:46 +02:00
parent 08021f912e
commit b24191d14e
13 changed files with 491 additions and 88 deletions

View File

@ -6,6 +6,7 @@ import (
"os"
"sync"
"github.com/jetkvm/kvm/internal/logging"
"github.com/jetkvm/kvm/internal/network"
"github.com/jetkvm/kvm/internal/usbgadget"
)
@ -123,6 +124,7 @@ var defaultConfig = &Config{
Keyboard: true,
MassStorage: true,
},
NetworkConfig: &network.NetworkConfig{},
DefaultLogLevel: "INFO",
}
@ -172,7 +174,7 @@ func LoadConfig() {
config = &loadedConfig
rootLogger.UpdateLogLevel()
logging.GetRootLogger().UpdateLogLevel(config.DefaultLogLevel)
logger.Info().Str("path", configPath).Msg("config loaded")
}

10
hw.go
View File

@ -4,6 +4,7 @@ import (
"fmt"
"os"
"regexp"
"strings"
"sync"
"time"
)
@ -51,6 +52,15 @@ func GetDeviceID() string {
return deviceID
}
func GetDefaultHostname() string {
deviceId := GetDeviceID()
if deviceId == "unknown_device_id" {
return "jetkvm"
}
return fmt.Sprintf("jetkvm-%s", strings.ToLower(deviceId))
}
func runWatchdog() {
file, err := os.OpenFile("/dev/watchdog", os.O_WRONLY, 0)
if err != nil {

154
internal/mdns/mdns.go Normal file
View File

@ -0,0 +1,154 @@
package mdns
import (
"fmt"
"net"
"reflect"
"strings"
"sync"
"github.com/jetkvm/kvm/internal/logging"
pion_mdns "github.com/pion/mdns/v2"
"github.com/rs/zerolog"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type MDNS struct {
conn *pion_mdns.Conn
lock sync.Mutex
l *zerolog.Logger
localNames []string
listenOptions *MDNSListenOptions
}
type MDNSListenOptions struct {
IPv4 bool
IPv6 bool
}
type MDNSOptions struct {
Logger *zerolog.Logger
LocalNames []string
ListenOptions *MDNSListenOptions
}
const (
DefaultAddressIPv4 = pion_mdns.DefaultAddressIPv4
DefaultAddressIPv6 = pion_mdns.DefaultAddressIPv6
)
func NewMDNS(opts *MDNSOptions) (*MDNS, error) {
if opts.Logger == nil {
opts.Logger = logging.GetDefaultLogger()
}
return &MDNS{
l: opts.Logger,
lock: sync.Mutex{},
localNames: opts.LocalNames,
listenOptions: opts.ListenOptions,
}, nil
}
func (m *MDNS) start(allowRestart bool) error {
m.lock.Lock()
defer m.lock.Unlock()
if m.conn != nil {
if !allowRestart {
return fmt.Errorf("mDNS server already running")
}
m.conn.Close()
}
addr4, err := net.ResolveUDPAddr("udp4", DefaultAddressIPv4)
if err != nil {
return err
}
addr6, err := net.ResolveUDPAddr("udp6", DefaultAddressIPv6)
if err != nil {
return err
}
l4, err := net.ListenUDP("udp4", addr4)
if err != nil {
return err
}
l6, err := net.ListenUDP("udp6", addr6)
if err != nil {
return err
}
scopeLogger := m.l.With().Interface("local_names", m.localNames).Logger()
newLocalNames := make([]string, len(m.localNames))
for i, name := range m.localNames {
newLocalNames[i] = strings.TrimRight(strings.ToLower(name), ".")
if !strings.HasSuffix(newLocalNames[i], ".local") {
newLocalNames[i] = newLocalNames[i] + ".local"
}
}
mDNSConn, err := pion_mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &pion_mdns.Config{
LocalNames: newLocalNames,
LoggerFactory: logging.GetPionDefaultLoggerFactory(),
})
if err != nil {
scopeLogger.Warn().Err(err).Msg("failed to start mDNS server")
return err
}
m.conn = mDNSConn
scopeLogger.Info().Msg("mDNS server started")
return nil
}
func (m *MDNS) Start() error {
return m.start(false)
}
func (m *MDNS) Restart() error {
return m.start(true)
}
func (m *MDNS) Stop() error {
m.lock.Lock()
defer m.lock.Unlock()
if m.conn == nil {
return nil
}
return m.conn.Close()
}
func (m *MDNS) SetLocalNames(localNames []string, always bool) error {
if reflect.DeepEqual(m.localNames, localNames) && !always {
return nil
}
m.localNames = localNames
m.Restart()
return nil
}
func (m *MDNS) SetListenOptions(listenOptions *MDNSListenOptions) error {
if m.listenOptions != nil &&
m.listenOptions.IPv4 == listenOptions.IPv4 &&
m.listenOptions.IPv6 == listenOptions.IPv6 {
return nil
}
m.listenOptions = listenOptions
m.Restart()
return nil
}

1
internal/mdns/utils.go Normal file
View File

@ -0,0 +1 @@
package mdns

View File

@ -1,8 +1,11 @@
package network
import (
"fmt"
"net"
"time"
"golang.org/x/net/idna"
)
type IPv6Address struct {
@ -33,8 +36,51 @@ type NetworkConfig struct {
DNS []string `json:"dns" validate_type:"ipv6"`
} `json:"ipv6_static,omitempty" required_if:"ipv6_mode,static"`
LLDPMode string `json:"lldp_mode,omitempty" one_of:"disabled,basic,all" default:"basic"`
LLDPTxTLVs []string `json:"lldp_tx_tlvs,omitempty" one_of:"chassis,port,system,vlan" default:"chassis,port,system,vlan"`
MDNSMode string `json:"mdns_mode,omitempty" one_of:"disabled,auto,ipv4_only,ipv6_only" default:"auto"`
TimeSyncMode string `json:"time_sync_mode,omitempty" one_of:"ntp_only,ntp_and_http,http_only,custom" default:"ntp_and_http"`
LLDPMode string `json:"lldp_mode,omitempty" one_of:"disabled,basic,all" default:"basic"`
LLDPTxTLVs []string `json:"lldp_tx_tlvs,omitempty" one_of:"chassis,port,system,vlan" default:"chassis,port,system,vlan"`
MDNSMode string `json:"mdns_mode,omitempty" one_of:"disabled,auto,ipv4_only,ipv6_only" default:"auto"`
TimeSyncMode string `json:"time_sync_mode,omitempty" one_of:"ntp_only,ntp_and_http,http_only,custom" default:"ntp_and_http"`
TimeSyncOrdering []string `json:"time_sync_ordering,omitempty" one_of:"http,ntp,ntp_dhcp,ntp_user_provided,ntp_fallback" default:"ntp,http"`
TimeSyncDisableFallback bool `json:"time_sync_disable_fallback,omitempty" default:"false"`
TimeSyncParallel int `json:"time_sync_parallel,omitempty" default:"4"`
}
func (s *NetworkInterfaceState) GetHostname() string {
hostname := ToValidHostname(s.config.Hostname)
if hostname == "" {
return s.defaultHostname
}
return hostname
}
func ToValidDomain(domain string) string {
ascii, err := idna.Lookup.ToASCII(domain)
if err != nil {
return ""
}
return ascii
}
func (s *NetworkInterfaceState) GetDomain() string {
domain := ToValidDomain(s.config.Domain)
if domain == "" {
lease := s.dhcpClient.GetLease()
if lease != nil && lease.Domain != "" {
domain = ToValidDomain(lease.Domain)
}
}
if domain == "" {
return "local"
}
return domain
}
func (s *NetworkInterfaceState) GetFQDN() string {
return fmt.Sprintf("%s.%s", s.GetHostname(), s.GetDomain())
}

View File

@ -0,0 +1,124 @@
package network
import (
"fmt"
"io"
"os"
"os/exec"
"strings"
"sync"
"golang.org/x/net/idna"
)
const (
hostnamePath = "/etc/hostname"
hostsPath = "/etc/hosts"
)
var (
hostnameLock sync.Mutex = sync.Mutex{}
)
func updateEtcHosts(hostname string, fqdn string) error {
// update /etc/hosts
hostsFile, err := os.OpenFile(hostsPath, os.O_RDWR|os.O_SYNC, os.ModeExclusive)
if err != nil {
return fmt.Errorf("failed to open %s: %w", hostsPath, err)
}
defer hostsFile.Close()
// read all lines
hostsFile.Seek(0, io.SeekStart)
lines, err := io.ReadAll(hostsFile)
if err != nil {
return fmt.Errorf("failed to read %s: %w", hostsPath, err)
}
newLines := []string{}
hostLine := fmt.Sprintf("127.0.1.1\t%s %s", hostname, fqdn)
hostLineExists := false
for _, line := range strings.Split(string(lines), "\n") {
if strings.HasPrefix(line, "127.0.1.1") {
hostLineExists = true
line = hostLine
}
newLines = append(newLines, line)
}
if !hostLineExists {
newLines = append(newLines, hostLine)
}
hostsFile.Truncate(0)
hostsFile.Seek(0, io.SeekStart)
hostsFile.Write([]byte(strings.Join(newLines, "\n")))
return nil
}
func ToValidHostname(hostname string) string {
ascii, err := idna.Lookup.ToASCII(hostname)
if err != nil {
return ""
}
return ascii
}
func SetHostname(hostname string, fqdn string) error {
hostnameLock.Lock()
defer hostnameLock.Unlock()
hostname = ToValidHostname(strings.TrimSpace(hostname))
fqdn = ToValidHostname(strings.TrimSpace(fqdn))
if hostname == "" {
return fmt.Errorf("invalid hostname: %s", hostname)
}
if fqdn == "" {
fqdn = hostname
}
// update /etc/hostname
os.WriteFile(hostnamePath, []byte(hostname), 0644)
// update /etc/hosts
if err := updateEtcHosts(hostname, fqdn); err != nil {
return fmt.Errorf("failed to update /etc/hosts: %w", err)
}
// run hostname
if err := exec.Command("hostname", "-F", hostnamePath).Run(); err != nil {
return fmt.Errorf("failed to run hostname: %w", err)
}
return nil
}
func (s *NetworkInterfaceState) setHostnameIfNotSame() error {
hostname := s.GetHostname()
currentHostname, _ := os.Hostname()
fqdn := fmt.Sprintf("%s.%s", hostname, s.GetDomain())
if currentHostname == hostname && s.currentFqdn == fqdn && s.currentHostname == hostname {
return nil
}
scopedLogger := s.l.With().Str("hostname", hostname).Str("fqdn", fqdn).Logger()
err := SetHostname(hostname, fqdn)
if err != nil {
scopedLogger.Error().Err(err).Msg("failed to set hostname")
return err
}
s.currentHostname = hostname
s.currentFqdn = fqdn
scopedLogger.Info().Msg("hostname set")
return nil
}

View File

@ -1,6 +1,7 @@
package network
import (
"fmt"
"net"
"sync"
"time"
@ -29,6 +30,10 @@ type NetworkInterfaceState struct {
config *NetworkConfig
dhcpClient *udhcpc.DHCPClient
defaultHostname string
currentHostname string
currentFqdn string
onStateChange func(state *NetworkInterfaceState)
onInitialCheck func(state *NetworkInterfaceState)
@ -39,21 +44,31 @@ type NetworkInterfaceOptions struct {
InterfaceName string
DhcpPidFile string
Logger *zerolog.Logger
DefaultHostname string
OnStateChange func(state *NetworkInterfaceState)
OnInitialCheck func(state *NetworkInterfaceState)
OnDhcpLeaseChange func(lease *udhcpc.Lease)
NetworkConfig *NetworkConfig
}
func NewNetworkInterfaceState(opts *NetworkInterfaceOptions) *NetworkInterfaceState {
func NewNetworkInterfaceState(opts *NetworkInterfaceOptions) (*NetworkInterfaceState, error) {
if opts.NetworkConfig == nil {
return nil, fmt.Errorf("NetworkConfig can not be nil")
}
if opts.DefaultHostname == "" {
opts.DefaultHostname = "jetkvm"
}
l := opts.Logger
s := &NetworkInterfaceState{
interfaceName: opts.InterfaceName,
stateLock: sync.Mutex{},
l: l,
onStateChange: opts.OnStateChange,
onInitialCheck: opts.OnInitialCheck,
config: opts.NetworkConfig,
interfaceName: opts.InterfaceName,
defaultHostname: opts.DefaultHostname,
stateLock: sync.Mutex{},
l: l,
onStateChange: opts.OnStateChange,
onInitialCheck: opts.OnInitialCheck,
config: opts.NetworkConfig,
}
// create the dhcp client
@ -68,13 +83,15 @@ func NewNetworkInterfaceState(opts *NetworkInterfaceOptions) *NetworkInterfaceSt
return
}
s.setHostnameIfNotSame()
opts.OnDhcpLeaseChange(lease)
},
})
s.dhcpClient = dhcpClient
return s
return s, nil
}
func (s *NetworkInterfaceState) IsUp() bool {
@ -277,6 +294,12 @@ func (s *NetworkInterfaceState) update() (DhcpTargetState, error) {
if initialCheck {
s.checked = true
changed = false
if dhcpTargetState == DhcpTargetStateRenew {
// it's the initial check, we'll start the DHCP client
// dhcpTargetState = DhcpTargetStateStart
// TODO: manage DHCP client start/stop
dhcpTargetState = DhcpTargetStateDoNothing
}
}
if initialCheck {
@ -326,6 +349,8 @@ func (s *NetworkInterfaceState) Run() error {
return err
}
_ = s.setHostnameIfNotSame()
// run the dhcp client
go s.dhcpClient.Run() // nolint:errcheck

View File

@ -39,14 +39,18 @@ type RpcNetworkSettings struct {
func (s *NetworkInterfaceState) RpcGetNetworkState() RpcNetworkState {
ipv6Addresses := make([]RpcIPv6Address, 0)
for _, addr := range s.ipv6Addresses {
ipv6Addresses = append(ipv6Addresses, RpcIPv6Address{
Address: addr.Prefix.String(),
ValidLifetime: addr.ValidLifetime,
PreferredLifetime: addr.PreferredLifetime,
Scope: addr.Scope,
})
if s.ipv6Addresses != nil {
for _, addr := range s.ipv6Addresses {
ipv6Addresses = append(ipv6Addresses, RpcIPv6Address{
Address: addr.Prefix.String(),
ValidLifetime: addr.ValidLifetime,
PreferredLifetime: addr.PreferredLifetime,
Scope: addr.Scope,
})
}
}
return RpcNetworkState{
InterfaceName: s.interfaceName,
MacAddress: s.macAddr.String(),
@ -60,6 +64,10 @@ func (s *NetworkInterfaceState) RpcGetNetworkState() RpcNetworkState {
}
func (s *NetworkInterfaceState) RpcGetNetworkSettings() RpcNetworkSettings {
if s.config == nil {
return RpcNetworkSettings{}
}
return RpcNetworkSettings{
Hostname: null.StringFrom(s.config.Hostname),
Domain: null.StringFrom(s.config.Domain),

View File

@ -960,10 +960,10 @@ var rpcHandlers = map[string]RPCHandler{
"getDeviceID": {Func: rpcGetDeviceID},
"deregisterDevice": {Func: rpcDeregisterDevice},
"getCloudState": {Func: rpcGetCloudState},
"getNetworkState": {Func: networkState.RpcGetNetworkState},
"getNetworkSettings": {Func: networkState.RpcGetNetworkSettings},
"setNetworkSettings": {Func: networkState.RpcSetNetworkSettings, Params: []string{"settings"}},
"renewDHCPLease": {Func: networkState.RpcRenewDHCPLease},
"getNetworkState": {Func: rpcGetNetworkState},
"getNetworkSettings": {Func: rpcGetNetworkSettings},
"setNetworkSettings": {Func: rpcSetNetworkSettings, Params: []string{"settings"}},
"renewDHCPLease": {Func: rpcRenewDHCPLease},
"keyboardReport": {Func: rpcKeyboardReport, Params: []string{"modifier", "keys"}},
"absMouseReport": {Func: rpcAbsMouseReport, Params: []string{"x", "y", "buttons"}},
"relMouseReport": {Func: rpcRelMouseReport, Params: []string{"dx", "dy", "buttons"}},

18
main.go
View File

@ -43,12 +43,26 @@ func Main() {
Int("ca_certs_loaded", len(rootcerts.Certs())).
Msg("loaded Root CA certificates")
initNetwork()
initTimeSync()
// Initialize network
if err := initNetwork(); err != nil {
logger.Error().Err(err).Msg("failed to initialize network")
os.Exit(1)
}
// Initialize time sync
initTimeSync()
timeSync.Start()
// Initialize mDNS
if err := initMdns(); err != nil {
logger.Error().Err(err).Msg("failed to initialize mDNS")
os.Exit(1)
}
// Initialize native ctrl socket server
StartNativeCtrlSocketServer()
// Initialize native video socket server
StartNativeVideoSocketServer()
initPrometheus()

69
mdns.go
View File

@ -1,60 +1,33 @@
package kvm
import (
"net"
"github.com/pion/mdns/v2"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"github.com/jetkvm/kvm/internal/mdns"
)
var mDNSConn *mdns.Conn
var mDNS *mdns.MDNS
func startMDNS() error {
// If server was previously running, stop it
if mDNSConn != nil {
logger.Info().Msg("stopping mDNS server")
err := mDNSConn.Close()
if err != nil {
logger.Warn().Err(err).Msg("failed to stop mDNS server")
}
}
// Start a new server
hostname := "jetkvm.local"
scopedLogger := logger.With().Str("hostname", hostname).Logger()
scopedLogger.Info().Msg("starting mDNS server")
addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4)
if err != nil {
return err
}
addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
if err != nil {
return err
}
l4, err := net.ListenUDP("udp4", addr4)
if err != nil {
return err
}
l6, err := net.ListenUDP("udp6", addr6)
if err != nil {
return err
}
mDNSConn, err = mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &mdns.Config{
LocalNames: []string{hostname}, //TODO: make it configurable
LoggerFactory: defaultLoggerFactory,
func initMdns() error {
m, err := mdns.NewMDNS(&mdns.MDNSOptions{
Logger: logger,
LocalNames: []string{
networkState.GetHostname(),
networkState.GetFQDN(),
},
ListenOptions: &mdns.MDNSListenOptions{
IPv4: true,
IPv6: true,
},
})
if err != nil {
scopedLogger.Warn().Err(err).Msg("failed to start mDNS server")
mDNSConn = nil
return err
}
//defer server.Close()
err = m.Start()
if err != nil {
return err
}
mDNS = m
return nil
}

View File

@ -1,7 +1,7 @@
package kvm
import (
"os"
"fmt"
"github.com/jetkvm/kvm/internal/network"
"github.com/jetkvm/kvm/internal/udhcpc"
@ -15,27 +15,72 @@ var (
networkState *network.NetworkInterfaceState
)
func initNetwork() {
func networkStateChanged() {
// do not block the main thread
go waitCtrlAndRequestDisplayUpdate(true)
// always restart mDNS when the network state changes
if mDNS != nil {
mDNS.SetLocalNames([]string{
networkState.GetHostname(),
networkState.GetFQDN(),
}, true)
}
}
func initNetwork() error {
ensureConfigLoaded()
networkState = network.NewNetworkInterfaceState(&network.NetworkInterfaceOptions{
InterfaceName: NetIfName,
NetworkConfig: config.NetworkConfig,
Logger: networkLogger,
state, err := network.NewNetworkInterfaceState(&network.NetworkInterfaceOptions{
DefaultHostname: GetDefaultHostname(),
InterfaceName: NetIfName,
NetworkConfig: config.NetworkConfig,
Logger: networkLogger,
OnStateChange: func(state *network.NetworkInterfaceState) {
waitCtrlAndRequestDisplayUpdate(true)
networkStateChanged()
},
OnInitialCheck: func(state *network.NetworkInterfaceState) {
waitCtrlAndRequestDisplayUpdate(true)
networkStateChanged()
},
OnDhcpLeaseChange: func(lease *udhcpc.Lease) {
waitCtrlAndRequestDisplayUpdate(true)
networkStateChanged()
if currentSession == nil {
return
}
writeJSONRPCEvent("networkState", networkState.RpcGetNetworkState(), currentSession)
},
})
err := networkState.Run()
if err != nil {
networkLogger.Error().Err(err).Msg("failed to run network state")
os.Exit(1)
if state == nil {
if err == nil {
return fmt.Errorf("failed to create NetworkInterfaceState")
}
return err
}
if err := state.Run(); err != nil {
return err
}
networkState = state
return nil
}
func rpcGetNetworkState() network.RpcNetworkState {
return networkState.RpcGetNetworkState()
}
func rpcGetNetworkSettings() network.RpcNetworkSettings {
return networkState.RpcGetNetworkSettings()
}
func rpcSetNetworkSettings(settings network.RpcNetworkSettings) error {
return networkState.RpcSetNetworkSettings(settings)
}
func rpcRenewDHCPLease() error {
return networkState.RpcRenewDHCPLease()
}

View File

@ -10,6 +10,7 @@ import (
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
"github.com/jetkvm/kvm/internal/logging"
"github.com/pion/webrtc/v4"
"github.com/rs/zerolog"
)
@ -68,7 +69,7 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) {
func newSession(config SessionConfig) (*Session, error) {
webrtcSettingEngine := webrtc.SettingEngine{
LoggerFactory: defaultLoggerFactory,
LoggerFactory: logging.GetPionDefaultLoggerFactory(),
}
iceServer := webrtc.ICEServer{}