diff --git a/cloud.go b/cloud.go index 5088ec7..3520e2f 100644 --- a/cloud.go +++ b/cloud.go @@ -7,13 +7,14 @@ import ( "fmt" "net/http" "net/url" - "github.com/coder/websocket/wsjson" "time" + "github.com/coder/websocket/wsjson" + "github.com/coreos/go-oidc/v3/oidc" - "github.com/gin-gonic/gin" "github.com/coder/websocket" + "github.com/gin-gonic/gin" ) type CloudRegisterRequest struct { @@ -192,7 +193,11 @@ func handleSessionRequest(ctx context.Context, c *websocket.Conn, req WebRTCSess return fmt.Errorf("google identity mismatch") } - session, err := newSession() + session, err := newSession(SessionConfig{ + ICEServers: req.ICEServers, + LocalIP: req.IP, + IsCloud: true, + }) if err != nil { _ = wsjson.Write(context.Background(), c, gin.H{"error": err}) return err diff --git a/web.go b/web.go index 64f8de7..02c7eea 100644 --- a/web.go +++ b/web.go @@ -17,8 +17,10 @@ import ( var staticFiles embed.FS type WebRTCSessionRequest struct { - Sd string `json:"sd"` - OidcGoogle string `json:"OidcGoogle,omitempty"` + Sd string `json:"sd"` + OidcGoogle string `json:"OidcGoogle,omitempty"` + IP string `json:"ip,omitempty"` + ICEServers []string `json:"iceServers,omitempty"` } type SetPasswordRequest struct { @@ -116,7 +118,7 @@ func handleWebRTCSession(c *gin.Context) { return } - session, err := newSession() + session, err := newSession(SessionConfig{}) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err}) return diff --git a/webrtc.go b/webrtc.go index 20ffb99..27084fc 100644 --- a/webrtc.go +++ b/webrtc.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net" "strings" "github.com/pion/webrtc/v4" @@ -19,6 +20,12 @@ type Session struct { shouldUmountVirtualMedia bool } +type SessionConfig struct { + ICEServers []string + LocalIP string + IsCloud bool +} + func (s *Session) ExchangeOffer(offerStr string) (string, error) { b, err := base64.StdEncoding.DecodeString(offerStr) if err != nil { @@ -61,9 +68,29 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) { return base64.StdEncoding.EncodeToString(localDescription), nil } -func newSession() (*Session, error) { - peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{ - ICEServers: []webrtc.ICEServer{{}}, +func newSession(config SessionConfig) (*Session, error) { + webrtcSettingEngine := webrtc.SettingEngine{} + iceServer := webrtc.ICEServer{} + + if config.IsCloud { + if config.ICEServers == nil { + fmt.Printf("ICE Servers not provided by cloud") + } else { + iceServer.URLs = config.ICEServers + fmt.Printf("Using ICE Servers provided by cloud: %v\n", iceServer.URLs) + } + + if config.LocalIP == "" || net.ParseIP(config.LocalIP) == nil { + fmt.Printf("Local IP address %v not provided or invalid, won't set NAT1To1IPs\n", config.LocalIP) + } else { + webrtcSettingEngine.SetNAT1To1IPs([]string{config.LocalIP}, webrtc.ICECandidateTypeSrflx) + fmt.Printf("Setting NAT1To1IPs to %s\n", config.LocalIP) + } + } + + api := webrtc.NewAPI(webrtc.WithSettingEngine(webrtcSettingEngine)) + peerConnection, err := api.NewPeerConnection(webrtc.Configuration{ + ICEServers: []webrtc.ICEServer{iceServer}, }) if err != nil { return nil, err