kvm/pkg/nmlite/link/manager.go

545 lines
14 KiB
Go

package link
import (
"context"
"fmt"
"net"
"time"
"github.com/jetkvm/kvm/internal/sync"
"github.com/jetkvm/kvm/internal/network/types"
"github.com/rs/zerolog"
"github.com/vishvananda/netlink"
)
// StateChangeHandler is the function type for link state callbacks
type StateChangeHandler func(link *Link)
// StateChangeCallback is the struct for link state callbacks
type StateChangeCallback struct {
Async bool
Func StateChangeHandler
}
// NetlinkManager provides centralized netlink operations
type NetlinkManager struct {
logger *zerolog.Logger
mu sync.RWMutex
stateChangeCallbacks map[string][]StateChangeCallback
}
func newNetlinkManager(logger *zerolog.Logger) *NetlinkManager {
if logger == nil {
logger = &zerolog.Logger{} // Default no-op logger
}
n := &NetlinkManager{
logger: logger,
stateChangeCallbacks: make(map[string][]StateChangeCallback),
}
n.monitorStateChange()
return n
}
// GetNetlinkManager returns the singleton NetlinkManager instance
func GetNetlinkManager() *NetlinkManager {
netlinkManagerOnce.Do(func() {
netlinkManagerInstance = newNetlinkManager(nil)
})
return netlinkManagerInstance
}
// InitializeNetlinkManager initializes the singleton NetlinkManager with a logger
func InitializeNetlinkManager(logger *zerolog.Logger) *NetlinkManager {
netlinkManagerOnce.Do(func() {
netlinkManagerInstance = newNetlinkManager(logger)
})
return netlinkManagerInstance
}
// AddStateChangeCallback adds a callback for link state changes
func (nm *NetlinkManager) AddStateChangeCallback(ifname string, callback StateChangeCallback) {
nm.mu.Lock()
defer nm.mu.Unlock()
if _, ok := nm.stateChangeCallbacks[ifname]; !ok {
nm.stateChangeCallbacks[ifname] = make([]StateChangeCallback, 0)
}
nm.stateChangeCallbacks[ifname] = append(nm.stateChangeCallbacks[ifname], callback)
}
// Interface operations
func (nm *NetlinkManager) monitorStateChange() {
updateCh := make(chan netlink.LinkUpdate)
// we don't need to stop the subscription, as it will be closed when the program exits
stopCh := make(chan struct{}) //nolint:unused
if err := netlink.LinkSubscribe(updateCh, stopCh); err != nil {
nm.logger.Error().Err(err).Msg("failed to subscribe to link state changes")
}
nm.logger.Info().Msg("state change monitoring started")
go func() {
for update := range updateCh {
nm.runCallbacks(update)
}
}()
}
func (nm *NetlinkManager) runCallbacks(update netlink.LinkUpdate) {
nm.mu.RLock()
defer nm.mu.RUnlock()
ifname := update.Link.Attrs().Name
callbacks, ok := nm.stateChangeCallbacks[ifname]
l := nm.logger.With().Str("interface", ifname).Logger()
if !ok {
l.Trace().Msg("no state change callbacks for interface")
return
}
for _, callback := range callbacks {
l.Trace().
Interface("callback", callback).
Bool("async", callback.Async).
Msg("calling callback")
if callback.Async {
go callback.Func(&Link{Link: update.Link})
} else {
callback.Func(&Link{Link: update.Link})
}
}
}
// GetLinkByName gets a network link by name
func (nm *NetlinkManager) GetLinkByName(name string) (*Link, error) {
nm.mu.RLock()
defer nm.mu.RUnlock()
link, err := netlink.LinkByName(name)
if err != nil {
return nil, err
}
return &Link{Link: link}, nil
}
// LinkSetUp brings a network interface up
func (nm *NetlinkManager) LinkSetUp(link *Link) error {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.LinkSetUp(link)
}
// LinkSetDown brings a network interface down
func (nm *NetlinkManager) LinkSetDown(link *Link) error {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.LinkSetDown(link)
}
// EnsureInterfaceUp ensures the interface is up
func (nm *NetlinkManager) EnsureInterfaceUp(link *Link) error {
if link.Attrs().OperState == netlink.OperUp {
return nil
}
return nm.LinkSetUp(link)
}
// EnsureInterfaceUpWithTimeout ensures the interface is up with timeout and retry logic
func (nm *NetlinkManager) EnsureInterfaceUpWithTimeout(ctx context.Context, iface *Link, timeout time.Duration) (*Link, error) {
ifname := iface.Attrs().Name
l := nm.logger.With().Str("interface", ifname).Logger()
linkUpTimeout := time.After(timeout)
var attempt int
start := time.Now()
for {
link, err := nm.GetLinkByName(ifname)
if err != nil {
return nil, err
}
state := link.Attrs().OperState
l = l.With().
Int("attempt", attempt).
Dur("duration", time.Since(start)).
Str("state", state.String()).
Logger()
if state == netlink.OperUp || state == netlink.OperUnknown {
if attempt > 0 {
l.Info().Int("attempt", attempt-1).Msg("interface is up")
}
return link, nil
}
l.Info().Msg("bringing up interface")
// bring up the interface
if err = nm.LinkSetUp(link); err != nil {
l.Error().Err(err).Msg("interface can't make it up")
}
// refresh the link attributes
if err = link.Refresh(); err != nil {
l.Error().Err(err).Msg("failed to refresh link attributes")
}
// check the state again
state = link.Attrs().OperState
l = l.With().Str("new_state", state.String()).Logger()
if state == netlink.OperUp {
l.Info().Msg("interface is up")
return link, nil
}
l.Warn().Msg("interface is still down, retrying")
select {
case <-time.After(500 * time.Millisecond):
attempt++
continue
case <-ctx.Done():
if err != nil {
return nil, err
}
return nil, ErrInterfaceUpCanceled
case <-linkUpTimeout:
attempt++
l.Error().
Int("attempt", attempt).Msg("interface is still down after timeout")
if err != nil {
return nil, err
}
return nil, ErrInterfaceUpTimeout
}
}
}
// Address operations
// AddrList gets all addresses for a link
func (nm *NetlinkManager) AddrList(link *Link, family int) ([]netlink.Addr, error) {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.AddrList(link, family)
}
// AddrAdd adds an address to a link
func (nm *NetlinkManager) AddrAdd(link *Link, addr *netlink.Addr) error {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.AddrAdd(link, addr)
}
// AddrDel removes an address from a link
func (nm *NetlinkManager) AddrDel(link *Link, addr *netlink.Addr) error {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.AddrDel(link, addr)
}
// RemoveAllAddresses removes all addresses of a specific family from a link
func (nm *NetlinkManager) RemoveAllAddresses(link *Link, family int) error {
addrs, err := nm.AddrList(link, family)
if err != nil {
return fmt.Errorf("failed to get addresses: %w", err)
}
for _, addr := range addrs {
if err := nm.AddrDel(link, &addr); err != nil {
nm.logger.Warn().Err(err).Str("address", addr.IP.String()).Msg("failed to remove address")
}
}
return nil
}
// RemoveNonLinkLocalIPv6Addresses removes all non-link-local IPv6 addresses
func (nm *NetlinkManager) RemoveNonLinkLocalIPv6Addresses(link *Link) error {
addrs, err := nm.AddrList(link, AfInet6)
if err != nil {
return fmt.Errorf("failed to get IPv6 addresses: %w", err)
}
for _, addr := range addrs {
if !addr.IP.IsLinkLocalUnicast() {
if err := nm.AddrDel(link, &addr); err != nil {
nm.logger.Warn().Err(err).Str("address", addr.IP.String()).Msg("failed to remove IPv6 address")
}
}
}
return nil
}
// RouteList gets all routes
func (nm *NetlinkManager) RouteList(link *Link, family int) ([]netlink.Route, error) {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.RouteList(link, family)
}
// RouteAdd adds a route
func (nm *NetlinkManager) RouteAdd(route *netlink.Route) error {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.RouteAdd(route)
}
// RouteDel removes a route
func (nm *NetlinkManager) RouteDel(route *netlink.Route) error {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.RouteDel(route)
}
// RouteReplace replaces a route
func (nm *NetlinkManager) RouteReplace(route *netlink.Route) error {
nm.mu.RLock()
defer nm.mu.RUnlock()
return netlink.RouteReplace(route)
}
// ListDefaultRoutes lists the default routes for the given family
func (nm *NetlinkManager) ListDefaultRoutes(family int) ([]netlink.Route, error) {
routes, err := netlink.RouteListFiltered(
family,
&netlink.Route{Dst: nil, Table: 254},
netlink.RT_FILTER_DST|netlink.RT_FILTER_TABLE,
)
if err != nil {
nm.logger.Error().Err(err).Int("family", family).Msg("failed to list default routes")
return nil, err
}
return routes, nil
}
// HasDefaultRoute checks if a default route exists for the given family
func (nm *NetlinkManager) HasDefaultRoute(family int) bool {
routes, err := nm.ListDefaultRoutes(family)
if err != nil {
return false
}
return len(routes) > 0
}
// AddDefaultRoute adds a default route
func (nm *NetlinkManager) AddDefaultRoute(link *Link, gateway net.IP, family int) error {
var dst *net.IPNet
switch family {
case AfInet:
dst = &ipv4DefaultRoute
case AfInet6:
dst = &ipv6DefaultRoute
default:
return fmt.Errorf("unsupported address family: %d", family)
}
route := &netlink.Route{
Dst: dst,
Gw: gateway,
LinkIndex: link.Attrs().Index,
}
return nm.RouteReplace(route)
}
// RemoveDefaultRoute removes the default route for the given family
func (nm *NetlinkManager) RemoveDefaultRoute(family int) error {
routes, err := nm.RouteList(nil, family)
if err != nil {
return fmt.Errorf("failed to get routes: %w", err)
}
for _, route := range routes {
if route.Dst != nil {
if family == AfInet && route.Dst.IP.Equal(net.IPv4zero) && route.Dst.Mask.String() == "0.0.0.0/0" {
if err := nm.RouteDel(&route); err != nil {
nm.logger.Warn().Err(err).Msg("failed to remove IPv4 default route")
}
}
if family == AfInet6 && route.Dst.IP.Equal(net.IPv6zero) && route.Dst.Mask.String() == "::/0" {
if err := nm.RouteDel(&route); err != nil {
nm.logger.Warn().Err(err).Msg("failed to remove IPv6 default route")
}
}
}
}
return nil
}
func (nm *NetlinkManager) reconcileDefaultRoute(link *Link, expected map[string]net.IP, family int) error {
linkIndex := link.Attrs().Index
added := 0
toRemove := make([]*netlink.Route, 0)
defaultRoutes, err := nm.ListDefaultRoutes(family)
if err != nil {
return fmt.Errorf("failed to get default routes: %w", err)
}
// check existing default routes
for _, defaultRoute := range defaultRoutes {
// only check the default routes for the current link
// TODO: we should also check others later
if defaultRoute.LinkIndex != linkIndex {
continue
}
key := defaultRoute.Gw.String()
if _, ok := expected[key]; !ok {
toRemove = append(toRemove, &defaultRoute)
continue
}
nm.logger.Warn().Str("gateway", key).Msg("keeping default route")
delete(expected, key)
}
// remove remaining default routes
for _, defaultRoute := range toRemove {
nm.logger.Warn().Str("gateway", defaultRoute.Gw.String()).Msg("removing default route")
if err := nm.RouteDel(defaultRoute); err != nil {
nm.logger.Warn().Err(err).Msg("failed to remove default route")
}
}
// add remaining expected default routes
for _, gateway := range expected {
nm.logger.Warn().Str("gateway", gateway.String()).Msg("adding default route")
route := &netlink.Route{
Dst: &ipv4DefaultRoute,
Gw: gateway,
LinkIndex: linkIndex,
}
if family == AfInet6 {
route.Dst = &ipv6DefaultRoute
}
if err := nm.RouteAdd(route); err != nil {
nm.logger.Warn().Err(err).Interface("route", route).Msg("failed to add default route")
}
added++
}
nm.logger.Info().
Int("added", added).
Int("removed", len(toRemove)).
Msg("default routes reconciled")
return nil
}
// ReconcileLink reconciles the addresses and routes of a link
func (nm *NetlinkManager) ReconcileLink(link *Link, expected []types.IPAddress, family int) error {
toAdd := make([]*types.IPAddress, 0)
toRemove := make([]*netlink.Addr, 0)
toUpdate := make([]*types.IPAddress, 0)
expectedAddrs := make(map[string]*types.IPAddress)
expectedGateways := make(map[string]net.IP)
mtu := link.Attrs().MTU
expectedMTU := mtu
// add all expected addresses to the map
for _, addr := range expected {
expectedAddrs[addr.String()] = &addr
if addr.Gateway != nil {
expectedGateways[addr.String()] = addr.Gateway
}
if addr.MTU != 0 {
mtu = addr.MTU
}
}
if expectedMTU != mtu {
if err := link.SetMTU(expectedMTU); err != nil {
nm.logger.Warn().Err(err).Int("expected_mtu", expectedMTU).Int("mtu", mtu).Msg("failed to set MTU")
}
}
addrs, err := nm.AddrList(link, family)
if err != nil {
return fmt.Errorf("failed to get addresses: %w", err)
}
// check existing addresses
for _, addr := range addrs {
// skip the link-local address
if addr.IP.IsLinkLocalUnicast() {
continue
}
expectedAddr, ok := expectedAddrs[addr.IPNet.String()]
if !ok {
toRemove = append(toRemove, &addr)
continue
}
// if it's not fully equal, we need to update it
if !expectedAddr.Compare(addr) {
toUpdate = append(toUpdate, expectedAddr)
continue
}
// remove it from expected addresses
delete(expectedAddrs, addr.IPNet.String())
}
// add remaining expected addresses
for _, addr := range expectedAddrs {
toAdd = append(toAdd, addr)
}
for _, addr := range toUpdate {
netlinkAddr := addr.NetlinkAddr()
if err := nm.AddrDel(link, &netlinkAddr); err != nil {
nm.logger.Warn().Err(err).Str("address", addr.Address.String()).Msg("failed to update address")
}
// we'll add it again later
toAdd = append(toAdd, addr)
}
for _, addr := range toAdd {
netlinkAddr := addr.NetlinkAddr()
if err := nm.AddrAdd(link, &netlinkAddr); err != nil {
nm.logger.Warn().Err(err).Str("address", addr.Address.String()).Msg("failed to add address")
}
}
for _, netlinkAddr := range toRemove {
if err := nm.AddrDel(link, netlinkAddr); err != nil {
nm.logger.Warn().Err(err).Str("address", netlinkAddr.IP.String()).Msg("failed to remove address")
}
}
for _, addr := range toAdd {
netlinkAddr := addr.NetlinkAddr()
if err := nm.AddrAdd(link, &netlinkAddr); err != nil {
nm.logger.Warn().Err(err).Str("address", addr.Address.String()).Msg("failed to add address")
}
}
actualToAdd := len(toAdd) - len(toUpdate)
if len(toAdd) > 0 || len(toUpdate) > 0 || len(toRemove) > 0 {
nm.logger.Info().
Int("added", actualToAdd).
Int("updated", len(toUpdate)).
Int("removed", len(toRemove)).
Msg("addresses reconciled")
}
if err := nm.reconcileDefaultRoute(link, expectedGateways, family); err != nil {
nm.logger.Warn().Err(err).Msg("failed to reconcile default route")
}
return nil
}