diff --git a/usb.go b/usb.go index 8f413c7..0be8bd0 100644 --- a/usb.go +++ b/usb.go @@ -1,6 +1,7 @@ package kvm import ( + "bytes" "errors" "fmt" "log" @@ -20,8 +21,89 @@ const gadgetPath = "/sys/kernel/config/usb_gadget" const kvmGadgetPath = "/sys/kernel/config/usb_gadget/jetkvm" const configC1Path = "/sys/kernel/config/usb_gadget/jetkvm/configs/c.1" +type gadgetConfigItem struct { + path []string + attrs gadgetAttributes + configAttrs gadgetAttributes + configPath string + reportDesc []byte +} + +type gadgetAttributes map[string]string + +var gadgetConfig = map[string]gadgetConfigItem{ + "base": { + attrs: gadgetAttributes{ + "bcdUSB": "0x0200", // USB 2.0 + "idVendor": "0x1d6b", // The Linux Foundation + "idProduct": "0104", // Multifunction Composite Gadget¬ + "bcdDevice": "0100", + }, + configAttrs: gadgetAttributes{ + "MaxPower": "250", // in unit of 2mA + }, + }, + "base_info": { + path: []string{"strings", "0x409"}, + attrs: gadgetAttributes{ + "serialnumber": GetDeviceID(), + "manufacturer": "JetKVM", + "product": "JetKVM USB Emulation Device", + }, + configAttrs: gadgetAttributes{ + "configuration": "Config 1: HID", + }, + }, + // keyboard HID + "keyboard": { + path: []string{"functions", "hid.usb0"}, + configPath: path.Join(configC1Path, "hid.usb0"), + attrs: gadgetAttributes{ + "protocol": "1", + "subclass": "1", + "report_length": "8", + }, + reportDesc: KeyboardReportDesc, + }, + // mouse HID + "absolute_mouse": { + path: []string{"functions", "hid.usb1"}, + configPath: path.Join(configC1Path, "hid.usb1"), + attrs: gadgetAttributes{ + "protocol": "2", + "subclass": "1", + "report_length": "6", + }, + reportDesc: CombinedMouseReportDesc, + }, + // mass storage + "mass_storage_base": { + path: []string{"functions", "mass_storage.usb0"}, + configPath: path.Join(configC1Path, "mass_storage.usb0"), + attrs: gadgetAttributes{ + "stall": "1", + }, + }, + + "mass_storage_usb0": { + path: []string{"functions", "mass_storage.usb0", "lun.0"}, + attrs: gadgetAttributes{ + "cdrom": "1", + "ro": "1", + "removable": "1", + "file": "\n", + "inquiry_string": "JetKVM Virtual Media", + }, + }, +} + func mountConfigFS() error { _, err := os.Stat(gadgetPath) + // TODO: check if it's mounted properly + if err == nil { + return nil + } + if os.IsNotExist(err) { err = exec.Command("mount", "-t", "configfs", "none", configFSPath).Run() if err != nil { @@ -33,6 +115,93 @@ func mountConfigFS() error { return nil } +func writeIfDifferent(filePath string, content []byte, permMode os.FileMode) error { + if _, err := os.Stat(filePath); err == nil { + oldContent, err := os.ReadFile(filePath) + if err == nil { + if bytes.Equal(oldContent, content) { + logger.Tracef("skipping writing to %s as it already has the correct content", filePath) + return nil + } + + if len(oldContent) == len(content)+1 && + bytes.Equal(oldContent[:len(content)], content) && + oldContent[len(content)] == 10 { + logger.Tracef("skipping writing to %s as it already has the correct content", filePath) + return nil + } + + logger.Tracef("writing to %s as it has different content %v %v", filePath, oldContent, content) + } + } + return os.WriteFile(filePath, content, permMode) +} + +func writeGadgetItemConfig(item gadgetConfigItem) error { + // create directory for the item + gadgetItemPathArr := append([]string{kvmGadgetPath}, item.path...) + gadgetItemPath := filepath.Join(gadgetItemPathArr...) + err := os.MkdirAll(gadgetItemPath, 0755) + if err != nil { + return fmt.Errorf("failed to create path %s: %w", gadgetItemPath, err) + } + + if len(item.configAttrs) > 0 { + configItemPathArr := append([]string{configC1Path}, item.path...) + configItemPath := filepath.Join(configItemPathArr...) + err = os.MkdirAll(configItemPath, 0755) + if err != nil { + return fmt.Errorf("failed to create path %s: %w", config, err) + } + + err = writeGadgetAttrs(configItemPath, item.configAttrs) + if err != nil { + return fmt.Errorf("failed to write config attributes for %s: %w", configItemPath, err) + } + } + + if len(item.attrs) > 0 { + // write attributes for the item + err = writeGadgetAttrs(gadgetItemPath, item.attrs) + if err != nil { + return fmt.Errorf("failed to write attributes for %s: %w", gadgetItemPath, err) + } + } + + // write report descriptor if available + if item.reportDesc != nil { + err = writeIfDifferent(path.Join(gadgetItemPath, "report_desc"), item.reportDesc, 0644) + if err != nil { + return err + } + } + + // create symlink if configPath is set + if item.configPath != "" { + logger.Tracef("Creating symlink from %s to %s", item.configPath, gadgetItemPath) + + // check if the symlink already exists, if yes, check if it points to the correct path + if _, err := os.Lstat(item.configPath); err == nil { + linkPath, err := os.Readlink(item.configPath) + if err != nil || linkPath != gadgetItemPath { + err = os.Remove(item.configPath) + if err != nil { + return fmt.Errorf("failed to remove existing symlink %s: %w", item.configPath, err) + } + } + } else if !os.IsNotExist(err) { + return fmt.Errorf("failed to check if symlink exists: %w", err) + } + + err = os.Symlink(gadgetItemPath, item.configPath) + if err != nil { + return fmt.Errorf("failed to create symlink from %s to %s: %w", item.configPath, gadgetItemPath, err) + } + } + + return nil +} + func init() { ensureConfigLoaded() @@ -45,13 +214,15 @@ func init() { udc = udcs[0] _, err := os.Stat(kvmGadgetPath) if err == nil { - logger.Info("usb gadget already exists, skipping usb gadget initialization") - return + logger.Info("usb gadget already exists") } err = mountConfigFS() if err != nil { logger.Errorf("failed to mount configfs: %v, usb stack might not function properly", err) } + + loadGadgetConfigFromUsbConfig() + err = writeGadgetConfig() if err != nil { logger.Errorf("failed to start gadget: %v", err) @@ -60,37 +231,23 @@ func init() { //TODO: read hid reports(capslock, numlock, etc) from keyboardHidFile } +func loadGadgetConfigFromUsbConfig() { + gadgetConfig["base"].attrs["idVendor"] = config.UsbConfig.VendorId + gadgetConfig["base"].attrs["idProduct"] = config.UsbConfig.ProductId + + gadgetConfig["base_info"].attrs["serialnumber"] = config.UsbConfig.SerialNumber + gadgetConfig["base_info"].attrs["manufacturer"] = config.UsbConfig.Manufacturer + gadgetConfig["base_info"].attrs["product"] = config.UsbConfig.Product +} + func UpdateGadgetConfig() error { - LoadConfig() - gadgetAttrs := [][]string{ - {"idVendor", config.UsbConfig.VendorId}, - {"idProduct", config.UsbConfig.ProductId}, - } - err := writeGadgetAttrs(kvmGadgetPath, gadgetAttrs) + loadGadgetConfigFromUsbConfig() + err := writeGadgetConfig() if err != nil { - return err + logger.Errorf("failed to update gadget: %v", err) } - log.Printf("Successfully updated usb gadget attributes: %v", gadgetAttrs) - - strAttrs := [][]string{ - {"serialnumber", config.UsbConfig.SerialNumber}, - {"manufacturer", config.UsbConfig.Manufacturer}, - {"product", config.UsbConfig.Product}, - } - gadgetStringsPath := filepath.Join(kvmGadgetPath, "strings", "0x409") - err = os.MkdirAll(gadgetStringsPath, 0755) - if err != nil { - return err - } - err = writeGadgetAttrs(gadgetStringsPath, strAttrs) - if err != nil { - return err - } - - log.Printf("Successfully updated usb string attributes: %s", strAttrs) - - err = rebindUsb() + err = rebindUsb(false) if err != nil { return err } @@ -98,10 +255,10 @@ func UpdateGadgetConfig() error { return nil } -func writeGadgetAttrs(basePath string, attrs [][]string) error { - for _, item := range attrs { - filePath := filepath.Join(basePath, item[0]) - err := os.WriteFile(filePath, []byte(item[1]), 0644) +func writeGadgetAttrs(basePath string, attrs gadgetAttributes) error { + for key, val := range attrs { + filePath := filepath.Join(basePath, key) + err := writeIfDifferent(filePath, []byte(val), 0644) if err != nil { return fmt.Errorf("failed to write to %s: %w", filePath, err) } @@ -119,145 +276,32 @@ func writeGadgetConfig() error { return err } - err = writeGadgetAttrs(kvmGadgetPath, [][]string{ - {"bcdUSB", "0x0200"}, //USB 2.0 - {"idVendor", config.UsbConfig.VendorId}, //The Linux Foundation - {"idProduct", config.UsbConfig.ProductId}, //Multifunction Composite Gadget¬ - {"bcdDevice", "0100"}, - }) - if err != nil { - return err - } - - gadgetStringsPath := filepath.Join(kvmGadgetPath, "strings", "0x409") - err = os.MkdirAll(gadgetStringsPath, 0755) - if err != nil { - return err - } - - err = writeGadgetAttrs(gadgetStringsPath, [][]string{ - {"serialnumber", GetDeviceID()}, - {"manufacturer", config.UsbConfig.Manufacturer}, - {"product", config.UsbConfig.Product}, - }) - if err != nil { - return err - } - - configC1StringsPath := path.Join(configC1Path, "strings", "0x409") - err = os.MkdirAll(configC1StringsPath, 0755) - if err != nil { - return err - } - - err = writeGadgetAttrs(configC1Path, [][]string{ - {"MaxPower", "250"}, //in unit of 2mA - }) - if err != nil { - return err - } - - err = writeGadgetAttrs(configC1StringsPath, [][]string{ - {"configuration", "Config 1: HID"}, - }) - if err != nil { - return err - } - - //keyboard HID - hid0Path := path.Join(kvmGadgetPath, "functions", "hid.usb0") - err = os.MkdirAll(hid0Path, 0755) - if err != nil { - return err - } - err = writeGadgetAttrs(hid0Path, [][]string{ - {"protocol", "1"}, - {"subclass", "1"}, - {"report_length", "8"}, - }) - if err != nil { - return err - } - - err = os.WriteFile(path.Join(hid0Path, "report_desc"), KeyboardReportDesc, 0644) - if err != nil { - return err - } - - //mouse HID - hid1Path := path.Join(kvmGadgetPath, "functions", "hid.usb1") - err = os.MkdirAll(hid1Path, 0755) - if err != nil { - return err - } - err = writeGadgetAttrs(hid1Path, [][]string{ - {"protocol", "2"}, - {"subclass", "1"}, - {"report_length", "6"}, - }) - if err != nil { - return err - } - - err = os.WriteFile(path.Join(hid1Path, "report_desc"), CombinedMouseReportDesc, 0644) - if err != nil { - return err - } - //mass storage - massStoragePath := path.Join(kvmGadgetPath, "functions", "mass_storage.usb0") - err = os.MkdirAll(massStoragePath, 0755) - if err != nil { - return err - } - - err = writeGadgetAttrs(massStoragePath, [][]string{ - {"stall", "1"}, - }) - if err != nil { - return err - } - lun0Path := path.Join(massStoragePath, "lun.0") - err = os.MkdirAll(lun0Path, 0755) - if err != nil { - return err - } - err = writeGadgetAttrs(lun0Path, [][]string{ - {"cdrom", "1"}, - {"ro", "1"}, - {"removable", "1"}, - {"file", "\n"}, - {"inquiry_string", "JetKVM Virtual Media"}, - }) - if err != nil { - return err - } - - err = os.Symlink(hid0Path, path.Join(configC1Path, "hid.usb0")) - if err != nil { - return err - } - - err = os.Symlink(hid1Path, path.Join(configC1Path, "hid.usb1")) - if err != nil { - return err - } - - err = os.Symlink(massStoragePath, path.Join(configC1Path, "mass_storage.usb0")) - if err != nil { - return err + logger.Tracef("writing gadget config") + for key, item := range gadgetConfig { + logger.Tracef("writing gadget config: %s", key) + err = writeGadgetItemConfig(item) + if err != nil { + return err + } } + logger.Tracef("writing UDC") err = os.WriteFile(path.Join(kvmGadgetPath, "UDC"), []byte(udc), 0644) if err != nil { return err } + err = rebindUsb(true) + if err != nil { + logger.Infof("failed to rebind usb: %v", err) + } + return nil } -func rebindUsb() error { +func rebindUsb(ignoreUnbindError bool) error { err := os.WriteFile("/sys/bus/platform/drivers/dwc3/unbind", []byte(udc), 0644) - if err != nil { + if err != nil && !ignoreUnbindError { return err } err = os.WriteFile("/sys/bus/platform/drivers/dwc3/bind", []byte(udc), 0644)