mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-08-26 00:15:31 +00:00
203 lines
4.2 KiB
Go
203 lines
4.2 KiB
Go
package encryption
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/xtls/xray-core/common/errors"
|
|
"lukechampine.com/blake3"
|
|
)
|
|
|
|
type CommonConn struct {
|
|
net.Conn
|
|
Client *ClientInstance
|
|
UnitedKey []byte
|
|
PreWrite []byte
|
|
GCM *GCM
|
|
PeerGCM *GCM
|
|
PeerCache []byte
|
|
}
|
|
|
|
func (c *CommonConn) Write(b []byte) (int, error) {
|
|
if len(b) == 0 {
|
|
return 0, nil
|
|
}
|
|
var data []byte
|
|
for n := 0; n < len(b); {
|
|
b := b[n:]
|
|
if len(b) > 8192 {
|
|
b = b[:8192] // for avoiding another copy() in peer's Read()
|
|
}
|
|
n += len(b)
|
|
data = make([]byte, 5+len(b)+16)
|
|
EncodeHeader(data, len(b)+16)
|
|
aead := c.GCM
|
|
if bytes.Equal(c.GCM.Nonce[:], MaxNonce) {
|
|
aead = nil
|
|
}
|
|
c.GCM.Seal(data[:5], nil, b, data[:5])
|
|
if aead == nil {
|
|
c.GCM = NewGCM(data[5:], c.UnitedKey)
|
|
}
|
|
if c.PreWrite != nil {
|
|
data = append(c.PreWrite, data...)
|
|
c.PreWrite = nil
|
|
}
|
|
if _, err := c.Conn.Write(data); err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
return len(b), nil
|
|
}
|
|
|
|
func (c *CommonConn) Read(b []byte) (int, error) {
|
|
if len(b) == 0 {
|
|
return 0, nil
|
|
}
|
|
if c.PeerGCM == nil { // client's 0-RTT
|
|
serverRandom := make([]byte, 32)
|
|
if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
|
|
return 0, err
|
|
}
|
|
c.PeerGCM = NewGCM(serverRandom, c.UnitedKey)
|
|
if xorConn, ok := c.Conn.(*XorConn); ok {
|
|
xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom[16:])
|
|
}
|
|
}
|
|
if len(c.PeerCache) != 0 {
|
|
n := copy(b, c.PeerCache)
|
|
c.PeerCache = c.PeerCache[n:]
|
|
return n, nil
|
|
}
|
|
h, l, err := ReadAndDecodeHeader(c.Conn) // l: 17~17000
|
|
if err != nil {
|
|
if c.Client != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // client's 0-RTT
|
|
c.Client.RWLock.Lock()
|
|
if bytes.Equal(c.UnitedKey[:32], c.Client.PfsKey) {
|
|
c.Client.Expire = time.Now() // expired
|
|
}
|
|
c.Client.RWLock.Unlock()
|
|
return 0, errors.New("new handshake needed")
|
|
}
|
|
return 0, err
|
|
}
|
|
c.Client = nil
|
|
peerData := make([]byte, l)
|
|
if _, err := io.ReadFull(c.Conn, peerData); err != nil {
|
|
return 0, err
|
|
}
|
|
dst := peerData[:l-16]
|
|
if len(dst) <= len(b) {
|
|
dst = b[:len(dst)] // avoids another copy()
|
|
}
|
|
var peerAEAD *GCM
|
|
if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) {
|
|
peerAEAD = NewGCM(peerData, c.UnitedKey)
|
|
}
|
|
_, err = c.PeerGCM.Open(dst[:0], nil, peerData, h)
|
|
if peerAEAD != nil {
|
|
c.PeerGCM = peerAEAD
|
|
}
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if len(dst) > len(b) {
|
|
c.PeerCache = dst[copy(b, dst):]
|
|
dst = b // for len(dst)
|
|
}
|
|
return len(dst), nil
|
|
}
|
|
|
|
type GCM struct {
|
|
cipher.AEAD
|
|
Nonce [12]byte
|
|
}
|
|
|
|
func NewGCM(ctx, key []byte) *GCM {
|
|
k := make([]byte, 32)
|
|
blake3.DeriveKey(k, string(ctx), key)
|
|
block, _ := aes.NewCipher(k)
|
|
aead, _ := cipher.NewGCM(block)
|
|
return &GCM{AEAD: aead}
|
|
//chacha20poly1305.New()
|
|
}
|
|
|
|
func (a *GCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
|
|
if nonce == nil {
|
|
nonce = IncreaseNonce(a.Nonce[:])
|
|
}
|
|
return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
|
|
}
|
|
|
|
func (a *GCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
|
if nonce == nil {
|
|
nonce = IncreaseNonce(a.Nonce[:])
|
|
}
|
|
return a.AEAD.Open(dst, nonce, ciphertext, additionalData)
|
|
}
|
|
|
|
func IncreaseNonce(nonce []byte) []byte {
|
|
for i := range 12 {
|
|
nonce[11-i]++
|
|
if nonce[11-i] != 0 {
|
|
break
|
|
}
|
|
}
|
|
return nonce
|
|
}
|
|
|
|
var MaxNonce = bytes.Repeat([]byte{255}, 12)
|
|
|
|
func EncodeLength(l int) []byte {
|
|
return []byte{byte(l >> 8), byte(l)}
|
|
}
|
|
|
|
func DecodeLength(b []byte) int {
|
|
return int(b[0])<<8 | int(b[1])
|
|
}
|
|
|
|
func EncodeHeader(h []byte, l int) {
|
|
h[0] = 23
|
|
h[1] = 3
|
|
h[2] = 3
|
|
h[3] = byte(l >> 8)
|
|
h[4] = byte(l)
|
|
}
|
|
|
|
func DecodeHeader(h []byte) (l int, err error) {
|
|
l = int(h[3])<<8 | int(h[4])
|
|
if h[0] != 23 || h[1] != 3 || h[2] != 3 {
|
|
l = 0
|
|
}
|
|
if l < 17 || l > 17000 { // TODO: TLSv1.3 max length
|
|
err = errors.New("invalid header: ", fmt.Sprintf("%v", h[:5])) // DO NOT CHANGE: relied by client's Read()
|
|
}
|
|
return
|
|
}
|
|
|
|
func ReadAndDecodeHeader(conn net.Conn) (h []byte, l int, err error) {
|
|
h = make([]byte, 5)
|
|
if _, err = io.ReadFull(conn, h); err != nil {
|
|
return
|
|
}
|
|
l, err = DecodeHeader(h)
|
|
return
|
|
}
|
|
|
|
func ReadAndDiscardPaddings(conn net.Conn) (h []byte, l int, err error) {
|
|
for {
|
|
if h, l, err = ReadAndDecodeHeader(conn); err != nil {
|
|
return
|
|
}
|
|
if _, err = io.ReadFull(conn, make([]byte, l)); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|