0
0
mirror of https://github.com/XTLS/REALITY.git synced 2025-08-22 14:38:35 +00:00

crypto/tls: support QUIC as a transport

Add a QUICConn type for use by QUIC implementations.

A QUICConn provides unencrypted handshake bytes and connection
secrets to the QUIC layer, and receives handshake bytes.

For #44886

Change-Id: I859dda4cc6d466a1df2fb863a69d3a2a069110d5
Reviewed-on: https://go-review.googlesource.com/c/go/+/493655
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
yuhan6665 2023-10-14 05:30:46 -04:00
parent 9a462df048
commit 8bde0136fd
9 changed files with 644 additions and 49 deletions

View File

@ -6,6 +6,16 @@ package reality
import "strconv" import "strconv"
// An AlertError is a TLS alert.
//
// When using a QUIC transport, QUICConn methods will return an error
// which wraps AlertError rather than sending a TLS alert.
type AlertError uint8
func (e AlertError) Error() string {
return alert(e).String()
}
type alert uint8 type alert uint8
const ( const (

View File

@ -99,6 +99,7 @@ const (
extensionCertificateAuthorities uint16 = 47 extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50 extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51 extensionKeyShare uint16 = 51
extensionQUICTransportParameters uint16 = 57
extensionRenegotiationInfo uint16 = 0xff01 extensionRenegotiationInfo uint16 = 0xff01
) )

117
conn.go
View File

@ -34,6 +34,7 @@ type Conn struct {
conn net.Conn conn net.Conn
isClient bool isClient bool
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
quic *quicState // nil for non-QUIC connections
// isHandshakeComplete is true if the connection is currently transferring // isHandshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake). // application data (i.e. is not currently processing a handshake).
@ -184,7 +185,8 @@ type halfConn struct {
nextCipher any // next encryption state nextCipher any // next encryption state
nextMac hash.Hash // next MAC algorithm nextMac hash.Hash // next MAC algorithm
trafficSecret []byte // current TLS 1.3 traffic secret level QUICEncryptionLevel // current QUIC encryption level
trafficSecret []byte // current TLS 1.3 traffic secret
} }
type permanentError struct { type permanentError struct {
@ -229,8 +231,9 @@ func (hc *halfConn) changeCipherSpec() error {
return nil return nil
} }
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
hc.trafficSecret = secret hc.trafficSecret = secret
hc.level = level
key, iv := suite.trafficKey(secret) key, iv := suite.trafficKey(secret)
hc.cipher = suite.aead(key, iv) hc.cipher = suite.aead(key, iv)
for i := range hc.seq { for i := range hc.seq {
@ -648,6 +651,10 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
} }
c.input.Reset(nil) c.input.Reset(nil)
if c.quic != nil {
return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
}
// Read header, payload. // Read header, payload.
if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
@ -737,6 +744,9 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert: case recordTypeAlert:
if c.quic != nil {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if len(data) != 2 { if len(data) != 2 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
} }
@ -854,6 +864,9 @@ func (c *Conn) readFromUntil(r io.Reader, n int) error {
// sendAlertLocked sends a TLS alert message. // sendAlertLocked sends a TLS alert message.
func (c *Conn) sendAlertLocked(err alert) error { func (c *Conn) sendAlertLocked(err alert) error {
if c.quic != nil {
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
switch err { switch err {
case alertNoRenegotiation, alertCloseNotify: case alertNoRenegotiation, alertCloseNotify:
c.tmp[0] = alertLevelWarning c.tmp[0] = alertLevelWarning
@ -988,6 +1001,19 @@ var outBufPool = sync.Pool{
// writeRecordLocked writes a TLS record with the given type and payload to the // writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state. // connection and updates the record layer state.
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
if c.quic != nil {
if typ != recordTypeHandshake {
return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
}
c.quicWriteCryptoData(c.out.level, data)
if !c.buffering {
if _, err := c.flush(); err != nil {
return 0, err
}
}
return len(data), nil
}
outBufPtr := outBufPool.Get().(*[]byte) outBufPtr := outBufPool.Get().(*[]byte)
outBuf := *outBufPtr outBuf := *outBufPtr
defer func() { defer func() {
@ -1101,14 +1127,25 @@ func (c *Conn) writeChangeCipherRecord() error {
return err return err
} }
// readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
func (c *Conn) readHandshakeBytes(n int) error {
if c.quic != nil {
return c.quicReadHandshakeBytes(n)
}
for c.hand.Len() < n {
if err := c.readRecord(); err != nil {
return err
}
}
return nil
}
// readHandshake reads the next handshake message from // readHandshake reads the next handshake message from
// the record layer. If transcript is non-nil, the message // the record layer. If transcript is non-nil, the message
// is written to the passed transcriptHash. // is written to the passed transcriptHash.
func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
for c.hand.Len() < 4 { if err := c.readHandshakeBytes(4); err != nil {
if err := c.readRecord(); err != nil { return nil, err
return nil, err
}
} }
data := c.hand.Bytes() data := c.hand.Bytes()
@ -1117,12 +1154,14 @@ func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
c.sendAlertLocked(alertInternalError) c.sendAlertLocked(alertInternalError)
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
} }
for c.hand.Len() < 4+n { if err := c.readHandshakeBytes(4 + n); err != nil {
if err := c.readRecord(); err != nil { return nil, err
return nil, err
}
} }
data = c.hand.Next(4 + n) data = c.hand.Next(4 + n)
return c.unmarshalHandshakeMessage(data, transcript)
}
func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
var m handshakeMessage var m handshakeMessage
switch data[0] { switch data[0] {
case typeHelloRequest: case typeHelloRequest:
@ -1313,7 +1352,6 @@ func (c *Conn) handlePostHandshakeMessage() error {
if err != nil { if err != nil {
return err return err
} }
c.retryCount++ c.retryCount++
if c.retryCount > maxUselessRecords { if c.retryCount > maxUselessRecords {
c.sendAlert(alertUnexpectedMessage) c.sendAlert(alertUnexpectedMessage)
@ -1325,20 +1363,28 @@ func (c *Conn) handlePostHandshakeMessage() error {
return c.handleNewSessionTicket(msg) return c.handleNewSessionTicket(msg)
case *keyUpdateMsg: case *keyUpdateMsg:
return c.handleKeyUpdate(msg) return c.handleKeyUpdate(msg)
default:
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
} }
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
// unexpected_message alert here doesn't provide it with enough information to distinguish
// this condition from other unexpected messages. This is probably fine.
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
} }
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
if c.quic != nil {
c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
}
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil { if cipherSuite == nil {
return c.in.setErrorLocked(c.sendAlert(alertInternalError)) return c.in.setErrorLocked(c.sendAlert(alertInternalError))
} }
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
c.in.setTrafficSecret(cipherSuite, newSecret) c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
if keyUpdate.updateRequested { if keyUpdate.updateRequested {
c.out.Lock() c.out.Lock()
@ -1357,7 +1403,7 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
} }
newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
c.out.setTrafficSecret(cipherSuite, newSecret) c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
} }
return nil return nil
@ -1518,12 +1564,15 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
// this cancellation. In the former case, we need to close the connection. // this cancellation. In the former case, we need to close the connection.
defer cancel() defer cancel()
// Start the "interrupter" goroutine, if this context might be canceled. if c.quic != nil {
// (The background context cannot). c.quic.cancelc = handshakeCtx.Done()
// c.quic.cancel = cancel
// The interrupter goroutine waits for the input context to be done and } else if ctx.Done() != nil {
// closes the connection if this happens before the function returns. // Start the "interrupter" goroutine, if this context might be canceled.
if ctx.Done() != nil { // (The background context cannot).
//
// The interrupter goroutine waits for the input context to be done and
// closes the connection if this happens before the function returns.
done := make(chan struct{}) done := make(chan struct{})
interruptRes := make(chan error, 1) interruptRes := make(chan error, 1)
defer func() { defer func() {
@ -1574,6 +1623,30 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
panic("tls: internal error: handshake returned an error but is marked successful") panic("tls: internal error: handshake returned an error but is marked successful")
} }
if c.quic != nil {
if c.handshakeErr == nil {
c.quicHandshakeComplete()
// Provide the 1-RTT read secret now that the handshake is complete.
// The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
// the handshake (RFC 9001, Section 5.7).
c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
} else {
var a alert
c.out.Lock()
if !errors.As(c.out.err, &a) {
a = alertInternalError
}
c.out.Unlock()
// Return an error which wraps both the handshake error and
// any alert error we may have sent, or alertInternalError
// if we didn't send an alert.
// Truncate the text of the alert to 0 characters.
c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
}
close(c.quic.blockedc)
close(c.quic.signalc)
}
return c.handshakeErr return c.handshakeErr
} }

View File

@ -71,7 +71,6 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
vers: clientHelloVersion, vers: clientHelloVersion,
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
random: make([]byte, 32), random: make([]byte, 32),
sessionId: make([]byte, 32),
ocspStapling: true, ocspStapling: true,
scts: true, scts: true,
serverName: hostnameInSNI(config.ServerName), serverName: hostnameInSNI(config.ServerName),
@ -114,8 +113,13 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
// A random session ID is used to detect when the server accepted a ticket // A random session ID is used to detect when the server accepted a ticket
// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
// a compatibility measure (see RFC 8446, Section 4.1.2). // a compatibility measure (see RFC 8446, Section 4.1.2).
if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { //
return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) // The session ID is not set for QUIC connections (see RFC 9001, Section 8.4).
if c.quic == nil {
hello.sessionId = make([]byte, 32)
if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
}
} }
if hello.vers >= VersionTLS12 { if hello.vers >= VersionTLS12 {
@ -144,6 +148,17 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
} }
if c.quic != nil {
p, err := c.quicGetTransportParameters()
if err != nil {
return nil, nil, err
}
if p == nil {
p = []byte{}
}
hello.quicTransportParameters = p
}
return hello, key, nil return hello, key, nil
} }
@ -271,7 +286,10 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
} }
// Try to resume a previously negotiated TLS session, if available. // Try to resume a previously negotiated TLS session, if available.
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) cacheKey = c.clientSessionCacheKey()
if cacheKey == "" {
return "", nil, nil, nil, nil
}
session, ok := c.config.ClientSessionCache.Get(cacheKey) session, ok := c.config.ClientSessionCache.Get(cacheKey)
if !ok || session == nil { if !ok || session == nil {
return cacheKey, nil, nil, nil, nil return cacheKey, nil, nil, nil, nil
@ -722,7 +740,7 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
} }
} }
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil { if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil {
c.sendAlert(alertUnsupportedExtension) c.sendAlert(alertUnsupportedExtension)
return false, err return false, err
} }
@ -760,8 +778,12 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
// checkALPN ensure that the server's choice of ALPN protocol is compatible with // checkALPN ensure that the server's choice of ALPN protocol is compatible with
// the protocols that we advertised in the Client Hello. // the protocols that we advertised in the Client Hello.
func checkALPN(clientProtos []string, serverProto string) error { func checkALPN(clientProtos []string, serverProto string, quic bool) error {
if serverProto == "" { if serverProto == "" {
if quic && len(clientProtos) > 0 {
// RFC 9001, Section 8.1
return errors.New("tls: server did not select an ALPN protocol")
}
return nil return nil
} }
if len(clientProtos) == 0 { if len(clientProtos) == 0 {
@ -1003,11 +1025,14 @@ func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate,
// clientSessionCacheKey returns a key used to cache sessionTickets that could // clientSessionCacheKey returns a key used to cache sessionTickets that could
// be used to resume previously negotiated TLS sessions with a server. // be used to resume previously negotiated TLS sessions with a server.
func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { func (c *Conn) clientSessionCacheKey() string {
if len(config.ServerName) > 0 { if len(c.config.ServerName) > 0 {
return config.ServerName return c.config.ServerName
} }
return serverAddr.String() if c.conn != nil {
return c.conn.RemoteAddr().String()
}
return ""
} }
// hostnameInSNI converts name into an appropriate hostname for SNI. // hostnameInSNI converts name into an appropriate hostname for SNI.

View File

@ -172,6 +172,9 @@ func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility // sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. // with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.c.quic != nil {
return nil
}
if hs.sentDummyCCS { if hs.sentDummyCCS {
return nil return nil
} }
@ -383,10 +386,18 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
clientSecret := hs.suite.deriveSecret(handshakeSecret, clientSecret := hs.suite.deriveSecret(handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript) clientHandshakeTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, clientSecret) c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.suite.deriveSecret(handshakeSecret, serverSecret := hs.suite.deriveSecret(handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript) serverHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, serverSecret) c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
}
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil { if err != nil {
@ -419,12 +430,30 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
return unexpectedMessageError(encryptedExtensions, msg) return unexpectedMessageError(encryptedExtensions, msg)
} }
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
c.sendAlert(alertUnsupportedExtension) // RFC 8446 specifies that no_application_protocol is sent by servers, but
// does not specify how clients handle the selection of an incompatible protocol.
// RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
// in this case. Always sending no_application_protocol seems reasonable.
c.sendAlert(alertNoApplicationProtocol)
return err return err
} }
c.clientProtocol = encryptedExtensions.alpnProtocol c.clientProtocol = encryptedExtensions.alpnProtocol
if c.quic != nil {
if encryptedExtensions.quicTransportParameters == nil {
// RFC 9001 Section 8.2.
c.sendAlert(alertMissingExtension)
return errors.New("tls: server did not send a quic_transport_parameters extension")
}
c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
} else {
if encryptedExtensions.quicTransportParameters != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
}
}
return nil return nil
} }
@ -552,7 +581,7 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error {
clientApplicationTrafficLabel, hs.transcript) clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret, serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript) serverApplicationTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, serverSecret) c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
if err != nil { if err != nil {
@ -648,13 +677,20 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
return err return err
} }
c.out.setTrafficSecret(hs.suite, hs.trafficSecret) c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript) resumptionLabel, hs.transcript)
} }
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
}
return nil return nil
} }
@ -702,8 +738,10 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
scts: c.scts, scts: c.scts,
} }
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) cacheKey := c.clientSessionCacheKey()
c.config.ClientSessionCache.Put(cacheKey, session) if cacheKey != "" {
c.config.ClientSessionCache.Put(cacheKey, session)
}
return nil return nil
} }

View File

@ -93,6 +93,7 @@ type clientHelloMsg struct {
pskModes []uint8 pskModes []uint8
pskIdentities []pskIdentity pskIdentities []pskIdentity
pskBinders [][]byte pskBinders [][]byte
quicTransportParameters []byte
} }
func (m *clientHelloMsg) marshal() ([]byte, error) { func (m *clientHelloMsg) marshal() ([]byte, error) {
@ -246,6 +247,13 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
}) })
}) })
} }
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
// RFC 9001, Section 8.2
exts.AddUint16(extensionQUICTransportParameters)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.quicTransportParameters)
})
}
if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
// RFC 8446, Section 4.2.11 // RFC 8446, Section 4.2.11
exts.AddUint16(extensionPreSharedKey) exts.AddUint16(extensionPreSharedKey)
@ -560,6 +568,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
if !readUint8LengthPrefixed(&extData, &m.pskModes) { if !readUint8LengthPrefixed(&extData, &m.pskModes) {
return false return false
} }
case extensionQUICTransportParameters:
m.quicTransportParameters = make([]byte, len(extData))
if !extData.CopyBytes(m.quicTransportParameters) {
return false
}
case extensionPreSharedKey: case extensionPreSharedKey:
// RFC 8446, Section 4.2.11 // RFC 8446, Section 4.2.11
if !extensions.Empty() { if !extensions.Empty() {
@ -860,8 +873,9 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
} }
type encryptedExtensionsMsg struct { type encryptedExtensionsMsg struct {
raw []byte raw []byte
alpnProtocol string alpnProtocol string
quicTransportParameters []byte
} }
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
@ -883,6 +897,13 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
}) })
}) })
} }
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
// draft-ietf-quic-tls-32, Section 8.2
b.AddUint16(extensionQUICTransportParameters)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.quicTransportParameters)
})
}
}) })
}) })
@ -921,6 +942,11 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
return false return false
} }
m.alpnProtocol = string(proto) m.alpnProtocol = string(proto)
case extensionQUICTransportParameters:
m.quicTransportParameters = make([]byte, len(extData))
if !extData.CopyBytes(m.quicTransportParameters) {
return false
}
default: default:
// Ignore unknown extensions. // Ignore unknown extensions.
continue continue

View File

@ -218,7 +218,7 @@ func (hs *serverHandshakeState) processClientHello() error {
c.serverName = hs.clientHello.serverName c.serverName = hs.clientHello.serverName
} }
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
if err != nil { if err != nil {
c.sendAlert(alertNoApplicationProtocol) c.sendAlert(alertNoApplicationProtocol)
return err return err
@ -279,8 +279,12 @@ func (hs *serverHandshakeState) processClientHello() error {
// negotiateALPN picks a shared ALPN protocol that both sides support in server // negotiateALPN picks a shared ALPN protocol that both sides support in server
// preference order. If ALPN is not configured or the peer doesn't support it, // preference order. If ALPN is not configured or the peer doesn't support it,
// it returns "" and no error. // it returns "" and no error.
func negotiateALPN(serverProtos, clientProtos []string) (string, error) { func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
if len(serverProtos) == 0 || len(clientProtos) == 0 { if len(serverProtos) == 0 || len(clientProtos) == 0 {
if quic && len(serverProtos) != 0 {
// RFC 9001, Section 8.1
return "", fmt.Errorf("tls: client did not request an application protocol")
}
return "", nil return "", nil
} }
var http11fallback bool var http11fallback bool

View File

@ -278,6 +278,20 @@ GroupSelection:
return errors.New("tls: invalid client key share") return errors.New("tls: invalid client key share")
} }
if c.quic != nil {
if hs.clientHello.quicTransportParameters == nil {
// RFC 9001 Section 8.2.
c.sendAlert(alertMissingExtension)
return errors.New("tls: client did not send a quic_transport_parameters extension")
}
c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
} else {
if hs.clientHello.quicTransportParameters != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
}
}
c.serverName = hs.clientHello.serverName c.serverName = hs.clientHello.serverName
return nil return nil
} }
@ -449,6 +463,9 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility // sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. // with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.c.quic != nil {
return nil
}
if hs.sentDummyCCS { if hs.sentDummyCCS {
return nil return nil
} }
@ -600,10 +617,18 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript) clientHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, clientSecret) c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript) serverHandshakeTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, serverSecret) c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
}
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
if err != nil { if err != nil {
@ -618,7 +643,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
encryptedExtensions := new(encryptedExtensionsMsg) encryptedExtensions := new(encryptedExtensionsMsg)
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
if err != nil { if err != nil {
c.sendAlert(alertNoApplicationProtocol) c.sendAlert(alertNoApplicationProtocol)
return err return err
@ -626,6 +651,14 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
encryptedExtensions.alpnProtocol = selectedProto encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto c.clientProtocol = selectedProto
if c.quic != nil {
p, err := c.quicGetTransportParameters()
if err != nil {
return err
}
encryptedExtensions.quicTransportParameters = p
}
if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil { if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
return err return err
} }
@ -724,7 +757,15 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
clientApplicationTrafficLabel, hs.transcript) clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret, serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript) serverApplicationTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, serverSecret) c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
// TODO: Handle this in setTrafficSecret?
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
}
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
if err != nil { if err != nil {
@ -939,7 +980,7 @@ func (hs *serverHandshakeStateTLS13) readClientFinished() error {
return errors.New("tls: invalid client finished hash") return errors.New("tls: invalid client finished hash")
} }
c.in.setTrafficSecret(hs.suite, hs.trafficSecret) c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
return nil return nil
} }

377
quic.go Normal file
View File

@ -0,0 +1,377 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package reality
import (
"context"
"errors"
"fmt"
)
// QUICEncryptionLevel represents a QUIC encryption level used to transmit
// handshake messages.
type QUICEncryptionLevel int
const (
QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
QUICEncryptionLevelHandshake
QUICEncryptionLevelApplication
)
func (l QUICEncryptionLevel) String() string {
switch l {
case QUICEncryptionLevelInitial:
return "Initial"
case QUICEncryptionLevelHandshake:
return "Handshake"
case QUICEncryptionLevelApplication:
return "Application"
default:
return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
}
}
// A QUICConn represents a connection which uses a QUIC implementation as the underlying
// transport as described in RFC 9001.
//
// Methods of QUICConn are not safe for concurrent use.
type QUICConn struct {
conn *Conn
}
// A QUICConfig configures a QUICConn.
type QUICConfig struct {
TLSConfig *Config
}
// A QUICEventKind is a type of operation on a QUIC connection.
type QUICEventKind int
const (
// QUICNoEvent indicates that there are no events available.
QUICNoEvent QUICEventKind = iota
// QUICSetReadSecret and QUICSetWriteSecret provide the read and write
// secrets for a given encryption level.
// QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
//
// Secrets for the Initial encryption level are derived from the initial
// destination connection ID, and are not provided by the QUICConn.
QUICSetReadSecret
QUICSetWriteSecret
// QUICWriteData provides data to send to the peer in CRYPTO frames.
// QUICEvent.Data is set.
QUICWriteData
// QUICTransportParameters provides the peer's QUIC transport parameters.
// QUICEvent.Data is set.
QUICTransportParameters
// QUICTransportParametersRequired indicates that the caller must provide
// QUIC transport parameters to send to the peer. The caller should set
// the transport parameters with QUICConn.SetTransportParameters and call
// QUICConn.NextEvent again.
//
// If transport parameters are set before calling QUICConn.Start, the
// connection will never generate a QUICTransportParametersRequired event.
QUICTransportParametersRequired
// QUICHandshakeDone indicates that the TLS handshake has completed.
QUICHandshakeDone
)
// A QUICEvent is an event occurring on a QUIC connection.
//
// The type of event is specified by the Kind field.
// The contents of the other fields are kind-specific.
type QUICEvent struct {
Kind QUICEventKind
// Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
Level QUICEncryptionLevel
// Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
// The contents are owned by crypto/tls, and are valid until the next NextEvent call.
Data []byte
// Set for QUICSetReadSecret and QUICSetWriteSecret.
Suite uint16
}
type quicState struct {
events []QUICEvent
nextEvent int
// eventArr is a statically allocated event array, large enough to handle
// the usual maximum number of events resulting from a single call:
// transport parameters, Initial data, Handshake write and read secrets,
// Handshake data, Application write secret, Application data.
eventArr [7]QUICEvent
started bool
signalc chan struct{} // handshake data is available to be read
blockedc chan struct{} // handshake is waiting for data, closed when done
cancelc <-chan struct{} // handshake has been canceled
cancel context.CancelFunc
// readbuf is shared between HandleData and the handshake goroutine.
// HandshakeCryptoData passes ownership to the handshake goroutine by
// reading from signalc, and reclaims ownership by reading from blockedc.
readbuf []byte
transportParams []byte // to send to the peer
}
// QUICClient returns a new TLS client side connection using QUICTransport as the
// underlying transport. The config cannot be nil.
//
// The config's MinVersion must be at least TLS 1.3.
func QUICClient(config *QUICConfig) *QUICConn {
return newQUICConn(Client(nil, config.TLSConfig))
}
// QUICServer returns a new TLS server side connection using QUICTransport as the
// underlying transport. The config cannot be nil.
//
// The config's MinVersion must be at least TLS 1.3.
func QUICServer(config *QUICConfig) *QUICConn {
c, _ := Server(context.Background(), nil, config.TLSConfig)
return newQUICConn(c)
}
func newQUICConn(conn *Conn) *QUICConn {
conn.quic = &quicState{
signalc: make(chan struct{}),
blockedc: make(chan struct{}),
}
conn.quic.events = conn.quic.eventArr[:0]
return &QUICConn{
conn: conn,
}
}
// Start starts the client or server handshake protocol.
// It may produce connection events, which may be read with NextEvent.
//
// Start must be called at most once.
func (q *QUICConn) Start(ctx context.Context) error {
if q.conn.quic.started {
return quicError(errors.New("tls: Start called more than once"))
}
q.conn.quic.started = true
if q.conn.config.MinVersion < VersionTLS13 {
return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
}
go q.conn.HandshakeContext(ctx)
if _, ok := <-q.conn.quic.blockedc; !ok {
return q.conn.handshakeErr
}
return nil
}
// NextEvent returns the next event occurring on the connection.
// It returns an event with a Kind of QUICNoEvent when no events are available.
func (q *QUICConn) NextEvent() QUICEvent {
qs := q.conn.quic
if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
// Write over some of the previous event's data,
// to catch callers erroniously retaining it.
qs.events[last].Data[0] = 0
}
if qs.nextEvent >= len(qs.events) {
qs.events = qs.events[:0]
qs.nextEvent = 0
return QUICEvent{Kind: QUICNoEvent}
}
e := qs.events[qs.nextEvent]
qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
qs.nextEvent++
return e
}
// Close closes the connection and stops any in-progress handshake.
func (q *QUICConn) Close() error {
if q.conn.quic.cancel == nil {
return nil // never started
}
q.conn.quic.cancel()
for range q.conn.quic.blockedc {
// Wait for the handshake goroutine to return.
}
return q.conn.handshakeErr
}
// HandleData handles handshake bytes received from the peer.
// It may produce connection events, which may be read with NextEvent.
func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
c := q.conn
if c.in.level != level {
return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
}
c.quic.readbuf = data
<-c.quic.signalc
_, ok := <-c.quic.blockedc
if ok {
// The handshake goroutine is waiting for more data.
return nil
}
// The handshake goroutine has exited.
c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil
for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
b := q.conn.hand.Bytes()
n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if 4+n < len(b) {
return nil
}
if err := q.conn.handlePostHandshakeMessage(); err != nil {
return quicError(err)
}
}
if q.conn.handshakeErr != nil {
return quicError(q.conn.handshakeErr)
}
return nil
}
// ConnectionState returns basic TLS details about the connection.
func (q *QUICConn) ConnectionState() ConnectionState {
return q.conn.ConnectionState()
}
// SetTransportParameters sets the transport parameters to send to the peer.
//
// Server connections may delay setting the transport parameters until after
// receiving the client's transport parameters. See QUICTransportParametersRequired.
func (q *QUICConn) SetTransportParameters(params []byte) {
if params == nil {
params = []byte{}
}
q.conn.quic.transportParams = params
if q.conn.quic.started {
<-q.conn.quic.signalc
<-q.conn.quic.blockedc
}
}
// quicError ensures err is an AlertError.
// If err is not already, quicError wraps it with alertInternalError.
func quicError(err error) error {
if err == nil {
return nil
}
var ae AlertError
if errors.As(err, &ae) {
return err
}
var a alert
if !errors.As(err, &a) {
a = alertInternalError
}
// Return an error wrapping the original error and an AlertError.
// Truncate the text of the alert to 0 characters.
return fmt.Errorf("%w%.0w", err, AlertError(a))
}
func (c *Conn) quicReadHandshakeBytes(n int) error {
for c.hand.Len() < n {
if err := c.quicWaitForSignal(); err != nil {
return err
}
}
return nil
}
func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICSetReadSecret,
Level: level,
Suite: suite,
Data: secret,
})
}
func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICSetWriteSecret,
Level: level,
Suite: suite,
Data: secret,
})
}
func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
var last *QUICEvent
if len(c.quic.events) > 0 {
last = &c.quic.events[len(c.quic.events)-1]
}
if last == nil || last.Kind != QUICWriteData || last.Level != level {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICWriteData,
Level: level,
})
last = &c.quic.events[len(c.quic.events)-1]
}
last.Data = append(last.Data, data...)
}
func (c *Conn) quicSetTransportParameters(params []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICTransportParameters,
Data: params,
})
}
func (c *Conn) quicGetTransportParameters() ([]byte, error) {
if c.quic.transportParams == nil {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICTransportParametersRequired,
})
}
for c.quic.transportParams == nil {
if err := c.quicWaitForSignal(); err != nil {
return nil, err
}
}
return c.quic.transportParams, nil
}
func (c *Conn) quicHandshakeComplete() {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICHandshakeDone,
})
}
// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
// and waits for a signal that the handshake should proceed.
//
// The handshake may become blocked waiting for handshake bytes
// or for the user to provide transport parameters.
func (c *Conn) quicWaitForSignal() error {
// Drop the handshake mutex while blocked to allow the user
// to call ConnectionState before the handshake completes.
c.handshakeMutex.Unlock()
defer c.handshakeMutex.Lock()
// Send on blockedc to notify the QUICConn that the handshake is blocked.
// Exported methods of QUICConn wait for the handshake to become blocked
// before returning to the user.
select {
case c.quic.blockedc <- struct{}{}:
case <-c.quic.cancelc:
return c.sendAlertLocked(alertCloseNotify)
}
// The QUICConn reads from signalc to notify us that the handshake may
// be able to proceed. (The QUICConn reads, because we close signalc to
// indicate that the handshake has completed.)
select {
case c.quic.signalc <- struct{}{}:
c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil
case <-c.quic.cancelc:
return c.sendAlertLocked(alertCloseNotify)
}
return nil
}