From a12e081c2e8b0a4f47a240b0b603ae59584cdf7c Mon Sep 17 00:00:00 2001 From: Siyuan Miao Date: Mon, 7 Apr 2025 13:23:16 +0200 Subject: [PATCH] feat(tls): rewrite tls feature --- Makefile | 3 + config.go | 2 +- internal/websecure/selfsign.go | 10 +- internal/websecure/store.go | 45 +++++ jsonrpc.go | 39 +++- main.go | 1 + ntp.go | 33 ++++ .../devices.$id.settings.access._index.tsx | 177 +++++++++++++++--- web_tls.go | 74 +++++++- 9 files changed, 356 insertions(+), 28 deletions(-) diff --git a/Makefile b/Makefile index 2aefdea..5ef7804 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,9 @@ VERSION := 0.3.8 PROMETHEUS_TAG := github.com/prometheus/common/version KVM_PKG_NAME := github.com/jetkvm/kvm +PROMETHEUS_TAG := github.com/prometheus/common/version +KVM_PKG_NAME := github.com/jetkvm/kvm + GO_LDFLAGS := \ -s -w \ -X $(PROMETHEUS_TAG).Branch=$(BRANCH) \ diff --git a/config.go b/config.go index 642f113..7c8437d 100644 --- a/config.go +++ b/config.go @@ -31,7 +31,7 @@ type Config struct { DisplayMaxBrightness int `json:"display_max_brightness"` DisplayDimAfterSec int `json:"display_dim_after_sec"` DisplayOffAfterSec int `json:"display_off_after_sec"` - TLSMode string `json:"tls_mode"` + TLSMode string `json:"tls_mode"` // options: "self-signed", "user-defined", "" UsbConfig *usbgadget.Config `json:"usb_config"` UsbDevices *usbgadget.Devices `json:"usb_devices"` } diff --git a/internal/websecure/selfsign.go b/internal/websecure/selfsign.go index 04844ae..f93d719 100644 --- a/internal/websecure/selfsign.go +++ b/internal/websecure/selfsign.go @@ -29,7 +29,14 @@ type SelfSigner struct { DefaultOU string } -func NewSelfSigner(store *CertStore, log logging.LeveledLogger, defaultDomain, defaultOrg, defaultOU, caName string) *SelfSigner { +func NewSelfSigner( + store *CertStore, + log logging.LeveledLogger, + defaultDomain, + defaultOrg, + defaultOU, + caName string, +) *SelfSigner { return &SelfSigner{ store: store, log: log, @@ -177,6 +184,5 @@ func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate } cert := s.createSelfSignedCert(hostname) - return cert, nil } diff --git a/internal/websecure/store.go b/internal/websecure/store.go index d6d835e..3049c17 100644 --- a/internal/websecure/store.go +++ b/internal/websecure/store.go @@ -95,6 +95,51 @@ func (s *CertStore) loadCertificate(hostname string) { s.log.Infof("Loaded certificate for %s", hostname) } +// GetCertificate returns the certificate for the given hostname +// returns nil if the certificate is not found +func (s *CertStore) GetCertificate(hostname string) *tls.Certificate { + s.certLock.Lock() + defer s.certLock.Unlock() + + return s.certificates[hostname] +} + +// ValidateAndSaveCertificate validates the certificate and saves it to the store +// returns are: +// - error: if the certificate is invalid or if there's any error during saving the certificate +// - error: if there's any warning or error during saving the certificate +func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key string, ignoreWarning bool) (error, error) { + tlsCert, err := tls.X509KeyPair([]byte(cert), []byte(key)) + if err != nil { + return fmt.Errorf("Failed to parse certificate: %w", err), nil + } + + // this can be skipped as current implementation supports one custom certificate only + if tlsCert.Leaf != nil { + // add recover to avoid panic + defer func() { + if r := recover(); r != nil { + s.log.Errorf("Failed to verify hostname: %v", r) + } + }() + + if err = tlsCert.Leaf.VerifyHostname(hostname); err != nil { + if !ignoreWarning { + return nil, fmt.Errorf("Certificate does not match hostname: %w", err) + } + s.log.Warnf("Certificate does not match hostname: %v", err) + } + } + + s.certLock.Lock() + s.certificates[hostname] = &tlsCert + s.certLock.Unlock() + + s.saveCertificate(hostname) + + return nil, nil +} + func (s *CertStore) saveCertificate(hostname string) { // check if certificate already exists tlsCert := s.certificates[hostname] diff --git a/jsonrpc.go b/jsonrpc.go index 64935e1..fde5946 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -95,7 +95,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - //logger.Infof("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID) + logger.Tracef("Received RPC request: Method=%s, Params=%v, ID=%w", request.Method, request.Params, request.ID) handler, ok := rpcHandlers[request.Method] if !ok { errorResponse := JSONRPCResponse{ @@ -110,6 +110,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } + logger.Tracef("Calling RPC handler: %s, ID=%w", request.Method, request.ID) result, err := callRPCHandler(handler, request.Params) if err != nil { errorResponse := JSONRPCResponse{ @@ -125,6 +126,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } + logger.Tracef("RPC handler returned: %v, ID=%w", result, request.ID) response := JSONRPCResponse{ JSONRPC: "2.0", Result: result, @@ -141,6 +143,30 @@ func rpcGetDeviceID() (string, error) { return GetDeviceID(), nil } +func rpcReboot(force bool) error { + logger.Info("Got reboot request from JSONRPC, rebooting...") + + args := []string{} + if force { + args = append(args, "-f") + } + + cmd := exec.Command("reboot", args...) + err := cmd.Start() + if err != nil { + logger.Errorf("failed to reboot: %v", err) + return fmt.Errorf("failed to reboot: %w", err) + } + + // If the reboot command is successful, exit the program after 5 seconds + go func() { + time.Sleep(5 * time.Second) + os.Exit(0) + }() + + return nil +} + var streamFactor = 1.0 func rpcGetStreamQualityFactor() (float64, error) { @@ -375,6 +401,14 @@ func rpcSetSSHKeyState(sshKey string) error { return nil } +func rpcGetTLSState() TLSState { + return getTLSState() +} + +func rpcSetTLSState(tlsState TLSState) error { + return setTLSState(tlsState) +} + func callRPCHandler(handler RPCHandler, params map[string]interface{}) (interface{}, error) { handlerValue := reflect.ValueOf(handler.Func) handlerType := handlerValue.Type() @@ -794,6 +828,7 @@ func rpcSetScrollSensitivity(sensitivity string) error { var rpcHandlers = map[string]RPCHandler{ "ping": {Func: rpcPing}, + "reboot": {Func: rpcReboot, Params: []string{"force"}}, "getDeviceID": {Func: rpcGetDeviceID}, "deregisterDevice": {Func: rpcDeregisterDevice}, "getCloudState": {Func: rpcGetCloudState}, @@ -822,6 +857,8 @@ var rpcHandlers = map[string]RPCHandler{ "setDevModeState": {Func: rpcSetDevModeState, Params: []string{"enabled"}}, "getSSHKeyState": {Func: rpcGetSSHKeyState}, "setSSHKeyState": {Func: rpcSetSSHKeyState, Params: []string{"sshKey"}}, + "getTLSState": {Func: rpcGetTLSState}, + "setTLSState": {Func: rpcSetTLSState, Params: []string{"state"}}, "setMassStorageMode": {Func: rpcSetMassStorageMode, Params: []string{"mode"}}, "getMassStorageMode": {Func: rpcGetMassStorageMode}, "isUpdatePending": {Func: rpcIsUpdatePending}, diff --git a/main.go b/main.go index b4d73ef..1debb0c 100644 --- a/main.go +++ b/main.go @@ -69,6 +69,7 @@ func Main() { }() //go RunFuseServer() go RunWebServer() + if config.TLSMode != "" { initCertStore() go RunWebSecureServer() diff --git a/ntp.go b/ntp.go index 39ea7af..27ec100 100644 --- a/ntp.go +++ b/ntp.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "os/exec" + "strconv" "time" "github.com/beevik/ntp" @@ -20,13 +21,41 @@ const ( ) var ( + builtTimestamp string timeSyncRetryInterval = 0 * time.Second + timeSyncSuccess = false defaultNTPServers = []string{ "time.cloudflare.com", "time.apple.com", } ) +func isTimeSyncNeeded() bool { + if builtTimestamp == "" { + logger.Warnf("Built timestamp is not set, time sync is needed") + return true + } + + ts, err := strconv.Atoi(builtTimestamp) + if err != nil { + logger.Warnf("Failed to parse built timestamp: %v", err) + return true + } + + // builtTimestamp is UNIX timestamp in seconds + builtTime := time.Unix(int64(ts), 0) + now := time.Now() + + logger.Tracef("Built time: %v, now: %v", builtTime, now) + + if now.Sub(builtTime) < 0 { + logger.Warnf("System time is behind the built time, time sync is needed") + return true + } + + return false +} + func TimeSyncLoop() { for { if !networkState.checked { @@ -40,6 +69,9 @@ func TimeSyncLoop() { continue } + // check if time sync is needed, but do nothing for now + isTimeSyncNeeded() + logger.Infof("Syncing system time") start := time.Now() err := SyncSystemTime() @@ -56,6 +88,7 @@ func TimeSyncLoop() { continue } + timeSyncSuccess = true logger.Infof("Time sync successful, now is: %v, time taken: %v", time.Now(), time.Since(start)) time.Sleep(timeSyncInterval) // after the first sync is done } diff --git a/ui/src/routes/devices.$id.settings.access._index.tsx b/ui/src/routes/devices.$id.settings.access._index.tsx index 0ed5862..8c944b7 100644 --- a/ui/src/routes/devices.$id.settings.access._index.tsx +++ b/ui/src/routes/devices.$id.settings.access._index.tsx @@ -18,6 +18,13 @@ import { isOnDevice } from "@/main"; import { LocalDevice } from "./devices.$id"; import { SettingsItem } from "./devices.$id.settings"; import { CloudState } from "./adopt"; +import { TextAreaWithLabel } from "@components/TextArea"; + +export interface TLSState { + mode: "selfsigned" | "custom" | "disabled"; + certificate?: string; + privateKey?: string; +}; export const loader = async () => { if (isOnDevice) { @@ -44,6 +51,9 @@ export default function SettingsAccessIndexRoute() { // Use a simple string identifier for the selected provider const [selectedProvider, setSelectedProvider] = useState("jetkvm"); + const [tlsMode, setTlsMode] = useState("self-signed"); + const [tlsCert, setTlsCert] = useState(""); + const [tlsKey, setTlsKey] = useState(""); const getCloudState = useCallback(() => { send("getCloudState", {}, resp => { @@ -66,6 +76,17 @@ export default function SettingsAccessIndexRoute() { }); }, [send]); + const getTLSState = useCallback(() => { + send("getTLSState", {}, resp => { + if ("error" in resp) return console.error(resp.error); + const tlsState = resp.result as TLSState; + + setTlsMode(tlsState.mode); + if (tlsState.certificate) setTlsCert(tlsState.certificate); + if (tlsState.privateKey) setTlsKey(tlsState.privateKey); + }); + }, [send]); + const deregisterDevice = async () => { send("deregisterDevice", {}, resp => { if ("error" in resp) { @@ -126,15 +147,51 @@ export default function SettingsAccessIndexRoute() { } }; + // Handle TLS mode change + const handleTlsModeChange = (value: string) => { + setTlsMode(value); + }; + + const handleTlsCertChange = (value: string) => { + setTlsCert(value); + }; + + const handleTlsKeyChange = (value: string) => { + setTlsKey(value); + }; + + const handleTlsUpdate = useCallback(() => { + send("setTLSState", { state: { mode: tlsMode, certificate: tlsCert, privateKey: tlsKey } as TLSState }, resp => { + if ("error" in resp) { + notifications.error(`Failed to update TLS settings: ${resp.error.data || "Unknown error"}`); + return; + } + + notifications.success("TLS settings updated successfully"); + }); + }, [send, tlsMode, tlsCert, tlsKey]); + + const handleReboot = useCallback(() => { + send("reboot", { force: false }, resp => { + if ("error" in resp) { + notifications.error(`Failed to reboot: ${resp.error.data || "Unknown error"}`); + return; + } + + notifications.success("Device will restart shortly, it might take a few seconds to boot up again."); + }); + }, [send]); + // Fetch device ID and cloud state on component mount useEffect(() => { getCloudState(); + getTLSState(); send("getDeviceID", {}, async resp => { if ("error" in resp) return console.error(resp.error); setDeviceId(resp.result as string); }); - }, [send, getCloudState]); + }, [send, getCloudState, getTLSState]); return (
@@ -150,30 +207,106 @@ export default function SettingsAccessIndexRoute() { title="Local" description="Manage the mode of local access to the device" /> - - {loaderData.authMode === "password" ? ( -
+ + + {tlsMode === "custom" && ( +
+
+ +
+ handleTlsCertChange(e.target.value)} + /> +
+ +
+
+ handleTlsKeyChange(e.target.value)} + /> +

+ Private key won't be shown again after saving. +

+
+
+
+
)} - + + {loaderData.authMode === "password" ? ( +