diff --git a/infra/conf/vless.go b/infra/conf/vless.go index 44503912..b53208dd 100644 --- a/infra/conf/vless.go +++ b/infra/conf/vless.go @@ -72,36 +72,35 @@ func (c *VLessInboundConfig) Build() (proto.Message, error) { config.Decryption = c.Decryption if !func() bool { s := strings.Split(config.Decryption, ".") - if len(s) != 5 || s[2] != "mlkem768Seed" { + if len(s) < 4 || s[0] != "mlkem768x25519plus" { return false } - if s[0] != "1rtt" { - t := strings.TrimSuffix(s[0], "min") - if t == s[0] { - return false - } - i, err := strconv.Atoi(t) - if err != nil { - return false - } - config.Minutes = uint32(i) - } switch s[1] { case "native": - case "divide": + case "xorpub": config.XorMode = 1 case "random": config.XorMode = 2 default: return false } - if b, _ := base64.RawURLEncoding.DecodeString(s[3]); len(b) != 32 { - return false + if s[2] != "1rtt" { + t := strings.TrimSuffix(s[2], "s") + if t == s[2] { + return false + } + i, err := strconv.Atoi(t) + if err != nil { + return false + } + config.Seconds = uint32(i) } - if b, _ := base64.RawURLEncoding.DecodeString(s[4]); len(b) != 64 { - return false + for i := 3; i < len(s); i++ { + if b, _ := base64.RawURLEncoding.DecodeString(s[i]); len(b) != 32 && len(b) != 64 { + return false + } } - config.Decryption = s[4] + "." + s[3] + config.Decryption = config.Decryption[27+len(s[2]):] return true }() && config.Decryption != "none" { if config.Decryption == "" { @@ -220,36 +219,31 @@ func (c *VLessOutboundConfig) Build() (proto.Message, error) { if !func() bool { s := strings.Split(account.Encryption, ".") - if len(s) != 5 || s[2] != "mlkem768Client" { + if len(s) < 4 || s[0] != "mlkem768x25519plus" { return false } - if s[0] != "1rtt" { - t := strings.TrimSuffix(s[0], "min") - if t == s[0] { - return false - } - i, err := strconv.Atoi(t) - if err != nil { - return false - } - account.Minutes = uint32(i) - } switch s[1] { case "native": - case "divide": + case "xorpub": account.XorMode = 1 case "random": account.XorMode = 2 default: return false } - if b, _ := base64.RawURLEncoding.DecodeString(s[3]); len(b) != 32 { + switch s[2] { + case "1rtt": + case "0rtt": + account.Seconds = 1 + default: return false } - if b, _ := base64.RawURLEncoding.DecodeString(s[4]); len(b) != 1184 { - return false + for i := 3; i < len(s); i++ { + if b, _ := base64.RawURLEncoding.DecodeString(s[i]); len(b) != 32 && len(b) != 1184 { + return false + } } - account.Encryption = s[4] + "." + s[3] + account.Encryption = account.Encryption[27+len(s[2]):] return true }() && account.Encryption != "none" { if account.Encryption == "" { diff --git a/main/commands/all/curve25519.go b/main/commands/all/curve25519.go index c3c516ad..16ca8c7c 100644 --- a/main/commands/all/curve25519.go +++ b/main/commands/all/curve25519.go @@ -5,6 +5,8 @@ import ( "crypto/rand" "encoding/base64" "fmt" + + "lukechampine.com/blake3" ) func Curve25519Genkey(StdEncoding bool, input_base64 string) { @@ -40,7 +42,10 @@ func Curve25519Genkey(StdEncoding bool, input_base64 string) { fmt.Println(err.Error()) return } - fmt.Printf("PrivateKey: %v\nPassword: %v", + password := key.PublicKey().Bytes() + hash32 := blake3.Sum256(password) + fmt.Printf("PrivateKey: %v\nPassword: %v\nHash32: %v", encoding.EncodeToString(privateKey), - encoding.EncodeToString(key.PublicKey().Bytes())) + encoding.EncodeToString(password), + encoding.EncodeToString(hash32[:])) } diff --git a/main/commands/all/mlkem768.go b/main/commands/all/mlkem768.go index f3cb0f79..0f6e707b 100644 --- a/main/commands/all/mlkem768.go +++ b/main/commands/all/mlkem768.go @@ -3,11 +3,11 @@ package all import ( "crypto/mlkem" "crypto/rand" - "crypto/sha3" "encoding/base64" "fmt" "github.com/xtls/xray-core/main/commands/base" + "lukechampine.com/blake3" ) var cmdMLKEM768 = &base.Command{ @@ -42,9 +42,9 @@ func executeMLKEM768(cmd *base.Command, args []string) { } key, _ := mlkem.NewDecapsulationKey768(seed[:]) client := key.EncapsulationKey().Bytes() - hash32 := sha3.Sum256(client) - fmt.Printf("Seed: %v\nClient: %v\nHash11: %v", + hash32 := blake3.Sum256(client) + fmt.Printf("Seed: %v\nClient: %v\nHash32: %v", base64.RawURLEncoding.EncodeToString(seed[:]), base64.RawURLEncoding.EncodeToString(client), - base64.RawURLEncoding.EncodeToString(hash32[:11])) + base64.RawURLEncoding.EncodeToString(hash32[:])) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 8251849d..049d9fbd 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -530,19 +530,12 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { var readCounter, writerCounter stats.Counter if conn != nil { isEncryption := false - if clientConn, ok := conn.(*encryption.ClientConn); ok { - conn = clientConn.Conn - isEncryption = true - } - if serverConn, ok := conn.(*encryption.ServerConn); ok { - conn = serverConn.Conn + if commonConn, ok := conn.(*encryption.CommonConn); ok { + conn = commonConn.Conn isEncryption = true } if xorConn, ok := conn.(*encryption.XorConn); ok { - if !xorConn.Divide { - return xorConn, nil, nil // full-random xorConn should not be penetrated - } - conn = xorConn.Conn + return xorConn, nil, nil // full-random xorConn should not be penetrated } if statConn, ok := conn.(*stat.CounterConnection); ok { conn = statConn.Connection diff --git a/proxy/vless/account.go b/proxy/vless/account.go index 9967c7e1..b1e09619 100644 --- a/proxy/vless/account.go +++ b/proxy/vless/account.go @@ -19,7 +19,7 @@ func (a *Account) AsAccount() (protocol.Account, error) { Flow: a.Flow, // needs parser here? Encryption: a.Encryption, // needs parser here? XorMode: a.XorMode, - Minutes: a.Minutes, + Seconds: a.Seconds, }, nil } @@ -32,7 +32,7 @@ type MemoryAccount struct { Encryption string XorMode uint32 - Minutes uint32 + Seconds uint32 } // Equals implements protocol.Account.Equals(). @@ -50,6 +50,6 @@ func (a *MemoryAccount) ToProto() proto.Message { Flow: a.Flow, Encryption: a.Encryption, XorMode: a.XorMode, - Minutes: a.Minutes, + Seconds: a.Seconds, } } diff --git a/proxy/vless/account.pb.go b/proxy/vless/account.pb.go index ca638de1..6048dc4e 100644 --- a/proxy/vless/account.pb.go +++ b/proxy/vless/account.pb.go @@ -31,7 +31,7 @@ type Account struct { Flow string `protobuf:"bytes,2,opt,name=flow,proto3" json:"flow,omitempty"` Encryption string `protobuf:"bytes,3,opt,name=encryption,proto3" json:"encryption,omitempty"` XorMode uint32 `protobuf:"varint,4,opt,name=xorMode,proto3" json:"xorMode,omitempty"` - Minutes uint32 `protobuf:"varint,5,opt,name=minutes,proto3" json:"minutes,omitempty"` + Seconds uint32 `protobuf:"varint,5,opt,name=seconds,proto3" json:"seconds,omitempty"` } func (x *Account) Reset() { @@ -92,9 +92,9 @@ func (x *Account) GetXorMode() uint32 { return 0 } -func (x *Account) GetMinutes() uint32 { +func (x *Account) GetSeconds() uint32 { if x != nil { - return x.Minutes + return x.Seconds } return 0 } @@ -111,8 +111,8 @@ var file_proxy_vless_account_proto_rawDesc = []byte{ 0x0a, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x18, 0x0a, 0x07, 0x78, 0x6f, 0x72, 0x4d, 0x6f, 0x64, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, - 0x78, 0x6f, 0x72, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x69, 0x6e, 0x75, 0x74, - 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x6d, 0x69, 0x6e, 0x75, 0x74, 0x65, + 0x78, 0x6f, 0x72, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x63, 0x6f, 0x6e, + 0x64, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x42, 0x52, 0x0a, 0x14, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x76, 0x6c, 0x65, 0x73, 0x73, 0x50, 0x01, 0x5a, 0x25, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, diff --git a/proxy/vless/account.proto b/proxy/vless/account.proto index f91a2528..ebb1feff 100644 --- a/proxy/vless/account.proto +++ b/proxy/vless/account.proto @@ -14,5 +14,5 @@ message Account { string encryption = 3; uint32 xorMode = 4; - uint32 minutes = 5; + uint32 seconds = 5; } diff --git a/proxy/vless/encryption/client.go b/proxy/vless/encryption/client.go index a0ae0d2b..47a2408f 100644 --- a/proxy/vless/encryption/client.go +++ b/proxy/vless/encryption/client.go @@ -1,266 +1,216 @@ package encryption import ( - "bytes" "crypto/cipher" "crypto/ecdh" "crypto/mlkem" "crypto/rand" - "crypto/sha3" "io" "net" - "strings" "sync" "time" "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/protocol" + "lukechampine.com/blake3" ) -var ClientCipher byte - -func init() { - if protocol.HasAESGCMHardwareSupport { - ClientCipher = 1 - } -} - type ClientInstance struct { - sync.RWMutex - nfsEKey *mlkem.EncapsulationKey768 - hash11 [11]byte // no more capacity - xorMode uint32 - xorPKey *ecdh.PublicKey - minutes time.Duration - expire time.Time - baseKey []byte - ticket []byte + NfsPKeys []any + NfsPKeysBytes [][]byte + Hash32s [][32]byte + RelaysLength int + XorMode uint32 + Seconds uint32 + + RWLock sync.RWMutex + Expire time.Time + PfsKey []byte + Ticket []byte } -type ClientConn struct { - net.Conn - instance *ClientInstance - baseKey []byte - ticket []byte - random []byte - aead cipher.AEAD - nonce []byte - peerAEAD cipher.AEAD - peerNonce []byte - PeerCache []byte -} - -func (i *ClientInstance) Init(nfsEKeyBytes, xorPKeyBytes []byte, xorMode, minutes uint32) (err error) { - if i.nfsEKey != nil { +func (i *ClientInstance) Init(nfsPKeysBytes [][]byte, xorMode, seconds uint32) (err error) { + if i.NfsPKeys != nil { err = errors.New("already initialized") return } - if i.nfsEKey, err = mlkem.NewEncapsulationKey768(nfsEKeyBytes); err != nil { + l := len(nfsPKeysBytes) + if l == 0 { + err = errors.New("empty nfsPKeysBytes") return } - if xorMode > 0 { - i.xorMode = xorMode - if i.xorPKey, err = ecdh.X25519().NewPublicKey(xorPKeyBytes); err != nil { - return + i.NfsPKeys = make([]any, l) + i.NfsPKeysBytes = nfsPKeysBytes + i.Hash32s = make([][32]byte, l) + for j, k := range nfsPKeysBytes { + if len(k) == 32 { + if i.NfsPKeys[j], err = ecdh.X25519().NewPublicKey(k); err != nil { + return + } + i.RelaysLength += 32 + 32 + } else { + if i.NfsPKeys[j], err = mlkem.NewEncapsulationKey768(k); err != nil { + return + } + i.RelaysLength += 1088 + 32 } - hash32 := sha3.Sum256(nfsEKeyBytes) - copy(i.hash11[:], hash32[:]) + i.Hash32s[j] = blake3.Sum256(k) } - i.minutes = time.Duration(minutes) * time.Minute + i.RelaysLength -= 32 + i.XorMode = xorMode + i.Seconds = seconds return } -func (i *ClientInstance) Handshake(conn net.Conn) (*ClientConn, error) { - if i.nfsEKey == nil { +func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { + if i.NfsPKeys == nil { return nil, errors.New("uninitialized") } - if i.xorMode > 0 { - conn, _ = NewXorConn(conn, i.xorMode, i.xorPKey, nil) - } - c := &ClientConn{Conn: conn} + c := &CommonConn{Conn: conn} - if i.minutes > 0 { - i.RLock() - if time.Now().Before(i.expire) { - c.instance = i - c.baseKey = i.baseKey - c.ticket = i.ticket - i.RUnlock() + ivAndRealysLength := 16 + i.RelaysLength + pfsKeyExchangeLength := 18 + 1184 + 32 + 16 + paddingLength := int(crypto.RandBetween(100, 1000)) + clientHello := make([]byte, ivAndRealysLength+pfsKeyExchangeLength+paddingLength) + + iv := clientHello[:16] + rand.Read(iv) + relays := clientHello[16:ivAndRealysLength] + var nfsPublicKey, nfsKey []byte + var lastCTR cipher.Stream + for j, k := range i.NfsPKeys { + var index = 32 + if k, ok := k.(*ecdh.PublicKey); ok { + privateKey, _ := ecdh.X25519().GenerateKey(rand.Reader) + nfsPublicKey = privateKey.PublicKey().Bytes() + copy(relays, nfsPublicKey) + var err error + nfsKey, err = privateKey.ECDH(k) + if err != nil { + return nil, err + } + } + if k, ok := k.(*mlkem.EncapsulationKey768); ok { + nfsKey, nfsPublicKey = k.Encapsulate() + copy(relays, nfsPublicKey) + index = 1088 + } + if i.XorMode > 0 { // this xor can (others can't) be decrypted by client's config, revealing an X25519 public key / ML-KEM-768 ciphertext, but it is not important + NewCTR(i.NfsPKeysBytes[j], iv).XORKeyStream(relays, relays[:index]) // make X25519 public key / ML-KEM-768 ciphertext distinguishable from random bytes + } + if lastCTR != nil { + lastCTR.XORKeyStream(relays, relays[:32]) // make this relay irreplaceable + } + if j == len(i.NfsPKeys)-1 { + break + } + lastCTR = NewCTR(nfsKey, iv) + lastCTR.XORKeyStream(relays[index:], i.Hash32s[j+1][:]) + relays = relays[index+32:] + } + nfsGCM := NewGCM(nfsPublicKey, nfsKey) + + if i.Seconds > 0 { + i.RWLock.RLock() + if time.Now().Before(i.Expire) { + c.Client = i + c.UnitedKey = append(i.PfsKey, nfsKey...) + nfsGCM.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil) + nfsGCM.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil) + i.RWLock.RUnlock() + c.PreWrite = clientHello[:ivAndRealysLength+18+32] + c.GCM = NewGCM(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey) + if i.XorMode == 2 { + c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), nil, len(c.PreWrite), 32) + } return c, nil } - i.RUnlock() + i.RWLock.RUnlock() } - pfsDKeySeed := make([]byte, 64) - rand.Read(pfsDKeySeed) - pfsDKey, _ := mlkem.NewDecapsulationKey768(pfsDKeySeed) - pfsEKeyBytes := pfsDKey.EncapsulationKey().Bytes() - nfsKey, encapsulatedNfsKey := i.nfsEKey.Encapsulate() - nfsAEAD := NewAEAD(ClientCipher, nfsKey, pfsEKeyBytes, encapsulatedNfsKey) + pfsKeyExchange := clientHello[ivAndRealysLength : ivAndRealysLength+pfsKeyExchangeLength] + nfsGCM.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil) + mlkem768DKey, _ := mlkem.GenerateKey768() + x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader) + pfsPublicKey := append(mlkem768DKey.EncapsulationKey().Bytes(), x25519SKey.PublicKey().Bytes()...) + nfsGCM.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil) - clientHello := make([]byte, 5+11+1+1184+1088+crypto.RandBetween(100, 1000)) - EncodeHeader(clientHello, 1, 11+1+1184+1088) - copy(clientHello[5:], i.hash11[:]) - clientHello[5+11] = ClientCipher - copy(clientHello[5+11+1:], pfsEKeyBytes) - copy(clientHello[5+11+1+1184:], encapsulatedNfsKey) - padding := clientHello[5+11+1+1184+1088:] - rand.Read(padding) // important - EncodeHeader(padding, 23, len(padding)-5) - nfsAEAD.Seal(padding[:5], clientHello[5:5+11+1], padding[5:len(padding)-16], padding[:5]) + padding := clientHello[ivAndRealysLength+pfsKeyExchangeLength:] + nfsGCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) + nfsGCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) - if _, err := c.Conn.Write(clientHello); err != nil { + if _, err := conn.Write(clientHello); err != nil { return nil, err } - // client can send more NFS AEAD paddings / messages if needed + // padding can be sent in a fragmented way, to create variable traffic pattern, before VLESS flow takes control - _, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before server hello + encryptedLength := make([]byte, 18) + if _, err := io.ReadFull(conn, encryptedLength); err != nil { + return nil, err + } + if _, err := nfsGCM.Open(encryptedLength[:0], make([]byte, 12), encryptedLength, nil); err != nil { + return nil, err + } + length := DecodeLength(encryptedLength[:2]) + + if length < 1088+32+16 { // server may send more public keys + return nil, errors.New("too short length") + } + encryptedPfsPublicKey := make([]byte, length) + if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil { + return nil, err + } + nfsGCM.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil) + mlkem768Key, err := mlkem768DKey.Decapsulate(encryptedPfsPublicKey[:1088]) if err != nil { return nil, err } - - if t != 1 { - return nil, errors.New("unexpected type ", t, ", expect server hello") - } - peerServerHello := make([]byte, 1088+21) - if l != len(peerServerHello) { - return nil, errors.New("unexpected length ", l, " for server hello") - } - if _, err := io.ReadFull(c.Conn, peerServerHello); err != nil { - return nil, err - } - encapsulatedPfsKey := peerServerHello[:1088] - c.ticket = append(i.hash11[:], peerServerHello[1088:]...) - - pfsKey, err := pfsDKey.Decapsulate(encapsulatedPfsKey) + peerX25519PKey, err := ecdh.X25519().NewPublicKey(encryptedPfsPublicKey[1088 : 1088+32]) if err != nil { return nil, err } - c.baseKey = append(pfsKey, nfsKey...) + x25519Key, err := x25519SKey.ECDH(peerX25519PKey) + if err != nil { + return nil, err + } + pfsKey := append(mlkem768Key, x25519Key...) + c.UnitedKey = append(pfsKey, nfsKey...) + c.GCM = NewGCM(pfsPublicKey, c.UnitedKey) + c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1088+32], c.UnitedKey) - VLESS, _ := NewAEAD(ClientCipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey).Open(nil, append(i.hash11[:], ClientCipher), c.ticket[11:], pfsEKeyBytes) - if !bytes.Equal(VLESS, []byte("VLESS")) { - return nil, errors.New("invalid server").AtError() + encryptedTicket := make([]byte, 32) + if _, err := io.ReadFull(conn, encryptedTicket); err != nil { + return nil, err + } + if _, err := c.PeerGCM.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil { + return nil, err + } + seconds := DecodeLength(encryptedTicket) + + if i.Seconds > 0 && seconds > 0 { + i.RWLock.Lock() + i.Expire = time.Now().Add(time.Duration(seconds) * time.Second) + i.PfsKey = pfsKey + i.Ticket = encryptedTicket[:16] + i.RWLock.Unlock() } - if i.minutes > 0 { - i.Lock() - i.expire = time.Now().Add(i.minutes) - i.baseKey = c.baseKey - i.ticket = c.ticket - i.Unlock() + if _, err := io.ReadFull(conn, encryptedLength); err != nil { + return nil, err + } + if _, err := c.PeerGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { + return nil, err + } + encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2])) // TODO: move to Read() + if _, err := io.ReadFull(conn, encryptedPadding); err != nil { + return nil, err + } + if _, err := c.PeerGCM.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil { + return nil, err } + if i.XorMode == 2 { + c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), NewCTR(c.UnitedKey, encryptedTicket[:16]), 0, 0) + } return c, nil } - -func (c *ClientConn) Write(b []byte) (int, error) { - if len(b) == 0 { - return 0, nil - } - var data []byte - for n := 0; n < len(b); { - b := b[n:] - if len(b) > 8192 { - b = b[:8192] // for avoiding another copy() in server's Read() - } - n += len(b) - if c.aead == nil { - data = make([]byte, 5+32+32+5+len(b)+16) - EncodeHeader(data, 0, 32+32) - copy(data[5:], c.ticket) - c.random = make([]byte, 32) - rand.Read(c.random) - copy(data[5+32:], c.random) - EncodeHeader(data[5+32+32:], 23, len(b)+16) - c.aead = NewAEAD(ClientCipher, c.baseKey, c.random, c.ticket) - c.nonce = make([]byte, 12) - c.aead.Seal(data[:5+32+32+5], c.nonce, b, data[5+32+32:5+32+32+5]) - } else { - data = make([]byte, 5+len(b)+16) - EncodeHeader(data, 23, len(b)+16) - c.aead.Seal(data[:5], c.nonce, b, data[:5]) - if bytes.Equal(c.nonce, MaxNonce) { - c.aead = NewAEAD(ClientCipher, c.baseKey, data[5:], data[:5]) - } - } - IncreaseNonce(c.nonce) - if _, err := c.Conn.Write(data); err != nil { - return 0, err - } - } - return len(b), nil -} - -func (c *ClientConn) Read(b []byte) (int, error) { - if len(b) == 0 { - return 0, nil - } - if c.peerAEAD == nil { - _, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before random hello - if err != nil { - if c.instance != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // 0-RTT - c.instance.Lock() - if bytes.Equal(c.ticket, c.instance.ticket) { - c.instance.expire = time.Now() // expired - } - c.instance.Unlock() - return 0, errors.New("new handshake needed") - } - return 0, err - } - if t != 0 { - return 0, errors.New("unexpected type ", t, ", expect random hello") - } - peerRandomHello := make([]byte, 32) - if l != len(peerRandomHello) { - return 0, errors.New("unexpected length ", l, " for random hello") - } - if _, err := io.ReadFull(c.Conn, peerRandomHello); err != nil { - return 0, err - } - if c.random == nil { - return 0, errors.New("empty c.random") - } - c.peerAEAD = NewAEAD(ClientCipher, c.baseKey, peerRandomHello, c.random) - c.peerNonce = make([]byte, 12) - } - if len(c.PeerCache) != 0 { - n := copy(b, c.PeerCache) - c.PeerCache = c.PeerCache[n:] - return n, nil - } - h, t, l, err := ReadAndDecodeHeader(c.Conn) // l: 17~17000 - if err != nil { - return 0, err - } - if t != 23 { - return 0, errors.New("unexpected type ", t, ", expect encrypted data") - } - peerData := make([]byte, l) - if _, err := io.ReadFull(c.Conn, peerData); err != nil { - return 0, err - } - dst := peerData[:l-16] - if len(dst) <= len(b) { - dst = b[:len(dst)] // avoids another copy() - } - var peerAEAD cipher.AEAD - if bytes.Equal(c.peerNonce, MaxNonce) { - peerAEAD = NewAEAD(ClientCipher, c.baseKey, peerData, h) - } - _, err = c.peerAEAD.Open(dst[:0], c.peerNonce, peerData, h) - if peerAEAD != nil { - c.peerAEAD = peerAEAD - } - IncreaseNonce(c.peerNonce) - if err != nil { - return 0, err - } - if len(dst) > len(b) { - c.PeerCache = dst[copy(b, dst):] - dst = b // for len(dst) - } - return len(dst), nil -} diff --git a/proxy/vless/encryption/common.go b/proxy/vless/encryption/common.go index 4e2d4756..a2418c1f 100644 --- a/proxy/vless/encryption/common.go +++ b/proxy/vless/encryption/common.go @@ -4,46 +4,175 @@ import ( "bytes" "crypto/aes" "crypto/cipher" - "crypto/hkdf" - "crypto/sha3" "fmt" "io" "net" + "strings" + "time" "github.com/xtls/xray-core/common/errors" - "golang.org/x/crypto/chacha20poly1305" + "lukechampine.com/blake3" ) +type CommonConn struct { + net.Conn + Client *ClientInstance + UnitedKey []byte + PreWrite []byte + GCM *GCM + PeerGCM *GCM + PeerCache []byte +} + +func (c *CommonConn) Write(b []byte) (int, error) { + if len(b) == 0 { + return 0, nil + } + var data []byte + for n := 0; n < len(b); { + b := b[n:] + if len(b) > 8192 { + b = b[:8192] // for avoiding another copy() in peer's Read() + } + n += len(b) + data = make([]byte, 5+len(b)+16) + EncodeHeader(data, len(b)+16) + aead := c.GCM + if bytes.Equal(c.GCM.Nonce[:], MaxNonce) { + aead = nil + } + c.GCM.Seal(data[:5], nil, b, data[:5]) + if aead == nil { + c.GCM = NewGCM(data[5:], c.UnitedKey) + } + if c.PreWrite != nil { + data = append(c.PreWrite, data...) + c.PreWrite = nil + } + if _, err := c.Conn.Write(data); err != nil { + return 0, err + } + } + return len(b), nil +} + +func (c *CommonConn) Read(b []byte) (int, error) { + if len(b) == 0 { + return 0, nil + } + if c.PeerGCM == nil { // client's 0-RTT + serverRandom := make([]byte, 32) + if _, err := io.ReadFull(c.Conn, serverRandom); err != nil { + return 0, err + } + c.PeerGCM = NewGCM(serverRandom, c.UnitedKey) + if xorConn, ok := c.Conn.(*XorConn); ok { + xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom[16:]) + } + } + if len(c.PeerCache) != 0 { + n := copy(b, c.PeerCache) + c.PeerCache = c.PeerCache[n:] + return n, nil + } + h, l, err := ReadAndDecodeHeader(c.Conn) // l: 17~17000 + if err != nil { + if c.Client != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // client's 0-RTT + c.Client.RWLock.Lock() + if bytes.Equal(c.UnitedKey[:32], c.Client.PfsKey) { + c.Client.Expire = time.Now() // expired + } + c.Client.RWLock.Unlock() + return 0, errors.New("new handshake needed") + } + return 0, err + } + c.Client = nil + peerData := make([]byte, l) + if _, err := io.ReadFull(c.Conn, peerData); err != nil { + return 0, err + } + dst := peerData[:l-16] + if len(dst) <= len(b) { + dst = b[:len(dst)] // avoids another copy() + } + var peerAEAD *GCM + if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) { + peerAEAD = NewGCM(peerData, c.UnitedKey) + } + _, err = c.PeerGCM.Open(dst[:0], nil, peerData, h) + if peerAEAD != nil { + c.PeerGCM = peerAEAD + } + if err != nil { + return 0, err + } + if len(dst) > len(b) { + c.PeerCache = dst[copy(b, dst):] + dst = b // for len(dst) + } + return len(dst), nil +} + +type GCM struct { + cipher.AEAD + Nonce [12]byte +} + +func NewGCM(ctx, key []byte) *GCM { + k := make([]byte, 32) + blake3.DeriveKey(k, string(ctx), key) + block, _ := aes.NewCipher(k) + aead, _ := cipher.NewGCM(block) + return &GCM{AEAD: aead} + //chacha20poly1305.New() +} + +func (a *GCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte { + if nonce == nil { + nonce = IncreaseNonce(a.Nonce[:]) + } + return a.AEAD.Seal(dst, nonce, plaintext, additionalData) +} + +func (a *GCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { + if nonce == nil { + nonce = IncreaseNonce(a.Nonce[:]) + } + return a.AEAD.Open(dst, nonce, ciphertext, additionalData) +} + +func IncreaseNonce(nonce []byte) []byte { + for i := range 12 { + nonce[11-i]++ + if nonce[11-i] != 0 { + break + } + } + return nonce +} + var MaxNonce = bytes.Repeat([]byte{255}, 12) -func EncodeHeader(h []byte, t byte, l int) { - switch t { - case 1: - h[0] = 1 - h[1] = 1 - h[2] = 1 - case 0: - h[0] = 0 - h[1] = 0 - h[2] = 0 - case 23: - h[0] = 23 - h[1] = 3 - h[2] = 3 - } +func EncodeLength(l int) []byte { + return []byte{byte(l >> 8), byte(l)} +} + +func DecodeLength(b []byte) int { + return int(b[0])<<8 | int(b[1]) +} + +func EncodeHeader(h []byte, l int) { + h[0] = 23 + h[1] = 3 + h[2] = 3 h[3] = byte(l >> 8) h[4] = byte(l) } -func DecodeHeader(h []byte) (t byte, l int, err error) { +func DecodeHeader(h []byte) (l int, err error) { l = int(h[3])<<8 | int(h[4]) - if h[0] == 23 && h[1] == 3 && h[2] == 3 { - t = 23 - } else if h[0] == 0 && h[1] == 0 && h[2] == 0 { - t = 0 - } else if h[0] == 1 && h[1] == 1 && h[2] == 1 { - t = 1 - } else { + if h[0] != 23 || h[1] != 3 || h[2] != 3 { l = 0 } if l < 17 || l > 17000 { // TODO: TLSv1.3 max length @@ -52,49 +181,22 @@ func DecodeHeader(h []byte) (t byte, l int, err error) { return } -func ReadAndDecodeHeader(conn net.Conn) (h []byte, t byte, l int, err error) { +func ReadAndDecodeHeader(conn net.Conn) (h []byte, l int, err error) { h = make([]byte, 5) if _, err = io.ReadFull(conn, h); err != nil { return } - t, l, err = DecodeHeader(h) + l, err = DecodeHeader(h) return } -func ReadAndDiscardPaddings(conn net.Conn, aead cipher.AEAD, nonce []byte) (h []byte, t byte, l int, err error) { +func ReadAndDiscardPaddings(conn net.Conn) (h []byte, l int, err error) { for { - if h, t, l, err = ReadAndDecodeHeader(conn); err != nil || t != 23 { + if h, l, err = ReadAndDecodeHeader(conn); err != nil { return } - padding := make([]byte, l) - if _, err = io.ReadFull(conn, padding); err != nil { + if _, err = io.ReadFull(conn, make([]byte, l)); err != nil { return } - if aead != nil { - if _, err := aead.Open(nil, nonce, padding, h); err != nil { - return h, t, l, err - } - IncreaseNonce(nonce) - } - } -} - -func NewAEAD(c byte, secret, salt, info []byte) (aead cipher.AEAD) { - key, _ := hkdf.Key(sha3.New256, secret, salt, string(info), 32) - if c&1 == 1 { - block, _ := aes.NewCipher(key) - aead, _ = cipher.NewGCM(block) - } else { - aead, _ = chacha20poly1305.New(key) - } - return -} - -func IncreaseNonce(nonce []byte) { - for i := range 12 { - nonce[11-i]++ - if nonce[11-i] != 0 { - break - } } } diff --git a/proxy/vless/encryption/server.go b/proxy/vless/encryption/server.go index 336571b4..4cf22516 100644 --- a/proxy/vless/encryption/server.go +++ b/proxy/vless/encryption/server.go @@ -6,7 +6,6 @@ import ( "crypto/ecdh" "crypto/mlkem" "crypto/rand" - "crypto/sha3" "fmt" "io" "net" @@ -15,73 +14,77 @@ import ( "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/errors" + "lukechampine.com/blake3" ) type ServerSession struct { - expire time.Time - cipher byte - baseKey []byte - randoms sync.Map + Expire time.Time + PfsKey []byte + Replays sync.Map } type ServerInstance struct { - sync.RWMutex - nfsDKey *mlkem.DecapsulationKey768 - hash11 [11]byte // no more capacity - xorMode uint32 - xorSKey *ecdh.PrivateKey - minutes time.Duration - sessions map[[32]byte]*ServerSession - closed bool + NfsSKeys []any + NfsPKeysBytes [][]byte + Hash32s [][32]byte + RelaysLength int + XorMode uint32 + Seconds uint32 + + RWLock sync.RWMutex + Sessions map[[16]byte]*ServerSession + Closed bool } -type ServerConn struct { - net.Conn - cipher byte - baseKey []byte - ticket []byte - peerRandom []byte - peerAEAD cipher.AEAD - peerNonce []byte - PeerCache []byte - aead cipher.AEAD - nonce []byte -} - -func (i *ServerInstance) Init(nfsDKeySeed, xorSKeyBytes []byte, xorMode, minutes uint32) (err error) { - if i.nfsDKey != nil { +func (i *ServerInstance) Init(nfsSKeysBytes [][]byte, xorMode, seconds uint32) (err error) { + if i.NfsSKeys != nil { err = errors.New("already initialized") return } - if i.nfsDKey, err = mlkem.NewDecapsulationKey768(nfsDKeySeed); err != nil { + l := len(nfsSKeysBytes) + if l == 0 { + err = errors.New("empty nfsSKeysBytes") return } - if xorMode > 0 { - i.xorMode = xorMode - if i.xorSKey, err = ecdh.X25519().NewPrivateKey(xorSKeyBytes); err != nil { - return + i.NfsSKeys = make([]any, l) + i.NfsPKeysBytes = make([][]byte, l) + i.Hash32s = make([][32]byte, l) + for j, k := range nfsSKeysBytes { + if len(k) == 32 { + if i.NfsSKeys[j], err = ecdh.X25519().NewPrivateKey(k); err != nil { + return + } + i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*ecdh.PrivateKey).PublicKey().Bytes() + i.RelaysLength += 32 + 32 + } else { + if i.NfsSKeys[j], err = mlkem.NewDecapsulationKey768(k); err != nil { + return + } + i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*mlkem.DecapsulationKey768).EncapsulationKey().Bytes() + i.RelaysLength += 1088 + 32 } - hash32 := sha3.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) - copy(i.hash11[:], hash32[:]) + i.Hash32s[j] = blake3.Sum256(i.NfsPKeysBytes[j]) } - if minutes > 0 { - i.minutes = time.Duration(minutes) * time.Minute - i.sessions = make(map[[32]byte]*ServerSession) + i.RelaysLength -= 32 + i.XorMode = xorMode + if seconds > 0 { + i.Seconds = seconds + i.Sessions = make(map[[16]byte]*ServerSession) go func() { for { time.Sleep(time.Minute) - i.Lock() - if i.closed { - i.Unlock() + i.RWLock.Lock() + if i.Closed { + i.RWLock.Unlock() return } now := time.Now() - for ticket, session := range i.sessions { - if now.After(session.expire) { - delete(i.sessions, ticket) + for ticket, session := range i.Sessions { + if now.After(session.Expire) { + delete(i.Sessions, ticket) } } - i.Unlock() + i.RWLock.Unlock() } }() } @@ -89,223 +92,190 @@ func (i *ServerInstance) Init(nfsDKeySeed, xorSKeyBytes []byte, xorMode, minutes } func (i *ServerInstance) Close() (err error) { - i.Lock() - i.closed = true - i.Unlock() + i.RWLock.Lock() + i.Closed = true + i.RWLock.Unlock() return } -func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) { - if i.nfsDKey == nil { +func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { + if i.NfsSKeys == nil { return nil, errors.New("uninitialized") } - if i.xorMode > 0 { - var err error - if conn, err = NewXorConn(conn, i.xorMode, nil, i.xorSKey); err != nil { - return nil, err - } - } - c := &ServerConn{Conn: conn} + c := &CommonConn{Conn: conn} - _, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before client/ticket hello - if err != nil { + ivAndRelays := make([]byte, 16+i.RelaysLength) + if _, err := io.ReadFull(conn, ivAndRelays); err != nil { return nil, err } + iv := ivAndRelays[:16] + relays := ivAndRelays[16:] + var nfsPublicKey, nfsKey []byte + var lastCTR cipher.Stream + for j, k := range i.NfsSKeys { + if lastCTR != nil { + lastCTR.XORKeyStream(relays, relays[:32]) // recover this relay + } + var index = 32 + if _, ok := k.(*mlkem.DecapsulationKey768); ok { + index = 1088 + } + if i.XorMode > 0 { + NewCTR(i.NfsPKeysBytes[j], iv).XORKeyStream(relays, relays[:index]) // we don't use buggy elligator, because we have PSK :) + } + nfsPublicKey = relays[:index] + if k, ok := k.(*ecdh.PrivateKey); ok { + publicKey, err := ecdh.X25519().NewPublicKey(nfsPublicKey) + if err != nil { + return nil, err + } + nfsKey, err = k.ECDH(publicKey) + if err != nil { + return nil, err + } + } + if k, ok := k.(*mlkem.DecapsulationKey768); ok { + var err error + nfsKey, err = k.Decapsulate(nfsPublicKey) + if err != nil { + return nil, err + } + } + if j == len(i.NfsSKeys)-1 { + break + } + relays = relays[index:] + lastCTR = NewCTR(nfsKey, iv) + lastCTR.XORKeyStream(relays, relays[:32]) + if !bytes.Equal(relays[:32], i.Hash32s[j+1][:]) { + return nil, errors.New("unexpected hash32: ", fmt.Sprintf("%v", relays[:32])) + } + relays = relays[32:] + } + nfsGCM := NewGCM(nfsPublicKey, nfsKey) - if t == 0 { - if i.minutes == 0 { + encryptedLength := make([]byte, 18) + if _, err := io.ReadFull(conn, encryptedLength); err != nil { + return nil, err + } + if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { + return nil, err + } + length := DecodeLength(encryptedLength[:2]) + + if length == 32 { + if i.Seconds == 0 { return nil, errors.New("0-RTT is not allowed") } - peerTicketHello := make([]byte, 32+32) - if l != len(peerTicketHello) { - return nil, errors.New("unexpected length ", l, " for ticket hello") - } - if _, err := io.ReadFull(c.Conn, peerTicketHello); err != nil { + encryptedTicket := make([]byte, 32) + if _, err := io.ReadFull(conn, encryptedTicket); err != nil { return nil, err } - if !bytes.Equal(peerTicketHello[:11], i.hash11[:]) { - return nil, errors.New("unexpected hash11: ", fmt.Sprintf("%v", peerTicketHello[:11])) + ticket, err := nfsGCM.Open(nil, nil, encryptedTicket, nil) + if err != nil { + return nil, err } - i.RLock() - s := i.sessions[[32]byte(peerTicketHello)] - i.RUnlock() + i.RWLock.RLock() + s := i.Sessions[[16]byte(ticket)] + i.RWLock.RUnlock() if s == nil { noises := make([]byte, crypto.RandBetween(100, 1000)) var err error for err == nil { rand.Read(noises) - _, _, err = DecodeHeader(noises) + _, err = DecodeHeader(noises) } - c.Conn.Write(noises) // make client do new handshake + conn.Write(noises) // make client do new handshake return nil, errors.New("expired ticket") } - if _, replay := s.randoms.LoadOrStore([32]byte(peerTicketHello[32:]), true); replay { + if _, replay := s.Replays.LoadOrStore([32]byte(encryptedTicket), true); replay { return nil, errors.New("replay detected") } - c.cipher = s.cipher - c.baseKey = s.baseKey - c.ticket = peerTicketHello[:32] - c.peerRandom = peerTicketHello[32:] + c.UnitedKey = append(s.PfsKey, nfsKey...) // the same key links the upload & download + c.PreWrite = make([]byte, 32) // always trust yourself, not the client + rand.Read(c.PreWrite) + c.GCM = NewGCM(c.PreWrite, c.UnitedKey) + c.PeerGCM = NewGCM(encryptedTicket, c.UnitedKey) + if i.XorMode == 2 { + c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite[16:]), NewCTR(c.UnitedKey, iv), 32, 0) + } return c, nil } - peerClientHello := make([]byte, 11+1+1184+1088) - if l != len(peerClientHello) { - return nil, errors.New("unexpected length ", l, " for client hello") + if length < 1184+32+16 { // client may send more public keys + return nil, errors.New("too short length") } - if _, err := io.ReadFull(c.Conn, peerClientHello); err != nil { + encryptedPfsPublicKey := make([]byte, length) + if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil { return nil, err } - if !bytes.Equal(peerClientHello[:11], i.hash11[:]) { - return nil, errors.New("unexpected hash11: ", fmt.Sprintf("%v", peerClientHello[:11])) + if _, err := nfsGCM.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil { + return nil, err } - c.cipher = peerClientHello[11] - pfsEKeyBytes := peerClientHello[11+1 : 11+1+1184] - encapsulatedNfsKey := peerClientHello[11+1+1184:] - - pfsEKey, err := mlkem.NewEncapsulationKey768(pfsEKeyBytes) + mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184]) if err != nil { return nil, err } - nfsKey, err := i.nfsDKey.Decapsulate(encapsulatedNfsKey) + mlkem768Key, encapsulatedPfsKey := mlkem768EKey.Encapsulate() + peerX25519PKey, err := ecdh.X25519().NewPublicKey(encryptedPfsPublicKey[1184 : 1184+32]) if err != nil { return nil, err } - nfsAEAD := NewAEAD(c.cipher, nfsKey, pfsEKeyBytes, encapsulatedNfsKey) - nfsNonce := append([]byte{}, peerClientHello[:11+1]...) - pfsKey, encapsulatedPfsKey := pfsEKey.Encapsulate() - c.baseKey = append(pfsKey, nfsKey...) - pfsAEAD := NewAEAD(c.cipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey) - pfsNonce := append([]byte{}, peerClientHello[:11+1]...) - c.ticket = append(i.hash11[:], pfsAEAD.Seal(nil, pfsNonce, []byte("VLESS"), pfsEKeyBytes)...) - IncreaseNonce(pfsNonce) - - serverHello := make([]byte, 5+1088+21+crypto.RandBetween(100, 1000)) - EncodeHeader(serverHello, 1, 1088+21) - copy(serverHello[5:], encapsulatedPfsKey) - copy(serverHello[5+1088:], c.ticket[11:]) - padding := serverHello[5+1088+21:] - rand.Read(padding) // important - EncodeHeader(padding, 23, len(padding)-5) - pfsAEAD.Seal(padding[:5], pfsNonce, padding[5:len(padding)-16], padding[:5]) - - if _, err := c.Conn.Write(serverHello); err != nil { - return nil, err - } - // server can send more PFS AEAD paddings / messages if needed - - _, t, l, err = ReadAndDiscardPaddings(c.Conn, nfsAEAD, nfsNonce) // allow paddings before ticket hello + x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader) + x25519Key, err := x25519SKey.ECDH(peerX25519PKey) if err != nil { return nil, err } - if t != 0 { - return nil, errors.New("unexpected type ", t, ", expect ticket hello") - } - peerTicketHello := make([]byte, 32+32) - if l != len(peerTicketHello) { - return nil, errors.New("unexpected length ", l, " for ticket hello") - } - if _, err := io.ReadFull(c.Conn, peerTicketHello); err != nil { + pfsKey := append(mlkem768Key, x25519Key...) + pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...) + c.UnitedKey = append(pfsKey, nfsKey...) + c.GCM = NewGCM(pfsPublicKey, c.UnitedKey) + c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1184+32], c.UnitedKey) + ticket := make([]byte, 16) + rand.Read(ticket) + copy(ticket, EncodeLength(int(i.Seconds*4/5))) + + pfsKeyExchangeLength := 18 + 1088 + 32 + 16 + encryptedTicketLength := 32 + paddingLength := int(crypto.RandBetween(100, 1000)) + serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength) + nfsGCM.Seal(serverHello[:0], make([]byte, 12), EncodeLength(pfsKeyExchangeLength-18), nil) // it is safe because our nonce starts from 1 + nfsGCM.Seal(serverHello[:18], MaxNonce, pfsPublicKey, nil) + c.GCM.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil) + padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:] + c.GCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) + c.GCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) + + if _, err := conn.Write(serverHello); err != nil { return nil, err } - if !bytes.Equal(peerTicketHello[:32], c.ticket) { - return nil, errors.New("naughty boy") - } - c.peerRandom = peerTicketHello[32:] + // padding can be sent in a fragmented way, to create variable traffic pattern, before VLESS flow takes control - if i.minutes > 0 { - i.Lock() - s := &ServerSession{ - expire: time.Now().Add(i.minutes), - cipher: c.cipher, - baseKey: c.baseKey, + if i.Seconds > 0 { + i.RWLock.Lock() + i.Sessions[[16]byte(ticket)] = &ServerSession{ + Expire: time.Now().Add(time.Duration(i.Seconds) * time.Second), + PfsKey: pfsKey, } - s.randoms.Store([32]byte(c.peerRandom), true) - i.sessions[[32]byte(c.ticket)] = s - i.Unlock() + i.RWLock.Unlock() } + if _, err := io.ReadFull(conn, encryptedLength); err != nil { + return nil, err + } + if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { + return nil, err + } + encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2])) + if _, err := io.ReadFull(conn, encryptedPadding); err != nil { + return nil, err + } + if _, err := nfsGCM.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil { + return nil, err + } + + if i.XorMode == 2 { + c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, ticket), NewCTR(c.UnitedKey, iv), 0, 0) + } return c, nil } - -func (c *ServerConn) Read(b []byte) (int, error) { - if len(b) == 0 { - return 0, nil - } - if c.peerAEAD == nil { - c.peerAEAD = NewAEAD(c.cipher, c.baseKey, c.peerRandom, c.ticket) - c.peerNonce = make([]byte, 12) - } - if len(c.PeerCache) != 0 { - n := copy(b, c.PeerCache) - c.PeerCache = c.PeerCache[n:] - return n, nil - } - h, t, l, err := ReadAndDecodeHeader(c.Conn) // l: 17~17000 - if err != nil { - return 0, err - } - if t != 23 { - return 0, errors.New("unexpected type ", t, ", expect encrypted data") - } - peerData := make([]byte, l) - if _, err := io.ReadFull(c.Conn, peerData); err != nil { - return 0, err - } - dst := peerData[:l-16] - if len(dst) <= len(b) { - dst = b[:len(dst)] // avoids another copy() - } - var peerAEAD cipher.AEAD - if bytes.Equal(c.peerNonce, MaxNonce) { - peerAEAD = NewAEAD(c.cipher, c.baseKey, peerData, h) - } - _, err = c.peerAEAD.Open(dst[:0], c.peerNonce, peerData, h) - if peerAEAD != nil { - c.peerAEAD = peerAEAD - } - IncreaseNonce(c.peerNonce) - if err != nil { - return 0, err - } - if len(dst) > len(b) { - c.PeerCache = dst[copy(b, dst):] - dst = b // for len(dst) - } - return len(dst), nil -} - -func (c *ServerConn) Write(b []byte) (int, error) { - if len(b) == 0 { - return 0, nil - } - var data []byte - for n := 0; n < len(b); { - b := b[n:] - if len(b) > 8192 { - b = b[:8192] // for avoiding another copy() in client's Read() - } - n += len(b) - if c.aead == nil { - data = make([]byte, 5+32+5+len(b)+16) - EncodeHeader(data, 0, 32) - rand.Read(data[5 : 5+32]) - EncodeHeader(data[5+32:], 23, len(b)+16) - c.aead = NewAEAD(c.cipher, c.baseKey, data[5:5+32], c.peerRandom) - c.nonce = make([]byte, 12) - c.aead.Seal(data[:5+32+5], c.nonce, b, data[5+32:5+32+5]) - } else { - data = make([]byte, 5+len(b)+16) - EncodeHeader(data, 23, len(b)+16) - c.aead.Seal(data[:5], c.nonce, b, data[:5]) - if bytes.Equal(c.nonce, MaxNonce) { - c.aead = NewAEAD(c.cipher, c.baseKey, data[5:], data[:5]) - } - } - IncreaseNonce(c.nonce) - if _, err := c.Conn.Write(data); err != nil { - return 0, err - } - } - return len(b), nil -} diff --git a/proxy/vless/encryption/xor.go b/proxy/vless/encryption/xor.go index c5586ae9..e435cb5c 100644 --- a/proxy/vless/encryption/xor.go +++ b/proxy/vless/encryption/xor.go @@ -3,135 +3,61 @@ package encryption import ( "crypto/aes" "crypto/cipher" - "crypto/ecdh" - "crypto/hkdf" - "crypto/rand" - "crypto/sha3" - "io" "net" - "github.com/xtls/xray-core/common/errors" + "lukechampine.com/blake3" ) -type XorConn struct { - net.Conn - Divide bool - - head []byte - key []byte - ctr cipher.Stream - peerCtr cipher.Stream - isHeader bool - skipNext bool - - out_after0 bool - out_header []byte - out_skip int - - in_after0 bool - in_header []byte - in_skip int -} - -func NewCTR(key, iv []byte, isServer bool) cipher.Stream { - info := "CLIENT" - if isServer { - info = "SERVER" // avoids attackers sending traffic back to the client, though the encryption layer has its own protection - } - key, _ = hkdf.Key(sha3.New256, key, iv, info, 32) // avoids using pKey directly if attackers sent the basepoint, or whaterver they like - block, _ := aes.NewCipher(key) +func NewCTR(key, iv []byte) cipher.Stream { + k := make([]byte, 32) + blake3.DeriveKey(k, "VLESS", key) // avoids using key directly + block, _ := aes.NewCipher(k) return cipher.NewCTR(block, iv) -} - -func NewXorConn(conn net.Conn, mode uint32, pKey *ecdh.PublicKey, sKey *ecdh.PrivateKey) (*XorConn, error) { - if mode == 0 || (pKey == nil && sKey == nil) || (pKey != nil && sKey != nil) { - return nil, errors.New("invalid parameters") - } - c := &XorConn{ - Conn: conn, - Divide: mode == 1, - isHeader: true, - out_header: make([]byte, 0, 5), // important - in_header: make([]byte, 0, 5), // important - } - if pKey != nil { - c.head = make([]byte, 16+32) - rand.Read(c.head) - eSKey, _ := ecdh.X25519().NewPrivateKey(c.head[16:]) - NewCTR(pKey.Bytes(), c.head[:16], false).XORKeyStream(c.head[16:], eSKey.PublicKey().Bytes()) // make X25519 public key distinguishable from random bytes - c.key, _ = eSKey.ECDH(pKey) - c.ctr = NewCTR(c.key, c.head[:16], false) - } - if sKey != nil { - peerHead := make([]byte, 16+32) - if _, err := io.ReadFull(c.Conn, peerHead); err != nil { - return nil, err - } - NewCTR(sKey.PublicKey().Bytes(), peerHead[:16], false).XORKeyStream(peerHead[16:], peerHead[16:]) // we don't use buggy elligator, because we have PSK :) - ePKey, err := ecdh.X25519().NewPublicKey(peerHead[16:]) - if err != nil { - return nil, err - } - key, err := sKey.ECDH(ePKey) - if err != nil { - return nil, err - } - c.peerCtr = NewCTR(key, peerHead[:16], false) - c.head = make([]byte, 16) - rand.Read(c.head) // make sure the server always replies random bytes even when received replays, though it is not important - c.ctr = NewCTR(key, c.head, true) // the same key links the upload & download, though the encryption layer has its own link - } - return c, nil //chacha20.NewUnauthenticatedCipher() } -func (c *XorConn) Write(b []byte) (int, error) { // whole one/two records +type XorConn struct { + net.Conn + CTR cipher.Stream + PeerCTR cipher.Stream + OutSkip int + OutHeader []byte + InSkip int + InHeader []byte +} + +func NewXorConn(conn net.Conn, ctr, peerCTR cipher.Stream, outSkip, inSkip int) *XorConn { + return &XorConn{ + Conn: conn, + CTR: ctr, + PeerCTR: peerCTR, + OutSkip: outSkip, + OutHeader: make([]byte, 0, 5), // important + InSkip: inSkip, + InHeader: make([]byte, 0, 5), // important + } +} + +func (c *XorConn) Write(b []byte) (int, error) { if len(b) == 0 { return 0, nil } - if !c.out_after0 { - t, l, _ := DecodeHeader(b) - if t == 23 { // single 23 - l = 5 - } else { // 1/0 + 23, or noises only - l += 10 - if t == 0 { - c.out_after0 = true - if c.Divide { - l -= 5 - } - } - } - c.ctr.XORKeyStream(b[:l], b[:l]) // caller MUST discard b - l = len(b) - if c.head != nil { - b = append(c.head, b...) - c.head = nil - } - if _, err := c.Conn.Write(b); err != nil { - return 0, err - } - return l, nil - } - if c.Divide { - return c.Conn.Write(b) - } - for p := b; ; { // for XTLS - if len(p) <= c.out_skip { - c.out_skip -= len(p) + for p := b; ; { + if len(p) <= c.OutSkip { + c.OutSkip -= len(p) break } - p = p[c.out_skip:] - c.out_skip = 0 - need := 5 - len(c.out_header) + p = p[c.OutSkip:] + c.OutSkip = 0 + need := 5 - len(c.OutHeader) if len(p) < need { - c.out_header = append(c.out_header, p...) - c.ctr.XORKeyStream(p, p) + c.OutHeader = append(c.OutHeader, p...) + c.CTR.XORKeyStream(p, p) break } - _, c.out_skip, _ = DecodeHeader(append(c.out_header, p[:need]...)) - c.out_header = c.out_header[:0] - c.ctr.XORKeyStream(p[:need], p[:need]) + c.OutSkip, _ = DecodeHeader(append(c.OutHeader, p[:need]...)) + c.OutHeader = c.OutHeader[:0] + c.CTR.XORKeyStream(p[:need], p[:need]) p = p[need:] } if _, err := c.Conn.Write(b); err != nil { @@ -140,60 +66,27 @@ func (c *XorConn) Write(b []byte) (int, error) { // whole one/two records return len(b), nil } -func (c *XorConn) Read(b []byte) (int, error) { // 5-bytes, data, 5-bytes... +func (c *XorConn) Read(b []byte) (int, error) { if len(b) == 0 { return 0, nil } - if !c.in_after0 || !c.isHeader { - if c.peerCtr == nil { // for client - peerIv := make([]byte, 16) - if _, err := io.ReadFull(c.Conn, peerIv); err != nil { - return 0, err - } - c.peerCtr = NewCTR(c.key, peerIv, true) - } - if _, err := io.ReadFull(c.Conn, b); err != nil { - return 0, err - } - if c.skipNext { - c.skipNext = false - return len(b), nil - } - c.peerCtr.XORKeyStream(b, b) - if c.isHeader { // always 5-bytes - if t, _, _ := DecodeHeader(b); t == 23 { - c.skipNext = true - } else { - c.isHeader = false - if t == 0 { - c.in_after0 = true - } - } - } else { - c.isHeader = true - } - return len(b), nil - } - if c.Divide { - return c.Conn.Read(b) - } n, err := c.Conn.Read(b) - for p := b[:n]; ; { // for XTLS - if len(p) <= c.in_skip { - c.in_skip -= len(p) + for p := b[:n]; ; { + if len(p) <= c.InSkip { + c.InSkip -= len(p) break } - p = p[c.in_skip:] - c.in_skip = 0 - need := 5 - len(c.in_header) + p = p[c.InSkip:] + c.InSkip = 0 + need := 5 - len(c.InHeader) if len(p) < need { - c.peerCtr.XORKeyStream(p, p) - c.in_header = append(c.in_header, p...) + c.PeerCTR.XORKeyStream(p, p) + c.InHeader = append(c.InHeader, p...) break } - c.peerCtr.XORKeyStream(p[:need], p[:need]) - _, c.in_skip, _ = DecodeHeader(append(c.in_header, p[:need]...)) - c.in_header = c.in_header[:0] + c.PeerCTR.XORKeyStream(p[:need], p[:need]) + c.InSkip, _ = DecodeHeader(append(c.InHeader, p[:need]...)) + c.InHeader = c.InHeader[:0] p = p[need:] } return n, err diff --git a/proxy/vless/inbound/config.pb.go b/proxy/vless/inbound/config.pb.go index 240c25d9..e3192cf8 100644 --- a/proxy/vless/inbound/config.pb.go +++ b/proxy/vless/inbound/config.pb.go @@ -115,7 +115,7 @@ type Config struct { Fallbacks []*Fallback `protobuf:"bytes,2,rep,name=fallbacks,proto3" json:"fallbacks,omitempty"` Decryption string `protobuf:"bytes,3,opt,name=decryption,proto3" json:"decryption,omitempty"` XorMode uint32 `protobuf:"varint,4,opt,name=xorMode,proto3" json:"xorMode,omitempty"` - Minutes uint32 `protobuf:"varint,5,opt,name=minutes,proto3" json:"minutes,omitempty"` + Seconds uint32 `protobuf:"varint,5,opt,name=seconds,proto3" json:"seconds,omitempty"` } func (x *Config) Reset() { @@ -176,9 +176,9 @@ func (x *Config) GetXorMode() uint32 { return 0 } -func (x *Config) GetMinutes() uint32 { +func (x *Config) GetSeconds() uint32 { if x != nil { - return x.Minutes + return x.Seconds } return 0 } @@ -211,9 +211,9 @@ var file_proxy_vless_inbound_config_proto_rawDesc = []byte{ 0x12, 0x1e, 0x0a, 0x0a, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x18, 0x0a, 0x07, 0x78, 0x6f, 0x72, 0x4d, 0x6f, 0x64, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x07, 0x78, 0x6f, 0x72, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x69, - 0x6e, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x6d, 0x69, 0x6e, - 0x75, 0x74, 0x65, 0x73, 0x42, 0x6a, 0x0a, 0x1c, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, + 0x0d, 0x52, 0x07, 0x78, 0x6f, 0x72, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, + 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x73, 0x65, 0x63, + 0x6f, 0x6e, 0x64, 0x73, 0x42, 0x6a, 0x0a, 0x1c, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x76, 0x6c, 0x65, 0x73, 0x73, 0x2e, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x50, 0x01, 0x5a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, diff --git a/proxy/vless/inbound/config.proto b/proxy/vless/inbound/config.proto index 186d8588..e1ebc8d3 100644 --- a/proxy/vless/inbound/config.proto +++ b/proxy/vless/inbound/config.proto @@ -23,5 +23,5 @@ message Config { string decryption = 3; uint32 xorMode = 4; - uint32 minutes = 5; + uint32 seconds = 5; } diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index fc8dd243..2a25dc5c 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -84,12 +84,16 @@ func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Val validator: validator, } - if s := strings.Split(config.Decryption, "."); len(s) == 2 { - nfsDKeySeed, _ := base64.RawURLEncoding.DecodeString(s[0]) - xorSKeyBytes, _ := base64.RawURLEncoding.DecodeString(s[1]) + if config.Decryption != "none" { + s := strings.Split(config.Decryption, ".") + var nfsSKeysBytes [][]byte + for _, r := range s { + b, _ := base64.RawURLEncoding.DecodeString(r) + nfsSKeysBytes = append(nfsSKeysBytes, b) + } handler.decryption = &encryption.ServerInstance{} - if err := handler.decryption.Init(nfsDKeySeed, xorSKeyBytes, config.XorMode, config.Minutes); err != nil { - return nil, errors.New("failed to use mlkem768seed").Base(err).AtError() + if err := handler.decryption.Init(nfsSKeysBytes, config.XorMode, config.Seconds); err != nil { + return nil, errors.New("failed to use decryption").Base(err).AtError() } } @@ -498,9 +502,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s case protocol.RequestCommandMux: fallthrough // we will break Mux connections that contain TCP requests case protocol.RequestCommandTCP: - if serverConn, ok := connection.(*encryption.ServerConn); ok { + if serverConn, ok := connection.(*encryption.CommonConn); ok { peerCache = &serverConn.PeerCache - if xorConn, ok := serverConn.Conn.(*encryption.XorConn); (ok && !xorConn.Divide) || !proxy.IsRAWTransport(iConn) { + if _, ok := serverConn.Conn.(*encryption.XorConn); ok || !proxy.IsRAWTransport(iConn) { inbound.CanSpliceCopy = 3 // full-random xorConn / non-RAW transport can not use Linux Splice } break diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index 51974825..4611750e 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -69,12 +69,16 @@ func New(ctx context.Context, config *Config) (*Handler, error) { } a := handler.serverPicker.PickServer().PickUser().Account.(*vless.MemoryAccount) - if s := strings.Split(a.Encryption, "."); len(s) == 2 { - nfsEKeyBytes, _ := base64.RawURLEncoding.DecodeString(s[0]) - xorPKeyBytes, _ := base64.RawURLEncoding.DecodeString(s[1]) + if a.Encryption != "none" { + s := strings.Split(a.Encryption, ".") + var nfsPKeysBytes [][]byte + for _, r := range s { + b, _ := base64.RawURLEncoding.DecodeString(r) + nfsPKeysBytes = append(nfsPKeysBytes, b) + } handler.encryption = &encryption.ClientInstance{} - if err := handler.encryption.Init(nfsEKeyBytes, xorPKeyBytes, a.XorMode, a.Minutes); err != nil { - return nil, errors.New("failed to use mlkem768client").Base(err).AtError() + if err := handler.encryption.Init(nfsPKeysBytes, a.XorMode, a.Seconds); err != nil { + return nil, errors.New("failed to use encryption").Base(err).AtError() } } @@ -161,9 +165,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte case protocol.RequestCommandMux: fallthrough // let server break Mux connections that contain TCP requests case protocol.RequestCommandTCP: - if clientConn, ok := conn.(*encryption.ClientConn); ok { + if clientConn, ok := conn.(*encryption.CommonConn); ok { peerCache = &clientConn.PeerCache - if xorConn, ok := clientConn.Conn.(*encryption.XorConn); (ok && !xorConn.Divide) || !proxy.IsRAWTransport(iConn) { + if _, ok := clientConn.Conn.(*encryption.XorConn); ok || !proxy.IsRAWTransport(iConn) { ob.CanSpliceCopy = 3 // full-random xorConn / non-RAW transport can not use Linux Splice } break