From c939a02f33fa60e11fdfa3673e53bb5608faa16f Mon Sep 17 00:00:00 2001 From: Siyuan Miao Date: Tue, 18 Mar 2025 17:25:03 +0100 Subject: [PATCH 1/9] feat(tls): store tls certificates --- internal/websecure/log.go | 5 + internal/websecure/selfsign.go | 182 +++++++++++++++++++++++++++++++++ internal/websecure/store.go | 126 +++++++++++++++++++++++ internal/websecure/utils.go | 80 +++++++++++++++ main.go | 1 + web_tls.go | 132 +++++------------------- 6 files changed, 422 insertions(+), 104 deletions(-) create mode 100644 internal/websecure/log.go create mode 100644 internal/websecure/selfsign.go create mode 100644 internal/websecure/store.go create mode 100644 internal/websecure/utils.go diff --git a/internal/websecure/log.go b/internal/websecure/log.go new file mode 100644 index 0000000..7643ac1 --- /dev/null +++ b/internal/websecure/log.go @@ -0,0 +1,5 @@ +package websecure + +import "github.com/pion/logging" + +var defaultLogger = logging.NewDefaultLoggerFactory().NewLogger("websecure") diff --git a/internal/websecure/selfsign.go b/internal/websecure/selfsign.go new file mode 100644 index 0000000..04844ae --- /dev/null +++ b/internal/websecure/selfsign.go @@ -0,0 +1,182 @@ +package websecure + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "log" + "net" + "strings" + "time" + + "github.com/pion/logging" + "golang.org/x/net/idna" +) + +const selfSignerCAMagicName = "__ca__" + +type SelfSigner struct { + store *CertStore + log logging.LeveledLogger + + caInfo pkix.Name + + DefaultDomain string + DefaultOrg string + DefaultOU string +} + +func NewSelfSigner(store *CertStore, log logging.LeveledLogger, defaultDomain, defaultOrg, defaultOU, caName string) *SelfSigner { + return &SelfSigner{ + store: store, + log: log, + DefaultDomain: defaultDomain, + DefaultOrg: defaultOrg, + DefaultOU: defaultOU, + caInfo: pkix.Name{ + CommonName: caName, + Organization: []string{defaultOrg}, + OrganizationalUnit: []string{defaultOU}, + }, + } +} + +func (s *SelfSigner) getCA() *tls.Certificate { + return s.createSelfSignedCert(selfSignerCAMagicName) +} + +func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate { + if tlsCert := s.store.certificates[hostname]; tlsCert != nil { + return tlsCert + } + + // check if hostname is the CA magic name + var ca *tls.Certificate + if hostname != selfSignerCAMagicName { + ca = s.getCA() + if ca == nil { + s.log.Errorf("Failed to get CA certificate") + return nil + } + } + + s.log.Infof("Creating self-signed certificate for %s", hostname) + + // lock the store while creating the certificate (do not move upwards) + s.store.certLock.Lock() + defer s.store.certLock.Unlock() + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + + notBefore := time.Now() + notAfter := notBefore.AddDate(1, 0, 0) + + serialNumber, err := generateSerialNumber() + if err != nil { + s.log.Errorf("Failed to generate serial number: %v", err) + } + + dnsName := hostname + ip := net.ParseIP(hostname) + if ip != nil { + dnsName = s.DefaultDomain + } + + // set up CSR + isCA := hostname == selfSignerCAMagicName + subject := pkix.Name{ + CommonName: hostname, + Organization: []string{s.DefaultOrg}, + OrganizationalUnit: []string{s.DefaultOU}, + } + keyUsage := x509.KeyUsageDigitalSignature + extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + + // check if hostname is the CA magic name, and if so, set the subject to the CA info + if isCA { + subject = s.caInfo + keyUsage |= x509.KeyUsageCertSign + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth) + notAfter = notBefore.AddDate(10, 0, 0) + } + + cert := x509.Certificate{ + SerialNumber: serialNumber, + Subject: subject, + NotBefore: notBefore, + NotAfter: notAfter, + IsCA: isCA, + KeyUsage: keyUsage, + ExtKeyUsage: extKeyUsage, + BasicConstraintsValid: true, + } + + // set up DNS names and IP addresses + if !isCA { + cert.DNSNames = []string{dnsName} + if ip != nil { + cert.IPAddresses = []net.IP{ip} + } + } + + // set up parent certificate + parent := &cert + parentPriv := priv + if ca != nil { + parent, err = x509.ParseCertificate(ca.Certificate[0]) + if err != nil { + s.log.Errorf("Failed to parse parent certificate: %v", err) + return nil + } + parentPriv = ca.PrivateKey.(*ecdsa.PrivateKey) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &cert, parent, &priv.PublicKey, parentPriv) + if err != nil { + s.log.Errorf("Failed to create certificate: %v", err) + return nil + } + + tlsCert := &tls.Certificate{ + Certificate: [][]byte{certBytes}, + PrivateKey: priv, + } + if ca != nil { + tlsCert.Certificate = append(tlsCert.Certificate, ca.Certificate...) + } + + s.store.certificates[hostname] = tlsCert + s.store.saveCertificate(hostname) + + return tlsCert +} + +func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + hostname := s.DefaultDomain + if info.ServerName != "" && info.ServerName != selfSignerCAMagicName { + hostname = info.ServerName + } else { + hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0] + } + + s.log.Infof("TLS handshake for %s, SupportedProtos: %v", hostname, info.SupportedProtos) + + // convert hostname to punycode + h, err := idna.Lookup.ToASCII(hostname) + if err != nil { + s.log.Warnf("Hostname %s is not valid: %w, from %s", hostname, err, info.Conn.RemoteAddr()) + hostname = s.DefaultDomain + } else { + hostname = h + } + + cert := s.createSelfSignedCert(hostname) + + return cert, nil +} diff --git a/internal/websecure/store.go b/internal/websecure/store.go new file mode 100644 index 0000000..d6d835e --- /dev/null +++ b/internal/websecure/store.go @@ -0,0 +1,126 @@ +package websecure + +import ( + "crypto/tls" + "fmt" + "os" + "path" + "strings" + "sync" + + "github.com/pion/logging" +) + +type CertStore struct { + certificates map[string]*tls.Certificate + certLock *sync.Mutex + + storePath string + + log logging.LeveledLogger +} + +func NewCertStore(storePath string) *CertStore { + return &CertStore{ + certificates: make(map[string]*tls.Certificate), + certLock: &sync.Mutex{}, + + storePath: storePath, + log: defaultLogger, + } +} + +func (s *CertStore) ensureStorePath() error { + // check if directory exists + stat, err := os.Stat(s.storePath) + if err == nil { + if stat.IsDir() { + return nil + } + + return fmt.Errorf("TLS store path exists but is not a directory: %s", s.storePath) + } + + if os.IsNotExist(err) { + s.log.Tracef("TLS store directory does not exist, creating directory") + err = os.MkdirAll(s.storePath, 0755) + if err != nil { + return fmt.Errorf("Failed to create TLS store path: %w", err) + } + return nil + } + + return fmt.Errorf("Failed to check TLS store path: %w", err) +} + +func (s *CertStore) LoadCertificates() { + err := s.ensureStorePath() + if err != nil { + s.log.Errorf(err.Error()) + return + } + + files, err := os.ReadDir(s.storePath) + if err != nil { + s.log.Errorf("Failed to read TLS directory: %v", err) + return + } + + for _, file := range files { + if file.IsDir() { + continue + } + + if strings.HasSuffix(file.Name(), ".crt") { + s.loadCertificate(strings.TrimSuffix(file.Name(), ".crt")) + } + } +} + +func (s *CertStore) loadCertificate(hostname string) { + s.certLock.Lock() + defer s.certLock.Unlock() + + keyFile := path.Join(s.storePath, hostname+".key") + crtFile := path.Join(s.storePath, hostname+".crt") + + cert, err := tls.LoadX509KeyPair(crtFile, keyFile) + if err != nil { + s.log.Errorf("Failed to load certificate for %s: %w", hostname, err) + return + } + + s.certificates[hostname] = &cert + + s.log.Infof("Loaded certificate for %s", hostname) +} + +func (s *CertStore) saveCertificate(hostname string) { + // check if certificate already exists + tlsCert := s.certificates[hostname] + if tlsCert == nil { + s.log.Errorf("Certificate for %s does not exist, skipping saving certificate", hostname) + return + } + + err := s.ensureStorePath() + if err != nil { + s.log.Errorf(err.Error()) + return + } + + keyFile := path.Join(s.storePath, hostname+".key") + crtFile := path.Join(s.storePath, hostname+".crt") + + if keyToFile(tlsCert, keyFile); err != nil { + s.log.Errorf(err.Error()) + return + } + + if certToFile(tlsCert, crtFile); err != nil { + s.log.Errorf(err.Error()) + return + } + + s.log.Infof("Saved certificate for %s", hostname) +} diff --git a/internal/websecure/utils.go b/internal/websecure/utils.go new file mode 100644 index 0000000..de29c73 --- /dev/null +++ b/internal/websecure/utils.go @@ -0,0 +1,80 @@ +package websecure + +import ( + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "math/big" + "os" +) + +var serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 4096) + +func withSecretFile(filename string, f func(*os.File) error) error { + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer file.Close() + + return f(file) +} + +func keyToFile(cert *tls.Certificate, filename string) error { + var keyBlock pem.Block + switch k := cert.PrivateKey.(type) { + case *rsa.PrivateKey: + keyBlock = pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(k), + } + case *ecdsa.PrivateKey: + b, e := x509.MarshalECPrivateKey(k) + if e != nil { + return fmt.Errorf("Failed to marshal EC private key: %v", e) + } + + keyBlock = pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: b, + } + default: + return fmt.Errorf("Unknown private key type: %T", k) + } + + err := withSecretFile(filename, func(file *os.File) error { + return pem.Encode(file, &keyBlock) + }) + + if err != nil { + return fmt.Errorf("Failed to save private key: %w", err) + } + + return nil +} + +func certToFile(cert *tls.Certificate, filename string) error { + return withSecretFile(filename, func(file *os.File) error { + for _, c := range cert.Certificate { + block := pem.Block{ + Type: "CERTIFICATE", + Bytes: c, + } + + err := pem.Encode(file, &block) + if err != nil { + return fmt.Errorf("Failed to save certificate: %w", err) + } + } + + return nil + }) +} + +func generateSerialNumber() (*big.Int, error) { + return rand.Int(rand.Reader, serialNumberLimit) +} diff --git a/main.go b/main.go index 98748c7..48ee8cb 100644 --- a/main.go +++ b/main.go @@ -70,6 +70,7 @@ func Main() { //go RunFuseServer() go RunWebServer() if config.TLSMode != "" { + initCertStore() go RunWebSecureServer() } // As websocket client already checks if the cloud token is set, we can start it here. diff --git a/web_tls.go b/web_tls.go index 1ef4d31..dbbad70 100644 --- a/web_tls.go +++ b/web_tls.go @@ -1,56 +1,51 @@ package kvm import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "math/big" - "net" "net/http" - "os" - "strings" - "sync" - "time" + + "github.com/jetkvm/kvm/internal/websecure" ) const ( - WebSecureListen = ":443" - WebSecureSelfSignedDefaultDomain = "jetkvm.local" - WebSecureSelfSignedDuration = 365 * 24 * time.Hour + tlsStorePath = "/userdata/jetkvm/tls" + webSecureListen = ":443" + webSecureSelfSignedDefaultDomain = "jetkvm.local" + webSecureSelfSignedCAName = "JetKVM Self-Signed CA" + webSecureSelfSignedOrganization = "JetKVM" + webSecureSelfSignedOU = "JetKVM Self-Signed" ) var ( - tlsCerts = make(map[string]*tls.Certificate) - tlsCertLock = &sync.Mutex{} + certStore *websecure.CertStore + certSigner *websecure.SelfSigner ) +func initCertStore() { + certStore = websecure.NewCertStore(tlsStorePath) + certStore.LoadCertificates() + + certSigner = websecure.NewSelfSigner( + certStore, + logger, + webSecureSelfSignedDefaultDomain, + webSecureSelfSignedOrganization, + webSecureSelfSignedOU, + webSecureSelfSignedCAName, + ) +} + // RunWebSecureServer runs a web server with TLS. func RunWebSecureServer() { r := setupRouter() server := &http.Server{ - Addr: WebSecureListen, + Addr: webSecureListen, Handler: r, TLSConfig: &tls.Config{ - // TODO: cache certificate in persistent storage - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - var hostname string - if info.ServerName != "" { - hostname = info.ServerName - } else { - hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0] - } - - logger.Info().Str("hostname", hostname).Interface("SupportedProtos", info.SupportedProtos).Msg("TLS handshake") - - cert := createSelfSignedCert(hostname) - - return cert, nil - }, + MaxVersion: tls.VersionTLS13, + CurvePreferences: []tls.CurveID{}, + GetCertificate: certSigner.GetCertificate, }, } logger.Info().Str("listen", WebSecureListen).Msg("Starting websecure server") @@ -59,74 +54,3 @@ func RunWebSecureServer() { panic(err) } } - -func createSelfSignedCert(hostname string) *tls.Certificate { - if tlsCert := tlsCerts[hostname]; tlsCert != nil { - return tlsCert - } - tlsCertLock.Lock() - defer tlsCertLock.Unlock() - - logger.Info().Str("hostname", hostname).Msg("Creating self-signed certificate") - - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - logger.Warn().Err(err).Msg("Failed to generate private key") - os.Exit(1) - } - keyUsage := x509.KeyUsageDigitalSignature - - notBefore := time.Now() - notAfter := notBefore.AddDate(1, 0, 0) - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - logger.Warn().Err(err).Msg("Failed to generate serial number") - } - - dnsName := hostname - ip := net.ParseIP(hostname) - if ip != nil { - dnsName = WebSecureSelfSignedDefaultDomain - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - CommonName: hostname, - Organization: []string{"JetKVM"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: keyUsage, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - - DNSNames: []string{dnsName}, - IPAddresses: []net.IP{}, - } - - if ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - logger.Warn().Err(err).Msg("Failed to create certificate") - } - - cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - if cert == nil { - logger.Warn().Msg("Failed to encode certificate") - } - - tlsCert := &tls.Certificate{ - Certificate: [][]byte{derBytes}, - PrivateKey: priv, - } - tlsCerts[hostname] = tlsCert - - return tlsCert -} From 34067c13c72ac5f4faa4c1885db36d550d88683e Mon Sep 17 00:00:00 2001 From: Siyuan Miao Date: Mon, 7 Apr 2025 13:23:16 +0200 Subject: [PATCH 2/9] feat(tls): rewrite tls feature --- Makefile | 3 + config.go | 2 +- internal/websecure/selfsign.go | 10 +- internal/websecure/store.go | 45 +++++ jsonrpc.go | 39 +++- main.go | 1 + .../devices.$id.settings.access._index.tsx | 177 +++++++++++++++--- web_tls.go | 72 ++++++- 8 files changed, 322 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index 2aefdea..5ef7804 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,9 @@ VERSION := 0.3.8 PROMETHEUS_TAG := github.com/prometheus/common/version KVM_PKG_NAME := github.com/jetkvm/kvm +PROMETHEUS_TAG := github.com/prometheus/common/version +KVM_PKG_NAME := github.com/jetkvm/kvm + GO_LDFLAGS := \ -s -w \ -X $(PROMETHEUS_TAG).Branch=$(BRANCH) \ diff --git a/config.go b/config.go index c38f1ed..cc98fb9 100644 --- a/config.go +++ b/config.go @@ -90,7 +90,7 @@ type Config struct { DisplayMaxBrightness int `json:"display_max_brightness"` DisplayDimAfterSec int `json:"display_dim_after_sec"` DisplayOffAfterSec int `json:"display_off_after_sec"` - TLSMode string `json:"tls_mode"` + TLSMode string `json:"tls_mode"` // options: "self-signed", "user-defined", "" UsbConfig *usbgadget.Config `json:"usb_config"` UsbDevices *usbgadget.Devices `json:"usb_devices"` } diff --git a/internal/websecure/selfsign.go b/internal/websecure/selfsign.go index 04844ae..f93d719 100644 --- a/internal/websecure/selfsign.go +++ b/internal/websecure/selfsign.go @@ -29,7 +29,14 @@ type SelfSigner struct { DefaultOU string } -func NewSelfSigner(store *CertStore, log logging.LeveledLogger, defaultDomain, defaultOrg, defaultOU, caName string) *SelfSigner { +func NewSelfSigner( + store *CertStore, + log logging.LeveledLogger, + defaultDomain, + defaultOrg, + defaultOU, + caName string, +) *SelfSigner { return &SelfSigner{ store: store, log: log, @@ -177,6 +184,5 @@ func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate } cert := s.createSelfSignedCert(hostname) - return cert, nil } diff --git a/internal/websecure/store.go b/internal/websecure/store.go index d6d835e..3049c17 100644 --- a/internal/websecure/store.go +++ b/internal/websecure/store.go @@ -95,6 +95,51 @@ func (s *CertStore) loadCertificate(hostname string) { s.log.Infof("Loaded certificate for %s", hostname) } +// GetCertificate returns the certificate for the given hostname +// returns nil if the certificate is not found +func (s *CertStore) GetCertificate(hostname string) *tls.Certificate { + s.certLock.Lock() + defer s.certLock.Unlock() + + return s.certificates[hostname] +} + +// ValidateAndSaveCertificate validates the certificate and saves it to the store +// returns are: +// - error: if the certificate is invalid or if there's any error during saving the certificate +// - error: if there's any warning or error during saving the certificate +func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key string, ignoreWarning bool) (error, error) { + tlsCert, err := tls.X509KeyPair([]byte(cert), []byte(key)) + if err != nil { + return fmt.Errorf("Failed to parse certificate: %w", err), nil + } + + // this can be skipped as current implementation supports one custom certificate only + if tlsCert.Leaf != nil { + // add recover to avoid panic + defer func() { + if r := recover(); r != nil { + s.log.Errorf("Failed to verify hostname: %v", r) + } + }() + + if err = tlsCert.Leaf.VerifyHostname(hostname); err != nil { + if !ignoreWarning { + return nil, fmt.Errorf("Certificate does not match hostname: %w", err) + } + s.log.Warnf("Certificate does not match hostname: %v", err) + } + } + + s.certLock.Lock() + s.certificates[hostname] = &tlsCert + s.certLock.Unlock() + + s.saveCertificate(hostname) + + return nil, nil +} + func (s *CertStore) saveCertificate(hostname string) { // check if certificate already exists tlsCert := s.certificates[hostname] diff --git a/jsonrpc.go b/jsonrpc.go index de29e08..4b72b5d 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -95,7 +95,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - //logger.Infof("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID) + logger.Tracef("Received RPC request: Method=%s, Params=%v, ID=%w", request.Method, request.Params, request.ID) handler, ok := rpcHandlers[request.Method] if !ok { errorResponse := JSONRPCResponse{ @@ -110,6 +110,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } + logger.Tracef("Calling RPC handler: %s, ID=%w", request.Method, request.ID) result, err := callRPCHandler(handler, request.Params) if err != nil { errorResponse := JSONRPCResponse{ @@ -125,6 +126,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } + logger.Tracef("RPC handler returned: %v, ID=%w", result, request.ID) response := JSONRPCResponse{ JSONRPC: "2.0", Result: result, @@ -141,6 +143,30 @@ func rpcGetDeviceID() (string, error) { return GetDeviceID(), nil } +func rpcReboot(force bool) error { + logger.Info("Got reboot request from JSONRPC, rebooting...") + + args := []string{} + if force { + args = append(args, "-f") + } + + cmd := exec.Command("reboot", args...) + err := cmd.Start() + if err != nil { + logger.Errorf("failed to reboot: %v", err) + return fmt.Errorf("failed to reboot: %w", err) + } + + // If the reboot command is successful, exit the program after 5 seconds + go func() { + time.Sleep(5 * time.Second) + os.Exit(0) + }() + + return nil +} + var streamFactor = 1.0 func rpcGetStreamQualityFactor() (float64, error) { @@ -375,6 +401,14 @@ func rpcSetSSHKeyState(sshKey string) error { return nil } +func rpcGetTLSState() TLSState { + return getTLSState() +} + +func rpcSetTLSState(tlsState TLSState) error { + return setTLSState(tlsState) +} + func callRPCHandler(handler RPCHandler, params map[string]interface{}) (interface{}, error) { handlerValue := reflect.ValueOf(handler.Func) handlerType := handlerValue.Type() @@ -892,6 +926,7 @@ func setKeyboardMacros(params KeyboardMacrosParams) (interface{}, error) { var rpcHandlers = map[string]RPCHandler{ "ping": {Func: rpcPing}, + "reboot": {Func: rpcReboot, Params: []string{"force"}}, "getDeviceID": {Func: rpcGetDeviceID}, "deregisterDevice": {Func: rpcDeregisterDevice}, "getCloudState": {Func: rpcGetCloudState}, @@ -920,6 +955,8 @@ var rpcHandlers = map[string]RPCHandler{ "setDevModeState": {Func: rpcSetDevModeState, Params: []string{"enabled"}}, "getSSHKeyState": {Func: rpcGetSSHKeyState}, "setSSHKeyState": {Func: rpcSetSSHKeyState, Params: []string{"sshKey"}}, + "getTLSState": {Func: rpcGetTLSState}, + "setTLSState": {Func: rpcSetTLSState, Params: []string{"state"}}, "setMassStorageMode": {Func: rpcSetMassStorageMode, Params: []string{"mode"}}, "getMassStorageMode": {Func: rpcGetMassStorageMode}, "isUpdatePending": {Func: rpcIsUpdatePending}, diff --git a/main.go b/main.go index 48ee8cb..62db7f6 100644 --- a/main.go +++ b/main.go @@ -69,6 +69,7 @@ func Main() { }() //go RunFuseServer() go RunWebServer() + if config.TLSMode != "" { initCertStore() go RunWebSecureServer() diff --git a/ui/src/routes/devices.$id.settings.access._index.tsx b/ui/src/routes/devices.$id.settings.access._index.tsx index 0ed5862..8c944b7 100644 --- a/ui/src/routes/devices.$id.settings.access._index.tsx +++ b/ui/src/routes/devices.$id.settings.access._index.tsx @@ -18,6 +18,13 @@ import { isOnDevice } from "@/main"; import { LocalDevice } from "./devices.$id"; import { SettingsItem } from "./devices.$id.settings"; import { CloudState } from "./adopt"; +import { TextAreaWithLabel } from "@components/TextArea"; + +export interface TLSState { + mode: "selfsigned" | "custom" | "disabled"; + certificate?: string; + privateKey?: string; +}; export const loader = async () => { if (isOnDevice) { @@ -44,6 +51,9 @@ export default function SettingsAccessIndexRoute() { // Use a simple string identifier for the selected provider const [selectedProvider, setSelectedProvider] = useState("jetkvm"); + const [tlsMode, setTlsMode] = useState("self-signed"); + const [tlsCert, setTlsCert] = useState(""); + const [tlsKey, setTlsKey] = useState(""); const getCloudState = useCallback(() => { send("getCloudState", {}, resp => { @@ -66,6 +76,17 @@ export default function SettingsAccessIndexRoute() { }); }, [send]); + const getTLSState = useCallback(() => { + send("getTLSState", {}, resp => { + if ("error" in resp) return console.error(resp.error); + const tlsState = resp.result as TLSState; + + setTlsMode(tlsState.mode); + if (tlsState.certificate) setTlsCert(tlsState.certificate); + if (tlsState.privateKey) setTlsKey(tlsState.privateKey); + }); + }, [send]); + const deregisterDevice = async () => { send("deregisterDevice", {}, resp => { if ("error" in resp) { @@ -126,15 +147,51 @@ export default function SettingsAccessIndexRoute() { } }; + // Handle TLS mode change + const handleTlsModeChange = (value: string) => { + setTlsMode(value); + }; + + const handleTlsCertChange = (value: string) => { + setTlsCert(value); + }; + + const handleTlsKeyChange = (value: string) => { + setTlsKey(value); + }; + + const handleTlsUpdate = useCallback(() => { + send("setTLSState", { state: { mode: tlsMode, certificate: tlsCert, privateKey: tlsKey } as TLSState }, resp => { + if ("error" in resp) { + notifications.error(`Failed to update TLS settings: ${resp.error.data || "Unknown error"}`); + return; + } + + notifications.success("TLS settings updated successfully"); + }); + }, [send, tlsMode, tlsCert, tlsKey]); + + const handleReboot = useCallback(() => { + send("reboot", { force: false }, resp => { + if ("error" in resp) { + notifications.error(`Failed to reboot: ${resp.error.data || "Unknown error"}`); + return; + } + + notifications.success("Device will restart shortly, it might take a few seconds to boot up again."); + }); + }, [send]); + // Fetch device ID and cloud state on component mount useEffect(() => { getCloudState(); + getTLSState(); send("getDeviceID", {}, async resp => { if ("error" in resp) return console.error(resp.error); setDeviceId(resp.result as string); }); - }, [send, getCloudState]); + }, [send, getCloudState, getTLSState]); return (
@@ -150,30 +207,106 @@ export default function SettingsAccessIndexRoute() { title="Local" description="Manage the mode of local access to the device" /> - - {loaderData.authMode === "password" ? ( -
+ + + {tlsMode === "custom" && ( +
+
+ +
+ handleTlsCertChange(e.target.value)} + /> +
+ +
+
+ handleTlsKeyChange(e.target.value)} + /> +

+ Private key won't be shown again after saving. +

+
+
+
+
)} - + + {loaderData.authMode === "password" ? ( +