0
0
mirror of https://github.com/XTLS/REALITY.git synced 2025-08-23 15:08:37 +00:00

crypto/tls: add SessionState and use it on the server side

This change by itself is useless, because the application has no way to
access or provide SessionStates to crypto/tls, but they will be provided
in following CLs.

For #60105

Change-Id: I8d5de79b1eda0a778420134cf6f346246a1bb296
Reviewed-on: https://go-review.googlesource.com/c/go/+/496818
Reviewed-by: Marten Seemann <martenseemann@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
yuhan6665 2023-10-14 07:00:31 -04:00
parent 695d127f77
commit 5320b1b20a
3 changed files with 64 additions and 110 deletions

View File

@ -31,7 +31,7 @@ type serverHandshakeState struct {
ecSignOk bool ecSignOk bool
rsaDecryptOk bool rsaDecryptOk bool
rsaSignOk bool rsaSignOk bool
sessionState *sessionState sessionState *SessionState
finishedHash finishedHash finishedHash finishedHash
masterSecret []byte masterSecret []byte
cert *Certificate cert *Certificate
@ -410,11 +410,11 @@ func (hs *serverHandshakeState) checkForResumption() bool {
if plaintext == nil { if plaintext == nil {
return false return false
} }
hs.sessionState = &sessionState{} ss, err := ParseSessionState(plaintext)
ok := hs.sessionState.unmarshal(plaintext) if err != nil {
if !ok {
return false return false
} }
hs.sessionState = ss
// TLS 1.2 tickets don't natively have a lifetime, but we want to avoid // TLS 1.2 tickets don't natively have a lifetime, but we want to avoid
// re-wrapping the same master secret in different tickets over and over for // re-wrapping the same master secret in different tickets over and over for
@ -425,7 +425,7 @@ func (hs *serverHandshakeState) checkForResumption() bool {
} }
// Never resume a session for a different TLS version. // Never resume a session for a different TLS version.
if c.vers != hs.sessionState.vers { if c.vers != hs.sessionState.version {
return false return false
} }
@ -448,7 +448,7 @@ func (hs *serverHandshakeState) checkForResumption() bool {
return false return false
} }
sessionHasClientCerts := len(hs.sessionState.certificates) != 0 sessionHasClientCerts := len(hs.sessionState.certificate.Certificate) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth) needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts { if needClientCerts && !sessionHasClientCerts {
return false return false
@ -481,9 +481,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
return err return err
} }
if err := c.processCertsFromClient(Certificate{ if err := c.processCertsFromClient(hs.sessionState.certificate); err != nil {
Certificate: hs.sessionState.certificates,
}); err != nil {
return err return err
} }
@ -494,7 +492,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
} }
} }
hs.masterSecret = hs.sessionState.masterSecret hs.masterSecret = hs.sessionState.secret
return nil return nil
} }
@ -772,14 +770,18 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
for _, cert := range c.peerCertificates { for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw) certsFromClient = append(certsFromClient, cert.Raw)
} }
state := sessionState{ state := SessionState{
vers: c.vers, version: c.vers,
cipherSuite: hs.suite.id, cipherSuite: hs.suite.id,
createdAt: createdAt, createdAt: createdAt,
masterSecret: hs.masterSecret, secret: hs.masterSecret,
certificates: certsFromClient, certificate: Certificate{
Certificate: certsFromClient,
OCSPStaple: c.ocspResponse,
SignedCertificateTimestamps: c.scts,
},
} }
stateBytes, err := state.marshal() stateBytes, err := state.Bytes()
if err != nil { if err != nil {
return err return err
} }

View File

@ -331,8 +331,8 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
if plaintext == nil { if plaintext == nil {
continue continue
} }
sessionState := new(sessionStateTLS13) sessionState, err := ParseSessionState(plaintext)
if ok := sessionState.unmarshal(plaintext); !ok { if err != nil || sessionState.version != VersionTLS13 {
continue continue
} }
@ -362,9 +362,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
continue continue
} }
psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", hs.earlySecret = hs.suite.extract(sessionState.secret, nil)
nil, hs.suite.hash.Size())
hs.earlySecret = hs.suite.extract(psk, nil)
binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil)
// Clone the transcript in case a HelloRetryRequest was recorded. // Clone the transcript in case a HelloRetryRequest was recorded.
transcript := cloneHash(hs.transcript, hs.suite.hash) transcript := cloneHash(hs.transcript, hs.suite.hash)
@ -823,6 +821,10 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
resumptionSecret := hs.suite.deriveSecret(hs.masterSecret, resumptionSecret := hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript) resumptionLabel, hs.transcript)
// ticket_nonce, which must be unique per connection, is always left at
// zero because we only ever send one ticket per connection.
psk := hs.suite.expandLabel(resumptionSecret, "resumption",
nil, hs.suite.hash.Size())
m := new(newSessionTicketMsgTLS13) m := new(newSessionTicketMsgTLS13)
@ -830,17 +832,18 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
for _, cert := range c.peerCertificates { for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw) certsFromClient = append(certsFromClient, cert.Raw)
} }
state := sessionStateTLS13{ state := &SessionState{
version: c.vers,
cipherSuite: hs.suite.id, cipherSuite: hs.suite.id,
createdAt: uint64(c.config.time().Unix()), createdAt: uint64(c.config.time().Unix()),
resumptionSecret: resumptionSecret, secret: psk,
certificate: Certificate{ certificate: Certificate{
Certificate: certsFromClient, Certificate: certsFromClient,
OCSPStaple: c.ocspResponse, OCSPStaple: c.ocspResponse,
SignedCertificateTimestamps: c.scts, SignedCertificateTimestamps: c.scts,
}, },
} }
stateBytes, err := state.marshal() stateBytes, err := state.Bytes()
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return err return err
@ -861,9 +864,6 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
} }
m.ageAdd = binary.LittleEndian.Uint32(ageAdd) m.ageAdd = binary.LittleEndian.Uint32(ageAdd)
// ticket_nonce, which must be unique per connection, is always left at
// zero because we only ever send one ticket per connection.
if _, err := c.writeHandshakeRecord(m, nil); err != nil { if _, err := c.writeHandshakeRecord(m, nil); err != nil {
return err return err
} }

106
ticket.go
View File

@ -16,99 +16,51 @@ import (
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
) )
// sessionState contains the information that is serialized into a session // A SessionState is a resumable session.
// ticket in order to later resume a connection. type SessionState struct {
type sessionState struct { version uint16 // uint16 version;
vers uint16 // uint8 revision = 1;
cipherSuite uint16 cipherSuite uint16
createdAt uint64 createdAt uint64
masterSecret []byte // opaque master_secret<1..2^16-1>; secret []byte // opaque master_secret<1..2^8-1>;
// struct { opaque certificate<1..2^24-1> } Certificate;
certificates [][]byte // Certificate certificate_list<0..2^24-1>;
}
func (m *sessionState) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(m.vers)
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.masterSecret)
})
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for _, cert := range m.certificates {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
}
})
return b.Bytes()
}
func (m *sessionState) unmarshal(data []byte) bool {
*m = sessionState{}
s := cryptobyte.String(data)
if ok := s.ReadUint16(&m.vers) &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint16LengthPrefixed(&s, &m.masterSecret) &&
len(m.masterSecret) != 0; !ok {
return false
}
var certList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&certList) {
return false
}
for !certList.Empty() {
var cert []byte
if !readUint24LengthPrefixed(&certList, &cert) {
return false
}
m.certificates = append(m.certificates, cert)
}
return s.Empty()
}
// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
// version (revision = 0) doesn't carry any of the information needed for 0-RTT
// validation and the nonce is always empty.
type sessionStateTLS13 struct {
// uint8 version = 0x0304;
// uint8 revision = 0;
cipherSuite uint16
createdAt uint64
resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>;
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
} }
func (m *sessionStateTLS13) marshal() ([]byte, error) { // Bytes encodes the session, including any private fields, so that it can be
// parsed by [ParseSessionState]. The encoding contains secret values.
//
// The specific encoding should be considered opaque and may change incompatibly
// between Go versions.
func (m *SessionState) Bytes() ([]byte, error) {
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint16(VersionTLS13) b.AddUint16(m.version)
b.AddUint8(0) // revision b.AddUint8(1) // revision
b.AddUint16(m.cipherSuite) b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt) addUint64(&b, m.createdAt)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.resumptionSecret) b.AddBytes(m.secret)
}) })
marshalCertificate(&b, m.certificate) marshalCertificate(&b, m.certificate)
return b.Bytes() return b.Bytes()
} }
func (m *sessionStateTLS13) unmarshal(data []byte) bool { // ParseSessionState parses a [SessionState] encoded by [SessionState.Bytes].
*m = sessionStateTLS13{} func ParseSessionState(data []byte) (*SessionState, error) {
ss := &SessionState{}
s := cryptobyte.String(data) s := cryptobyte.String(data)
var version uint16
var revision uint8 var revision uint8
return s.ReadUint16(&version) && if !s.ReadUint16(&ss.version) ||
version == VersionTLS13 && !s.ReadUint8(&revision) ||
s.ReadUint8(&revision) && revision != 1 ||
revision == 0 && !s.ReadUint16(&ss.cipherSuite) ||
s.ReadUint16(&m.cipherSuite) && !readUint64(&s, &ss.createdAt) ||
readUint64(&s, &m.createdAt) && !readUint8LengthPrefixed(&s, &ss.secret) ||
readUint8LengthPrefixed(&s, &m.resumptionSecret) && len(ss.secret) == 0 ||
len(m.resumptionSecret) != 0 && !unmarshalCertificate(&s, &ss.certificate) ||
unmarshalCertificate(&s, &m.certificate) && !s.Empty() {
s.Empty() return nil, errors.New("tls: invalid session encoding")
}
return ss, nil
} }
func (c *Conn) encryptTicket(state []byte) ([]byte, error) { func (c *Conn) encryptTicket(state []byte) ([]byte, error) {