diff --git a/tls.go b/tls.go index 182736a..5b666ca 100644 --- a/tls.go +++ b/tls.go @@ -211,11 +211,23 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if copying || err != nil || hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] { break } + var peerPub []byte for _, keyShare := range hs.clientHello.keyShares { - if keyShare.group != X25519 || len(keyShare.data) != 32 { - continue + if keyShare.group == X25519 && len(keyShare.data) == 32 { + peerPub = keyShare.data + break } - if hs.c.AuthKey, err = curve25519.X25519(config.PrivateKey, keyShare.data); err != nil { + } + if peerPub == nil { + for _, keyShare := range hs.clientHello.keyShares { + if keyShare.group == X25519MLKEM768 && len(keyShare.data) == mlkem.EncapsulationKeySize768+32 { + peerPub = keyShare.data[mlkem.EncapsulationKeySize768:] + break + } + } + } + for peerPub != nil { + if hs.c.AuthKey, err = curve25519.X25519(config.PrivateKey, peerPub); err != nil { break } if _, err = hkdf.New(sha256.New, hs.c.AuthKey, hs.clientHello.random[:20], []byte("REALITY")).Read(hs.c.AuthKey); err != nil { @@ -426,7 +438,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { waitGroup.Wait() target.Close() if config.Show { - fmt.Printf("REALITY remoteAddr: %v\ths.c.handshakeStatus: %v\n", remoteAddr, hs.c.isHandshakeComplete.Load()) + fmt.Printf("REALITY remoteAddr: %v\ths.c.isHandshakeComplete.Load(): %v\n", remoteAddr, hs.c.isHandshakeComplete.Load()) } if hs.c.isHandshakeComplete.Load() { return hs.c, nil