mirror of https://github.com/jetkvm/kvm.git
feat(tls): #330
This commit is contained in:
parent
4c37f7e079
commit
82c018a2f6
|
@ -90,7 +90,7 @@ type Config struct {
|
||||||
DisplayMaxBrightness int `json:"display_max_brightness"`
|
DisplayMaxBrightness int `json:"display_max_brightness"`
|
||||||
DisplayDimAfterSec int `json:"display_dim_after_sec"`
|
DisplayDimAfterSec int `json:"display_dim_after_sec"`
|
||||||
DisplayOffAfterSec int `json:"display_off_after_sec"`
|
DisplayOffAfterSec int `json:"display_off_after_sec"`
|
||||||
TLSMode string `json:"tls_mode"`
|
TLSMode string `json:"tls_mode"` // options: "self-signed", "user-defined", ""
|
||||||
UsbConfig *usbgadget.Config `json:"usb_config"`
|
UsbConfig *usbgadget.Config `json:"usb_config"`
|
||||||
UsbDevices *usbgadget.Devices `json:"usb_devices"`
|
UsbDevices *usbgadget.Devices `json:"usb_devices"`
|
||||||
}
|
}
|
||||||
|
@ -169,6 +169,8 @@ func SaveConfig() error {
|
||||||
configLock.Lock()
|
configLock.Lock()
|
||||||
defer configLock.Unlock()
|
defer configLock.Unlock()
|
||||||
|
|
||||||
|
logger.Trace().Str("path", configPath).Msg("Saving config")
|
||||||
|
|
||||||
file, err := os.Create(configPath)
|
file, err := os.Create(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create config file: %w", err)
|
return fmt.Errorf("failed to create config file: %w", err)
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
package websecure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultLogger = zerolog.New(os.Stdout).With().Str("component", "websecure").Logger()
|
|
@ -0,0 +1,191 @@
|
||||||
|
package websecure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"golang.org/x/net/idna"
|
||||||
|
)
|
||||||
|
|
||||||
|
const selfSignerCAMagicName = "__ca__"
|
||||||
|
|
||||||
|
type SelfSigner struct {
|
||||||
|
store *CertStore
|
||||||
|
log *zerolog.Logger
|
||||||
|
|
||||||
|
caInfo pkix.Name
|
||||||
|
|
||||||
|
DefaultDomain string
|
||||||
|
DefaultOrg string
|
||||||
|
DefaultOU string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSelfSigner(
|
||||||
|
store *CertStore,
|
||||||
|
log *zerolog.Logger,
|
||||||
|
defaultDomain,
|
||||||
|
defaultOrg,
|
||||||
|
defaultOU,
|
||||||
|
caName string,
|
||||||
|
) *SelfSigner {
|
||||||
|
return &SelfSigner{
|
||||||
|
store: store,
|
||||||
|
log: log,
|
||||||
|
DefaultDomain: defaultDomain,
|
||||||
|
DefaultOrg: defaultOrg,
|
||||||
|
DefaultOU: defaultOU,
|
||||||
|
caInfo: pkix.Name{
|
||||||
|
CommonName: caName,
|
||||||
|
Organization: []string{defaultOrg},
|
||||||
|
OrganizationalUnit: []string{defaultOU},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SelfSigner) getCA() *tls.Certificate {
|
||||||
|
return s.createSelfSignedCert(selfSignerCAMagicName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate {
|
||||||
|
if tlsCert := s.store.certificates[hostname]; tlsCert != nil {
|
||||||
|
return tlsCert
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if hostname is the CA magic name
|
||||||
|
var ca *tls.Certificate
|
||||||
|
if hostname != selfSignerCAMagicName {
|
||||||
|
ca = s.getCA()
|
||||||
|
if ca == nil {
|
||||||
|
s.log.Error().Msg("Failed to get CA certificate")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Info().Str("hostname", hostname).Msg("Creating self-signed certificate")
|
||||||
|
|
||||||
|
// lock the store while creating the certificate (do not move upwards)
|
||||||
|
s.store.certLock.Lock()
|
||||||
|
defer s.store.certLock.Unlock()
|
||||||
|
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to generate private key")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.AddDate(1, 0, 0)
|
||||||
|
|
||||||
|
serialNumber, err := generateSerialNumber()
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to generate serial number")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsName := hostname
|
||||||
|
ip := net.ParseIP(hostname)
|
||||||
|
if ip != nil {
|
||||||
|
dnsName = s.DefaultDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up CSR
|
||||||
|
isCA := hostname == selfSignerCAMagicName
|
||||||
|
subject := pkix.Name{
|
||||||
|
CommonName: hostname,
|
||||||
|
Organization: []string{s.DefaultOrg},
|
||||||
|
OrganizationalUnit: []string{s.DefaultOU},
|
||||||
|
}
|
||||||
|
keyUsage := x509.KeyUsageDigitalSignature
|
||||||
|
extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
|
||||||
|
|
||||||
|
// check if hostname is the CA magic name, and if so, set the subject to the CA info
|
||||||
|
if isCA {
|
||||||
|
subject = s.caInfo
|
||||||
|
keyUsage |= x509.KeyUsageCertSign
|
||||||
|
extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth)
|
||||||
|
notAfter = notBefore.AddDate(10, 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: subject,
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
IsCA: isCA,
|
||||||
|
KeyUsage: keyUsage,
|
||||||
|
ExtKeyUsage: extKeyUsage,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up DNS names and IP addresses
|
||||||
|
if !isCA {
|
||||||
|
cert.DNSNames = []string{dnsName}
|
||||||
|
if ip != nil {
|
||||||
|
cert.IPAddresses = []net.IP{ip}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up parent certificate
|
||||||
|
parent := &cert
|
||||||
|
parentPriv := priv
|
||||||
|
if ca != nil {
|
||||||
|
parent, err = x509.ParseCertificate(ca.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to parse parent certificate")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parentPriv = ca.PrivateKey.(*ecdsa.PrivateKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
certBytes, err := x509.CreateCertificate(rand.Reader, &cert, parent, &priv.PublicKey, parentPriv)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to create certificate")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsCert := &tls.Certificate{
|
||||||
|
Certificate: [][]byte{certBytes},
|
||||||
|
PrivateKey: priv,
|
||||||
|
}
|
||||||
|
if ca != nil {
|
||||||
|
tlsCert.Certificate = append(tlsCert.Certificate, ca.Certificate...)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.store.certificates[hostname] = tlsCert
|
||||||
|
s.store.saveCertificate(hostname)
|
||||||
|
|
||||||
|
return tlsCert
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertificate returns the certificate for the given hostname
|
||||||
|
// returns nil if the certificate is not found
|
||||||
|
func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
var hostname string
|
||||||
|
if info.ServerName != "" && info.ServerName != selfSignerCAMagicName {
|
||||||
|
hostname = info.ServerName
|
||||||
|
} else {
|
||||||
|
hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Info().Str("hostname", hostname).Strs("supported_protos", info.SupportedProtos).Msg("TLS handshake")
|
||||||
|
|
||||||
|
// convert hostname to punycode
|
||||||
|
h, err := idna.Lookup.ToASCII(hostname)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Warn().Str("hostname", hostname).Err(err).Str("remote_addr", info.Conn.RemoteAddr().String()).Msg("Hostname is not valid")
|
||||||
|
hostname = s.DefaultDomain
|
||||||
|
} else {
|
||||||
|
hostname = h
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := s.createSelfSignedCert(hostname)
|
||||||
|
return cert, nil
|
||||||
|
}
|
|
@ -0,0 +1,175 @@
|
||||||
|
package websecure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CertStore struct {
|
||||||
|
certificates map[string]*tls.Certificate
|
||||||
|
certLock *sync.Mutex
|
||||||
|
|
||||||
|
storePath string
|
||||||
|
|
||||||
|
log *zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCertStore(storePath string, log *zerolog.Logger) *CertStore {
|
||||||
|
if log == nil {
|
||||||
|
log = &defaultLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CertStore{
|
||||||
|
certificates: make(map[string]*tls.Certificate),
|
||||||
|
certLock: &sync.Mutex{},
|
||||||
|
|
||||||
|
storePath: storePath,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CertStore) ensureStorePath() error {
|
||||||
|
// check if directory exists
|
||||||
|
stat, err := os.Stat(s.storePath)
|
||||||
|
if err == nil {
|
||||||
|
if stat.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("TLS store path exists but is not a directory: %s", s.storePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
s.log.Trace().Str("path", s.storePath).Msg("TLS store directory does not exist, creating directory")
|
||||||
|
err = os.MkdirAll(s.storePath, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to create TLS store path: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("Failed to check TLS store path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CertStore) LoadCertificates() {
|
||||||
|
err := s.ensureStorePath()
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to ensure store path")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
files, err := os.ReadDir(s.storePath)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to read TLS directory")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
if file.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(file.Name(), ".crt") {
|
||||||
|
s.loadCertificate(strings.TrimSuffix(file.Name(), ".crt"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CertStore) loadCertificate(hostname string) {
|
||||||
|
s.certLock.Lock()
|
||||||
|
defer s.certLock.Unlock()
|
||||||
|
|
||||||
|
keyFile := path.Join(s.storePath, hostname+".key")
|
||||||
|
crtFile := path.Join(s.storePath, hostname+".crt")
|
||||||
|
|
||||||
|
cert, err := tls.LoadX509KeyPair(crtFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error().Err(err).Str("hostname", hostname).Msg("Failed to load certificate")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.certificates[hostname] = &cert
|
||||||
|
|
||||||
|
s.log.Info().Str("hostname", hostname).Msg("Loaded certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertificate returns the certificate for the given hostname
|
||||||
|
// returns nil if the certificate is not found
|
||||||
|
func (s *CertStore) GetCertificate(hostname string) *tls.Certificate {
|
||||||
|
s.certLock.Lock()
|
||||||
|
defer s.certLock.Unlock()
|
||||||
|
|
||||||
|
return s.certificates[hostname]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndSaveCertificate validates the certificate and saves it to the store
|
||||||
|
// returns are:
|
||||||
|
// - error: if the certificate is invalid or if there's any error during saving the certificate
|
||||||
|
// - error: if there's any warning or error during saving the certificate
|
||||||
|
func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key string, ignoreWarning bool) (error, error) {
|
||||||
|
tlsCert, err := tls.X509KeyPair([]byte(cert), []byte(key))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to parse certificate: %w", err), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// this can be skipped as current implementation supports one custom certificate only
|
||||||
|
if tlsCert.Leaf != nil {
|
||||||
|
// add recover to avoid panic
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
s.log.Error().Interface("recovered", r).Msg("Failed to verify hostname")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = tlsCert.Leaf.VerifyHostname(hostname); err != nil {
|
||||||
|
if !ignoreWarning {
|
||||||
|
return nil, fmt.Errorf("Certificate does not match hostname: %w", err)
|
||||||
|
}
|
||||||
|
s.log.Warn().Err(err).Msg("Certificate does not match hostname")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.certLock.Lock()
|
||||||
|
s.certificates[hostname] = &tlsCert
|
||||||
|
s.certLock.Unlock()
|
||||||
|
|
||||||
|
s.saveCertificate(hostname)
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CertStore) saveCertificate(hostname string) {
|
||||||
|
// check if certificate already exists
|
||||||
|
tlsCert := s.certificates[hostname]
|
||||||
|
if tlsCert == nil {
|
||||||
|
s.log.Error().Str("hostname", hostname).Msg("Certificate for hostname does not exist, skipping saving certificate")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.ensureStorePath()
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to ensure store path")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keyFile := path.Join(s.storePath, hostname+".key")
|
||||||
|
crtFile := path.Join(s.storePath, hostname+".crt")
|
||||||
|
|
||||||
|
if err := keyToFile(tlsCert, keyFile); err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to save key file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := certToFile(tlsCert, crtFile); err != nil {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to save certificate")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Info().Str("hostname", hostname).Msg("Saved certificate")
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
package websecure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 4096)
|
||||||
|
|
||||||
|
func withSecretFile(filename string, f func(*os.File) error) error {
|
||||||
|
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
return f(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyToFile(cert *tls.Certificate, filename string) error {
|
||||||
|
var keyBlock pem.Block
|
||||||
|
switch k := cert.PrivateKey.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
keyBlock = pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(k),
|
||||||
|
}
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
b, e := x509.MarshalECPrivateKey(k)
|
||||||
|
if e != nil {
|
||||||
|
return fmt.Errorf("Failed to marshal EC private key: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyBlock = pem.Block{
|
||||||
|
Type: "EC PRIVATE KEY",
|
||||||
|
Bytes: b,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Unknown private key type: %T", k)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := withSecretFile(filename, func(file *os.File) error {
|
||||||
|
return pem.Encode(file, &keyBlock)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to save private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func certToFile(cert *tls.Certificate, filename string) error {
|
||||||
|
return withSecretFile(filename, func(file *os.File) error {
|
||||||
|
for _, c := range cert.Certificate {
|
||||||
|
block := pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: c,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := pem.Encode(file, &block)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to save certificate: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSerialNumber() (*big.Int, error) {
|
||||||
|
return rand.Int(rand.Reader, serialNumberLimit)
|
||||||
|
}
|
48
jsonrpc.go
48
jsonrpc.go
|
@ -95,7 +95,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//logger.Infof("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID)
|
logger.Trace().Str("method", request.Method).Interface("params", request.Params).Interface("id", request.ID).Msg("Received RPC request")
|
||||||
handler, ok := rpcHandlers[request.Method]
|
handler, ok := rpcHandlers[request.Method]
|
||||||
if !ok {
|
if !ok {
|
||||||
errorResponse := JSONRPCResponse{
|
errorResponse := JSONRPCResponse{
|
||||||
|
@ -110,6 +110,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Trace().Str("method", request.Method).Interface("id", request.ID).Msg("Calling RPC handler")
|
||||||
result, err := callRPCHandler(handler, request.Params)
|
result, err := callRPCHandler(handler, request.Params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorResponse := JSONRPCResponse{
|
errorResponse := JSONRPCResponse{
|
||||||
|
@ -125,6 +126,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Trace().Interface("result", result).Interface("id", request.ID).Msg("RPC handler returned")
|
||||||
response := JSONRPCResponse{
|
response := JSONRPCResponse{
|
||||||
JSONRPC: "2.0",
|
JSONRPC: "2.0",
|
||||||
Result: result,
|
Result: result,
|
||||||
|
@ -141,6 +143,30 @@ func rpcGetDeviceID() (string, error) {
|
||||||
return GetDeviceID(), nil
|
return GetDeviceID(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rpcReboot(force bool) error {
|
||||||
|
logger.Info().Msg("Got reboot request from JSONRPC, rebooting...")
|
||||||
|
|
||||||
|
args := []string{}
|
||||||
|
if force {
|
||||||
|
args = append(args, "-f")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("reboot", args...)
|
||||||
|
err := cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error().Err(err).Msg("failed to reboot")
|
||||||
|
return fmt.Errorf("failed to reboot: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the reboot command is successful, exit the program after 5 seconds
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
os.Exit(0)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var streamFactor = 1.0
|
var streamFactor = 1.0
|
||||||
|
|
||||||
func rpcGetStreamQualityFactor() (float64, error) {
|
func rpcGetStreamQualityFactor() (float64, error) {
|
||||||
|
@ -375,6 +401,23 @@ func rpcSetSSHKeyState(sshKey string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rpcGetTLSState() TLSState {
|
||||||
|
return getTLSState()
|
||||||
|
}
|
||||||
|
|
||||||
|
func rpcSetTLSState(state TLSState) error {
|
||||||
|
err := setTLSState(state)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set TLS state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := SaveConfig(); err != nil {
|
||||||
|
return fmt.Errorf("failed to save config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func callRPCHandler(handler RPCHandler, params map[string]interface{}) (interface{}, error) {
|
func callRPCHandler(handler RPCHandler, params map[string]interface{}) (interface{}, error) {
|
||||||
handlerValue := reflect.ValueOf(handler.Func)
|
handlerValue := reflect.ValueOf(handler.Func)
|
||||||
handlerType := handlerValue.Type()
|
handlerType := handlerValue.Type()
|
||||||
|
@ -892,6 +935,7 @@ func setKeyboardMacros(params KeyboardMacrosParams) (interface{}, error) {
|
||||||
|
|
||||||
var rpcHandlers = map[string]RPCHandler{
|
var rpcHandlers = map[string]RPCHandler{
|
||||||
"ping": {Func: rpcPing},
|
"ping": {Func: rpcPing},
|
||||||
|
"reboot": {Func: rpcReboot, Params: []string{"force"}},
|
||||||
"getDeviceID": {Func: rpcGetDeviceID},
|
"getDeviceID": {Func: rpcGetDeviceID},
|
||||||
"deregisterDevice": {Func: rpcDeregisterDevice},
|
"deregisterDevice": {Func: rpcDeregisterDevice},
|
||||||
"getCloudState": {Func: rpcGetCloudState},
|
"getCloudState": {Func: rpcGetCloudState},
|
||||||
|
@ -920,6 +964,8 @@ var rpcHandlers = map[string]RPCHandler{
|
||||||
"setDevModeState": {Func: rpcSetDevModeState, Params: []string{"enabled"}},
|
"setDevModeState": {Func: rpcSetDevModeState, Params: []string{"enabled"}},
|
||||||
"getSSHKeyState": {Func: rpcGetSSHKeyState},
|
"getSSHKeyState": {Func: rpcGetSSHKeyState},
|
||||||
"setSSHKeyState": {Func: rpcSetSSHKeyState, Params: []string{"sshKey"}},
|
"setSSHKeyState": {Func: rpcSetSSHKeyState, Params: []string{"sshKey"}},
|
||||||
|
"getTLSState": {Func: rpcGetTLSState},
|
||||||
|
"setTLSState": {Func: rpcSetTLSState, Params: []string{"state"}},
|
||||||
"setMassStorageMode": {Func: rpcSetMassStorageMode, Params: []string{"mode"}},
|
"setMassStorageMode": {Func: rpcSetMassStorageMode, Params: []string{"mode"}},
|
||||||
"getMassStorageMode": {Func: rpcGetMassStorageMode},
|
"getMassStorageMode": {Func: rpcGetMassStorageMode},
|
||||||
"isUpdatePending": {Func: rpcIsUpdatePending},
|
"isUpdatePending": {Func: rpcIsUpdatePending},
|
||||||
|
|
1
log.go
1
log.go
|
@ -50,6 +50,7 @@ var (
|
||||||
displayLogger = getLogger("display")
|
displayLogger = getLogger("display")
|
||||||
usbLogger = getLogger("usb")
|
usbLogger = getLogger("usb")
|
||||||
ginLogger = getLogger("gin")
|
ginLogger = getLogger("gin")
|
||||||
|
websecureLogger = getLogger("websecure")
|
||||||
)
|
)
|
||||||
|
|
||||||
func updateLogLevel() {
|
func updateLogLevel() {
|
||||||
|
|
6
main.go
6
main.go
|
@ -69,9 +69,13 @@ func Main() {
|
||||||
}()
|
}()
|
||||||
//go RunFuseServer()
|
//go RunFuseServer()
|
||||||
go RunWebServer()
|
go RunWebServer()
|
||||||
|
|
||||||
|
go RunWebSecureServer()
|
||||||
|
// Web secure server is started only if TLS mode is enabled
|
||||||
if config.TLSMode != "" {
|
if config.TLSMode != "" {
|
||||||
go RunWebSecureServer()
|
startWebSecureServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
// As websocket client already checks if the cloud token is set, we can start it here.
|
// As websocket client already checks if the cloud token is set, we can start it here.
|
||||||
go RunWebsocketClient()
|
go RunWebsocketClient()
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,18 @@ import notifications from "@/notifications";
|
||||||
import { DEVICE_API } from "@/ui.config";
|
import { DEVICE_API } from "@/ui.config";
|
||||||
import { useJsonRpc } from "@/hooks/useJsonRpc";
|
import { useJsonRpc } from "@/hooks/useJsonRpc";
|
||||||
import { isOnDevice } from "@/main";
|
import { isOnDevice } from "@/main";
|
||||||
|
import { TextAreaWithLabel } from "@components/TextArea";
|
||||||
|
|
||||||
import { LocalDevice } from "./devices.$id";
|
import { LocalDevice } from "./devices.$id";
|
||||||
import { SettingsItem } from "./devices.$id.settings";
|
import { SettingsItem } from "./devices.$id.settings";
|
||||||
import { CloudState } from "./adopt";
|
import { CloudState } from "./adopt";
|
||||||
|
|
||||||
|
export interface TLSState {
|
||||||
|
mode: "self-signed" | "custom" | "disabled";
|
||||||
|
certificate?: string;
|
||||||
|
privateKey?: string;
|
||||||
|
}
|
||||||
|
|
||||||
export const loader = async () => {
|
export const loader = async () => {
|
||||||
if (isOnDevice) {
|
if (isOnDevice) {
|
||||||
const status = await api
|
const status = await api
|
||||||
|
@ -44,6 +51,9 @@ export default function SettingsAccessIndexRoute() {
|
||||||
|
|
||||||
// Use a simple string identifier for the selected provider
|
// Use a simple string identifier for the selected provider
|
||||||
const [selectedProvider, setSelectedProvider] = useState<string>("jetkvm");
|
const [selectedProvider, setSelectedProvider] = useState<string>("jetkvm");
|
||||||
|
const [tlsMode, setTlsMode] = useState<string>("unknown");
|
||||||
|
const [tlsCert, setTlsCert] = useState<string>("");
|
||||||
|
const [tlsKey, setTlsKey] = useState<string>("");
|
||||||
|
|
||||||
const getCloudState = useCallback(() => {
|
const getCloudState = useCallback(() => {
|
||||||
send("getCloudState", {}, resp => {
|
send("getCloudState", {}, resp => {
|
||||||
|
@ -66,6 +76,17 @@ export default function SettingsAccessIndexRoute() {
|
||||||
});
|
});
|
||||||
}, [send]);
|
}, [send]);
|
||||||
|
|
||||||
|
const getTLSState = useCallback(() => {
|
||||||
|
send("getTLSState", {}, resp => {
|
||||||
|
if ("error" in resp) return console.error(resp.error);
|
||||||
|
const tlsState = resp.result as TLSState;
|
||||||
|
|
||||||
|
setTlsMode(tlsState.mode);
|
||||||
|
if (tlsState.certificate) setTlsCert(tlsState.certificate);
|
||||||
|
if (tlsState.privateKey) setTlsKey(tlsState.privateKey);
|
||||||
|
});
|
||||||
|
}, [send]);
|
||||||
|
|
||||||
const deregisterDevice = async () => {
|
const deregisterDevice = async () => {
|
||||||
send("deregisterDevice", {}, resp => {
|
send("deregisterDevice", {}, resp => {
|
||||||
if ("error" in resp) {
|
if ("error" in resp) {
|
||||||
|
@ -126,15 +147,62 @@ export default function SettingsAccessIndexRoute() {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Function to update TLS state - accepts a mode parameter
|
||||||
|
const updateTlsState = useCallback(
|
||||||
|
(mode: string, cert?: string, key?: string) => {
|
||||||
|
const state = { mode } as TLSState;
|
||||||
|
if (cert && key) {
|
||||||
|
state.certificate = cert;
|
||||||
|
state.privateKey = key;
|
||||||
|
}
|
||||||
|
|
||||||
|
send("setTLSState", { state }, resp => {
|
||||||
|
if ("error" in resp) {
|
||||||
|
notifications.error(
|
||||||
|
`Failed to update TLS settings: ${resp.error.data || "Unknown error"}`,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
notifications.success("TLS settings updated successfully");
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[send],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Handle TLS mode change
|
||||||
|
const handleTlsModeChange = (value: string) => {
|
||||||
|
setTlsMode(value);
|
||||||
|
|
||||||
|
// For "disabled" and "self-signed" modes, immediately apply the settings
|
||||||
|
if (value !== "custom") {
|
||||||
|
updateTlsState(value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleTlsCertChange = (value: string) => {
|
||||||
|
setTlsCert(value);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleTlsKeyChange = (value: string) => {
|
||||||
|
setTlsKey(value);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update the custom TLS settings button click handler
|
||||||
|
const handleCustomTlsUpdate = () => {
|
||||||
|
updateTlsState(tlsMode, tlsCert, tlsKey);
|
||||||
|
};
|
||||||
|
|
||||||
// Fetch device ID and cloud state on component mount
|
// Fetch device ID and cloud state on component mount
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
getCloudState();
|
getCloudState();
|
||||||
|
getTLSState();
|
||||||
|
|
||||||
send("getDeviceID", {}, async resp => {
|
send("getDeviceID", {}, async resp => {
|
||||||
if ("error" in resp) return console.error(resp.error);
|
if ("error" in resp) return console.error(resp.error);
|
||||||
setDeviceId(resp.result as string);
|
setDeviceId(resp.result as string);
|
||||||
});
|
});
|
||||||
}, [send, getCloudState]);
|
}, [send, getCloudState, getTLSState]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
|
@ -150,30 +218,95 @@ export default function SettingsAccessIndexRoute() {
|
||||||
title="Local"
|
title="Local"
|
||||||
description="Manage the mode of local access to the device"
|
description="Manage the mode of local access to the device"
|
||||||
/>
|
/>
|
||||||
<SettingsItem
|
<>
|
||||||
title="Authentication Mode"
|
<SettingsItem
|
||||||
description={`Current mode: ${loaderData.authMode === "password" ? "Password protected" : "No password"}`}
|
title="HTTPS Mode"
|
||||||
>
|
badge="Experimental"
|
||||||
{loaderData.authMode === "password" ? (
|
description="Configure secure HTTPS access to your device"
|
||||||
<Button
|
>
|
||||||
|
<SelectMenuBasic
|
||||||
size="SM"
|
size="SM"
|
||||||
theme="light"
|
value={tlsMode}
|
||||||
text="Disable Protection"
|
onChange={e => handleTlsModeChange(e.target.value)}
|
||||||
onClick={() => {
|
disabled={tlsMode === "unknown"}
|
||||||
navigateTo("./local-auth", { state: { init: "deletePassword" } });
|
options={[
|
||||||
}}
|
{ value: "disabled", label: "Disabled" },
|
||||||
/>
|
{ value: "self-signed", label: "Self-signed" },
|
||||||
) : (
|
{ value: "custom", label: "Custom" },
|
||||||
<Button
|
]}
|
||||||
size="SM"
|
|
||||||
theme="light"
|
|
||||||
text="Enable Password"
|
|
||||||
onClick={() => {
|
|
||||||
navigateTo("./local-auth", { state: { init: "createPassword" } });
|
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
|
</SettingsItem>
|
||||||
|
|
||||||
|
{tlsMode === "custom" && (
|
||||||
|
<div className="mt-4 space-y-4">
|
||||||
|
<div className="space-y-4">
|
||||||
|
<SettingsItem
|
||||||
|
title="TLS Certificate"
|
||||||
|
description="Paste your TLS certificate below. For certificate chains, include the entire chain (leaf, intermediate, and root certificates)."
|
||||||
|
/>
|
||||||
|
<div className="space-y-4">
|
||||||
|
<TextAreaWithLabel
|
||||||
|
label="Certificate"
|
||||||
|
rows={3}
|
||||||
|
placeholder={
|
||||||
|
"-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----"
|
||||||
|
}
|
||||||
|
value={tlsCert}
|
||||||
|
onChange={e => handleTlsCertChange(e.target.value)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="space-y-4">
|
||||||
|
<TextAreaWithLabel
|
||||||
|
label="Private Key"
|
||||||
|
description="For security reasons, it will not be displayed after saving."
|
||||||
|
rows={3}
|
||||||
|
placeholder={
|
||||||
|
"-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----"
|
||||||
|
}
|
||||||
|
value={tlsKey}
|
||||||
|
onChange={e => handleTlsKeyChange(e.target.value)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-x-2">
|
||||||
|
<Button
|
||||||
|
size="SM"
|
||||||
|
theme="primary"
|
||||||
|
text="Update TLS Settings"
|
||||||
|
onClick={handleCustomTlsUpdate}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
)}
|
)}
|
||||||
</SettingsItem>
|
|
||||||
|
<SettingsItem
|
||||||
|
title="Authentication Mode"
|
||||||
|
description={`Current mode: ${loaderData.authMode === "password" ? "Password protected" : "No password"}`}
|
||||||
|
>
|
||||||
|
{loaderData.authMode === "password" ? (
|
||||||
|
<Button
|
||||||
|
size="SM"
|
||||||
|
theme="light"
|
||||||
|
text="Disable Protection"
|
||||||
|
onClick={() => {
|
||||||
|
navigateTo("./local-auth", { state: { init: "deletePassword" } });
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
size="SM"
|
||||||
|
theme="light"
|
||||||
|
text="Enable Password"
|
||||||
|
onClick={() => {
|
||||||
|
navigateTo("./local-auth", { state: { init: "createPassword" } });
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</SettingsItem>
|
||||||
|
</>
|
||||||
|
|
||||||
{loaderData.authMode === "password" && (
|
{loaderData.authMode === "password" && (
|
||||||
<SettingsItem
|
<SettingsItem
|
||||||
|
|
|
@ -246,6 +246,7 @@ export function SettingsItem({
|
||||||
children,
|
children,
|
||||||
className,
|
className,
|
||||||
loading,
|
loading,
|
||||||
|
badge,
|
||||||
}: {
|
}: {
|
||||||
title: string;
|
title: string;
|
||||||
description: string | React.ReactNode;
|
description: string | React.ReactNode;
|
||||||
|
@ -253,6 +254,7 @@ export function SettingsItem({
|
||||||
className?: string;
|
className?: string;
|
||||||
name?: string;
|
name?: string;
|
||||||
loading?: boolean;
|
loading?: boolean;
|
||||||
|
badge?: string;
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<label
|
<label
|
||||||
|
@ -263,10 +265,17 @@ export function SettingsItem({
|
||||||
>
|
>
|
||||||
<div className="space-y-0.5">
|
<div className="space-y-0.5">
|
||||||
<div className="flex items-center gap-x-2">
|
<div className="flex items-center gap-x-2">
|
||||||
<h3 className="text-base font-semibold text-black dark:text-white">{title}</h3>
|
<div className="flex items-center text-base font-semibold text-black dark:text-white">
|
||||||
|
{title}
|
||||||
|
{badge && (
|
||||||
|
<span className="ml-2 rounded-full bg-red-500 px-2 py-1 text-[10px] font-medium leading-none text-white dark:border dark:border-red-700 dark:bg-red-800 dark:text-red-50">
|
||||||
|
{badge}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
{loading && <LoadingSpinner className="h-4 w-4 text-blue-500" />}
|
{loading && <LoadingSpinner className="h-4 w-4 text-blue-500" />}
|
||||||
</div>
|
</div>
|
||||||
<p className="text-sm text-slate-700 dark:text-slate-300">{description}</p>
|
<div className="text-sm text-slate-700 dark:text-slate-300">{description}</div>
|
||||||
</div>
|
</div>
|
||||||
{children ? <div>{children}</div> : null}
|
{children ? <div>{children}</div> : null}
|
||||||
</label>
|
</label>
|
||||||
|
|
283
web_tls.go
283
web_tls.go
|
@ -1,132 +1,211 @@
|
||||||
package kvm
|
package kvm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ecdsa"
|
"context"
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"math/big"
|
"errors"
|
||||||
"net"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
"github.com/jetkvm/kvm/internal/websecure"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
WebSecureListen = ":443"
|
tlsStorePath = "/userdata/jetkvm/tls"
|
||||||
WebSecureSelfSignedDefaultDomain = "jetkvm.local"
|
webSecureListen = ":443"
|
||||||
WebSecureSelfSignedDuration = 365 * 24 * time.Hour
|
webSecureSelfSignedDefaultDomain = "jetkvm.local"
|
||||||
|
webSecureSelfSignedCAName = "JetKVM Self-Signed CA"
|
||||||
|
webSecureSelfSignedOrganization = "JetKVM"
|
||||||
|
webSecureSelfSignedOU = "JetKVM Self-Signed"
|
||||||
|
webSecureCustomCertificateName = "user-defined"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tlsCerts = make(map[string]*tls.Certificate)
|
certStore *websecure.CertStore
|
||||||
tlsCertLock = &sync.Mutex{}
|
certSigner *websecure.SelfSigner
|
||||||
|
)
|
||||||
|
|
||||||
|
type TLSState struct {
|
||||||
|
Mode string `json:"mode"`
|
||||||
|
Certificate string `json:"certificate"`
|
||||||
|
PrivateKey string `json:"privateKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func initCertStore() {
|
||||||
|
if certStore != nil {
|
||||||
|
websecureLogger.Warn().Msg("TLS store already initialized, it should not be initialized again")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
certStore = websecure.NewCertStore(tlsStorePath, &websecureLogger)
|
||||||
|
certStore.LoadCertificates()
|
||||||
|
|
||||||
|
certSigner = websecure.NewSelfSigner(
|
||||||
|
certStore,
|
||||||
|
&websecureLogger,
|
||||||
|
webSecureSelfSignedDefaultDomain,
|
||||||
|
webSecureSelfSignedOrganization,
|
||||||
|
webSecureSelfSignedOU,
|
||||||
|
webSecureSelfSignedCAName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
if config.TLSMode == "self-signed" {
|
||||||
|
if isTimeSyncNeeded() || !timeSyncSuccess {
|
||||||
|
return nil, fmt.Errorf("time is not synced")
|
||||||
|
}
|
||||||
|
return certSigner.GetCertificate(info)
|
||||||
|
} else if config.TLSMode == "custom" {
|
||||||
|
return certStore.GetCertificate(webSecureCustomCertificateName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
websecureLogger.Info().Msg("TLS mode is disabled but WebSecure is running, returning nil")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTLSState() TLSState {
|
||||||
|
s := TLSState{}
|
||||||
|
switch config.TLSMode {
|
||||||
|
case "disabled":
|
||||||
|
s.Mode = "disabled"
|
||||||
|
case "custom":
|
||||||
|
s.Mode = "custom"
|
||||||
|
cert := certStore.GetCertificate(webSecureCustomCertificateName)
|
||||||
|
if cert != nil {
|
||||||
|
var certPEM []byte
|
||||||
|
// convert to pem format
|
||||||
|
for _, c := range cert.Certificate {
|
||||||
|
block := pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: c,
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM = append(certPEM, pem.EncodeToMemory(&block)...)
|
||||||
|
}
|
||||||
|
s.Certificate = string(certPEM)
|
||||||
|
}
|
||||||
|
case "self-signed":
|
||||||
|
s.Mode = "self-signed"
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func setTLSState(s TLSState) error {
|
||||||
|
var isChanged = false
|
||||||
|
|
||||||
|
switch s.Mode {
|
||||||
|
case "disabled":
|
||||||
|
if config.TLSMode != "" {
|
||||||
|
isChanged = true
|
||||||
|
}
|
||||||
|
config.TLSMode = ""
|
||||||
|
case "custom":
|
||||||
|
if config.TLSMode == "" {
|
||||||
|
isChanged = true
|
||||||
|
}
|
||||||
|
// parse pem to cert and key
|
||||||
|
err, _ := certStore.ValidateAndSaveCertificate(webSecureCustomCertificateName, s.Certificate, s.PrivateKey, true)
|
||||||
|
// warn doesn't matter as ... we don't know the hostname yet
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to save certificate: %w", err)
|
||||||
|
}
|
||||||
|
config.TLSMode = "custom"
|
||||||
|
case "self-signed":
|
||||||
|
if config.TLSMode == "" {
|
||||||
|
isChanged = true
|
||||||
|
}
|
||||||
|
config.TLSMode = "self-signed"
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid TLS mode: %s", s.Mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isChanged {
|
||||||
|
websecureLogger.Info().Msg("TLS enabled state is not changed, not starting/stopping websecure server")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TLSMode == "" {
|
||||||
|
websecureLogger.Info().Msg("Stopping websecure server, as TLS mode is disabled")
|
||||||
|
stopWebSecureServer()
|
||||||
|
} else {
|
||||||
|
websecureLogger.Info().Msg("Starting websecure server, as TLS mode is enabled")
|
||||||
|
startWebSecureServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
startTLS = make(chan struct{})
|
||||||
|
stopTLS = make(chan struct{})
|
||||||
|
tlsServiceLock = sync.Mutex{}
|
||||||
|
tlsStarted = false
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunWebSecureServer runs a web server with TLS.
|
// RunWebSecureServer runs a web server with TLS.
|
||||||
func RunWebSecureServer() {
|
func runWebSecureServer() {
|
||||||
|
tlsServiceLock.Lock()
|
||||||
|
defer tlsServiceLock.Unlock()
|
||||||
|
|
||||||
|
tlsStarted = true
|
||||||
|
defer func() {
|
||||||
|
tlsStarted = false
|
||||||
|
}()
|
||||||
|
|
||||||
r := setupRouter()
|
r := setupRouter()
|
||||||
|
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: WebSecureListen,
|
Addr: webSecureListen,
|
||||||
Handler: r,
|
Handler: r,
|
||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
// TODO: cache certificate in persistent storage
|
MaxVersion: tls.VersionTLS13,
|
||||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
CurvePreferences: []tls.CurveID{},
|
||||||
var hostname string
|
GetCertificate: getCertificate,
|
||||||
if info.ServerName != "" {
|
|
||||||
hostname = info.ServerName
|
|
||||||
} else {
|
|
||||||
hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info().Str("hostname", hostname).Interface("SupportedProtos", info.SupportedProtos).Msg("TLS handshake")
|
|
||||||
|
|
||||||
cert := createSelfSignedCert(hostname)
|
|
||||||
|
|
||||||
return cert, nil
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
logger.Info().Str("listen", WebSecureListen).Msg("Starting websecure server")
|
websecureLogger.Info().Str("listen", webSecureListen).Msg("Starting websecure server")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for _ = range stopTLS {
|
||||||
|
websecureLogger.Info().Msg("Shutting down websecure server")
|
||||||
|
err := server.Shutdown(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
websecureLogger.Error().Err(err).Msg("Failed to shutdown websecure server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
err := server.ListenAndServeTLS("", "")
|
err := server.ListenAndServeTLS("", "")
|
||||||
if err != nil {
|
if !errors.Is(err, http.ErrServerClosed) {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSelfSignedCert(hostname string) *tls.Certificate {
|
func stopWebSecureServer() {
|
||||||
if tlsCert := tlsCerts[hostname]; tlsCert != nil {
|
if !tlsStarted {
|
||||||
return tlsCert
|
websecureLogger.Info().Msg("Websecure server is not running, not stopping it")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stopTLS <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startWebSecureServer() {
|
||||||
|
if tlsStarted {
|
||||||
|
websecureLogger.Info().Msg("Websecure server is already running, not starting it again")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
startTLS <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunWebSecureServer() {
|
||||||
|
for _ = range startTLS {
|
||||||
|
websecureLogger.Info().Msg("Starting websecure server, as we have received a start signal")
|
||||||
|
if certStore == nil {
|
||||||
|
initCertStore()
|
||||||
|
}
|
||||||
|
go runWebSecureServer()
|
||||||
}
|
}
|
||||||
tlsCertLock.Lock()
|
|
||||||
defer tlsCertLock.Unlock()
|
|
||||||
|
|
||||||
logger.Info().Str("hostname", hostname).Msg("Creating self-signed certificate")
|
|
||||||
|
|
||||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Err(err).Msg("Failed to generate private key")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
keyUsage := x509.KeyUsageDigitalSignature
|
|
||||||
|
|
||||||
notBefore := time.Now()
|
|
||||||
notAfter := notBefore.AddDate(1, 0, 0)
|
|
||||||
|
|
||||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
||||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Err(err).Msg("Failed to generate serial number")
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsName := hostname
|
|
||||||
ip := net.ParseIP(hostname)
|
|
||||||
if ip != nil {
|
|
||||||
dnsName = WebSecureSelfSignedDefaultDomain
|
|
||||||
}
|
|
||||||
|
|
||||||
template := x509.Certificate{
|
|
||||||
SerialNumber: serialNumber,
|
|
||||||
Subject: pkix.Name{
|
|
||||||
CommonName: hostname,
|
|
||||||
Organization: []string{"JetKVM"},
|
|
||||||
},
|
|
||||||
NotBefore: notBefore,
|
|
||||||
NotAfter: notAfter,
|
|
||||||
|
|
||||||
KeyUsage: keyUsage,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
|
|
||||||
DNSNames: []string{dnsName},
|
|
||||||
IPAddresses: []net.IP{},
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip != nil {
|
|
||||||
template.IPAddresses = append(template.IPAddresses, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Err(err).Msg("Failed to create certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
|
||||||
if cert == nil {
|
|
||||||
logger.Warn().Msg("Failed to encode certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsCert := &tls.Certificate{
|
|
||||||
Certificate: [][]byte{derBytes},
|
|
||||||
PrivateKey: priv,
|
|
||||||
}
|
|
||||||
tlsCerts[hostname] = tlsCert
|
|
||||||
|
|
||||||
return tlsCert
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue