diff --git a/jsonrpc.go b/jsonrpc.go index b401ac59..3e446756 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -1181,6 +1181,9 @@ var rpcHandlers = map[string]RPCHandler{ "listStorageFiles": {Func: rpcListStorageFiles}, "deleteStorageFile": {Func: rpcDeleteStorageFile, Params: []string{"filename"}}, "startStorageFileUpload": {Func: rpcStartStorageFileUpload, Params: []string{"filename", "size"}}, + "downloadFromUrl": {Func: rpcDownloadFromUrl, Params: []string{"url", "filename"}}, + "getDownloadState": {Func: rpcGetDownloadState}, + "cancelDownload": {Func: rpcCancelDownload}, "getWakeOnLanDevices": {Func: rpcGetWakeOnLanDevices}, "setWakeOnLanDevices": {Func: rpcSetWakeOnLanDevices, Params: []string{"params"}}, "resetConfig": {Func: rpcResetConfig}, diff --git a/ui/localization/messages/en.json b/ui/localization/messages/en.json index 3730fe3a..e5c34687 100644 --- a/ui/localization/messages/en.json +++ b/ui/localization/messages/en.json @@ -568,6 +568,19 @@ "mount_uploaded_has_been_uploaded": "{name} has been uploaded", "mount_uploading": "Uploading…", "mount_uploading_with_name": "Uploading {name}", + "mount_download_title": "Download from URL", + "mount_download_description": "Download an image file directly to JetKVM storage from a URL", + "mount_download_url_label": "Image URL", + "mount_download_filename_label": "Save as filename", + "mount_downloading": "Downloading...", + "mount_downloading_with_name": "Downloading {name}", + "mount_download_successful": "Download successful", + "mount_download_has_been_downloaded": "{name} has been downloaded", + "mount_download_error": "Download error: {error}", + "mount_download_cancelled": "Download cancelled", + "mount_button_start_download": "Start Download", + "mount_button_cancel_download": "Cancel Download", + "mount_button_download_from_url": "Download from URL", "mount_url_description": "Mount files from any public web address", "mount_url_input_label": "Image URL", "mount_url_mount": "URL Mount", diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index 4c3252d6..dea958f8 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -443,7 +443,7 @@ export interface MountMediaState { remoteVirtualMediaState: RemoteVirtualMediaState | null; setRemoteVirtualMediaState: (state: MountMediaState["remoteVirtualMediaState"]) => void; - modalView: "mode" | "url" | "device" | "upload" | "error" | null; + modalView: "mode" | "url" | "device" | "upload" | "download" | "error" | null; setModalView: (view: MountMediaState["modalView"]) => void; isMountMediaDialogOpen: boolean; diff --git a/ui/src/routes/devices.$id.mount.tsx b/ui/src/routes/devices.$id.mount.tsx index b2ff891f..e8f15a32 100644 --- a/ui/src/routes/devices.$id.mount.tsx +++ b/ui/src/routes/devices.$id.mount.tsx @@ -5,6 +5,7 @@ import { LuRadioReceiver, LuCheck, LuUpload, + LuDownload, } from "react-icons/lu"; import { PlusCircleIcon, ExclamationTriangleIcon } from "@heroicons/react/20/solid"; import { TrashIcon } from "@heroicons/react/16/solid"; @@ -186,6 +187,9 @@ export function Dialog({ onClose }: Readonly<{ onClose: () => void }>) { setIncompleteFileName(incompleteFile || null); setModalView("upload"); }} + onDownloadClick={() => { + setModalView("download"); + }} /> )} @@ -200,6 +204,15 @@ export function Dialog({ onClose }: Readonly<{ onClose: () => void }>) { /> )} + {modalView === "download" && ( + setModalView("device")} + onDownloadComplete={() => { + setModalView("device"); + }} + /> + )} + {modalView === "error" && ( void; mountInProgress: boolean; onBack: () => void; onNewImageClick: (incompleteFileName?: string) => void; + onDownloadClick: () => void; }) { const [onStorageFiles, setOnStorageFiles] = useState< { @@ -799,7 +814,7 @@ function DeviceFileView({ {onStorageFiles.length > 0 && (
onNewImageClick()} /> +
)} @@ -1247,6 +1269,272 @@ function UploadFileView({ ); } +function DownloadFileView({ + onBack, + onDownloadComplete, +}: { + onBack: () => void; + onDownloadComplete: () => void; +}) { + const [downloadViewState, setDownloadViewState] = useState<"idle" | "downloading" | "success" | "error">("idle"); + const [url, setUrl] = useState(""); + const [filename, setFilename] = useState(""); + const [progress, setProgress] = useState(0); + const [downloadSpeed, setDownloadSpeed] = useState(null); + const [downloadError, setDownloadError] = useState(null); + const [totalBytes, setTotalBytes] = useState(0); + + const { send } = useJsonRpc(); + + // Track download speed + const lastBytesRef = useRef(0); + const lastTimeRef = useRef(0); + const speedHistoryRef = useRef([]); + + // Compute URL validity + const isUrlValid = useMemo(() => { + try { + const urlObj = new URL(url); + return urlObj.protocol === 'http:' || urlObj.protocol === 'https:'; + } catch { + return false; + } + }, [url]); + + // Extract filename from URL + const suggestedFilename = useMemo(() => { + if (!url) return ''; + try { + const urlObj = new URL(url); + const pathParts = urlObj.pathname.split('/'); + const lastPart = pathParts[pathParts.length - 1]; + if (lastPart && (lastPart.endsWith('.iso') || lastPart.endsWith('.img'))) { + return lastPart; + } + } catch { + // Invalid URL, ignore + } + return ''; + }, [url]); + + // Update filename when URL changes and user hasn't manually edited it + const [userEditedFilename, setUserEditedFilename] = useState(false); + const effectiveFilename = userEditedFilename ? filename : (suggestedFilename || filename); + + // Listen for download state events via polling + useEffect(() => { + if (downloadViewState !== "downloading") return; + + const pollInterval = setInterval(() => { + send("getDownloadState", {}, (resp: JsonRpcResponse) => { + if ("error" in resp) return; + + const state = resp.result as { + downloading: boolean; + filename: string; + totalBytes: number; + doneBytes: number; + progress: number; + error?: string; + }; + + if (state.error) { + setDownloadError(state.error); + setDownloadViewState("error"); + return; + } + + setTotalBytes(state.totalBytes); + setProgress(state.progress * 100); + + // Calculate speed + const now = Date.now(); + const timeDiff = (now - lastTimeRef.current) / 1000; + const bytesDiff = state.doneBytes - lastBytesRef.current; + + if (timeDiff > 0 && bytesDiff > 0) { + const instantSpeed = bytesDiff / timeDiff; + speedHistoryRef.current.push(instantSpeed); + if (speedHistoryRef.current.length > 5) { + speedHistoryRef.current.shift(); + } + const avgSpeed = speedHistoryRef.current.reduce((a, b) => a + b, 0) / speedHistoryRef.current.length; + setDownloadSpeed(avgSpeed); + } + + lastBytesRef.current = state.doneBytes; + lastTimeRef.current = now; + + if (!state.downloading && state.progress >= 1) { + setDownloadViewState("success"); + } + }); + }, 500); + + return () => clearInterval(pollInterval); + }, [downloadViewState, send]); + + function handleStartDownload() { + if (!url || !effectiveFilename) return; + + setDownloadViewState("downloading"); + setDownloadError(null); + setProgress(0); + setDownloadSpeed(null); + lastBytesRef.current = 0; + lastTimeRef.current = Date.now(); + speedHistoryRef.current = []; + + send("downloadFromUrl", { url, filename: effectiveFilename }, (resp: JsonRpcResponse) => { + if ("error" in resp) { + setDownloadError(resp.error.message); + setDownloadViewState("error"); + } + }); + } + + function handleCancelDownload() { + send("cancelDownload", {}, (resp: JsonRpcResponse) => { + if ("error" in resp) { + console.error("Failed to cancel download:", resp.error); + } + setDownloadViewState("idle"); + }); + } + + return ( +
+ + + {downloadViewState === "idle" && ( + <> +
+ setUrl(e.target.value)} + /> + { + setFilename(e.target.value); + setUserEditedFilename(true); + }} + /> +
+
+
+ + )} + + {downloadViewState === "downloading" && ( +
+ +
+
+ +

+ {m.mount_downloading_with_name({ name: formatters.truncateMiddle(effectiveFilename, 30) })} +

+
+

+ {formatters.bytes(totalBytes)} +

+
+
+
+
+ {m.mount_downloading()} + + {downloadSpeed !== null + ? `${formatters.bytes(downloadSpeed)}/s` + : m.mount_calculating()} + +
+
+ +
+
+
+ )} + + {downloadViewState === "success" && ( +
+ +
+ +

+ {m.mount_download_successful()} +

+

+ {m.mount_download_has_been_downloaded({ name: effectiveFilename })} +

+
+
+
+
+
+ )} + + {downloadViewState === "error" && ( +
+ +
+ +

+ {m.mount_error_title()} +

+

+ {downloadError} +

+
+
+
+
+
+ )} +
+ ); +} + function ErrorView({ errorMessage, onClose, diff --git a/usb_mass_storage.go b/usb_mass_storage.go index 0f1f4b93..7e657874 100644 --- a/usb_mass_storage.go +++ b/usb_mass_storage.go @@ -1,6 +1,7 @@ package kvm import ( + "context" "encoding/json" "errors" "fmt" @@ -87,7 +88,7 @@ func mountImage(imagePath string) error { var nbdDevice *NBDDevice -const imagesFolder = "/userdata/jetkvm/images" +var imagesFolder = "/userdata/jetkvm/images" func initImagesFolder() error { err := os.MkdirAll(imagesFolder, 0755) @@ -612,3 +613,232 @@ func handleUploadHttp(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "Upload completed"}) } + +// Download state management +type DownloadState struct { + Downloading bool `json:"downloading"` + Filename string `json:"filename,omitempty"` + URL string `json:"url,omitempty"` + TotalBytes int64 `json:"totalBytes"` + DoneBytes int64 `json:"doneBytes"` + Progress float32 `json:"progress"` + Error string `json:"error,omitempty"` +} + +var currentDownload *DownloadState +var downloadMutex sync.Mutex +var downloadCancel context.CancelFunc + +func rpcGetDownloadState() (*DownloadState, error) { + downloadMutex.Lock() + defer downloadMutex.Unlock() + if currentDownload == nil { + return &DownloadState{Downloading: false}, nil + } + return currentDownload, nil +} + +func rpcCancelDownload() error { + downloadMutex.Lock() + defer downloadMutex.Unlock() + if downloadCancel != nil { + downloadCancel() + downloadCancel = nil + } + return nil +} + +func rpcDownloadFromUrl(url string, filename string) error { + // Sanitize filename + sanitizedFilename, err := sanitizeFilename(filename) + if err != nil { + return err + } + + // Validate URL + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + return errors.New("invalid URL: must start with http:// or https://") + } + + // Check if already downloading + downloadMutex.Lock() + if currentDownload != nil && currentDownload.Downloading { + downloadMutex.Unlock() + return errors.New("another download is already in progress") + } + + // Check if file already exists + filePath := filepath.Join(imagesFolder, sanitizedFilename) + if _, err := os.Stat(filePath); err == nil { + downloadMutex.Unlock() + return fmt.Errorf("file already exists: %s", sanitizedFilename) + } + + // Initialize download state + ctx, cancel := context.WithCancel(context.Background()) + downloadCancel = cancel + currentDownload = &DownloadState{ + Downloading: true, + Filename: sanitizedFilename, + URL: url, + Progress: 0, + } + downloadMutex.Unlock() + + // Start download in goroutine + go performDownload(ctx, url, sanitizedFilename) + + return nil +} + +func performDownload(ctx context.Context, url string, filename string) { + downloadPath := filepath.Join(imagesFolder, filename+".incomplete") + finalPath := filepath.Join(imagesFolder, filename) + + defer func() { + downloadMutex.Lock() + if currentDownload != nil { + currentDownload.Downloading = false + } + downloadCancel = nil + downloadMutex.Unlock() + broadcastDownloadState() + }() + + // Create HTTP request with context + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + setDownloadError(fmt.Sprintf("failed to create request: %v", err)) + return + } + + // Perform request + client := &http.Client{Timeout: 0} // No timeout for large downloads + resp, err := client.Do(req) + if err != nil { + if ctx.Err() == context.Canceled { + setDownloadError("download cancelled") + // Clean up incomplete file + os.Remove(downloadPath) + } else { + setDownloadError(fmt.Sprintf("failed to download: %v", err)) + } + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + setDownloadError(fmt.Sprintf("server returned status %d", resp.StatusCode)) + return + } + + totalSize := resp.ContentLength + if totalSize <= 0 { + setDownloadError("server did not provide content length") + return + } + + // Update state with total size + downloadMutex.Lock() + if currentDownload != nil { + currentDownload.TotalBytes = totalSize + } + downloadMutex.Unlock() + broadcastDownloadState() + + // Create file + file, err := os.Create(downloadPath) + if err != nil { + setDownloadError(fmt.Sprintf("failed to create file: %v", err)) + return + } + defer file.Close() + + // Download with progress tracking + var written int64 + buf := make([]byte, 32*1024) + lastProgress := float32(0) + + for { + select { + case <-ctx.Done(): + setDownloadError("download cancelled") + file.Close() + os.Remove(downloadPath) + return + default: + } + + nr, er := resp.Body.Read(buf) + if nr > 0 { + nw, ew := file.Write(buf[0:nr]) + if nw < nr { + setDownloadError(fmt.Sprintf("short write: %d < %d", nw, nr)) + return + } + written += int64(nw) + if ew != nil { + setDownloadError(fmt.Sprintf("write error: %v", ew)) + return + } + + progress := float32(written) / float32(totalSize) + if progress-lastProgress >= 0.01 { + downloadMutex.Lock() + if currentDownload != nil { + currentDownload.DoneBytes = written + currentDownload.Progress = progress + } + downloadMutex.Unlock() + broadcastDownloadState() + lastProgress = progress + } + } + if er != nil { + if er == io.EOF { + break + } + setDownloadError(fmt.Sprintf("read error: %v", er)) + return + } + } + + // Sync filesystem + file.Sync() + file.Close() + + // Rename to final filename + if err := os.Rename(downloadPath, finalPath); err != nil { + setDownloadError(fmt.Sprintf("failed to rename file: %v", err)) + return + } + + // Update final state + downloadMutex.Lock() + if currentDownload != nil { + currentDownload.DoneBytes = totalSize + currentDownload.Progress = 1.0 + } + downloadMutex.Unlock() + + logger.Info().Str("filename", filename).Int64("size", totalSize).Msg("download completed") +} + +func setDownloadError(errMsg string) { + downloadMutex.Lock() + if currentDownload != nil { + currentDownload.Error = errMsg + } + downloadMutex.Unlock() + logger.Warn().Str("error", errMsg).Msg("download error") +} + +func broadcastDownloadState() { + downloadMutex.Lock() + state := currentDownload + downloadMutex.Unlock() + + if currentSession != nil && state != nil { + writeJSONRPCEvent("downloadState", state, currentSession) + } +} diff --git a/usb_mass_storage_test.go b/usb_mass_storage_test.go new file mode 100644 index 00000000..5192ea87 --- /dev/null +++ b/usb_mass_storage_test.go @@ -0,0 +1,95 @@ +package kvm + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestSanitizeFilename(t *testing.T) { + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + {"Simple filename", "image.iso", "image.iso", false}, + {"Filename with spaces", "my image.iso", "my image.iso", false}, + {"Path traversal", "../image.iso", "", true}, + {"Absolute path", "/etc/passwd", "", true}, + {"Current directory", ".", "", true}, + {"Parent directory", "..", "", true}, + {"Root directory", "/", "", true}, + {"Nested path", "folder/image.iso", "image.iso", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := sanitizeFilename(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("sanitizeFilename() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.expected { + t.Errorf("sanitizeFilename() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestRpcDownloadFromUrl(t *testing.T) { + // Create temp directory for images + tmpDir := t.TempDir() + originalImagesFolder := imagesFolder + imagesFolder = tmpDir + defer func() { imagesFolder = originalImagesFolder }() + + // Start test server + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "12") + w.Write([]byte("test content")) + })) + defer ts.Close() + + // Test download + filename := "test.iso" + err := rpcDownloadFromUrl(ts.URL, filename) + if err != nil { + t.Fatalf("rpcDownloadFromUrl() error = %v", err) + } + + // Wait for download to complete (since it's async) + // In a real test we might need a better way to wait, but for now we can poll the state + timeout := time.After(2 * time.Second) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + t.Fatal("timeout waiting for download") + case <-ticker.C: + state, _ := rpcGetDownloadState() + if state.Error != "" { + t.Fatalf("download failed with error: %s", state.Error) + } + if !state.Downloading && state.DoneBytes == 12 { + goto Done + } + } + } +Done: + + // Verify file content + content, err := os.ReadFile(filepath.Join(tmpDir, filename)) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + if string(content) != "test content" { + t.Errorf("file content = %q, want %q", string(content), "test content") + } +}