package mdns

import (
	"fmt"
	"net"
	"reflect"
	"strings"
	"sync"

	"github.com/jetkvm/kvm/internal/logging"
	pion_mdns "github.com/pion/mdns/v2"
	"github.com/rs/zerolog"
	"golang.org/x/net/ipv4"
	"golang.org/x/net/ipv6"
)

type MDNS struct {
	conn *pion_mdns.Conn
	lock sync.Mutex
	l    *zerolog.Logger

	localNames    []string
	listenOptions *MDNSListenOptions
}

type MDNSListenOptions struct {
	IPv4 bool
	IPv6 bool
}

type MDNSOptions struct {
	Logger        *zerolog.Logger
	LocalNames    []string
	ListenOptions *MDNSListenOptions
}

const (
	DefaultAddressIPv4 = pion_mdns.DefaultAddressIPv4
	DefaultAddressIPv6 = pion_mdns.DefaultAddressIPv6
)

func NewMDNS(opts *MDNSOptions) (*MDNS, error) {
	if opts.Logger == nil {
		opts.Logger = logging.GetDefaultLogger()
	}

	if opts.ListenOptions == nil {
		opts.ListenOptions = &MDNSListenOptions{
			IPv4: true,
			IPv6: true,
		}
	}

	return &MDNS{
		l:             opts.Logger,
		lock:          sync.Mutex{},
		localNames:    opts.LocalNames,
		listenOptions: opts.ListenOptions,
	}, nil
}

func (m *MDNS) start(allowRestart bool) error {
	m.lock.Lock()
	defer m.lock.Unlock()

	if m.conn != nil {
		if !allowRestart {
			return fmt.Errorf("mDNS server already running")
		}

		m.conn.Close()
	}

	if m.listenOptions == nil {
		return fmt.Errorf("listen options not set")
	}

	if !m.listenOptions.IPv4 && !m.listenOptions.IPv6 {
		m.l.Info().Msg("mDNS server disabled")
		return nil
	}

	var (
		addr4, addr6 *net.UDPAddr
		l4, l6       *net.UDPConn
		p4           *ipv4.PacketConn
		p6           *ipv6.PacketConn
		err          error
	)

	if m.listenOptions.IPv4 {
		addr4, err = net.ResolveUDPAddr("udp4", DefaultAddressIPv4)
		if err != nil {
			return err
		}

		l4, err = net.ListenUDP("udp4", addr4)
		if err != nil {
			return err
		}

		p4 = ipv4.NewPacketConn(l4)
	}

	if m.listenOptions.IPv6 {
		addr6, err = net.ResolveUDPAddr("udp6", DefaultAddressIPv6)
		if err != nil {
			return err
		}

		l6, err = net.ListenUDP("udp6", addr6)
		if err != nil {
			return err
		}

		p6 = ipv6.NewPacketConn(l6)
	}

	scopeLogger := m.l.With().
		Interface("local_names", m.localNames).
		Bool("ipv4", m.listenOptions.IPv4).
		Bool("ipv6", m.listenOptions.IPv6).
		Logger()

	newLocalNames := make([]string, len(m.localNames))
	for i, name := range m.localNames {
		newLocalNames[i] = strings.TrimRight(strings.ToLower(name), ".")
		if !strings.HasSuffix(newLocalNames[i], ".local") {
			newLocalNames[i] = newLocalNames[i] + ".local"
		}
	}

	mDNSConn, err := pion_mdns.Server(p4, p6, &pion_mdns.Config{
		LocalNames:    newLocalNames,
		LoggerFactory: logging.GetPionDefaultLoggerFactory(),
	})

	if err != nil {
		scopeLogger.Warn().Err(err).Msg("failed to start mDNS server")
		return err
	}

	m.conn = mDNSConn
	scopeLogger.Info().Msg("mDNS server started")

	return nil
}

func (m *MDNS) Start() error {
	return m.start(false)
}

func (m *MDNS) Restart() error {
	return m.start(true)
}

func (m *MDNS) Stop() error {
	m.lock.Lock()
	defer m.lock.Unlock()

	if m.conn == nil {
		return nil
	}

	return m.conn.Close()
}

func (m *MDNS) SetLocalNames(localNames []string, always bool) error {
	if reflect.DeepEqual(m.localNames, localNames) && !always {
		return nil
	}

	m.localNames = localNames
	_ = m.Restart()

	return nil
}

func (m *MDNS) SetListenOptions(listenOptions *MDNSListenOptions) error {
	if m.listenOptions != nil &&
		m.listenOptions.IPv4 == listenOptions.IPv4 &&
		m.listenOptions.IPv6 == listenOptions.IPv6 {
		return nil
	}

	m.listenOptions = listenOptions
	_ = m.Restart()

	return nil
}