diff --git a/main.go b/main.go index 25fbb3a..39b1427 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,11 @@ func Main() { initUsbGadget() + err = setInitialVirtualMediaState() + if err != nil { + logger.Warn().Err(err).Msg("failed to set initial virtual media state") + } + go func() { time.Sleep(15 * time.Minute) for { diff --git a/usb_mass_storage.go b/usb_mass_storage.go index 8d4b1f1..3ecbdd8 100644 --- a/usb_mass_storage.go +++ b/usb_mass_storage.go @@ -133,7 +133,7 @@ func getMassStorageCDROMEnabled() (bool, error) { if err != nil { return false, fmt.Errorf("failed to get mass storage path: %w", err) } - data, err := os.ReadFile(path.Join(massStorageFunctionPath, "lun.0", "cdrom")) + data, err := os.ReadFile(path.Join(massStorageFunctionPath, "cdrom")) if err != nil { return false, fmt.Errorf("failed to read cdrom mode: %w", err) } @@ -214,23 +214,47 @@ func getInitialVirtualMediaState() (*VirtualMediaState, error) { return nil, fmt.Errorf("failed to get mass storage image: %w", err) } - source := Storage - // TODO: check if it's WebRTC or HTTP - if diskPath == "/dev/nbd0" { - source = HTTP + initialState := &VirtualMediaState{ + Source: Storage, + Mode: Disk, } - mode := Disk if cdromEnabled { - mode = CDROM + initialState.Mode = CDROM } - return &VirtualMediaState{ - Source: source, - Mode: mode, - URL: "", - Size: 0, - }, nil + // TODO: check if it's WebRTC or HTTP + if diskPath == "" { + return nil, nil + } else if diskPath == "/dev/nbd0" { + initialState.Source = HTTP + initialState.URL = "/" + initialState.Size = 1 + } else { + initialState.Filename = filepath.Base(diskPath) + // get size from file + logger.Info().Str("diskPath", diskPath).Msg("getting file size") + info, err := os.Stat(diskPath) + if err != nil { + return nil, fmt.Errorf("failed to get file info: %w", err) + } + initialState.Size = info.Size() + } + + return initialState, nil +} + +func setInitialVirtualMediaState() error { + virtualMediaStateMutex.Lock() + defer virtualMediaStateMutex.Unlock() + initialState, err := getInitialVirtualMediaState() + if err != nil { + return fmt.Errorf("failed to get initial virtual media state: %w", err) + } + currentVirtualMediaState = initialState + + logger.Info().Interface("initial_virtual_media_state", initialState).Msg("initial virtual media state set") + return nil } func rpcMountWithHTTP(url string, mode VirtualMediaMode) error {