kvm/vendor/github.com/pion/dtls/v3/state.go

301 lines
8.5 KiB
Go

// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package dtls
import (
"bytes"
"encoding/gob"
"errors"
"sync/atomic"
"github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/crypto/prf"
"github.com/pion/dtls/v3/pkg/crypto/signaturehash"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/transport/v3/replaydetector"
)
// State holds the dtls connection state and implements both encoding.BinaryMarshaler and
// encoding.BinaryUnmarshaler.
type State struct {
localEpoch, remoteEpoch atomic.Value
localSequenceNumber []uint64 // uint48
localRandom, remoteRandom handshake.Random
masterSecret []byte
cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen
CipherSuiteID CipherSuiteID
srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
remoteSRTPMasterKeyIdentifier []byte
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
// Connection Identifiers must be negotiated afresh on session resumption.
// https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension
// localConnectionID is the locally generated connection ID that is expected
// to be received from the remote endpoint.
// For a server, this is the connection ID sent in ServerHello.
// For a client, this is the connection ID sent in the ClientHello.
localConnectionID atomic.Value
// remoteConnectionID is the connection ID that the remote endpoint
// specifies should be sent.
// For a server, this is the connection ID received in the ClientHello.
// For a client, this is the connection ID received in the ServerHello.
remoteConnectionID []byte
isClient bool
preMasterSecret []byte
extendedMasterSecret bool
namedCurve elliptic.Curve
localKeypair *elliptic.Keypair
cookie []byte
handshakeSendSequence int
handshakeRecvSequence int
serverName string
remoteCertRequestAlgs []signaturehash.Algorithm
remoteRequestedCertificate bool // Did we get a CertificateRequest
localCertificatesVerify []byte // cache CertificateVerify
localVerifyData []byte // cached VerifyData
localKeySignature []byte // cached keySignature
peerCertificatesVerified bool
replayDetector []replaydetector.ReplayDetector
peerSupportedProtocols []string
NegotiatedProtocol string
}
type serializedState struct {
LocalEpoch uint16
RemoteEpoch uint16
LocalRandom [handshake.RandomLength]byte
RemoteRandom [handshake.RandomLength]byte
CipherSuiteID uint16
MasterSecret []byte
SequenceNumber uint64
SRTPProtectionProfile uint16
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
LocalConnectionID []byte
RemoteConnectionID []byte
IsClient bool
NegotiatedProtocol string
}
var errCipherSuiteNotSet = &InternalError{Err: errors.New("cipher suite not set")} //nolint:goerr113
func (s *State) clone() (*State, error) {
serialized, err := s.serialize()
if err != nil {
return nil, err
}
state := &State{}
state.deserialize(*serialized)
return state, err
}
func (s *State) serialize() (*serializedState, error) {
if s.cipherSuite == nil {
return nil, errCipherSuiteNotSet
}
cipherSuiteID := uint16(s.cipherSuite.ID())
// Marshal random values
localRnd := s.localRandom.MarshalFixed()
remoteRnd := s.remoteRandom.MarshalFixed()
epoch := s.getLocalEpoch()
return &serializedState{
LocalEpoch: s.getLocalEpoch(),
RemoteEpoch: s.getRemoteEpoch(),
CipherSuiteID: cipherSuiteID,
MasterSecret: s.masterSecret,
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
LocalRandom: localRnd,
RemoteRandom: remoteRnd,
SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()),
PeerCertificates: s.PeerCertificates,
IdentityHint: s.IdentityHint,
SessionID: s.SessionID,
LocalConnectionID: s.getLocalConnectionID(),
RemoteConnectionID: s.remoteConnectionID,
IsClient: s.isClient,
NegotiatedProtocol: s.NegotiatedProtocol,
}, nil
}
func (s *State) deserialize(serialized serializedState) {
// Set epoch values
epoch := serialized.LocalEpoch
s.localEpoch.Store(serialized.LocalEpoch)
s.remoteEpoch.Store(serialized.RemoteEpoch)
for len(s.localSequenceNumber) <= int(epoch) {
s.localSequenceNumber = append(s.localSequenceNumber, uint64(0))
}
// Set random values
localRandom := &handshake.Random{}
localRandom.UnmarshalFixed(serialized.LocalRandom)
s.localRandom = *localRandom
remoteRandom := &handshake.Random{}
remoteRandom.UnmarshalFixed(serialized.RemoteRandom)
s.remoteRandom = *remoteRandom
s.isClient = serialized.IsClient
// Set master secret
s.masterSecret = serialized.MasterSecret
// Set cipher suite
s.CipherSuiteID = CipherSuiteID(serialized.CipherSuiteID)
s.cipherSuite = cipherSuiteForID(s.CipherSuiteID, nil)
atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber)
s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile))
// Set remote certificate
s.PeerCertificates = serialized.PeerCertificates
s.IdentityHint = serialized.IdentityHint
// Set local and remote connection IDs
s.setLocalConnectionID(serialized.LocalConnectionID)
s.remoteConnectionID = serialized.RemoteConnectionID
s.SessionID = serialized.SessionID
s.NegotiatedProtocol = serialized.NegotiatedProtocol
}
func (s *State) initCipherSuite() error {
if s.cipherSuite.IsInitialized() {
return nil
}
localRandom := s.localRandom.MarshalFixed()
remoteRandom := s.remoteRandom.MarshalFixed()
var err error
if s.isClient {
err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true)
} else {
err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false)
}
if err != nil {
return err
}
return nil
}
// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation.
func (s *State) MarshalBinary() ([]byte, error) {
serialized, err := s.serialize()
if err != nil {
return nil, err
}
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(*serialized); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation.
func (s *State) UnmarshalBinary(data []byte) error {
enc := gob.NewDecoder(bytes.NewBuffer(data))
var serialized serializedState
if err := enc.Decode(&serialized); err != nil {
return err
}
s.deserialize(serialized)
return s.initCipherSuite()
}
// ExportKeyingMaterial returns length bytes of exported key material in a new
// slice as defined in RFC 5705.
// This allows protocols to use DTLS for key establishment, but
// then use some of the keying material for their own purposes.
func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
if s.getLocalEpoch() == 0 {
return nil, errHandshakeInProgress
} else if len(context) != 0 {
return nil, errContextUnsupported
} else if _, ok := invalidKeyingLabels()[label]; ok {
return nil, errReservedExportKeyingMaterial
}
localRandom := s.localRandom.MarshalFixed()
remoteRandom := s.remoteRandom.MarshalFixed()
seed := []byte(label)
if s.isClient {
seed = append(append(seed, localRandom[:]...), remoteRandom[:]...)
} else {
seed = append(append(seed, remoteRandom[:]...), localRandom[:]...)
}
return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc())
}
func (s *State) getRemoteEpoch() uint16 {
if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok {
return remoteEpoch
}
return 0
}
func (s *State) getLocalEpoch() uint16 {
if localEpoch, ok := s.localEpoch.Load().(uint16); ok {
return localEpoch
}
return 0
}
func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) {
s.srtpProtectionProfile.Store(profile)
}
func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok {
return val
}
return 0
}
func (s *State) getLocalConnectionID() []byte {
if val, ok := s.localConnectionID.Load().([]byte); ok {
return val
}
return nil
}
func (s *State) setLocalConnectionID(v []byte) {
s.localConnectionID.Store(v)
}
// RemoteRandomBytes returns the remote client hello random bytes.
func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte {
return s.remoteRandom.RandomBytes
}