diff --git a/config.go b/config.go index c38f1ed..f19b6e0 100644 --- a/config.go +++ b/config.go @@ -90,7 +90,7 @@ type Config struct { DisplayMaxBrightness int `json:"display_max_brightness"` DisplayDimAfterSec int `json:"display_dim_after_sec"` DisplayOffAfterSec int `json:"display_off_after_sec"` - TLSMode string `json:"tls_mode"` + TLSMode string `json:"tls_mode"` // options: "self-signed", "user-defined", "" UsbConfig *usbgadget.Config `json:"usb_config"` UsbDevices *usbgadget.Devices `json:"usb_devices"` } @@ -169,6 +169,8 @@ func SaveConfig() error { configLock.Lock() defer configLock.Unlock() + logger.Trace().Str("path", configPath).Msg("Saving config") + file, err := os.Create(configPath) if err != nil { return fmt.Errorf("failed to create config file: %w", err) diff --git a/internal/websecure/log.go b/internal/websecure/log.go new file mode 100644 index 0000000..f45767e --- /dev/null +++ b/internal/websecure/log.go @@ -0,0 +1,9 @@ +package websecure + +import ( + "os" + + "github.com/rs/zerolog" +) + +var defaultLogger = zerolog.New(os.Stdout).With().Str("component", "websecure").Logger() diff --git a/internal/websecure/selfsign.go b/internal/websecure/selfsign.go new file mode 100644 index 0000000..77efa37 --- /dev/null +++ b/internal/websecure/selfsign.go @@ -0,0 +1,191 @@ +package websecure + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "net" + "strings" + "time" + + "github.com/rs/zerolog" + "golang.org/x/net/idna" +) + +const selfSignerCAMagicName = "__ca__" + +type SelfSigner struct { + store *CertStore + log *zerolog.Logger + + caInfo pkix.Name + + DefaultDomain string + DefaultOrg string + DefaultOU string +} + +func NewSelfSigner( + store *CertStore, + log *zerolog.Logger, + defaultDomain, + defaultOrg, + defaultOU, + caName string, +) *SelfSigner { + return &SelfSigner{ + store: store, + log: log, + DefaultDomain: defaultDomain, + DefaultOrg: defaultOrg, + DefaultOU: defaultOU, + caInfo: pkix.Name{ + CommonName: caName, + Organization: []string{defaultOrg}, + OrganizationalUnit: []string{defaultOU}, + }, + } +} + +func (s *SelfSigner) getCA() *tls.Certificate { + return s.createSelfSignedCert(selfSignerCAMagicName) +} + +func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate { + if tlsCert := s.store.certificates[hostname]; tlsCert != nil { + return tlsCert + } + + // check if hostname is the CA magic name + var ca *tls.Certificate + if hostname != selfSignerCAMagicName { + ca = s.getCA() + if ca == nil { + s.log.Error().Msg("Failed to get CA certificate") + return nil + } + } + + s.log.Info().Str("hostname", hostname).Msg("Creating self-signed certificate") + + // lock the store while creating the certificate (do not move upwards) + s.store.certLock.Lock() + defer s.store.certLock.Unlock() + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + s.log.Error().Err(err).Msg("Failed to generate private key") + return nil + } + + notBefore := time.Now() + notAfter := notBefore.AddDate(1, 0, 0) + + serialNumber, err := generateSerialNumber() + if err != nil { + s.log.Error().Err(err).Msg("Failed to generate serial number") + return nil + } + + dnsName := hostname + ip := net.ParseIP(hostname) + if ip != nil { + dnsName = s.DefaultDomain + } + + // set up CSR + isCA := hostname == selfSignerCAMagicName + subject := pkix.Name{ + CommonName: hostname, + Organization: []string{s.DefaultOrg}, + OrganizationalUnit: []string{s.DefaultOU}, + } + keyUsage := x509.KeyUsageDigitalSignature + extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + + // check if hostname is the CA magic name, and if so, set the subject to the CA info + if isCA { + subject = s.caInfo + keyUsage |= x509.KeyUsageCertSign + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth) + notAfter = notBefore.AddDate(10, 0, 0) + } + + cert := x509.Certificate{ + SerialNumber: serialNumber, + Subject: subject, + NotBefore: notBefore, + NotAfter: notAfter, + IsCA: isCA, + KeyUsage: keyUsage, + ExtKeyUsage: extKeyUsage, + BasicConstraintsValid: true, + } + + // set up DNS names and IP addresses + if !isCA { + cert.DNSNames = []string{dnsName} + if ip != nil { + cert.IPAddresses = []net.IP{ip} + } + } + + // set up parent certificate + parent := &cert + parentPriv := priv + if ca != nil { + parent, err = x509.ParseCertificate(ca.Certificate[0]) + if err != nil { + s.log.Error().Err(err).Msg("Failed to parse parent certificate") + return nil + } + parentPriv = ca.PrivateKey.(*ecdsa.PrivateKey) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &cert, parent, &priv.PublicKey, parentPriv) + if err != nil { + s.log.Error().Err(err).Msg("Failed to create certificate") + return nil + } + + tlsCert := &tls.Certificate{ + Certificate: [][]byte{certBytes}, + PrivateKey: priv, + } + if ca != nil { + tlsCert.Certificate = append(tlsCert.Certificate, ca.Certificate...) + } + + s.store.certificates[hostname] = tlsCert + s.store.saveCertificate(hostname) + + return tlsCert +} + +// GetCertificate returns the certificate for the given hostname +// returns nil if the certificate is not found +func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + var hostname string + if info.ServerName != "" && info.ServerName != selfSignerCAMagicName { + hostname = info.ServerName + } else { + hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0] + } + + s.log.Info().Str("hostname", hostname).Strs("supported_protos", info.SupportedProtos).Msg("TLS handshake") + + // convert hostname to punycode + h, err := idna.Lookup.ToASCII(hostname) + if err != nil { + s.log.Warn().Str("hostname", hostname).Err(err).Str("remote_addr", info.Conn.RemoteAddr().String()).Msg("Hostname is not valid") + hostname = s.DefaultDomain + } else { + hostname = h + } + + cert := s.createSelfSignedCert(hostname) + return cert, nil +} diff --git a/internal/websecure/store.go b/internal/websecure/store.go new file mode 100644 index 0000000..7da2dee --- /dev/null +++ b/internal/websecure/store.go @@ -0,0 +1,175 @@ +package websecure + +import ( + "crypto/tls" + "fmt" + "os" + "path" + "strings" + "sync" + + "github.com/rs/zerolog" +) + +type CertStore struct { + certificates map[string]*tls.Certificate + certLock *sync.Mutex + + storePath string + + log *zerolog.Logger +} + +func NewCertStore(storePath string, log *zerolog.Logger) *CertStore { + if log == nil { + log = &defaultLogger + } + + return &CertStore{ + certificates: make(map[string]*tls.Certificate), + certLock: &sync.Mutex{}, + + storePath: storePath, + log: log, + } +} + +func (s *CertStore) ensureStorePath() error { + // check if directory exists + stat, err := os.Stat(s.storePath) + if err == nil { + if stat.IsDir() { + return nil + } + + return fmt.Errorf("TLS store path exists but is not a directory: %s", s.storePath) + } + + if os.IsNotExist(err) { + s.log.Trace().Str("path", s.storePath).Msg("TLS store directory does not exist, creating directory") + err = os.MkdirAll(s.storePath, 0755) + if err != nil { + return fmt.Errorf("Failed to create TLS store path: %w", err) + } + return nil + } + + return fmt.Errorf("Failed to check TLS store path: %w", err) +} + +func (s *CertStore) LoadCertificates() { + err := s.ensureStorePath() + if err != nil { + s.log.Error().Err(err).Msg("Failed to ensure store path") + return + } + + files, err := os.ReadDir(s.storePath) + if err != nil { + s.log.Error().Err(err).Msg("Failed to read TLS directory") + return + } + + for _, file := range files { + if file.IsDir() { + continue + } + + if strings.HasSuffix(file.Name(), ".crt") { + s.loadCertificate(strings.TrimSuffix(file.Name(), ".crt")) + } + } +} + +func (s *CertStore) loadCertificate(hostname string) { + s.certLock.Lock() + defer s.certLock.Unlock() + + keyFile := path.Join(s.storePath, hostname+".key") + crtFile := path.Join(s.storePath, hostname+".crt") + + cert, err := tls.LoadX509KeyPair(crtFile, keyFile) + if err != nil { + s.log.Error().Err(err).Str("hostname", hostname).Msg("Failed to load certificate") + return + } + + s.certificates[hostname] = &cert + + s.log.Info().Str("hostname", hostname).Msg("Loaded certificate") +} + +// GetCertificate returns the certificate for the given hostname +// returns nil if the certificate is not found +func (s *CertStore) GetCertificate(hostname string) *tls.Certificate { + s.certLock.Lock() + defer s.certLock.Unlock() + + return s.certificates[hostname] +} + +// ValidateAndSaveCertificate validates the certificate and saves it to the store +// returns are: +// - error: if the certificate is invalid or if there's any error during saving the certificate +// - error: if there's any warning or error during saving the certificate +func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key string, ignoreWarning bool) (error, error) { + tlsCert, err := tls.X509KeyPair([]byte(cert), []byte(key)) + if err != nil { + return fmt.Errorf("Failed to parse certificate: %w", err), nil + } + + // this can be skipped as current implementation supports one custom certificate only + if tlsCert.Leaf != nil { + // add recover to avoid panic + defer func() { + if r := recover(); r != nil { + s.log.Error().Interface("recovered", r).Msg("Failed to verify hostname") + } + }() + + if err = tlsCert.Leaf.VerifyHostname(hostname); err != nil { + if !ignoreWarning { + return nil, fmt.Errorf("Certificate does not match hostname: %w", err) + } + s.log.Warn().Err(err).Msg("Certificate does not match hostname") + } + } + + s.certLock.Lock() + s.certificates[hostname] = &tlsCert + s.certLock.Unlock() + + s.saveCertificate(hostname) + + return nil, nil +} + +func (s *CertStore) saveCertificate(hostname string) { + // check if certificate already exists + tlsCert := s.certificates[hostname] + if tlsCert == nil { + s.log.Error().Str("hostname", hostname).Msg("Certificate for hostname does not exist, skipping saving certificate") + return + } + + err := s.ensureStorePath() + if err != nil { + s.log.Error().Err(err).Msg("Failed to ensure store path") + return + } + + keyFile := path.Join(s.storePath, hostname+".key") + crtFile := path.Join(s.storePath, hostname+".crt") + + if err := keyToFile(tlsCert, keyFile); err != nil { + s.log.Error().Err(err).Msg("Failed to save key file") + return + } + + if err := certToFile(tlsCert, crtFile); err != nil { + s.log.Error().Err(err).Msg("Failed to save certificate") + return + } + + s.log.Info().Str("hostname", hostname).Msg("Saved certificate") +} diff --git a/internal/websecure/utils.go b/internal/websecure/utils.go new file mode 100644 index 0000000..de29c73 --- /dev/null +++ b/internal/websecure/utils.go @@ -0,0 +1,80 @@ +package websecure + +import ( + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "math/big" + "os" +) + +var serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 4096) + +func withSecretFile(filename string, f func(*os.File) error) error { + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer file.Close() + + return f(file) +} + +func keyToFile(cert *tls.Certificate, filename string) error { + var keyBlock pem.Block + switch k := cert.PrivateKey.(type) { + case *rsa.PrivateKey: + keyBlock = pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(k), + } + case *ecdsa.PrivateKey: + b, e := x509.MarshalECPrivateKey(k) + if e != nil { + return fmt.Errorf("Failed to marshal EC private key: %v", e) + } + + keyBlock = pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: b, + } + default: + return fmt.Errorf("Unknown private key type: %T", k) + } + + err := withSecretFile(filename, func(file *os.File) error { + return pem.Encode(file, &keyBlock) + }) + + if err != nil { + return fmt.Errorf("Failed to save private key: %w", err) + } + + return nil +} + +func certToFile(cert *tls.Certificate, filename string) error { + return withSecretFile(filename, func(file *os.File) error { + for _, c := range cert.Certificate { + block := pem.Block{ + Type: "CERTIFICATE", + Bytes: c, + } + + err := pem.Encode(file, &block) + if err != nil { + return fmt.Errorf("Failed to save certificate: %w", err) + } + } + + return nil + }) +} + +func generateSerialNumber() (*big.Int, error) { + return rand.Int(rand.Reader, serialNumberLimit) +} diff --git a/jsonrpc.go b/jsonrpc.go index de29e08..d56b8ea 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -95,7 +95,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - //logger.Infof("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID) + logger.Trace().Str("method", request.Method).Interface("params", request.Params).Interface("id", request.ID).Msg("Received RPC request") handler, ok := rpcHandlers[request.Method] if !ok { errorResponse := JSONRPCResponse{ @@ -110,6 +110,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } + logger.Trace().Str("method", request.Method).Interface("id", request.ID).Msg("Calling RPC handler") result, err := callRPCHandler(handler, request.Params) if err != nil { errorResponse := JSONRPCResponse{ @@ -125,6 +126,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } + logger.Trace().Interface("result", result).Interface("id", request.ID).Msg("RPC handler returned") response := JSONRPCResponse{ JSONRPC: "2.0", Result: result, @@ -141,6 +143,30 @@ func rpcGetDeviceID() (string, error) { return GetDeviceID(), nil } +func rpcReboot(force bool) error { + logger.Info().Msg("Got reboot request from JSONRPC, rebooting...") + + args := []string{} + if force { + args = append(args, "-f") + } + + cmd := exec.Command("reboot", args...) + err := cmd.Start() + if err != nil { + logger.Error().Err(err).Msg("failed to reboot") + return fmt.Errorf("failed to reboot: %w", err) + } + + // If the reboot command is successful, exit the program after 5 seconds + go func() { + time.Sleep(5 * time.Second) + os.Exit(0) + }() + + return nil +} + var streamFactor = 1.0 func rpcGetStreamQualityFactor() (float64, error) { @@ -375,6 +401,23 @@ func rpcSetSSHKeyState(sshKey string) error { return nil } +func rpcGetTLSState() TLSState { + return getTLSState() +} + +func rpcSetTLSState(state TLSState) error { + err := setTLSState(state) + if err != nil { + return fmt.Errorf("failed to set TLS state: %w", err) + } + + if err := SaveConfig(); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + + return nil +} + func callRPCHandler(handler RPCHandler, params map[string]interface{}) (interface{}, error) { handlerValue := reflect.ValueOf(handler.Func) handlerType := handlerValue.Type() @@ -892,6 +935,7 @@ func setKeyboardMacros(params KeyboardMacrosParams) (interface{}, error) { var rpcHandlers = map[string]RPCHandler{ "ping": {Func: rpcPing}, + "reboot": {Func: rpcReboot, Params: []string{"force"}}, "getDeviceID": {Func: rpcGetDeviceID}, "deregisterDevice": {Func: rpcDeregisterDevice}, "getCloudState": {Func: rpcGetCloudState}, @@ -920,6 +964,8 @@ var rpcHandlers = map[string]RPCHandler{ "setDevModeState": {Func: rpcSetDevModeState, Params: []string{"enabled"}}, "getSSHKeyState": {Func: rpcGetSSHKeyState}, "setSSHKeyState": {Func: rpcSetSSHKeyState, Params: []string{"sshKey"}}, + "getTLSState": {Func: rpcGetTLSState}, + "setTLSState": {Func: rpcSetTLSState, Params: []string{"state"}}, "setMassStorageMode": {Func: rpcSetMassStorageMode, Params: []string{"mode"}}, "getMassStorageMode": {Func: rpcGetMassStorageMode}, "isUpdatePending": {Func: rpcIsUpdatePending}, diff --git a/log.go b/log.go index 6824a3f..5dac1f6 100644 --- a/log.go +++ b/log.go @@ -50,6 +50,7 @@ var ( displayLogger = getLogger("display") usbLogger = getLogger("usb") ginLogger = getLogger("gin") + websecureLogger = getLogger("websecure") ) func updateLogLevel() { diff --git a/main.go b/main.go index 98748c7..d74b1ef 100644 --- a/main.go +++ b/main.go @@ -69,9 +69,13 @@ func Main() { }() //go RunFuseServer() go RunWebServer() + + go RunWebSecureServer() + // Web secure server is started only if TLS mode is enabled if config.TLSMode != "" { - go RunWebSecureServer() + startWebSecureServer() } + // As websocket client already checks if the cloud token is set, we can start it here. go RunWebsocketClient() diff --git a/ui/src/routes/devices.$id.settings.access._index.tsx b/ui/src/routes/devices.$id.settings.access._index.tsx index 0ed5862..d8eebf9 100644 --- a/ui/src/routes/devices.$id.settings.access._index.tsx +++ b/ui/src/routes/devices.$id.settings.access._index.tsx @@ -14,11 +14,18 @@ import notifications from "@/notifications"; import { DEVICE_API } from "@/ui.config"; import { useJsonRpc } from "@/hooks/useJsonRpc"; import { isOnDevice } from "@/main"; +import { TextAreaWithLabel } from "@components/TextArea"; import { LocalDevice } from "./devices.$id"; import { SettingsItem } from "./devices.$id.settings"; import { CloudState } from "./adopt"; +export interface TLSState { + mode: "self-signed" | "custom" | "disabled"; + certificate?: string; + privateKey?: string; +} + export const loader = async () => { if (isOnDevice) { const status = await api @@ -44,6 +51,9 @@ export default function SettingsAccessIndexRoute() { // Use a simple string identifier for the selected provider const [selectedProvider, setSelectedProvider] = useState("jetkvm"); + const [tlsMode, setTlsMode] = useState("unknown"); + const [tlsCert, setTlsCert] = useState(""); + const [tlsKey, setTlsKey] = useState(""); const getCloudState = useCallback(() => { send("getCloudState", {}, resp => { @@ -66,6 +76,17 @@ export default function SettingsAccessIndexRoute() { }); }, [send]); + const getTLSState = useCallback(() => { + send("getTLSState", {}, resp => { + if ("error" in resp) return console.error(resp.error); + const tlsState = resp.result as TLSState; + + setTlsMode(tlsState.mode); + if (tlsState.certificate) setTlsCert(tlsState.certificate); + if (tlsState.privateKey) setTlsKey(tlsState.privateKey); + }); + }, [send]); + const deregisterDevice = async () => { send("deregisterDevice", {}, resp => { if ("error" in resp) { @@ -126,15 +147,62 @@ export default function SettingsAccessIndexRoute() { } }; + // Function to update TLS state - accepts a mode parameter + const updateTlsState = useCallback( + (mode: string, cert?: string, key?: string) => { + const state = { mode } as TLSState; + if (cert && key) { + state.certificate = cert; + state.privateKey = key; + } + + send("setTLSState", { state }, resp => { + if ("error" in resp) { + notifications.error( + `Failed to update TLS settings: ${resp.error.data || "Unknown error"}`, + ); + return; + } + + notifications.success("TLS settings updated successfully"); + }); + }, + [send], + ); + + // Handle TLS mode change + const handleTlsModeChange = (value: string) => { + setTlsMode(value); + + // For "disabled" and "self-signed" modes, immediately apply the settings + if (value !== "custom") { + updateTlsState(value); + } + }; + + const handleTlsCertChange = (value: string) => { + setTlsCert(value); + }; + + const handleTlsKeyChange = (value: string) => { + setTlsKey(value); + }; + + // Update the custom TLS settings button click handler + const handleCustomTlsUpdate = () => { + updateTlsState(tlsMode, tlsCert, tlsKey); + }; + // Fetch device ID and cloud state on component mount useEffect(() => { getCloudState(); + getTLSState(); send("getDeviceID", {}, async resp => { if ("error" in resp) return console.error(resp.error); setDeviceId(resp.result as string); }); - }, [send, getCloudState]); + }, [send, getCloudState, getTLSState]); return (
@@ -150,30 +218,95 @@ export default function SettingsAccessIndexRoute() { title="Local" description="Manage the mode of local access to the device" /> - - {loaderData.authMode === "password" ? ( -
+ )} - + + + {loaderData.authMode === "password" ? ( +