From c939a02f33fa60e11fdfa3673e53bb5608faa16f Mon Sep 17 00:00:00 2001
From: Siyuan Miao <i@xswan.net>
Date: Tue, 18 Mar 2025 17:25:03 +0100
Subject: [PATCH] feat(tls): store tls certificates

---
 internal/websecure/log.go      |   5 +
 internal/websecure/selfsign.go | 182 +++++++++++++++++++++++++++++++++
 internal/websecure/store.go    | 126 +++++++++++++++++++++++
 internal/websecure/utils.go    |  80 +++++++++++++++
 main.go                        |   1 +
 web_tls.go                     | 132 +++++-------------------
 6 files changed, 422 insertions(+), 104 deletions(-)
 create mode 100644 internal/websecure/log.go
 create mode 100644 internal/websecure/selfsign.go
 create mode 100644 internal/websecure/store.go
 create mode 100644 internal/websecure/utils.go

diff --git a/internal/websecure/log.go b/internal/websecure/log.go
new file mode 100644
index 0000000..7643ac1
--- /dev/null
+++ b/internal/websecure/log.go
@@ -0,0 +1,5 @@
+package websecure
+
+import "github.com/pion/logging"
+
+var defaultLogger = logging.NewDefaultLoggerFactory().NewLogger("websecure")
diff --git a/internal/websecure/selfsign.go b/internal/websecure/selfsign.go
new file mode 100644
index 0000000..04844ae
--- /dev/null
+++ b/internal/websecure/selfsign.go
@@ -0,0 +1,182 @@
+package websecure
+
+import (
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/tls"
+	"crypto/x509"
+	"crypto/x509/pkix"
+	"log"
+	"net"
+	"strings"
+	"time"
+
+	"github.com/pion/logging"
+	"golang.org/x/net/idna"
+)
+
+const selfSignerCAMagicName = "__ca__"
+
+type SelfSigner struct {
+	store *CertStore
+	log   logging.LeveledLogger
+
+	caInfo pkix.Name
+
+	DefaultDomain string
+	DefaultOrg    string
+	DefaultOU     string
+}
+
+func NewSelfSigner(store *CertStore, log logging.LeveledLogger, defaultDomain, defaultOrg, defaultOU, caName string) *SelfSigner {
+	return &SelfSigner{
+		store:         store,
+		log:           log,
+		DefaultDomain: defaultDomain,
+		DefaultOrg:    defaultOrg,
+		DefaultOU:     defaultOU,
+		caInfo: pkix.Name{
+			CommonName:         caName,
+			Organization:       []string{defaultOrg},
+			OrganizationalUnit: []string{defaultOU},
+		},
+	}
+}
+
+func (s *SelfSigner) getCA() *tls.Certificate {
+	return s.createSelfSignedCert(selfSignerCAMagicName)
+}
+
+func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate {
+	if tlsCert := s.store.certificates[hostname]; tlsCert != nil {
+		return tlsCert
+	}
+
+	// check if hostname is the CA magic name
+	var ca *tls.Certificate
+	if hostname != selfSignerCAMagicName {
+		ca = s.getCA()
+		if ca == nil {
+			s.log.Errorf("Failed to get CA certificate")
+			return nil
+		}
+	}
+
+	s.log.Infof("Creating self-signed certificate for %s", hostname)
+
+	// lock the store while creating the certificate (do not move upwards)
+	s.store.certLock.Lock()
+	defer s.store.certLock.Unlock()
+
+	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	if err != nil {
+		log.Fatalf("Failed to generate private key: %v", err)
+	}
+
+	notBefore := time.Now()
+	notAfter := notBefore.AddDate(1, 0, 0)
+
+	serialNumber, err := generateSerialNumber()
+	if err != nil {
+		s.log.Errorf("Failed to generate serial number: %v", err)
+	}
+
+	dnsName := hostname
+	ip := net.ParseIP(hostname)
+	if ip != nil {
+		dnsName = s.DefaultDomain
+	}
+
+	// set up CSR
+	isCA := hostname == selfSignerCAMagicName
+	subject := pkix.Name{
+		CommonName:         hostname,
+		Organization:       []string{s.DefaultOrg},
+		OrganizationalUnit: []string{s.DefaultOU},
+	}
+	keyUsage := x509.KeyUsageDigitalSignature
+	extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
+
+	// check if hostname is the CA magic name, and if so, set the subject to the CA info
+	if isCA {
+		subject = s.caInfo
+		keyUsage |= x509.KeyUsageCertSign
+		extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth)
+		notAfter = notBefore.AddDate(10, 0, 0)
+	}
+
+	cert := x509.Certificate{
+		SerialNumber:          serialNumber,
+		Subject:               subject,
+		NotBefore:             notBefore,
+		NotAfter:              notAfter,
+		IsCA:                  isCA,
+		KeyUsage:              keyUsage,
+		ExtKeyUsage:           extKeyUsage,
+		BasicConstraintsValid: true,
+	}
+
+	// set up DNS names and IP addresses
+	if !isCA {
+		cert.DNSNames = []string{dnsName}
+		if ip != nil {
+			cert.IPAddresses = []net.IP{ip}
+		}
+	}
+
+	// set up parent certificate
+	parent := &cert
+	parentPriv := priv
+	if ca != nil {
+		parent, err = x509.ParseCertificate(ca.Certificate[0])
+		if err != nil {
+			s.log.Errorf("Failed to parse parent certificate: %v", err)
+			return nil
+		}
+		parentPriv = ca.PrivateKey.(*ecdsa.PrivateKey)
+	}
+
+	certBytes, err := x509.CreateCertificate(rand.Reader, &cert, parent, &priv.PublicKey, parentPriv)
+	if err != nil {
+		s.log.Errorf("Failed to create certificate: %v", err)
+		return nil
+	}
+
+	tlsCert := &tls.Certificate{
+		Certificate: [][]byte{certBytes},
+		PrivateKey:  priv,
+	}
+	if ca != nil {
+		tlsCert.Certificate = append(tlsCert.Certificate, ca.Certificate...)
+	}
+
+	s.store.certificates[hostname] = tlsCert
+	s.store.saveCertificate(hostname)
+
+	return tlsCert
+}
+
+func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
+	hostname := s.DefaultDomain
+	if info.ServerName != "" && info.ServerName != selfSignerCAMagicName {
+		hostname = info.ServerName
+	} else {
+		hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
+	}
+
+	s.log.Infof("TLS handshake for %s, SupportedProtos: %v", hostname, info.SupportedProtos)
+
+	// convert hostname to punycode
+	h, err := idna.Lookup.ToASCII(hostname)
+	if err != nil {
+		s.log.Warnf("Hostname %s is not valid: %w, from %s", hostname, err, info.Conn.RemoteAddr())
+		hostname = s.DefaultDomain
+	} else {
+		hostname = h
+	}
+
+	cert := s.createSelfSignedCert(hostname)
+
+	return cert, nil
+}
diff --git a/internal/websecure/store.go b/internal/websecure/store.go
new file mode 100644
index 0000000..d6d835e
--- /dev/null
+++ b/internal/websecure/store.go
@@ -0,0 +1,126 @@
+package websecure
+
+import (
+	"crypto/tls"
+	"fmt"
+	"os"
+	"path"
+	"strings"
+	"sync"
+
+	"github.com/pion/logging"
+)
+
+type CertStore struct {
+	certificates map[string]*tls.Certificate
+	certLock     *sync.Mutex
+
+	storePath string
+
+	log logging.LeveledLogger
+}
+
+func NewCertStore(storePath string) *CertStore {
+	return &CertStore{
+		certificates: make(map[string]*tls.Certificate),
+		certLock:     &sync.Mutex{},
+
+		storePath: storePath,
+		log:       defaultLogger,
+	}
+}
+
+func (s *CertStore) ensureStorePath() error {
+	// check if directory exists
+	stat, err := os.Stat(s.storePath)
+	if err == nil {
+		if stat.IsDir() {
+			return nil
+		}
+
+		return fmt.Errorf("TLS store path exists but is not a directory: %s", s.storePath)
+	}
+
+	if os.IsNotExist(err) {
+		s.log.Tracef("TLS store directory does not exist, creating directory")
+		err = os.MkdirAll(s.storePath, 0755)
+		if err != nil {
+			return fmt.Errorf("Failed to create TLS store path: %w", err)
+		}
+		return nil
+	}
+
+	return fmt.Errorf("Failed to check TLS store path: %w", err)
+}
+
+func (s *CertStore) LoadCertificates() {
+	err := s.ensureStorePath()
+	if err != nil {
+		s.log.Errorf(err.Error())
+		return
+	}
+
+	files, err := os.ReadDir(s.storePath)
+	if err != nil {
+		s.log.Errorf("Failed to read TLS directory: %v", err)
+		return
+	}
+
+	for _, file := range files {
+		if file.IsDir() {
+			continue
+		}
+
+		if strings.HasSuffix(file.Name(), ".crt") {
+			s.loadCertificate(strings.TrimSuffix(file.Name(), ".crt"))
+		}
+	}
+}
+
+func (s *CertStore) loadCertificate(hostname string) {
+	s.certLock.Lock()
+	defer s.certLock.Unlock()
+
+	keyFile := path.Join(s.storePath, hostname+".key")
+	crtFile := path.Join(s.storePath, hostname+".crt")
+
+	cert, err := tls.LoadX509KeyPair(crtFile, keyFile)
+	if err != nil {
+		s.log.Errorf("Failed to load certificate for %s: %w", hostname, err)
+		return
+	}
+
+	s.certificates[hostname] = &cert
+
+	s.log.Infof("Loaded certificate for %s", hostname)
+}
+
+func (s *CertStore) saveCertificate(hostname string) {
+	// check if certificate already exists
+	tlsCert := s.certificates[hostname]
+	if tlsCert == nil {
+		s.log.Errorf("Certificate for %s does not exist, skipping saving certificate", hostname)
+		return
+	}
+
+	err := s.ensureStorePath()
+	if err != nil {
+		s.log.Errorf(err.Error())
+		return
+	}
+
+	keyFile := path.Join(s.storePath, hostname+".key")
+	crtFile := path.Join(s.storePath, hostname+".crt")
+
+	if keyToFile(tlsCert, keyFile); err != nil {
+		s.log.Errorf(err.Error())
+		return
+	}
+
+	if certToFile(tlsCert, crtFile); err != nil {
+		s.log.Errorf(err.Error())
+		return
+	}
+
+	s.log.Infof("Saved certificate for %s", hostname)
+}
diff --git a/internal/websecure/utils.go b/internal/websecure/utils.go
new file mode 100644
index 0000000..de29c73
--- /dev/null
+++ b/internal/websecure/utils.go
@@ -0,0 +1,80 @@
+package websecure
+
+import (
+	"crypto/ecdsa"
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/tls"
+	"crypto/x509"
+	"encoding/pem"
+	"fmt"
+	"math/big"
+	"os"
+)
+
+var serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 4096)
+
+func withSecretFile(filename string, f func(*os.File) error) error {
+	file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
+	if err != nil {
+		return err
+	}
+	defer file.Close()
+
+	return f(file)
+}
+
+func keyToFile(cert *tls.Certificate, filename string) error {
+	var keyBlock pem.Block
+	switch k := cert.PrivateKey.(type) {
+	case *rsa.PrivateKey:
+		keyBlock = pem.Block{
+			Type:  "RSA PRIVATE KEY",
+			Bytes: x509.MarshalPKCS1PrivateKey(k),
+		}
+	case *ecdsa.PrivateKey:
+		b, e := x509.MarshalECPrivateKey(k)
+		if e != nil {
+			return fmt.Errorf("Failed to marshal EC private key: %v", e)
+		}
+
+		keyBlock = pem.Block{
+			Type:  "EC PRIVATE KEY",
+			Bytes: b,
+		}
+	default:
+		return fmt.Errorf("Unknown private key type: %T", k)
+	}
+
+	err := withSecretFile(filename, func(file *os.File) error {
+		return pem.Encode(file, &keyBlock)
+	})
+
+	if err != nil {
+		return fmt.Errorf("Failed to save private key: %w", err)
+	}
+
+	return nil
+}
+
+func certToFile(cert *tls.Certificate, filename string) error {
+	return withSecretFile(filename, func(file *os.File) error {
+		for _, c := range cert.Certificate {
+			block := pem.Block{
+				Type:  "CERTIFICATE",
+				Bytes: c,
+			}
+
+			err := pem.Encode(file, &block)
+			if err != nil {
+				return fmt.Errorf("Failed to save certificate: %w", err)
+			}
+		}
+
+		return nil
+	})
+}
+
+func generateSerialNumber() (*big.Int, error) {
+	return rand.Int(rand.Reader, serialNumberLimit)
+}
diff --git a/main.go b/main.go
index 98748c7..48ee8cb 100644
--- a/main.go
+++ b/main.go
@@ -70,6 +70,7 @@ func Main() {
 	//go RunFuseServer()
 	go RunWebServer()
 	if config.TLSMode != "" {
+		initCertStore()
 		go RunWebSecureServer()
 	}
 	// As websocket client already checks if the cloud token is set, we can start it here.
diff --git a/web_tls.go b/web_tls.go
index 1ef4d31..dbbad70 100644
--- a/web_tls.go
+++ b/web_tls.go
@@ -1,56 +1,51 @@
 package kvm
 
 import (
-	"crypto/ecdsa"
-	"crypto/elliptic"
-	"crypto/rand"
 	"crypto/tls"
-	"crypto/x509"
-	"crypto/x509/pkix"
-	"encoding/pem"
-	"math/big"
-	"net"
 	"net/http"
-	"os"
-	"strings"
-	"sync"
-	"time"
+
+	"github.com/jetkvm/kvm/internal/websecure"
 )
 
 const (
-	WebSecureListen                  = ":443"
-	WebSecureSelfSignedDefaultDomain = "jetkvm.local"
-	WebSecureSelfSignedDuration      = 365 * 24 * time.Hour
+	tlsStorePath                     = "/userdata/jetkvm/tls"
+	webSecureListen                  = ":443"
+	webSecureSelfSignedDefaultDomain = "jetkvm.local"
+	webSecureSelfSignedCAName        = "JetKVM Self-Signed CA"
+	webSecureSelfSignedOrganization  = "JetKVM"
+	webSecureSelfSignedOU            = "JetKVM Self-Signed"
 )
 
 var (
-	tlsCerts    = make(map[string]*tls.Certificate)
-	tlsCertLock = &sync.Mutex{}
+	certStore  *websecure.CertStore
+	certSigner *websecure.SelfSigner
 )
 
+func initCertStore() {
+	certStore = websecure.NewCertStore(tlsStorePath)
+	certStore.LoadCertificates()
+
+	certSigner = websecure.NewSelfSigner(
+		certStore,
+		logger,
+		webSecureSelfSignedDefaultDomain,
+		webSecureSelfSignedOrganization,
+		webSecureSelfSignedOU,
+		webSecureSelfSignedCAName,
+	)
+}
+
 // RunWebSecureServer runs a web server with TLS.
 func RunWebSecureServer() {
 	r := setupRouter()
 
 	server := &http.Server{
-		Addr:    WebSecureListen,
+		Addr:    webSecureListen,
 		Handler: r,
 		TLSConfig: &tls.Config{
-			// TODO: cache certificate in persistent storage
-			GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
-				var hostname string
-				if info.ServerName != "" {
-					hostname = info.ServerName
-				} else {
-					hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
-				}
-
-				logger.Info().Str("hostname", hostname).Interface("SupportedProtos", info.SupportedProtos).Msg("TLS handshake")
-
-				cert := createSelfSignedCert(hostname)
-
-				return cert, nil
-			},
+			MaxVersion:       tls.VersionTLS13,
+			CurvePreferences: []tls.CurveID{},
+			GetCertificate:   certSigner.GetCertificate,
 		},
 	}
 	logger.Info().Str("listen", WebSecureListen).Msg("Starting websecure server")
@@ -59,74 +54,3 @@ func RunWebSecureServer() {
 		panic(err)
 	}
 }
-
-func createSelfSignedCert(hostname string) *tls.Certificate {
-	if tlsCert := tlsCerts[hostname]; tlsCert != nil {
-		return tlsCert
-	}
-	tlsCertLock.Lock()
-	defer tlsCertLock.Unlock()
-
-	logger.Info().Str("hostname", hostname).Msg("Creating self-signed certificate")
-
-	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	if err != nil {
-		logger.Warn().Err(err).Msg("Failed to generate private key")
-		os.Exit(1)
-	}
-	keyUsage := x509.KeyUsageDigitalSignature
-
-	notBefore := time.Now()
-	notAfter := notBefore.AddDate(1, 0, 0)
-
-	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
-	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
-	if err != nil {
-		logger.Warn().Err(err).Msg("Failed to generate serial number")
-	}
-
-	dnsName := hostname
-	ip := net.ParseIP(hostname)
-	if ip != nil {
-		dnsName = WebSecureSelfSignedDefaultDomain
-	}
-
-	template := x509.Certificate{
-		SerialNumber: serialNumber,
-		Subject: pkix.Name{
-			CommonName:   hostname,
-			Organization: []string{"JetKVM"},
-		},
-		NotBefore: notBefore,
-		NotAfter:  notAfter,
-
-		KeyUsage:              keyUsage,
-		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
-		BasicConstraintsValid: true,
-
-		DNSNames:    []string{dnsName},
-		IPAddresses: []net.IP{},
-	}
-
-	if ip != nil {
-		template.IPAddresses = append(template.IPAddresses, ip)
-	}
-
-	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
-	if err != nil {
-		logger.Warn().Err(err).Msg("Failed to create certificate")
-	}
-
-	cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
-	if cert == nil {
-		logger.Warn().Msg("Failed to encode certificate")
-	}
-
-	tlsCert := &tls.Certificate{
-		Certificate: [][]byte{derBytes},
-		PrivateKey:  priv,
-	}
-	tlsCerts[hostname] = tlsCert
-
-	return tlsCert
-}