From 63c2272c4589c9b052ee75d5496a555af1a1632b Mon Sep 17 00:00:00 2001
From: Aveline <352441+ym@users.noreply.github.com>
Date: Mon, 12 May 2025 19:07:27 +0200
Subject: [PATCH] feat(usb_mass_storage): mount as disk (#333)

* feat(usb_mass_storage): mount as disk

* chore: try to set initial virtual media state from sysfs

* chore(usb-mass-storage): fix inquiry_string
---
 internal/usbgadget/config.go        |  23 ++++++
 internal/usbgadget/mass_storage.go  |  13 ++--
 jsonrpc.go                          |   9 ++-
 main.go                             |   5 ++
 ui/src/routes/devices.$id.mount.tsx |  10 +--
 usb_mass_storage.go                 | 104 ++++++++++++++++++++++++----
 6 files changed, 137 insertions(+), 27 deletions(-)

diff --git a/internal/usbgadget/config.go b/internal/usbgadget/config.go
index b73d392..5c287da 100644
--- a/internal/usbgadget/config.go
+++ b/internal/usbgadget/config.go
@@ -137,6 +137,29 @@ func (u *UsbGadget) GetPath(itemKey string) (string, error) {
 	return joinPath(u.kvmGadgetPath, item.path), nil
 }
 
+// OverrideGadgetConfig overrides the gadget config for the given item and attribute.
+// It returns an error if the item is not found or the attribute is not found.
+// It returns true if the attribute is overridden, false otherwise.
+func (u *UsbGadget) OverrideGadgetConfig(itemKey string, itemAttr string, value string) (error, bool) {
+	u.configLock.Lock()
+	defer u.configLock.Unlock()
+
+	// get it as a pointer
+	_, ok := u.configMap[itemKey]
+	if !ok {
+		return fmt.Errorf("config item %s not found", itemKey), false
+	}
+
+	if u.configMap[itemKey].attrs[itemAttr] == value {
+		return nil, false
+	}
+
+	u.configMap[itemKey].attrs[itemAttr] = value
+	u.log.Info().Str("itemKey", itemKey).Str("itemAttr", itemAttr).Str("value", value).Msg("overriding gadget config")
+
+	return nil, true
+}
+
 func mountConfigFS() error {
 	_, err := os.Stat(gadgetPath)
 	// TODO: check if it's mounted properly
diff --git a/internal/usbgadget/mass_storage.go b/internal/usbgadget/mass_storage.go
index f962cb4..41c1521 100644
--- a/internal/usbgadget/mass_storage.go
+++ b/internal/usbgadget/mass_storage.go
@@ -14,10 +14,13 @@ var massStorageLun0Config = gadgetConfigItem{
 	order: 3001,
 	path:  []string{"functions", "mass_storage.usb0", "lun.0"},
 	attrs: gadgetAttributes{
-		"cdrom":          "1",
-		"ro":             "1",
-		"removable":      "1",
-		"file":           "\n",
-		"inquiry_string": "JetKVM Virtual Media",
+		"cdrom":     "1",
+		"ro":        "1",
+		"removable": "1",
+		"file":      "\n",
+		// the additional whitespace is intentional to avoid the "JetKVM V irtual Media" string
+		// https://github.com/jetkvm/rv1106-system/blob/778133a1c153041e73f7de86c9c434a2753ea65d/sysdrv/source/uboot/u-boot/drivers/usb/gadget/f_mass_storage.c#L2556
+		// Vendor (8 chars), product (16 chars)
+		"inquiry_string": "JetKVM  Virtual Media",
 	},
 }
diff --git a/jsonrpc.go b/jsonrpc.go
index 05db3d5..3c805e4 100644
--- a/jsonrpc.go
+++ b/jsonrpc.go
@@ -566,9 +566,12 @@ type RPCHandler struct {
 func rpcSetMassStorageMode(mode string) (string, error) {
 	logger.Info().Str("mode", mode).Msg("Setting mass storage mode")
 	var cdrom bool
-	if mode == "cdrom" {
+	switch mode {
+	case "cdrom":
 		cdrom = true
-	} else if mode != "file" {
+	case "file":
+		cdrom = false
+	default:
 		logger.Info().Str("mode", mode).Msg("Invalid mode provided")
 		return "", fmt.Errorf("invalid mode: %s", mode)
 	}
@@ -587,7 +590,7 @@ func rpcSetMassStorageMode(mode string) (string, error) {
 }
 
 func rpcGetMassStorageMode() (string, error) {
-	cdrom, err := getMassStorageMode()
+	cdrom, err := getMassStorageCDROMEnabled()
 	if err != nil {
 		return "", fmt.Errorf("failed to get mass storage mode: %w", err)
 	}
diff --git a/main.go b/main.go
index 25fbb3a..39b1427 100644
--- a/main.go
+++ b/main.go
@@ -77,6 +77,11 @@ func Main() {
 
 	initUsbGadget()
 
+	err = setInitialVirtualMediaState()
+	if err != nil {
+		logger.Warn().Err(err).Msg("failed to set initial virtual media state")
+	}
+
 	go func() {
 		time.Sleep(15 * time.Minute)
 		for {
diff --git a/ui/src/routes/devices.$id.mount.tsx b/ui/src/routes/devices.$id.mount.tsx
index 74fcae2..4d3369a 100644
--- a/ui/src/routes/devices.$id.mount.tsx
+++ b/ui/src/routes/devices.$id.mount.tsx
@@ -414,7 +414,7 @@ function BrowserFileView({
     if (file?.name.endsWith(".iso")) {
       setUsbMode("CDROM");
     } else if (file?.name.endsWith(".img")) {
-      setUsbMode("CDROM");
+      setUsbMode("Disk");
     }
   };
 
@@ -566,7 +566,7 @@ function UrlView({
     if (url.endsWith(".iso")) {
       setUsbMode("CDROM");
     } else if (url.endsWith(".img")) {
-      setUsbMode("CDROM");
+      setUsbMode("Disk");
     }
   }
 
@@ -773,7 +773,7 @@ function DeviceFileView({
     if (file.name.endsWith(".iso")) {
       setUsbMode("CDROM");
     } else if (file.name.endsWith(".img")) {
-      setUsbMode("CDROM");
+      setUsbMode("Disk");
     }
   }
 
@@ -1579,7 +1579,6 @@ function UsbModeSelector({
             type="radio"
             id="disk"
             name="mountType"
-            disabled
             checked={usbMode === "Disk"}
             onChange={() => setUsbMode("Disk")}
             className="h-3 w-3 border-slate-800/30 bg-white text-blue-700 transition-opacity focus:ring-blue-500 disabled:opacity-30 dark:bg-slate-800"
@@ -1588,9 +1587,6 @@ function UsbModeSelector({
             <span className="text-sm font-medium leading-none text-slate-900 opacity-50 dark:text-white">
               Disk
             </span>
-            <div className="text-[10px] text-slate-500 dark:text-slate-400">
-              Coming soon
-            </div>
           </div>
         </label>
       </div>
diff --git a/usb_mass_storage.go b/usb_mass_storage.go
index 79a05d1..3ecbdd8 100644
--- a/usb_mass_storage.go
+++ b/usb_mass_storage.go
@@ -26,6 +26,19 @@ func writeFile(path string, data string) error {
 	return os.WriteFile(path, []byte(data), 0644)
 }
 
+func getMassStorageImage() (string, error) {
+	massStorageFunctionPath, err := gadget.GetPath("mass_storage_lun0")
+	if err != nil {
+		return "", fmt.Errorf("failed to get mass storage path: %w", err)
+	}
+
+	imagePath, err := os.ReadFile(path.Join(massStorageFunctionPath, "file"))
+	if err != nil {
+		return "", fmt.Errorf("failed to get mass storage image path: %w", err)
+	}
+	return strings.TrimSpace(string(imagePath)), nil
+}
+
 func setMassStorageImage(imagePath string) error {
 	massStorageFunctionPath, err := gadget.GetPath("mass_storage_lun0")
 	if err != nil {
@@ -39,19 +52,21 @@ func setMassStorageImage(imagePath string) error {
 }
 
 func setMassStorageMode(cdrom bool) error {
-	massStorageFunctionPath, err := gadget.GetPath("mass_storage_lun0")
-	if err != nil {
-		return fmt.Errorf("failed to get mass storage path: %w", err)
-	}
-
 	mode := "0"
 	if cdrom {
 		mode = "1"
 	}
-	if err := writeFile(path.Join(massStorageFunctionPath, "lun.0", "cdrom"), mode); err != nil {
+
+	err, changed := gadget.OverrideGadgetConfig("mass_storage_lun0", "cdrom", mode)
+	if err != nil {
 		return fmt.Errorf("failed to set cdrom mode: %w", err)
 	}
-	return nil
+
+	if !changed {
+		return nil
+	}
+
+	return gadget.UpdateGadgetConfig()
 }
 
 func onDiskMessage(msg webrtc.DataChannelMessage) {
@@ -113,20 +128,17 @@ func rpcMountBuiltInImage(filename string) error {
 	return mountImage(imagePath)
 }
 
-func getMassStorageMode() (bool, error) {
+func getMassStorageCDROMEnabled() (bool, error) {
 	massStorageFunctionPath, err := gadget.GetPath("mass_storage_lun0")
 	if err != nil {
 		return false, fmt.Errorf("failed to get mass storage path: %w", err)
 	}
-
-	data, err := os.ReadFile(path.Join(massStorageFunctionPath, "lun.0", "cdrom"))
+	data, err := os.ReadFile(path.Join(massStorageFunctionPath, "cdrom"))
 	if err != nil {
 		return false, fmt.Errorf("failed to read cdrom mode: %w", err)
 	}
-
 	// Trim any whitespace characters. It has a newline at the end
 	trimmedData := strings.TrimSpace(string(data))
-
 	return trimmedData == "1", nil
 }
 
@@ -191,6 +203,60 @@ func rpcUnmountImage() error {
 
 var httpRangeReader *httpreadat.RangeReader
 
+func getInitialVirtualMediaState() (*VirtualMediaState, error) {
+	cdromEnabled, err := getMassStorageCDROMEnabled()
+	if err != nil {
+		return nil, fmt.Errorf("failed to get mass storage cdrom enabled: %w", err)
+	}
+
+	diskPath, err := getMassStorageImage()
+	if err != nil {
+		return nil, fmt.Errorf("failed to get mass storage image: %w", err)
+	}
+
+	initialState := &VirtualMediaState{
+		Source: Storage,
+		Mode:   Disk,
+	}
+
+	if cdromEnabled {
+		initialState.Mode = CDROM
+	}
+
+	// TODO: check if it's WebRTC or HTTP
+	if diskPath == "" {
+		return nil, nil
+	} else if diskPath == "/dev/nbd0" {
+		initialState.Source = HTTP
+		initialState.URL = "/"
+		initialState.Size = 1
+	} else {
+		initialState.Filename = filepath.Base(diskPath)
+		// get size from file
+		logger.Info().Str("diskPath", diskPath).Msg("getting file size")
+		info, err := os.Stat(diskPath)
+		if err != nil {
+			return nil, fmt.Errorf("failed to get file info: %w", err)
+		}
+		initialState.Size = info.Size()
+	}
+
+	return initialState, nil
+}
+
+func setInitialVirtualMediaState() error {
+	virtualMediaStateMutex.Lock()
+	defer virtualMediaStateMutex.Unlock()
+	initialState, err := getInitialVirtualMediaState()
+	if err != nil {
+		return fmt.Errorf("failed to get initial virtual media state: %w", err)
+	}
+	currentVirtualMediaState = initialState
+
+	logger.Info().Interface("initial_virtual_media_state", initialState).Msg("initial virtual media state set")
+	return nil
+}
+
 func rpcMountWithHTTP(url string, mode VirtualMediaMode) error {
 	virtualMediaStateMutex.Lock()
 	if currentVirtualMediaState != nil {
@@ -204,6 +270,11 @@ func rpcMountWithHTTP(url string, mode VirtualMediaMode) error {
 		return fmt.Errorf("failed to use http url: %w", err)
 	}
 	logger.Info().Str("url", url).Int64("size", n).Msg("using remote url")
+
+	if err := setMassStorageMode(mode == CDROM); err != nil {
+		return fmt.Errorf("failed to set mass storage mode: %w", err)
+	}
+
 	currentVirtualMediaState = &VirtualMediaState{
 		Source: HTTP,
 		Mode:   mode,
@@ -243,6 +314,11 @@ func rpcMountWithWebRTC(filename string, size int64, mode VirtualMediaMode) erro
 		Size:     size,
 	}
 	virtualMediaStateMutex.Unlock()
+
+	if err := setMassStorageMode(mode == CDROM); err != nil {
+		return fmt.Errorf("failed to set mass storage mode: %w", err)
+	}
+
 	logger.Debug().Interface("currentVirtualMediaState", currentVirtualMediaState).Msg("currentVirtualMediaState")
 	logger.Debug().Msg("Starting nbd device")
 	nbdDevice = NewNBDDevice()
@@ -280,6 +356,10 @@ func rpcMountWithStorage(filename string, mode VirtualMediaMode) error {
 		return fmt.Errorf("failed to get file info: %w", err)
 	}
 
+	if err := setMassStorageMode(mode == CDROM); err != nil {
+		return fmt.Errorf("failed to set mass storage mode: %w", err)
+	}
+
 	err = setMassStorageImage(fullPath)
 	if err != nil {
 		return fmt.Errorf("failed to set mass storage image: %w", err)