diff --git a/dev_deploy.sh b/dev_deploy.sh index 466acc7..267e2c7 100755 --- a/dev_deploy.sh +++ b/dev_deploy.sh @@ -98,9 +98,6 @@ if [ "$SKIP_UI_BUILD" = false ]; then make frontend fi -msg_info "▶ Building go binary" -make build_dev - if [ "$RUN_GO_TESTS" = true ]; then msg_info "▶ Building go tests" make build_dev_test @@ -121,6 +118,9 @@ tar zxvf device-tests.tar.gz EOF fi +msg_info "▶ Building go binary" +make build_dev + # Kill any existing instances of the application ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app_debug || true" @@ -128,6 +128,8 @@ ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app_debug || true" ssh "${REMOTE_USER}@${REMOTE_HOST}" "cat > ${REMOTE_PATH}/jetkvm_app_debug" < bin/jetkvm_app if [ "$RESET_USB_HID_DEVICE" = true ]; then + msg_info "▶ Resetting USB HID device" + msg_warn "The option has been deprecated and will be removed in a future version, as JetKVM will now reset USB gadget configuration when needed" # Remove the old USB gadget configuration ssh "${REMOTE_USER}@${REMOTE_HOST}" "rm -rf /sys/kernel/config/usb_gadget/jetkvm/configs/c.1/hid.usb*" ssh "${REMOTE_USER}@${REMOTE_HOST}" "ls /sys/class/udc > /sys/kernel/config/usb_gadget/jetkvm/UDC" diff --git a/go.mod b/go.mod index 6784a59..0e288db 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/jetkvm/kvm -go 1.23.0 +go 1.23.4 + +toolchain go1.24.3 require ( github.com/Masterminds/semver/v3 v3.3.0 @@ -12,21 +14,25 @@ require ( github.com/gin-contrib/logger v1.2.5 github.com/gin-gonic/gin v1.10.0 github.com/google/uuid v1.6.0 + github.com/guregu/null/v6 v6.0.0 github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf github.com/hanwen/go-fuse/v2 v2.5.1 - github.com/hashicorp/go-envparse v0.1.0 github.com/pion/logging v0.2.2 github.com/pion/mdns/v2 v2.0.7 github.com/pion/webrtc/v4 v4.0.0 github.com/pojntfx/go-nbd v0.3.2 github.com/prometheus/client_golang v1.21.0 github.com/prometheus/common v0.62.0 + github.com/prometheus/procfs v0.15.1 github.com/psanford/httpreadat v0.1.0 github.com/rs/zerolog v1.34.0 + github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f + github.com/stretchr/testify v1.10.0 github.com/vishvananda/netlink v1.3.0 go.bug.st/serial v1.6.2 golang.org/x/crypto v0.36.0 golang.org/x/net v0.38.0 + golang.org/x/sys v0.32.0 ) replace github.com/pojntfx/go-nbd v0.3.2 => github.com/chemhack/go-nbd v0.0.0-20241006125820-59e45f5b1e7b @@ -38,6 +44,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.5 // indirect github.com/creack/goselect v0.1.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/gin-contrib/sse v1.0.0 // indirect github.com/go-jose/go-jose/v4 v4.0.2 // indirect @@ -45,7 +52,6 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.26.0 // indirect github.com/goccy/go-json v0.10.5 // indirect - github.com/guregu/null/v6 v6.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect @@ -70,15 +76,14 @@ require ( github.com/pion/stun/v3 v3.0.0 // indirect github.com/pion/transport/v3 v3.0.7 // indirect github.com/pion/turn/v4 v4.0.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/procfs v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/wlynxg/anet v0.0.5 // indirect golang.org/x/arch v0.15.0 // indirect golang.org/x/oauth2 v0.24.0 // indirect - golang.org/x/sys v0.32.0 // indirect golang.org/x/text v0.23.0 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 3ad832a..14aadb1 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,6 @@ github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf h1:JO6ISZIvEUitto github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf/go.mod h1:5Kt9XkWvkGi2OHOq0QsGxebHmhCcqJ8KCbNg/a6+n+g= github.com/hanwen/go-fuse/v2 v2.5.1 h1:OQBE8zVemSocRxA4OaFJbjJ5hlpCmIWbGr7r0M4uoQQ= github.com/hanwen/go-fuse/v2 v2.5.1/go.mod h1:xKwi1cF7nXAOBCXujD5ie0ZKsxc8GGSA1rlMJc+8IJs= -github.com/hashicorp/go-envparse v0.1.0 h1:bE++6bhIsNCPLvgDZkYqo3nA+/PFI51pkrHdmPSDFPY= -github.com/hashicorp/go-envparse v0.1.0/go.mod h1:OHheN1GoygLlAkTlXLXvAdnXdZxy8JUweQ1rAXx1xnc= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= @@ -154,6 +152,8 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f h1:VgoRCP1efSCEZIcF2THLQ46+pIBzzgNiaUBe9wEDwYU= +github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f/go.mod h1:pzro7BGorij2WgrjEammtrkbo3+xldxo+KaGLGUiD+Q= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/internal/usbgadget/changeset.go b/internal/usbgadget/changeset.go new file mode 100644 index 0000000..4465160 --- /dev/null +++ b/internal/usbgadget/changeset.go @@ -0,0 +1,403 @@ +package usbgadget + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "reflect" + "time" + + "github.com/prometheus/procfs" + "github.com/sourcegraph/tf-dag/dag" +) + +// it's a minimalistic implementation of ansible's file module with some modifications +// to make it more suitable for our use case +// https://docs.ansible.com/ansible/latest/modules/file_module.html + +// we use this to check if the files in the gadget config are in the expected state +// and to update them if they are not in the expected state + +type FileState uint8 +type ChangeState uint8 +type FileChangeResolvedAction uint8 + +type ApplyFunc func(c *ChangeSet, changes []*FileChange) error + +const ( + FileStateUnknown FileState = iota + FileStateAbsent + FileStateDirectory + FileStateFile + FileStateFileContentMatch + FileStateFileWrite // update file content without checking + FileStateMounted + FileStateMountedConfigFS + FileStateSymlink + FileStateTouch +) + +var FileStateString = map[FileState]string{ + FileStateUnknown: "UNKNOWN", + FileStateAbsent: "ABSENT", + FileStateDirectory: "DIRECTORY", + FileStateFile: "FILE", + FileStateFileContentMatch: "FILE_CONTENT_MATCH", + FileStateFileWrite: "FILE_WRITE", + FileStateMounted: "MOUNTED", + FileStateMountedConfigFS: "CONFIGFS_MOUNT", + FileStateSymlink: "SYMLINK", + FileStateTouch: "TOUCH", +} + +const ( + ChangeStateUnknown ChangeState = iota + ChangeStateRequired + ChangeStateNotChanged + ChangeStateChanged + ChangeStateError +) + +const ( + FileChangeResolvedActionUnknown FileChangeResolvedAction = iota + FileChangeResolvedActionDoNothing + FileChangeResolvedActionRemove + FileChangeResolvedActionCreateFile + FileChangeResolvedActionWriteFile + FileChangeResolvedActionUpdateFile + FileChangeResolvedActionAppendFile + FileChangeResolvedActionCreateSymlink + FileChangeResolvedActionRecreateSymlink + FileChangeResolvedActionCreateDirectory + FileChangeResolvedActionRemoveDirectory + FileChangeResolvedActionTouch + FileChangeResolvedActionMountConfigFS +) + +var FileChangeResolvedActionString = map[FileChangeResolvedAction]string{ + FileChangeResolvedActionUnknown: "UNKNOWN", + FileChangeResolvedActionDoNothing: "DO_NOTHING", + FileChangeResolvedActionRemove: "REMOVE", + FileChangeResolvedActionCreateFile: "FILE_CREATE", + FileChangeResolvedActionWriteFile: "FILE_WRITE", + FileChangeResolvedActionUpdateFile: "FILE_UPDATE", + FileChangeResolvedActionAppendFile: "FILE_APPEND", + FileChangeResolvedActionCreateSymlink: "SYMLINK_CREATE", + FileChangeResolvedActionRecreateSymlink: "SYMLINK_RECREATE", + FileChangeResolvedActionCreateDirectory: "DIR_CREATE", + FileChangeResolvedActionRemoveDirectory: "DIR_REMOVE", + FileChangeResolvedActionTouch: "TOUCH", + FileChangeResolvedActionMountConfigFS: "CONFIGFS_MOUNT", +} + +type ChangeSet struct { + Changes []FileChange +} + +type RequestedFileChange struct { + Component string + Key string + Path string // will be used as Key if Key is empty + ExpectedState FileState + ExpectedContent []byte + DependsOn []string + BeforeChange []string // if the file is going to be changed, apply the change first + Description string + IgnoreErrors bool + When string // only apply the change if when meets the condition +} + +type FileChange struct { + RequestedFileChange + ActualState FileState + ActualContent []byte + resolvedDeps []string + checked bool + changed ChangeState + action FileChangeResolvedAction +} + +func (f *RequestedFileChange) String() string { + var s string + switch f.ExpectedState { + case FileStateDirectory: + s = fmt.Sprintf("dir: %s", f.Path) + case FileStateFile: + s = fmt.Sprintf("file: %s", f.Path) + case FileStateSymlink: + s = fmt.Sprintf("symlink: %s -> %s", f.Path, f.ExpectedContent) + case FileStateAbsent: + s = fmt.Sprintf("absent: %s", f.Path) + case FileStateFileContentMatch: + s = fmt.Sprintf("file: %s with content [%s]", f.Path, f.ExpectedContent) + case FileStateFileWrite: + s = fmt.Sprintf("write: %s with content [%s]", f.Path, f.ExpectedContent) + case FileStateMountedConfigFS: + s = fmt.Sprintf("configfs: %s", f.Path) + case FileStateTouch: + s = fmt.Sprintf("touch: %s", f.Path) + case FileStateUnknown: + s = fmt.Sprintf("unknown change for %s", f.Path) + default: + s = fmt.Sprintf("unknown expected state %d for %s", f.ExpectedState, f.Path) + } + + return s +} + +func (f *RequestedFileChange) IsSame(other *RequestedFileChange) bool { + return f.Path == other.Path && + f.ExpectedState == other.ExpectedState && + reflect.DeepEqual(f.ExpectedContent, other.ExpectedContent) && + reflect.DeepEqual(f.DependsOn, other.DependsOn) && + f.IgnoreErrors == other.IgnoreErrors +} + +func (fc *FileChange) checkIfDirIsMountPoint() error { + // check if the file is a mount point + mounts, err := procfs.GetMounts() + if err != nil { + return fmt.Errorf("failed to get mounts") + } + + for _, mount := range mounts { + if mount.MountPoint == fc.Path { + fc.ActualState = FileStateMounted + fc.ActualContent = []byte(mount.Source) + + if mount.FSType == "configfs" { + fc.ActualState = FileStateMountedConfigFS + } + + return nil + } + } + + return nil +} + +// GetActualState returns the actual state of the file at the given path. +func (fc *FileChange) getActualState() error { + l := defaultLogger.With().Str("path", fc.Path).Logger() + + fi, err := os.Lstat(fc.Path) + if err != nil { + if os.IsNotExist(err) { + fc.ActualState = FileStateAbsent + } else { + l.Warn().Err(err).Msg("failed to stat file") + fc.ActualState = FileStateUnknown + } + return nil + } + + // check if the file is a symlink + if fi.Mode()&os.ModeSymlink == os.ModeSymlink { + fc.ActualState = FileStateSymlink + // get the target of the symlink + target, err := os.Readlink(fc.Path) + if err != nil { + l.Warn().Err(err).Msg("failed to read symlink") + return fmt.Errorf("failed to read symlink") + } + // check if the target is a relative path + if !filepath.IsAbs(target) { + // make it absolute + target, err = filepath.Abs(filepath.Join(filepath.Dir(fc.Path), target)) + if err != nil { + l.Warn().Err(err).Msg("failed to make symlink target absolute") + return fmt.Errorf("failed to make symlink target absolute") + } + } + fc.ActualContent = []byte(target) + return nil + } + + if fi.IsDir() { + fc.ActualState = FileStateDirectory + + if fc.ExpectedState == FileStateMountedConfigFS { + err := fc.checkIfDirIsMountPoint() + if err != nil { + l.Warn().Err(err).Msg("failed to check if dir is mount point") + return err + } + } + return nil + } + + if fi.Mode()&os.ModeDevice == os.ModeDevice { + l.Info().Msg("file is a device") + return nil + } + + // check if the file is a regular file + if fi.Mode().IsRegular() { + fc.ActualState = FileStateFile + // get the content of the file + content, err := os.ReadFile(fc.Path) + if err != nil { + l.Warn().Err(err).Msg("failed to read file") + return fmt.Errorf("failed to read file") + } + fc.ActualContent = content + return nil + } + + l.Warn().Interface("file_info", fi.Mode()).Bool("is_dir", fi.IsDir()).Msg("unknown file type") + + return fmt.Errorf("unknown file type") +} + +func (fc *FileChange) ResetActionResolution() { + fc.checked = false + fc.action = FileChangeResolvedActionUnknown + fc.changed = ChangeStateUnknown +} + +func (fc *FileChange) Action() FileChangeResolvedAction { + if !fc.checked { + fc.action = fc.getFileChangeResolvedAction() + fc.checked = true + } + + return fc.action +} + +func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { + l := defaultLogger.With().Str("path", fc.Path).Logger() + + // some actions are not needed to be checked + switch fc.ExpectedState { + case FileStateFileWrite: + return FileChangeResolvedActionWriteFile + case FileStateTouch: + return FileChangeResolvedActionTouch + } + + // get the actual state of the file + err := fc.getActualState() + if err != nil { + return FileChangeResolvedActionDoNothing + } + + baseName := filepath.Base(fc.Path) + + switch fc.ExpectedState { + case FileStateDirectory: + // if the file is already a directory, do nothing + if fc.ActualState == FileStateDirectory { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionCreateDirectory + case FileStateFile: + // if the file is already a file, do nothing + if fc.ActualState == FileStateFile { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionCreateFile + case FileStateFileContentMatch: + // if the file is already a file with the expected content, do nothing + if fc.ActualState == FileStateFile { + looserMatch := baseName == "inquiry_string" + if compareFileContent(fc.ActualContent, fc.ExpectedContent, looserMatch) { + return FileChangeResolvedActionDoNothing + } + // TODO: move this to somewhere else + // this is a workaround for the fact that the file is not updated if it has no content + if baseName == "file" && + bytes.Equal(fc.ActualContent, []byte{}) && + bytes.Equal(fc.ExpectedContent, []byte{0x0a}) { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionUpdateFile + } + return FileChangeResolvedActionCreateFile + case FileStateSymlink: + // if the file is already a symlink, check if the target is the same + if fc.ActualState == FileStateSymlink { + if reflect.DeepEqual(fc.ActualContent, fc.ExpectedContent) { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionRecreateSymlink + } + return FileChangeResolvedActionCreateSymlink + case FileStateAbsent: + if fc.ActualState == FileStateAbsent { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionRemove + case FileStateMountedConfigFS: + if fc.ActualState == FileStateMountedConfigFS { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionMountConfigFS + default: + l.Warn().Interface("file_change", FileStateString[fc.ExpectedState]).Msg("unknown expected state") + return FileChangeResolvedActionDoNothing + } +} + +func (c *ChangeSet) AddFileChangeStruct(r RequestedFileChange) { + fc := FileChange{ + RequestedFileChange: r, + } + c.Changes = append(c.Changes, fc) +} + +func (c *ChangeSet) AddFileChange(component string, path string, expectedState FileState, expectedContent []byte, dependsOn []string, description string) { + c.AddFileChangeStruct(RequestedFileChange{ + Component: component, + Path: path, + ExpectedState: expectedState, + ExpectedContent: expectedContent, + DependsOn: dependsOn, + Description: description, + }) +} + +func (c *ChangeSet) ApplyChanges() error { + r := ChangeSetResolver{ + changeset: c, + g: &dag.AcyclicGraph{}, + } + + return r.Apply() +} + +func (c *ChangeSet) applyChange(change *FileChange) error { + switch change.Action() { + case FileChangeResolvedActionWriteFile: + return os.WriteFile(change.Path, change.ExpectedContent, 0644) + case FileChangeResolvedActionUpdateFile: + return os.WriteFile(change.Path, change.ExpectedContent, 0644) + case FileChangeResolvedActionCreateFile: + return os.WriteFile(change.Path, change.ExpectedContent, 0644) + case FileChangeResolvedActionCreateSymlink: + return os.Symlink(string(change.ExpectedContent), change.Path) + case FileChangeResolvedActionRecreateSymlink: + if err := os.Remove(change.Path); err != nil { + return fmt.Errorf("failed to remove symlink: %w", err) + } + return os.Symlink(string(change.ExpectedContent), change.Path) + case FileChangeResolvedActionCreateDirectory: + return os.MkdirAll(change.Path, 0755) + case FileChangeResolvedActionRemove: + return os.Remove(change.Path) + case FileChangeResolvedActionRemoveDirectory: + return os.RemoveAll(change.Path) + case FileChangeResolvedActionTouch: + return os.Chtimes(change.Path, time.Now(), time.Now()) + case FileChangeResolvedActionMountConfigFS: + return mountConfigFS(change.Path) + case FileChangeResolvedActionDoNothing: + return nil + default: + return fmt.Errorf("unknown action: %d", change.Action()) + } +} + +func (c *ChangeSet) Apply() error { + return c.ApplyChanges() +} diff --git a/internal/usbgadget/changeset_arm_test.go b/internal/usbgadget/changeset_arm_test.go new file mode 100644 index 0000000..c71c9f6 --- /dev/null +++ b/internal/usbgadget/changeset_arm_test.go @@ -0,0 +1,41 @@ +//go:build arm && linux + +package usbgadget + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + usbConfig = &Config{ + VendorId: "0x1d6b", //The Linux Foundation + ProductId: "0x0104", //Multifunction Composite Gadget + SerialNumber: "", + Manufacturer: "JetKVM", + Product: "USB Emulation Device", + strictMode: true, + } + usbDevices = &Devices{ + AbsoluteMouse: true, + RelativeMouse: true, + Keyboard: true, + MassStorage: true, + } + usbGadgetName = "jetkvm" + usbGadget *UsbGadget +) + +func TestUsbGadgetInit(t *testing.T) { + assert := assert.New(t) + usbGadget = NewUsbGadget(usbGadgetName, usbDevices, usbConfig, nil) + + assert.NotNil(usbGadget) +} + +func TestUsbGadgetStrictModeInitFail(t *testing.T) { + usbConfig.strictMode = true + u := NewUsbGadget("test", usbDevices, usbConfig, nil) + assert.Nil(t, u, "should be nil") +} diff --git a/internal/usbgadget/changeset_resolver.go b/internal/usbgadget/changeset_resolver.go new file mode 100644 index 0000000..a4bc546 --- /dev/null +++ b/internal/usbgadget/changeset_resolver.go @@ -0,0 +1,177 @@ +package usbgadget + +import ( + "fmt" + + "github.com/sourcegraph/tf-dag/dag" +) + +type ChangeSetResolver struct { + changeset *ChangeSet + + g *dag.AcyclicGraph + + changesMap map[string]*FileChange + conditionalChangesMap map[string]*FileChange + + orderedChanges []dag.Vertex + resolvedChanges []*FileChange + additionalResolveRequired bool +} + +func (c *ChangeSetResolver) toOrderedChanges() error { + for key, change := range c.changesMap { + v := c.g.Add(key) + + for _, dependsOn := range change.DependsOn { + c.g.Connect(dag.BasicEdge(dependsOn, v)) + } + for _, dependsOn := range change.resolvedDeps { + c.g.Connect(dag.BasicEdge(dependsOn, v)) + } + } + + cycles := c.g.Cycles() + if len(cycles) > 0 { + return fmt.Errorf("cycles detected: %v", cycles) + } + + orderedChanges := c.g.TopologicalOrder() + c.orderedChanges = orderedChanges + return nil +} + +func (c *ChangeSetResolver) doResolveChanges(initial bool) error { + resolvedChanges := make([]*FileChange, 0) + + for _, key := range c.orderedChanges { + change := c.changesMap[key.(string)] + if !initial { + change.ResetActionResolution() + } + + resolvedAction := change.Action() + + resolvedChanges = append(resolvedChanges, change) + // no need to check the triggers if there's no change + if resolvedAction == FileChangeResolvedActionDoNothing { + continue + } + + if !initial { + continue + } + + if change.BeforeChange != nil { + change.resolvedDeps = append(change.resolvedDeps, change.BeforeChange...) + c.additionalResolveRequired = true + + // add the dependencies to the changes map + for _, dep := range change.BeforeChange { + depChange, ok := c.conditionalChangesMap[dep] + if !ok { + return fmt.Errorf("dependency %s not found", dep) + } + + c.changesMap[dep] = depChange + } + } + } + + c.resolvedChanges = resolvedChanges + return nil +} + +func (c *ChangeSetResolver) resolveChanges(initial bool) error { + // get the ordered changes + err := c.toOrderedChanges() + if err != nil { + return err + } + + // resolve the changes + err = c.doResolveChanges(initial) + if err != nil { + return err + } + + if !c.additionalResolveRequired || !initial { + return nil + } + + return c.resolveChanges(false) +} + +func (c *ChangeSetResolver) applyChanges() error { + for _, change := range c.resolvedChanges { + change.ResetActionResolution() + action := change.Action() + actionStr := FileChangeResolvedActionString[action] + + l := defaultLogger.Info() + if action == FileChangeResolvedActionDoNothing { + l = defaultLogger.Trace() + } + + l.Str("action", actionStr).Str("change", change.String()).Msg("applying change") + + err := c.changeset.applyChange(change) + if err != nil { + return err + } + } + + return nil +} + +func (c *ChangeSetResolver) GetChanges() ([]*FileChange, error) { + localChanges := c.changeset.Changes + changesMap := make(map[string]*FileChange) + conditionalChangesMap := make(map[string]*FileChange) + + // build the map of the changes + for _, change := range localChanges { + key := change.Key + if key == "" { + key = change.Path + } + + // remove it from the map first + if change.When != "" { + conditionalChangesMap[key] = &change + continue + } + + if _, ok := changesMap[key]; ok { + if changesMap[key].IsSame(&change.RequestedFileChange) { + continue + } + return nil, fmt.Errorf( + "duplicate change: %s, current: %s, requested: %s", + key, + changesMap[key].String(), + change.String(), + ) + } + + changesMap[key] = &change + } + + c.changesMap = changesMap + c.conditionalChangesMap = conditionalChangesMap + + err := c.resolveChanges(true) + if err != nil { + return nil, err + } + + return c.resolvedChanges, nil +} + +func (c *ChangeSetResolver) Apply() error { + if _, err := c.GetChanges(); err != nil { + return err + } + + return c.applyChanges() +} diff --git a/internal/usbgadget/config.go b/internal/usbgadget/config.go index 5c287da..f4c9ce4 100644 --- a/internal/usbgadget/config.go +++ b/internal/usbgadget/config.go @@ -4,9 +4,6 @@ import ( "fmt" "os" "os/exec" - "path" - "path/filepath" - "sort" ) type gadgetConfigItem struct { @@ -160,15 +157,15 @@ func (u *UsbGadget) OverrideGadgetConfig(itemKey string, itemAttr string, value return nil, true } -func mountConfigFS() error { - _, err := os.Stat(gadgetPath) +func mountConfigFS(path string) error { + _, err := os.Stat(path) // TODO: check if it's mounted properly if err == nil { return nil } if os.IsNotExist(err) { - err = exec.Command("mount", "-t", "configfs", "none", configFSPath).Run() + err = exec.Command("mount", "-t", "configfs", "none", path).Run() if err != nil { return fmt.Errorf("failed to mount configfs: %w", err) } @@ -186,26 +183,19 @@ func (u *UsbGadget) Init() error { udcs := getUdcs() if len(udcs) < 1 { - u.log.Error().Msg("no udc found, skipping USB stack init") - return nil + return u.logWarn("no udc found, skipping USB stack init", nil) } u.udc = udcs[0] - _, err := os.Stat(u.kvmGadgetPath) - if err == nil { - u.log.Info().Msg("usb gadget already exists") - } - if err := mountConfigFS(); err != nil { - u.log.Error().Err(err).Msg("failed to mount configfs, usb stack might not function properly") - } - - if err := os.MkdirAll(u.configC1Path, 0755); err != nil { - u.log.Error().Err(err).Msg("failed to create config path") - } - - if err := u.writeGadgetConfig(); err != nil { - u.log.Error().Err(err).Msg("failed to start gadget") + err := u.WithTransaction(func() error { + u.tx.MountConfigFS() + u.tx.CreateConfigPath() + u.tx.WriteGadgetConfig() + return nil + }) + if err != nil { + return u.logError("unable to initialize USB stack", err) } return nil @@ -217,143 +207,13 @@ func (u *UsbGadget) UpdateGadgetConfig() error { u.loadGadgetConfig() - if err := u.writeGadgetConfig(); err != nil { - u.log.Error().Err(err).Msg("failed to update gadget") - } - - return nil -} - -func (u *UsbGadget) getOrderedConfigItems() orderedGadgetConfigItems { - items := make([]gadgetConfigItemWithKey, 0) - for key, item := range u.configMap { - items = append(items, gadgetConfigItemWithKey{key, item}) - } - - sort.Slice(items, func(i, j int) bool { - return items[i].item.order < items[j].item.order + err := u.WithTransaction(func() error { + u.tx.WriteGadgetConfig() + return nil }) - - return items -} - -func (u *UsbGadget) writeGadgetConfig() error { - // create kvm gadget path - err := os.MkdirAll(u.kvmGadgetPath, 0755) if err != nil { - return err - } - - u.log.Trace().Msg("writing gadget config") - for _, val := range u.getOrderedConfigItems() { - key := val.key - item := val.item - - // check if the item is enabled in the config - if !u.isGadgetConfigItemEnabled(key) { - u.log.Trace().Str("key", key).Msg("disabling gadget config") - err = u.disableGadgetItemConfig(item) - if err != nil { - return err - } - continue - } - u.log.Trace().Str("key", key).Msg("writing gadget config") - err = u.writeGadgetItemConfig(item) - if err != nil { - return err - } - } - - if err = u.writeUDC(); err != nil { - u.log.Error().Err(err).Msg("failed to write UDC") - return err - } - - if err = u.rebindUsb(true); err != nil { - u.log.Info().Err(err).Msg("failed to rebind usb") + return u.logError("unable to update gadget config", err) } return nil } - -func (u *UsbGadget) disableGadgetItemConfig(item gadgetConfigItem) error { - // remove symlink if exists - if item.configPath == nil { - return nil - } - - configPath := joinPath(u.configC1Path, item.configPath) - - if _, err := os.Lstat(configPath); os.IsNotExist(err) { - u.log.Trace().Str("path", configPath).Msg("symlink does not exist") - return nil - } - - if err := os.Remove(configPath); err != nil { - return fmt.Errorf("failed to remove symlink %s: %w", item.configPath, err) - } - - return nil -} - -func (u *UsbGadget) writeGadgetItemConfig(item gadgetConfigItem) error { - // create directory for the item - gadgetItemPath := joinPath(u.kvmGadgetPath, item.path) - err := os.MkdirAll(gadgetItemPath, 0755) - if err != nil { - return fmt.Errorf("failed to create path %s: %w", gadgetItemPath, err) - } - - if len(item.attrs) > 0 { - // write attributes for the item - err = u.writeGadgetAttrs(gadgetItemPath, item.attrs) - if err != nil { - return fmt.Errorf("failed to write attributes for %s: %w", gadgetItemPath, err) - } - } - - // write report descriptor if available - if item.reportDesc != nil { - err = u.writeIfDifferent(path.Join(gadgetItemPath, "report_desc"), item.reportDesc, 0644) - if err != nil { - return err - } - } - - // create config directory if configAttrs are set - if len(item.configAttrs) > 0 { - configItemPath := joinPath(u.configC1Path, item.configPath) - err = os.MkdirAll(configItemPath, 0755) - if err != nil { - return fmt.Errorf("failed to create path %s: %w", configItemPath, err) - } - - err = u.writeGadgetAttrs(configItemPath, item.configAttrs) - if err != nil { - return fmt.Errorf("failed to write config attributes for %s: %w", configItemPath, err) - } - } - - // create symlink if configPath is set - if item.configPath != nil && item.configAttrs == nil { - configPath := joinPath(u.configC1Path, item.configPath) - u.log.Trace().Str("source", configPath).Str("target", gadgetItemPath).Msg("creating symlink") - if err := ensureSymlink(configPath, gadgetItemPath); err != nil { - return err - } - } - - return nil -} - -func (u *UsbGadget) writeGadgetAttrs(basePath string, attrs gadgetAttributes) error { - for key, val := range attrs { - filePath := filepath.Join(basePath, key) - err := u.writeIfDifferent(filePath, []byte(val), 0644) - if err != nil { - return fmt.Errorf("failed to write to %s: %w", filePath, err) - } - } - return nil -} diff --git a/internal/usbgadget/config_tx.go b/internal/usbgadget/config_tx.go new file mode 100644 index 0000000..b4f1be0 --- /dev/null +++ b/internal/usbgadget/config_tx.go @@ -0,0 +1,294 @@ +package usbgadget + +import ( + "fmt" + "path" + "path/filepath" + "sort" + + "github.com/rs/zerolog" +) + +// no os package should occur in this file + +type UsbGadgetTransaction struct { + c *ChangeSet + + // below are the fields that are needed to be set by the caller + log *zerolog.Logger + udc string + dwc3Path string + kvmGadgetPath string + configC1Path string + orderedConfigItems orderedGadgetConfigItems + isGadgetConfigItemEnabled func(key string) bool +} + +func (u *UsbGadget) newUsbGadgetTransaction(lock bool) error { + if lock { + u.txLock.Lock() + defer u.txLock.Unlock() + } + + if u.tx != nil { + return fmt.Errorf("transaction already exists") + } + + tx := &UsbGadgetTransaction{ + c: &ChangeSet{}, + log: u.log, + udc: u.udc, + dwc3Path: dwc3Path, + kvmGadgetPath: u.kvmGadgetPath, + configC1Path: u.configC1Path, + orderedConfigItems: u.getOrderedConfigItems(), + isGadgetConfigItemEnabled: u.isGadgetConfigItemEnabled, + } + u.tx = tx + + return nil +} + +func (u *UsbGadget) WithTransaction(fn func() error) error { + u.txLock.Lock() + defer u.txLock.Unlock() + + err := u.newUsbGadgetTransaction(false) + if err != nil { + u.log.Error().Err(err).Msg("failed to create transaction") + return err + } + if err := fn(); err != nil { + u.log.Error().Err(err).Msg("transaction failed") + return err + } + result := u.tx.Commit() + u.tx = nil + + return result +} + +func (tx *UsbGadgetTransaction) addFileChange(component string, change RequestedFileChange) string { + change.Component = component + tx.c.AddFileChangeStruct(change) + + key := change.Key + if key == "" { + key = change.Path + } + return key +} + +func (tx *UsbGadgetTransaction) mkdirAll(component string, path string, description string) string { + return tx.addFileChange(component, RequestedFileChange{ + Path: path, + ExpectedState: FileStateDirectory, + Description: description, + }) +} + +func (tx *UsbGadgetTransaction) removeFile(component string, path string, description string) string { + return tx.addFileChange(component, RequestedFileChange{ + Path: path, + ExpectedState: FileStateAbsent, + Description: description, + }) +} + +func (tx *UsbGadgetTransaction) Commit() error { + err := tx.c.Apply() + if err != nil { + tx.log.Error().Err(err).Msg("failed to update usbgadget configuration") + return err + } + tx.log.Info().Msg("usbgadget configuration updated") + return nil +} + +func (u *UsbGadget) getOrderedConfigItems() orderedGadgetConfigItems { + items := make([]gadgetConfigItemWithKey, 0) + for key, item := range u.configMap { + items = append(items, gadgetConfigItemWithKey{key, item}) + } + + sort.Slice(items, func(i, j int) bool { + return items[i].item.order < items[j].item.order + }) + + return items +} + +func (tx *UsbGadgetTransaction) MountConfigFS() { + tx.addFileChange("gadget", RequestedFileChange{ + Path: configFSPath, + ExpectedState: FileStateMountedConfigFS, + Description: "mount configfs", + }) +} + +func (tx *UsbGadgetTransaction) CreateConfigPath() { + tx.mkdirAll("gadget", tx.configC1Path, "create config path") +} + +func (tx *UsbGadgetTransaction) WriteGadgetConfig() { + // create kvm gadget path + tx.mkdirAll("gadget", tx.kvmGadgetPath, "create kvm gadget path") + + deps := make([]string, 0) + + for _, val := range tx.orderedConfigItems { + key := val.key + item := val.item + + // check if the item is enabled in the config + if !tx.isGadgetConfigItemEnabled(key) { + tx.DisableGadgetItemConfig(item) + continue + } + deps = tx.writeGadgetItemConfig(item, deps) + } + + tx.WriteUDC() +} + +func (tx *UsbGadgetTransaction) DisableGadgetItemConfig(item gadgetConfigItem) { + // remove symlink if exists + if item.configPath == nil { + return + } + + configPath := joinPath(tx.configC1Path, item.configPath) + _ = tx.removeFile("gadget", configPath, "remove symlink: disable gadget config") +} + +func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, deps []string) []string { + component := item.device + + // create directory for the item + files := make([]string, 0) + files = append(files, deps...) + + gadgetItemPath := joinPath(tx.kvmGadgetPath, item.path) + files = append(files, tx.mkdirAll(component, gadgetItemPath, "create gadget item directory")) + + beforeChange := make([]string, 0) + disableGadgetItemKey := fmt.Sprintf("disable-%s", item.device) + if item.configPath != nil && item.configAttrs == nil { + beforeChange = append(beforeChange, disableGadgetItemKey) + } + + if len(item.attrs) > 0 { + // write attributes for the item + files = append(files, tx.writeGadgetAttrs( + gadgetItemPath, + item.attrs, + component, + beforeChange, + )...) + } + + // write report descriptor if available + reportDescPath := path.Join(gadgetItemPath, "report_desc") + if item.reportDesc != nil { + tx.addFileChange(component, RequestedFileChange{ + Path: reportDescPath, + ExpectedState: FileStateFileContentMatch, + ExpectedContent: item.reportDesc, + Description: "write report descriptor", + BeforeChange: beforeChange, + DependsOn: files, + }) + } else { + tx.addFileChange(component, RequestedFileChange{ + Path: reportDescPath, + ExpectedState: FileStateAbsent, + Description: "remove report descriptor", + BeforeChange: beforeChange, + DependsOn: files, + }) + } + files = append(files, reportDescPath) + + // create config directory if configAttrs are set + if len(item.configAttrs) > 0 { + configItemPath := joinPath(tx.configC1Path, item.configPath) + tx.mkdirAll(component, configItemPath, "create config item directory") + files = append(files, tx.writeGadgetAttrs( + configItemPath, + item.configAttrs, + component, + beforeChange, + )...) + } + + // create symlink if configPath is set + if item.configPath != nil && item.configAttrs == nil { + configPath := joinPath(tx.configC1Path, item.configPath) + + // the change will be only applied by `beforeChange` + tx.addFileChange(component, RequestedFileChange{ + Key: disableGadgetItemKey, + Path: configPath, + ExpectedState: FileStateAbsent, + When: "beforeChange", // TODO: make it more flexible + Description: "remove symlink", + }) + + tx.addFileChange(component, RequestedFileChange{ + Path: configPath, + ExpectedState: FileStateSymlink, + ExpectedContent: []byte(gadgetItemPath), + Description: "create symlink", + DependsOn: files, + }) + } + + return files +} + +func (tx *UsbGadgetTransaction) writeGadgetAttrs(basePath string, attrs gadgetAttributes, component string, beforeChange []string) (files []string) { + files = make([]string, 0) + for key, val := range attrs { + filePath := filepath.Join(basePath, key) + tx.addFileChange(component, RequestedFileChange{ + Path: filePath, + ExpectedState: FileStateFileContentMatch, + ExpectedContent: []byte(val), + Description: "write gadget attribute", + DependsOn: []string{basePath}, + BeforeChange: beforeChange, + }) + files = append(files, filePath) + } + return files +} + +func (tx *UsbGadgetTransaction) WriteUDC() { + // bound the gadget to a UDC (USB Device Controller) + path := path.Join(tx.kvmGadgetPath, "UDC") + tx.addFileChange("udc", RequestedFileChange{ + Path: path, + ExpectedState: FileStateFileContentMatch, + ExpectedContent: []byte(tx.udc), + Description: "write UDC", + }) +} + +func (tx *UsbGadgetTransaction) RebindUsb(ignoreUnbindError bool) { + // remove the gadget from the UDC + tx.addFileChange("udc", RequestedFileChange{ + Path: path.Join(tx.dwc3Path, "unbind"), + ExpectedState: FileStateFileWrite, + ExpectedContent: []byte(tx.udc), + Description: "unbind UDC", + }) + // bind the gadget to the UDC + tx.addFileChange("udc", RequestedFileChange{ + Path: path.Join(tx.dwc3Path, "bind"), + ExpectedState: FileStateFileWrite, + ExpectedContent: []byte(tx.udc), + Description: "bind UDC", + DependsOn: []string{path.Join(tx.dwc3Path, "unbind")}, + IgnoreErrors: ignoreUnbindError, + }) +} diff --git a/internal/usbgadget/log.go b/internal/usbgadget/log.go new file mode 100644 index 0000000..f979f6c --- /dev/null +++ b/internal/usbgadget/log.go @@ -0,0 +1,27 @@ +package usbgadget + +import ( + "errors" +) + +func (u *UsbGadget) logWarn(msg string, err error) error { + if err == nil { + err = errors.New(msg) + } + if u.strictMode { + return err + } + u.log.Warn().Err(err).Msg(msg) + return nil +} + +func (u *UsbGadget) logError(msg string, err error) error { + if err == nil { + err = errors.New(msg) + } + if u.strictMode { + return err + } + u.log.Error().Err(err).Msg(msg) + return nil +} diff --git a/internal/usbgadget/udc.go b/internal/usbgadget/udc.go index 84dfbe4..4b7fbe3 100644 --- a/internal/usbgadget/udc.go +++ b/internal/usbgadget/udc.go @@ -50,18 +50,6 @@ func (u *UsbGadget) RebindUsb(ignoreUnbindError bool) error { return u.rebindUsb(ignoreUnbindError) } -func (u *UsbGadget) writeUDC() error { - path := path.Join(u.kvmGadgetPath, "UDC") - - u.log.Trace().Str("udc", u.udc).Str("path", path).Msg("writing UDC") - err := u.writeIfDifferent(path, []byte(u.udc), 0644) - if err != nil { - return fmt.Errorf("failed to write UDC: %w", err) - } - - return nil -} - // GetUsbState returns the current state of the USB gadget func (u *UsbGadget) GetUsbState() (state string) { stateFile := path.Join("/sys/class/udc", u.udc, "state") diff --git a/internal/usbgadget/usbgadget.go b/internal/usbgadget/usbgadget.go index 1dff2f3..663aa22 100644 --- a/internal/usbgadget/usbgadget.go +++ b/internal/usbgadget/usbgadget.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/zerolog" ) @@ -28,7 +29,8 @@ type Config struct { Manufacturer string `json:"manufacturer"` Product string `json:"product"` - isEmpty bool + strictMode bool // when it's enabled, all warnings will be converted to errors + isEmpty bool } var defaultUsbGadgetDevices = Devices{ @@ -59,22 +61,27 @@ type UsbGadget struct { enabledDevices Devices + strictMode bool // only intended for testing for now + absMouseAccumulatedWheelY float64 lastUserInput time.Time + tx *UsbGadgetTransaction + txLock sync.Mutex + log *zerolog.Logger } const configFSPath = "/sys/kernel/config" const gadgetPath = "/sys/kernel/config/usb_gadget" -var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel) +var defaultLogger = logging.GetSubsystemLogger("usbgadget") // NewUsbGadget creates a new UsbGadget. func NewUsbGadget(name string, enabledDevices *Devices, config *Config, logger *zerolog.Logger) *UsbGadget { if logger == nil { - logger = &defaultLogger + logger = defaultLogger } if enabledDevices == nil { @@ -95,10 +102,13 @@ func NewUsbGadget(name string, enabledDevices *Devices, config *Config, logger * keyboardLock: sync.Mutex{}, absMouseLock: sync.Mutex{}, relMouseLock: sync.Mutex{}, + txLock: sync.Mutex{}, enabledDevices: *enabledDevices, lastUserInput: time.Now(), log: logger, + strictMode: config.strictMode, + absMouseAccumulatedWheelY: 0, } if err := g.Init(); err != nil { diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index 0e796c8..7f20036 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -3,8 +3,9 @@ package usbgadget import ( "bytes" "fmt" - "os" "path/filepath" + "strconv" + "strings" ) // Helper function to get absolute value of float64 @@ -20,44 +21,68 @@ func joinPath(basePath string, paths []string) string { return filepath.Join(pathArr...) } -func ensureSymlink(linkPath string, target string) error { - if _, err := os.Lstat(linkPath); err == nil { - currentTarget, err := os.Readlink(linkPath) - if err != nil || currentTarget != target { - err = os.Remove(linkPath) - if err != nil { - return fmt.Errorf("failed to remove existing symlink %s: %w", linkPath, err) - } - } - } else if !os.IsNotExist(err) { - return fmt.Errorf("failed to check if symlink exists: %w", err) +func hexToDecimal(hex string) (int64, error) { + decimal, err := strconv.ParseInt(hex, 16, 64) + if err != nil { + return 0, err } - - if err := os.Symlink(target, linkPath); err != nil { - return fmt.Errorf("failed to create symlink from %s to %s: %w", linkPath, target, err) - } - - return nil + return decimal, nil } -func (u *UsbGadget) writeIfDifferent(filePath string, content []byte, permMode os.FileMode) error { - if _, err := os.Stat(filePath); err == nil { - oldContent, err := os.ReadFile(filePath) - if err == nil { - if bytes.Equal(oldContent, content) { - u.log.Trace().Str("path", filePath).Msg("skipping writing to as it already has the correct content") - return nil - } +func decimalToOctal(decimal int64) string { + return fmt.Sprintf("%04o", decimal) +} - if len(oldContent) == len(content)+1 && - bytes.Equal(oldContent[:len(content)], content) && - oldContent[len(content)] == 10 { - u.log.Trace().Str("path", filePath).Msg("skipping writing to as it already has the correct content") - return nil - } +func hexToOctal(hex string) (string, error) { + hex = strings.ToLower(hex) + hex = strings.Replace(hex, "0x", "", 1) //remove 0x or 0X - u.log.Trace().Str("path", filePath).Bytes("old", oldContent).Bytes("new", content).Msg("writing to as it has different content") + decimal, err := hexToDecimal(hex) + if err != nil { + return "", err + } + + // Convert the decimal integer to an octal string. + octal := decimalToOctal(decimal) + return octal, nil +} + +func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) bool { + if bytes.Equal(oldContent, newContent) { + return true + } + + if len(oldContent) == len(newContent)+1 && + bytes.Equal(oldContent[:len(newContent)], newContent) && + oldContent[len(newContent)] == 10 { + return true + } + + if len(newContent) == 4 { + if len(oldContent) < 6 || len(oldContent) > 7 { + return false + } + + if len(oldContent) == 7 && oldContent[6] == 0x0a { + oldContent = oldContent[:6] + } + + oldOctalValue, err := hexToOctal(string(oldContent)) + if err != nil { + return false + } + + if oldOctalValue == string(newContent) { + return true } } - return os.WriteFile(filePath, content, permMode) + + if looserMatch { + oldContentStr := strings.TrimSpace(string(oldContent)) + newContentStr := strings.TrimSpace(string(newContent)) + + return oldContentStr == newContentStr + } + + return false }