mirror of https://github.com/jetkvm/kvm.git
167 lines
3.9 KiB
Go
167 lines
3.9 KiB
Go
package ota
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
)
|
|
|
|
func syncFilesystem() error {
|
|
// Flush filesystem buffers to ensure all data is written to disk
|
|
if err := exec.Command("sync").Run(); err != nil {
|
|
return fmt.Errorf("error flushing filesystem buffers: %w", err)
|
|
}
|
|
|
|
// Clear the filesystem caches to force a read from disk
|
|
if err := os.WriteFile("/proc/sys/vm/drop_caches", []byte("1"), 0644); err != nil {
|
|
return fmt.Errorf("error clearing filesystem caches: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *State) downloadFile(ctx context.Context, path string, url string, downloadProgress *float32) error {
|
|
if _, err := os.Stat(path); err == nil {
|
|
if err := os.Remove(path); err != nil {
|
|
return fmt.Errorf("error removing existing file: %w", err)
|
|
}
|
|
}
|
|
|
|
unverifiedPath := path + ".unverified"
|
|
if _, err := os.Stat(unverifiedPath); err == nil {
|
|
if err := os.Remove(unverifiedPath); err != nil {
|
|
return fmt.Errorf("error removing existing unverified file: %w", err)
|
|
}
|
|
}
|
|
|
|
file, err := os.Create(unverifiedPath)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating request: %w", err)
|
|
}
|
|
|
|
client := s.client()
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("error downloading file: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
|
}
|
|
|
|
totalSize := resp.ContentLength
|
|
if totalSize <= 0 {
|
|
return fmt.Errorf("invalid content length")
|
|
}
|
|
|
|
var written int64
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
nr, er := resp.Body.Read(buf)
|
|
if nr > 0 {
|
|
nw, ew := file.Write(buf[0:nr])
|
|
if nw < nr {
|
|
return fmt.Errorf("short write: %d < %d", nw, nr)
|
|
}
|
|
written += int64(nw)
|
|
if ew != nil {
|
|
return fmt.Errorf("error writing to file: %w", ew)
|
|
}
|
|
progress := float32(written) / float32(totalSize)
|
|
if progress-*downloadProgress >= 0.01 {
|
|
*downloadProgress = progress
|
|
s.triggerStateUpdate()
|
|
}
|
|
}
|
|
if er != nil {
|
|
if er == io.EOF {
|
|
break
|
|
}
|
|
return fmt.Errorf("error reading response body: %w", er)
|
|
}
|
|
}
|
|
|
|
file.Close()
|
|
|
|
if err := syncFilesystem(); err != nil {
|
|
return fmt.Errorf("error syncing filesystem: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *State) verifyFile(path string, expectedHash string, verifyProgress *float32) error {
|
|
l := s.l.With().Str("path", path).Logger()
|
|
|
|
unverifiedPath := path + ".unverified"
|
|
fileToHash, err := os.Open(unverifiedPath)
|
|
if err != nil {
|
|
return fmt.Errorf("error opening file for hashing: %w", err)
|
|
}
|
|
defer fileToHash.Close()
|
|
|
|
hash := sha256.New()
|
|
fileInfo, err := fileToHash.Stat()
|
|
if err != nil {
|
|
return fmt.Errorf("error getting file info: %w", err)
|
|
}
|
|
totalSize := fileInfo.Size()
|
|
|
|
buf := make([]byte, 32*1024)
|
|
verified := int64(0)
|
|
|
|
for {
|
|
nr, er := fileToHash.Read(buf)
|
|
if nr > 0 {
|
|
nw, ew := hash.Write(buf[0:nr])
|
|
if nw < nr {
|
|
return fmt.Errorf("short write: %d < %d", nw, nr)
|
|
}
|
|
verified += int64(nw)
|
|
if ew != nil {
|
|
return fmt.Errorf("error writing to hash: %w", ew)
|
|
}
|
|
progress := float32(verified) / float32(totalSize)
|
|
if progress-*verifyProgress >= 0.01 {
|
|
*verifyProgress = progress
|
|
s.triggerStateUpdate()
|
|
}
|
|
}
|
|
if er != nil {
|
|
if er == io.EOF {
|
|
break
|
|
}
|
|
return fmt.Errorf("error reading file: %w", er)
|
|
}
|
|
}
|
|
|
|
hashSum := hash.Sum(nil)
|
|
l.Info().Str("hash", hex.EncodeToString(hashSum)).Msg("SHA256 hash of")
|
|
|
|
if hex.EncodeToString(hashSum) != expectedHash {
|
|
return fmt.Errorf("hash mismatch: %x != %s", hashSum, expectedHash)
|
|
}
|
|
|
|
if err := os.Rename(unverifiedPath, path); err != nil {
|
|
return fmt.Errorf("error renaming file: %w", err)
|
|
}
|
|
|
|
if err := os.Chmod(path, 0755); err != nil {
|
|
return fmt.Errorf("error making file executable: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|