From 87eb555fa23168bd92a18592324cfbb3f2b73909 Mon Sep 17 00:00:00 2001 From: Aveline <352441+ym@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:20:59 +0100 Subject: [PATCH] refactor: OTA (#912) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Adam Shiervani --- .vscode/settings.json | 18 +- config.go | 19 +- hw.go | 36 + internal/ota/app.go | 45 ++ internal/ota/errors.go | 24 + internal/ota/ota.go | 429 ++++++++++++ internal/ota/ota_test.go | 261 +++++++ internal/ota/rpc.go | 167 +++++ internal/ota/state.go | 215 ++++++ internal/ota/sys.go | 101 +++ internal/ota/testdata/ota.schema.json | 159 +++++ .../ota/testdata/ota/app_only_downgrade.json | 34 + .../ota/testdata/ota/app_only_upgrade.json | 33 + internal/ota/testdata/ota/both_downgrade.json | 37 + internal/ota/testdata/ota/both_upgrade.json | 34 + internal/ota/testdata/ota/no_components.json | 32 + .../testdata/ota/system_only_downgrade.json | 34 + .../ota/testdata/ota/system_only_upgrade.json | 33 + internal/ota/testdata/ota/without_certs.json | 17 + internal/ota/utils.go | 193 ++++++ jsonrpc.go | 57 +- main.go | 15 +- network.go | 76 ++- ota.go | 644 ++++-------------- ui/src/components/NestedSettingsGroup.tsx | 22 + ui/src/hooks/useVersion.tsx | 13 + .../devices.$id.settings.access._index.tsx | 6 +- .../routes/devices.$id.settings.advanced.tsx | 298 ++++++-- .../devices.$id.settings.general._index.tsx | 3 +- .../devices.$id.settings.general.reboot.tsx | 11 +- .../devices.$id.settings.general.update.tsx | 137 +++- .../routes/devices.$id.settings.hardware.tsx | 2 +- ui/src/routes/devices.$id.settings.video.tsx | 5 +- ui/src/utils/jsonrpc.ts | 31 +- webrtc.go | 11 +- 35 files changed, 2560 insertions(+), 692 deletions(-) create mode 100644 internal/ota/app.go create mode 100644 internal/ota/errors.go create mode 100644 internal/ota/ota.go create mode 100644 internal/ota/ota_test.go create mode 100644 internal/ota/rpc.go create mode 100644 internal/ota/state.go create mode 100644 internal/ota/sys.go create mode 100644 internal/ota/testdata/ota.schema.json create mode 100644 internal/ota/testdata/ota/app_only_downgrade.json create mode 100644 internal/ota/testdata/ota/app_only_upgrade.json create mode 100644 internal/ota/testdata/ota/both_downgrade.json create mode 100644 internal/ota/testdata/ota/both_upgrade.json create mode 100644 internal/ota/testdata/ota/no_components.json create mode 100644 internal/ota/testdata/ota/system_only_downgrade.json create mode 100644 internal/ota/testdata/ota/system_only_upgrade.json create mode 100644 internal/ota/testdata/ota/without_certs.json create mode 100644 internal/ota/utils.go create mode 100644 ui/src/components/NestedSettingsGroup.tsx diff --git a/.vscode/settings.json b/.vscode/settings.json index a86e6b63..5aeb206a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,21 @@ "cva", "cx" ], - "git.ignoreLimitWarning": true + "gopls": { + "build.buildFlags": [ + "-tags", + "synctrace" + ] + }, + "git.ignoreLimitWarning": true, + "cmake.sourceDirectory": "/workspaces/kvm-static-ip/internal/native/cgo", + "cmake.ignoreCMakeListsMissing": true, + "json.schemas": [ + { + "fileMatch": [ + "/internal/ota/testdata/ota/*.json" + ], + "url": "./internal/ota/testdata/ota.schema.json" + } + ] } \ No newline at end of file diff --git a/config.go b/config.go index 9049d980..42c1d427 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "strconv" + "strings" "sync" "github.com/jetkvm/kvm/internal/confparser" @@ -15,6 +16,10 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) +const ( + DefaultAPIURL = "https://api.jetkvm.com" +) + type WakeOnLanDevice struct { Name string `json:"name"` MacAddress string `json:"macAddress"` @@ -80,6 +85,7 @@ func (m *KeyboardMacro) Validate() error { type Config struct { CloudURL string `json:"cloud_url"` + UpdateAPIURL string `json:"update_api_url"` CloudAppURL string `json:"cloud_app_url"` CloudToken string `json:"cloud_token"` GoogleIdentity string `json:"google_identity"` @@ -109,6 +115,15 @@ type Config struct { VideoQualityFactor float64 `json:"video_quality_factor"` } +// GetUpdateAPIURL returns the update API URL +func (c *Config) GetUpdateAPIURL() string { + if c.UpdateAPIURL == "" { + return DefaultAPIURL + } + return strings.TrimSuffix(c.UpdateAPIURL, "/") + "/releases" +} + +// GetDisplayRotation returns the display rotation func (c *Config) GetDisplayRotation() uint16 { rotationInt, err := strconv.ParseUint(c.DisplayRotation, 10, 16) if err != nil { @@ -118,6 +133,7 @@ func (c *Config) GetDisplayRotation() uint16 { return uint16(rotationInt) } +// SetDisplayRotation sets the display rotation func (c *Config) SetDisplayRotation(rotation string) error { _, err := strconv.ParseUint(rotation, 10, 16) if err != nil { @@ -156,7 +172,8 @@ var ( func getDefaultConfig() Config { return Config{ - CloudURL: "https://api.jetkvm.com", + CloudURL: DefaultAPIURL, + UpdateAPIURL: DefaultAPIURL, CloudAppURL: "https://app.jetkvm.com", AutoUpdateEnabled: true, // Set a default value ActiveExtension: "", diff --git a/hw.go b/hw.go index 20d88ebf..f1670262 100644 --- a/hw.go +++ b/hw.go @@ -7,6 +7,8 @@ import ( "strings" "sync" "time" + + "github.com/jetkvm/kvm/internal/ota" ) func extractSerialNumber() (string, error) { @@ -28,12 +30,46 @@ func extractSerialNumber() (string, error) { return matches[1], nil } +<<<<<<< HEAD func readOtpEntropy() ([]byte, error) { //nolint:unused content, err := os.ReadFile("/sys/bus/nvmem/devices/rockchip-otp0/nvmem") if err != nil { return nil, err } return content[0x17:0x1C], nil +======= +func hwReboot(force bool, postRebootAction *ota.PostRebootAction, delay time.Duration) error { + logger.Info().Dur("delayMs", delay).Msg("reboot requested") + + writeJSONRPCEvent("willReboot", postRebootAction, currentSession) + time.Sleep(1 * time.Second) // Wait for the JSONRPCEvent to be sent + + nativeInstance.SwitchToScreenIfDifferent("rebooting_screen") + if delay > 1*time.Second { + time.Sleep(delay - 1*time.Second) // wait requested extra settle time + } + + args := []string{} + if force { + args = append(args, "-f") + } + + cmd := exec.Command("reboot", args...) + err := cmd.Start() + if err != nil { + logger.Error().Err(err).Msg("failed to reboot") + switchToMainScreen() + return fmt.Errorf("failed to reboot: %w", err) + } + + // If the reboot command is successful, exit the program after 5 seconds + go func() { + time.Sleep(5 * time.Second) + os.Exit(0) + }() + + return nil +>>>>>>> 752fb55 (refactor: OTA (#912)) } var deviceID string diff --git a/internal/ota/app.go b/internal/ota/app.go new file mode 100644 index 00000000..55caa8e8 --- /dev/null +++ b/internal/ota/app.go @@ -0,0 +1,45 @@ +package ota + +import ( + "context" + "time" +) + +const ( + appUpdatePath = "/userdata/jetkvm/jetkvm_app.update" +) + +// DO NOT call it directly, it's not thread safe +// Mutex is currently held by the caller, e.g. doUpdate +func (s *State) updateApp(ctx context.Context, appUpdate *componentUpdateStatus) error { + l := s.l.With().Str("path", appUpdatePath).Logger() + + if err := s.downloadFile(ctx, appUpdatePath, appUpdate.url, "app"); err != nil { + return s.componentUpdateError("Error downloading app update", err, &l) + } + + downloadFinished := time.Now() + appUpdate.downloadFinishedAt = downloadFinished + appUpdate.downloadProgress = 1 + s.triggerComponentUpdateState("app", appUpdate) + + if err := s.verifyFile( + appUpdatePath, + appUpdate.hash, + &appUpdate.verificationProgress, + ); err != nil { + return s.componentUpdateError("Error verifying app update hash", err, &l) + } + verifyFinished := time.Now() + appUpdate.verifiedAt = verifyFinished + appUpdate.verificationProgress = 1 + appUpdate.updatedAt = verifyFinished + appUpdate.updateProgress = 1 + s.triggerComponentUpdateState("app", appUpdate) + + l.Info().Msg("App update downloaded") + + s.rebootNeeded = true + + return nil +} diff --git a/internal/ota/errors.go b/internal/ota/errors.go new file mode 100644 index 00000000..a1d0b4c5 --- /dev/null +++ b/internal/ota/errors.go @@ -0,0 +1,24 @@ +package ota + +import ( + "errors" + "fmt" + + "github.com/rs/zerolog" +) + +var ( + // ErrVersionNotFound is returned when the specified version is not found + ErrVersionNotFound = errors.New("specified version not found") +) + +func (s *State) componentUpdateError(prefix string, err error, l *zerolog.Logger) error { + if l == nil { + l = s.l + } + l.Error().Err(err).Msg(prefix) + s.error = fmt.Sprintf("%s: %v", prefix, err) + s.updating = false + s.triggerStateUpdate() + return err +} diff --git a/internal/ota/ota.go b/internal/ota/ota.go new file mode 100644 index 00000000..52cbb6e2 --- /dev/null +++ b/internal/ota/ota.go @@ -0,0 +1,429 @@ +package ota + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptrace" + "net/url" + "time" + + "github.com/rs/zerolog" +) + +// HttpClient is the interface for the HTTP client +type HttpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// UpdateReleaseAPIEndpoint updates the release API endpoint +func (s *State) UpdateReleaseAPIEndpoint(endpoint string) { + s.releaseAPIEndpoint = endpoint +} + +// GetReleaseAPIEndpoint returns the release API endpoint +func (s *State) GetReleaseAPIEndpoint() string { + return s.releaseAPIEndpoint +} + +// getUpdateURL returns the update URL for the given parameters +func (s *State) getUpdateURL(params UpdateParams) (string, error, bool) { + updateURL, err := url.Parse(s.releaseAPIEndpoint) + if err != nil { + return "", fmt.Errorf("error parsing update metadata URL: %w", err), false + } + + isCustomVersion := false + + query := updateURL.Query() + query.Set("deviceId", params.DeviceID) + query.Set("prerelease", fmt.Sprintf("%v", params.IncludePreRelease)) + + // set the custom versions if they are specified + for component, constraint := range params.Components { + if constraint == "" { + continue + } + + query.Set(component+"Version", constraint) + isCustomVersion = true + } + + updateURL.RawQuery = query.Encode() + + return updateURL.String(), nil, isCustomVersion +} + +// newHTTPRequestWithTrace creates a new HTTP request with a trace logger +// TODO: use OTEL instead of doing this manually +func (s *State) newHTTPRequestWithTrace(ctx context.Context, method, url string, body io.Reader, logger func() *zerolog.Event) (*http.Request, error) { + localCtx := ctx + if s.l.GetLevel() <= zerolog.TraceLevel { + if logger == nil { + logger = func() *zerolog.Event { return s.l.Trace() } + } + + l := func() *zerolog.Event { return logger().Str("url", url).Str("method", method) } + localCtx = httptrace.WithClientTrace(localCtx, &httptrace.ClientTrace{ + GetConn: func(hostPort string) { l().Str("hostPort", hostPort).Msg("[conn] starting to create conn") }, + GotConn: func(info httptrace.GotConnInfo) { l().Interface("info", info).Msg("[conn] connection established") }, + PutIdleConn: func(err error) { l().Err(err).Msg("[conn] connection returned to idle pool") }, + GotFirstResponseByte: func() { l().Msg("[resp] first response byte received") }, + Got100Continue: func() { l().Msg("[resp] 100 continue received") }, + DNSStart: func(info httptrace.DNSStartInfo) { l().Interface("info", info).Msg("[dns] starting to look up dns") }, + DNSDone: func(info httptrace.DNSDoneInfo) { l().Interface("info", info).Msg("[dns] done looking up dns") }, + ConnectStart: func(network, addr string) { + l().Str("network", network).Str("addr", addr).Msg("[tcp] starting tcp connection") + }, + ConnectDone: func(network, addr string, err error) { + l().Str("network", network).Str("addr", addr).Err(err).Msg("[tcp] tcp connection created") + }, + TLSHandshakeStart: func() { l().Msg("[tls] handshake started") }, + TLSHandshakeDone: func(state tls.ConnectionState, err error) { + l(). + Str("tlsVersion", tls.VersionName(state.Version)). + Str("cipherSuite", tls.CipherSuiteName(state.CipherSuite)). + Str("negotiatedProtocol", state.NegotiatedProtocol). + Str("serverName", state.ServerName). + Err(err).Msg("[tls] handshake done") + }, + }) + } + + return http.NewRequestWithContext(localCtx, method, url, body) +} + +func (s *State) fetchUpdateMetadata(ctx context.Context, params UpdateParams) (*UpdateMetadata, error) { + metadata := &UpdateMetadata{} + + logger := s.l.With().Logger() + if params.RequestID != "" { + logger = logger.With().Str("requestID", params.RequestID).Logger() + } + t := time.Now() + traceLogger := func() *zerolog.Event { + return logger.Trace().Dur("duration", time.Since(t)) + } + + url, err, isCustomVersion := s.getUpdateURL(params) + traceLogger().Err(err). + Msg("fetchUpdateMetadata: getUpdateURL") + if err != nil { + return nil, fmt.Errorf("error getting update URL: %w", err) + } + + traceLogger(). + Str("url", url). + Msg("fetching update metadata") + + req, err := s.newHTTPRequestWithTrace(ctx, "GET", url, nil, traceLogger) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + client := s.client() + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error sending request: %w", err) + } + defer resp.Body.Close() + + traceLogger(). + Int("status", resp.StatusCode). + Msg("fetchUpdateMetadata: response") + + if isCustomVersion && resp.StatusCode == http.StatusNotFound { + return nil, ErrVersionNotFound + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + err = json.NewDecoder(resp.Body).Decode(metadata) + if err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + traceLogger(). + Msg("fetchUpdateMetadata: completed") + + return metadata, nil +} + +func (s *State) triggerStateUpdate() { + s.onStateUpdate(s.ToRPCState()) +} + +func (s *State) triggerComponentUpdateState(component string, update *componentUpdateStatus) { + s.componentUpdateStatuses[component] = *update + s.triggerStateUpdate() +} + +// TryUpdate tries to update the given components +// if the update is already in progress, it returns an error +func (s *State) TryUpdate(ctx context.Context, params UpdateParams) error { + locked := s.mu.TryLock() + if !locked { + return fmt.Errorf("update already in progress") + } + + return s.doUpdate(ctx, params) +} + +// before calling doUpdate, the caller must have locked the mutex +// otherwise a runtime error will occur +func (s *State) doUpdate(ctx context.Context, params UpdateParams) error { + defer s.mu.Unlock() + + scopedLogger := s.l.With(). + Interface("params", params). + Logger() + + scopedLogger.Info().Msg("checking for updates") + if s.updating { + return fmt.Errorf("update already in progress") + } + + s.updating = true + s.triggerStateUpdate() + + if len(params.Components) == 0 { + params.Components = defaultComponents + } + + _, shouldUpdateApp := params.Components["app"] + _, shouldUpdateSystem := params.Components["system"] + + if !shouldUpdateApp && !shouldUpdateSystem { + return s.componentUpdateError( + "Update aborted: no components were specified to update. Requested components: ", + fmt.Errorf("%v", params.Components), + &scopedLogger, + ) + } + + appUpdate, systemUpdate, err := s.getUpdateStatus(ctx, params) + if err != nil { + return s.componentUpdateError("Error checking for updates", err, &scopedLogger) + } + + s.metadataFetchedAt = time.Now() + s.triggerStateUpdate() + + if shouldUpdateApp && appUpdate.available { + appUpdate.pending = true + s.updating = true + s.triggerComponentUpdateState("app", appUpdate) + } + + if shouldUpdateSystem && systemUpdate.available { + systemUpdate.pending = true + s.updating = true + s.triggerComponentUpdateState("system", systemUpdate) + } + + scopedLogger.Trace().Bool("pending", appUpdate.pending).Msg("Checking for app update") + + if appUpdate.pending { + scopedLogger.Info(). + Str("url", appUpdate.url). + Str("hash", appUpdate.hash). + Msg("App update available") + + if err := s.updateApp(ctx, appUpdate); err != nil { + return s.componentUpdateError("Error updating app", err, &scopedLogger) + } + } else { + scopedLogger.Info().Msg("App is up to date") + } + + scopedLogger.Trace().Bool("pending", systemUpdate.pending).Msg("Checking for system update") + + if systemUpdate.pending { + if err := s.updateSystem(ctx, systemUpdate); err != nil { + return s.componentUpdateError("Error updating system", err, &scopedLogger) + } + } else { + scopedLogger.Info().Msg("System is up to date") + } + + if s.rebootNeeded { + if appUpdate.customVersionUpdate || systemUpdate.customVersionUpdate { + scopedLogger.Info().Msg("disabling auto-update due to custom version update") + // If they are explicitly updating a custom version, we assume they want to disable auto-update + if _, err := s.setAutoUpdate(false); err != nil { + scopedLogger.Warn().Err(err).Msg("Failed to disable auto-update") + } + } + + scopedLogger.Info().Msg("System Rebooting due to OTA update") + + redirectUrl := "/settings/general/update" + + if params.ResetConfig { + scopedLogger.Info().Msg("Resetting config") + if err := s.resetConfig(); err != nil { + return s.componentUpdateError("Error resetting config", err, &scopedLogger) + } + redirectUrl = "/welcome" + } + + postRebootAction := &PostRebootAction{ + HealthCheck: "/device/status", + RedirectTo: redirectUrl, + } + + // REBOOT_REDIRECT_DELAY_MS is 7 seconds in the UI, + // it means that healthCheckUrl will be called after 7 seconds that we send willReboot JSONRPC event + // so we need to reboot it within 7 seconds to avoid it being called before the device is rebooted + if err := s.reboot(true, postRebootAction, 5*time.Second); err != nil { + return s.componentUpdateError("Error requesting reboot", err, &scopedLogger) + } + } + + // We don't need set the updating flag to false here. Either it will; + // - set to false by the componentUpdateError function + // - device will reboot + return nil +} + +// UpdateParams represents the parameters for the update +type UpdateParams struct { + DeviceID string `json:"deviceID"` + Components map[string]string `json:"components"` + IncludePreRelease bool `json:"includePreRelease"` + ResetConfig bool `json:"resetConfig"` + // RequestID is a unique identifier for the update request + // When it's set, detailed trace logs will be enabled (if the log level is Trace) + RequestID string +} + +// getUpdateStatus gets the update status for the given components +// and updates the componentUpdateStatuses map +func (s *State) getUpdateStatus( + ctx context.Context, + params UpdateParams, +) ( + appUpdate *componentUpdateStatus, + systemUpdate *componentUpdateStatus, + err error, +) { + appUpdate = &componentUpdateStatus{} + systemUpdate = &componentUpdateStatus{} + + if currentAppUpdate, ok := s.componentUpdateStatuses["app"]; ok { + appUpdate = ¤tAppUpdate + } + + if currentSystemUpdate, ok := s.componentUpdateStatuses["system"]; ok { + systemUpdate = ¤tSystemUpdate + } + + err = s.checkUpdateStatus(ctx, params, appUpdate, systemUpdate) + if err != nil { + return nil, nil, err + } + + s.componentUpdateStatuses["app"] = *appUpdate + s.componentUpdateStatuses["system"] = *systemUpdate + + return appUpdate, systemUpdate, nil +} + +// checkUpdateStatus checks the update status for the given components +func (s *State) checkUpdateStatus( + ctx context.Context, + params UpdateParams, + appUpdateStatus *componentUpdateStatus, + systemUpdateStatus *componentUpdateStatus, +) error { + // get the local versions + systemVersionLocal, appVersionLocal, err := s.getLocalVersion() + if err != nil { + return fmt.Errorf("error getting local version: %w", err) + } + appUpdateStatus.localVersion = appVersionLocal.String() + systemUpdateStatus.localVersion = systemVersionLocal.String() + + logger := s.l.With().Logger() + if params.RequestID != "" { + logger = logger.With().Str("requestID", params.RequestID).Logger() + } + t := time.Now() + + logger.Trace(). + Str("appVersionLocal", appVersionLocal.String()). + Str("systemVersionLocal", systemVersionLocal.String()). + Dur("duration", time.Since(t)). + Msg("checkUpdateStatus: getLocalVersion") + + // fetch the remote metadata + remoteMetadata, err := s.fetchUpdateMetadata(ctx, params) + if err != nil { + if err == ErrVersionNotFound || errors.Unwrap(err) == ErrVersionNotFound { + err = ErrVersionNotFound + } else { + err = fmt.Errorf("error checking for updates: %w", err) + } + return err + } + + logger.Trace(). + Interface("remoteMetadata", remoteMetadata). + Dur("duration", time.Since(t)). + Msg("checkUpdateStatus: fetchUpdateMetadata") + + // parse the remote metadata to the componentUpdateStatuses + if err := remoteMetadataToComponentStatus( + remoteMetadata, + "app", + appUpdateStatus, + params, + ); err != nil { + return fmt.Errorf("error parsing remote app version: %w", err) + } + + if err := remoteMetadataToComponentStatus( + remoteMetadata, + "system", + systemUpdateStatus, + params, + ); err != nil { + return fmt.Errorf("error parsing remote system version: %w", err) + } + + if s.l.GetLevel() <= zerolog.TraceLevel { + appUpdateStatus.getZerologLogger(&logger).Trace().Msg("checkUpdateStatus: remoteMetadataToComponentStatus [app]") + systemUpdateStatus.getZerologLogger(&logger).Trace().Msg("checkUpdateStatus: remoteMetadataToComponentStatus [system]") + } + + logger.Trace(). + Dur("duration", time.Since(t)). + Msg("checkUpdateStatus: completed") + + return nil +} + +// GetUpdateStatus returns the current update status (for backwards compatibility) +func (s *State) GetUpdateStatus(ctx context.Context, params UpdateParams) (*UpdateStatus, error) { + // if no components are specified, use the default components + // we should remove this once app router feature is released + if len(params.Components) == 0 { + params.Components = defaultComponents + } + + appUpdateStatus := componentUpdateStatus{} + systemUpdateStatus := componentUpdateStatus{} + err := s.checkUpdateStatus(ctx, params, &appUpdateStatus, &systemUpdateStatus) + if err != nil { + return nil, fmt.Errorf("error getting update status: %w", err) + } + + return toUpdateStatus(&appUpdateStatus, &systemUpdateStatus, ""), nil +} diff --git a/internal/ota/ota_test.go b/internal/ota/ota_test.go new file mode 100644 index 00000000..2c8ce661 --- /dev/null +++ b/internal/ota/ota_test.go @@ -0,0 +1,261 @@ +package ota + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "embed" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "testing" + "time" + + "github.com/Masterminds/semver/v3" + "github.com/gwatts/rootcerts" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +//go:embed testdata/ota +var testDataFS embed.FS + +const pseudoDeviceID = "golang-test" +const releaseAPIEndpoint = "https://api.jetkvm.com/releases" + +type testData struct { + Name string `json:"name"` + WithoutCerts bool `json:"withoutCerts"` + RemoteMetadata []struct { + Code int `json:"code"` + Params map[string]string `json:"params"` + Data UpdateMetadata `json:"data"` + } `json:"remoteMetadata"` + LocalMetadata struct { + SystemVersion string `json:"systemVersion"` + AppVersion string `json:"appVersion"` + } `json:"localMetadata"` + UpdateParams UpdateParams `json:"updateParams"` + Expected struct { + System bool `json:"system"` + App bool `json:"app"` + Error string `json:"error,omitempty"` + } `json:"expected"` +} + +func (d *testData) ToFixtures(t *testing.T) map[string]mockData { + fixtures := make(map[string]mockData) + for _, resp := range d.RemoteMetadata { + url, err := url.Parse(releaseAPIEndpoint) + if err != nil { + t.Fatalf("failed to parse release API endpoint: %v", err) + } + query := url.Query() + query.Set("deviceId", pseudoDeviceID) + for key, value := range resp.Params { + query.Set(key, value) + } + url.RawQuery = query.Encode() + fixtures[url.String()] = mockData{ + Metadata: &resp.Data, + StatusCode: resp.Code, + } + } + return fixtures +} + +func (d *testData) ToUpdateParams() UpdateParams { + d.UpdateParams.DeviceID = pseudoDeviceID + return d.UpdateParams +} + +func loadTestData(t *testing.T, filename string) *testData { + f, err := testDataFS.ReadFile(filepath.Join("testdata", "ota", filename)) + if err != nil { + t.Fatalf("failed to read test data file %s: %v", filename, err) + } + + var testData testData + if err := json.Unmarshal(f, &testData); err != nil { + t.Fatalf("failed to unmarshal test data file %s: %v", filename, err) + } + + return &testData +} + +type mockData struct { + Metadata *UpdateMetadata + StatusCode int +} + +type mockHTTPClient struct { + DoFunc func(req *http.Request) (*http.Response, error) + Fixtures map[string]mockData +} + +func compareURLs(a *url.URL, b *url.URL) bool { + if a.String() == b.String() { + return true + } + if a.Host != b.Host || a.Scheme != b.Scheme || a.Path != b.Path { + return false + } + + // do a quick check to see if the query parameters are the same + queryA := a.Query() + queryB := b.Query() + if len(queryA) != len(queryB) { + return false + } + for key := range queryA { + if queryA.Get(key) != queryB.Get(key) { + return false + } + } + for key := range queryB { + if queryA.Get(key) != queryB.Get(key) { + return false + } + } + return true +} + +func (m *mockHTTPClient) getFixture(expectedURL *url.URL) *mockData { + for u, fixture := range m.Fixtures { + fixtureURL, err := url.Parse(u) + if err != nil { + continue + } + if compareURLs(fixtureURL, expectedURL) { + return &fixture + } + } + return nil +} + +func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { + fixture := m.getFixture(req.URL) + if fixture == nil { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewBufferString("")), + }, fmt.Errorf("no fixture found for URL: %s", req.URL.String()) + } + + resp := &http.Response{ + StatusCode: fixture.StatusCode, + } + + jsonData, err := json.Marshal(fixture.Metadata) + if err != nil { + return nil, fmt.Errorf("error marshalling metadata: %w", err) + } + + resp.Body = io.NopCloser(bytes.NewBufferString(string(jsonData))) + return resp, nil +} + +func newMockHTTPClient(fixtures map[string]mockData) *mockHTTPClient { + return &mockHTTPClient{ + Fixtures: fixtures, + } +} + +func newOtaState(d *testData, t *testing.T) *State { + pseudoGetLocalVersion := func() (systemVersion *semver.Version, appVersion *semver.Version, err error) { + appVersion = semver.MustParse(d.LocalMetadata.AppVersion) + systemVersion = semver.MustParse(d.LocalMetadata.SystemVersion) + return systemVersion, appVersion, nil + } + + traceLevel := zerolog.InfoLevel + + if os.Getenv("TEST_LOG_TRACE") == "1" { + traceLevel = zerolog.TraceLevel + } + logger := zerolog.New(os.Stdout).Level(traceLevel) + otaState := NewState(Options{ + SkipConfirmSystem: true, + Logger: &logger, + ReleaseAPIEndpoint: releaseAPIEndpoint, + GetHTTPClient: func() HttpClient { + if d.RemoteMetadata != nil { + return newMockHTTPClient(d.ToFixtures(t)) + } + transport := http.DefaultTransport.(*http.Transport).Clone() + if !d.WithoutCerts { + transport.TLSClientConfig = &tls.Config{RootCAs: rootcerts.ServerCertPool()} + } else { + transport.TLSClientConfig = &tls.Config{RootCAs: x509.NewCertPool()} + } + client := &http.Client{ + Transport: transport, + } + return client + }, + GetLocalVersion: pseudoGetLocalVersion, + HwReboot: func(force bool, postRebootAction *PostRebootAction, delay time.Duration) error { return nil }, + ResetConfig: func() error { return nil }, + OnStateUpdate: func(state *RPCState) {}, + OnProgressUpdate: func(progress float32) {}, + }) + return otaState +} + +func testUsingJson(t *testing.T, filename string) { + td := loadTestData(t, filename) + otaState := newOtaState(td, t) + info, err := otaState.GetUpdateStatus(context.Background(), td.ToUpdateParams()) + if err != nil { + if td.Expected.Error != "" { + assert.ErrorContains(t, err, td.Expected.Error) + } else { + t.Fatalf("failed to get update status: %v", err) + } + } + + if td.Expected.System { + assert.True(t, info.SystemUpdateAvailable, fmt.Sprintf("system update should available, but reason: %s", info.SystemUpdateAvailableReason)) + } else { + assert.False(t, info.SystemUpdateAvailable, fmt.Sprintf("system update should not be available, but reason: %s", info.SystemUpdateAvailableReason)) + } + + if td.Expected.App { + assert.True(t, info.AppUpdateAvailable, fmt.Sprintf("app update should available, but reason: %s", info.AppUpdateAvailableReason)) + } else { + assert.False(t, info.AppUpdateAvailable, fmt.Sprintf("app update should not be available, but reason: %s", info.AppUpdateAvailableReason)) + } +} + +func TestCheckUpdateComponentsSystemOnlyUpgrade(t *testing.T) { + testUsingJson(t, "system_only_upgrade.json") +} + +func TestCheckUpdateComponentsSystemOnlyDowngrade(t *testing.T) { + testUsingJson(t, "system_only_downgrade.json") +} + +func TestCheckUpdateComponentsAppOnlyUpgrade(t *testing.T) { + testUsingJson(t, "app_only_upgrade.json") +} + +func TestCheckUpdateComponentsAppOnlyDowngrade(t *testing.T) { + testUsingJson(t, "app_only_downgrade.json") +} + +func TestCheckUpdateComponentsSystemBothUpgrade(t *testing.T) { + testUsingJson(t, "both_upgrade.json") +} + +func TestCheckUpdateComponentsSystemBothDowngrade(t *testing.T) { + testUsingJson(t, "both_downgrade.json") +} + +func TestCheckUpdateComponentsNoComponents(t *testing.T) { + testUsingJson(t, "no_components.json") +} diff --git a/internal/ota/rpc.go b/internal/ota/rpc.go new file mode 100644 index 00000000..30f132ea --- /dev/null +++ b/internal/ota/rpc.go @@ -0,0 +1,167 @@ +package ota + +import ( + "fmt" + "reflect" + "strings" + "time" + + "github.com/Masterminds/semver/v3" +) + +// to make the field names consistent with the RPCState struct +var componentFieldMap = map[string]string{ + "app": "App", + "system": "System", +} + +// RPCState represents the current OTA state for the RPC API +type RPCState struct { + Updating bool `json:"updating"` + Error string `json:"error,omitempty"` + MetadataFetchedAt *time.Time `json:"metadataFetchedAt,omitempty"` + AppUpdatePending bool `json:"appUpdatePending"` + SystemUpdatePending bool `json:"systemUpdatePending"` + AppDownloadProgress *float32 `json:"appDownloadProgress,omitempty"` //TODO: implement for progress bar + AppDownloadFinishedAt *time.Time `json:"appDownloadFinishedAt,omitempty"` + SystemDownloadProgress *float32 `json:"systemDownloadProgress,omitempty"` //TODO: implement for progress bar + SystemDownloadFinishedAt *time.Time `json:"systemDownloadFinishedAt,omitempty"` + AppVerificationProgress *float32 `json:"appVerificationProgress,omitempty"` + AppVerifiedAt *time.Time `json:"appVerifiedAt,omitempty"` + SystemVerificationProgress *float32 `json:"systemVerificationProgress,omitempty"` + SystemVerifiedAt *time.Time `json:"systemVerifiedAt,omitempty"` + AppUpdateProgress *float32 `json:"appUpdateProgress,omitempty"` //TODO: implement for progress bar + AppUpdatedAt *time.Time `json:"appUpdatedAt,omitempty"` + SystemUpdateProgress *float32 `json:"systemUpdateProgress,omitempty"` //TODO: port rk_ota, then implement + SystemUpdatedAt *time.Time `json:"systemUpdatedAt,omitempty"` +} + +func setTimeIfNotZero(rpcVal reflect.Value, i int, status time.Time) { + if !status.IsZero() { + rpcVal.Field(i).Set(reflect.ValueOf(&status)) + } +} + +func setFloat32IfNotZero(rpcVal reflect.Value, i int, status float32) { + if status != 0 { + rpcVal.Field(i).Set(reflect.ValueOf(&status)) + } +} + +// applyComponentStatusToRPCState uses reflection to map componentUpdateStatus fields to RPCState +func applyComponentStatusToRPCState(component string, status componentUpdateStatus, rpcState *RPCState) { + prefix := componentFieldMap[component] + if prefix == "" { + return + } + + rpcVal := reflect.ValueOf(rpcState).Elem() + + // it's really inefficient, but hey we do not need to use this often + // componentUpdateStatus is for internal use only, and all fields are unexported + for i := 0; i < rpcVal.NumField(); i++ { + rpcFieldName, hasPrefix := strings.CutPrefix(rpcVal.Type().Field(i).Name, prefix) + if !hasPrefix { + continue + } + + switch rpcFieldName { + case "DownloadProgress": + setFloat32IfNotZero(rpcVal, i, status.downloadProgress) + case "DownloadFinishedAt": + setTimeIfNotZero(rpcVal, i, status.downloadFinishedAt) + case "VerificationProgress": + setFloat32IfNotZero(rpcVal, i, status.verificationProgress) + case "VerifiedAt": + setTimeIfNotZero(rpcVal, i, status.verifiedAt) + case "UpdateProgress": + setFloat32IfNotZero(rpcVal, i, status.updateProgress) + case "UpdatedAt": + setTimeIfNotZero(rpcVal, i, status.updatedAt) + case "UpdatePending": + rpcVal.Field(i).SetBool(status.pending) + default: + continue + } + } +} + +// ToRPCState converts the State to the RPCState +func (s *State) ToRPCState() *RPCState { + r := &RPCState{ + Updating: s.updating, + Error: s.error, + MetadataFetchedAt: &s.metadataFetchedAt, + } + + for component, status := range s.componentUpdateStatuses { + applyComponentStatusToRPCState(component, status, r) + } + + return r +} + +func remoteMetadataToComponentStatus( + remoteMetadata *UpdateMetadata, + component string, + componentStatus *componentUpdateStatus, + params UpdateParams, +) error { + prefix := componentFieldMap[component] + if prefix == "" { + return fmt.Errorf("unknown component: %s", component) + } + + remoteMetadataVal := reflect.ValueOf(remoteMetadata).Elem() + for i := 0; i < remoteMetadataVal.NumField(); i++ { + fieldName, hasPrefix := strings.CutPrefix(remoteMetadataVal.Type().Field(i).Name, prefix) + if !hasPrefix { + continue + } + + switch fieldName { + case "URL": + componentStatus.url = remoteMetadataVal.Field(i).String() + case "Hash": + componentStatus.hash = remoteMetadataVal.Field(i).String() + case "Version": + componentStatus.version = remoteMetadataVal.Field(i).String() + default: + // fmt.Printf("unknown field %s", fieldName) + continue + } + } + + localVersion, err := semver.NewVersion(componentStatus.localVersion) + if err != nil { + return fmt.Errorf("error parsing local version: %w", err) + } + + remoteVersion, err := semver.NewVersion(componentStatus.version) + if err != nil { + return fmt.Errorf("error parsing remote version: %w", err) + } + componentStatus.available = remoteVersion.GreaterThan(localVersion) + componentStatus.availableReason = fmt.Sprintf("remote version %s is greater than local version %s", remoteVersion.String(), localVersion.String()) + + // Handle pre-release updates + if remoteVersion.Prerelease() != "" && params.IncludePreRelease && componentStatus.available { + componentStatus.availableReason += " (pre-release)" + } + + // If a custom version is specified, use it to determine if the update is available + constraint, componentExists := params.Components[component] + // we don't need to check again if it's already available + if componentExists && constraint != "" { + componentStatus.available = componentStatus.version != componentStatus.localVersion + if componentStatus.available { + componentStatus.availableReason = fmt.Sprintf("custom version %s is not equal to local version %s", constraint, componentStatus.localVersion) + componentStatus.customVersionUpdate = true + } + } else if !componentExists { + componentStatus.available = false + componentStatus.availableReason = "component not specified in update parameters" + } + + return nil +} diff --git a/internal/ota/state.go b/internal/ota/state.go new file mode 100644 index 00000000..2bb7055e --- /dev/null +++ b/internal/ota/state.go @@ -0,0 +1,215 @@ +package ota + +import ( + "sync" + "time" + + "github.com/Masterminds/semver/v3" + "github.com/rs/zerolog" +) + +var ( + availableComponents = []string{"app", "system"} + defaultComponents = map[string]string{ + "app": "", + "system": "", + } +) + +// UpdateMetadata represents the metadata of an update +type UpdateMetadata struct { + AppVersion string `json:"appVersion"` + AppURL string `json:"appUrl"` + AppHash string `json:"appHash"` + SystemVersion string `json:"systemVersion"` + SystemURL string `json:"systemUrl"` + SystemHash string `json:"systemHash"` +} + +// LocalMetadata represents the local metadata of the system +type LocalMetadata struct { + AppVersion string `json:"appVersion"` + SystemVersion string `json:"systemVersion"` +} + +// UpdateStatus represents the current update status +type UpdateStatus struct { + Local *LocalMetadata `json:"local"` + Remote *UpdateMetadata `json:"remote"` + SystemUpdateAvailable bool `json:"systemUpdateAvailable"` + AppUpdateAvailable bool `json:"appUpdateAvailable"` + WillDisableAutoUpdate bool `json:"willDisableAutoUpdate"` + + // only available for debugging and won't be exported + SystemUpdateAvailableReason string `json:"-"` + AppUpdateAvailableReason string `json:"-"` + + // for backwards compatibility + Error string `json:"error,omitempty"` +} + +// PostRebootAction represents the action to be taken after a reboot +// It is used to redirect the user to a specific page after a reboot +type PostRebootAction struct { + HealthCheck string `json:"healthCheck"` // The health check URL to call after the reboot + RedirectTo string `json:"redirectTo"` // The URL to redirect to after the reboot +} + +// componentUpdateStatus represents the status of a component update +type componentUpdateStatus struct { + pending bool + available bool + availableReason string // why the component is available or not available + customVersionUpdate bool + version string + localVersion string + url string + hash string + downloadProgress float32 + downloadFinishedAt time.Time + verificationProgress float32 + verifiedAt time.Time + updateProgress float32 + updatedAt time.Time + dependsOn []string +} + +func (c *componentUpdateStatus) getZerologLogger(l *zerolog.Logger) *zerolog.Logger { + logger := l.With(). + Bool("pending", c.pending). + Bool("available", c.available). + Str("availableReason", c.availableReason). + Str("version", c.version). + Str("localVersion", c.localVersion). + Str("url", c.url). + Str("hash", c.hash). + Float32("downloadProgress", c.downloadProgress). + Time("downloadFinishedAt", c.downloadFinishedAt). + Float32("verificationProgress", c.verificationProgress). + Time("verifiedAt", c.verifiedAt). + Float32("updateProgress", c.updateProgress). + Time("updatedAt", c.updatedAt). + Strs("dependsOn", c.dependsOn). + Logger() + return &logger +} + +// HwRebootFunc is a function that reboots the hardware +type HwRebootFunc func(force bool, postRebootAction *PostRebootAction, delay time.Duration) error + +// ResetConfigFunc is a function that resets the config +type ResetConfigFunc func() error + +// SetAutoUpdateFunc is a function that sets the auto-update state +type SetAutoUpdateFunc func(enabled bool) (bool, error) + +// GetHTTPClientFunc is a function that returns the HTTP client +type GetHTTPClientFunc func() HttpClient + +// OnStateUpdateFunc is a function that updates the state of the OTA +type OnStateUpdateFunc func(state *RPCState) + +// OnProgressUpdateFunc is a function that updates the progress of the OTA +type OnProgressUpdateFunc func(progress float32) + +// GetLocalVersionFunc is a function that returns the local version of the system and app +type GetLocalVersionFunc func() (systemVersion *semver.Version, appVersion *semver.Version, err error) + +// State represents the current OTA state for the UI +type State struct { + releaseAPIEndpoint string + l *zerolog.Logger + mu sync.Mutex + updating bool + error string + metadataFetchedAt time.Time + rebootNeeded bool + componentUpdateStatuses map[string]componentUpdateStatus + client GetHTTPClientFunc + reboot HwRebootFunc + getLocalVersion GetLocalVersionFunc + onStateUpdate OnStateUpdateFunc + resetConfig ResetConfigFunc + setAutoUpdate SetAutoUpdateFunc +} + +func toUpdateStatus(appUpdate *componentUpdateStatus, systemUpdate *componentUpdateStatus, error string) *UpdateStatus { + return &UpdateStatus{ + Local: &LocalMetadata{ + AppVersion: appUpdate.localVersion, + SystemVersion: systemUpdate.localVersion, + }, + Remote: &UpdateMetadata{ + AppVersion: appUpdate.version, + AppURL: appUpdate.url, + AppHash: appUpdate.hash, + SystemVersion: systemUpdate.version, + SystemURL: systemUpdate.url, + SystemHash: systemUpdate.hash, + }, + SystemUpdateAvailable: systemUpdate.available, + SystemUpdateAvailableReason: systemUpdate.availableReason, + AppUpdateAvailable: appUpdate.available, + AppUpdateAvailableReason: appUpdate.availableReason, + WillDisableAutoUpdate: appUpdate.customVersionUpdate || systemUpdate.customVersionUpdate, + Error: error, + } +} + +// ToUpdateStatus converts the State to the UpdateStatus +func (s *State) ToUpdateStatus() *UpdateStatus { + appUpdate, ok := s.componentUpdateStatuses["app"] + if !ok { + return nil + } + + systemUpdate, ok := s.componentUpdateStatuses["system"] + if !ok { + return nil + } + + return toUpdateStatus(&appUpdate, &systemUpdate, s.error) +} + +// IsUpdatePending returns true if an update is pending +func (s *State) IsUpdatePending() bool { + return s.updating +} + +// Options represents the options for the OTA state +type Options struct { + Logger *zerolog.Logger + GetHTTPClient GetHTTPClientFunc + GetLocalVersion GetLocalVersionFunc + OnStateUpdate OnStateUpdateFunc + OnProgressUpdate OnProgressUpdateFunc + HwReboot HwRebootFunc + ReleaseAPIEndpoint string + ResetConfig ResetConfigFunc + SkipConfirmSystem bool + SetAutoUpdate SetAutoUpdateFunc +} + +// NewState creates a new OTA state +func NewState(opts Options) *State { + components := make(map[string]componentUpdateStatus) + for _, component := range availableComponents { + components[component] = componentUpdateStatus{} + } + + s := &State{ + l: opts.Logger, + client: opts.GetHTTPClient, + reboot: opts.HwReboot, + onStateUpdate: opts.OnStateUpdate, + getLocalVersion: opts.GetLocalVersion, + componentUpdateStatuses: components, + releaseAPIEndpoint: opts.ReleaseAPIEndpoint, + resetConfig: opts.ResetConfig, + setAutoUpdate: opts.SetAutoUpdate, + } + if !opts.SkipConfirmSystem { + go s.confirmCurrentSystem() + } + return s +} diff --git a/internal/ota/sys.go b/internal/ota/sys.go new file mode 100644 index 00000000..6a5002f6 --- /dev/null +++ b/internal/ota/sys.go @@ -0,0 +1,101 @@ +package ota + +import ( + "bytes" + "context" + "os/exec" + "time" +) + +const ( + systemUpdatePath = "/userdata/jetkvm/update_system.tar" +) + +// DO NOT call it directly, it's not thread safe +// Mutex is currently held by the caller, e.g. doUpdate +func (s *State) updateSystem(ctx context.Context, systemUpdate *componentUpdateStatus) error { + l := s.l.With().Str("path", systemUpdatePath).Logger() + + if err := s.downloadFile(ctx, systemUpdatePath, systemUpdate.url, "system"); err != nil { + return s.componentUpdateError("Error downloading system update", err, &l) + } + + downloadFinished := time.Now() + systemUpdate.downloadFinishedAt = downloadFinished + systemUpdate.downloadProgress = 1 + s.triggerComponentUpdateState("system", systemUpdate) + + if err := s.verifyFile( + systemUpdatePath, + systemUpdate.hash, + &systemUpdate.verificationProgress, + ); err != nil { + return s.componentUpdateError("Error verifying system update hash", err, &l) + } + verifyFinished := time.Now() + systemUpdate.verifiedAt = verifyFinished + systemUpdate.verificationProgress = 1 + systemUpdate.updatedAt = verifyFinished + systemUpdate.updateProgress = 1 + s.triggerComponentUpdateState("system", systemUpdate) + + l.Info().Msg("System update downloaded") + + l.Info().Msg("Starting rk_ota command") + + cmd := exec.Command("rk_ota", "--misc=update", "--tar_path=/userdata/jetkvm/update_system.tar", "--save_dir=/userdata/jetkvm/ota_save", "--partition=all") + var b bytes.Buffer + cmd.Stdout = &b + cmd.Stderr = &b + if err := cmd.Start(); err != nil { + return s.componentUpdateError("Error starting rk_ota command", err, &l) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + ticker := time.NewTicker(1800 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if systemUpdate.updateProgress >= 0.99 { + return + } + systemUpdate.updateProgress += 0.01 + if systemUpdate.updateProgress > 0.99 { + systemUpdate.updateProgress = 0.99 + } + s.triggerComponentUpdateState("system", systemUpdate) + case <-ctx.Done(): + return + } + } + }() + + err := cmd.Wait() + cancel() + rkLogger := s.l.With(). + Str("output", b.String()). + Int("exitCode", cmd.ProcessState.ExitCode()).Logger() + if err != nil { + return s.componentUpdateError("Error executing rk_ota command", err, &rkLogger) + } + rkLogger.Info().Msg("rk_ota success") + + s.rebootNeeded = true + systemUpdate.updateProgress = 1 + systemUpdate.updatedAt = verifyFinished + s.triggerComponentUpdateState("system", systemUpdate) + + return nil +} + +func (s *State) confirmCurrentSystem() { + output, err := exec.Command("rk_ota", "--misc=now").CombinedOutput() + if err != nil { + s.l.Warn().Str("output", string(output)).Msg("failed to set current partition in A/B setup") + } + s.l.Trace().Str("output", string(output)).Msg("current partition in A/B setup set") +} diff --git a/internal/ota/testdata/ota.schema.json b/internal/ota/testdata/ota.schema.json new file mode 100644 index 00000000..15965850 --- /dev/null +++ b/internal/ota/testdata/ota.schema.json @@ -0,0 +1,159 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "OTA Test Data Schema", + "description": "Schema for OTA update test data", + "type": "object", + "required": ["name", "remoteMetadata", "localMetadata", "updateParams"], + "properties": { + "name": { + "type": "string", + "description": "Name of the test case" + }, + "withoutCerts": { + "type": "boolean", + "default": false, + "description": "Whether to run the test without Root CA certificates" + }, + "remoteMetadata": { + "type": "array", + "description": "Remote metadata responses", + "items": { + "type": "object", + "required": ["params", "code", "data"], + "properties": { + "params": { + "type": "object", + "description": "Query parameters used for the request", + "required": ["prerelease"], + "properties": { + "prerelease": { + "type": "string", + "description": "Whether to include pre-release versions" + }, + "appVersion": { + "type": "string", + "description": "Application version string", + "pattern": "^[0-9]+\\.[0-9]+\\.[0-9]+$" + }, + "systemVersion": { + "type": "string", + "description": "System version string", + "pattern": "^[0-9]+\\.[0-9]+\\.[0-9]+$" + } + }, + "additionalProperties": false + }, + "code": { + "type": "integer", + "description": "HTTP status code" + }, + "data": { + "type": "object", + "required": ["appVersion", "appUrl", "appHash", "systemVersion", "systemUrl", "systemHash"], + "properties": { + "appVersion": { + "type": "string", + "description": "Application version string", + "pattern": "^[0-9]+\\.[0-9]+\\.[0-9]+$" + }, + "appUrl": { + "type": "string", + "description": "URL to download the application", + "format": "uri" + }, + "appHash": { + "type": "string", + "description": "SHA-256 hash of the application", + "pattern": "^[a-f0-9]{64}$" + }, + "systemVersion": { + "type": "string", + "description": "System version string", + "pattern": "^[0-9]+\\.[0-9]+\\.[0-9]+$" + }, + "systemUrl": { + "type": "string", + "description": "URL to download the system", + "format": "uri" + }, + "systemHash": { + "type": "string", + "description": "SHA-256 hash of the system", + "pattern": "^[a-f0-9]{64}$" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "localMetadata": { + "type": "object", + "description": "Local metadata containing current installed versions", + "required": ["systemVersion", "appVersion"], + "properties": { + "systemVersion": { + "type": "string", + "description": "Currently installed system version", + "pattern": "^[0-9]+\\.[0-9]+\\.[0-9]+$" + }, + "appVersion": { + "type": "string", + "description": "Currently installed application version", + "pattern": "^[0-9]+\\.[0-9]+\\.[0-9]+$" + } + }, + "additionalProperties": false + }, + "updateParams": { + "type": "object", + "description": "Parameters for the update operation", + "required": ["includePreRelease"], + "properties": { + "includePreRelease": { + "type": "boolean", + "description": "Whether to include pre-release versions" + }, + "components": { + "type": "object", + "description": "Component update configuration", + "properties": { + "system": { + "type": "string", + "description": "System component update configuration (empty string to update)" + }, + "app": { + "type": "string", + "description": "App component update configuration (version string to update to)" + } + }, + "additionalProperties": true + } + }, + "additionalProperties": false + }, + "expected": { + "type": "object", + "description": "Expected update results", + "required": [], + "properties": { + "system": { + "type": "boolean", + "description": "Whether system update is expected" + }, + "app": { + "type": "boolean", + "description": "Whether app update is expected" + }, + "error": { + "type": "string", + "description": "Error message if the test case is expected to fail" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false +} + diff --git a/internal/ota/testdata/ota/app_only_downgrade.json b/internal/ota/testdata/ota/app_only_downgrade.json new file mode 100644 index 00000000..e8e2f7d1 --- /dev/null +++ b/internal/ota/testdata/ota/app_only_downgrade.json @@ -0,0 +1,34 @@ +{ + "name": "Downgrade App Only", + "remoteMetadata": [ + { + "params": { + "prerelease": "false", + "appVersion": "0.4.6" + }, + "code": 200, + "data": { + "appVersion": "0.4.6", + "appUrl": "https://update.jetkvm.com/app/0.4.6/jetkvm_app", + "appHash": "714f33432f17035e38d238bf376e98f3073e6cc2845d269ff617503d12d92bdd", + "systemVersion": "0.2.5", + "systemUrl": "https://update.jetkvm.com/system/0.2.5/system.tar", + "systemHash": "2323463ea8652be767d94514e548f90dd61b1ebcc0fb1834d700fac5b3d88a35" + } + } + ], + "localMetadata": { + "systemVersion": "0.2.2", + "appVersion": "0.4.5" + }, + "updateParams": { + "includePreRelease": false, + "components": { + "app": "0.4.6" + } + }, + "expected": { + "system": false, + "app": true + } +} \ No newline at end of file diff --git a/internal/ota/testdata/ota/app_only_upgrade.json b/internal/ota/testdata/ota/app_only_upgrade.json new file mode 100644 index 00000000..69aa7fb7 --- /dev/null +++ b/internal/ota/testdata/ota/app_only_upgrade.json @@ -0,0 +1,33 @@ +{ + "name": "Upgrade App Only", + "remoteMetadata": [ + { + "params": { + "prerelease": "false" + }, + "code": 200, + "data": { + "appVersion": "0.4.7", + "appUrl": "https://update.jetkvm.com/app/0.4.7/jetkvm_app", + "appHash": "714f33432f17035e38d238bf376e98f3073e6cc2845d269ff617503d12d92bdd", + "systemVersion": "0.2.5", + "systemUrl": "https://update.jetkvm.com/system/0.2.5/system.tar", + "systemHash": "2323463ea8652be767d94514e548f90dd61b1ebcc0fb1834d700fac5b3d88a35" + } + } + ], + "localMetadata": { + "systemVersion": "0.2.2", + "appVersion": "0.4.5" + }, + "updateParams": { + "includePreRelease": false, + "components": { + "app": "" + } + }, + "expected": { + "system": false, + "app": true + } +} \ No newline at end of file diff --git a/internal/ota/testdata/ota/both_downgrade.json b/internal/ota/testdata/ota/both_downgrade.json new file mode 100644 index 00000000..3c57461c --- /dev/null +++ b/internal/ota/testdata/ota/both_downgrade.json @@ -0,0 +1,37 @@ +{ + "name": "Downgrade System & App", + "remoteMetadata": [ + { + "params": { + "prerelease": "false", + "systemVersion": "0.2.2", + "appVersion": "0.4.6" + }, + "code": 200, + "data": { + "appVersion": "0.4.6", + "appUrl": "https://update.jetkvm.com/app/0.4.6/jetkvm_app", + "appHash": "714f33432f17035e38d238bf376e98f3073e6cc2845d269ff617503d12d92bdd", + "systemVersion": "0.2.2", + "systemUrl": "https://update.jetkvm.com/system/0.2.2/system.tar", + "systemHash": "2323463ea8652be767d94514e548f90dd61b1ebcc0fb1834d700fac5b3d88a35" + } + } + ], + "localMetadata": { + "systemVersion": "0.2.5", + "appVersion": "0.4.5" + }, + "updateParams": { + "includePreRelease": false, + "components": { + "system": "0.2.2", + "app": "0.4.6" + } + }, + "expected": { + "system": true, + "app": true + } +} + diff --git a/internal/ota/testdata/ota/both_upgrade.json b/internal/ota/testdata/ota/both_upgrade.json new file mode 100644 index 00000000..c3d3daee --- /dev/null +++ b/internal/ota/testdata/ota/both_upgrade.json @@ -0,0 +1,34 @@ +{ + "name": "Upgrade System & App (components given)", + "remoteMetadata": [ + { + "params": { + "prerelease": "false" + }, + "code": 200, + "data": { + "appVersion": "0.4.7", + "appUrl": "https://update.jetkvm.com/app/0.4.7/jetkvm_app", + "appHash": "714f33432f17035e38d238bf376e98f3073e6cc2845d269ff617503d12d92bdd", + "systemVersion": "0.2.5", + "systemUrl": "https://update.jetkvm.com/system/0.2.5/system.tar", + "systemHash": "2323463ea8652be767d94514e548f90dd61b1ebcc0fb1834d700fac5b3d88a35" + } + } + ], + "localMetadata": { + "systemVersion": "0.2.2", + "appVersion": "0.4.5" + }, + "updateParams": { + "includePreRelease": false, + "components": { + "system": "", + "app": "" + } + }, + "expected": { + "system": true, + "app": true + } +} \ No newline at end of file diff --git a/internal/ota/testdata/ota/no_components.json b/internal/ota/testdata/ota/no_components.json new file mode 100644 index 00000000..9fb8b253 --- /dev/null +++ b/internal/ota/testdata/ota/no_components.json @@ -0,0 +1,32 @@ +{ + "name": "Upgrade System & App (no components given)", + "remoteMetadata": [ + { + "params": { + "prerelease": "false" + }, + "code": 200, + "data": { + "appVersion": "0.4.7", + "appUrl": "https://update.jetkvm.com/app/0.4.7/jetkvm_app", + "appHash": "714f33432f17035e38d238bf376e98f3073e6cc2845d269ff617503d12d92bdd", + "systemVersion": "0.2.5", + "systemUrl": "https://update.jetkvm.com/system/0.2.5/system.tar", + "systemHash": "2323463ea8652be767d94514e548f90dd61b1ebcc0fb1834d700fac5b3d88a35" + } + } + ], + "localMetadata": { + "systemVersion": "0.2.2", + "appVersion": "0.4.2" + }, + "updateParams": { + "includePreRelease": false, + "components": {} + }, + "expected": { + "system": true, + "app": true + } +} + diff --git a/internal/ota/testdata/ota/system_only_downgrade.json b/internal/ota/testdata/ota/system_only_downgrade.json new file mode 100644 index 00000000..007f5279 --- /dev/null +++ b/internal/ota/testdata/ota/system_only_downgrade.json @@ -0,0 +1,34 @@ +{ + "name": "Downgrade System Only", + "remoteMetadata": [ + { + "params": { + "prerelease": "false", + "systemVersion": "0.2.2" + }, + "code": 200, + "data": { + "appVersion": "0.4.7", + "appUrl": "https://update.jetkvm.com/app/0.4.7/jetkvm_app", + "appHash": "714f33432f17035e38d238bf376e98f3073e6cc2845d269ff617503d12d92bdd", + "systemVersion": "0.2.2", + "systemUrl": "https://update.jetkvm.com/system/0.2.2/system.tar", + "systemHash": "2323463ea8652be767d94514e548f90dd61b1ebcc0fb1834d700fac5b3d88a35" + } + } + ], + "localMetadata": { + "systemVersion": "0.2.5", + "appVersion": "0.4.5" + }, + "updateParams": { + "includePreRelease": false, + "components": { + "system": "0.2.2" + } + }, + "expected": { + "system": true, + "app": false + } +} \ No newline at end of file diff --git a/internal/ota/testdata/ota/system_only_upgrade.json b/internal/ota/testdata/ota/system_only_upgrade.json new file mode 100644 index 00000000..b32c9434 --- /dev/null +++ b/internal/ota/testdata/ota/system_only_upgrade.json @@ -0,0 +1,33 @@ +{ + "name": "Upgrade System Only", + "remoteMetadata": [ + { + "params": { + "prerelease": "false" + }, + "code": 200, + "data": { + "appVersion": "0.4.7", + "appUrl": "https://update.jetkvm.com/app/0.4.7/jetkvm_app", + "appHash": "714f33432f17035e38d238bf376e98f3073e6cc2845d269ff617503d12d92bdd", + "systemVersion": "0.2.6", + "systemUrl": "https://update.jetkvm.com/system/0.2.6/system.tar", + "systemHash": "2323463ea8652be767d94514e548f90dd61b1ebcc0fb1834d700fac5b3d88a35" + } + } + ], + "localMetadata": { + "systemVersion": "0.2.5", + "appVersion": "0.4.5" + }, + "updateParams": { + "includePreRelease": false, + "components": { + "system": "" + } + }, + "expected": { + "system": true, + "app": false + } +} \ No newline at end of file diff --git a/internal/ota/testdata/ota/without_certs.json b/internal/ota/testdata/ota/without_certs.json new file mode 100644 index 00000000..d5150896 --- /dev/null +++ b/internal/ota/testdata/ota/without_certs.json @@ -0,0 +1,17 @@ +{ + "name": "Without Certs", + "localMetadata": { + "systemVersion": "0.2.5", + "appVersion": "0.4.7" + }, + "updateParams": { + "includePreRelease": false, + "components": {} + }, + "expected": { + "system": false, + "app": false, + "error": "certificate signed by unknown authority" + } +} + diff --git a/internal/ota/utils.go b/internal/ota/utils.go new file mode 100644 index 00000000..b03db342 --- /dev/null +++ b/internal/ota/utils.go @@ -0,0 +1,193 @@ +package ota + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "time" + + "github.com/rs/zerolog" +) + +func syncFilesystem() error { + // Flush filesystem buffers to ensure all data is written to disk + if err := exec.Command("sync").Run(); err != nil { + return fmt.Errorf("error flushing filesystem buffers: %w", err) + } + + // Clear the filesystem caches to force a read from disk + if err := os.WriteFile("/proc/sys/vm/drop_caches", []byte("1"), 0644); err != nil { + return fmt.Errorf("error clearing filesystem caches: %w", err) + } + + return nil +} + +func (s *State) downloadFile(ctx context.Context, path string, url string, component string) error { + logger := s.l.With(). + Str("path", path). + Str("url", url). + Str("downloadComponent", component). + Logger() + t := time.Now() + traceLogger := func() *zerolog.Event { + return logger.Trace().Dur("duration", time.Since(t)) + } + traceLogger().Msg("downloading file") + + componentUpdate, ok := s.componentUpdateStatuses[component] + if !ok { + return fmt.Errorf("component %s not found", component) + } + + downloadProgress := componentUpdate.downloadProgress + + if _, err := os.Stat(path); err == nil { + traceLogger().Msg("removing existing file") + if err := os.Remove(path); err != nil { + return fmt.Errorf("error removing existing file: %w", err) + } + } + + unverifiedPath := path + ".unverified" + if _, err := os.Stat(unverifiedPath); err == nil { + traceLogger().Msg("removing existing unverified file") + if err := os.Remove(unverifiedPath); err != nil { + return fmt.Errorf("error removing existing unverified file: %w", err) + } + } + + traceLogger().Msg("creating unverified file") + file, err := os.Create(unverifiedPath) + if err != nil { + return fmt.Errorf("error creating file: %w", err) + } + defer file.Close() + + traceLogger().Msg("creating request") + req, err := s.newHTTPRequestWithTrace(ctx, "GET", url, nil, traceLogger) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + client := s.client() + traceLogger().Msg("starting download") + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("error downloading file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + totalSize := resp.ContentLength + if totalSize <= 0 { + return fmt.Errorf("invalid content length") + } + + var written int64 + buf := make([]byte, 32*1024) + for { + nr, er := resp.Body.Read(buf) + if nr > 0 { + nw, ew := file.Write(buf[0:nr]) + if nw < nr { + return fmt.Errorf("short write: %d < %d", nw, nr) + } + written += int64(nw) + if ew != nil { + return fmt.Errorf("error writing to file: %w", ew) + } + progress := float32(written) / float32(totalSize) + if progress-downloadProgress >= 0.01 { + componentUpdate.downloadProgress = progress + s.triggerComponentUpdateState(component, &componentUpdate) + } + } + if er != nil { + if er == io.EOF { + break + } + return fmt.Errorf("error reading response body: %w", er) + } + } + + traceLogger().Msg("download finished") + file.Close() + + traceLogger().Msg("syncing filesystem") + if err := syncFilesystem(); err != nil { + return fmt.Errorf("error syncing filesystem: %w", err) + } + + return nil +} +func (s *State) verifyFile(path string, expectedHash string, verifyProgress *float32) error { + l := s.l.With().Str("path", path).Logger() + + unverifiedPath := path + ".unverified" + fileToHash, err := os.Open(unverifiedPath) + if err != nil { + return fmt.Errorf("error opening file for hashing: %w", err) + } + defer fileToHash.Close() + + hash := sha256.New() + fileInfo, err := fileToHash.Stat() + if err != nil { + return fmt.Errorf("error getting file info: %w", err) + } + totalSize := fileInfo.Size() + + buf := make([]byte, 32*1024) + verified := int64(0) + + for { + nr, er := fileToHash.Read(buf) + if nr > 0 { + nw, ew := hash.Write(buf[0:nr]) + if nw < nr { + return fmt.Errorf("short write: %d < %d", nw, nr) + } + verified += int64(nw) + if ew != nil { + return fmt.Errorf("error writing to hash: %w", ew) + } + progress := float32(verified) / float32(totalSize) + if progress-*verifyProgress >= 0.01 { + *verifyProgress = progress + s.triggerStateUpdate() + } + } + if er != nil { + if er == io.EOF { + break + } + return fmt.Errorf("error reading file: %w", er) + } + } + + hashSum := hash.Sum(nil) + l.Info().Str("hash", hex.EncodeToString(hashSum)).Msg("SHA256 hash of") + + if hex.EncodeToString(hashSum) != expectedHash { + return fmt.Errorf("hash mismatch: %x != %s", hashSum, expectedHash) + } + + if err := os.Rename(unverifiedPath, path); err != nil { + return fmt.Errorf("error renaming file: %w", err) + } + + if err := os.Chmod(path, 0755); err != nil { + return fmt.Errorf("error making file executable: %w", err) + } + + return nil +} diff --git a/jsonrpc.go b/jsonrpc.go index 9459329f..37dd3c5c 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -123,6 +123,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { Interface("id", request.ID).Logger() scopedLogger.Trace().Msg("Received RPC request") + t := time.Now() handler, ok := rpcHandlers[request.Method] if !ok { @@ -154,7 +155,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - scopedLogger.Trace().Interface("result", result).Msg("RPC handler returned") + scopedLogger.Trace().Dur("duration", time.Since(t)).Interface("result", result).Msg("RPC handler returned") response := JSONRPCResponse{ JSONRPC: "2.0", @@ -258,55 +259,6 @@ func rpcGetVideoLogStatus() (string, error) { return nativeInstance.VideoLogStatus() } -func rpcGetDevChannelState() (bool, error) { - return config.IncludePreRelease, nil -} - -func rpcSetDevChannelState(enabled bool) error { - config.IncludePreRelease = enabled - if err := SaveConfig(); err != nil { - return fmt.Errorf("failed to save config: %w", err) - } - return nil -} - -func rpcGetUpdateStatus() (*UpdateStatus, error) { - includePreRelease := config.IncludePreRelease - updateStatus, err := GetUpdateStatus(context.Background(), GetDeviceID(), includePreRelease) - // to ensure backwards compatibility, - // if there's an error, we won't return an error, but we will set the error field - if err != nil { - if updateStatus == nil { - return nil, fmt.Errorf("error checking for updates: %w", err) - } - updateStatus.Error = err.Error() - } - - return updateStatus, nil -} - -func rpcGetLocalVersion() (*LocalMetadata, error) { - systemVersion, appVersion, err := GetLocalVersion() - if err != nil { - return nil, fmt.Errorf("error getting local version: %w", err) - } - return &LocalMetadata{ - AppVersion: appVersion.String(), - SystemVersion: systemVersion.String(), - }, nil -} - -func rpcTryUpdate() error { - includePreRelease := config.IncludePreRelease - go func() { - err := TryUpdate(context.Background(), GetDeviceID(), includePreRelease) - if err != nil { - logger.Warn().Err(err).Msg("failed to try update") - } - }() - return nil -} - func rpcSetDisplayRotation(params DisplayRotationSettings) error { currentRotation := config.DisplayRotation if currentRotation == params.Rotation { @@ -676,7 +628,7 @@ func rpcGetMassStorageMode() (string, error) { } func rpcIsUpdatePending() (bool, error) { - return IsUpdatePending(), nil + return otaState.IsUpdatePending(), nil } func rpcGetUsbEmulationState() (bool, error) { @@ -1222,7 +1174,10 @@ var rpcHandlers = map[string]RPCHandler{ "setDevChannelState": {Func: rpcSetDevChannelState, Params: []string{"enabled"}}, "getLocalVersion": {Func: rpcGetLocalVersion}, "getUpdateStatus": {Func: rpcGetUpdateStatus}, + "checkUpdateComponents": {Func: rpcCheckUpdateComponents, Params: []string{"params", "includePreRelease"}}, + "getUpdateStatusChannel": {Func: rpcGetUpdateStatusChannel}, "tryUpdate": {Func: rpcTryUpdate}, + "tryUpdateComponents": {Func: rpcTryUpdateComponents, Params: []string{"params", "includePreRelease", "resetConfig"}}, "getDevModeState": {Func: rpcGetDevModeState}, "setDevModeState": {Func: rpcSetDevModeState, Params: []string{"enabled"}}, "getSSHKeyState": {Func: rpcGetSSHKeyState}, diff --git a/main.go b/main.go index 669c6f44..46425d07 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "time" "github.com/gwatts/rootcerts" + "github.com/jetkvm/kvm/internal/ota" ) var appCtx context.Context @@ -36,9 +37,9 @@ func Main() { Msg("starting JetKVM") go runWatchdog() - go confirmCurrentSystem() initNative(systemVersionLocal, appVersionLocal) + initDisplay() http.DefaultClient.Timeout = 1 * time.Minute @@ -50,6 +51,13 @@ func Main() { Int("ca_certs_loaded", len(rootcerts.Certs())). Msg("loaded Root CA certificates") + initOta() + + initNative(systemVersionLocal, appVersionLocal) + initDisplay() + + http.DefaultClient.Timeout = 1 * time.Minute + // Initialize network if err := initNetwork(); err != nil { logger.Error().Err(err).Msg("failed to initialize network") @@ -106,7 +114,10 @@ func Main() { } includePreRelease := config.IncludePreRelease - err = TryUpdate(context.Background(), GetDeviceID(), includePreRelease) + err = otaState.TryUpdate(context.Background(), ota.UpdateParams{ + DeviceID: GetDeviceID(), + IncludePreRelease: includePreRelease, + }) if err != nil { logger.Warn().Err(err).Msg("failed to auto update") } diff --git a/network.go b/network.go index b808d6fe..00dd45fa 100644 --- a/network.go +++ b/network.go @@ -2,8 +2,10 @@ package kvm import ( "fmt" + "reflect" "github.com/jetkvm/kvm/internal/network" + "github.com/jetkvm/kvm/internal/ota" "github.com/jetkvm/kvm/internal/udhcpc" ) @@ -82,21 +84,79 @@ func initNetwork() error { } }, }) +} - if state == nil { - if err == nil { - return fmt.Errorf("failed to create NetworkInterfaceState") +func setHostname(nm *nmlite.NetworkManager, hostname, domain string) error { + if nm == nil { + return nil + } + + if hostname == "" { + hostname = GetDefaultHostname() + } + + return nm.SetHostname(hostname, domain) +} + +func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (rebootRequired bool, postRebootAction *ota.PostRebootAction) { + oldDhcpClient := oldConfig.DHCPClient.String + + l := networkLogger.With(). + Interface("old", oldConfig). + Interface("new", newConfig). + Logger() + + // DHCP client change always requires reboot + if newConfig.DHCPClient.String != oldDhcpClient { + rebootRequired = true + l.Info().Msg("DHCP client changed, reboot required") + return rebootRequired, postRebootAction + } + + oldIPv4Mode := oldConfig.IPv4Mode.String + newIPv4Mode := newConfig.IPv4Mode.String + + // IPv4 mode change requires reboot + if newIPv4Mode != oldIPv4Mode { + rebootRequired = true + l.Info().Msg("IPv4 mode changed with udhcpc, reboot required") + + if newIPv4Mode == "static" && oldIPv4Mode != "static" { + postRebootAction = &ota.PostRebootAction{ + HealthCheck: fmt.Sprintf("//%s/device/status", newConfig.IPv4Static.Address.String), + RedirectTo: fmt.Sprintf("//%s", newConfig.IPv4Static.Address.String), + } + l.Info().Interface("postRebootAction", postRebootAction).Msg("IPv4 mode changed to static, reboot required") } - return err + + return rebootRequired, postRebootAction } - if err := state.Run(); err != nil { - return err + // IPv4 static config changes require reboot + if !reflect.DeepEqual(oldConfig.IPv4Static, newConfig.IPv4Static) { + rebootRequired = true + + // Handle IP change for redirect (only if both are not nil and IP changed) + if newConfig.IPv4Static != nil && oldConfig.IPv4Static != nil && + newConfig.IPv4Static.Address.String != oldConfig.IPv4Static.Address.String { + postRebootAction = &ota.PostRebootAction{ + HealthCheck: fmt.Sprintf("//%s/device/status", newConfig.IPv4Static.Address.String), + RedirectTo: fmt.Sprintf("//%s", newConfig.IPv4Static.Address.String), + } + + l.Info().Interface("postRebootAction", postRebootAction).Msg("IPv4 static config changed, reboot required") + } + + return rebootRequired, postRebootAction } - networkState = state + // IPv6 mode change requires reboot when using udhcpc + if newConfig.IPv6Mode.String != oldConfig.IPv6Mode.String && oldDhcpClient == "udhcpc" { + rebootRequired = true + l.Info().Msg("IPv6 mode changed with udhcpc, reboot required") + } - return nil + return rebootRequired, postRebootAction } func rpcGetNetworkState() network.RpcNetworkState { diff --git a/ota.go b/ota.go index bf0828dc..ef7f9c21 100644 --- a/ota.go +++ b/ota.go @@ -1,59 +1,65 @@ package kvm import ( - "bytes" "context" - "crypto/sha256" - "crypto/tls" - "encoding/hex" - "encoding/json" "fmt" - "io" "net/http" - "net/url" "os" - "os/exec" "strings" - "time" "github.com/Masterminds/semver/v3" - "github.com/gwatts/rootcerts" - "github.com/rs/zerolog" + "github.com/google/uuid" + "github.com/jetkvm/kvm/internal/ota" ) -type UpdateMetadata struct { - AppVersion string `json:"appVersion"` - AppUrl string `json:"appUrl"` - AppHash string `json:"appHash"` - SystemVersion string `json:"systemVersion"` - SystemUrl string `json:"systemUrl"` - SystemHash string `json:"systemHash"` -} - -type LocalMetadata struct { - AppVersion string `json:"appVersion"` - SystemVersion string `json:"systemVersion"` -} - -// UpdateStatus represents the current update status -type UpdateStatus struct { - Local *LocalMetadata `json:"local"` - Remote *UpdateMetadata `json:"remote"` - SystemUpdateAvailable bool `json:"systemUpdateAvailable"` - AppUpdateAvailable bool `json:"appUpdateAvailable"` - - // for backwards compatibility - Error string `json:"error,omitempty"` -} - -const UpdateMetadataUrl = "https://api.jetkvm.com/releases" - var builtAppVersion = "0.1.0+dev" +var otaState *ota.State + +func initOta() { + otaState = ota.NewState(ota.Options{ + Logger: otaLogger, + ReleaseAPIEndpoint: config.GetUpdateAPIURL(), + GetHTTPClient: func() ota.HttpClient { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = config.NetworkConfig.GetTransportProxyFunc() + + client := &http.Client{ + Transport: transport, + } + return client + }, + GetLocalVersion: GetLocalVersion, + HwReboot: hwReboot, + ResetConfig: rpcResetConfig, + SetAutoUpdate: rpcSetAutoUpdateState, + OnStateUpdate: func(state *ota.RPCState) { + triggerOTAStateUpdate(state) + }, + OnProgressUpdate: func(progress float32) { + writeJSONRPCEvent("otaProgress", progress, currentSession) + }, + }) +} + +func triggerOTAStateUpdate(state *ota.RPCState) { + go func() { + if currentSession == nil || (otaState == nil && state == nil) { + return + } + if state == nil { + state = otaState.ToRPCState() + } + writeJSONRPCEvent("otaState", state, currentSession) + }() +} + +// GetBuiltAppVersion returns the built-in app version func GetBuiltAppVersion() string { return builtAppVersion } +// GetLocalVersion returns the local version of the system and app func GetLocalVersion() (systemVersion *semver.Version, appVersion *semver.Version, err error) { appVersion, err = semver.NewVersion(builtAppVersion) if err != nil { @@ -73,491 +79,107 @@ func GetLocalVersion() (systemVersion *semver.Version, appVersion *semver.Versio return systemVersion, appVersion, nil } -func fetchUpdateMetadata(ctx context.Context, deviceId string, includePreRelease bool) (*UpdateMetadata, error) { - metadata := &UpdateMetadata{} +func getUpdateStatus(includePreRelease bool) (*ota.UpdateStatus, error) { + updateStatus, err := otaState.GetUpdateStatus(context.Background(), ota.UpdateParams{ + DeviceID: GetDeviceID(), + IncludePreRelease: includePreRelease, + RequestID: uuid.New().String(), + }) - updateUrl, err := url.Parse(UpdateMetadataUrl) + // to ensure backwards compatibility, + // if there's an error, we won't return an error, but we will set the error field if err != nil { - return nil, fmt.Errorf("error parsing update metadata URL: %w", err) - } - - query := updateUrl.Query() - query.Set("deviceId", deviceId) - query.Set("prerelease", fmt.Sprintf("%v", includePreRelease)) - updateUrl.RawQuery = query.Encode() - - logger.Info().Str("url", updateUrl.String()).Msg("Checking for updates") - - req, err := http.NewRequestWithContext(ctx, "GET", updateUrl.String(), nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %w", err) - } - - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.Proxy = config.NetworkConfig.GetTransportProxyFunc() - - client := &http.Client{ - Transport: transport, - } - - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("error sending request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - err = json.NewDecoder(resp.Body).Decode(metadata) - if err != nil { - return nil, fmt.Errorf("error decoding response: %w", err) - } - - return metadata, nil -} - -func downloadFile(ctx context.Context, path string, url string, downloadProgress *float32) error { - if _, err := os.Stat(path); err == nil { - if err := os.Remove(path); err != nil { - return fmt.Errorf("error removing existing file: %w", err) + if updateStatus == nil { + return nil, fmt.Errorf("error checking for updates: %w", err) } + updateStatus.Error = err.Error() } - unverifiedPath := path + ".unverified" - if _, err := os.Stat(unverifiedPath); err == nil { - if err := os.Remove(unverifiedPath); err != nil { - return fmt.Errorf("error removing existing unverified file: %w", err) - } + // otaState doesn't have the current auto-update state, so we need to get it from the config + if updateStatus.WillDisableAutoUpdate { + updateStatus.WillDisableAutoUpdate = config.AutoUpdateEnabled } - file, err := os.Create(unverifiedPath) - if err != nil { - return fmt.Errorf("error creating file: %w", err) - } - defer file.Close() - - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return fmt.Errorf("error creating request: %w", err) - } - - client := http.Client{ - Timeout: 10 * time.Minute, - Transport: &http.Transport{ - Proxy: config.NetworkConfig.GetTransportProxyFunc(), - TLSHandshakeTimeout: 30 * time.Second, - TLSClientConfig: &tls.Config{ - RootCAs: rootcerts.ServerCertPool(), - }, - }, - } - - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("error downloading file: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - totalSize := resp.ContentLength - if totalSize <= 0 { - return fmt.Errorf("invalid content length") - } - - var written int64 - buf := make([]byte, 32*1024) - for { - nr, er := resp.Body.Read(buf) - if nr > 0 { - nw, ew := file.Write(buf[0:nr]) - if nw < nr { - return fmt.Errorf("short write: %d < %d", nw, nr) - } - written += int64(nw) - if ew != nil { - return fmt.Errorf("error writing to file: %w", ew) - } - progress := float32(written) / float32(totalSize) - if progress-*downloadProgress >= 0.01 { - *downloadProgress = progress - triggerOTAStateUpdate() - } - } - if er != nil { - if er == io.EOF { - break - } - return fmt.Errorf("error reading response body: %w", er) - } - } - - file.Close() - - // Flush filesystem buffers to ensure all data is written to disk - err = exec.Command("sync").Run() - if err != nil { - return fmt.Errorf("error flushing filesystem buffers: %w", err) - } - - // Clear the filesystem caches to force a read from disk - err = os.WriteFile("/proc/sys/vm/drop_caches", []byte("1"), 0644) - if err != nil { - return fmt.Errorf("error clearing filesystem caches: %w", err) - } - - return nil -} - -func verifyFile(path string, expectedHash string, verifyProgress *float32, scopedLogger *zerolog.Logger) error { - if scopedLogger == nil { - scopedLogger = otaLogger - } - - unverifiedPath := path + ".unverified" - fileToHash, err := os.Open(unverifiedPath) - if err != nil { - return fmt.Errorf("error opening file for hashing: %w", err) - } - defer fileToHash.Close() - - hash := sha256.New() - fileInfo, err := fileToHash.Stat() - if err != nil { - return fmt.Errorf("error getting file info: %w", err) - } - totalSize := fileInfo.Size() - - buf := make([]byte, 32*1024) - verified := int64(0) - - for { - nr, er := fileToHash.Read(buf) - if nr > 0 { - nw, ew := hash.Write(buf[0:nr]) - if nw < nr { - return fmt.Errorf("short write: %d < %d", nw, nr) - } - verified += int64(nw) - if ew != nil { - return fmt.Errorf("error writing to hash: %w", ew) - } - progress := float32(verified) / float32(totalSize) - if progress-*verifyProgress >= 0.01 { - *verifyProgress = progress - triggerOTAStateUpdate() - } - } - if er != nil { - if er == io.EOF { - break - } - return fmt.Errorf("error reading file: %w", er) - } - } - - hashSum := hash.Sum(nil) - scopedLogger.Info().Str("path", path).Str("hash", hex.EncodeToString(hashSum)).Msg("SHA256 hash of") - - if hex.EncodeToString(hashSum) != expectedHash { - return fmt.Errorf("hash mismatch: %x != %s", hashSum, expectedHash) - } - - if err := os.Rename(unverifiedPath, path); err != nil { - return fmt.Errorf("error renaming file: %w", err) - } - - if err := os.Chmod(path, 0755); err != nil { - return fmt.Errorf("error making file executable: %w", err) - } - - return nil -} - -type OTAState struct { - Updating bool `json:"updating"` - Error string `json:"error,omitempty"` - MetadataFetchedAt *time.Time `json:"metadataFetchedAt,omitempty"` - AppUpdatePending bool `json:"appUpdatePending"` - SystemUpdatePending bool `json:"systemUpdatePending"` - AppDownloadProgress float32 `json:"appDownloadProgress,omitempty"` //TODO: implement for progress bar - AppDownloadFinishedAt *time.Time `json:"appDownloadFinishedAt,omitempty"` - SystemDownloadProgress float32 `json:"systemDownloadProgress,omitempty"` //TODO: implement for progress bar - SystemDownloadFinishedAt *time.Time `json:"systemDownloadFinishedAt,omitempty"` - AppVerificationProgress float32 `json:"appVerificationProgress,omitempty"` - AppVerifiedAt *time.Time `json:"appVerifiedAt,omitempty"` - SystemVerificationProgress float32 `json:"systemVerificationProgress,omitempty"` - SystemVerifiedAt *time.Time `json:"systemVerifiedAt,omitempty"` - AppUpdateProgress float32 `json:"appUpdateProgress,omitempty"` //TODO: implement for progress bar - AppUpdatedAt *time.Time `json:"appUpdatedAt,omitempty"` - SystemUpdateProgress float32 `json:"systemUpdateProgress,omitempty"` //TODO: port rk_ota, then implement - SystemUpdatedAt *time.Time `json:"systemUpdatedAt,omitempty"` -} - -var otaState = OTAState{} - -func triggerOTAStateUpdate() { - go func() { - if currentSession == nil { - logger.Info().Msg("No active RPC session, skipping update state update") - return - } - writeJSONRPCEvent("otaState", otaState, currentSession) - }() -} - -func TryUpdate(ctx context.Context, deviceId string, includePreRelease bool) error { - scopedLogger := otaLogger.With(). - Str("deviceId", deviceId). - Str("includePreRelease", fmt.Sprintf("%v", includePreRelease)). - Logger() - - scopedLogger.Info().Msg("Trying to update...") - if otaState.Updating { - return fmt.Errorf("update already in progress") - } - - otaState = OTAState{ - Updating: true, - } - triggerOTAStateUpdate() - - defer func() { - otaState.Updating = false - triggerOTAStateUpdate() - }() - - updateStatus, err := GetUpdateStatus(ctx, deviceId, includePreRelease) - if err != nil { - otaState.Error = fmt.Sprintf("Error checking for updates: %v", err) - scopedLogger.Error().Err(err).Msg("Error checking for updates") - return fmt.Errorf("error checking for updates: %w", err) - } - - now := time.Now() - otaState.MetadataFetchedAt = &now - otaState.AppUpdatePending = updateStatus.AppUpdateAvailable - otaState.SystemUpdatePending = updateStatus.SystemUpdateAvailable - triggerOTAStateUpdate() - - local := updateStatus.Local - remote := updateStatus.Remote - appUpdateAvailable := updateStatus.AppUpdateAvailable - systemUpdateAvailable := updateStatus.SystemUpdateAvailable - - rebootNeeded := false - - if appUpdateAvailable { - scopedLogger.Info(). - Str("local", local.AppVersion). - Str("remote", remote.AppVersion). - Msg("App update available") - - err := downloadFile(ctx, "/userdata/jetkvm/jetkvm_app.update", remote.AppUrl, &otaState.AppDownloadProgress) - if err != nil { - otaState.Error = fmt.Sprintf("Error downloading app update: %v", err) - scopedLogger.Error().Err(err).Msg("Error downloading app update") - triggerOTAStateUpdate() - return err - } - downloadFinished := time.Now() - otaState.AppDownloadFinishedAt = &downloadFinished - otaState.AppDownloadProgress = 1 - triggerOTAStateUpdate() - - err = verifyFile( - "/userdata/jetkvm/jetkvm_app.update", - remote.AppHash, - &otaState.AppVerificationProgress, - &scopedLogger, - ) - if err != nil { - otaState.Error = fmt.Sprintf("Error verifying app update hash: %v", err) - scopedLogger.Error().Err(err).Msg("Error verifying app update hash") - triggerOTAStateUpdate() - return err - } - verifyFinished := time.Now() - otaState.AppVerifiedAt = &verifyFinished - otaState.AppVerificationProgress = 1 - otaState.AppUpdatedAt = &verifyFinished - otaState.AppUpdateProgress = 1 - triggerOTAStateUpdate() - - scopedLogger.Info().Msg("App update downloaded") - rebootNeeded = true - } else { - scopedLogger.Info().Msg("App is up to date") - } - - if systemUpdateAvailable { - scopedLogger.Info(). - Str("local", local.SystemVersion). - Str("remote", remote.SystemVersion). - Msg("System update available") - - err := downloadFile(ctx, "/userdata/jetkvm/update_system.tar", remote.SystemUrl, &otaState.SystemDownloadProgress) - if err != nil { - otaState.Error = fmt.Sprintf("Error downloading system update: %v", err) - scopedLogger.Error().Err(err).Msg("Error downloading system update") - triggerOTAStateUpdate() - return err - } - downloadFinished := time.Now() - otaState.SystemDownloadFinishedAt = &downloadFinished - otaState.SystemDownloadProgress = 1 - triggerOTAStateUpdate() - - err = verifyFile( - "/userdata/jetkvm/update_system.tar", - remote.SystemHash, - &otaState.SystemVerificationProgress, - &scopedLogger, - ) - if err != nil { - otaState.Error = fmt.Sprintf("Error verifying system update hash: %v", err) - scopedLogger.Error().Err(err).Msg("Error verifying system update hash") - triggerOTAStateUpdate() - return err - } - scopedLogger.Info().Msg("System update downloaded") - verifyFinished := time.Now() - otaState.SystemVerifiedAt = &verifyFinished - otaState.SystemVerificationProgress = 1 - triggerOTAStateUpdate() - - scopedLogger.Info().Msg("Starting rk_ota command") - cmd := exec.Command("rk_ota", "--misc=update", "--tar_path=/userdata/jetkvm/update_system.tar", "--save_dir=/userdata/jetkvm/ota_save", "--partition=all") - var b bytes.Buffer - cmd.Stdout = &b - cmd.Stderr = &b - err = cmd.Start() - if err != nil { - otaState.Error = fmt.Sprintf("Error starting rk_ota command: %v", err) - scopedLogger.Error().Err(err).Msg("Error starting rk_ota command") - return fmt.Errorf("error starting rk_ota command: %w", err) - } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - ticker := time.NewTicker(1800 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if otaState.SystemUpdateProgress >= 0.99 { - return - } - otaState.SystemUpdateProgress += 0.01 - if otaState.SystemUpdateProgress > 0.99 { - otaState.SystemUpdateProgress = 0.99 - } - triggerOTAStateUpdate() - case <-ctx.Done(): - return - } - } - }() - - err = cmd.Wait() - cancel() - output := b.String() - if err != nil { - otaState.Error = fmt.Sprintf("Error executing rk_ota command: %v\nOutput: %s", err, output) - scopedLogger.Error(). - Err(err). - Str("output", output). - Int("exitCode", cmd.ProcessState.ExitCode()). - Msg("Error executing rk_ota command") - return fmt.Errorf("error executing rk_ota command: %w\nOutput: %s", err, output) - } - scopedLogger.Info().Str("output", output).Msg("rk_ota success") - otaState.SystemUpdateProgress = 1 - otaState.SystemUpdatedAt = &verifyFinished - triggerOTAStateUpdate() - rebootNeeded = true - } else { - scopedLogger.Info().Msg("System is up to date") - } - - if rebootNeeded { - scopedLogger.Info().Msg("System Rebooting in 10s") - time.Sleep(10 * time.Second) - cmd := exec.Command("reboot") - err := cmd.Start() - if err != nil { - otaState.Error = fmt.Sprintf("Failed to start reboot: %v", err) - scopedLogger.Error().Err(err).Msg("Failed to start reboot") - return fmt.Errorf("failed to start reboot: %w", err) - } else { - os.Exit(0) - } - } - - return nil -} - -func GetUpdateStatus(ctx context.Context, deviceId string, includePreRelease bool) (*UpdateStatus, error) { - updateStatus := &UpdateStatus{} - - // Get local versions - systemVersionLocal, appVersionLocal, err := GetLocalVersion() - if err != nil { - return updateStatus, fmt.Errorf("error getting local version: %w", err) - } - updateStatus.Local = &LocalMetadata{ - AppVersion: appVersionLocal.String(), - SystemVersion: systemVersionLocal.String(), - } - - // Get remote metadata - remoteMetadata, err := fetchUpdateMetadata(ctx, deviceId, includePreRelease) - if err != nil { - return updateStatus, fmt.Errorf("error checking for updates: %w", err) - } - updateStatus.Remote = remoteMetadata - - // Get remote versions - systemVersionRemote, err := semver.NewVersion(remoteMetadata.SystemVersion) - if err != nil { - return updateStatus, fmt.Errorf("error parsing remote system version: %w", err) - } - appVersionRemote, err := semver.NewVersion(remoteMetadata.AppVersion) - if err != nil { - return updateStatus, fmt.Errorf("error parsing remote app version: %w, %s", err, remoteMetadata.AppVersion) - } - - updateStatus.SystemUpdateAvailable = systemVersionRemote.GreaterThan(systemVersionLocal) - updateStatus.AppUpdateAvailable = appVersionRemote.GreaterThan(appVersionLocal) - - // Handle pre-release updates - isRemoteSystemPreRelease := systemVersionRemote.Prerelease() != "" - isRemoteAppPreRelease := appVersionRemote.Prerelease() != "" - - if isRemoteSystemPreRelease && !includePreRelease { - updateStatus.SystemUpdateAvailable = false - } - if isRemoteAppPreRelease && !includePreRelease { - updateStatus.AppUpdateAvailable = false - } + otaLogger.Info().Interface("updateStatus", updateStatus).Msg("Update status") return updateStatus, nil } -func IsUpdatePending() bool { - return otaState.Updating +func rpcGetDevChannelState() (bool, error) { + return config.IncludePreRelease, nil } -// make sure our current a/b partition is set as default -func confirmCurrentSystem() { - output, err := exec.Command("rk_ota", "--misc=now").CombinedOutput() - if err != nil { - logger.Warn().Str("output", string(output)).Msg("failed to set current partition in A/B setup") +func rpcSetDevChannelState(enabled bool) error { + config.IncludePreRelease = enabled + if err := SaveConfig(); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + return nil +} + +func rpcGetUpdateStatus() (*ota.UpdateStatus, error) { + return getUpdateStatus(config.IncludePreRelease) +} + +func rpcGetUpdateStatusChannel(channel string) (*ota.UpdateStatus, error) { + switch channel { + case "stable": + return getUpdateStatus(false) + case "dev": + return getUpdateStatus(true) + default: + return nil, fmt.Errorf("invalid channel: %s", channel) } } + +func rpcGetLocalVersion() (*ota.LocalMetadata, error) { + systemVersion, appVersion, err := GetLocalVersion() + if err != nil { + return nil, fmt.Errorf("error getting local version: %w", err) + } + return &ota.LocalMetadata{ + AppVersion: appVersion.String(), + SystemVersion: systemVersion.String(), + }, nil +} + +type updateParams struct { + Components map[string]string `json:"components,omitempty"` +} + +func rpcTryUpdate() error { + return rpcTryUpdateComponents(updateParams{ + Components: make(map[string]string), + }, config.IncludePreRelease, false) +} + +// rpcCheckUpdateComponents checks the update status for the given components +func rpcCheckUpdateComponents(params updateParams, includePreRelease bool) (*ota.UpdateStatus, error) { + updateParams := ota.UpdateParams{ + DeviceID: GetDeviceID(), + IncludePreRelease: includePreRelease, + Components: params.Components, + } + info, err := otaState.GetUpdateStatus(context.Background(), updateParams) + if err != nil { + return nil, fmt.Errorf("failed to check update: %w", err) + } + return info, nil +} + +func rpcTryUpdateComponents(params updateParams, includePreRelease bool, resetConfig bool) error { + updateParams := ota.UpdateParams{ + DeviceID: GetDeviceID(), + IncludePreRelease: includePreRelease, + ResetConfig: resetConfig, + Components: params.Components, + } + + go func() { + err := otaState.TryUpdate(context.Background(), updateParams) + if err != nil { + otaLogger.Warn().Err(err).Msg("failed to try update") + } + }() + return nil +} diff --git a/ui/src/components/NestedSettingsGroup.tsx b/ui/src/components/NestedSettingsGroup.tsx new file mode 100644 index 00000000..3ee57b0f --- /dev/null +++ b/ui/src/components/NestedSettingsGroup.tsx @@ -0,0 +1,22 @@ +import { cx } from "@/cva.config"; + +interface NestedSettingsGroupProps { + readonly children: React.ReactNode; + readonly className?: string; +} + +export function NestedSettingsGroup(props: NestedSettingsGroupProps) { + const { children, className } = props; + + return ( +
+ {children} +
+ ); +} + diff --git a/ui/src/hooks/useVersion.tsx b/ui/src/hooks/useVersion.tsx index 7341dacb..759782e0 100644 --- a/ui/src/hooks/useVersion.tsx +++ b/ui/src/hooks/useVersion.tsx @@ -17,6 +17,19 @@ export interface SystemVersionInfo { error?: string; } +export interface VersionInfo { + appVersion: string; + systemVersion: string; +} + +export interface SystemVersionInfo { + local: VersionInfo; + remote?: VersionInfo; + systemUpdateAvailable: boolean; + appUpdateAvailable: boolean; + error?: string; +} + export function useVersion() { const { appVersion, diff --git a/ui/src/routes/devices.$id.settings.access._index.tsx b/ui/src/routes/devices.$id.settings.access._index.tsx index f30bfef1..a9404470 100644 --- a/ui/src/routes/devices.$id.settings.access._index.tsx +++ b/ui/src/routes/devices.$id.settings.access._index.tsx @@ -278,7 +278,7 @@ export default function SettingsAccessIndexRoute() { onClick={handleCustomTlsUpdate} /> - + )} {selectedProvider === "custom" && ( -
+
-
+ )} )} diff --git a/ui/src/routes/devices.$id.settings.advanced.tsx b/ui/src/routes/devices.$id.settings.advanced.tsx index 722e31bf..59d01797 100644 --- a/ui/src/routes/devices.$id.settings.advanced.tsx +++ b/ui/src/routes/devices.$id.settings.advanced.tsx @@ -1,20 +1,30 @@ import { useCallback, useEffect, useState } from "react"; +import { useSettingsStore } from "@hooks/stores"; +import { JsonRpcError, JsonRpcResponse, useJsonRpc } from "@hooks/useJsonRpc"; +import { useDeviceUiNavigation } from "@hooks/useAppNavigation"; +import { Button } from "@components/Button"; +import Checkbox, { CheckboxWithLabel } from "@components/Checkbox"; +import { ConfirmDialog } from "@components/ConfirmDialog"; import { GridCard } from "@components/Card"; import { SettingsItem } from "@components/SettingsItem"; +import { SettingsPageHeader } from "@components/SettingsPageheader"; +import { NestedSettingsGroup } from "@components/NestedSettingsGroup"; +import { TextAreaWithLabel } from "@components/TextArea"; +import { InputFieldWithLabel } from "@components/InputField"; +import { SelectMenuBasic } from "@components/SelectMenuBasic"; +import { isOnDevice } from "@/main"; +import notifications from "@/notifications"; +import { m } from "@localizations/messages.js"; +import { sleep } from "@/utils"; +import { checkUpdateComponents, UpdateComponents } from "@/utils/jsonrpc"; +import { SystemVersionInfo } from "@hooks/useVersion"; -import { Button } from "../components/Button"; -import Checkbox from "../components/Checkbox"; -import { ConfirmDialog } from "../components/ConfirmDialog"; -import { SettingsPageHeader } from "../components/SettingsPageheader"; -import { TextAreaWithLabel } from "../components/TextArea"; -import { useSettingsStore } from "../hooks/stores"; -import { JsonRpcResponse, useJsonRpc } from "../hooks/useJsonRpc"; -import { isOnDevice } from "../main"; -import notifications from "../notifications"; +import { FeatureFlag } from "../components/FeatureFlag"; export default function SettingsAdvancedRoute() { const { send } = useJsonRpc(); + const { navigateTo } = useDeviceUiNavigation(); const [sshKey, setSSHKey] = useState(""); const { setDeveloperMode } = useSettingsStore(); @@ -22,7 +32,12 @@ export default function SettingsAdvancedRoute() { const [usbEmulationEnabled, setUsbEmulationEnabled] = useState(false); const [showLoopbackWarning, setShowLoopbackWarning] = useState(false); const [localLoopbackOnly, setLocalLoopbackOnly] = useState(false); - + const [updateTarget, setUpdateTarget] = useState("app"); + const [appVersion, setAppVersion] = useState(""); + const [systemVersion, setSystemVersion] = useState(""); + const [resetConfig, setResetConfig] = useState(false); + const [versionChangeAcknowledged, setVersionChangeAcknowledged] = useState(false); + const [customVersionUpdateLoading, setCustomVersionUpdateLoading] = useState(false); const settings = useSettingsStore(); useEffect(() => { @@ -172,6 +187,61 @@ export default function SettingsAdvancedRoute() { setShowLoopbackWarning(false); }, [applyLoopbackOnlyMode, setShowLoopbackWarning]); + const handleVersionUpdateError = useCallback((error?: JsonRpcError | string) => { + notifications.error( + m.advanced_error_version_update({ + error: typeof error === "string" ? error : (error?.data ?? error?.message ?? m.unknown_error()) + }), + { duration: 1000 * 15 } // 15 seconds + ); + setCustomVersionUpdateLoading(false); + }, []); + + const handleCustomVersionUpdate = useCallback(async () => { + const components: UpdateComponents = {}; + if (["app", "both"].includes(updateTarget) && appVersion) components.app = appVersion; + if (["system", "both"].includes(updateTarget) && systemVersion) components.system = systemVersion; + let versionInfo: SystemVersionInfo | undefined; + + try { + // we do not need to set it to false if check succeeds, + // because it will be redirected to the update page later + setCustomVersionUpdateLoading(true); + versionInfo = await checkUpdateComponents({ + components, + }, devChannel); + } catch (error: unknown) { + const jsonRpcError = error as JsonRpcError; + handleVersionUpdateError(jsonRpcError); + return; + } + + let hasUpdate = false; + + const pageParams = new URLSearchParams(); + if (components.app && versionInfo?.remote?.appVersion && versionInfo?.appUpdateAvailable) { + hasUpdate = true; + pageParams.set("custom_app_version", versionInfo.remote?.appVersion); + } + if (components.system && versionInfo?.remote?.systemVersion && versionInfo?.systemUpdateAvailable) { + hasUpdate = true; + pageParams.set("custom_system_version", versionInfo.remote?.systemVersion); + } + pageParams.set("reset_config", resetConfig.toString()); + + if (!hasUpdate) { + handleVersionUpdateError("No update available"); + return; + } + + // Navigate to update page + navigateTo(`/settings/general/update?${pageParams.toString()}`); + }, [ + updateTarget, appVersion, systemVersion, devChannel, + navigateTo, resetConfig, handleVersionUpdateError, + setCustomVersionUpdateLoading + ]); + return (
handleDevModeChange(e.target.checked)} /> - - {settings.developerMode && ( - -
- - - -
-
-

- Developer Mode Enabled -

-
-
    -
  • Security is weakened while active
  • -
  • Only use if you understand the risks
  • -
+ {settings.developerMode ? ( + + +
+ + + +
+
+

+ {m.advanced_developer_mode_enabled_title()} +

+
+
    +
  • {m.advanced_developer_mode_warning_security()}
  • +
  • {m.advanced_developer_mode_warning_risks()}
  • +
+
+
+
+ {m.advanced_developer_mode_warning_advanced()}
+
+
-
- For advanced users only. Not for production use. + {isOnDevice && ( +
+ + setSSHKey(e.target.value)} + placeholder={m.advanced_ssh_public_key_placeholder()} + /> +

+ {m.advanced_ssh_default_user()}root. +

+
+
-
- - )} + )} + + +
+ + + setUpdateTarget(e.target.value)} + /> + + {(updateTarget === "app" || updateTarget === "both") && ( + setAppVersion(e.target.value)} + /> + )} + + {(updateTarget === "system" || updateTarget === "both") && ( + setSystemVersion(e.target.value)} + /> + )} + +

+ {m.advanced_version_update_helper()}{" "} + + {m.advanced_version_update_github_link()} + +

+ +
+ setResetConfig(e.target.checked)} + /> +
+ +
+ setVersionChangeAcknowledged(e.target.checked)} + /> +
+ +
+
+
+ ) : null} - {isOnDevice && settings.developerMode && ( -
- -
- setSSHKey(e.target.value)} - placeholder="Enter your SSH public key" - /> -

- The default SSH user is root. -

-
-
-
-
- )} + {settings.debugMode && ( - <> + - + )}
diff --git a/ui/src/routes/devices.$id.settings.general._index.tsx b/ui/src/routes/devices.$id.settings.general._index.tsx index c71e858b..c50b536a 100644 --- a/ui/src/routes/devices.$id.settings.general._index.tsx +++ b/ui/src/routes/devices.$id.settings.general._index.tsx @@ -16,7 +16,6 @@ export default function SettingsGeneralRoute() { const { send } = useJsonRpc(); const { navigateTo } = useDeviceUiNavigation(); const [autoUpdate, setAutoUpdate] = useState(true); - const currentVersions = useDeviceStore(state => { const { appVersion, systemVersion } = state; if (!appVersion || !systemVersion) return null; @@ -70,7 +69,7 @@ export default function SettingsGeneralRoute() { ) } /> -
+
diff --git a/ui/src/routes/devices.$id.settings.hardware.tsx b/ui/src/routes/devices.$id.settings.hardware.tsx index dd3ba2ed..8f3dcf31 100644 --- a/ui/src/routes/devices.$id.settings.hardware.tsx +++ b/ui/src/routes/devices.$id.settings.hardware.tsx @@ -189,7 +189,7 @@ export default function SettingsHardwareRoute() { }} /> - + )}

The display will wake up when the connection state changes, or when touched. diff --git a/ui/src/routes/devices.$id.settings.video.tsx b/ui/src/routes/devices.$id.settings.video.tsx index 5bf5a8f6..cdbd606c 100644 --- a/ui/src/routes/devices.$id.settings.video.tsx +++ b/ui/src/routes/devices.$id.settings.video.tsx @@ -7,6 +7,7 @@ import { SettingsItem } from "@components/SettingsItem"; import { SettingsPageHeader } from "@components/SettingsPageheader"; import { useSettingsStore } from "@/hooks/stores"; import { SelectMenuBasic } from "@components/SelectMenuBasic"; +import { NestedSettingsGroup } from "@components/NestedSettingsGroup"; import Fieldset from "@components/Fieldset"; import notifications from "@/notifications"; @@ -180,7 +181,7 @@ export default function SettingsVideoRoute() { description="Adjust color settings to make the video output more vibrant and colorful" /> -

+
-
+
({ method: "getUpdateStatus", // This function calls our api server to see if there are any updates available. // It can be called on page load right after a restart, so we need to give it time to // establish a connection to the api server. - maxAttempts: 6, + maxAttempts: UPDATE_STATUS_RPC_MAX_ATTEMPTS, + attemptTimeoutMs: UPDATE_STATUS_RPC_TIMEOUT_MS, }); if (response.error) throw response.error; @@ -242,3 +247,27 @@ export async function getLocalVersion() { if (response.error) throw response.error; return response.result; } + +export type UpdateComponent = "app" | "system"; +export type UpdateComponents = Partial>; + +export interface updateParams { + components?: UpdateComponents; +} + +export async function checkUpdateComponents(params: updateParams, includePreRelease: boolean) { + const response = await callJsonRpc({ + method: "checkUpdateComponents", + params: { + params, + includePreRelease, + }, + // maxAttempts is set to 1, + // because it currently retry for all errors, + // and we don't want to retry if the error is not a network error + maxAttempts: 1, + attemptTimeoutMs: UPDATE_STATUS_RPC_TIMEOUT_MS, + }); + if (response.error) throw response.error; + return response.result; +} \ No newline at end of file diff --git a/webrtc.go b/webrtc.go index abe1aba7..76de2914 100644 --- a/webrtc.go +++ b/webrtc.go @@ -286,10 +286,13 @@ func newSession(config SessionConfig) (*Session, error) { // Enqueue to ensure ordered processing session.rpcQueue <- msg }) - triggerOTAStateUpdate() - triggerVideoStateUpdate() - triggerUSBStateUpdate() - notifyFailsafeMode(session) + // Wait for channel to be open before sending initial state + d.OnOpen(func() { + triggerOTAStateUpdate(otaState.ToRPCState()) + triggerVideoStateUpdate() + triggerUSBStateUpdate() + notifyFailsafeMode(session) + }) case "terminal": handleTerminalChannel(d) case "serial":