From 7a9fb7cbb11a3c964037ba1be2df70c830dd482a Mon Sep 17 00:00:00 2001 From: Aveline <352441+ym@users.noreply.github.com> Date: Mon, 19 May 2025 21:48:43 +0200 Subject: [PATCH] chore(usbgadget): update usbgadget config only when needed (#474) --- .github/workflows/smoketest.yml | 9 + dev_deploy.sh | 29 +- go.mod | 15 +- go.sum | 4 +- internal/usbgadget/changeset.go | 432 +++++++++++++++++++++++ internal/usbgadget/changeset_arm_test.go | 115 ++++++ internal/usbgadget/changeset_resolver.go | 183 ++++++++++ internal/usbgadget/changeset_symlink.go | 136 +++++++ internal/usbgadget/config.go | 172 +-------- internal/usbgadget/config_tx.go | 329 +++++++++++++++++ internal/usbgadget/log.go | 27 ++ internal/usbgadget/udc.go | 12 - internal/usbgadget/usbgadget.go | 22 +- internal/usbgadget/utils.go | 93 +++-- 14 files changed, 1357 insertions(+), 221 deletions(-) create mode 100644 internal/usbgadget/changeset.go create mode 100644 internal/usbgadget/changeset_arm_test.go create mode 100644 internal/usbgadget/changeset_resolver.go create mode 100644 internal/usbgadget/changeset_symlink.go create mode 100644 internal/usbgadget/config_tx.go create mode 100644 internal/usbgadget/log.go diff --git a/.github/workflows/smoketest.yml b/.github/workflows/smoketest.yml index d5493e7..3f8eb6c 100644 --- a/.github/workflows/smoketest.yml +++ b/.github/workflows/smoketest.yml @@ -69,6 +69,15 @@ jobs: CI_USER: ${{ vars.JETKVM_CI_USER }} CI_HOST: ${{ vars.JETKVM_CI_HOST }} CI_SSH_PRIVATE: ${{ secrets.JETKVM_CI_SSH_PRIVATE }} + - name: Run tests + run: | + set -e + make build_dev_test + + echo "+ Copying device-tests.tar.gz to remote host" + ssh jkci "cat > /userdata/jetkvm/device-tests.tar.gz" < device-tests.tar.gz + echo "+ Running go tests" + ssh jkci "cd /userdata/jetkvm && tar zxvf device-tests.tar.gz && ./run_all_tests -json" - name: Deploy application run: | set -e diff --git a/dev_deploy.sh b/dev_deploy.sh index 466acc7..7bd649b 100755 --- a/dev_deploy.sh +++ b/dev_deploy.sh @@ -26,6 +26,8 @@ show_help() { echo "Optional:" echo " -u, --user Remote username (default: root)" echo " --run-go-tests Run go tests" + echo " --run-go-tests-only Run go tests and exit" + echo " --run-go-tests-json Run go tests and output JSON" echo " --skip-ui-build Skip frontend/UI build" echo " --help Display this help message" echo @@ -42,6 +44,8 @@ RESET_USB_HID_DEVICE=false LOG_TRACE_SCOPES="${LOG_TRACE_SCOPES:-jetkvm,cloud,websocket,native,jsonrpc}" RUN_GO_TESTS=false RUN_GO_TESTS_JSON=false +RUN_GO_TESTS_ONLY=false + # Parse command line arguments while [[ $# -gt 0 ]]; do case $1 in @@ -67,6 +71,12 @@ while [[ $# -gt 0 ]]; do ;; --run-go-tests-json) RUN_GO_TESTS_JSON=true + RUN_GO_TESTS=true + shift + ;; + --run-go-tests-only) + RUN_GO_TESTS_ONLY=true + RUN_GO_TESTS=true shift ;; --help) @@ -81,10 +91,6 @@ while [[ $# -gt 0 ]]; do esac done -if [ "$RUN_GO_TESTS_JSON" = true ]; then - RUN_GO_TESTS=true -fi - # Verify required parameters if [ -z "$REMOTE_HOST" ]; then msg_err "Error: Remote IP is a required parameter" @@ -98,9 +104,6 @@ if [ "$SKIP_UI_BUILD" = false ]; then make frontend fi -msg_info "▶ Building go binary" -make build_dev - if [ "$RUN_GO_TESTS" = true ]; then msg_info "▶ Building go tests" make build_dev_test @@ -117,10 +120,18 @@ if [ "$RUN_GO_TESTS" = true ]; then set -e cd ${REMOTE_PATH} tar zxvf device-tests.tar.gz -./run_all_tests $TEST_ARGS +PION_LOG_TRACE=all ./run_all_tests $TEST_ARGS EOF + + if [ "$RUN_GO_TESTS_ONLY" = true ]; then + msg_info "▶ Go tests completed" + exit 0 + fi fi +msg_info "▶ Building go binary" +make build_dev + # Kill any existing instances of the application ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app_debug || true" @@ -128,6 +139,8 @@ ssh "${REMOTE_USER}@${REMOTE_HOST}" "killall jetkvm_app_debug || true" ssh "${REMOTE_USER}@${REMOTE_HOST}" "cat > ${REMOTE_PATH}/jetkvm_app_debug" < bin/jetkvm_app if [ "$RESET_USB_HID_DEVICE" = true ]; then + msg_info "▶ Resetting USB HID device" + msg_warn "The option has been deprecated and will be removed in a future version, as JetKVM will now reset USB gadget configuration when needed" # Remove the old USB gadget configuration ssh "${REMOTE_USER}@${REMOTE_HOST}" "rm -rf /sys/kernel/config/usb_gadget/jetkvm/configs/c.1/hid.usb*" ssh "${REMOTE_USER}@${REMOTE_HOST}" "ls /sys/class/udc > /sys/kernel/config/usb_gadget/jetkvm/UDC" diff --git a/go.mod b/go.mod index 6784a59..0e288db 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/jetkvm/kvm -go 1.23.0 +go 1.23.4 + +toolchain go1.24.3 require ( github.com/Masterminds/semver/v3 v3.3.0 @@ -12,21 +14,25 @@ require ( github.com/gin-contrib/logger v1.2.5 github.com/gin-gonic/gin v1.10.0 github.com/google/uuid v1.6.0 + github.com/guregu/null/v6 v6.0.0 github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf github.com/hanwen/go-fuse/v2 v2.5.1 - github.com/hashicorp/go-envparse v0.1.0 github.com/pion/logging v0.2.2 github.com/pion/mdns/v2 v2.0.7 github.com/pion/webrtc/v4 v4.0.0 github.com/pojntfx/go-nbd v0.3.2 github.com/prometheus/client_golang v1.21.0 github.com/prometheus/common v0.62.0 + github.com/prometheus/procfs v0.15.1 github.com/psanford/httpreadat v0.1.0 github.com/rs/zerolog v1.34.0 + github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f + github.com/stretchr/testify v1.10.0 github.com/vishvananda/netlink v1.3.0 go.bug.st/serial v1.6.2 golang.org/x/crypto v0.36.0 golang.org/x/net v0.38.0 + golang.org/x/sys v0.32.0 ) replace github.com/pojntfx/go-nbd v0.3.2 => github.com/chemhack/go-nbd v0.0.0-20241006125820-59e45f5b1e7b @@ -38,6 +44,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.5 // indirect github.com/creack/goselect v0.1.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/gin-contrib/sse v1.0.0 // indirect github.com/go-jose/go-jose/v4 v4.0.2 // indirect @@ -45,7 +52,6 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.26.0 // indirect github.com/goccy/go-json v0.10.5 // indirect - github.com/guregu/null/v6 v6.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect @@ -70,15 +76,14 @@ require ( github.com/pion/stun/v3 v3.0.0 // indirect github.com/pion/transport/v3 v3.0.7 // indirect github.com/pion/turn/v4 v4.0.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/procfs v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/wlynxg/anet v0.0.5 // indirect golang.org/x/arch v0.15.0 // indirect golang.org/x/oauth2 v0.24.0 // indirect - golang.org/x/sys v0.32.0 // indirect golang.org/x/text v0.23.0 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 3ad832a..14aadb1 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,6 @@ github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf h1:JO6ISZIvEUitto github.com/gwatts/rootcerts v0.0.0-20240401182218-3ab9db955caf/go.mod h1:5Kt9XkWvkGi2OHOq0QsGxebHmhCcqJ8KCbNg/a6+n+g= github.com/hanwen/go-fuse/v2 v2.5.1 h1:OQBE8zVemSocRxA4OaFJbjJ5hlpCmIWbGr7r0M4uoQQ= github.com/hanwen/go-fuse/v2 v2.5.1/go.mod h1:xKwi1cF7nXAOBCXujD5ie0ZKsxc8GGSA1rlMJc+8IJs= -github.com/hashicorp/go-envparse v0.1.0 h1:bE++6bhIsNCPLvgDZkYqo3nA+/PFI51pkrHdmPSDFPY= -github.com/hashicorp/go-envparse v0.1.0/go.mod h1:OHheN1GoygLlAkTlXLXvAdnXdZxy8JUweQ1rAXx1xnc= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= @@ -154,6 +152,8 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f h1:VgoRCP1efSCEZIcF2THLQ46+pIBzzgNiaUBe9wEDwYU= +github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f/go.mod h1:pzro7BGorij2WgrjEammtrkbo3+xldxo+KaGLGUiD+Q= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/internal/usbgadget/changeset.go b/internal/usbgadget/changeset.go new file mode 100644 index 0000000..3a5ceaa --- /dev/null +++ b/internal/usbgadget/changeset.go @@ -0,0 +1,432 @@ +package usbgadget + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "reflect" + "time" + + "github.com/prometheus/procfs" + "github.com/sourcegraph/tf-dag/dag" +) + +// it's a minimalistic implementation of ansible's file module with some modifications +// to make it more suitable for our use case +// https://docs.ansible.com/ansible/latest/modules/file_module.html + +// we use this to check if the files in the gadget config are in the expected state +// and to update them if they are not in the expected state + +type FileState uint8 +type ChangeState uint8 +type FileChangeResolvedAction uint8 + +type ApplyFunc func(c *ChangeSet, changes []*FileChange) error + +const ( + FileStateUnknown FileState = iota + FileStateAbsent + FileStateDirectory + FileStateFile + FileStateFileContentMatch + FileStateFileWrite // update file content without checking + FileStateMounted + FileStateMountedConfigFS + FileStateSymlink + FileStateSymlinkInOrderConfigFS // configfs is a shithole, so we need to check if the symlinks are created in the correct order + FileStateSymlinkNotInOrderConfigFS + FileStateTouch +) + +var FileStateString = map[FileState]string{ + FileStateUnknown: "UNKNOWN", + FileStateAbsent: "ABSENT", + FileStateDirectory: "DIRECTORY", + FileStateFile: "FILE", + FileStateFileContentMatch: "FILE_CONTENT_MATCH", + FileStateFileWrite: "FILE_WRITE", + FileStateMounted: "MOUNTED", + FileStateMountedConfigFS: "CONFIGFS_MOUNT", + FileStateSymlink: "SYMLINK", + FileStateSymlinkInOrderConfigFS: "SYMLINK_IN_ORDER_CONFIGFS", + FileStateTouch: "TOUCH", +} + +const ( + ChangeStateUnknown ChangeState = iota + ChangeStateRequired + ChangeStateNotChanged + ChangeStateChanged + ChangeStateError +) + +const ( + FileChangeResolvedActionUnknown FileChangeResolvedAction = iota + FileChangeResolvedActionDoNothing + FileChangeResolvedActionRemove + FileChangeResolvedActionCreateFile + FileChangeResolvedActionWriteFile + FileChangeResolvedActionUpdateFile + FileChangeResolvedActionAppendFile + FileChangeResolvedActionCreateSymlink + FileChangeResolvedActionRecreateSymlink + FileChangeResolvedActionCreateDirectoryAndSymlinks + FileChangeResolvedActionReorderSymlinks + FileChangeResolvedActionCreateDirectory + FileChangeResolvedActionRemoveDirectory + FileChangeResolvedActionTouch + FileChangeResolvedActionMountConfigFS +) + +var FileChangeResolvedActionString = map[FileChangeResolvedAction]string{ + FileChangeResolvedActionUnknown: "UNKNOWN", + FileChangeResolvedActionDoNothing: "DO_NOTHING", + FileChangeResolvedActionRemove: "REMOVE", + FileChangeResolvedActionCreateFile: "FILE_CREATE", + FileChangeResolvedActionWriteFile: "FILE_WRITE", + FileChangeResolvedActionUpdateFile: "FILE_UPDATE", + FileChangeResolvedActionAppendFile: "FILE_APPEND", + FileChangeResolvedActionCreateSymlink: "SYMLINK_CREATE", + FileChangeResolvedActionRecreateSymlink: "SYMLINK_RECREATE", + FileChangeResolvedActionCreateDirectoryAndSymlinks: "DIR_CREATE_AND_SYMLINKS", + FileChangeResolvedActionReorderSymlinks: "SYMLINK_REORDER", + FileChangeResolvedActionCreateDirectory: "DIR_CREATE", + FileChangeResolvedActionRemoveDirectory: "DIR_REMOVE", + FileChangeResolvedActionTouch: "TOUCH", + FileChangeResolvedActionMountConfigFS: "CONFIGFS_MOUNT", +} + +type ChangeSet struct { + Changes []FileChange +} + +type RequestedFileChange struct { + Component string + Key string + Path string // will be used as Key if Key is empty + ParamSymlinks []symlink + ExpectedState FileState + ExpectedContent []byte + DependsOn []string + BeforeChange []string // if the file is going to be changed, apply the change first + Description string + IgnoreErrors bool + When string // only apply the change if when meets the condition +} + +type FileChange struct { + RequestedFileChange + ActualState FileState + ActualContent []byte + resolvedDeps []string + checked bool + changed ChangeState + action FileChangeResolvedAction +} + +func (f *RequestedFileChange) String() string { + var s string + switch f.ExpectedState { + case FileStateDirectory: + s = fmt.Sprintf("dir: %s", f.Path) + case FileStateFile: + s = fmt.Sprintf("file: %s", f.Path) + case FileStateSymlink: + s = fmt.Sprintf("symlink: %s -> %s", f.Path, f.ExpectedContent) + case FileStateSymlinkInOrderConfigFS: + s = fmt.Sprintf("symlink_in_order_configfs: %s -> %s", f.Path, f.ExpectedContent) + case FileStateSymlinkNotInOrderConfigFS: + s = fmt.Sprintf("symlink_not_in_order_configfs: %s -> %s", f.Path, f.ExpectedContent) + case FileStateAbsent: + s = fmt.Sprintf("absent: %s", f.Path) + case FileStateFileContentMatch: + s = fmt.Sprintf("file: %s with content [%s]", f.Path, f.ExpectedContent) + case FileStateFileWrite: + s = fmt.Sprintf("write: %s with content [%s]", f.Path, f.ExpectedContent) + case FileStateMountedConfigFS: + s = fmt.Sprintf("configfs: %s", f.Path) + case FileStateTouch: + s = fmt.Sprintf("touch: %s", f.Path) + case FileStateUnknown: + s = fmt.Sprintf("unknown change for %s", f.Path) + default: + s = fmt.Sprintf("unknown expected state %d for %s", f.ExpectedState, f.Path) + } + + return s +} + +func (f *RequestedFileChange) IsSame(other *RequestedFileChange) bool { + return f.Path == other.Path && + f.ExpectedState == other.ExpectedState && + reflect.DeepEqual(f.ExpectedContent, other.ExpectedContent) && + reflect.DeepEqual(f.DependsOn, other.DependsOn) && + f.IgnoreErrors == other.IgnoreErrors +} + +func (fc *FileChange) checkIfDirIsMountPoint() error { + // check if the file is a mount point + mounts, err := procfs.GetMounts() + if err != nil { + return fmt.Errorf("failed to get mounts") + } + + for _, mount := range mounts { + if mount.MountPoint == fc.Path { + fc.ActualState = FileStateMounted + fc.ActualContent = []byte(mount.Source) + + if mount.FSType == "configfs" { + fc.ActualState = FileStateMountedConfigFS + } + + return nil + } + } + + return nil +} + +// GetActualState returns the actual state of the file at the given path. +func (fc *FileChange) getActualState() error { + l := defaultLogger.With().Str("path", fc.Path).Logger() + + fi, err := os.Lstat(fc.Path) + if err != nil { + if os.IsNotExist(err) { + fc.ActualState = FileStateAbsent + } else { + l.Warn().Err(err).Msg("failed to stat file") + fc.ActualState = FileStateUnknown + } + return nil + } + + // check if the file is a symlink + if fi.Mode()&os.ModeSymlink == os.ModeSymlink { + fc.ActualState = FileStateSymlink + // get the target of the symlink + target, err := os.Readlink(fc.Path) + if err != nil { + l.Warn().Err(err).Msg("failed to read symlink") + return fmt.Errorf("failed to read symlink") + } + // check if the target is a relative path + if !filepath.IsAbs(target) { + // make it absolute + target, err = filepath.Abs(filepath.Join(filepath.Dir(fc.Path), target)) + if err != nil { + l.Warn().Err(err).Msg("failed to make symlink target absolute") + return fmt.Errorf("failed to make symlink target absolute") + } + } + fc.ActualContent = []byte(target) + return nil + } + + if fi.IsDir() { + fc.ActualState = FileStateDirectory + + switch fc.ExpectedState { + case FileStateMountedConfigFS: + err := fc.checkIfDirIsMountPoint() + if err != nil { + l.Warn().Err(err).Msg("failed to check if dir is mount point") + return err + } + case FileStateSymlinkInOrderConfigFS: + state, err := checkIfSymlinksInOrder(fc, &l) + if err != nil { + l.Warn().Err(err).Msg("failed to check if symlinks are in order") + return err + } + fc.ActualState = state + } + return nil + } + + if fi.Mode()&os.ModeDevice == os.ModeDevice { + l.Info().Msg("file is a device") + return nil + } + + // check if the file is a regular file + if fi.Mode().IsRegular() { + fc.ActualState = FileStateFile + // get the content of the file + content, err := os.ReadFile(fc.Path) + if err != nil { + l.Warn().Err(err).Msg("failed to read file") + return fmt.Errorf("failed to read file") + } + fc.ActualContent = content + return nil + } + + l.Warn().Interface("file_info", fi.Mode()).Bool("is_dir", fi.IsDir()).Msg("unknown file type") + + return fmt.Errorf("unknown file type") +} + +func (fc *FileChange) ResetActionResolution() { + fc.checked = false + fc.action = FileChangeResolvedActionUnknown + fc.changed = ChangeStateUnknown +} + +func (fc *FileChange) Action() FileChangeResolvedAction { + if !fc.checked { + fc.action = fc.getFileChangeResolvedAction() + fc.checked = true + } + + return fc.action +} + +func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { + l := defaultLogger.With().Str("path", fc.Path).Logger() + + // some actions are not needed to be checked + switch fc.ExpectedState { + case FileStateFileWrite: + return FileChangeResolvedActionWriteFile + case FileStateTouch: + return FileChangeResolvedActionTouch + } + + // get the actual state of the file + err := fc.getActualState() + if err != nil { + return FileChangeResolvedActionDoNothing + } + + baseName := filepath.Base(fc.Path) + + switch fc.ExpectedState { + case FileStateDirectory: + // if the file is already a directory, do nothing + if fc.ActualState == FileStateDirectory { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionCreateDirectory + case FileStateFile: + // if the file is already a file, do nothing + if fc.ActualState == FileStateFile { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionCreateFile + case FileStateFileContentMatch: + // if the file is already a file with the expected content, do nothing + if fc.ActualState == FileStateFile { + looserMatch := baseName == "inquiry_string" + if compareFileContent(fc.ActualContent, fc.ExpectedContent, looserMatch) { + return FileChangeResolvedActionDoNothing + } + // TODO: move this to somewhere else + // this is a workaround for the fact that the file is not updated if it has no content + if baseName == "file" && + bytes.Equal(fc.ActualContent, []byte{}) && + bytes.Equal(fc.ExpectedContent, []byte{0x0a}) { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionUpdateFile + } + return FileChangeResolvedActionCreateFile + case FileStateSymlink: + // if the file is already a symlink, check if the target is the same + if fc.ActualState == FileStateSymlink { + if reflect.DeepEqual(fc.ActualContent, fc.ExpectedContent) { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionRecreateSymlink + } + return FileChangeResolvedActionCreateSymlink + case FileStateSymlinkInOrderConfigFS: + // if the file is already a symlink, check if the target is the same + if fc.ActualState == FileStateSymlinkInOrderConfigFS { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionReorderSymlinks + case FileStateAbsent: + if fc.ActualState == FileStateAbsent { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionRemove + case FileStateMountedConfigFS: + if fc.ActualState == FileStateMountedConfigFS { + return FileChangeResolvedActionDoNothing + } + return FileChangeResolvedActionMountConfigFS + default: + l.Warn().Interface("file_change", FileStateString[fc.ExpectedState]).Msg("unknown expected state") + return FileChangeResolvedActionDoNothing + } +} + +func (c *ChangeSet) AddFileChangeStruct(r RequestedFileChange) { + fc := FileChange{ + RequestedFileChange: r, + } + c.Changes = append(c.Changes, fc) +} + +func (c *ChangeSet) AddFileChange(component string, path string, expectedState FileState, expectedContent []byte, dependsOn []string, description string) { + c.AddFileChangeStruct(RequestedFileChange{ + Component: component, + Path: path, + ExpectedState: expectedState, + ExpectedContent: expectedContent, + DependsOn: dependsOn, + Description: description, + }) +} + +func (c *ChangeSet) ApplyChanges() error { + r := ChangeSetResolver{ + changeset: c, + g: &dag.AcyclicGraph{}, + l: defaultLogger, + } + + return r.Apply() +} + +func (c *ChangeSet) applyChange(change *FileChange) error { + switch change.Action() { + case FileChangeResolvedActionWriteFile: + return os.WriteFile(change.Path, change.ExpectedContent, 0644) + case FileChangeResolvedActionUpdateFile: + return os.WriteFile(change.Path, change.ExpectedContent, 0644) + case FileChangeResolvedActionCreateFile: + return os.WriteFile(change.Path, change.ExpectedContent, 0644) + case FileChangeResolvedActionCreateSymlink: + return os.Symlink(string(change.ExpectedContent), change.Path) + case FileChangeResolvedActionRecreateSymlink: + if err := os.Remove(change.Path); err != nil { + return fmt.Errorf("failed to remove symlink: %w", err) + } + return os.Symlink(string(change.ExpectedContent), change.Path) + case FileChangeResolvedActionReorderSymlinks: + return recreateSymlinks(change, nil) + case FileChangeResolvedActionCreateDirectory: + return os.MkdirAll(change.Path, 0755) + case FileChangeResolvedActionRemove: + return os.Remove(change.Path) + case FileChangeResolvedActionRemoveDirectory: + return os.RemoveAll(change.Path) + case FileChangeResolvedActionTouch: + return os.Chtimes(change.Path, time.Now(), time.Now()) + case FileChangeResolvedActionMountConfigFS: + return mountConfigFS(change.Path) + case FileChangeResolvedActionDoNothing: + return nil + default: + return fmt.Errorf("unknown action: %d", change.Action()) + } +} + +func (c *ChangeSet) Apply() error { + return c.ApplyChanges() +} diff --git a/internal/usbgadget/changeset_arm_test.go b/internal/usbgadget/changeset_arm_test.go new file mode 100644 index 0000000..8c0abd5 --- /dev/null +++ b/internal/usbgadget/changeset_arm_test.go @@ -0,0 +1,115 @@ +//go:build arm && linux + +package usbgadget + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + usbConfig = &Config{ + VendorId: "0x1d6b", //The Linux Foundation + ProductId: "0x0104", //Multifunction Composite Gadget + SerialNumber: "", + Manufacturer: "JetKVM", + Product: "USB Emulation Device", + strictMode: true, + } + usbDevices = &Devices{ + AbsoluteMouse: true, + RelativeMouse: true, + Keyboard: true, + MassStorage: true, + } + usbGadgetName = "jetkvm" + usbGadget *UsbGadget +) + +var oldAbsoluteMouseCombinedReportDesc = []byte{ + 0x05, 0x01, // Usage Page (Generic Desktop Ctrls) + 0x09, 0x02, // Usage (Mouse) + 0xA1, 0x01, // Collection (Application) + + // Report ID 1: Absolute Mouse Movement + 0x85, 0x01, // Report ID (1) + 0x09, 0x01, // Usage (Pointer) + 0xA1, 0x00, // Collection (Physical) + 0x05, 0x09, // Usage Page (Button) + 0x19, 0x01, // Usage Minimum (0x01) + 0x29, 0x03, // Usage Maximum (0x03) + 0x15, 0x00, // Logical Minimum (0) + 0x25, 0x01, // Logical Maximum (1) + 0x75, 0x01, // Report Size (1) + 0x95, 0x03, // Report Count (3) + 0x81, 0x02, // Input (Data, Var, Abs) + 0x95, 0x01, // Report Count (1) + 0x75, 0x05, // Report Size (5) + 0x81, 0x03, // Input (Cnst, Var, Abs) + 0x05, 0x01, // Usage Page (Generic Desktop Ctrls) + 0x09, 0x30, // Usage (X) + 0x09, 0x31, // Usage (Y) + 0x16, 0x00, 0x00, // Logical Minimum (0) + 0x26, 0xFF, 0x7F, // Logical Maximum (32767) + 0x36, 0x00, 0x00, // Physical Minimum (0) + 0x46, 0xFF, 0x7F, // Physical Maximum (32767) + 0x75, 0x10, // Report Size (16) + 0x95, 0x02, // Report Count (2) + 0x81, 0x02, // Input (Data, Var, Abs) + 0xC0, // End Collection + + // Report ID 2: Relative Wheel Movement + 0x85, 0x02, // Report ID (2) + 0x09, 0x38, // Usage (Wheel) + 0x15, 0x81, // Logical Minimum (-127) + 0x25, 0x7F, // Logical Maximum (127) + 0x75, 0x08, // Report Size (8) + 0x95, 0x01, // Report Count (1) + 0x81, 0x06, // Input (Data, Var, Rel) + + 0xC0, // End Collection +} + +func TestUsbGadgetInit(t *testing.T) { + assert := assert.New(t) + usbGadget = NewUsbGadget(usbGadgetName, usbDevices, usbConfig, nil) + + assert.NotNil(usbGadget) +} + +func TestUsbGadgetStrictModeInitFail(t *testing.T) { + usbConfig.strictMode = true + u := NewUsbGadget("test", usbDevices, usbConfig, nil) + assert.Nil(t, u, "should be nil") +} + +func TestUsbGadgetUDCNotBoundAfterReportDescrChanged(t *testing.T) { + assert := assert.New(t) + usbGadget = NewUsbGadget(usbGadgetName, usbDevices, usbConfig, nil) + assert.NotNil(usbGadget) + + // release the usb gadget and create a new one + usbGadget = nil + + altGadgetConfig := defaultGadgetConfig + + oldAbsoluteMouseConfig := altGadgetConfig["absolute_mouse"] + oldAbsoluteMouseConfig.reportDesc = oldAbsoluteMouseCombinedReportDesc + altGadgetConfig["absolute_mouse"] = oldAbsoluteMouseConfig + + usbGadget = newUsbGadget(usbGadgetName, altGadgetConfig, usbDevices, usbConfig, nil) + assert.NotNil(usbGadget) + + udcs := getUdcs() + assert.Equal(1, len(udcs), "should be only one UDC") + // check if the UDC is bound + udc := udcs[0] + assert.NotNil(udc, "UDC should exist") + + udcStr, err := os.ReadFile("/sys/kernel/config/usb_gadget/jetkvm/UDC") + assert.Nil(err, "usb_gadget/UDC should exist") + assert.Equal(strings.TrimSpace(udc), strings.TrimSpace(string(udcStr)), "UDC should be the same") +} diff --git a/internal/usbgadget/changeset_resolver.go b/internal/usbgadget/changeset_resolver.go new file mode 100644 index 0000000..9369daf --- /dev/null +++ b/internal/usbgadget/changeset_resolver.go @@ -0,0 +1,183 @@ +package usbgadget + +import ( + "fmt" + + "github.com/rs/zerolog" + "github.com/sourcegraph/tf-dag/dag" +) + +type ChangeSetResolver struct { + changeset *ChangeSet + + l *zerolog.Logger + g *dag.AcyclicGraph + + changesMap map[string]*FileChange + conditionalChangesMap map[string]*FileChange + + orderedChanges []dag.Vertex + resolvedChanges []*FileChange + additionalResolveRequired bool +} + +func (c *ChangeSetResolver) toOrderedChanges() error { + for key, change := range c.changesMap { + v := c.g.Add(key) + + for _, dependsOn := range change.DependsOn { + c.g.Connect(dag.BasicEdge(dependsOn, v)) + } + for _, dependsOn := range change.resolvedDeps { + c.g.Connect(dag.BasicEdge(dependsOn, v)) + } + } + + cycles := c.g.Cycles() + if len(cycles) > 0 { + return fmt.Errorf("cycles detected: %v", cycles) + } + + orderedChanges := c.g.TopologicalOrder() + c.orderedChanges = orderedChanges + return nil +} + +func (c *ChangeSetResolver) doResolveChanges(initial bool) error { + resolvedChanges := make([]*FileChange, 0) + + for _, key := range c.orderedChanges { + change := c.changesMap[key.(string)] + if !initial { + change.ResetActionResolution() + } + + resolvedAction := change.Action() + + resolvedChanges = append(resolvedChanges, change) + // no need to check the triggers if there's no change + if resolvedAction == FileChangeResolvedActionDoNothing { + continue + } + + if !initial { + continue + } + + if change.BeforeChange != nil { + change.resolvedDeps = append(change.resolvedDeps, change.BeforeChange...) + c.additionalResolveRequired = true + + // add the dependencies to the changes map + for _, dep := range change.BeforeChange { + depChange, ok := c.conditionalChangesMap[dep] + if !ok { + return fmt.Errorf("dependency %s not found", dep) + } + + c.changesMap[dep] = depChange + } + } + } + + c.resolvedChanges = resolvedChanges + return nil +} + +func (c *ChangeSetResolver) resolveChanges(initial bool) error { + // get the ordered changes + err := c.toOrderedChanges() + if err != nil { + return err + } + + // resolve the changes + err = c.doResolveChanges(initial) + if err != nil { + return err + } + + for _, change := range c.resolvedChanges { + c.l.Trace().Str("change", change.String()).Msg("resolved change") + } + + if !c.additionalResolveRequired || !initial { + return nil + } + + return c.resolveChanges(false) +} + +func (c *ChangeSetResolver) applyChanges() error { + for _, change := range c.resolvedChanges { + change.ResetActionResolution() + action := change.Action() + actionStr := FileChangeResolvedActionString[action] + + l := c.l.Info() + if action == FileChangeResolvedActionDoNothing { + l = c.l.Trace() + } + + l.Str("action", actionStr).Str("change", change.String()).Msg("applying change") + + err := c.changeset.applyChange(change) + if err != nil { + return err + } + } + + return nil +} + +func (c *ChangeSetResolver) GetChanges() ([]*FileChange, error) { + localChanges := c.changeset.Changes + changesMap := make(map[string]*FileChange) + conditionalChangesMap := make(map[string]*FileChange) + + // build the map of the changes + for _, change := range localChanges { + key := change.Key + if key == "" { + key = change.Path + } + + // remove it from the map first + if change.When != "" { + conditionalChangesMap[key] = &change + continue + } + + if _, ok := changesMap[key]; ok { + if changesMap[key].IsSame(&change.RequestedFileChange) { + continue + } + return nil, fmt.Errorf( + "duplicate change: %s, current: %s, requested: %s", + key, + changesMap[key].String(), + change.String(), + ) + } + + changesMap[key] = &change + } + + c.changesMap = changesMap + c.conditionalChangesMap = conditionalChangesMap + + err := c.resolveChanges(true) + if err != nil { + return nil, err + } + + return c.resolvedChanges, nil +} + +func (c *ChangeSetResolver) Apply() error { + if _, err := c.GetChanges(); err != nil { + return err + } + + return c.applyChanges() +} diff --git a/internal/usbgadget/changeset_symlink.go b/internal/usbgadget/changeset_symlink.go new file mode 100644 index 0000000..16ffb77 --- /dev/null +++ b/internal/usbgadget/changeset_symlink.go @@ -0,0 +1,136 @@ +package usbgadget + +import ( + "fmt" + "os" + "path" + "path/filepath" + "reflect" + + "github.com/rs/zerolog" +) + +type symlink struct { + Path string + Target string +} + +func compareSymlinks(expected []symlink, actual []symlink) bool { + if len(expected) != len(actual) { + return false + } + + return reflect.DeepEqual(expected, actual) +} + +func checkIfSymlinksInOrder(fc *FileChange, logger *zerolog.Logger) (FileState, error) { + if logger == nil { + logger = defaultLogger + } + l := logger.With().Str("path", fc.Path).Logger() + + if fc.ParamSymlinks == nil || len(fc.ParamSymlinks) == 0 { + return FileStateUnknown, fmt.Errorf("no symlinks to check") + } + + fi, err := os.Lstat(fc.Path) + + if err != nil { + if os.IsNotExist(err) { + return FileStateAbsent, nil + } else { + l.Warn().Err(err).Msg("failed to stat file") + return FileStateUnknown, fmt.Errorf("failed to stat file") + } + } + + if !fi.IsDir() { + return FileStateUnknown, fmt.Errorf("file is not a directory") + } + + files, err := os.ReadDir(fc.Path) + symlinks := make([]symlink, 0) + if err != nil { + return FileStateUnknown, fmt.Errorf("failed to read directory") + } + + for _, file := range files { + if file.Type()&os.ModeSymlink != os.ModeSymlink { + continue + } + + path := filepath.Join(fc.Path, file.Name()) + target, err := os.Readlink(path) + if err != nil { + return FileStateUnknown, fmt.Errorf("failed to read symlink") + } + + if !filepath.IsAbs(target) { + target = filepath.Join(fc.Path, target) + newTarget, err := filepath.Abs(target) + if err != nil { + return FileStateUnknown, fmt.Errorf("failed to get absolute path") + } + target = newTarget + } + + symlinks = append(symlinks, symlink{ + Path: path, + Target: target, + }) + } + + // compare the symlinks with the expected symlinks + if compareSymlinks(fc.ParamSymlinks, symlinks) { + return FileStateSymlinkInOrderConfigFS, nil + } + + l.Trace().Interface("expected", fc.ParamSymlinks).Interface("actual", symlinks).Msg("symlinks are not in order") + + return FileStateSymlinkNotInOrderConfigFS, nil +} + +func recreateSymlinks(fc *FileChange, logger *zerolog.Logger) error { + if logger == nil { + logger = defaultLogger + } + // remove all symlinks + files, err := os.ReadDir(fc.Path) + if err != nil { + return fmt.Errorf("failed to read directory") + } + + l := logger.With().Str("path", fc.Path).Logger() + l.Info().Msg("recreate symlinks") + + for _, file := range files { + if file.Type()&os.ModeSymlink != os.ModeSymlink { + continue + } + l.Info().Str("name", file.Name()).Msg("remove symlink") + err := os.Remove(path.Join(fc.Path, file.Name())) + if err != nil { + return fmt.Errorf("failed to remove symlink") + } + } + + l.Info().Interface("param-symlinks", fc.ParamSymlinks).Msg("create symlinks") + + // create the symlinks + for _, symlink := range fc.ParamSymlinks { + l.Info().Str("name", symlink.Path).Str("target", symlink.Target).Msg("create symlink") + + path := symlink.Path + if !filepath.IsAbs(path) { + path = filepath.Join(fc.Path, path) + } + + err := os.Symlink(symlink.Target, path) + if err != nil { + l.Warn().Err(err).Msg("failed to create symlink") + return fmt.Errorf("failed to create symlink") + } + } + + return nil +} diff --git a/internal/usbgadget/config.go b/internal/usbgadget/config.go index 5c287da..f4c9ce4 100644 --- a/internal/usbgadget/config.go +++ b/internal/usbgadget/config.go @@ -4,9 +4,6 @@ import ( "fmt" "os" "os/exec" - "path" - "path/filepath" - "sort" ) type gadgetConfigItem struct { @@ -160,15 +157,15 @@ func (u *UsbGadget) OverrideGadgetConfig(itemKey string, itemAttr string, value return nil, true } -func mountConfigFS() error { - _, err := os.Stat(gadgetPath) +func mountConfigFS(path string) error { + _, err := os.Stat(path) // TODO: check if it's mounted properly if err == nil { return nil } if os.IsNotExist(err) { - err = exec.Command("mount", "-t", "configfs", "none", configFSPath).Run() + err = exec.Command("mount", "-t", "configfs", "none", path).Run() if err != nil { return fmt.Errorf("failed to mount configfs: %w", err) } @@ -186,26 +183,19 @@ func (u *UsbGadget) Init() error { udcs := getUdcs() if len(udcs) < 1 { - u.log.Error().Msg("no udc found, skipping USB stack init") - return nil + return u.logWarn("no udc found, skipping USB stack init", nil) } u.udc = udcs[0] - _, err := os.Stat(u.kvmGadgetPath) - if err == nil { - u.log.Info().Msg("usb gadget already exists") - } - if err := mountConfigFS(); err != nil { - u.log.Error().Err(err).Msg("failed to mount configfs, usb stack might not function properly") - } - - if err := os.MkdirAll(u.configC1Path, 0755); err != nil { - u.log.Error().Err(err).Msg("failed to create config path") - } - - if err := u.writeGadgetConfig(); err != nil { - u.log.Error().Err(err).Msg("failed to start gadget") + err := u.WithTransaction(func() error { + u.tx.MountConfigFS() + u.tx.CreateConfigPath() + u.tx.WriteGadgetConfig() + return nil + }) + if err != nil { + return u.logError("unable to initialize USB stack", err) } return nil @@ -217,143 +207,13 @@ func (u *UsbGadget) UpdateGadgetConfig() error { u.loadGadgetConfig() - if err := u.writeGadgetConfig(); err != nil { - u.log.Error().Err(err).Msg("failed to update gadget") - } - - return nil -} - -func (u *UsbGadget) getOrderedConfigItems() orderedGadgetConfigItems { - items := make([]gadgetConfigItemWithKey, 0) - for key, item := range u.configMap { - items = append(items, gadgetConfigItemWithKey{key, item}) - } - - sort.Slice(items, func(i, j int) bool { - return items[i].item.order < items[j].item.order + err := u.WithTransaction(func() error { + u.tx.WriteGadgetConfig() + return nil }) - - return items -} - -func (u *UsbGadget) writeGadgetConfig() error { - // create kvm gadget path - err := os.MkdirAll(u.kvmGadgetPath, 0755) if err != nil { - return err - } - - u.log.Trace().Msg("writing gadget config") - for _, val := range u.getOrderedConfigItems() { - key := val.key - item := val.item - - // check if the item is enabled in the config - if !u.isGadgetConfigItemEnabled(key) { - u.log.Trace().Str("key", key).Msg("disabling gadget config") - err = u.disableGadgetItemConfig(item) - if err != nil { - return err - } - continue - } - u.log.Trace().Str("key", key).Msg("writing gadget config") - err = u.writeGadgetItemConfig(item) - if err != nil { - return err - } - } - - if err = u.writeUDC(); err != nil { - u.log.Error().Err(err).Msg("failed to write UDC") - return err - } - - if err = u.rebindUsb(true); err != nil { - u.log.Info().Err(err).Msg("failed to rebind usb") + return u.logError("unable to update gadget config", err) } return nil } - -func (u *UsbGadget) disableGadgetItemConfig(item gadgetConfigItem) error { - // remove symlink if exists - if item.configPath == nil { - return nil - } - - configPath := joinPath(u.configC1Path, item.configPath) - - if _, err := os.Lstat(configPath); os.IsNotExist(err) { - u.log.Trace().Str("path", configPath).Msg("symlink does not exist") - return nil - } - - if err := os.Remove(configPath); err != nil { - return fmt.Errorf("failed to remove symlink %s: %w", item.configPath, err) - } - - return nil -} - -func (u *UsbGadget) writeGadgetItemConfig(item gadgetConfigItem) error { - // create directory for the item - gadgetItemPath := joinPath(u.kvmGadgetPath, item.path) - err := os.MkdirAll(gadgetItemPath, 0755) - if err != nil { - return fmt.Errorf("failed to create path %s: %w", gadgetItemPath, err) - } - - if len(item.attrs) > 0 { - // write attributes for the item - err = u.writeGadgetAttrs(gadgetItemPath, item.attrs) - if err != nil { - return fmt.Errorf("failed to write attributes for %s: %w", gadgetItemPath, err) - } - } - - // write report descriptor if available - if item.reportDesc != nil { - err = u.writeIfDifferent(path.Join(gadgetItemPath, "report_desc"), item.reportDesc, 0644) - if err != nil { - return err - } - } - - // create config directory if configAttrs are set - if len(item.configAttrs) > 0 { - configItemPath := joinPath(u.configC1Path, item.configPath) - err = os.MkdirAll(configItemPath, 0755) - if err != nil { - return fmt.Errorf("failed to create path %s: %w", configItemPath, err) - } - - err = u.writeGadgetAttrs(configItemPath, item.configAttrs) - if err != nil { - return fmt.Errorf("failed to write config attributes for %s: %w", configItemPath, err) - } - } - - // create symlink if configPath is set - if item.configPath != nil && item.configAttrs == nil { - configPath := joinPath(u.configC1Path, item.configPath) - u.log.Trace().Str("source", configPath).Str("target", gadgetItemPath).Msg("creating symlink") - if err := ensureSymlink(configPath, gadgetItemPath); err != nil { - return err - } - } - - return nil -} - -func (u *UsbGadget) writeGadgetAttrs(basePath string, attrs gadgetAttributes) error { - for key, val := range attrs { - filePath := filepath.Join(basePath, key) - err := u.writeIfDifferent(filePath, []byte(val), 0644) - if err != nil { - return fmt.Errorf("failed to write to %s: %w", filePath, err) - } - } - return nil -} diff --git a/internal/usbgadget/config_tx.go b/internal/usbgadget/config_tx.go new file mode 100644 index 0000000..be72487 --- /dev/null +++ b/internal/usbgadget/config_tx.go @@ -0,0 +1,329 @@ +package usbgadget + +import ( + "fmt" + "path" + "path/filepath" + "sort" + + "github.com/rs/zerolog" +) + +// no os package should occur in this file + +type UsbGadgetTransaction struct { + c *ChangeSet + + // below are the fields that are needed to be set by the caller + log *zerolog.Logger + udc string + dwc3Path string + kvmGadgetPath string + configC1Path string + orderedConfigItems orderedGadgetConfigItems + isGadgetConfigItemEnabled func(key string) bool + + reorderSymlinkChanges *RequestedFileChange +} + +func (u *UsbGadget) newUsbGadgetTransaction(lock bool) error { + if lock { + u.txLock.Lock() + defer u.txLock.Unlock() + } + + if u.tx != nil { + return fmt.Errorf("transaction already exists") + } + + tx := &UsbGadgetTransaction{ + c: &ChangeSet{}, + log: u.log, + udc: u.udc, + dwc3Path: dwc3Path, + kvmGadgetPath: u.kvmGadgetPath, + configC1Path: u.configC1Path, + orderedConfigItems: u.getOrderedConfigItems(), + isGadgetConfigItemEnabled: u.isGadgetConfigItemEnabled, + } + u.tx = tx + + return nil +} + +func (u *UsbGadget) WithTransaction(fn func() error) error { + u.txLock.Lock() + defer u.txLock.Unlock() + + err := u.newUsbGadgetTransaction(false) + if err != nil { + u.log.Error().Err(err).Msg("failed to create transaction") + return err + } + if err := fn(); err != nil { + u.log.Error().Err(err).Msg("transaction failed") + return err + } + result := u.tx.Commit() + u.tx = nil + + return result +} + +func (tx *UsbGadgetTransaction) addFileChange(component string, change RequestedFileChange) string { + change.Component = component + tx.c.AddFileChangeStruct(change) + + key := change.Key + if key == "" { + key = change.Path + } + return key +} + +func (tx *UsbGadgetTransaction) mkdirAll(component string, path string, description string) string { + return tx.addFileChange(component, RequestedFileChange{ + Path: path, + ExpectedState: FileStateDirectory, + Description: description, + }) +} + +func (tx *UsbGadgetTransaction) removeFile(component string, path string, description string) string { + return tx.addFileChange(component, RequestedFileChange{ + Path: path, + ExpectedState: FileStateAbsent, + Description: description, + }) +} + +func (tx *UsbGadgetTransaction) Commit() error { + tx.addFileChange("gadget-finalize", *tx.reorderSymlinkChanges) + + err := tx.c.Apply() + if err != nil { + tx.log.Error().Err(err).Msg("failed to update usbgadget configuration") + return err + } + tx.log.Info().Msg("usbgadget configuration updated") + return nil +} + +func (u *UsbGadget) getOrderedConfigItems() orderedGadgetConfigItems { + items := make([]gadgetConfigItemWithKey, 0) + for key, item := range u.configMap { + items = append(items, gadgetConfigItemWithKey{key, item}) + } + + sort.Slice(items, func(i, j int) bool { + return items[i].item.order < items[j].item.order + }) + + return items +} + +func (tx *UsbGadgetTransaction) MountConfigFS() { + tx.addFileChange("gadget", RequestedFileChange{ + Path: configFSPath, + ExpectedState: FileStateMountedConfigFS, + Description: "mount configfs", + }) +} + +func (tx *UsbGadgetTransaction) CreateConfigPath() { + tx.mkdirAll("gadget", tx.configC1Path, "create config path") +} + +func (tx *UsbGadgetTransaction) WriteGadgetConfig() { + // create kvm gadget path + tx.mkdirAll("gadget", tx.kvmGadgetPath, "create kvm gadget path") + + deps := make([]string, 0) + + for _, val := range tx.orderedConfigItems { + key := val.key + item := val.item + + // check if the item is enabled in the config + if !tx.isGadgetConfigItemEnabled(key) { + tx.DisableGadgetItemConfig(item) + continue + } + deps = tx.writeGadgetItemConfig(item, deps) + } + + tx.WriteUDC() +} + +func (tx *UsbGadgetTransaction) getDisableKeys() []string { + disableKeys := make([]string, 0) + for _, item := range tx.orderedConfigItems { + if !tx.isGadgetConfigItemEnabled(item.key) { + continue + } + if item.item.configPath == nil || item.item.configAttrs != nil { + continue + } + + disableKeys = append(disableKeys, fmt.Sprintf("disable-%s", item.item.device)) + } + return disableKeys +} + +func (tx *UsbGadgetTransaction) DisableGadgetItemConfig(item gadgetConfigItem) { + // remove symlink if exists + if item.configPath == nil { + return + } + + configPath := joinPath(tx.configC1Path, item.configPath) + _ = tx.removeFile("gadget", configPath, "remove symlink: disable gadget config") +} + +func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, deps []string) []string { + component := item.device + + // create directory for the item + files := make([]string, 0) + files = append(files, deps...) + + gadgetItemPath := joinPath(tx.kvmGadgetPath, item.path) + files = append(files, tx.mkdirAll(component, gadgetItemPath, "create gadget item directory")) + + beforeChange := make([]string, 0) + disableGadgetItemKey := fmt.Sprintf("disable-%s", item.device) + if item.configPath != nil && item.configAttrs == nil { + beforeChange = append(beforeChange, tx.getDisableKeys()...) + } + + if len(item.attrs) > 0 { + // write attributes for the item + files = append(files, tx.writeGadgetAttrs( + gadgetItemPath, + item.attrs, + component, + beforeChange, + )...) + } + + // write report descriptor if available + reportDescPath := path.Join(gadgetItemPath, "report_desc") + if item.reportDesc != nil { + tx.addFileChange(component, RequestedFileChange{ + Path: reportDescPath, + ExpectedState: FileStateFileContentMatch, + ExpectedContent: item.reportDesc, + Description: "write report descriptor", + BeforeChange: beforeChange, + DependsOn: files, + }) + } else { + tx.addFileChange(component, RequestedFileChange{ + Path: reportDescPath, + ExpectedState: FileStateAbsent, + Description: "remove report descriptor", + BeforeChange: beforeChange, + DependsOn: files, + }) + } + files = append(files, reportDescPath) + + // create config directory if configAttrs are set + if len(item.configAttrs) > 0 { + configItemPath := joinPath(tx.configC1Path, item.configPath) + tx.mkdirAll(component, configItemPath, "create config item directory") + files = append(files, tx.writeGadgetAttrs( + configItemPath, + item.configAttrs, + component, + beforeChange, + )...) + } + + // create symlink if configPath is set + if item.configPath != nil && item.configAttrs == nil { + configPath := joinPath(tx.configC1Path, item.configPath) + + // the change will be only applied by `beforeChange` + tx.addFileChange(component, RequestedFileChange{ + Key: disableGadgetItemKey, + Path: configPath, + ExpectedState: FileStateAbsent, + When: "beforeChange", // TODO: make it more flexible + Description: "remove symlink", + }) + + tx.addReorderSymlinkChange(configPath, gadgetItemPath, files) + } + + return files +} + +func (tx *UsbGadgetTransaction) writeGadgetAttrs(basePath string, attrs gadgetAttributes, component string, beforeChange []string) (files []string) { + files = make([]string, 0) + for key, val := range attrs { + filePath := filepath.Join(basePath, key) + tx.addFileChange(component, RequestedFileChange{ + Path: filePath, + ExpectedState: FileStateFileContentMatch, + ExpectedContent: []byte(val), + Description: "write gadget attribute", + DependsOn: []string{basePath}, + BeforeChange: beforeChange, + }) + files = append(files, filePath) + } + return files +} + +func (tx *UsbGadgetTransaction) addReorderSymlinkChange(path string, target string, deps []string) { + tx.log.Trace().Str("path", path).Str("target", target).Msg("add reorder symlink change") + + if tx.reorderSymlinkChanges == nil { + tx.reorderSymlinkChanges = &RequestedFileChange{ + Component: "gadget-finalize", + Key: "reorder-symlinks", + Path: tx.configC1Path, + ExpectedState: FileStateSymlinkInOrderConfigFS, + Description: "order symlinks", + ParamSymlinks: []symlink{}, + } + } + + tx.reorderSymlinkChanges.DependsOn = append(tx.reorderSymlinkChanges.DependsOn, deps...) + tx.reorderSymlinkChanges.ParamSymlinks = append(tx.reorderSymlinkChanges.ParamSymlinks, symlink{ + Path: path, + Target: target, + }) +} + +func (tx *UsbGadgetTransaction) WriteUDC() { + // bound the gadget to a UDC (USB Device Controller) + path := path.Join(tx.kvmGadgetPath, "UDC") + tx.addFileChange("udc", RequestedFileChange{ + Path: path, + ExpectedState: FileStateFileContentMatch, + ExpectedContent: []byte(tx.udc), + DependsOn: []string{"reorder-symlinks"}, + Description: "write UDC", + }) +} + +func (tx *UsbGadgetTransaction) RebindUsb(ignoreUnbindError bool) { + // remove the gadget from the UDC + tx.addFileChange("udc", RequestedFileChange{ + Path: path.Join(tx.dwc3Path, "unbind"), + ExpectedState: FileStateFileWrite, + ExpectedContent: []byte(tx.udc), + Description: "unbind UDC", + }) + // bind the gadget to the UDC + tx.addFileChange("udc", RequestedFileChange{ + Path: path.Join(tx.dwc3Path, "bind"), + ExpectedState: FileStateFileWrite, + ExpectedContent: []byte(tx.udc), + Description: "bind UDC", + DependsOn: []string{path.Join(tx.dwc3Path, "unbind")}, + IgnoreErrors: ignoreUnbindError, + }) +} diff --git a/internal/usbgadget/log.go b/internal/usbgadget/log.go new file mode 100644 index 0000000..f979f6c --- /dev/null +++ b/internal/usbgadget/log.go @@ -0,0 +1,27 @@ +package usbgadget + +import ( + "errors" +) + +func (u *UsbGadget) logWarn(msg string, err error) error { + if err == nil { + err = errors.New(msg) + } + if u.strictMode { + return err + } + u.log.Warn().Err(err).Msg(msg) + return nil +} + +func (u *UsbGadget) logError(msg string, err error) error { + if err == nil { + err = errors.New(msg) + } + if u.strictMode { + return err + } + u.log.Error().Err(err).Msg(msg) + return nil +} diff --git a/internal/usbgadget/udc.go b/internal/usbgadget/udc.go index 84dfbe4..4b7fbe3 100644 --- a/internal/usbgadget/udc.go +++ b/internal/usbgadget/udc.go @@ -50,18 +50,6 @@ func (u *UsbGadget) RebindUsb(ignoreUnbindError bool) error { return u.rebindUsb(ignoreUnbindError) } -func (u *UsbGadget) writeUDC() error { - path := path.Join(u.kvmGadgetPath, "UDC") - - u.log.Trace().Str("udc", u.udc).Str("path", path).Msg("writing UDC") - err := u.writeIfDifferent(path, []byte(u.udc), 0644) - if err != nil { - return fmt.Errorf("failed to write UDC: %w", err) - } - - return nil -} - // GetUsbState returns the current state of the USB gadget func (u *UsbGadget) GetUsbState() (state string) { stateFile := path.Join("/sys/class/udc", u.udc, "state") diff --git a/internal/usbgadget/usbgadget.go b/internal/usbgadget/usbgadget.go index 1dff2f3..f8b2b3e 100644 --- a/internal/usbgadget/usbgadget.go +++ b/internal/usbgadget/usbgadget.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/zerolog" ) @@ -28,7 +29,8 @@ type Config struct { Manufacturer string `json:"manufacturer"` Product string `json:"product"` - isEmpty bool + strictMode bool // when it's enabled, all warnings will be converted to errors + isEmpty bool } var defaultUsbGadgetDevices = Devices{ @@ -59,22 +61,31 @@ type UsbGadget struct { enabledDevices Devices + strictMode bool // only intended for testing for now + absMouseAccumulatedWheelY float64 lastUserInput time.Time + tx *UsbGadgetTransaction + txLock sync.Mutex + log *zerolog.Logger } const configFSPath = "/sys/kernel/config" const gadgetPath = "/sys/kernel/config/usb_gadget" -var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel) +var defaultLogger = logging.GetSubsystemLogger("usbgadget") // NewUsbGadget creates a new UsbGadget. func NewUsbGadget(name string, enabledDevices *Devices, config *Config, logger *zerolog.Logger) *UsbGadget { + return newUsbGadget(name, defaultGadgetConfig, enabledDevices, config, logger) +} + +func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDevices *Devices, config *Config, logger *zerolog.Logger) *UsbGadget { if logger == nil { - logger = &defaultLogger + logger = defaultLogger } if enabledDevices == nil { @@ -89,16 +100,19 @@ func NewUsbGadget(name string, enabledDevices *Devices, config *Config, logger * name: name, kvmGadgetPath: path.Join(gadgetPath, name), configC1Path: path.Join(gadgetPath, name, "configs/c.1"), - configMap: defaultGadgetConfig, + configMap: configMap, customConfig: *config, configLock: sync.Mutex{}, keyboardLock: sync.Mutex{}, absMouseLock: sync.Mutex{}, relMouseLock: sync.Mutex{}, + txLock: sync.Mutex{}, enabledDevices: *enabledDevices, lastUserInput: time.Now(), log: logger, + strictMode: config.strictMode, + absMouseAccumulatedWheelY: 0, } if err := g.Init(); err != nil { diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index 28a9e37..7a6d1db 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -3,8 +3,9 @@ package usbgadget import ( "bytes" "fmt" - "os" "path/filepath" + "strconv" + "strings" ) func joinPath(basePath string, paths []string) string { @@ -12,44 +13,68 @@ func joinPath(basePath string, paths []string) string { return filepath.Join(pathArr...) } -func ensureSymlink(linkPath string, target string) error { - if _, err := os.Lstat(linkPath); err == nil { - currentTarget, err := os.Readlink(linkPath) - if err != nil || currentTarget != target { - err = os.Remove(linkPath) - if err != nil { - return fmt.Errorf("failed to remove existing symlink %s: %w", linkPath, err) - } - } - } else if !os.IsNotExist(err) { - return fmt.Errorf("failed to check if symlink exists: %w", err) +func hexToDecimal(hex string) (int64, error) { + decimal, err := strconv.ParseInt(hex, 16, 64) + if err != nil { + return 0, err } - - if err := os.Symlink(target, linkPath); err != nil { - return fmt.Errorf("failed to create symlink from %s to %s: %w", linkPath, target, err) - } - - return nil + return decimal, nil } -func (u *UsbGadget) writeIfDifferent(filePath string, content []byte, permMode os.FileMode) error { - if _, err := os.Stat(filePath); err == nil { - oldContent, err := os.ReadFile(filePath) - if err == nil { - if bytes.Equal(oldContent, content) { - u.log.Trace().Str("path", filePath).Msg("skipping writing to as it already has the correct content") - return nil - } +func decimalToOctal(decimal int64) string { + return fmt.Sprintf("%04o", decimal) +} - if len(oldContent) == len(content)+1 && - bytes.Equal(oldContent[:len(content)], content) && - oldContent[len(content)] == 10 { - u.log.Trace().Str("path", filePath).Msg("skipping writing to as it already has the correct content") - return nil - } +func hexToOctal(hex string) (string, error) { + hex = strings.ToLower(hex) + hex = strings.Replace(hex, "0x", "", 1) //remove 0x or 0X - u.log.Trace().Str("path", filePath).Bytes("old", oldContent).Bytes("new", content).Msg("writing to as it has different content") + decimal, err := hexToDecimal(hex) + if err != nil { + return "", err + } + + // Convert the decimal integer to an octal string. + octal := decimalToOctal(decimal) + return octal, nil +} + +func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) bool { + if bytes.Equal(oldContent, newContent) { + return true + } + + if len(oldContent) == len(newContent)+1 && + bytes.Equal(oldContent[:len(newContent)], newContent) && + oldContent[len(newContent)] == 10 { + return true + } + + if len(newContent) == 4 { + if len(oldContent) < 6 || len(oldContent) > 7 { + return false + } + + if len(oldContent) == 7 && oldContent[6] == 0x0a { + oldContent = oldContent[:6] + } + + oldOctalValue, err := hexToOctal(string(oldContent)) + if err != nil { + return false + } + + if oldOctalValue == string(newContent) { + return true } } - return os.WriteFile(filePath, content, permMode) + + if looserMatch { + oldContentStr := strings.TrimSpace(string(oldContent)) + newContentStr := strings.TrimSpace(string(newContent)) + + return oldContentStr == newContentStr + } + + return false }