diff --git a/internal/websecure/log.go b/internal/websecure/log.go new file mode 100644 index 0000000..7643ac1 --- /dev/null +++ b/internal/websecure/log.go @@ -0,0 +1,5 @@ +package websecure + +import "github.com/pion/logging" + +var defaultLogger = logging.NewDefaultLoggerFactory().NewLogger("websecure") diff --git a/internal/websecure/selfsign.go b/internal/websecure/selfsign.go new file mode 100644 index 0000000..04844ae --- /dev/null +++ b/internal/websecure/selfsign.go @@ -0,0 +1,182 @@ +package websecure + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "log" + "net" + "strings" + "time" + + "github.com/pion/logging" + "golang.org/x/net/idna" +) + +const selfSignerCAMagicName = "__ca__" + +type SelfSigner struct { + store *CertStore + log logging.LeveledLogger + + caInfo pkix.Name + + DefaultDomain string + DefaultOrg string + DefaultOU string +} + +func NewSelfSigner(store *CertStore, log logging.LeveledLogger, defaultDomain, defaultOrg, defaultOU, caName string) *SelfSigner { + return &SelfSigner{ + store: store, + log: log, + DefaultDomain: defaultDomain, + DefaultOrg: defaultOrg, + DefaultOU: defaultOU, + caInfo: pkix.Name{ + CommonName: caName, + Organization: []string{defaultOrg}, + OrganizationalUnit: []string{defaultOU}, + }, + } +} + +func (s *SelfSigner) getCA() *tls.Certificate { + return s.createSelfSignedCert(selfSignerCAMagicName) +} + +func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate { + if tlsCert := s.store.certificates[hostname]; tlsCert != nil { + return tlsCert + } + + // check if hostname is the CA magic name + var ca *tls.Certificate + if hostname != selfSignerCAMagicName { + ca = s.getCA() + if ca == nil { + s.log.Errorf("Failed to get CA certificate") + return nil + } + } + + s.log.Infof("Creating self-signed certificate for %s", hostname) + + // lock the store while creating the certificate (do not move upwards) + s.store.certLock.Lock() + defer s.store.certLock.Unlock() + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + + notBefore := time.Now() + notAfter := notBefore.AddDate(1, 0, 0) + + serialNumber, err := generateSerialNumber() + if err != nil { + s.log.Errorf("Failed to generate serial number: %v", err) + } + + dnsName := hostname + ip := net.ParseIP(hostname) + if ip != nil { + dnsName = s.DefaultDomain + } + + // set up CSR + isCA := hostname == selfSignerCAMagicName + subject := pkix.Name{ + CommonName: hostname, + Organization: []string{s.DefaultOrg}, + OrganizationalUnit: []string{s.DefaultOU}, + } + keyUsage := x509.KeyUsageDigitalSignature + extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + + // check if hostname is the CA magic name, and if so, set the subject to the CA info + if isCA { + subject = s.caInfo + keyUsage |= x509.KeyUsageCertSign + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth) + notAfter = notBefore.AddDate(10, 0, 0) + } + + cert := x509.Certificate{ + SerialNumber: serialNumber, + Subject: subject, + NotBefore: notBefore, + NotAfter: notAfter, + IsCA: isCA, + KeyUsage: keyUsage, + ExtKeyUsage: extKeyUsage, + BasicConstraintsValid: true, + } + + // set up DNS names and IP addresses + if !isCA { + cert.DNSNames = []string{dnsName} + if ip != nil { + cert.IPAddresses = []net.IP{ip} + } + } + + // set up parent certificate + parent := &cert + parentPriv := priv + if ca != nil { + parent, err = x509.ParseCertificate(ca.Certificate[0]) + if err != nil { + s.log.Errorf("Failed to parse parent certificate: %v", err) + return nil + } + parentPriv = ca.PrivateKey.(*ecdsa.PrivateKey) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &cert, parent, &priv.PublicKey, parentPriv) + if err != nil { + s.log.Errorf("Failed to create certificate: %v", err) + return nil + } + + tlsCert := &tls.Certificate{ + Certificate: [][]byte{certBytes}, + PrivateKey: priv, + } + if ca != nil { + tlsCert.Certificate = append(tlsCert.Certificate, ca.Certificate...) + } + + s.store.certificates[hostname] = tlsCert + s.store.saveCertificate(hostname) + + return tlsCert +} + +func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + hostname := s.DefaultDomain + if info.ServerName != "" && info.ServerName != selfSignerCAMagicName { + hostname = info.ServerName + } else { + hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0] + } + + s.log.Infof("TLS handshake for %s, SupportedProtos: %v", hostname, info.SupportedProtos) + + // convert hostname to punycode + h, err := idna.Lookup.ToASCII(hostname) + if err != nil { + s.log.Warnf("Hostname %s is not valid: %w, from %s", hostname, err, info.Conn.RemoteAddr()) + hostname = s.DefaultDomain + } else { + hostname = h + } + + cert := s.createSelfSignedCert(hostname) + + return cert, nil +} diff --git a/internal/websecure/store.go b/internal/websecure/store.go new file mode 100644 index 0000000..d6d835e --- /dev/null +++ b/internal/websecure/store.go @@ -0,0 +1,126 @@ +package websecure + +import ( + "crypto/tls" + "fmt" + "os" + "path" + "strings" + "sync" + + "github.com/pion/logging" +) + +type CertStore struct { + certificates map[string]*tls.Certificate + certLock *sync.Mutex + + storePath string + + log logging.LeveledLogger +} + +func NewCertStore(storePath string) *CertStore { + return &CertStore{ + certificates: make(map[string]*tls.Certificate), + certLock: &sync.Mutex{}, + + storePath: storePath, + log: defaultLogger, + } +} + +func (s *CertStore) ensureStorePath() error { + // check if directory exists + stat, err := os.Stat(s.storePath) + if err == nil { + if stat.IsDir() { + return nil + } + + return fmt.Errorf("TLS store path exists but is not a directory: %s", s.storePath) + } + + if os.IsNotExist(err) { + s.log.Tracef("TLS store directory does not exist, creating directory") + err = os.MkdirAll(s.storePath, 0755) + if err != nil { + return fmt.Errorf("Failed to create TLS store path: %w", err) + } + return nil + } + + return fmt.Errorf("Failed to check TLS store path: %w", err) +} + +func (s *CertStore) LoadCertificates() { + err := s.ensureStorePath() + if err != nil { + s.log.Errorf(err.Error()) + return + } + + files, err := os.ReadDir(s.storePath) + if err != nil { + s.log.Errorf("Failed to read TLS directory: %v", err) + return + } + + for _, file := range files { + if file.IsDir() { + continue + } + + if strings.HasSuffix(file.Name(), ".crt") { + s.loadCertificate(strings.TrimSuffix(file.Name(), ".crt")) + } + } +} + +func (s *CertStore) loadCertificate(hostname string) { + s.certLock.Lock() + defer s.certLock.Unlock() + + keyFile := path.Join(s.storePath, hostname+".key") + crtFile := path.Join(s.storePath, hostname+".crt") + + cert, err := tls.LoadX509KeyPair(crtFile, keyFile) + if err != nil { + s.log.Errorf("Failed to load certificate for %s: %w", hostname, err) + return + } + + s.certificates[hostname] = &cert + + s.log.Infof("Loaded certificate for %s", hostname) +} + +func (s *CertStore) saveCertificate(hostname string) { + // check if certificate already exists + tlsCert := s.certificates[hostname] + if tlsCert == nil { + s.log.Errorf("Certificate for %s does not exist, skipping saving certificate", hostname) + return + } + + err := s.ensureStorePath() + if err != nil { + s.log.Errorf(err.Error()) + return + } + + keyFile := path.Join(s.storePath, hostname+".key") + crtFile := path.Join(s.storePath, hostname+".crt") + + if keyToFile(tlsCert, keyFile); err != nil { + s.log.Errorf(err.Error()) + return + } + + if certToFile(tlsCert, crtFile); err != nil { + s.log.Errorf(err.Error()) + return + } + + s.log.Infof("Saved certificate for %s", hostname) +} diff --git a/internal/websecure/utils.go b/internal/websecure/utils.go new file mode 100644 index 0000000..de29c73 --- /dev/null +++ b/internal/websecure/utils.go @@ -0,0 +1,80 @@ +package websecure + +import ( + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "math/big" + "os" +) + +var serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 4096) + +func withSecretFile(filename string, f func(*os.File) error) error { + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer file.Close() + + return f(file) +} + +func keyToFile(cert *tls.Certificate, filename string) error { + var keyBlock pem.Block + switch k := cert.PrivateKey.(type) { + case *rsa.PrivateKey: + keyBlock = pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(k), + } + case *ecdsa.PrivateKey: + b, e := x509.MarshalECPrivateKey(k) + if e != nil { + return fmt.Errorf("Failed to marshal EC private key: %v", e) + } + + keyBlock = pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: b, + } + default: + return fmt.Errorf("Unknown private key type: %T", k) + } + + err := withSecretFile(filename, func(file *os.File) error { + return pem.Encode(file, &keyBlock) + }) + + if err != nil { + return fmt.Errorf("Failed to save private key: %w", err) + } + + return nil +} + +func certToFile(cert *tls.Certificate, filename string) error { + return withSecretFile(filename, func(file *os.File) error { + for _, c := range cert.Certificate { + block := pem.Block{ + Type: "CERTIFICATE", + Bytes: c, + } + + err := pem.Encode(file, &block) + if err != nil { + return fmt.Errorf("Failed to save certificate: %w", err) + } + } + + return nil + }) +} + +func generateSerialNumber() (*big.Int, error) { + return rand.Int(rand.Reader, serialNumberLimit) +} diff --git a/main.go b/main.go index 98748c7..48ee8cb 100644 --- a/main.go +++ b/main.go @@ -70,6 +70,7 @@ func Main() { //go RunFuseServer() go RunWebServer() if config.TLSMode != "" { + initCertStore() go RunWebSecureServer() } // As websocket client already checks if the cloud token is set, we can start it here. diff --git a/web_tls.go b/web_tls.go index 1ef4d31..dbbad70 100644 --- a/web_tls.go +++ b/web_tls.go @@ -1,56 +1,51 @@ package kvm import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "math/big" - "net" "net/http" - "os" - "strings" - "sync" - "time" + + "github.com/jetkvm/kvm/internal/websecure" ) const ( - WebSecureListen = ":443" - WebSecureSelfSignedDefaultDomain = "jetkvm.local" - WebSecureSelfSignedDuration = 365 * 24 * time.Hour + tlsStorePath = "/userdata/jetkvm/tls" + webSecureListen = ":443" + webSecureSelfSignedDefaultDomain = "jetkvm.local" + webSecureSelfSignedCAName = "JetKVM Self-Signed CA" + webSecureSelfSignedOrganization = "JetKVM" + webSecureSelfSignedOU = "JetKVM Self-Signed" ) var ( - tlsCerts = make(map[string]*tls.Certificate) - tlsCertLock = &sync.Mutex{} + certStore *websecure.CertStore + certSigner *websecure.SelfSigner ) +func initCertStore() { + certStore = websecure.NewCertStore(tlsStorePath) + certStore.LoadCertificates() + + certSigner = websecure.NewSelfSigner( + certStore, + logger, + webSecureSelfSignedDefaultDomain, + webSecureSelfSignedOrganization, + webSecureSelfSignedOU, + webSecureSelfSignedCAName, + ) +} + // RunWebSecureServer runs a web server with TLS. func RunWebSecureServer() { r := setupRouter() server := &http.Server{ - Addr: WebSecureListen, + Addr: webSecureListen, Handler: r, TLSConfig: &tls.Config{ - // TODO: cache certificate in persistent storage - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - var hostname string - if info.ServerName != "" { - hostname = info.ServerName - } else { - hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0] - } - - logger.Info().Str("hostname", hostname).Interface("SupportedProtos", info.SupportedProtos).Msg("TLS handshake") - - cert := createSelfSignedCert(hostname) - - return cert, nil - }, + MaxVersion: tls.VersionTLS13, + CurvePreferences: []tls.CurveID{}, + GetCertificate: certSigner.GetCertificate, }, } logger.Info().Str("listen", WebSecureListen).Msg("Starting websecure server") @@ -59,74 +54,3 @@ func RunWebSecureServer() { panic(err) } } - -func createSelfSignedCert(hostname string) *tls.Certificate { - if tlsCert := tlsCerts[hostname]; tlsCert != nil { - return tlsCert - } - tlsCertLock.Lock() - defer tlsCertLock.Unlock() - - logger.Info().Str("hostname", hostname).Msg("Creating self-signed certificate") - - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - logger.Warn().Err(err).Msg("Failed to generate private key") - os.Exit(1) - } - keyUsage := x509.KeyUsageDigitalSignature - - notBefore := time.Now() - notAfter := notBefore.AddDate(1, 0, 0) - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - logger.Warn().Err(err).Msg("Failed to generate serial number") - } - - dnsName := hostname - ip := net.ParseIP(hostname) - if ip != nil { - dnsName = WebSecureSelfSignedDefaultDomain - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - CommonName: hostname, - Organization: []string{"JetKVM"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: keyUsage, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - - DNSNames: []string{dnsName}, - IPAddresses: []net.IP{}, - } - - if ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - logger.Warn().Err(err).Msg("Failed to create certificate") - } - - cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - if cert == nil { - logger.Warn().Msg("Failed to encode certificate") - } - - tlsCert := &tls.Certificate{ - Certificate: [][]byte{derBytes}, - PrivateKey: priv, - } - tlsCerts[hostname] = tlsCert - - return tlsCert -}