0
0
mirror of https://github.com/XTLS/REALITY.git synced 2025-08-22 22:48:36 +00:00

Update hpye.go to 760f228

This commit is contained in:
yuhan6665 2025-05-04 23:20:28 -04:00
parent c14471f843
commit 02afebcf30

View File

@ -9,13 +9,14 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/ecdh" "crypto/ecdh"
"crypto/hkdf"
"crypto/rand" "crypto/rand"
"encoding/binary"
"errors" "errors"
"math/bits" "math/bits"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
"github.com/xtls/reality/byteorder"
) )
// testingOnlyGenerateKey is only used during testing, to provide // testingOnlyGenerateKey is only used during testing, to provide
@ -26,28 +27,23 @@ type hkdfKDF struct {
hash crypto.Hash hash crypto.Hash
} }
func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte { func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) {
labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey)) labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
labeledIKM = append(labeledIKM, []byte("HPKE-v1")...) labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
labeledIKM = append(labeledIKM, suiteID...) labeledIKM = append(labeledIKM, sid...)
labeledIKM = append(labeledIKM, label...) labeledIKM = append(labeledIKM, label...)
labeledIKM = append(labeledIKM, inputKey...) labeledIKM = append(labeledIKM, inputKey...)
return hkdf.Extract(kdf.hash.New, labeledIKM, salt) return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
} }
func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte { func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) {
labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info)) labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
labeledInfo = binary.BigEndian.AppendUint16(labeledInfo, length) labeledInfo = byteorder.BEAppendUint16(labeledInfo, length)
labeledInfo = append(labeledInfo, []byte("HPKE-v1")...) labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
labeledInfo = append(labeledInfo, suiteID...) labeledInfo = append(labeledInfo, suiteID...)
labeledInfo = append(labeledInfo, label...) labeledInfo = append(labeledInfo, label...)
labeledInfo = append(labeledInfo, info...) labeledInfo = append(labeledInfo, info...)
out := make([]byte, length) return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length))
n, err := hkdf.Expand(kdf.hash.New, randomKey, labeledInfo).Read(out)
if err != nil || n != int(length) {
panic("hpke: LabeledExpand failed unexpectedly")
}
return out
} }
// dhKEM implements the KEM specified in RFC 9180, Section 4.1. // dhKEM implements the KEM specified in RFC 9180, Section 4.1.
@ -59,13 +55,17 @@ type dhKEM struct {
nSecret uint16 nSecret uint16
} }
type KemID uint16
const DHKEM_X25519_HKDF_SHA256 = 0x0020
var SupportedKEMs = map[uint16]struct { var SupportedKEMs = map[uint16]struct {
curve ecdh.Curve curve ecdh.Curve
hash crypto.Hash hash crypto.Hash
nSecret uint16 nSecret uint16
}{ }{
// RFC 9180 Section 7.1 // RFC 9180 Section 7.1
0x0020: {ecdh.X25519(), crypto.SHA256, 32}, DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
} }
func newDHKem(kemID uint16) (*dhKEM, error) { func newDHKem(kemID uint16) (*dhKEM, error) {
@ -76,13 +76,16 @@ func newDHKem(kemID uint16) (*dhKEM, error) {
return &dhKEM{ return &dhKEM{
dh: suite.curve, dh: suite.curve,
kdf: hkdfKDF{suite.hash}, kdf: hkdfKDF{suite.hash},
suiteID: binary.BigEndian.AppendUint16([]byte("KEM"), kemID), suiteID: byteorder.BEAppendUint16([]byte("KEM"), kemID),
nSecret: suite.nSecret, nSecret: suite.nSecret,
}, nil }, nil
} }
func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte { func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey) eaePRK, err := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
if err != nil {
return nil, err
}
return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret) return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
} }
@ -104,13 +107,28 @@ func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encap
encPubRecip := pubRecipient.Bytes() encPubRecip := pubRecipient.Bytes()
kemContext := append(encPubEph, encPubRecip...) kemContext := append(encPubEph, encPubRecip...)
sharedSecret, err = dh.ExtractAndExpand(dhVal, kemContext)
return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil if err != nil {
return nil, nil, err
}
return sharedSecret, encPubEph, nil
} }
type Sender struct { func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
pubEph, err := dh.dh.NewPublicKey(encPubEph)
if err != nil {
return nil, err
}
dhVal, err := secRecipient.ECDH(pubEph)
if err != nil {
return nil, err
}
kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
return dh.ExtractAndExpand(dhVal, kemContext)
}
type context struct {
aead cipher.AEAD aead cipher.AEAD
kem *dhKEM
sharedSecret []byte sharedSecret []byte
@ -123,6 +141,14 @@ type Sender struct {
seqNum uint128 seqNum uint128
} }
type Sender struct {
*context
}
type Recipient struct {
*context
}
var aesGCMNew = func(key []byte) (cipher.AEAD, error) { var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
if err != nil { if err != nil {
@ -131,102 +157,165 @@ var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
return cipher.NewGCM(block) return cipher.NewGCM(block)
} }
type AEADID uint16
const (
AEAD_AES_128_GCM = 0x0001
AEAD_AES_256_GCM = 0x0002
AEAD_ChaCha20Poly1305 = 0x0003
)
var SupportedAEADs = map[uint16]struct { var SupportedAEADs = map[uint16]struct {
keySize int keySize int
nonceSize int nonceSize int
aead func([]byte) (cipher.AEAD, error) aead func([]byte) (cipher.AEAD, error)
}{ }{
// RFC 9180, Section 7.3 // RFC 9180, Section 7.3
0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew}, AEAD_AES_128_GCM: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew}, AEAD_AES_256_GCM: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New}, AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
} }
type KDFID uint16
const KDF_HKDF_SHA256 = 0x0001
var SupportedKDFs = map[uint16]func() *hkdfKDF{ var SupportedKDFs = map[uint16]func() *hkdfKDF{
// RFC 9180, Section 7.2 // RFC 9180, Section 7.2
0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} }, KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
} }
func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) { func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
suiteID := SuiteID(kemID, kdfID, aeadID) sid := suiteID(kemID, kdfID, aeadID)
kem, err := newDHKem(kemID)
if err != nil {
return nil, nil, err
}
pubRecipient, ok := pub.(*ecdh.PublicKey)
if !ok {
return nil, nil, errors.New("incorrect public key type")
}
sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient)
if err != nil {
return nil, nil, err
}
kdfInit, ok := SupportedKDFs[kdfID] kdfInit, ok := SupportedKDFs[kdfID]
if !ok { if !ok {
return nil, nil, errors.New("unsupported KDF id") return nil, errors.New("unsupported KDF id")
} }
kdf := kdfInit() kdf := kdfInit()
aeadInfo, ok := SupportedAEADs[aeadID] aeadInfo, ok := SupportedAEADs[aeadID]
if !ok { if !ok {
return nil, nil, errors.New("unsupported AEAD id") return nil, errors.New("unsupported AEAD id")
} }
pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil) pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info) if err != nil {
return nil, err
}
infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info)
if err != nil {
return nil, err
}
ksContext := append([]byte{0}, pskIDHash...) ksContext := append([]byte{0}, pskIDHash...)
ksContext = append(ksContext, infoHash...) ksContext = append(ksContext, infoHash...)
secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil) secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
if err != nil {
key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) return nil, err
baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) }
exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
if err != nil {
return nil, err
}
baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
if err != nil {
return nil, err
}
exporterSecret, err := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
if err != nil {
return nil, err
}
aead, err := aeadInfo.aead(key) aead, err := aeadInfo.aead(key)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
return encapsulatedKey, &Sender{ return &context{
kem: kem,
aead: aead, aead: aead,
sharedSecret: sharedSecret, sharedSecret: sharedSecret,
suiteID: suiteID, suiteID: sid,
key: key, key: key,
baseNonce: baseNonce, baseNonce: baseNonce,
exporterSecret: exporterSecret, exporterSecret: exporterSecret,
}, nil }, nil
} }
func (s *Sender) nextNonce() []byte { func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) {
nonce := s.seqNum.bytes()[16-s.aead.NonceSize():] kem, err := newDHKem(kemID)
for i := range s.baseNonce { if err != nil {
nonce[i] ^= s.baseNonce[i] return nil, nil, err
} }
// Message limit is, according to the RFC, 2^95+1, which sharedSecret, encapsulatedKey, err := kem.Encap(pub)
// is somewhat confusing, but we do as we're told. if err != nil {
if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 { return nil, nil, err
panic("message limit reached") }
context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
if err != nil {
return nil, nil, err
}
return encapsulatedKey, &Sender{context}, nil
}
func SetupRecipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Recipient, error) {
kem, err := newDHKem(kemID)
if err != nil {
return nil, err
}
sharedSecret, err := kem.Decap(encPubEph, priv)
if err != nil {
return nil, err
}
context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
if err != nil {
return nil, err
}
return &Recipient{context}, nil
}
func (ctx *context) nextNonce() []byte {
nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():]
for i := range ctx.baseNonce {
nonce[i] ^= ctx.baseNonce[i]
} }
s.seqNum = s.seqNum.addOne()
return nonce return nonce
} }
func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) { func (ctx *context) incrementNonce() {
// Message limit is, according to the RFC, 2^95+1, which
// is somewhat confusing, but we do as we're told.
if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
panic("message limit reached")
}
ctx.seqNum = ctx.seqNum.addOne()
}
func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad) ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
s.incrementNonce()
return ciphertext, nil return ciphertext, nil
} }
func SuiteID(kemID, kdfID, aeadID uint16) []byte { func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) {
plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
if err != nil {
return nil, err
}
r.incrementNonce()
return plaintext, nil
}
func suiteID(kemID, kdfID, aeadID uint16) []byte {
suiteID := make([]byte, 0, 4+2+2+2) suiteID := make([]byte, 0, 4+2+2+2)
suiteID = append(suiteID, []byte("HPKE")...) suiteID = append(suiteID, []byte("HPKE")...)
suiteID = binary.BigEndian.AppendUint16(suiteID, kemID) suiteID = byteorder.BEAppendUint16(suiteID, kemID)
suiteID = binary.BigEndian.AppendUint16(suiteID, kdfID) suiteID = byteorder.BEAppendUint16(suiteID, kdfID)
suiteID = binary.BigEndian.AppendUint16(suiteID, aeadID) suiteID = byteorder.BEAppendUint16(suiteID, aeadID)
return suiteID return suiteID
} }
@ -238,6 +327,14 @@ func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
return kemInfo.curve.NewPublicKey(bytes) return kemInfo.curve.NewPublicKey(bytes)
} }
func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
kemInfo, ok := SupportedKEMs[kemID]
if !ok {
return nil, errors.New("unsupported KEM id")
}
return kemInfo.curve.NewPrivateKey(bytes)
}
type uint128 struct { type uint128 struct {
hi, lo uint64 hi, lo uint64
} }
@ -253,7 +350,7 @@ func (u uint128) bitLen() int {
func (u uint128) bytes() []byte { func (u uint128) bytes() []byte {
b := make([]byte, 16) b := make([]byte, 16)
binary.BigEndian.PutUint64(b[0:], u.hi) byteorder.BEPutUint64(b[0:], u.hi)
binary.BigEndian.PutUint64(b[8:], u.lo) byteorder.BEPutUint64(b[8:], u.lo)
return b return b
} }