diff --git a/cloud.go b/cloud.go index f91085a..fe77482 100644 --- a/cloud.go +++ b/cloud.go @@ -11,6 +11,7 @@ import ( "time" "github.com/coder/websocket/wsjson" + "github.com/pion/webrtc/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -325,22 +326,63 @@ func runWebsocketClient() error { // ignore non-text messages continue } - var req WebRTCSessionRequest - err = json.Unmarshal(msg, &req) + + var message struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + } + + err = json.Unmarshal(msg, &message) if err != nil { cloudLogger.Warnf("unable to parse ws message: %v", string(msg)) continue } - cloudLogger.Infof("new session request: %v", req.OidcGoogle) - cloudLogger.Tracef("session request info: %v", req) + if message.Type == "offer" { + cloudLogger.Infof("new session request received") + var req WebRTCSessionRequest + err = json.Unmarshal(message.Data, &req) + if err != nil { + cloudLogger.Warnf("unable to parse session request data: %v", string(message.Data)) + continue + } - metricCloudConnectionSessionRequestCount.Inc() - metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime() - err = handleSessionRequest(runCtx, c, req) - if err != nil { - cloudLogger.Infof("error starting new session: %v", err) - continue + cloudLogger.Infof("new session request: %v", req.OidcGoogle) + cloudLogger.Tracef("session request info: %v", req) + + metricCloudConnectionSessionRequestCount.Inc() + metricCloudConnectionLastSessionRequestTimestamp.SetToCurrentTime() + err = handleSessionRequest(runCtx, c, req) + if err != nil { + cloudLogger.Infof("error starting new session: %v", err) + continue + } + } else if message.Type == "new-ice-candidate" { + cloudLogger.Infof("client has sent us a new ICE candidate: %v", string(message.Data)) + var candidate webrtc.ICECandidateInit + + // Attempt to unmarshal as a ICECandidateInit + if err := json.Unmarshal(message.Data, &candidate); err != nil { + cloudLogger.Warnf("unable to parse ICE candidate data: %v", string(message.Data)) + continue + } + + if candidate.Candidate == "" { + cloudLogger.Warnf("empty ICE candidate, skipping") + continue + } + + cloudLogger.Infof("unmarshalled ICE candidate: %v", candidate) + + if currentSession == nil { + cloudLogger.Infof("no current session, skipping ICE candidate") + continue + } + + cloudLogger.Infof("adding ICE candidate to current session: %v", candidate) + if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil { + cloudLogger.Warnf("failed to add ICE candidate: %v", err) + } } } } @@ -383,6 +425,7 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess ICEServers: req.ICEServers, LocalIP: req.IP, IsCloud: true, + ws: c, }) if err != nil { _ = wsjson.Write(context.Background(), c, gin.H{"error": err}) @@ -406,7 +449,7 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess cloudLogger.Info("new session accepted") cloudLogger.Tracef("new session accepted: %v", session) currentSession = session - _ = wsjson.Write(context.Background(), c, gin.H{"sd": sd}) + _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd}) return nil } diff --git a/ui/package-lock.json b/ui/package-lock.json index e9caa20..ebce148 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -30,6 +30,7 @@ "react-icons": "^5.4.0", "react-router-dom": "^6.22.3", "react-simple-keyboard": "^3.7.112", + "react-use-websocket": "^4.13.0", "react-xtermjs": "^1.0.9", "recharts": "^2.15.0", "tailwind-merge": "^2.5.5", @@ -5180,6 +5181,11 @@ "react-dom": ">=16.6.0" } }, + "node_modules/react-use-websocket": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/react-use-websocket/-/react-use-websocket-4.13.0.tgz", + "integrity": "sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw==" + }, "node_modules/react-xtermjs": { "version": "1.0.9", "resolved": "https://registry.npmjs.org/react-xtermjs/-/react-xtermjs-1.0.9.tgz", diff --git a/ui/package.json b/ui/package.json index f8f1c7a..a248616 100644 --- a/ui/package.json +++ b/ui/package.json @@ -40,6 +40,7 @@ "react-icons": "^5.4.0", "react-router-dom": "^6.22.3", "react-simple-keyboard": "^3.7.112", + "react-use-websocket": "^4.13.0", "react-xtermjs": "^1.0.9", "recharts": "^2.15.0", "tailwind-merge": "^2.5.5", diff --git a/ui/src/routes/devices.$id.tsx b/ui/src/routes/devices.$id.tsx index d2662fc..08ed023 100644 --- a/ui/src/routes/devices.$id.tsx +++ b/ui/src/routes/devices.$id.tsx @@ -14,6 +14,7 @@ import { import { useInterval } from "usehooks-ts"; import FocusTrap from "focus-trap-react"; import { motion, AnimatePresence } from "framer-motion"; +import useWebSocket, { ReadyState } from "react-use-websocket"; import { cx } from "@/cva.config"; import { @@ -117,7 +118,6 @@ const loader = async ({ params }: LoaderFunctionArgs) => { export default function KvmIdRoute() { const loaderResp = useLoaderData() as LocalLoaderResp | CloudLoaderResp; - // Depending on the mode, we set the appropriate variables const user = "user" in loaderResp ? loaderResp.user : null; const deviceName = "deviceName" in loaderResp ? loaderResp.deviceName : null; @@ -169,87 +169,70 @@ export default function KvmIdRoute() { useEffect(() => { connectionFailedRef.current = connectionFailed; }, [connectionFailed]); - const signalingAttempts = useRef(0); - const syncRemoteSessionDescription = useCallback( - async function syncRemoteSessionDescription(pc: RTCPeerConnection) { - try { - if (!pc) return; + const setRemoteSessionDescription = useCallback( + async function setRemoteSessionDescription( + pc: RTCPeerConnection, + remoteDescription: RTCSessionDescriptionInit, + ) { + setLoadingMessage("Setting remote description"); + pc.setRemoteDescription(new RTCSessionDescription(remoteDescription)); + console.log("Waiting for remote description to be set"); + let attempts = 0; - const sd = btoa(JSON.stringify(pc.localDescription)); + setLoadingMessage("Establishing secure connection..."); + const checkInterval = setInterval(() => { + attempts++; - const sessionUrl = isOnDevice - ? `${DEVICE_API}/webrtc/session` - : `${CLOUD_API}/webrtc/session`; - - console.log("Trying to get remote session description"); - setLoadingMessage( - `Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`, - ); - const res = await api.POST(sessionUrl, { - sd, - // When on device, we don't need to specify the device id, as it's already known - ...(isOnDevice ? {} : { id: params.id }), - }); - - const json = await res.json(); - if (res.status === 401) return navigate(isOnDevice ? "/login-local" : "/login"); - if (!res.ok) { - console.error("Error getting SDP", { status: res.status, json }); - throw new Error("Error getting SDP"); - } - - console.log("Successfully got Remote Session Description. Setting."); - setLoadingMessage("Setting remote session description..."); - - const decodedSd = atob(json.sd); - const parsedSd = JSON.parse(decodedSd); - pc.setRemoteDescription(new RTCSessionDescription(parsedSd)); - - await new Promise((resolve, reject) => { - console.log("Waiting for remote description to be set"); - const maxAttempts = 10; - const interval = 1000; - let attempts = 0; - - const checkInterval = setInterval(() => { - attempts++; - // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects - if (pc.sctp?.state === "connected") { - console.log("Remote description set"); - clearInterval(checkInterval); - resolve(true); - } else if (attempts >= maxAttempts) { - console.log( - `Failed to get remote description after ${maxAttempts} attempts`, - ); - closePeerConnection(); - clearInterval(checkInterval); - reject( - new Error( - `Failed to get remote description after ${maxAttempts} attempts`, - ), - ); - } else { - console.log("Waiting for remote description to be set"); - } - }, interval); - }); - } catch (error) { - console.error("Error getting SDP", { error }); - console.log("Connection failed", connectionFailedRef.current); - if (connectionFailedRef.current) return; - if (signalingAttempts.current < 5) { - signalingAttempts.current++; - await new Promise(resolve => setTimeout(resolve, 500)); - console.log("Attempting to get SDP again", signalingAttempts.current); - syncRemoteSessionDescription(pc); - } else { + // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects + if (pc.sctp?.state === "connected") { + console.log("Remote description set"); + clearInterval(checkInterval); + } else if (attempts >= 10) { + console.log("Failed to get remote description after 10 attempts"); closePeerConnection(); + clearInterval(checkInterval); + } else { + console.log("Waiting for remote description to be set"); } - } + }, 1000); }, - [closePeerConnection, navigate, params.id], + [closePeerConnection], + ); + + // TODO: Handle auth!!! The old signaling http request could get a 401 on local and on cloud + const { sendMessage, readyState } = useWebSocket( + isOnDevice ? `${DEVICE_API}/client` : `${CLOUD_API}/client?id=${params.id}`, + { + heartbeat: true, + onMessage: message => { + if (message.data === "pong") return; + if (!peerConnection) return; + + const parsedMessage = JSON.parse(message.data); + if (parsedMessage.type === "answer") { + console.log("Setting remote description", parsedMessage.data); + const sd = atob(parsedMessage.data); + const remoteSessionDescription = JSON.parse(sd); + + setRemoteSessionDescription( + peerConnection, + new RTCSessionDescription(remoteSessionDescription), + ); + } else if (parsedMessage.type === "new-ice-candidate") { + console.log("Received new ICE candidate", parsedMessage.data); + const candidate = parsedMessage.data; + peerConnection.addIceCandidate(candidate); + } + }, + }, + ); + + const sendWebRTCSignal = useCallback( + (type: string, data: string | RTCIceCandidate) => { + sendMessage(JSON.stringify({ type, data })); + }, + [sendMessage], ); const setupPeerConnection = useCallback(async () => { @@ -267,6 +250,7 @@ export default function KvmIdRoute() { ? { iceServers: [iceConfig?.iceServers] } : {}), }); + console.log("Peer connection created", pc); setLoadingMessage("Peer connection created"); } catch (e) { @@ -282,21 +266,24 @@ export default function KvmIdRoute() { console.log("Connection state changed", pc.connectionState); }; - pc.onicegatheringstatechange = event => { - const pc = event.currentTarget as RTCPeerConnection; - console.log("ICE Gathering State Changed", pc.iceGatheringState); - if (pc.iceGatheringState === "complete") { - console.log("ICE Gathering completed"); - setLoadingMessage("ICE Gathering completed"); - - // We can now start the https/ws connection to get the remote session description from the KVM device - syncRemoteSessionDescription(pc); - } else if (pc.iceGatheringState === "gathering") { - console.log("ICE Gathering Started"); - setLoadingMessage("Gathering ICE candidates..."); + pc.onnegotiationneeded = async () => { + try { + const offer = await pc.createOffer(); + await pc.setLocalDescription(offer); + const sd = btoa(JSON.stringify(pc.localDescription)); + sendWebRTCSignal("offer", sd); + } catch (e) { + console.error(`Error creating offer: ${e}`, new Date().toISOString()); + closePeerConnection(); } }; + pc.onicecandidate = async ({ candidate }) => { + if (!candidate) return; + if (candidate.candidate === "") return; + sendWebRTCSignal("new-ice-candidate", candidate); + }; + pc.ontrack = function (event) { setMediaMediaStream(event.streams[0]); }; @@ -314,31 +301,24 @@ export default function KvmIdRoute() { }; setPeerConnection(pc); - - try { - const offer = await pc.createOffer(); - await pc.setLocalDescription(offer); - } catch (e) { - console.error(`Error creating offer: ${e}`, new Date().toISOString()); - closePeerConnection(); - } }, [ closePeerConnection, iceConfig?.iceServers, + sendWebRTCSignal, setDiskChannel, setMediaMediaStream, setPeerConnection, setRpcDataChannel, setTransceiver, - syncRemoteSessionDescription, ]); // On boot, if the connection state is undefined, we connect to the WebRTC useEffect(() => { + if (readyState !== ReadyState.OPEN) return; if (peerConnection?.connectionState === undefined) { setupPeerConnection(); } - }, [setupPeerConnection, peerConnection?.connectionState]); + }, [readyState, setupPeerConnection, peerConnection?.connectionState]); // Cleanup effect const clearInboundRtpStats = useRTCStore(state => state.clearInboundRtpStats); @@ -593,7 +573,7 @@ export default function KvmIdRoute() { />