diff --git a/handshake_server.go b/handshake_server.go index fd78e13..3d24bc2 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -31,7 +31,7 @@ type serverHandshakeState struct { ecSignOk bool rsaDecryptOk bool rsaSignOk bool - sessionState *sessionState + sessionState *SessionState finishedHash finishedHash masterSecret []byte cert *Certificate @@ -410,11 +410,11 @@ func (hs *serverHandshakeState) checkForResumption() bool { if plaintext == nil { return false } - hs.sessionState = &sessionState{} - ok := hs.sessionState.unmarshal(plaintext) - if !ok { + ss, err := ParseSessionState(plaintext) + if err != nil { return false } + hs.sessionState = ss // TLS 1.2 tickets don't natively have a lifetime, but we want to avoid // re-wrapping the same master secret in different tickets over and over for @@ -425,7 +425,7 @@ func (hs *serverHandshakeState) checkForResumption() bool { } // Never resume a session for a different TLS version. - if c.vers != hs.sessionState.vers { + if c.vers != hs.sessionState.version { return false } @@ -448,7 +448,7 @@ func (hs *serverHandshakeState) checkForResumption() bool { return false } - sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + sessionHasClientCerts := len(hs.sessionState.certificate.Certificate) != 0 needClientCerts := requiresClientCert(c.config.ClientAuth) if needClientCerts && !sessionHasClientCerts { return false @@ -481,9 +481,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error { return err } - if err := c.processCertsFromClient(Certificate{ - Certificate: hs.sessionState.certificates, - }); err != nil { + if err := c.processCertsFromClient(hs.sessionState.certificate); err != nil { return err } @@ -494,7 +492,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error { } } - hs.masterSecret = hs.sessionState.masterSecret + hs.masterSecret = hs.sessionState.secret return nil } @@ -772,14 +770,18 @@ func (hs *serverHandshakeState) sendSessionTicket() error { for _, cert := range c.peerCertificates { certsFromClient = append(certsFromClient, cert.Raw) } - state := sessionState{ - vers: c.vers, - cipherSuite: hs.suite.id, - createdAt: createdAt, - masterSecret: hs.masterSecret, - certificates: certsFromClient, + state := SessionState{ + version: c.vers, + cipherSuite: hs.suite.id, + createdAt: createdAt, + secret: hs.masterSecret, + certificate: Certificate{ + Certificate: certsFromClient, + OCSPStaple: c.ocspResponse, + SignedCertificateTimestamps: c.scts, + }, } - stateBytes, err := state.marshal() + stateBytes, err := state.Bytes() if err != nil { return err } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 9b9a7a9..7589d99 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -331,8 +331,8 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { if plaintext == nil { continue } - sessionState := new(sessionStateTLS13) - if ok := sessionState.unmarshal(plaintext); !ok { + sessionState, err := ParseSessionState(plaintext) + if err != nil || sessionState.version != VersionTLS13 { continue } @@ -362,9 +362,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { continue } - psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", - nil, hs.suite.hash.Size()) - hs.earlySecret = hs.suite.extract(psk, nil) + hs.earlySecret = hs.suite.extract(sessionState.secret, nil) binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) // Clone the transcript in case a HelloRetryRequest was recorded. transcript := cloneHash(hs.transcript, hs.suite.hash) @@ -823,6 +821,10 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { resumptionSecret := hs.suite.deriveSecret(hs.masterSecret, resumptionLabel, hs.transcript) + // ticket_nonce, which must be unique per connection, is always left at + // zero because we only ever send one ticket per connection. + psk := hs.suite.expandLabel(resumptionSecret, "resumption", + nil, hs.suite.hash.Size()) m := new(newSessionTicketMsgTLS13) @@ -830,17 +832,18 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { for _, cert := range c.peerCertificates { certsFromClient = append(certsFromClient, cert.Raw) } - state := sessionStateTLS13{ - cipherSuite: hs.suite.id, - createdAt: uint64(c.config.time().Unix()), - resumptionSecret: resumptionSecret, + state := &SessionState{ + version: c.vers, + cipherSuite: hs.suite.id, + createdAt: uint64(c.config.time().Unix()), + secret: psk, certificate: Certificate{ Certificate: certsFromClient, OCSPStaple: c.ocspResponse, SignedCertificateTimestamps: c.scts, }, } - stateBytes, err := state.marshal() + stateBytes, err := state.Bytes() if err != nil { c.sendAlert(alertInternalError) return err @@ -861,9 +864,6 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { } m.ageAdd = binary.LittleEndian.Uint32(ageAdd) - // ticket_nonce, which must be unique per connection, is always left at - // zero because we only ever send one ticket per connection. - if _, err := c.writeHandshakeRecord(m, nil); err != nil { return err } diff --git a/ticket.go b/ticket.go index 918e78a..004b7da 100644 --- a/ticket.go +++ b/ticket.go @@ -16,99 +16,51 @@ import ( "golang.org/x/crypto/cryptobyte" ) -// sessionState contains the information that is serialized into a session -// ticket in order to later resume a connection. -type sessionState struct { - vers uint16 - cipherSuite uint16 - createdAt uint64 - masterSecret []byte // opaque master_secret<1..2^16-1>; - // struct { opaque certificate<1..2^24-1> } Certificate; - certificates [][]byte // Certificate certificate_list<0..2^24-1>; +// A SessionState is a resumable session. +type SessionState struct { + version uint16 // uint16 version; + // uint8 revision = 1; + cipherSuite uint16 + createdAt uint64 + secret []byte // opaque master_secret<1..2^8-1>; + certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; } -func (m *sessionState) marshal() ([]byte, error) { +// Bytes encodes the session, including any private fields, so that it can be +// parsed by [ParseSessionState]. The encoding contains secret values. +// +// The specific encoding should be considered opaque and may change incompatibly +// between Go versions. +func (m *SessionState) Bytes() ([]byte, error) { var b cryptobyte.Builder - b.AddUint16(m.vers) - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.masterSecret) - }) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for _, cert := range m.certificates { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - } - }) - return b.Bytes() -} - -func (m *sessionState) unmarshal(data []byte) bool { - *m = sessionState{} - s := cryptobyte.String(data) - if ok := s.ReadUint16(&m.vers) && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint16LengthPrefixed(&s, &m.masterSecret) && - len(m.masterSecret) != 0; !ok { - return false - } - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - if !readUint24LengthPrefixed(&certList, &cert) { - return false - } - m.certificates = append(m.certificates, cert) - } - return s.Empty() -} - -// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first -// version (revision = 0) doesn't carry any of the information needed for 0-RTT -// validation and the nonce is always empty. -type sessionStateTLS13 struct { - // uint8 version = 0x0304; - // uint8 revision = 0; - cipherSuite uint16 - createdAt uint64 - resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; - certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; -} - -func (m *sessionStateTLS13) marshal() ([]byte, error) { - var b cryptobyte.Builder - b.AddUint16(VersionTLS13) - b.AddUint8(0) // revision + b.AddUint16(m.version) + b.AddUint8(1) // revision b.AddUint16(m.cipherSuite) addUint64(&b, m.createdAt) b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.resumptionSecret) + b.AddBytes(m.secret) }) marshalCertificate(&b, m.certificate) return b.Bytes() } -func (m *sessionStateTLS13) unmarshal(data []byte) bool { - *m = sessionStateTLS13{} +// ParseSessionState parses a [SessionState] encoded by [SessionState.Bytes]. +func ParseSessionState(data []byte) (*SessionState, error) { + ss := &SessionState{} s := cryptobyte.String(data) - var version uint16 var revision uint8 - return s.ReadUint16(&version) && - version == VersionTLS13 && - s.ReadUint8(&revision) && - revision == 0 && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint8LengthPrefixed(&s, &m.resumptionSecret) && - len(m.resumptionSecret) != 0 && - unmarshalCertificate(&s, &m.certificate) && - s.Empty() + if !s.ReadUint16(&ss.version) || + !s.ReadUint8(&revision) || + revision != 1 || + !s.ReadUint16(&ss.cipherSuite) || + !readUint64(&s, &ss.createdAt) || + !readUint8LengthPrefixed(&s, &ss.secret) || + len(ss.secret) == 0 || + !unmarshalCertificate(&s, &ss.certificate) || + !s.Empty() { + return nil, errors.New("tls: invalid session encoding") + } + return ss, nil } func (c *Conn) encryptTicket(state []byte) ([]byte, error) {