kvm/internal/usbgadget/changeset_symlink.go

137 lines
3.1 KiB
Go

package usbgadget
import (
"fmt"
"os"
"path"
"path/filepath"
"reflect"
"github.com/rs/zerolog"
)
type symlink struct {
Path string
Target string
}
func compareSymlinks(expected []symlink, actual []symlink) bool {
if len(expected) != len(actual) {
return false
}
return reflect.DeepEqual(expected, actual)
}
func checkIfSymlinksInOrder(fc *FileChange, logger *zerolog.Logger) (FileState, error) {
if logger == nil {
logger = defaultLogger
}
l := logger.With().Str("path", fc.Path).Logger()
if fc.ParamSymlinks == nil || len(fc.ParamSymlinks) == 0 {
return FileStateUnknown, fmt.Errorf("no symlinks to check")
}
fi, err := os.Lstat(fc.Path)
if err != nil {
if os.IsNotExist(err) {
return FileStateAbsent, nil
} else {
l.Warn().Err(err).Msg("failed to stat file")
return FileStateUnknown, fmt.Errorf("failed to stat file")
}
}
if !fi.IsDir() {
return FileStateUnknown, fmt.Errorf("file is not a directory")
}
files, err := os.ReadDir(fc.Path)
symlinks := make([]symlink, 0)
if err != nil {
return FileStateUnknown, fmt.Errorf("failed to read directory")
}
for _, file := range files {
if file.Type()&os.ModeSymlink != os.ModeSymlink {
continue
}
path := filepath.Join(fc.Path, file.Name())
target, err := os.Readlink(path)
if err != nil {
return FileStateUnknown, fmt.Errorf("failed to read symlink")
}
if !filepath.IsAbs(target) {
target = filepath.Join(fc.Path, target)
newTarget, err := filepath.Abs(target)
if err != nil {
return FileStateUnknown, fmt.Errorf("failed to get absolute path")
}
target = newTarget
}
symlinks = append(symlinks, symlink{
Path: path,
Target: target,
})
}
// compare the symlinks with the expected symlinks
if compareSymlinks(fc.ParamSymlinks, symlinks) {
return FileStateSymlinkInOrderConfigFS, nil
}
l.Trace().Interface("expected", fc.ParamSymlinks).Interface("actual", symlinks).Msg("symlinks are not in order")
return FileStateSymlinkNotInOrderConfigFS, nil
}
func recreateSymlinks(fc *FileChange, logger *zerolog.Logger) error {
if logger == nil {
logger = defaultLogger
}
// remove all symlinks
files, err := os.ReadDir(fc.Path)
if err != nil {
return fmt.Errorf("failed to read directory")
}
l := logger.With().Str("path", fc.Path).Logger()
l.Info().Msg("recreate symlinks")
for _, file := range files {
if file.Type()&os.ModeSymlink != os.ModeSymlink {
continue
}
l.Info().Str("name", file.Name()).Msg("remove symlink")
err := os.Remove(path.Join(fc.Path, file.Name()))
if err != nil {
return fmt.Errorf("failed to remove symlink")
}
}
l.Info().Interface("param-symlinks", fc.ParamSymlinks).Msg("create symlinks")
// create the symlinks
for _, symlink := range fc.ParamSymlinks {
l.Info().Str("name", symlink.Path).Str("target", symlink.Target).Msg("create symlink")
path := symlink.Path
if !filepath.IsAbs(path) {
path = filepath.Join(fc.Path, path)
}
err := os.Symlink(symlink.Target, path)
if err != nil {
l.Warn().Err(err).Msg("failed to create symlink")
return fmt.Errorf("failed to create symlink")
}
}
return nil
}