kvm/internal/mdns/mdns.go

191 lines
3.5 KiB
Go

package mdns
import (
"fmt"
"net"
"reflect"
"strings"
"sync"
"github.com/jetkvm/kvm/internal/logging"
pion_mdns "github.com/pion/mdns/v2"
"github.com/rs/zerolog"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type MDNS struct {
conn *pion_mdns.Conn
lock sync.Mutex
l *zerolog.Logger
localNames []string
listenOptions *MDNSListenOptions
}
type MDNSListenOptions struct {
IPv4 bool
IPv6 bool
}
type MDNSOptions struct {
Logger *zerolog.Logger
LocalNames []string
ListenOptions *MDNSListenOptions
}
const (
DefaultAddressIPv4 = pion_mdns.DefaultAddressIPv4
DefaultAddressIPv6 = pion_mdns.DefaultAddressIPv6
)
func NewMDNS(opts *MDNSOptions) (*MDNS, error) {
if opts.Logger == nil {
opts.Logger = logging.GetDefaultLogger()
}
if opts.ListenOptions == nil {
opts.ListenOptions = &MDNSListenOptions{
IPv4: true,
IPv6: true,
}
}
return &MDNS{
l: opts.Logger,
lock: sync.Mutex{},
localNames: opts.LocalNames,
listenOptions: opts.ListenOptions,
}, nil
}
func (m *MDNS) start(allowRestart bool) error {
m.lock.Lock()
defer m.lock.Unlock()
if m.conn != nil {
if !allowRestart {
return fmt.Errorf("mDNS server already running")
}
m.conn.Close()
}
if m.listenOptions == nil {
return fmt.Errorf("listen options not set")
}
if !m.listenOptions.IPv4 && !m.listenOptions.IPv6 {
m.l.Info().Msg("mDNS server disabled")
return nil
}
var (
addr4, addr6 *net.UDPAddr
l4, l6 *net.UDPConn
p4 *ipv4.PacketConn
p6 *ipv6.PacketConn
err error
)
if m.listenOptions.IPv4 {
addr4, err = net.ResolveUDPAddr("udp4", DefaultAddressIPv4)
if err != nil {
return err
}
l4, err = net.ListenUDP("udp4", addr4)
if err != nil {
return err
}
p4 = ipv4.NewPacketConn(l4)
}
if m.listenOptions.IPv6 {
addr6, err = net.ResolveUDPAddr("udp6", DefaultAddressIPv6)
if err != nil {
return err
}
l6, err = net.ListenUDP("udp6", addr6)
if err != nil {
return err
}
p6 = ipv6.NewPacketConn(l6)
}
scopeLogger := m.l.With().
Interface("local_names", m.localNames).
Bool("ipv4", m.listenOptions.IPv4).
Bool("ipv6", m.listenOptions.IPv6).
Logger()
newLocalNames := make([]string, len(m.localNames))
for i, name := range m.localNames {
newLocalNames[i] = strings.TrimRight(strings.ToLower(name), ".")
if !strings.HasSuffix(newLocalNames[i], ".local") {
newLocalNames[i] = newLocalNames[i] + ".local"
}
}
mDNSConn, err := pion_mdns.Server(p4, p6, &pion_mdns.Config{
LocalNames: newLocalNames,
LoggerFactory: logging.GetPionDefaultLoggerFactory(),
})
if err != nil {
scopeLogger.Warn().Err(err).Msg("failed to start mDNS server")
return err
}
m.conn = mDNSConn
scopeLogger.Info().Msg("mDNS server started")
return nil
}
func (m *MDNS) Start() error {
return m.start(false)
}
func (m *MDNS) Restart() error {
return m.start(true)
}
func (m *MDNS) Stop() error {
m.lock.Lock()
defer m.lock.Unlock()
if m.conn == nil {
return nil
}
return m.conn.Close()
}
func (m *MDNS) SetLocalNames(localNames []string, always bool) error {
if reflect.DeepEqual(m.localNames, localNames) && !always {
return nil
}
m.localNames = localNames
m.Restart()
return nil
}
func (m *MDNS) SetListenOptions(listenOptions *MDNSListenOptions) error {
if m.listenOptions != nil &&
m.listenOptions.IPv4 == listenOptions.IPv4 &&
m.listenOptions.IPv6 == listenOptions.IPv6 {
return nil
}
m.listenOptions = listenOptions
m.Restart()
return nil
}