From 85f7f6061866ed2a754d923276653bb53bf44d30 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Tue, 28 Oct 2025 08:54:41 +0000 Subject: [PATCH] WIP: OTA refactor --- config.go | 13 + hw.go | 4 +- internal/ota/app.go | 58 ++ internal/ota/logger.go | 5 + internal/ota/ota.go | 211 +++++++ internal/ota/state.go | 209 ++++++ internal/ota/sys.go | 100 +++ internal/ota/utils.go | 166 +++++ jsonrpc.go | 32 +- main.go | 15 +- network.go | 7 +- ota.go | 597 ++---------------- ui/localization/messages/en.json | 1 + ui/src/hooks/useVersion.tsx | 3 +- .../devices.$id.settings.general._index.tsx | 23 +- .../devices.$id.settings.general.update.tsx | 2 +- webrtc.go | 2 +- 17 files changed, 865 insertions(+), 583 deletions(-) create mode 100644 internal/ota/app.go create mode 100644 internal/ota/logger.go create mode 100644 internal/ota/ota.go create mode 100644 internal/ota/state.go create mode 100644 internal/ota/sys.go create mode 100644 internal/ota/utils.go diff --git a/config.go b/config.go index 5a3e7dc8..7dc3db30 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "strconv" + "strings" "sync" "github.com/jetkvm/kvm/internal/confparser" @@ -80,6 +81,7 @@ func (m *KeyboardMacro) Validate() error { type Config struct { CloudURL string `json:"cloud_url"` + UpdateAPIURL string `json:"update_api_url"` CloudAppURL string `json:"cloud_app_url"` CloudToken string `json:"cloud_token"` GoogleIdentity string `json:"google_identity"` @@ -109,6 +111,15 @@ type Config struct { VideoQualityFactor float64 `json:"video_quality_factor"` } +// GetUpdateAPIURL returns the update API URL +func (c *Config) GetUpdateAPIURL() string { + if c.UpdateAPIURL == "" { + return "https://api.jetkvm.com" + } + return strings.TrimSuffix(c.UpdateAPIURL, "/") + "/releases" +} + +// GetDisplayRotation returns the display rotation func (c *Config) GetDisplayRotation() uint16 { rotationInt, err := strconv.ParseUint(c.DisplayRotation, 10, 16) if err != nil { @@ -118,6 +129,7 @@ func (c *Config) GetDisplayRotation() uint16 { return uint16(rotationInt) } +// SetDisplayRotation sets the display rotation func (c *Config) SetDisplayRotation(rotation string) error { _, err := strconv.ParseUint(rotation, 10, 16) if err != nil { @@ -157,6 +169,7 @@ var ( func getDefaultConfig() Config { return Config{ CloudURL: "https://api.jetkvm.com", + UpdateAPIURL: "https://api.jetkvm.com", CloudAppURL: "https://app.jetkvm.com", AutoUpdateEnabled: true, // Set a default value ActiveExtension: "", diff --git a/hw.go b/hw.go index 7797adc1..b6416e25 100644 --- a/hw.go +++ b/hw.go @@ -8,6 +8,8 @@ import ( "strings" "sync" "time" + + "github.com/jetkvm/kvm/internal/ota" ) func extractSerialNumber() (string, error) { @@ -37,7 +39,7 @@ func readOtpEntropy() ([]byte, error) { //nolint:unused return content[0x17:0x1C], nil } -func hwReboot(force bool, postRebootAction *PostRebootAction, delay time.Duration) error { +func hwReboot(force bool, postRebootAction *ota.PostRebootAction, delay time.Duration) error { //nolint:unused logger.Info().Msgf("Reboot requested, rebooting in %d seconds...", delay) writeJSONRPCEvent("willReboot", postRebootAction, currentSession) diff --git a/internal/ota/app.go b/internal/ota/app.go new file mode 100644 index 00000000..482a07de --- /dev/null +++ b/internal/ota/app.go @@ -0,0 +1,58 @@ +package ota + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog" +) + +const ( + appUpdatePath = "/userdata/jetkvm/jetkvm_app.update" +) + +func (s *State) componentUpdateError(prefix string, err error, l *zerolog.Logger) error { + if l == nil { + l = s.l + } + l.Error().Err(err).Msg(prefix) + s.error = fmt.Sprintf("%s: %v", prefix, err) + return err +} + +func (s *State) updateApp(ctx context.Context, appUpdate *componentUpdateStatus) error { + s.mu.Lock() + defer s.mu.Unlock() + + l := s.l.With().Str("path", appUpdatePath).Logger() + + if err := s.downloadFile(ctx, appUpdatePath, appUpdate.url, &appUpdate.downloadProgress); err != nil { + return s.componentUpdateError("Error downloading app update", err, &l) + } + + downloadFinished := time.Now() + appUpdate.downloadFinishedAt = downloadFinished + appUpdate.downloadProgress = 1 + s.onProgressUpdate() + + if err := s.verifyFile( + appUpdatePath, + appUpdate.hash, + &appUpdate.verificationProgress, + ); err != nil { + return s.componentUpdateError("Error verifying app update hash", err, &l) + } + verifyFinished := time.Now() + appUpdate.verifiedAt = verifyFinished + appUpdate.verificationProgress = 1 + appUpdate.updatedAt = verifyFinished + appUpdate.updateProgress = 1 + s.onProgressUpdate() + + l.Info().Msg("App update downloaded") + + s.rebootNeeded = true + + return nil +} diff --git a/internal/ota/logger.go b/internal/ota/logger.go new file mode 100644 index 00000000..a13036de --- /dev/null +++ b/internal/ota/logger.go @@ -0,0 +1,5 @@ +package ota + +import "github.com/jetkvm/kvm/internal/logging" + +var logger = logging.GetSubsystemLogger("ota") diff --git a/internal/ota/ota.go b/internal/ota/ota.go new file mode 100644 index 00000000..b45909ec --- /dev/null +++ b/internal/ota/ota.go @@ -0,0 +1,211 @@ +package ota + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/Masterminds/semver/v3" +) + +// UpdateReleaseAPIEndpoint updates the release API endpoint +func (s *State) UpdateReleaseAPIEndpoint(endpoint string) { + s.releaseAPIEndpoint = endpoint +} + +// GetReleaseAPIEndpoint returns the release API endpoint +func (s *State) GetReleaseAPIEndpoint() string { + return s.releaseAPIEndpoint +} + +func (s *State) fetchUpdateMetadata(ctx context.Context, deviceID string, includePreRelease bool) (*UpdateMetadata, error) { + metadata := &UpdateMetadata{} + + updateURL, err := url.Parse(s.releaseAPIEndpoint) + if err != nil { + return nil, fmt.Errorf("error parsing update metadata URL: %w", err) + } + + query := updateURL.Query() + query.Set("deviceId", deviceID) + query.Set("prerelease", fmt.Sprintf("%v", includePreRelease)) + updateURL.RawQuery = query.Encode() + + logger.Info().Str("url", updateURL.String()).Msg("Checking for updates") + + req, err := http.NewRequestWithContext(ctx, "GET", updateURL.String(), nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + client := s.client() + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error sending request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + err = json.NewDecoder(resp.Body).Decode(metadata) + if err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + return metadata, nil +} + +func (s *State) TryUpdate(ctx context.Context, deviceID string, includePreRelease bool) error { + scopedLogger := s.l.With(). + Str("deviceID", deviceID). + Str("includePreRelease", fmt.Sprintf("%v", includePreRelease)). + Logger() + + scopedLogger.Info().Msg("Trying to update...") + if s.updating { + return fmt.Errorf("update already in progress") + } + + s.updating = true + s.onProgressUpdate() + + defer func() { + s.updating = false + s.onProgressUpdate() + }() + + appUpdate, systemUpdate, err := s.getUpdateStatus(ctx, deviceID, includePreRelease) + if err != nil { + return s.componentUpdateError("Error checking for updates", err, &scopedLogger) + } + + s.metadataFetchedAt = time.Now() + s.onProgressUpdate() + + if appUpdate.available { + appUpdate.pending = true + } + + if systemUpdate.available { + systemUpdate.pending = true + } + + if appUpdate.pending { + scopedLogger.Info(). + Str("url", appUpdate.url). + Str("hash", appUpdate.hash). + Msg("App update available") + + if err := s.updateApp(ctx, appUpdate); err != nil { + return s.componentUpdateError("Error updating app", err, &scopedLogger) + } + } else { + scopedLogger.Info().Msg("App is up to date") + } + + if systemUpdate.pending { + if err := s.updateSystem(ctx, systemUpdate); err != nil { + return s.componentUpdateError("Error updating system", err, &scopedLogger) + } + } else { + scopedLogger.Info().Msg("System is up to date") + } + + if s.rebootNeeded { + scopedLogger.Info().Msg("System Rebooting due to OTA update") + + postRebootAction := &PostRebootAction{ + HealthCheck: "/device/status", + RedirectUrl: fmt.Sprintf("/settings/general/update?version=%s", systemUpdate.version), + } + + if err := s.reboot(true, postRebootAction, 10*time.Second); err != nil { + return s.componentUpdateError("Error requesting reboot", err, &scopedLogger) + } + } + + return nil +} + +func (s *State) getUpdateStatus( + ctx context.Context, + deviceID string, + includePreRelease bool, +) ( + appUpdate *componentUpdateStatus, + systemUpdate *componentUpdateStatus, + err error, +) { + appUpdate = &componentUpdateStatus{} + systemUpdate = &componentUpdateStatus{} + err = nil + + // Get local versions + systemVersionLocal, appVersionLocal, err := s.getLocalVersion() + if err != nil { + return nil, nil, fmt.Errorf("error getting local version: %w", err) + } + appUpdate.localVersion = appVersionLocal.String() + systemUpdate.localVersion = systemVersionLocal.String() + + // Get remote metadata + remoteMetadata, err := s.fetchUpdateMetadata(ctx, deviceID, includePreRelease) + if err != nil { + err = fmt.Errorf("error checking for updates: %w", err) + return + } + appUpdate.url = remoteMetadata.AppURL + appUpdate.hash = remoteMetadata.AppHash + appUpdate.version = remoteMetadata.AppVersion + + systemUpdate.url = remoteMetadata.SystemURL + systemUpdate.hash = remoteMetadata.SystemHash + systemUpdate.version = remoteMetadata.SystemVersion + + // Get remote versions + systemVersionRemote, err := semver.NewVersion(remoteMetadata.SystemVersion) + if err != nil { + err = fmt.Errorf("error parsing remote system version: %w", err) + return + } + systemUpdate.available = systemVersionRemote.GreaterThan(systemVersionLocal) + + appVersionRemote, err := semver.NewVersion(remoteMetadata.AppVersion) + if err != nil { + err = fmt.Errorf("error parsing remote app version: %w, %s", err, remoteMetadata.AppVersion) + return + } + appUpdate.available = appVersionRemote.GreaterThan(appVersionLocal) + + // Handle pre-release updates + isRemoteSystemPreRelease := systemVersionRemote.Prerelease() != "" + isRemoteAppPreRelease := appVersionRemote.Prerelease() != "" + + if isRemoteSystemPreRelease && !includePreRelease { + systemUpdate.available = false + } + if isRemoteAppPreRelease && !includePreRelease { + appUpdate.available = false + } + + s.componentUpdateStatuses["app"] = *appUpdate + s.componentUpdateStatuses["system"] = *systemUpdate + + return +} + +// GetUpdateStatus returns the current update status (for backwards compatibility) +func (s *State) GetUpdateStatus(ctx context.Context, deviceID string, includePreRelease bool) (*UpdateStatus, error) { + _, _, err := s.getUpdateStatus(ctx, deviceID, includePreRelease) + if err != nil { + return nil, fmt.Errorf("error getting update status: %w", err) + } + + return s.ToUpdateStatus(), nil +} diff --git a/internal/ota/state.go b/internal/ota/state.go new file mode 100644 index 00000000..4375a08f --- /dev/null +++ b/internal/ota/state.go @@ -0,0 +1,209 @@ +package ota + +import ( + "net/http" + "sync" + "time" + + "github.com/Masterminds/semver/v3" + "github.com/rs/zerolog" +) + +// UpdateMetadata represents the metadata of an update +type UpdateMetadata struct { + AppVersion string `json:"appVersion"` + AppURL string `json:"appUrl"` + AppHash string `json:"appHash"` + SystemVersion string `json:"systemVersion"` + SystemURL string `json:"systemUrl"` + SystemHash string `json:"systemHash"` +} + +// LocalMetadata represents the local metadata of the system +type LocalMetadata struct { + AppVersion string `json:"appVersion"` + SystemVersion string `json:"systemVersion"` +} + +// UpdateStatus represents the current update status +type UpdateStatus struct { + Local *LocalMetadata `json:"local"` + Remote *UpdateMetadata `json:"remote"` + SystemUpdateAvailable bool `json:"systemUpdateAvailable"` + AppUpdateAvailable bool `json:"appUpdateAvailable"` + + // for backwards compatibility + Error string `json:"error,omitempty"` +} + +// PostRebootAction represents the action to be taken after a reboot +// It is used to redirect the user to a specific page after a reboot +type PostRebootAction struct { + HealthCheck string `json:"healthCheck"` // The health check URL to call after the reboot + RedirectUrl string `json:"redirectUrl"` // The URL to redirect to after the reboot +} + +// componentUpdateStatus represents the status of a component update +type componentUpdateStatus struct { + pending bool + available bool + version string + localVersion string + url string + hash string + downloadProgress float32 + downloadFinishedAt time.Time + verificationProgress float32 + verifiedAt time.Time + updateProgress float32 + updatedAt time.Time + dependsOn []string +} + +// RPCState represents the current OTA state for the RPC API +type RPCState struct { + Updating bool `json:"updating"` + Error string `json:"error,omitempty"` + MetadataFetchedAt time.Time `json:"metadataFetchedAt,omitempty"` + AppUpdatePending bool `json:"appUpdatePending"` + SystemUpdatePending bool `json:"systemUpdatePending"` + AppDownloadProgress float32 `json:"appDownloadProgress,omitempty"` //TODO: implement for progress bar + AppDownloadFinishedAt time.Time `json:"appDownloadFinishedAt,omitempty"` + SystemDownloadProgress float32 `json:"systemDownloadProgress,omitempty"` //TODO: implement for progress bar + SystemDownloadFinishedAt time.Time `json:"systemDownloadFinishedAt,omitempty"` + AppVerificationProgress float32 `json:"appVerificationProgress,omitempty"` + AppVerifiedAt time.Time `json:"appVerifiedAt,omitempty"` + SystemVerificationProgress float32 `json:"systemVerificationProgress,omitempty"` + SystemVerifiedAt time.Time `json:"systemVerifiedAt,omitempty"` + AppUpdateProgress float32 `json:"appUpdateProgress,omitempty"` //TODO: implement for progress bar + AppUpdatedAt time.Time `json:"appUpdatedAt,omitempty"` + SystemUpdateProgress float32 `json:"systemUpdateProgress,omitempty"` //TODO: port rk_ota, then implement + SystemUpdatedAt time.Time `json:"systemUpdatedAt,omitempty"` +} + +// HwRebootFunc is a function that reboots the hardware +type HwRebootFunc func(force bool, postRebootAction *PostRebootAction, delay time.Duration) error + +// GetHTTPClientFunc is a function that returns the HTTP client +type GetHTTPClientFunc func() *http.Client + +// OnStateUpdateFunc is a function that updates the state of the OTA +type OnStateUpdateFunc func(state *RPCState) + +// OnProgressUpdateFunc is a function that updates the progress of the OTA +type OnProgressUpdateFunc func(progress float32) + +// GetLocalVersionFunc is a function that returns the local version of the system and app +type GetLocalVersionFunc func() (systemVersion *semver.Version, appVersion *semver.Version, err error) + +// State represents the current OTA state for the UI +type State struct { + releaseAPIEndpoint string + l *zerolog.Logger + mu sync.Mutex + updating bool + error string + metadataFetchedAt time.Time + rebootNeeded bool + componentUpdateStatuses map[string]componentUpdateStatus + client GetHTTPClientFunc + reboot HwRebootFunc + getLocalVersion GetLocalVersionFunc +} + +// ToUpdateStatus converts the State to the UpdateStatus +func (s *State) ToUpdateStatus() *UpdateStatus { + appUpdate, ok := s.componentUpdateStatuses["app"] + if !ok { + return nil + } + + systemUpdate, ok := s.componentUpdateStatuses["system"] + if !ok { + return nil + } + + return &UpdateStatus{ + Local: &LocalMetadata{ + AppVersion: appUpdate.localVersion, + SystemVersion: systemUpdate.localVersion, + }, + Remote: &UpdateMetadata{ + AppVersion: appUpdate.version, + AppURL: appUpdate.url, + AppHash: appUpdate.hash, + SystemVersion: systemUpdate.version, + SystemURL: systemUpdate.url, + SystemHash: systemUpdate.hash, + }, + SystemUpdateAvailable: systemUpdate.available, + AppUpdateAvailable: appUpdate.available, + Error: s.error, + } +} + +// IsUpdatePending returns true if an update is pending +func (s *State) IsUpdatePending() bool { + return s.updating +} + +// Options represents the options for the OTA state +type Options struct { + Logger *zerolog.Logger + GetHTTPClient GetHTTPClientFunc + GetLocalVersion GetLocalVersionFunc + OnStateUpdate OnStateUpdateFunc + OnProgressUpdate OnProgressUpdateFunc + HwReboot HwRebootFunc + ReleaseAPIEndpoint string +} + +// NewState creates a new OTA state +func NewState(opts Options) *State { + s := &State{ + l: opts.Logger, + client: opts.GetHTTPClient, + reboot: opts.HwReboot, + getLocalVersion: opts.GetLocalVersion, + componentUpdateStatuses: make(map[string]componentUpdateStatus), + releaseAPIEndpoint: opts.ReleaseAPIEndpoint, + } + go s.confirmCurrentSystem() + return s +} + +// ToRPCState converts the State to the RPCState +func (s *State) ToRPCState() *RPCState { + r := &RPCState{ + Updating: s.updating, + Error: s.error, + MetadataFetchedAt: s.metadataFetchedAt, + } + + app, ok := s.componentUpdateStatuses["app"] + if ok { + r.AppUpdatePending = app.pending + r.AppDownloadProgress = app.downloadProgress + r.AppDownloadFinishedAt = app.downloadFinishedAt + r.AppVerificationProgress = app.verificationProgress + r.AppVerifiedAt = app.verifiedAt + r.AppUpdateProgress = app.updateProgress + r.AppUpdatedAt = app.updatedAt + } + + system, ok := s.componentUpdateStatuses["system"] + if ok { + r.SystemUpdatePending = system.pending + r.SystemDownloadProgress = system.downloadProgress + r.SystemDownloadFinishedAt = system.downloadFinishedAt + r.SystemVerificationProgress = system.verificationProgress + r.SystemVerifiedAt = system.verifiedAt + r.SystemUpdateProgress = system.updateProgress + r.SystemUpdatedAt = system.updatedAt + } + + return r +} + +func (s *State) onProgressUpdate() { +} diff --git a/internal/ota/sys.go b/internal/ota/sys.go new file mode 100644 index 00000000..2426c353 --- /dev/null +++ b/internal/ota/sys.go @@ -0,0 +1,100 @@ +package ota + +import ( + "bytes" + "context" + "os/exec" + "time" +) + +const ( + systemUpdatePath = "/userdata/jetkvm/update_system.tar" +) + +func (s *State) updateSystem(ctx context.Context, systemUpdate *componentUpdateStatus) error { + s.mu.Lock() + defer s.mu.Unlock() + + l := s.l.With().Str("path", systemUpdatePath).Logger() + + if err := s.downloadFile(ctx, systemUpdatePath, systemUpdate.url, &systemUpdate.downloadProgress); err != nil { + return s.componentUpdateError("Error downloading system update", err, &l) + } + + downloadFinished := time.Now() + systemUpdate.downloadFinishedAt = downloadFinished + systemUpdate.downloadProgress = 1 + s.onProgressUpdate() + + if err := s.verifyFile( + systemUpdatePath, + systemUpdate.hash, + &systemUpdate.verificationProgress, + ); err != nil { + return s.componentUpdateError("Error verifying system update hash", err, &l) + } + verifyFinished := time.Now() + systemUpdate.verifiedAt = verifyFinished + systemUpdate.verificationProgress = 1 + systemUpdate.updatedAt = verifyFinished + systemUpdate.updateProgress = 1 + s.onProgressUpdate() + + l.Info().Msg("System update downloaded") + + l.Info().Msg("Starting rk_ota command") + + cmd := exec.Command("rk_ota", "--misc=update", "--tar_path=/userdata/jetkvm/update_system.tar", "--save_dir=/userdata/jetkvm/ota_save", "--partition=all") + var b bytes.Buffer + cmd.Stdout = &b + cmd.Stderr = &b + if err := cmd.Start(); err != nil { + return s.componentUpdateError("Error starting rk_ota command", err, &l) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + ticker := time.NewTicker(1800 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if systemUpdate.updateProgress >= 0.99 { + return + } + systemUpdate.updateProgress += 0.01 + if systemUpdate.updateProgress > 0.99 { + systemUpdate.updateProgress = 0.99 + } + s.onProgressUpdate() + case <-ctx.Done(): + return + } + } + }() + + err := cmd.Wait() + cancel() + rkLogger := s.l.With(). + Str("output", b.String()). + Int("exitCode", cmd.ProcessState.ExitCode()).Logger() + if err != nil { + return s.componentUpdateError("Error executing rk_ota command", err, &rkLogger) + } + rkLogger.Info().Msg("rk_ota success") + systemUpdate.updateProgress = 1 + systemUpdate.updatedAt = verifyFinished + s.onProgressUpdate() + + return nil +} + +func (s *State) confirmCurrentSystem() { + output, err := exec.Command("rk_ota", "--misc=now").CombinedOutput() + if err != nil { + s.l.Warn().Str("output", string(output)).Msg("failed to set current partition in A/B setup") + } + s.l.Trace().Str("output", string(output)).Msg("current partition in A/B setup set") +} diff --git a/internal/ota/utils.go b/internal/ota/utils.go new file mode 100644 index 00000000..caa03384 --- /dev/null +++ b/internal/ota/utils.go @@ -0,0 +1,166 @@ +package ota + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "os/exec" +) + +func syncFilesystem() error { + // Flush filesystem buffers to ensure all data is written to disk + if err := exec.Command("sync").Run(); err != nil { + return fmt.Errorf("error flushing filesystem buffers: %w", err) + } + + // Clear the filesystem caches to force a read from disk + if err := os.WriteFile("/proc/sys/vm/drop_caches", []byte("1"), 0644); err != nil { + return fmt.Errorf("error clearing filesystem caches: %w", err) + } + + return nil +} + +func (s *State) downloadFile(ctx context.Context, path string, url string, downloadProgress *float32) error { + if _, err := os.Stat(path); err == nil { + if err := os.Remove(path); err != nil { + return fmt.Errorf("error removing existing file: %w", err) + } + } + + unverifiedPath := path + ".unverified" + if _, err := os.Stat(unverifiedPath); err == nil { + if err := os.Remove(unverifiedPath); err != nil { + return fmt.Errorf("error removing existing unverified file: %w", err) + } + } + + file, err := os.Create(unverifiedPath) + if err != nil { + return fmt.Errorf("error creating file: %w", err) + } + defer file.Close() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + client := s.client() + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("error downloading file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + totalSize := resp.ContentLength + if totalSize <= 0 { + return fmt.Errorf("invalid content length") + } + + var written int64 + buf := make([]byte, 32*1024) + for { + nr, er := resp.Body.Read(buf) + if nr > 0 { + nw, ew := file.Write(buf[0:nr]) + if nw < nr { + return fmt.Errorf("short write: %d < %d", nw, nr) + } + written += int64(nw) + if ew != nil { + return fmt.Errorf("error writing to file: %w", ew) + } + progress := float32(written) / float32(totalSize) + if progress-*downloadProgress >= 0.01 { + *downloadProgress = progress + s.onProgressUpdate() + } + } + if er != nil { + if er == io.EOF { + break + } + return fmt.Errorf("error reading response body: %w", er) + } + } + + file.Close() + + if err := syncFilesystem(); err != nil { + return fmt.Errorf("error syncing filesystem: %w", err) + } + + return nil +} + +func (s *State) verifyFile(path string, expectedHash string, verifyProgress *float32) error { + l := s.l.With().Str("path", path).Logger() + + unverifiedPath := path + ".unverified" + fileToHash, err := os.Open(unverifiedPath) + if err != nil { + return fmt.Errorf("error opening file for hashing: %w", err) + } + defer fileToHash.Close() + + hash := sha256.New() + fileInfo, err := fileToHash.Stat() + if err != nil { + return fmt.Errorf("error getting file info: %w", err) + } + totalSize := fileInfo.Size() + + buf := make([]byte, 32*1024) + verified := int64(0) + + for { + nr, er := fileToHash.Read(buf) + if nr > 0 { + nw, ew := hash.Write(buf[0:nr]) + if nw < nr { + return fmt.Errorf("short write: %d < %d", nw, nr) + } + verified += int64(nw) + if ew != nil { + return fmt.Errorf("error writing to hash: %w", ew) + } + progress := float32(verified) / float32(totalSize) + if progress-*verifyProgress >= 0.01 { + *verifyProgress = progress + s.onProgressUpdate() + } + } + if er != nil { + if er == io.EOF { + break + } + return fmt.Errorf("error reading file: %w", er) + } + } + + hashSum := hash.Sum(nil) + l.Info().Str("hash", hex.EncodeToString(hashSum)).Msg("SHA256 hash of") + + if hex.EncodeToString(hashSum) != expectedHash { + return fmt.Errorf("hash mismatch: %x != %s", hashSum, expectedHash) + } + + if err := os.Rename(unverifiedPath, path); err != nil { + return fmt.Errorf("error renaming file: %w", err) + } + + if err := os.Chmod(path, 0755); err != nil { + return fmt.Errorf("error making file executable: %w", err) + } + + return nil +} diff --git a/jsonrpc.go b/jsonrpc.go index 5ed90a7a..960e0bea 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -19,6 +19,7 @@ import ( "go.bug.st/serial" "github.com/jetkvm/kvm/internal/hidrpc" + "github.com/jetkvm/kvm/internal/ota" "github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/utils" ) @@ -248,9 +249,8 @@ func rpcSetDevChannelState(enabled bool) error { return nil } -func rpcGetUpdateStatus() (*UpdateStatus, error) { - includePreRelease := config.IncludePreRelease - updateStatus, err := GetUpdateStatus(context.Background(), GetDeviceID(), includePreRelease) +func getUpdateStatus(includePreRelease bool) (*ota.UpdateStatus, error) { + updateStatus, err := otaState.GetUpdateStatus(context.Background(), GetDeviceID(), includePreRelease) // to ensure backwards compatibility, // if there's an error, we won't return an error, but we will set the error field if err != nil { @@ -260,15 +260,32 @@ func rpcGetUpdateStatus() (*UpdateStatus, error) { updateStatus.Error = err.Error() } + logger.Info().Interface("updateStatus", updateStatus).Msg("Update status") + return updateStatus, nil } -func rpcGetLocalVersion() (*LocalMetadata, error) { +func rpcGetUpdateStatus() (*ota.UpdateStatus, error) { + return getUpdateStatus(config.IncludePreRelease) +} + +func rpcGetUpdateStatusChannel(channel string) (*ota.UpdateStatus, error) { + switch channel { + case "stable": + return getUpdateStatus(false) + case "dev": + return getUpdateStatus(true) + default: + return nil, fmt.Errorf("invalid channel: %s", channel) + } +} + +func rpcGetLocalVersion() (*ota.LocalMetadata, error) { systemVersion, appVersion, err := GetLocalVersion() if err != nil { return nil, fmt.Errorf("error getting local version: %w", err) } - return &LocalMetadata{ + return &ota.LocalMetadata{ AppVersion: appVersion.String(), SystemVersion: systemVersion.String(), }, nil @@ -277,7 +294,7 @@ func rpcGetLocalVersion() (*LocalMetadata, error) { func rpcTryUpdate() error { includePreRelease := config.IncludePreRelease go func() { - err := TryUpdate(context.Background(), GetDeviceID(), includePreRelease) + err := otaState.TryUpdate(context.Background(), GetDeviceID(), includePreRelease) if err != nil { logger.Warn().Err(err).Msg("failed to try update") } @@ -654,7 +671,7 @@ func rpcGetMassStorageMode() (string, error) { } func rpcIsUpdatePending() (bool, error) { - return IsUpdatePending(), nil + return otaState.IsUpdatePending(), nil } func rpcGetUsbEmulationState() (bool, error) { @@ -1200,6 +1217,7 @@ var rpcHandlers = map[string]RPCHandler{ "setDevChannelState": {Func: rpcSetDevChannelState, Params: []string{"enabled"}}, "getLocalVersion": {Func: rpcGetLocalVersion}, "getUpdateStatus": {Func: rpcGetUpdateStatus}, + "getUpdateStatusChannel": {Func: rpcGetUpdateStatusChannel}, "tryUpdate": {Func: rpcTryUpdate}, "getDevModeState": {Func: rpcGetDevModeState}, "setDevModeState": {Func: rpcSetDevModeState, Params: []string{"enabled"}}, diff --git a/main.go b/main.go index bcc2d73d..b6fc469d 100644 --- a/main.go +++ b/main.go @@ -32,12 +32,6 @@ func Main() { Msg("starting JetKVM") go runWatchdog() - go confirmCurrentSystem() - - initDisplay() - initNative(systemVersionLocal, appVersionLocal) - - http.DefaultClient.Timeout = 1 * time.Minute err = rootcerts.UpdateDefaultTransport() if err != nil { @@ -47,6 +41,13 @@ func Main() { Int("ca_certs_loaded", len(rootcerts.Certs())). Msg("loaded Root CA certificates") + initOta() + + initDisplay() + initNative(systemVersionLocal, appVersionLocal) + + http.DefaultClient.Timeout = 1 * time.Minute + // Initialize network if err := initNetwork(); err != nil { logger.Error().Err(err).Msg("failed to initialize network") @@ -106,7 +107,7 @@ func Main() { } includePreRelease := config.IncludePreRelease - err = TryUpdate(context.Background(), GetDeviceID(), includePreRelease) + err = otaState.TryUpdate(context.Background(), GetDeviceID(), includePreRelease) if err != nil { logger.Warn().Err(err).Msg("failed to auto update") } diff --git a/network.go b/network.go index 846f41f1..25e562a0 100644 --- a/network.go +++ b/network.go @@ -8,6 +8,7 @@ import ( "github.com/jetkvm/kvm/internal/confparser" "github.com/jetkvm/kvm/internal/mdns" "github.com/jetkvm/kvm/internal/network/types" + "github.com/jetkvm/kvm/internal/ota" "github.com/jetkvm/kvm/pkg/nmlite" ) @@ -176,7 +177,7 @@ func setHostname(nm *nmlite.NetworkManager, hostname, domain string) error { return nm.SetHostname(hostname, domain) } -func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (rebootRequired bool, postRebootAction *PostRebootAction) { +func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (rebootRequired bool, postRebootAction *ota.PostRebootAction) { oldDhcpClient := oldConfig.DHCPClient.String l := networkLogger.With(). @@ -200,7 +201,7 @@ func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (re l.Info().Msg("IPv4 mode changed with udhcpc, reboot required") if newIPv4Mode == "static" && oldIPv4Mode != "static" { - postRebootAction = &PostRebootAction{ + postRebootAction = &ota.PostRebootAction{ HealthCheck: fmt.Sprintf("//%s/device/status", newConfig.IPv4Static.Address.String), RedirectTo: fmt.Sprintf("//%s", newConfig.IPv4Static.Address.String), } @@ -217,7 +218,7 @@ func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (re // Handle IP change for redirect (only if both are not nil and IP changed) if newConfig.IPv4Static != nil && oldConfig.IPv4Static != nil && newConfig.IPv4Static.Address.String != oldConfig.IPv4Static.Address.String { - postRebootAction = &PostRebootAction{ + postRebootAction = &ota.PostRebootAction{ HealthCheck: fmt.Sprintf("//%s/device/status", newConfig.IPv4Static.Address.String), RedirectTo: fmt.Sprintf("//%s", newConfig.IPv4Static.Address.String), } diff --git a/ota.go b/ota.go index 5371e428..90bd8f28 100644 --- a/ota.go +++ b/ota.go @@ -1,59 +1,61 @@ package kvm import ( - "bytes" - "context" - "crypto/sha256" - "crypto/tls" - "encoding/hex" - "encoding/json" "fmt" - "io" "net/http" - "net/url" "os" - "os/exec" "strings" - "time" "github.com/Masterminds/semver/v3" - "github.com/gwatts/rootcerts" - "github.com/rs/zerolog" + "github.com/jetkvm/kvm/internal/ota" ) -type UpdateMetadata struct { - AppVersion string `json:"appVersion"` - AppUrl string `json:"appUrl"` - AppHash string `json:"appHash"` - SystemVersion string `json:"systemVersion"` - SystemUrl string `json:"systemUrl"` - SystemHash string `json:"systemHash"` -} - -type LocalMetadata struct { - AppVersion string `json:"appVersion"` - SystemVersion string `json:"systemVersion"` -} - -// UpdateStatus represents the current update status -type UpdateStatus struct { - Local *LocalMetadata `json:"local"` - Remote *UpdateMetadata `json:"remote"` - SystemUpdateAvailable bool `json:"systemUpdateAvailable"` - AppUpdateAvailable bool `json:"appUpdateAvailable"` - - // for backwards compatibility - Error string `json:"error,omitempty"` -} - -const UpdateMetadataUrl = "https://api.jetkvm.com/releases" - var builtAppVersion = "0.1.0+dev" +var otaState *ota.State + +func initOta() { + otaState = ota.NewState(ota.Options{ + Logger: otaLogger, + ReleaseAPIEndpoint: config.GetUpdateAPIURL(), + GetHTTPClient: func() *http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = config.NetworkConfig.GetTransportProxyFunc() + + client := &http.Client{ + Transport: transport, + } + return client + }, + GetLocalVersion: GetLocalVersion, + HwReboot: hwReboot, + OnStateUpdate: func(state *ota.RPCState) { + triggerOTAStateUpdate(state) + }, + OnProgressUpdate: func(progress float32) { + writeJSONRPCEvent("otaProgress", progress, currentSession) + }, + }) +} + +func triggerOTAStateUpdate(state *ota.RPCState) { + go func() { + if currentSession == nil || (otaState == nil && state == nil) { + return + } + if state == nil { + state = otaState.ToRPCState() + } + writeJSONRPCEvent("otaState", state, currentSession) + }() +} + +// GetBuiltAppVersion returns the built-in app version func GetBuiltAppVersion() string { return builtAppVersion } +// GetLocalVersion returns the local version of the system and app func GetLocalVersion() (systemVersion *semver.Version, appVersion *semver.Version, err error) { appVersion, err = semver.NewVersion(builtAppVersion) if err != nil { @@ -72,520 +74,3 @@ func GetLocalVersion() (systemVersion *semver.Version, appVersion *semver.Versio return systemVersion, appVersion, nil } - -func fetchUpdateMetadata(ctx context.Context, deviceId string, includePreRelease bool) (*UpdateMetadata, error) { - metadata := &UpdateMetadata{} - - updateUrl, err := url.Parse(UpdateMetadataUrl) - if err != nil { - return nil, fmt.Errorf("error parsing update metadata URL: %w", err) - } - - query := updateUrl.Query() - query.Set("deviceId", deviceId) - query.Set("prerelease", fmt.Sprintf("%v", includePreRelease)) - updateUrl.RawQuery = query.Encode() - - logger.Info().Str("url", updateUrl.String()).Msg("Checking for updates") - - req, err := http.NewRequestWithContext(ctx, "GET", updateUrl.String(), nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %w", err) - } - - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.Proxy = config.NetworkConfig.GetTransportProxyFunc() - - client := &http.Client{ - Transport: transport, - } - - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("error sending request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - err = json.NewDecoder(resp.Body).Decode(metadata) - if err != nil { - return nil, fmt.Errorf("error decoding response: %w", err) - } - - return metadata, nil -} - -func downloadFile(ctx context.Context, path string, url string, downloadProgress *float32) error { - if _, err := os.Stat(path); err == nil { - if err := os.Remove(path); err != nil { - return fmt.Errorf("error removing existing file: %w", err) - } - } - - unverifiedPath := path + ".unverified" - if _, err := os.Stat(unverifiedPath); err == nil { - if err := os.Remove(unverifiedPath); err != nil { - return fmt.Errorf("error removing existing unverified file: %w", err) - } - } - - file, err := os.Create(unverifiedPath) - if err != nil { - return fmt.Errorf("error creating file: %w", err) - } - defer file.Close() - - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return fmt.Errorf("error creating request: %w", err) - } - - client := http.Client{ - Timeout: 10 * time.Minute, - Transport: &http.Transport{ - Proxy: config.NetworkConfig.GetTransportProxyFunc(), - TLSHandshakeTimeout: 30 * time.Second, - TLSClientConfig: &tls.Config{ - RootCAs: rootcerts.ServerCertPool(), - }, - }, - } - - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("error downloading file: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - totalSize := resp.ContentLength - if totalSize <= 0 { - return fmt.Errorf("invalid content length") - } - - var written int64 - buf := make([]byte, 32*1024) - for { - nr, er := resp.Body.Read(buf) - if nr > 0 { - nw, ew := file.Write(buf[0:nr]) - if nw < nr { - return fmt.Errorf("short file write: %d < %d", nw, nr) - } - written += int64(nw) - if ew != nil { - return fmt.Errorf("error writing to file: %w", ew) - } - progress := float32(written) / float32(totalSize) - if progress-*downloadProgress >= 0.01 { - *downloadProgress = progress - triggerOTAStateUpdate() - } - } - if er != nil { - if er == io.EOF { - break - } - return fmt.Errorf("error reading response body: %w", er) - } - } - - file.Close() - - // Flush filesystem buffers to ensure all data is written to disk - err = exec.Command("sync").Run() - if err != nil { - return fmt.Errorf("error flushing filesystem buffers: %w", err) - } - - // Clear the filesystem caches to force a read from disk - err = os.WriteFile("/proc/sys/vm/drop_caches", []byte("1"), 0644) - if err != nil { - return fmt.Errorf("error clearing filesystem caches: %w", err) - } - - return nil -} - -func verifyFile(path string, expectedHash string, verifyProgress *float32, scopedLogger *zerolog.Logger) error { - if scopedLogger == nil { - scopedLogger = otaLogger - } - - unverifiedPath := path + ".unverified" - fileToHash, err := os.Open(unverifiedPath) - if err != nil { - return fmt.Errorf("error opening file for hashing: %w", err) - } - defer fileToHash.Close() - - hash := sha256.New() - fileInfo, err := fileToHash.Stat() - if err != nil { - return fmt.Errorf("error getting file info: %w", err) - } - totalSize := fileInfo.Size() - - buf := make([]byte, 32*1024) - verified := int64(0) - - for { - nr, er := fileToHash.Read(buf) - if nr > 0 { - nw, ew := hash.Write(buf[0:nr]) - if nw < nr { - return fmt.Errorf("short hash write: %d < %d", nw, nr) - } - verified += int64(nw) - if ew != nil { - return fmt.Errorf("error writing to hash: %w", ew) - } - progress := float32(verified) / float32(totalSize) - if progress-*verifyProgress >= 0.01 { - *verifyProgress = progress - triggerOTAStateUpdate() - } - } - if er != nil { - if er == io.EOF { - break - } - return fmt.Errorf("error reading file: %w", er) - } - } - - // close the file so we can rename below - if err := fileToHash.Close(); err != nil { - return fmt.Errorf("error closing file: %w", err) - } - - hashSum := hex.EncodeToString(hash.Sum(nil)) - scopedLogger.Info().Str("path", path).Str("hash", hashSum).Msg("SHA256 hash of") - - if hashSum != expectedHash { - return fmt.Errorf("hash mismatch: %s != %s", hashSum, expectedHash) - } - - if err := os.Rename(unverifiedPath, path); err != nil { - return fmt.Errorf("error renaming file: %w", err) - } - - if err := os.Chmod(path, 0755); err != nil { - return fmt.Errorf("error making file executable: %w", err) - } - - return nil -} - -type OTAState struct { - Updating bool `json:"updating"` - Error string `json:"error,omitempty"` - MetadataFetchedAt *time.Time `json:"metadataFetchedAt,omitempty"` - AppUpdatePending bool `json:"appUpdatePending"` - SystemUpdatePending bool `json:"systemUpdatePending"` - AppDownloadProgress float32 `json:"appDownloadProgress,omitempty"` //TODO: implement for progress bar - AppDownloadFinishedAt *time.Time `json:"appDownloadFinishedAt,omitempty"` - SystemDownloadProgress float32 `json:"systemDownloadProgress,omitempty"` //TODO: implement for progress bar - SystemDownloadFinishedAt *time.Time `json:"systemDownloadFinishedAt,omitempty"` - AppVerificationProgress float32 `json:"appVerificationProgress,omitempty"` - AppVerifiedAt *time.Time `json:"appVerifiedAt,omitempty"` - SystemVerificationProgress float32 `json:"systemVerificationProgress,omitempty"` - SystemVerifiedAt *time.Time `json:"systemVerifiedAt,omitempty"` - AppUpdateProgress float32 `json:"appUpdateProgress,omitempty"` //TODO: implement for progress bar - AppUpdatedAt *time.Time `json:"appUpdatedAt,omitempty"` - SystemUpdateProgress float32 `json:"systemUpdateProgress,omitempty"` //TODO: port rk_ota, then implement - SystemUpdatedAt *time.Time `json:"systemUpdatedAt,omitempty"` -} - -var otaState = OTAState{} - -func triggerOTAStateUpdate() { - go func() { - if currentSession == nil { - logger.Info().Msg("No active RPC session, skipping update state update") - return - } - writeJSONRPCEvent("otaState", otaState, currentSession) - }() -} - -func TryUpdate(ctx context.Context, deviceId string, includePreRelease bool) error { - scopedLogger := otaLogger.With(). - Str("deviceId", deviceId). - Bool("includePreRelease", includePreRelease). - Logger() - - scopedLogger.Info().Msg("Trying to update...") - if otaState.Updating { - return fmt.Errorf("update already in progress") - } - - otaState = OTAState{ - Updating: true, - } - triggerOTAStateUpdate() - - defer func() { - otaState.Updating = false - triggerOTAStateUpdate() - }() - - updateStatus, err := GetUpdateStatus(ctx, deviceId, includePreRelease) - if err != nil { - otaState.Error = fmt.Sprintf("Error checking for updates: %v", err) - scopedLogger.Error().Err(err).Msg("Error checking for updates") - return fmt.Errorf("error checking for updates: %w", err) - } - - now := time.Now() - otaState.MetadataFetchedAt = &now - otaState.AppUpdatePending = updateStatus.AppUpdateAvailable - otaState.SystemUpdatePending = updateStatus.SystemUpdateAvailable - triggerOTAStateUpdate() - - local := updateStatus.Local - remote := updateStatus.Remote - appUpdateAvailable := updateStatus.AppUpdateAvailable - systemUpdateAvailable := updateStatus.SystemUpdateAvailable - - rebootNeeded := false - - if appUpdateAvailable { - scopedLogger.Info(). - Str("local", local.AppVersion). - Str("remote", remote.AppVersion). - Msg("App update available") - - err := downloadFile(ctx, "/userdata/jetkvm/jetkvm_app.update", remote.AppUrl, &otaState.AppDownloadProgress) - if err != nil { - otaState.Error = fmt.Sprintf("Error downloading app update: %v", err) - scopedLogger.Error().Err(err).Msg("Error downloading app update") - triggerOTAStateUpdate() - return fmt.Errorf("error downloading app update: %w", err) - } - - downloadFinished := time.Now() - otaState.AppDownloadFinishedAt = &downloadFinished - otaState.AppDownloadProgress = 1 - triggerOTAStateUpdate() - - err = verifyFile( - "/userdata/jetkvm/jetkvm_app.update", - remote.AppHash, - &otaState.AppVerificationProgress, - &scopedLogger, - ) - if err != nil { - otaState.Error = fmt.Sprintf("Error verifying app update hash: %v", err) - scopedLogger.Error().Err(err).Msg("Error verifying app update hash") - triggerOTAStateUpdate() - return fmt.Errorf("error verifying app update: %w", err) - } - - verifyFinished := time.Now() - otaState.AppVerifiedAt = &verifyFinished - otaState.AppVerificationProgress = 1 - triggerOTAStateUpdate() - - otaState.AppUpdatedAt = &verifyFinished - otaState.AppUpdateProgress = 1 - triggerOTAStateUpdate() - - scopedLogger.Info().Msg("App update downloaded") - rebootNeeded = true - triggerOTAStateUpdate() - } else { - scopedLogger.Info().Msg("App is up to date") - } - - if systemUpdateAvailable { - scopedLogger.Info(). - Str("local", local.SystemVersion). - Str("remote", remote.SystemVersion). - Msg("System update available") - - err := downloadFile(ctx, "/userdata/jetkvm/update_system.tar", remote.SystemUrl, &otaState.SystemDownloadProgress) - if err != nil { - otaState.Error = fmt.Sprintf("Error downloading system update: %v", err) - scopedLogger.Error().Err(err).Msg("Error downloading system update") - triggerOTAStateUpdate() - return fmt.Errorf("error downloading system update: %w", err) - } - - downloadFinished := time.Now() - otaState.SystemDownloadFinishedAt = &downloadFinished - otaState.SystemDownloadProgress = 1 - triggerOTAStateUpdate() - - err = verifyFile( - "/userdata/jetkvm/update_system.tar", - remote.SystemHash, - &otaState.SystemVerificationProgress, - &scopedLogger, - ) - if err != nil { - otaState.Error = fmt.Sprintf("Error verifying system update hash: %v", err) - scopedLogger.Error().Err(err).Msg("Error verifying system update hash") - triggerOTAStateUpdate() - return fmt.Errorf("error verifying system update: %w", err) - } - - scopedLogger.Info().Msg("System update downloaded") - verifyFinished := time.Now() - otaState.SystemVerifiedAt = &verifyFinished - otaState.SystemVerificationProgress = 1 - triggerOTAStateUpdate() - - scopedLogger.Info().Msg("Starting rk_ota command") - cmd := exec.Command("rk_ota", "--misc=update", "--tar_path=/userdata/jetkvm/update_system.tar", "--save_dir=/userdata/jetkvm/ota_save", "--partition=all") - var b bytes.Buffer - cmd.Stdout = &b - cmd.Stderr = &b - err = cmd.Start() - if err != nil { - otaState.Error = fmt.Sprintf("Error starting rk_ota command: %v", err) - scopedLogger.Error().Err(err).Msg("Error starting rk_ota command") - triggerOTAStateUpdate() - return fmt.Errorf("error starting rk_ota command: %w", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - ticker := time.NewTicker(1800 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if otaState.SystemUpdateProgress >= 0.99 { - return - } - otaState.SystemUpdateProgress += 0.01 - if otaState.SystemUpdateProgress > 0.99 { - otaState.SystemUpdateProgress = 0.99 - } - triggerOTAStateUpdate() - case <-ctx.Done(): - return - } - } - }() - - err = cmd.Wait() - cancel() - output := b.String() - if err != nil { - otaState.Error = fmt.Sprintf("Error executing rk_ota command: %v\nOutput: %s", err, output) - scopedLogger.Error(). - Err(err). - Str("output", output). - Int("exitCode", cmd.ProcessState.ExitCode()). - Msg("Error executing rk_ota command") - triggerOTAStateUpdate() - return fmt.Errorf("error executing rk_ota command: %w\nOutput: %s", err, output) - } - - scopedLogger.Info().Str("output", output).Msg("rk_ota success") - otaState.SystemUpdateProgress = 1 - otaState.SystemUpdatedAt = &verifyFinished - rebootNeeded = true - triggerOTAStateUpdate() - } else { - scopedLogger.Info().Msg("System is up to date") - } - - if rebootNeeded { - scopedLogger.Info().Msg("System Rebooting due to OTA update") - - // Build redirect URL with conditional query parameters - redirectTo := "/settings/general/update" - queryParams := url.Values{} - if systemUpdateAvailable { - queryParams.Set("systemVersion", remote.SystemVersion) - } - if appUpdateAvailable { - queryParams.Set("appVersion", remote.AppVersion) - } - if len(queryParams) > 0 { - redirectTo += "?" + queryParams.Encode() - } - - postRebootAction := &PostRebootAction{ - HealthCheck: "/device/status", - RedirectTo: redirectTo, - } - - if err := hwReboot(true, postRebootAction, 10*time.Second); err != nil { - return fmt.Errorf("error requesting reboot: %w", err) - } - } - - return nil -} - -func GetUpdateStatus(ctx context.Context, deviceId string, includePreRelease bool) (*UpdateStatus, error) { - updateStatus := &UpdateStatus{} - - // Get local versions - systemVersionLocal, appVersionLocal, err := GetLocalVersion() - if err != nil { - return updateStatus, fmt.Errorf("error getting local version: %w", err) - } - updateStatus.Local = &LocalMetadata{ - AppVersion: appVersionLocal.String(), - SystemVersion: systemVersionLocal.String(), - } - - // Get remote metadata - remoteMetadata, err := fetchUpdateMetadata(ctx, deviceId, includePreRelease) - if err != nil { - return updateStatus, fmt.Errorf("error checking for updates: %w", err) - } - updateStatus.Remote = remoteMetadata - - // Get remote versions - systemVersionRemote, err := semver.NewVersion(remoteMetadata.SystemVersion) - if err != nil { - return updateStatus, fmt.Errorf("error parsing remote system version: %w", err) - } - appVersionRemote, err := semver.NewVersion(remoteMetadata.AppVersion) - if err != nil { - return updateStatus, fmt.Errorf("error parsing remote app version: %w, %s", err, remoteMetadata.AppVersion) - } - - updateStatus.SystemUpdateAvailable = systemVersionRemote.GreaterThan(systemVersionLocal) - updateStatus.AppUpdateAvailable = appVersionRemote.GreaterThan(appVersionLocal) - - // Handle pre-release updates - isRemoteSystemPreRelease := systemVersionRemote.Prerelease() != "" - isRemoteAppPreRelease := appVersionRemote.Prerelease() != "" - - if isRemoteSystemPreRelease && !includePreRelease { - updateStatus.SystemUpdateAvailable = false - } - if isRemoteAppPreRelease && !includePreRelease { - updateStatus.AppUpdateAvailable = false - } - - return updateStatus, nil -} - -func IsUpdatePending() bool { - return otaState.Updating -} - -// make sure our current a/b partition is set as default -func confirmCurrentSystem() { - output, err := exec.Command("rk_ota", "--misc=now").CombinedOutput() - if err != nil { - logger.Warn().Str("output", string(output)).Msg("failed to set current partition in A/B setup") - } -} diff --git a/ui/localization/messages/en.json b/ui/localization/messages/en.json index 0356e8e5..78b04538 100644 --- a/ui/localization/messages/en.json +++ b/ui/localization/messages/en.json @@ -242,6 +242,7 @@ "general_auto_update_error": "Failed to set auto-update: {error}", "general_auto_update_title": "Auto Update", "general_check_for_updates": "Check for Updates", + "general_check_for_stable_updates": "Downgrade", "general_page_description": "Configure device settings and update preferences", "general_reboot_description": "Do you want to proceed with rebooting the system?", "general_reboot_device": "Reboot Device", diff --git a/ui/src/hooks/useVersion.tsx b/ui/src/hooks/useVersion.tsx index 94c2f99d..78bbc313 100644 --- a/ui/src/hooks/useVersion.tsx +++ b/ui/src/hooks/useVersion.tsx @@ -1,4 +1,4 @@ -import { useCallback } from "react"; +import { useCallback, useMemo } from "react"; import { useDeviceStore } from "@/hooks/stores"; import { JsonRpcError, RpcMethodNotFound } from "@/hooks/useJsonRpc"; @@ -53,5 +53,6 @@ export function useVersion() { getLocalVersion, appVersion, systemVersion, + isOnDevVersion, }; } diff --git a/ui/src/routes/devices.$id.settings.general._index.tsx b/ui/src/routes/devices.$id.settings.general._index.tsx index 86e92bcd..7f70f2aa 100644 --- a/ui/src/routes/devices.$id.settings.general._index.tsx +++ b/ui/src/routes/devices.$id.settings.general._index.tsx @@ -12,12 +12,13 @@ import notifications from "@/notifications"; import { getLocale, setLocale, locales, baseLocale } from '@localizations/runtime.js'; import { m } from "@localizations/messages.js"; import { deleteCookie, map_locale_code_to_name } from "@/utils"; +import { useVersion } from "@hooks/useVersion"; export default function SettingsGeneralRoute() { const { send } = useJsonRpc(); const { navigateTo } = useDeviceUiNavigation(); const [autoUpdate, setAutoUpdate] = useState(true); - + const { isOnDevVersion } = useVersion(); const currentVersions = useDeviceStore(state => { const { appVersion, systemVersion } = state; if (!appVersion || !systemVersion) return null; @@ -48,10 +49,10 @@ export default function SettingsGeneralRoute() { const localeOptions = useMemo(() => { return ["", ...locales] .map((code) => { - const [localizedName, nativeName] = map_locale_code_to_name(currentLocale, code); - // don't repeat the name if it's the same in both locales (or blank) - const label = nativeName && nativeName !== localizedName ? `${localizedName} - ${nativeName}` : localizedName; - return { value: code, label: label } + const [localizedName, nativeName] = map_locale_code_to_name(currentLocale, code); + // don't repeat the name if it's the same in both locales (or blank) + const label = nativeName && nativeName !== localizedName ? `${localizedName} - ${nativeName}` : localizedName; + return { value: code, label: label } }); }, [currentLocale]); @@ -74,6 +75,10 @@ export default function SettingsGeneralRoute() { notifications.success(m.locale_change_success({ locale: validLocale || m.locale_auto() })); }; + const downgradeAvailable = useMemo(() => { + return isOnDevVersion; + }, [isOnDevVersion]); + return (
} /> -
+
+ {downgradeAvailable &&