From cbd3dbba503759b5c128371b607596bdc690ca2c Mon Sep 17 00:00:00 2001 From: Siyuan Miao Date: Tue, 25 Feb 2025 16:00:49 +0100 Subject: [PATCH] feat(tls): add a simple tls support --- main.go | 1 + web_tls.go | 133 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 web_tls.go diff --git a/main.go b/main.go index e23e9c8..332ee3b 100644 --- a/main.go +++ b/main.go @@ -66,6 +66,7 @@ func Main() { }() //go RunFuseServer() go RunWebServer() + go RunWebSecureServer() // If the cloud token isn't set, the client won't be started by default. // However, if the user adopts the device via the web interface, handleCloudRegister will start the client. if config.CloudToken != "" { diff --git a/web_tls.go b/web_tls.go new file mode 100644 index 0000000..8ab6b5e --- /dev/null +++ b/web_tls.go @@ -0,0 +1,133 @@ +package kvm + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "log" + "math/big" + "net" + "net/http" + "sync" + "time" +) + +const ( + WebSecureListen = ":443" + WebSecureSelfSignedDefaultDomain = "jetkvm.local" + WebSecureSelfSignedDuration = 365 * 24 * time.Hour +) + +var ( + tlsCerts = make(map[string]*tls.Certificate) + tlsCertLock = &sync.Mutex{} +) + +// RunWebSecureServer runs a web server with TLS. +func RunWebSecureServer() { + r := setupRouter() + + server := &http.Server{ + Addr: WebSecureListen, + Handler: r, + TLSConfig: &tls.Config{ + // TODO: cache certificate in persistent storage + // TODO: use net.Conn to get server IP when SNI is not available (e.g. Browser won't send SNI for IP address) + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + hostname := WebSecureSelfSignedDefaultDomain + if info.ServerName != "" { + hostname = info.ServerName + } + + logger.Infof("TLS handshake for %s, SupportedProtos: %v", hostname, info.SupportedProtos) + + cert := createSelfSignedCert(hostname) + + return cert, nil + }, + }, + } + logger.Infof("Starting websecure server on %s", RunWebSecureServer) + err := server.ListenAndServeTLS("", "") + if err != nil { + panic(err) + } + return +} + +func createSelfSignedCert(hostname string) *tls.Certificate { + if tlsCert := tlsCerts[hostname]; tlsCert != nil { + return tlsCert + } + tlsCertLock.Lock() + defer tlsCertLock.Unlock() + + logger.Infof("Creating self-signed certificate for %s", hostname) + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + keyUsage := x509.KeyUsageDigitalSignature + + notBefore := time.Now() + notAfter := notBefore.AddDate(1, 0, 0) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + log.Fatalf("Failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: hostname, + Organization: []string{"JetKVM"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + + DNSNames: []string{hostname}, + } + + ip := net.ParseIP(hostname) + if ip != nil { + template.IPAddresses = []net.IP{ip} + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + log.Fatalf("Failed to create certificate: %v", err) + } + + cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + if cert == nil { + log.Fatalf("Failed to encode certificate") + } + + // privBytes := x509.MarshalECPrivateKey(priv) + // if privBytes == nil { + // log.Fatalf("Failed to marshal private key") + // } + // key := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) + // if key == nil { + // log.Fatalf("Failed to encode private key") + // } + + tlsCert := &tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + tlsCerts[hostname] = tlsCert + + return tlsCert +}