diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 2242e5c..9031489 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -25,7 +25,7 @@ import ( "slices" "sort" "time" - + "github.com/xtls/reality/fips140tls" "github.com/xtls/reality/hpke" "github.com/xtls/reality/tls13" @@ -93,12 +93,32 @@ func (hs *serverHandshakeStateTLS13) handshake() error { hs.suite = cipherSuiteTLS13ByID(hs.hello.cipherSuite) c.cipherSuite = hs.suite.id hs.transcript = hs.suite.hash.New() - + + var peerData []byte + for _, keyShare := range hs.clientHello.keyShares { + if keyShare.group == hs.hello.serverShare.group { + peerData = keyShare.data + break + } + } + + var peerPub = peerData + if hs.hello.serverShare.group == X25519MLKEM768 { + peerPub = peerData[mlkem.EncapsulationKeySize768:] + } + key, _ := generateECDHEKey(c.config.rand(), X25519) copy(hs.hello.serverShare.data, key.PublicKey().Bytes()) - peerKey, _ := key.Curve().NewPublicKey(hs.clientHello.keyShares[hs.clientHello.keyShares[0].group].data) + peerKey, _ := key.Curve().NewPublicKey(peerPub) hs.sharedKey, _ = key.ECDH(peerKey) + if hs.hello.serverShare.group == X25519MLKEM768 { + k, _ := mlkem.NewEncapsulationKey768(peerData[:mlkem.EncapsulationKeySize768]) + mlkemSharedSecret, ciphertext := k.Encapsulate() + hs.sharedKey = append(mlkemSharedSecret, hs.sharedKey...) + copy(hs.hello.serverShare.data, append(ciphertext, hs.hello.serverShare.data[:32]...)) + } + c.serverName = hs.clientHello.serverName } /* diff --git a/tls.go b/tls.go index 2b47035..5ee87e6 100644 --- a/tls.go +++ b/tls.go @@ -34,6 +34,7 @@ import ( "crypto/cipher" "crypto/ecdsa" "crypto/ed25519" + "crypto/mlkem" "crypto/rsa" "crypto/sha256" "crypto/x509" @@ -54,8 +55,8 @@ import ( "golang.org/x/crypto/curve25519" "golang.org/x/crypto/hkdf" - "github.com/xtls/reality/gcm" fipsaes "github.com/xtls/reality/aes" + "github.com/xtls/reality/gcm" ) type CloseWriteConn interface { @@ -180,7 +181,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if copying || err != nil || hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] { break } - for i, keyShare := range hs.clientHello.keyShares { + for _, keyShare := range hs.clientHello.keyShares { if keyShare.group != X25519 || len(keyShare.data) != 32 { continue } @@ -222,7 +223,6 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { (config.ShortIds[hs.c.ClientShortId]) { hs.c.conn = conn } - hs.clientHello.keyShares[0].group = CurveID(i) break } if config.Show { @@ -308,7 +308,8 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if !hs.hello.unmarshal(s2cSaved[recordHeaderLen:handshakeLen]) || hs.hello.vers != VersionTLS12 || hs.hello.supportedVersion != VersionTLS13 || cipherSuiteTLS13ByID(hs.hello.cipherSuite) == nil || - hs.hello.serverShare.group != X25519 || len(hs.hello.serverShare.data) != 32 { + (!(hs.hello.serverShare.group == X25519 && len(hs.hello.serverShare.data) == 32) && + !(hs.hello.serverShare.group == X25519MLKEM768 && len(hs.hello.serverShare.data) == mlkem.CiphertextSize768+32)) { break f } }