0
0
mirror of https://github.com/XTLS/REALITY.git synced 2025-08-22 14:38:35 +00:00

REALITY practice: Support X25519MLKEM768 for TLS' communication

Thank https://github.com/XTLS/REALITY/pull/14 @yuhan6665
This commit is contained in:
RPRX 2025-05-12 20:18:51 +00:00 committed by GitHub
parent ce2747b9b0
commit f07c896f71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 7 deletions

View File

@ -94,11 +94,31 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
c.cipherSuite = hs.suite.id
hs.transcript = hs.suite.hash.New()
var peerData []byte
for _, keyShare := range hs.clientHello.keyShares {
if keyShare.group == hs.hello.serverShare.group {
peerData = keyShare.data
break
}
}
var peerPub = peerData
if hs.hello.serverShare.group == X25519MLKEM768 {
peerPub = peerData[mlkem.EncapsulationKeySize768:]
}
key, _ := generateECDHEKey(c.config.rand(), X25519)
copy(hs.hello.serverShare.data, key.PublicKey().Bytes())
peerKey, _ := key.Curve().NewPublicKey(hs.clientHello.keyShares[hs.clientHello.keyShares[0].group].data)
peerKey, _ := key.Curve().NewPublicKey(peerPub)
hs.sharedKey, _ = key.ECDH(peerKey)
if hs.hello.serverShare.group == X25519MLKEM768 {
k, _ := mlkem.NewEncapsulationKey768(peerData[:mlkem.EncapsulationKeySize768])
mlkemSharedSecret, ciphertext := k.Encapsulate()
hs.sharedKey = append(mlkemSharedSecret, hs.sharedKey...)
copy(hs.hello.serverShare.data, append(ciphertext, hs.hello.serverShare.data[:32]...))
}
c.serverName = hs.clientHello.serverName
}
/*

9
tls.go
View File

@ -34,6 +34,7 @@ import (
"crypto/cipher"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/mlkem"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
@ -54,8 +55,8 @@ import (
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
"github.com/xtls/reality/gcm"
fipsaes "github.com/xtls/reality/aes"
"github.com/xtls/reality/gcm"
)
type CloseWriteConn interface {
@ -180,7 +181,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
if copying || err != nil || hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
break
}
for i, keyShare := range hs.clientHello.keyShares {
for _, keyShare := range hs.clientHello.keyShares {
if keyShare.group != X25519 || len(keyShare.data) != 32 {
continue
}
@ -222,7 +223,6 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
(config.ShortIds[hs.c.ClientShortId]) {
hs.c.conn = conn
}
hs.clientHello.keyShares[0].group = CurveID(i)
break
}
if config.Show {
@ -308,7 +308,8 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
if !hs.hello.unmarshal(s2cSaved[recordHeaderLen:handshakeLen]) ||
hs.hello.vers != VersionTLS12 || hs.hello.supportedVersion != VersionTLS13 ||
cipherSuiteTLS13ByID(hs.hello.cipherSuite) == nil ||
hs.hello.serverShare.group != X25519 || len(hs.hello.serverShare.data) != 32 {
(!(hs.hello.serverShare.group == X25519 && len(hs.hello.serverShare.data) == 32) &&
!(hs.hello.serverShare.group == X25519MLKEM768 && len(hs.hello.serverShare.data) == mlkem.CiphertextSize768+32)) {
break f
}
}