0
0
mirror of https://github.com/XTLS/REALITY.git synced 2025-08-22 22:48:36 +00:00

crypto/tls: add WrapSession and UnwrapSession

There was a bug in TestResumption: the first ExpiredSessionTicket was
inserting a ticket far in the future, so the second ExpiredSessionTicket
wasn't actually supposed to fail. However, there was a bug in
checkForResumption->sendSessionTicket, too: if a session was not resumed
because it was too old, its createdAt was still persisted in the next
ticket. The two bugs used to cancel each other out.

For #60105
Fixes #19199

Change-Id: Ic9b2aab943dcbf0de62b8758a6195319dc286e2f
Reviewed-on: https://go-review.googlesource.com/c/go/+/496821
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
yuhan6665 2023-10-14 08:04:22 -04:00
parent a10c90ae93
commit 600acbbb6b
4 changed files with 162 additions and 51 deletions

View File

@ -687,6 +687,35 @@ type Config struct {
// session resumption. It is only used by clients. // session resumption. It is only used by clients.
ClientSessionCache ClientSessionCache ClientSessionCache ClientSessionCache
// UnwrapSession is called on the server to turn a ticket/identity
// previously produced by [WrapSession] into a usable session.
//
// UnwrapSession will usually either decrypt a session state in the ticket
// (for example with [Config.EncryptTicket]), or use the ticket as a handle
// to recover a previously stored state. It must use [ParseSessionState] to
// deserialize the session state.
//
// If UnwrapSession returns an error, the connection is terminated. If it
// returns (nil, nil), the session is ignored. crypto/tls may still choose
// not to resume the returned session.
UnwrapSession func(identity []byte, cs ConnectionState) (*SessionState, error)
// WrapSession is called on the server to produce a session ticket/identity.
//
// WrapSession must serialize the session state with [SessionState.Bytes].
// It may then encrypt the serialized state (for example with
// [Config.DecryptTicket]) and use it as the ticket, or store the state and
// return a handle for it.
//
// If WrapSession returns an error, the connection is terminated.
//
// Warning: the return value will be exposed on the wire and to clients in
// plaintext. The application is in charge of encrypting and authenticating
// it (and rotating keys) or returning high-entropy identifiers. Failing to
// do so correctly can compromise current, previous, and future connections
// depending on the protocol version.
WrapSession func(ConnectionState, *SessionState) ([]byte, error)
// MinVersion contains the minimum TLS version that is acceptable. // MinVersion contains the minimum TLS version that is acceptable.
// //
// By default, TLS 1.2 is currently used as the minimum when acting as a // By default, TLS 1.2 is currently used as the minimum when acting as a
@ -819,6 +848,8 @@ func (c *Config) Clone() *Config {
SessionTicketsDisabled: c.SessionTicketsDisabled, SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey, SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache, ClientSessionCache: c.ClientSessionCache,
UnwrapSession: c.UnwrapSession,
WrapSession: c.WrapSession,
MinVersion: c.MinVersion, MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion, MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences, CurvePreferences: c.CurvePreferences,

View File

@ -70,9 +70,11 @@ func (hs *serverHandshakeState) handshake() error {
// For an overview of TLS handshaking, see RFC 5246, Section 7.3. // For an overview of TLS handshaking, see RFC 5246, Section 7.3.
c.buffering = true c.buffering = true
if hs.checkForResumption() { if err := hs.checkForResumption(); err != nil {
return err
}
if hs.sessionState != nil {
// The client has included a session ticket and so we do an abbreviated handshake. // The client has included a session ticket and so we do an abbreviated handshake.
c.didResume = true
if err := hs.doResumeHandshake(); err != nil { if err := hs.doResumeHandshake(); err != nil {
return err return err
} }
@ -399,65 +401,80 @@ func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
} }
// checkForResumption reports whether we should perform resumption on this connection. // checkForResumption reports whether we should perform resumption on this connection.
func (hs *serverHandshakeState) checkForResumption() bool { func (hs *serverHandshakeState) checkForResumption() error {
c := hs.c c := hs.c
if c.config.SessionTicketsDisabled { if c.config.SessionTicketsDisabled {
return false return nil
} }
plaintext := c.decryptTicket(hs.clientHello.sessionTicket) var sessionState *SessionState
if plaintext == nil { if c.config.UnwrapSession != nil {
return false ss, err := c.config.UnwrapSession(hs.clientHello.sessionTicket, c.connectionStateLocked())
if err != nil {
return err
}
if ss == nil {
return nil
}
sessionState = ss
} else {
plaintext := c.config.decryptTicket(hs.clientHello.sessionTicket, c.ticketKeys)
if plaintext == nil {
return nil
}
ss, err := ParseSessionState(plaintext)
if err != nil {
return nil
}
sessionState = ss
} }
ss, err := ParseSessionState(plaintext)
if err != nil {
return false
}
hs.sessionState = ss
// TLS 1.2 tickets don't natively have a lifetime, but we want to avoid // TLS 1.2 tickets don't natively have a lifetime, but we want to avoid
// re-wrapping the same master secret in different tickets over and over for // re-wrapping the same master secret in different tickets over and over for
// too long, weakening forward secrecy. // too long, weakening forward secrecy.
createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) createdAt := time.Unix(int64(sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
return false return nil
} }
// Never resume a session for a different TLS version. // Never resume a session for a different TLS version.
if c.vers != hs.sessionState.version { if c.vers != sessionState.version {
return false return nil
} }
cipherSuiteOk := false cipherSuiteOk := false
// Check that the client is still offering the ciphersuite in the session. // Check that the client is still offering the ciphersuite in the session.
for _, id := range hs.clientHello.cipherSuites { for _, id := range hs.clientHello.cipherSuites {
if id == hs.sessionState.cipherSuite { if id == sessionState.cipherSuite {
cipherSuiteOk = true cipherSuiteOk = true
break break
} }
} }
if !cipherSuiteOk { if !cipherSuiteOk {
return false return nil
} }
// Check that we also support the ciphersuite from the session. // Check that we also support the ciphersuite from the session.
hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, suite := selectCipherSuite([]uint16{sessionState.cipherSuite},
c.config.cipherSuites(), hs.cipherSuiteOk) c.config.cipherSuites(), hs.cipherSuiteOk)
if hs.suite == nil { if suite == nil {
return false return nil
} }
sessionHasClientCerts := len(hs.sessionState.peerCertificates) != 0 sessionHasClientCerts := len(sessionState.peerCertificates) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth) needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts { if needClientCerts && !sessionHasClientCerts {
return false return nil
} }
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
return false return nil
} }
return true hs.sessionState = sessionState
hs.suite = suite
c.didResume = true
return nil
} }
func (hs *serverHandshakeState) doResumeHandshake() error { func (hs *serverHandshakeState) doResumeHandshake() error {
@ -769,13 +786,20 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
// the original time it was created. // the original time it was created.
state.createdAt = hs.sessionState.createdAt state.createdAt = hs.sessionState.createdAt
} }
stateBytes, err := state.Bytes() if c.config.WrapSession != nil {
if err != nil { m.ticket, err = c.config.WrapSession(c.connectionStateLocked(), state)
return err if err != nil {
} return err
m.ticket, err = c.encryptTicket(stateBytes) }
if err != nil { } else {
return err stateBytes, err := state.Bytes()
if err != nil {
return err
}
m.ticket, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
if err != nil {
return err
}
} }
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {

View File

@ -327,12 +327,29 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
break break
} }
plaintext := c.decryptTicket(identity.label) var sessionState *SessionState
if plaintext == nil { if c.config.UnwrapSession != nil {
continue var err error
sessionState, err = c.config.UnwrapSession(identity.label, c.connectionStateLocked())
if err != nil {
return err
}
if sessionState == nil {
continue
}
} else {
plaintext := c.config.decryptTicket(identity.label, c.ticketKeys)
if plaintext == nil {
continue
}
var err error
sessionState, err = ParseSessionState(plaintext)
if err != nil {
continue
}
} }
sessionState, err := ParseSessionState(plaintext)
if err != nil || sessionState.version != VersionTLS13 { if sessionState.version != VersionTLS13 {
continue continue
} }
@ -833,14 +850,21 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
return err return err
} }
state.secret = psk state.secret = psk
stateBytes, err := state.Bytes() if c.config.WrapSession != nil {
if err != nil { m.label, err = c.config.WrapSession(c.connectionStateLocked(), state)
c.sendAlert(alertInternalError) if err != nil {
return err return err
} }
m.label, err = c.encryptTicket(stateBytes) } else {
if err != nil { stateBytes, err := state.Bytes()
return err if err != nil {
c.sendAlert(alertInternalError)
return err
}
m.label, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
if err != nil {
return err
}
} }
m.lifetime = uint32(maxSessionTicketLifetime / time.Second) m.lifetime = uint32(maxSessionTicketLifetime / time.Second)

View File

@ -228,8 +228,21 @@ func (c *Conn) sessionState() (*SessionState, error) {
}, nil }, nil
} }
func (c *Conn) encryptTicket(state []byte) ([]byte, error) { // EncryptTicket encrypts a ticket with the Config's configured (or default)
if len(c.ticketKeys) == 0 { // session ticket keys. It can be used as a [Config.WrapSession] implementation.
func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) {
ticketKeys := c.ticketKeys(nil)
stateBytes, err := ss.Bytes()
if err != nil {
return nil, err
}
return c.encryptTicket(stateBytes, ticketKeys)
}
var _ = &Config{WrapSession: (&Config{}).EncryptTicket}
func (c *Config) encryptTicket(state []byte, ticketKeys []ticketKey) ([]byte, error) {
if len(ticketKeys) == 0 {
return nil, errors.New("tls: internal error: session ticket keys unavailable") return nil, errors.New("tls: internal error: session ticket keys unavailable")
} }
@ -239,10 +252,10 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
authenticated := encrypted[:len(encrypted)-sha256.Size] authenticated := encrypted[:len(encrypted)-sha256.Size]
macBytes := encrypted[len(encrypted)-sha256.Size:] macBytes := encrypted[len(encrypted)-sha256.Size:]
if _, err := io.ReadFull(c.config.rand(), iv); err != nil { if _, err := io.ReadFull(c.rand(), iv); err != nil {
return nil, err return nil, err
} }
key := c.ticketKeys[0] key := ticketKeys[0]
block, err := aes.NewCipher(key.aesKey[:]) block, err := aes.NewCipher(key.aesKey[:])
if err != nil { if err != nil {
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
@ -256,7 +269,26 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
return encrypted, nil return encrypted, nil
} }
func (c *Conn) decryptTicket(encrypted []byte) []byte { // DecryptTicket decrypts a ticket encrypted by [Config.EncryptTicket]. It can
// be used as a [Config.UnwrapSession] implementation.
//
// If the ticket can't be decrypted or parsed, DecryptTicket returns (nil, nil).
func (c *Config) DecryptTicket(identity []byte, cs ConnectionState) (*SessionState, error) {
ticketKeys := c.ticketKeys(nil)
stateBytes := c.decryptTicket(identity, ticketKeys)
if stateBytes == nil {
return nil, nil
}
s, err := ParseSessionState(stateBytes)
if err != nil {
return nil, nil // drop unparsable tickets on the floor
}
return s, nil
}
var _ = &Config{UnwrapSession: (&Config{}).DecryptTicket}
func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte {
if len(encrypted) < aes.BlockSize+sha256.Size { if len(encrypted) < aes.BlockSize+sha256.Size {
return nil return nil
} }
@ -265,7 +297,7 @@ func (c *Conn) decryptTicket(encrypted []byte) []byte {
ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size] ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
authenticated := encrypted[:len(encrypted)-sha256.Size] authenticated := encrypted[:len(encrypted)-sha256.Size]
macBytes := encrypted[len(encrypted)-sha256.Size:] macBytes := encrypted[len(encrypted)-sha256.Size:]
for _, key := range c.ticketKeys { for _, key := range ticketKeys {
mac := hmac.New(sha256.New, key.hmacKey[:]) mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(authenticated) mac.Write(authenticated)
expected := mac.Sum(nil) expected := mac.Sum(nil)