feat(tls): store tls certificates

This commit is contained in:
Siyuan Miao 2025-03-18 17:25:03 +01:00
parent 934bb687cc
commit c939a02f33
6 changed files with 422 additions and 104 deletions

View File

@ -0,0 +1,5 @@
package websecure
import "github.com/pion/logging"
var defaultLogger = logging.NewDefaultLoggerFactory().NewLogger("websecure")

View File

@ -0,0 +1,182 @@
package websecure
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"log"
"net"
"strings"
"time"
"github.com/pion/logging"
"golang.org/x/net/idna"
)
const selfSignerCAMagicName = "__ca__"
type SelfSigner struct {
store *CertStore
log logging.LeveledLogger
caInfo pkix.Name
DefaultDomain string
DefaultOrg string
DefaultOU string
}
func NewSelfSigner(store *CertStore, log logging.LeveledLogger, defaultDomain, defaultOrg, defaultOU, caName string) *SelfSigner {
return &SelfSigner{
store: store,
log: log,
DefaultDomain: defaultDomain,
DefaultOrg: defaultOrg,
DefaultOU: defaultOU,
caInfo: pkix.Name{
CommonName: caName,
Organization: []string{defaultOrg},
OrganizationalUnit: []string{defaultOU},
},
}
}
func (s *SelfSigner) getCA() *tls.Certificate {
return s.createSelfSignedCert(selfSignerCAMagicName)
}
func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate {
if tlsCert := s.store.certificates[hostname]; tlsCert != nil {
return tlsCert
}
// check if hostname is the CA magic name
var ca *tls.Certificate
if hostname != selfSignerCAMagicName {
ca = s.getCA()
if ca == nil {
s.log.Errorf("Failed to get CA certificate")
return nil
}
}
s.log.Infof("Creating self-signed certificate for %s", hostname)
// lock the store while creating the certificate (do not move upwards)
s.store.certLock.Lock()
defer s.store.certLock.Unlock()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
log.Fatalf("Failed to generate private key: %v", err)
}
notBefore := time.Now()
notAfter := notBefore.AddDate(1, 0, 0)
serialNumber, err := generateSerialNumber()
if err != nil {
s.log.Errorf("Failed to generate serial number: %v", err)
}
dnsName := hostname
ip := net.ParseIP(hostname)
if ip != nil {
dnsName = s.DefaultDomain
}
// set up CSR
isCA := hostname == selfSignerCAMagicName
subject := pkix.Name{
CommonName: hostname,
Organization: []string{s.DefaultOrg},
OrganizationalUnit: []string{s.DefaultOU},
}
keyUsage := x509.KeyUsageDigitalSignature
extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
// check if hostname is the CA magic name, and if so, set the subject to the CA info
if isCA {
subject = s.caInfo
keyUsage |= x509.KeyUsageCertSign
extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth)
notAfter = notBefore.AddDate(10, 0, 0)
}
cert := x509.Certificate{
SerialNumber: serialNumber,
Subject: subject,
NotBefore: notBefore,
NotAfter: notAfter,
IsCA: isCA,
KeyUsage: keyUsage,
ExtKeyUsage: extKeyUsage,
BasicConstraintsValid: true,
}
// set up DNS names and IP addresses
if !isCA {
cert.DNSNames = []string{dnsName}
if ip != nil {
cert.IPAddresses = []net.IP{ip}
}
}
// set up parent certificate
parent := &cert
parentPriv := priv
if ca != nil {
parent, err = x509.ParseCertificate(ca.Certificate[0])
if err != nil {
s.log.Errorf("Failed to parse parent certificate: %v", err)
return nil
}
parentPriv = ca.PrivateKey.(*ecdsa.PrivateKey)
}
certBytes, err := x509.CreateCertificate(rand.Reader, &cert, parent, &priv.PublicKey, parentPriv)
if err != nil {
s.log.Errorf("Failed to create certificate: %v", err)
return nil
}
tlsCert := &tls.Certificate{
Certificate: [][]byte{certBytes},
PrivateKey: priv,
}
if ca != nil {
tlsCert.Certificate = append(tlsCert.Certificate, ca.Certificate...)
}
s.store.certificates[hostname] = tlsCert
s.store.saveCertificate(hostname)
return tlsCert
}
func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
hostname := s.DefaultDomain
if info.ServerName != "" && info.ServerName != selfSignerCAMagicName {
hostname = info.ServerName
} else {
hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
}
s.log.Infof("TLS handshake for %s, SupportedProtos: %v", hostname, info.SupportedProtos)
// convert hostname to punycode
h, err := idna.Lookup.ToASCII(hostname)
if err != nil {
s.log.Warnf("Hostname %s is not valid: %w, from %s", hostname, err, info.Conn.RemoteAddr())
hostname = s.DefaultDomain
} else {
hostname = h
}
cert := s.createSelfSignedCert(hostname)
return cert, nil
}

126
internal/websecure/store.go Normal file
View File

@ -0,0 +1,126 @@
package websecure
import (
"crypto/tls"
"fmt"
"os"
"path"
"strings"
"sync"
"github.com/pion/logging"
)
type CertStore struct {
certificates map[string]*tls.Certificate
certLock *sync.Mutex
storePath string
log logging.LeveledLogger
}
func NewCertStore(storePath string) *CertStore {
return &CertStore{
certificates: make(map[string]*tls.Certificate),
certLock: &sync.Mutex{},
storePath: storePath,
log: defaultLogger,
}
}
func (s *CertStore) ensureStorePath() error {
// check if directory exists
stat, err := os.Stat(s.storePath)
if err == nil {
if stat.IsDir() {
return nil
}
return fmt.Errorf("TLS store path exists but is not a directory: %s", s.storePath)
}
if os.IsNotExist(err) {
s.log.Tracef("TLS store directory does not exist, creating directory")
err = os.MkdirAll(s.storePath, 0755)
if err != nil {
return fmt.Errorf("Failed to create TLS store path: %w", err)
}
return nil
}
return fmt.Errorf("Failed to check TLS store path: %w", err)
}
func (s *CertStore) LoadCertificates() {
err := s.ensureStorePath()
if err != nil {
s.log.Errorf(err.Error())
return
}
files, err := os.ReadDir(s.storePath)
if err != nil {
s.log.Errorf("Failed to read TLS directory: %v", err)
return
}
for _, file := range files {
if file.IsDir() {
continue
}
if strings.HasSuffix(file.Name(), ".crt") {
s.loadCertificate(strings.TrimSuffix(file.Name(), ".crt"))
}
}
}
func (s *CertStore) loadCertificate(hostname string) {
s.certLock.Lock()
defer s.certLock.Unlock()
keyFile := path.Join(s.storePath, hostname+".key")
crtFile := path.Join(s.storePath, hostname+".crt")
cert, err := tls.LoadX509KeyPair(crtFile, keyFile)
if err != nil {
s.log.Errorf("Failed to load certificate for %s: %w", hostname, err)
return
}
s.certificates[hostname] = &cert
s.log.Infof("Loaded certificate for %s", hostname)
}
func (s *CertStore) saveCertificate(hostname string) {
// check if certificate already exists
tlsCert := s.certificates[hostname]
if tlsCert == nil {
s.log.Errorf("Certificate for %s does not exist, skipping saving certificate", hostname)
return
}
err := s.ensureStorePath()
if err != nil {
s.log.Errorf(err.Error())
return
}
keyFile := path.Join(s.storePath, hostname+".key")
crtFile := path.Join(s.storePath, hostname+".crt")
if keyToFile(tlsCert, keyFile); err != nil {
s.log.Errorf(err.Error())
return
}
if certToFile(tlsCert, crtFile); err != nil {
s.log.Errorf(err.Error())
return
}
s.log.Infof("Saved certificate for %s", hostname)
}

View File

@ -0,0 +1,80 @@
package websecure
import (
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"os"
)
var serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 4096)
func withSecretFile(filename string, f func(*os.File) error) error {
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer file.Close()
return f(file)
}
func keyToFile(cert *tls.Certificate, filename string) error {
var keyBlock pem.Block
switch k := cert.PrivateKey.(type) {
case *rsa.PrivateKey:
keyBlock = pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(k),
}
case *ecdsa.PrivateKey:
b, e := x509.MarshalECPrivateKey(k)
if e != nil {
return fmt.Errorf("Failed to marshal EC private key: %v", e)
}
keyBlock = pem.Block{
Type: "EC PRIVATE KEY",
Bytes: b,
}
default:
return fmt.Errorf("Unknown private key type: %T", k)
}
err := withSecretFile(filename, func(file *os.File) error {
return pem.Encode(file, &keyBlock)
})
if err != nil {
return fmt.Errorf("Failed to save private key: %w", err)
}
return nil
}
func certToFile(cert *tls.Certificate, filename string) error {
return withSecretFile(filename, func(file *os.File) error {
for _, c := range cert.Certificate {
block := pem.Block{
Type: "CERTIFICATE",
Bytes: c,
}
err := pem.Encode(file, &block)
if err != nil {
return fmt.Errorf("Failed to save certificate: %w", err)
}
}
return nil
})
}
func generateSerialNumber() (*big.Int, error) {
return rand.Int(rand.Reader, serialNumberLimit)
}

View File

@ -70,6 +70,7 @@ func Main() {
//go RunFuseServer()
go RunWebServer()
if config.TLSMode != "" {
initCertStore()
go RunWebSecureServer()
}
// As websocket client already checks if the cloud token is set, we can start it here.

View File

@ -1,56 +1,51 @@
package kvm
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/jetkvm/kvm/internal/websecure"
)
const (
WebSecureListen = ":443"
WebSecureSelfSignedDefaultDomain = "jetkvm.local"
WebSecureSelfSignedDuration = 365 * 24 * time.Hour
tlsStorePath = "/userdata/jetkvm/tls"
webSecureListen = ":443"
webSecureSelfSignedDefaultDomain = "jetkvm.local"
webSecureSelfSignedCAName = "JetKVM Self-Signed CA"
webSecureSelfSignedOrganization = "JetKVM"
webSecureSelfSignedOU = "JetKVM Self-Signed"
)
var (
tlsCerts = make(map[string]*tls.Certificate)
tlsCertLock = &sync.Mutex{}
certStore *websecure.CertStore
certSigner *websecure.SelfSigner
)
func initCertStore() {
certStore = websecure.NewCertStore(tlsStorePath)
certStore.LoadCertificates()
certSigner = websecure.NewSelfSigner(
certStore,
logger,
webSecureSelfSignedDefaultDomain,
webSecureSelfSignedOrganization,
webSecureSelfSignedOU,
webSecureSelfSignedCAName,
)
}
// RunWebSecureServer runs a web server with TLS.
func RunWebSecureServer() {
r := setupRouter()
server := &http.Server{
Addr: WebSecureListen,
Addr: webSecureListen,
Handler: r,
TLSConfig: &tls.Config{
// TODO: cache certificate in persistent storage
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
var hostname string
if info.ServerName != "" {
hostname = info.ServerName
} else {
hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
}
logger.Info().Str("hostname", hostname).Interface("SupportedProtos", info.SupportedProtos).Msg("TLS handshake")
cert := createSelfSignedCert(hostname)
return cert, nil
},
MaxVersion: tls.VersionTLS13,
CurvePreferences: []tls.CurveID{},
GetCertificate: certSigner.GetCertificate,
},
}
logger.Info().Str("listen", WebSecureListen).Msg("Starting websecure server")
@ -59,74 +54,3 @@ func RunWebSecureServer() {
panic(err)
}
}
func createSelfSignedCert(hostname string) *tls.Certificate {
if tlsCert := tlsCerts[hostname]; tlsCert != nil {
return tlsCert
}
tlsCertLock.Lock()
defer tlsCertLock.Unlock()
logger.Info().Str("hostname", hostname).Msg("Creating self-signed certificate")
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
logger.Warn().Err(err).Msg("Failed to generate private key")
os.Exit(1)
}
keyUsage := x509.KeyUsageDigitalSignature
notBefore := time.Now()
notAfter := notBefore.AddDate(1, 0, 0)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
logger.Warn().Err(err).Msg("Failed to generate serial number")
}
dnsName := hostname
ip := net.ParseIP(hostname)
if ip != nil {
dnsName = WebSecureSelfSignedDefaultDomain
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: hostname,
Organization: []string{"JetKVM"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{dnsName},
IPAddresses: []net.IP{},
}
if ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
logger.Warn().Err(err).Msg("Failed to create certificate")
}
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
if cert == nil {
logger.Warn().Msg("Failed to encode certificate")
}
tlsCert := &tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}
tlsCerts[hostname] = tlsCert
return tlsCert
}