diff --git a/common.go b/common.go index 6dfc7f6..5cdbf81 100644 --- a/common.go +++ b/common.go @@ -129,11 +129,13 @@ const ( scsvRenegotiation uint16 = 0x00ff ) -// CurveID is the type of a TLS identifier for an elliptic curve. See +// CurveID is the type of a TLS identifier for a key exchange mechanism. See // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8. // -// In TLS 1.3, this type is called NamedGroup, but at this time this library -// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7. +// In TLS 1.2, this registry used to support only elliptic curves. In TLS 1.3, +// it was extended to other groups and renamed NamedGroup. See RFC 8446, Section +// 4.2.7. It was then also extended to other mechanisms, such as hybrid +// post-quantum KEMs. type CurveID uint16 const ( @@ -141,6 +143,11 @@ const ( CurveP384 CurveID = 24 CurveP521 CurveID = 25 X25519 CurveID = 29 + + // Experimental codepoint for X25519Kyber768Draft00, specified in + // draft-tls-westerbaan-xyber768d00-03. Not exported, as support might be + // removed in the future. + x25519Kyber768Draft00 CurveID = 0x6399 // X25519Kyber768Draft00 ) // TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. @@ -301,6 +308,10 @@ type ConnectionState struct { // testingOnlyDidHRR is true if a HelloRetryRequest was sent/received. testingOnlyDidHRR bool + + // testingOnlyCurveID is the selected CurveID, or zero if an RSA exchanges + // is performed. + testingOnlyCurveID CurveID } // ExportKeyingMaterial returns length bytes of exported key material in a new @@ -374,7 +385,7 @@ type ClientSessionCache interface { Put(sessionKey string, cs *ClientSessionState) } -//go:generate stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go +//go:generate stringer -linecomment -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go // SignatureScheme identifies a signature algorithm supported by TLS. See // RFC 8446, Section 4.2.3. @@ -770,6 +781,10 @@ type Config struct { // an ECDHE handshake, in preference order. If empty, the default will // be used. The client will use the first preference as the type for // its key share in TLS 1.3. This may change in the future. + // + // From Go 1.23, the default includes the X25519Kyber768Draft00 hybrid + // post-quantum key exchange. To disable it, set CurvePreferences explicitly + // or use the GODEBUG=tlskyber=0 environment variable. CurvePreferences []CurveID // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. @@ -1099,20 +1114,25 @@ func supportedVersionsFromMax(maxVersion uint16) []uint16 { return versions } -var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} +var defaultCurvePreferences = []CurveID{x25519Kyber768Draft00, X25519, CurveP256, CurveP384, CurveP521} -func (c *Config) curvePreferences() []CurveID { +var defaultCurvePreferencesWithoutKyber = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *Config) curvePreferences(version uint16) []CurveID { if needFIPS() { return fipsCurvePreferences(c) } if c == nil || len(c.CurvePreferences) == 0 { + if version < VersionTLS13 || true /*tlskyber.Value() == "0"*/ { + return defaultCurvePreferencesWithoutKyber + } return defaultCurvePreferences } return c.CurvePreferences } -func (c *Config) supportsCurve(curve CurveID) bool { - for _, cc := range c.curvePreferences() { +func (c *Config) supportsCurve(version uint16, curve CurveID) bool { + for _, cc := range c.curvePreferences(version) { if cc == curve { return true } @@ -1271,7 +1291,7 @@ func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error { } // The only signed key exchange we support is ECDHE. - if !supportsECDHE(config, chi.SupportedCurves, chi.SupportedPoints) { + if !supportsECDHE(config, vers, chi.SupportedCurves, chi.SupportedPoints) { return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) } @@ -1292,7 +1312,7 @@ func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error { } var curveOk bool for _, c := range chi.SupportedCurves { - if c == curve && config.supportsCurve(c) { + if c == curve && config.supportsCurve(vers, c) { curveOk = true break } diff --git a/common_string.go b/common_string.go index 4f500e2..4ac388b 100644 --- a/common_string.go +++ b/common_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT. +// Code generated by "stringer -linecomment -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT. package reality @@ -71,11 +71,13 @@ func _() { _ = x[CurveP384-24] _ = x[CurveP521-25] _ = x[X25519-29] + _ = x[x25519Kyber768Draft00-25497] } const ( _CurveID_name_0 = "CurveP256CurveP384CurveP521" _CurveID_name_1 = "X25519" + _CurveID_name_2 = "X25519Kyber768Draft00" ) var ( @@ -89,6 +91,8 @@ func (i CurveID) String() string { return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]] case i == 29: return _CurveID_name_1 + case i == 25497: + return _CurveID_name_2 default: return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/conn.go b/conn.go index 17a4046..7ded21c 100644 --- a/conn.go +++ b/conn.go @@ -54,6 +54,7 @@ type Conn struct { didResume bool // whether this connection was a session resumption didHRR bool // whether a HelloRetryRequest was sent/received cipherSuite uint16 + curveID CurveID ocspResponse []byte // stapled OCSP response scts [][]byte // signed certificate timestamps from server peerCertificates []*x509.Certificate @@ -1671,6 +1672,8 @@ func (c *Conn) connectionStateLocked() ConnectionState { state.NegotiatedProtocol = c.clientProtocol state.DidResume = c.didResume state.testingOnlyDidHRR = c.didHRR + // c.curveID is not set on TLS 1.0–1.2 resumptions. Fix that before exposing it. + state.testingOnlyCurveID = c.curveID state.NegotiatedProtocolIsMutual = true state.ServerName = c.serverName state.CipherSuite = c.cipherSuite diff --git a/handshake_client.go b/handshake_client.go index 7b6d598..333c884 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -8,12 +8,12 @@ import ( "bytes" "context" "crypto" - "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/subtle" "crypto/x509" + "encoding/binary" "errors" "fmt" "hash" @@ -21,6 +21,8 @@ import ( "net" "strings" "time" + + "github.com/xtls/reality/mlkem768" ) type clientHandshakeState struct { @@ -37,7 +39,7 @@ type clientHandshakeState struct { var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme -func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { +func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, error) { config := c.config if len(config.ServerName) == 0 && !config.InsecureSkipVerify { return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") @@ -60,29 +62,30 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") } - clientHelloVersion := config.maxSupportedVersion(roleClient) - // The version at the beginning of the ClientHello was capped at TLS 1.2 - // for compatibility reasons. The supported_versions extension is used - // to negotiate versions now. See RFC 8446, Section 4.2.1. - if clientHelloVersion > VersionTLS12 { - clientHelloVersion = VersionTLS12 - } + maxVersion := config.maxSupportedVersion(roleClient) hello := &clientHelloMsg{ - vers: clientHelloVersion, + vers: maxVersion, compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), extendedMasterSecret: true, ocspStapling: true, scts: true, serverName: hostnameInSNI(config.ServerName), - supportedCurves: config.curvePreferences(), + supportedCurves: config.curvePreferences(maxVersion), supportedPoints: []uint8{pointFormatUncompressed}, secureRenegotiationSupported: true, alpnProtocols: config.NextProtos, supportedVersions: supportedVersions, } + // The version at the beginning of the ClientHello was capped at TLS 1.2 + // for compatibility reasons. The supported_versions extension is used + // to negotiate versions now. See RFC 8446, Section 4.2.1. + if hello.vers > VersionTLS12 { + hello.vers = VersionTLS12 + } + if c.handshakes > 0 { hello.secureRenegotiation = c.clientFinished[:] } @@ -101,7 +104,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { } // Don't advertise TLS 1.2-only cipher suites unless // we're attempting TLS 1.2. - if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + if maxVersion < VersionTLS12 && suite.flags&suiteTLS12 != 0 { continue } hello.cipherSuites = append(hello.cipherSuites, suiteId) @@ -124,14 +127,14 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { } } - if hello.vers >= VersionTLS12 { + if maxVersion >= VersionTLS12 { hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() } if testingOnlyForceClientHelloSignatureAlgorithms != nil { hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms } - var key *ecdh.PrivateKey + var keyShareKeys *keySharePrivateKeys if hello.supportedVersions[0] == VersionTLS13 { // Reset the list of ciphers when the client only supports TLS 1.3. if len(hello.supportedVersions) == 1 { @@ -143,15 +146,40 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) } - curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + curveID := config.curvePreferences(maxVersion)[0] + keyShareKeys = &keySharePrivateKeys{curveID: curveID} + if curveID == x25519Kyber768Draft00 { + keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), X25519) + if err != nil { + return nil, nil, err + } + seed := make([]byte, mlkem768.SeedSize) + if _, err := io.ReadFull(config.rand(), seed); err != nil { + return nil, nil, err + } + keyShareKeys.kyber, err = mlkem768.NewKeyFromSeed(seed) + if err != nil { + return nil, nil, err + } + // For draft-tls-westerbaan-xyber768d00-03, we send both a hybrid + // and a standard X25519 key share, since most servers will only + // support the latter. We reuse the same X25519 ephemeral key for + // both, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2. + hello.keyShares = []keyShare{ + {group: x25519Kyber768Draft00, data: append(keyShareKeys.ecdhe.PublicKey().Bytes(), + keyShareKeys.kyber.EncapsulationKey()...)}, + {group: X25519, data: keyShareKeys.ecdhe.PublicKey().Bytes()}, + } + } else { + if _, ok := curveForCurveID(curveID); !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: keyShareKeys.ecdhe.PublicKey().Bytes()}} } - key, err = generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, nil, err - } - hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } if c.quic != nil { @@ -165,7 +193,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { hello.quicTransportParameters = p } - return hello, key, nil + return hello, keyShareKeys, nil } func (c *Conn) clientHandshake(ctx context.Context) (err error) { @@ -177,7 +205,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { // need to be reset. c.didResume = false - hello, ecdheKey, err := c.makeClientHello() + hello, keyShareKeys, err := c.makeClientHello() if err != nil { return err } @@ -247,17 +275,15 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { if c.vers == VersionTLS13 { hs := &clientHandshakeStateTLS13{ - c: c, - ctx: ctx, - serverHello: serverHello, - hello: hello, - ecdheKey: ecdheKey, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, + c: c, + ctx: ctx, + serverHello: serverHello, + hello: hello, + keyShareKeys: keyShareKeys, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, } - - // In TLS 1.3, session tickets are delivered after the handshake. return hs.handshake() } @@ -269,11 +295,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { session: session, } - if err := hs.handshake(); err != nil { - return err - } - - return nil + return hs.handshake() } func (c *Conn) loadSession(hello *clientHelloMsg) ( @@ -596,6 +618,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return err } + if len(skx.key) >= 3 && skx.key[0] == 3 /* named curve */ { + c.curveID = CurveID(binary.BigEndian.Uint16(skx.key[1:])) + } msg, err = c.readHandshake(&hs.finishedHash) if err != nil { diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 7b09281..3e69d59 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -8,20 +8,22 @@ import ( "bytes" "context" "crypto" - "crypto/ecdh" "crypto/hmac" "crypto/rsa" "errors" "hash" + "slices" "time" + + "github.com/xtls/reality/mlkem768" ) type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheKey *ecdh.PrivateKey + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + keyShareKeys *keySharePrivateKeys session *SessionState earlySecret []byte @@ -36,7 +38,7 @@ type clientHandshakeStateTLS13 struct { trafficSecret []byte // client_application_traffic_secret_0 } -// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and, +// handshake requires hs.c, hs.hello, hs.serverHello, hs.keyShareKeys, and, // optionally, hs.session, hs.earlySecret and hs.binderKey to be set. func (hs *clientHandshakeStateTLS13) handshake() error { c := hs.c @@ -53,7 +55,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheKey == nil || len(hs.hello.keyShares) != 1 { + if hs.keyShareKeys == nil || hs.keyShareKeys.ecdhe == nil || len(hs.hello.keyShares) == 0 { return c.sendAlert(alertInternalError) } @@ -221,21 +223,22 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { // a group we advertised but did not send a key share for, and send a key // share for it this time. if curveID := hs.serverHello.selectedGroup; curveID != 0 { - curveOK := false - for _, id := range hs.hello.supportedCurves { - if id == curveID { - curveOK = true - break - } - } - if !curveOK { + if !slices.Contains(hs.hello.supportedCurves, curveID) { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID { + if slices.ContainsFunc(hs.hello.keyShares, func(ks keyShare) bool { + return ks.group == curveID + }) { c.sendAlert(alertIllegalParameter) return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") } + // Note: we don't support selecting X25519Kyber768Draft00 in a HRR, + // because we currently only support it at all when CurvePreferences is + // empty, which will cause us to also send a key share for it. + // + // This will have to change once we support selecting hybrid KEMs + // without sending key shares for them. if _, ok := curveForCurveID(curveID); !ok { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") @@ -245,7 +248,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.sendAlert(alertInternalError) return err } - hs.ecdheKey = key + hs.keyShareKeys = &keySharePrivateKeys{curveID: curveID, ecdhe: key} hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } @@ -333,7 +336,9 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server did not send a key share") } - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID { + if !slices.ContainsFunc(hs.hello.keyShares, func(ks keyShare) bool { + return ks.group == hs.serverHello.serverShare.group + }) { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } @@ -372,16 +377,37 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { c := hs.c - peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) + ecdhePeerData := hs.serverHello.serverShare.data + if hs.serverHello.serverShare.group == x25519Kyber768Draft00 { + if len(ecdhePeerData) != x25519PublicKeySize+mlkem768.CiphertextSize { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + ecdhePeerData = hs.serverHello.serverShare.data[:x25519PublicKeySize] + } + peerKey, err := hs.keyShareKeys.ecdhe.Curve().NewPublicKey(ecdhePeerData) if err != nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid server key share") } - sharedKey, err := hs.ecdheKey.ECDH(peerKey) + sharedKey, err := hs.keyShareKeys.ecdhe.ECDH(peerKey) if err != nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid server key share") } + if hs.serverHello.serverShare.group == x25519Kyber768Draft00 { + if hs.keyShareKeys.kyber == nil { + return c.sendAlert(alertInternalError) + } + ciphertext := hs.serverHello.serverShare.data[x25519PublicKeySize:] + kyberShared, err := kyberDecapsulate(hs.keyShareKeys.kyber, ciphertext) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid Kyber server key share") + } + sharedKey = append(sharedKey, kyberShared...) + } + c.curveID = hs.serverHello.serverShare.group earlySecret := hs.earlySecret if !hs.usingPSK { diff --git a/handshake_server.go b/handshake_server.go index 3a8d8cb..46bd06d 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -12,6 +12,7 @@ import ( "crypto/rsa" "crypto/subtle" "crypto/x509" + "encoding/binary" "errors" "fmt" "hash" @@ -242,7 +243,7 @@ func (hs *serverHandshakeState) processClientHello() error { hs.hello.scts = hs.cert.SignedCertificateTimestamps } - hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) + hs.ecdheOk = supportsECDHE(c.config, c.vers, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 { // Although omitting the ec_point_formats extension is permitted, some @@ -313,10 +314,10 @@ func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, erro // supportsECDHE returns whether ECDHE key exchanges can be used with this // pre-TLS 1.3 client. -func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8) bool { +func supportsECDHE(c *Config, version uint16, supportedCurves []CurveID, supportedPoints []uint8) bool { supportsCurve := false for _, curve := range supportedCurves { - if c.supportsCurve(curve) { + if c.supportsCurve(version, curve) { supportsCurve = true break } @@ -577,6 +578,9 @@ func (hs *serverHandshakeState) doFullHandshake() error { return err } if skx != nil { + if len(skx.key) >= 3 && skx.key[0] == 3 /* named curve */ { + c.curveID = CurveID(binary.BigEndian.Uint16(skx.key[1:])) + } if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { return err } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 5ccf13b..1488af9 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -21,6 +21,8 @@ import ( "math/big" "slices" "time" + + "github.com/xtls/reality/mlkem768" ) // maxClientPSKIdentities is the number of client PSK identities the server will @@ -230,11 +232,11 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { hs.hello.cipherSuite = hs.suite.id hs.transcript = hs.suite.hash.New() - // Pick the ECDHE group in server preference order, but give priority to - // groups with a key share, to avoid a HelloRetryRequest round-trip. + // Pick the key exchange method in server preference order, but give + // priority to key shares, to avoid a HelloRetryRequest round-trip. var selectedGroup CurveID var clientKeyShare *keyShare - preferredGroups := c.config.curvePreferences() + preferredGroups := c.config.curvePreferences(c.vers) for _, preferredGroup := range preferredGroups { ki := slices.IndexFunc(hs.clientHello.keyShares, func(ks keyShare) bool { return ks.group == preferredGroup @@ -262,23 +264,35 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { return errors.New("tls: no ECDHE curve supported by both client and server") } if clientKeyShare == nil { - if err := hs.doHelloRetryRequest(selectedGroup); err != nil { + ks, err := hs.doHelloRetryRequest(selectedGroup) + if err != nil { return err } - clientKeyShare = &hs.clientHello.keyShares[0] + clientKeyShare = ks } + c.curveID = selectedGroup - if _, ok := curveForCurveID(selectedGroup); !ok { + ecdhGroup := selectedGroup + ecdhData := clientKeyShare.data + if selectedGroup == x25519Kyber768Draft00 { + ecdhGroup = X25519 + if len(ecdhData) != x25519PublicKeySize+mlkem768.EncapsulationKeySize { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid Kyber client key share") + } + ecdhData = ecdhData[:x25519PublicKeySize] + } + if _, ok := curveForCurveID(ecdhGroup); !ok { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") } - key, err := generateECDHEKey(c.config.rand(), selectedGroup) + key, err := generateECDHEKey(c.config.rand(), ecdhGroup) if err != nil { c.sendAlert(alertInternalError) return err } hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()} - peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data) + peerKey, err := key.Curve().NewPublicKey(ecdhData) if err != nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid client key share") @@ -288,6 +302,15 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid client key share") } + if selectedGroup == x25519Kyber768Draft00 { + ciphertext, kyberShared, err := kyberEncapsulate(clientKeyShare.data[x25519PublicKeySize:]) + if err != nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid Kyber client key share") + } + hs.sharedKey = append(hs.sharedKey, kyberShared...) + hs.hello.serverShare.data = append(hs.hello.serverShare.data, ciphertext...) + } selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil) if err != nil { @@ -531,13 +554,13 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { return hs.c.writeChangeCipherRecord() } -func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { +func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) (*keyShare, error) { c := hs.c // The first ClientHello gets double-hashed into the transcript upon a // HelloRetryRequest. See RFC 8446, Section 4.4.1. if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { - return err + return nil, err } chHash := hs.transcript.Sum(nil) hs.transcript.Reset() @@ -555,43 +578,49 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) } if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil { - return err + return nil, err } if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err + return nil, err } // clientHelloMsg is not included in the transcript. msg, err := c.readHandshake(nil) if err != nil { - return err + return nil, err } clientHello, ok := msg.(*clientHelloMsg) if !ok { c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientHello, msg) + return nil, unexpectedMessageError(clientHello, msg) } - if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { + if len(clientHello.keyShares) != 1 { c.sendAlert(alertIllegalParameter) - return errors.New("tls: client sent invalid key share in second ClientHello") + return nil, errors.New("tls: client didn't send one key share in second ClientHello") + } + ks := &clientHello.keyShares[0] + + if ks.group != selectedGroup { + c.sendAlert(alertIllegalParameter) + return nil, errors.New("tls: client sent unexpected key share in second ClientHello") } if clientHello.earlyData { c.sendAlert(alertIllegalParameter) - return errors.New("tls: client indicated early data in second ClientHello") + return nil, errors.New("tls: client indicated early data in second ClientHello") } if illegalClientHelloChange(clientHello, hs.clientHello) { c.sendAlert(alertIllegalParameter) - return errors.New("tls: client illegally modified second ClientHello") + return nil, errors.New("tls: client illegally modified second ClientHello") } c.didHRR = true hs.clientHello = clientHello - return nil + return ks, nil } // illegalClientHelloChange reports whether the two ClientHello messages are diff --git a/key_agreement.go b/key_agreement.go index d052b9e..8429d13 100644 --- a/key_agreement.go +++ b/key_agreement.go @@ -16,8 +16,8 @@ import ( "io" ) -// a keyAgreement implements the client and server side of a TLS key agreement -// protocol by generating and processing key exchange messages. +// A keyAgreement implements the client and server side of a TLS 1.0–1.2 key +// agreement protocol by generating and processing key exchange messages. type keyAgreement interface { // On the server side, the first two methods are called in order. @@ -126,7 +126,7 @@ func md5SHA1Hash(slices [][]byte) []byte { } // hashForServerKeyExchange hashes the given slices and returns their digest -// using the given hash function (for >= TLS 1.2) or using a default based on +// using the given hash function (for TLS 1.2) or using a default based on // the sigType (for earlier TLS versions). For Ed25519 signatures, which don't // do pre-hashing, it returns the concatenation of the slices. func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { @@ -169,7 +169,7 @@ type ecdheKeyAgreement struct { func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { var curveID CurveID for _, c := range clientHello.supportedCurves { - if config.supportsCurve(c) { + if config.supportsCurve(ka.version, c) { curveID = c break } diff --git a/key_schedule.go b/key_schedule.go index 46c7d4a..63636b5 100644 --- a/key_schedule.go +++ b/key_schedule.go @@ -14,6 +14,9 @@ import ( "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/hkdf" + "golang.org/x/crypto/sha3" + + "github.com/xtls/reality/mlkem768" ) // This file contains the functions necessary to compute the TLS 1.3 key @@ -117,6 +120,45 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript } } +type keySharePrivateKeys struct { + curveID CurveID + ecdhe *ecdh.PrivateKey + kyber *mlkem768.DecapsulationKey +} + +// kyberDecapsulate implements decapsulation according to Kyber Round 3. +func kyberDecapsulate(dk *mlkem768.DecapsulationKey, c []byte) ([]byte, error) { + K, err := mlkem768.Decapsulate(dk, c) + if err != nil { + return nil, err + } + return kyberSharedSecret(K, c), nil +} + +// kyberEncapsulate implements encapsulation according to Kyber Round 3. +func kyberEncapsulate(ek []byte) (c, ss []byte, err error) { + c, ss, err = mlkem768.Encapsulate(ek) + if err != nil { + return nil, nil, err + } + return c, kyberSharedSecret(ss, c), nil +} + +func kyberSharedSecret(K, c []byte) []byte { + // Package mlkem768 implements ML-KEM, which compared to Kyber removed a + // final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber. + // See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3. + h := sha3.NewShake256() + h.Write(K) + ch := sha3.Sum256(c) + h.Write(ch[:]) + out := make([]byte, 32) + h.Read(out) + return out +} + +const x25519PublicKeySize = 32 + // generateECDHEKey returns a PrivateKey that implements Diffie-Hellman // according to RFC 8446, Section 4.2.8.2. func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) { diff --git a/mlkem768/mlkem768.go b/mlkem768/mlkem768.go new file mode 100644 index 0000000..24bedea --- /dev/null +++ b/mlkem768/mlkem768.go @@ -0,0 +1,886 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package mlkem768 implements the quantum-resistant key encapsulation method +// ML-KEM (formerly known as Kyber). +// +// Only the recommended ML-KEM-768 parameter set is provided. +// +// The version currently implemented is the one specified by [NIST FIPS 203 ipd], +// with the unintentional transposition of the matrix A reverted to match the +// behavior of [Kyber version 3.0]. Future versions of this package might +// introduce backwards incompatible changes to implement changes to FIPS 203. +// +// [Kyber version 3.0]: https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf +// [NIST FIPS 203 ipd]: https://doi.org/10.6028/NIST.FIPS.203.ipd +package mlkem768 + +// This package targets security, correctness, simplicity, readability, and +// reviewability as its primary goals. All critical operations are performed in +// constant time. +// +// Variable and function names, as well as code layout, are selected to +// facilitate reviewing the implementation against the NIST FIPS 203 ipd +// document. +// +// Reviewers unfamiliar with polynomials or linear algebra might find the +// background at https://words.filippo.io/kyber-math/ useful. + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/binary" + "errors" + + "golang.org/x/crypto/sha3" +) + +const ( + // ML-KEM global constants. + n = 256 + q = 3329 + + log2q = 12 + + // ML-KEM-768 parameters. The code makes assumptions based on these values, + // they can't be changed blindly. + k = 3 + η = 2 + du = 10 + dv = 4 + + // encodingSizeX is the byte size of a ringElement or nttElement encoded + // by ByteEncode_X (FIPS 203 (DRAFT), Algorithm 4). + encodingSize12 = n * log2q / 8 + encodingSize10 = n * du / 8 + encodingSize4 = n * dv / 8 + encodingSize1 = n * 1 / 8 + + messageSize = encodingSize1 + decryptionKeySize = k * encodingSize12 + encryptionKeySize = k*encodingSize12 + 32 + + CiphertextSize = k*encodingSize10 + encodingSize4 + EncapsulationKeySize = encryptionKeySize + DecapsulationKeySize = decryptionKeySize + encryptionKeySize + 32 + 32 + SharedKeySize = 32 + SeedSize = 32 + 32 +) + +// A DecapsulationKey is the secret key used to decapsulate a shared key from a +// ciphertext. It includes various precomputed values. +type DecapsulationKey struct { + dk [DecapsulationKeySize]byte + encryptionKey + decryptionKey +} + +// Bytes returns the extended encoding of the decapsulation key, according to +// FIPS 203 (DRAFT). +func (dk *DecapsulationKey) Bytes() []byte { + var b [DecapsulationKeySize]byte + copy(b[:], dk.dk[:]) + return b[:] +} + +// EncapsulationKey returns the public encapsulation key necessary to produce +// ciphertexts. +func (dk *DecapsulationKey) EncapsulationKey() []byte { + var b [EncapsulationKeySize]byte + copy(b[:], dk.dk[decryptionKeySize:]) + return b[:] +} + +// encryptionKey is the parsed and expanded form of a PKE encryption key. +type encryptionKey struct { + t [k]nttElement // ByteDecode₁₂(ek[:384k]) + A [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i) +} + +// decryptionKey is the parsed and expanded form of a PKE decryption key. +type decryptionKey struct { + s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize]) +} + +// GenerateKey generates a new decapsulation key, drawing random bytes from +// crypto/rand. The decapsulation key must be kept secret. +func GenerateKey() (*DecapsulationKey, error) { + // The actual logic is in a separate function to outline this allocation. + dk := &DecapsulationKey{} + return generateKey(dk) +} + +func generateKey(dk *DecapsulationKey) (*DecapsulationKey, error) { + var d [32]byte + if _, err := rand.Read(d[:]); err != nil { + return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) + } + var z [32]byte + if _, err := rand.Read(z[:]); err != nil { + return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) + } + return kemKeyGen(dk, &d, &z), nil +} + +// NewKeyFromSeed deterministically generates a decapsulation key from a 64-byte +// seed in the "d || z" form. The seed must be uniformly random. +func NewKeyFromSeed(seed []byte) (*DecapsulationKey, error) { + // The actual logic is in a separate function to outline this allocation. + dk := &DecapsulationKey{} + return newKeyFromSeed(dk, seed) +} + +func newKeyFromSeed(dk *DecapsulationKey, seed []byte) (*DecapsulationKey, error) { + if len(seed) != SeedSize { + return nil, errors.New("mlkem768: invalid seed length") + } + d := (*[32]byte)(seed[:32]) + z := (*[32]byte)(seed[32:]) + return kemKeyGen(dk, d, z), nil +} + +// NewKeyFromExtendedEncoding parses a decapsulation key from its FIPS 203 +// (DRAFT) extended encoding. +func NewKeyFromExtendedEncoding(decapsulationKey []byte) (*DecapsulationKey, error) { + // The actual logic is in a separate function to outline this allocation. + dk := &DecapsulationKey{} + return newKeyFromExtendedEncoding(dk, decapsulationKey) +} + +func newKeyFromExtendedEncoding(dk *DecapsulationKey, dkBytes []byte) (*DecapsulationKey, error) { + if len(dkBytes) != DecapsulationKeySize { + return nil, errors.New("mlkem768: invalid decapsulation key length") + } + + // Note that we don't check that H(ek) matches ekPKE, as that's not + // specified in FIPS 203 (DRAFT). This is one reason to prefer the seed + // private key format. + dk.dk = [DecapsulationKeySize]byte(dkBytes) + + dkPKE := dkBytes[:decryptionKeySize] + if err := parseDK(&dk.decryptionKey, dkPKE); err != nil { + return nil, err + } + + ekPKE := dkBytes[decryptionKeySize : decryptionKeySize+encryptionKeySize] + if err := parseEK(&dk.encryptionKey, ekPKE); err != nil { + return nil, err + } + + return dk, nil +} + +// kemKeyGen generates a decapsulation key. +// +// It implements ML-KEM.KeyGen according to FIPS 203 (DRAFT), Algorithm 15, and +// K-PKE.KeyGen according to FIPS 203 (DRAFT), Algorithm 12. The two are merged +// to save copies and allocations. +func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey { + if dk == nil { + dk = &DecapsulationKey{} + } + + G := sha3.Sum512(d[:]) + ρ, σ := G[:32], G[32:] + + A := &dk.A + for i := byte(0); i < k; i++ { + for j := byte(0); j < k; j++ { + // Note that this is consistent with Kyber round 3, rather than with + // the initial draft of FIPS 203, because NIST signaled that the + // change was involuntary and will be reverted. + A[i*k+j] = sampleNTT(ρ, j, i) + } + } + + var N byte + s := &dk.s + for i := range s { + s[i] = ntt(samplePolyCBD(σ, N)) + N++ + } + e := make([]nttElement, k) + for i := range e { + e[i] = ntt(samplePolyCBD(σ, N)) + N++ + } + + t := &dk.t + for i := range t { // t = A ◦ s + e + t[i] = e[i] + for j := range s { + t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j])) + } + } + + // dkPKE ← ByteEncode₁₂(s) + // ekPKE ← ByteEncode₁₂(t) || ρ + // ek ← ekPKE + // dk ← dkPKE || ek || H(ek) || z + dkB := dk.dk[:0] + + for i := range s { + dkB = polyByteEncode(dkB, s[i]) + } + + for i := range t { + dkB = polyByteEncode(dkB, t[i]) + } + dkB = append(dkB, ρ...) + + H := sha3.New256() + H.Write(dkB[decryptionKeySize:]) + dkB = H.Sum(dkB) + + dkB = append(dkB, z[:]...) + + if len(dkB) != len(dk.dk) { + panic("mlkem768: internal error: invalid decapsulation key size") + } + + return dk +} + +// Encapsulate generates a shared key and an associated ciphertext from an +// encapsulation key, drawing random bytes from crypto/rand. +// If the encapsulation key is not valid, Encapsulate returns an error. +// +// The shared key must be kept secret. +func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) { + // The actual logic is in a separate function to outline this allocation. + var cc [CiphertextSize]byte + return encapsulate(&cc, encapsulationKey) +} + +func encapsulate(cc *[CiphertextSize]byte, encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) { + if len(encapsulationKey) != EncapsulationKeySize { + return nil, nil, errors.New("mlkem768: invalid encapsulation key length") + } + var m [messageSize]byte + if _, err := rand.Read(m[:]); err != nil { + return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) + } + return kemEncaps(cc, encapsulationKey, &m) +} + +// kemEncaps generates a shared key and an associated ciphertext. +// +// It implements ML-KEM.Encaps according to FIPS 203 (DRAFT), Algorithm 16. +func kemEncaps(cc *[CiphertextSize]byte, ek []byte, m *[messageSize]byte) (c, K []byte, err error) { + if cc == nil { + cc = &[CiphertextSize]byte{} + } + + H := sha3.Sum256(ek[:]) + g := sha3.New512() + g.Write(m[:]) + g.Write(H[:]) + G := g.Sum(nil) + K, r := G[:SharedKeySize], G[SharedKeySize:] + var ex encryptionKey + if err := parseEK(&ex, ek[:]); err != nil { + return nil, nil, err + } + c = pkeEncrypt(cc, &ex, m, r) + return c, K, nil +} + +// parseEK parses an encryption key from its encoded form. +// +// It implements the initial stages of K-PKE.Encrypt according to FIPS 203 +// (DRAFT), Algorithm 13. +func parseEK(ex *encryptionKey, ekPKE []byte) error { + if len(ekPKE) != encryptionKeySize { + return errors.New("mlkem768: invalid encryption key length") + } + + for i := range ex.t { + var err error + ex.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12]) + if err != nil { + return err + } + ekPKE = ekPKE[encodingSize12:] + } + ρ := ekPKE + + for i := byte(0); i < k; i++ { + for j := byte(0); j < k; j++ { + // See the note in pkeKeyGen about the order of the indices being + // consistent with Kyber round 3. + ex.A[i*k+j] = sampleNTT(ρ, j, i) + } + } + + return nil +} + +// pkeEncrypt encrypt a plaintext message. +// +// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13, +// although the computation of t and AT is done in parseEK. +func pkeEncrypt(cc *[CiphertextSize]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte { + var N byte + r, e1 := make([]nttElement, k), make([]ringElement, k) + for i := range r { + r[i] = ntt(samplePolyCBD(rnd, N)) + N++ + } + for i := range e1 { + e1[i] = samplePolyCBD(rnd, N) + N++ + } + e2 := samplePolyCBD(rnd, N) + + u := make([]ringElement, k) // NTT⁻¹(AT ◦ r) + e1 + for i := range u { + u[i] = e1[i] + for j := range r { + // Note that i and j are inverted, as we need the transposed of A. + u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.A[j*k+i], r[j]))) + } + } + + μ := ringDecodeAndDecompress1(m) + + var vNTT nttElement // t⊺ ◦ r + for i := range ex.t { + vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i])) + } + v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ) + + c := cc[:0] + for _, f := range u { + c = ringCompressAndEncode10(c, f) + } + c = ringCompressAndEncode4(c, v) + + return c +} + +// Decapsulate generates a shared key from a ciphertext and a decapsulation key. +// If the ciphertext is not valid, Decapsulate returns an error. +// +// The shared key must be kept secret. +func Decapsulate(dk *DecapsulationKey, ciphertext []byte) (sharedKey []byte, err error) { + if len(ciphertext) != CiphertextSize { + return nil, errors.New("mlkem768: invalid ciphertext length") + } + c := (*[CiphertextSize]byte)(ciphertext) + return kemDecaps(dk, c), nil +} + +// kemDecaps produces a shared key from a ciphertext. +// +// It implements ML-KEM.Decaps according to FIPS 203 (DRAFT), Algorithm 17. +func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) { + h := dk.dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32] + z := dk.dk[decryptionKeySize+encryptionKeySize+32:] + + m := pkeDecrypt(&dk.decryptionKey, c) + g := sha3.New512() + g.Write(m[:]) + g.Write(h) + G := g.Sum(nil) + Kprime, r := G[:SharedKeySize], G[SharedKeySize:] + J := sha3.NewShake256() + J.Write(z) + J.Write(c[:]) + Kout := make([]byte, SharedKeySize) + J.Read(Kout) + var cc [CiphertextSize]byte + c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r) + + subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime) + return Kout +} + +// parseDK parses a decryption key from its encoded form. +// +// It implements the computation of s from K-PKE.Decrypt according to FIPS 203 +// (DRAFT), Algorithm 14. +func parseDK(dx *decryptionKey, dkPKE []byte) error { + if len(dkPKE) != decryptionKeySize { + return errors.New("mlkem768: invalid decryption key length") + } + + for i := range dx.s { + f, err := polyByteDecode[nttElement](dkPKE[:encodingSize12]) + if err != nil { + return err + } + dx.s[i] = f + dkPKE = dkPKE[encodingSize12:] + } + + return nil +} + +// pkeDecrypt decrypts a ciphertext. +// +// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14, +// although the computation of s is done in parseDK. +func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize]byte) []byte { + u := make([]ringElement, k) + for i := range u { + b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)]) + u[i] = ringDecodeAndDecompress10(b) + } + + b := (*[encodingSize4]byte)(c[encodingSize10*k:]) + v := ringDecodeAndDecompress4(b) + + var mask nttElement // s⊺ ◦ NTT(u) + for i := range dx.s { + mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i]))) + } + w := polySub(v, inverseNTT(mask)) + + return ringCompressAndEncode1(nil, w) +} + +// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced. +type fieldElement uint16 + +// fieldCheckReduced checks that a value a is < q. +func fieldCheckReduced(a uint16) (fieldElement, error) { + if a >= q { + return 0, errors.New("unreduced field element") + } + return fieldElement(a), nil +} + +// fieldReduceOnce reduces a value a < 2q. +func fieldReduceOnce(a uint16) fieldElement { + x := a - q + // If x underflowed, then x >= 2¹⁶ - q > 2¹⁵, so the top bit is set. + x += (x >> 15) * q + return fieldElement(x) +} + +func fieldAdd(a, b fieldElement) fieldElement { + x := uint16(a + b) + return fieldReduceOnce(x) +} + +func fieldSub(a, b fieldElement) fieldElement { + x := uint16(a - b + q) + return fieldReduceOnce(x) +} + +const ( + barrettMultiplier = 5039 // 2¹² * 2¹² / q + barrettShift = 24 // log₂(2¹² * 2¹²) +) + +// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid +// potentially variable-time division. +func fieldReduce(a uint32) fieldElement { + quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift) + return fieldReduceOnce(uint16(a - quotient*q)) +} + +func fieldMul(a, b fieldElement) fieldElement { + x := uint32(a) * uint32(b) + return fieldReduce(x) +} + +// fieldMulSub returns a * (b - c). This operation is fused to save a +// fieldReduceOnce after the subtraction. +func fieldMulSub(a, b, c fieldElement) fieldElement { + x := uint32(a) * uint32(b-c+q) + return fieldReduce(x) +} + +// fieldAddMul returns a * b + c * d. This operation is fused to save a +// fieldReduceOnce and a fieldReduce. +func fieldAddMul(a, b, c, d fieldElement) fieldElement { + x := uint32(a) * uint32(b) + x += uint32(c) * uint32(d) + return fieldReduce(x) +} + +// compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to +// FIPS 203 (DRAFT), Definition 4.5. +func compress(x fieldElement, d uint8) uint16 { + // We want to compute (x * 2ᵈ) / q, rounded to nearest integer, with 1/2 + // rounding up (see FIPS 203 (DRAFT), Section 2.3). + + // Barrett reduction produces a quotient and a remainder in the range [0, 2q), + // such that dividend = quotient * q + remainder. + dividend := uint32(x) << d // x * 2ᵈ + quotient := uint32(uint64(dividend) * barrettMultiplier >> barrettShift) + remainder := dividend - quotient*q + + // Since the remainder is in the range [0, 2q), not [0, q), we need to + // portion it into three spans for rounding. + // + // [ 0, q/2 ) -> round to 0 + // [ q/2, q + q/2 ) -> round to 1 + // [ q + q/2, 2q ) -> round to 2 + // + // We can convert that to the following logic: add 1 if remainder > q/2, + // then add 1 again if remainder > q + q/2. + // + // Note that if remainder > x, then ⌊x⌋ - remainder underflows, and the top + // bit of the difference will be set. + quotient += (q/2 - remainder) >> 31 & 1 + quotient += (q + q/2 - remainder) >> 31 & 1 + + // quotient might have overflowed at this point, so reduce it by masking. + var mask uint32 = (1 << d) - 1 + return uint16(quotient & mask) +} + +// decompress maps a number x between 0 and 2ᵈ-1 uniformly to the full range of +// field elements, according to FIPS 203 (DRAFT), Definition 4.6. +func decompress(y uint16, d uint8) fieldElement { + // We want to compute (y * q) / 2ᵈ, rounded to nearest integer, with 1/2 + // rounding up (see FIPS 203 (DRAFT), Section 2.3). + + dividend := uint32(y) * q + quotient := dividend >> d // (y * q) / 2ᵈ + + // The d'th least-significant bit of the dividend (the most significant bit + // of the remainder) is 1 for the top half of the values that divide to the + // same quotient, which are the ones that round up. + quotient += dividend >> (d - 1) & 1 + + // quotient is at most (2¹¹-1) * q / 2¹¹ + 1 = 3328, so it didn't overflow. + return fieldElement(quotient) +} + +// ringElement is a polynomial, an element of R_q, represented as an array +// according to FIPS 203 (DRAFT), Section 2.4. +type ringElement [n]fieldElement + +// polyAdd adds two ringElements or nttElements. +func polyAdd[T ~[n]fieldElement](a, b T) (s T) { + for i := range s { + s[i] = fieldAdd(a[i], b[i]) + } + return s +} + +// polySub subtracts two ringElements or nttElements. +func polySub[T ~[n]fieldElement](a, b T) (s T) { + for i := range s { + s[i] = fieldSub(a[i], b[i]) + } + return s +} + +// polyByteEncode appends the 384-byte encoding of f to b. +// +// It implements ByteEncode₁₂, according to FIPS 203 (DRAFT), Algorithm 4. +func polyByteEncode[T ~[n]fieldElement](b []byte, f T) []byte { + out, B := sliceForAppend(b, encodingSize12) + for i := 0; i < n; i += 2 { + x := uint32(f[i]) | uint32(f[i+1])<<12 + B[0] = uint8(x) + B[1] = uint8(x >> 8) + B[2] = uint8(x >> 16) + B = B[3:] + } + return out +} + +// polyByteDecode decodes the 384-byte encoding of a polynomial, checking that +// all the coefficients are properly reduced. This achieves the "Modulus check" +// step of ML-KEM Encapsulation Input Validation. +// +// polyByteDecode is also used in ML-KEM Decapsulation, where the input +// validation is not required, but implicitly allowed by the specification. +// +// It implements ByteDecode₁₂, according to FIPS 203 (DRAFT), Algorithm 5. +func polyByteDecode[T ~[n]fieldElement](b []byte) (T, error) { + if len(b) != encodingSize12 { + return T{}, errors.New("mlkem768: invalid encoding length") + } + var f T + for i := 0; i < n; i += 2 { + d := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 + const mask12 = 0b1111_1111_1111 + var err error + if f[i], err = fieldCheckReduced(uint16(d & mask12)); err != nil { + return T{}, errors.New("mlkem768: invalid polynomial encoding") + } + if f[i+1], err = fieldCheckReduced(uint16(d >> 12)); err != nil { + return T{}, errors.New("mlkem768: invalid polynomial encoding") + } + b = b[3:] + } + return f, nil +} + +// sliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// ringCompressAndEncode1 appends a 32-byte encoding of a ring element to s, +// compressing one coefficients per bit. +// +// It implements Compress₁, according to FIPS 203 (DRAFT), Definition 4.5, +// followed by ByteEncode₁, according to FIPS 203 (DRAFT), Algorithm 4. +func ringCompressAndEncode1(s []byte, f ringElement) []byte { + s, b := sliceForAppend(s, encodingSize1) + for i := range b { + b[i] = 0 + } + for i := range f { + b[i/8] |= uint8(compress(f[i], 1) << (i % 8)) + } + return s +} + +// ringDecodeAndDecompress1 decodes a 32-byte slice to a ring element where each +// bit is mapped to 0 or ⌈q/2⌋. +// +// It implements ByteDecode₁, according to FIPS 203 (DRAFT), Algorithm 5, +// followed by Decompress₁, according to FIPS 203 (DRAFT), Definition 4.6. +func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement { + var f ringElement + for i := range f { + b_i := b[i/8] >> (i % 8) & 1 + const halfQ = (q + 1) / 2 // ⌈q/2⌋, rounded up per FIPS 203 (DRAFT), Section 2.3 + f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋ + } + return f +} + +// ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s, +// compressing two coefficients per byte. +// +// It implements Compress₄, according to FIPS 203 (DRAFT), Definition 4.5, +// followed by ByteEncode₄, according to FIPS 203 (DRAFT), Algorithm 4. +func ringCompressAndEncode4(s []byte, f ringElement) []byte { + s, b := sliceForAppend(s, encodingSize4) + for i := 0; i < n; i += 2 { + b[i/2] = uint8(compress(f[i], 4) | compress(f[i+1], 4)<<4) + } + return s +} + +// ringDecodeAndDecompress4 decodes a 128-byte encoding of a ring element where +// each four bits are mapped to an equidistant distribution. +// +// It implements ByteDecode₄, according to FIPS 203 (DRAFT), Algorithm 5, +// followed by Decompress₄, according to FIPS 203 (DRAFT), Definition 4.6. +func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement { + var f ringElement + for i := 0; i < n; i += 2 { + f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4)) + f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4)) + } + return f +} + +// ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s, +// compressing four coefficients per five bytes. +// +// It implements Compress₁₀, according to FIPS 203 (DRAFT), Definition 4.5, +// followed by ByteEncode₁₀, according to FIPS 203 (DRAFT), Algorithm 4. +func ringCompressAndEncode10(s []byte, f ringElement) []byte { + s, b := sliceForAppend(s, encodingSize10) + for i := 0; i < n; i += 4 { + var x uint64 + x |= uint64(compress(f[i+0], 10)) + x |= uint64(compress(f[i+1], 10)) << 10 + x |= uint64(compress(f[i+2], 10)) << 20 + x |= uint64(compress(f[i+3], 10)) << 30 + b[0] = uint8(x) + b[1] = uint8(x >> 8) + b[2] = uint8(x >> 16) + b[3] = uint8(x >> 24) + b[4] = uint8(x >> 32) + b = b[5:] + } + return s +} + +// ringDecodeAndDecompress10 decodes a 320-byte encoding of a ring element where +// each ten bits are mapped to an equidistant distribution. +// +// It implements ByteDecode₁₀, according to FIPS 203 (DRAFT), Algorithm 5, +// followed by Decompress₁₀, according to FIPS 203 (DRAFT), Definition 4.6. +func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement { + b := bb[:] + var f ringElement + for i := 0; i < n; i += 4 { + x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 + b = b[5:] + f[i] = fieldElement(decompress(uint16(x>>0&0b11_1111_1111), 10)) + f[i+1] = fieldElement(decompress(uint16(x>>10&0b11_1111_1111), 10)) + f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10)) + f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10)) + } + return f +} + +// samplePolyCBD draws a ringElement from the special Dη distribution given a +// stream of random bytes generated by the PRF function, according to FIPS 203 +// (DRAFT), Algorithm 7 and Definition 4.1. +func samplePolyCBD(s []byte, b byte) ringElement { + prf := sha3.NewShake256() + prf.Write(s) + prf.Write([]byte{b}) + B := make([]byte, 128) + prf.Read(B) + + // SamplePolyCBD simply draws four (2η) bits for each coefficient, and adds + // the first two and subtracts the last two. + + var f ringElement + for i := 0; i < n; i += 2 { + b := B[i/2] + b_7, b_6, b_5, b_4 := b>>7, b>>6&1, b>>5&1, b>>4&1 + b_3, b_2, b_1, b_0 := b>>3&1, b>>2&1, b>>1&1, b&1 + f[i] = fieldSub(fieldElement(b_0+b_1), fieldElement(b_2+b_3)) + f[i+1] = fieldSub(fieldElement(b_4+b_5), fieldElement(b_6+b_7)) + } + return f +} + +// nttElement is an NTT representation, an element of T_q, represented as an +// array according to FIPS 203 (DRAFT), Section 2.4. +type nttElement [n]fieldElement + +// gammas are the values ζ^2BitRev7(i)+1 mod q for each index i. +var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175} + +// nttMul multiplies two nttElements. +// +// It implements MultiplyNTTs, according to FIPS 203 (DRAFT), Algorithm 10. +func nttMul(f, g nttElement) nttElement { + var h nttElement + // We use i += 2 for bounds check elimination. See https://go.dev/issue/66826. + for i := 0; i < 256; i += 2 { + a0, a1 := f[i], f[i+1] + b0, b1 := g[i], g[i+1] + h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2]) + h[i+1] = fieldAddMul(a0, b1, a1, b0) + } + return h +} + +// zetas are the values ζ^BitRev7(k) mod q for each index k. +var zetas = [128]fieldElement{1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154} + +// ntt maps a ringElement to its nttElement representation. +// +// It implements NTT, according to FIPS 203 (DRAFT), Algorithm 8. +func ntt(f ringElement) nttElement { + k := 1 + for len := 128; len >= 2; len /= 2 { + for start := 0; start < 256; start += 2 * len { + zeta := zetas[k] + k++ + // Bounds check elimination hint. + f, flen := f[start:start+len], f[start+len:start+len+len] + for j := 0; j < len; j++ { + t := fieldMul(zeta, flen[j]) + flen[j] = fieldSub(f[j], t) + f[j] = fieldAdd(f[j], t) + } + } + } + return nttElement(f) +} + +// inverseNTT maps a nttElement back to the ringElement it represents. +// +// It implements NTT⁻¹, according to FIPS 203 (DRAFT), Algorithm 9. +func inverseNTT(f nttElement) ringElement { + k := 127 + for len := 2; len <= 128; len *= 2 { + for start := 0; start < 256; start += 2 * len { + zeta := zetas[k] + k-- + // Bounds check elimination hint. + f, flen := f[start:start+len], f[start+len:start+len+len] + for j := 0; j < len; j++ { + t := f[j] + f[j] = fieldAdd(t, flen[j]) + flen[j] = fieldMulSub(zeta, flen[j], t) + } + } + } + for i := range f { + f[i] = fieldMul(f[i], 3303) // 3303 = 128⁻¹ mod q + } + return ringElement(f) +} + +// sampleNTT draws a uniformly random nttElement from a stream of uniformly +// random bytes generated by the XOF function, according to FIPS 203 (DRAFT), +// Algorithm 6 and Definition 4.2. +func sampleNTT(rho []byte, ii, jj byte) nttElement { + B := sha3.NewShake128() + B.Write(rho) + B.Write([]byte{ii, jj}) + + // SampleNTT essentially draws 12 bits at a time from r, interprets them in + // little-endian, and rejects values higher than q, until it drew 256 + // values. (The rejection rate is approximately 19%.) + // + // To do this from a bytes stream, it draws three bytes at a time, and + // splits them into two uint16 appropriately masked. + // + // r₀ r₁ r₂ + // |- - - - - - - -|- - - - - - - -|- - - - - - - -| + // + // Uint16(r₀ || r₁) + // |- - - - - - - - - - - - - - - -| + // |- - - - - - - - - - - -| + // d₁ + // + // Uint16(r₁ || r₂) + // |- - - - - - - - - - - - - - - -| + // |- - - - - - - - - - - -| + // d₂ + // + // Note that in little-endian, the rightmost bits are the most significant + // bits (dropped with a mask) and the leftmost bits are the least + // significant bits (dropped with a right shift). + + var a nttElement + var j int // index into a + var buf [24]byte // buffered reads from B + off := len(buf) // index into buf, starts in a "buffer fully consumed" state + for { + if off >= len(buf) { + B.Read(buf[:]) + off = 0 + } + d1 := binary.LittleEndian.Uint16(buf[off:]) & 0b1111_1111_1111 + d2 := binary.LittleEndian.Uint16(buf[off+1:]) >> 4 + off += 3 + if d1 < q { + a[j] = fieldElement(d1) + j++ + } + if j >= len(a) { + break + } + if d2 < q { + a[j] = fieldElement(d2) + j++ + } + if j >= len(a) { + break + } + } + return a +}