kvm/cmd/main.go

171 lines
3.6 KiB
Go

package main
import (
"flag"
"fmt"
"io"
"os"
"os/exec"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/erikdubbelboer/gspt"
"github.com/jetkvm/kvm"
)
const (
envChildID = "JETKVM_CHILD_ID"
errorDumpDir = "/userdata/jetkvm/"
errorDumpStateFile = ".has_error_dump"
errorDumpTemplate = "jetkvm-%s.log"
)
func program() {
gspt.SetProcTitle(os.Args[0] + " [app]")
kvm.Main()
}
func main() {
versionPtr := flag.Bool("version", false, "print version and exit")
versionJSONPtr := flag.Bool("version-json", false, "print version as json and exit")
flag.Parse()
if *versionPtr || *versionJSONPtr {
versionData, err := kvm.GetVersionData(*versionJSONPtr)
if err != nil {
fmt.Printf("failed to get version data: %v\n", err)
os.Exit(1)
}
fmt.Println(string(versionData))
return
}
childID := os.Getenv(envChildID)
switch childID {
case "":
doSupervise()
case kvm.GetBuiltAppVersion():
program()
default:
fmt.Printf("Invalid build version: %s != %s\n", childID, kvm.GetBuiltAppVersion())
os.Exit(1)
}
}
func supervise() error {
// check binary path
binPath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
// check if binary is same as current binary
if info, statErr := os.Stat(binPath); statErr != nil {
return fmt.Errorf("failed to get executable info: %w", statErr)
// check if binary is empty
} else if info.Size() == 0 {
return fmt.Errorf("binary is empty")
// check if it's executable
} else if info.Mode().Perm()&0111 == 0 {
return fmt.Errorf("binary is not executable")
}
// run the child binary
cmd := exec.Command(binPath)
cmd.Env = append(os.Environ(), []string{envChildID + "=" + kvm.GetBuiltAppVersion()}...)
cmd.Args = os.Args
logFile, err := os.CreateTemp("", "jetkvm-stdout.log")
defer func() {
// we don't care about the errors here
_ = logFile.Close()
_ = os.Remove(logFile.Name())
}()
if err != nil {
return fmt.Errorf("failed to create log file: %w", err)
}
// Use io.MultiWriter to write to both the original streams and our buffers
cmd.Stdout = io.MultiWriter(os.Stdout, logFile)
cmd.Stderr = io.MultiWriter(os.Stderr, logFile)
if startErr := cmd.Start(); startErr != nil {
return fmt.Errorf("failed to start command: %w", startErr)
}
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM)
sig := <-sigChan
_ = cmd.Process.Signal(sig)
}()
gspt.SetProcTitle(os.Args[0] + " [sup]")
cmdErr := cmd.Wait()
if cmdErr == nil {
return nil
}
if exiterr, ok := cmdErr.(*exec.ExitError); ok {
createErrorDump(logFile)
os.Exit(exiterr.ExitCode())
}
return nil
}
func createErrorDump(logFile *os.File) {
logFile.Close()
// touch the error dump state file
if err := os.WriteFile(filepath.Join(errorDumpDir, errorDumpStateFile), []byte{}, 0644); err != nil {
return
}
fileName := fmt.Sprintf(errorDumpTemplate, time.Now().Format("20060102150405"))
filePath := filepath.Join(errorDumpDir, fileName)
if err := os.Rename(logFile.Name(), filePath); err == nil {
fmt.Printf("error dump created: %s\n", filePath)
return
}
fnSrc, err := os.Open(logFile.Name())
if err != nil {
return
}
defer fnSrc.Close()
fnDst, err := os.Create(filePath)
if err != nil {
return
}
defer fnDst.Close()
buf := make([]byte, 1024*1024)
for {
n, err := fnSrc.Read(buf)
if err != nil && err != io.EOF {
return
}
if n == 0 {
break
}
if _, err := fnDst.Write(buf[:n]); err != nil {
return
}
}
fmt.Printf("error dump created: %s\n", filePath)
}
func doSupervise() {
err := supervise()
if err == nil {
return
}
}