diff --git a/proxy/vless/encryption/client.go b/proxy/vless/encryption/client.go index 4f86b6a2..a0ae0d2b 100644 --- a/proxy/vless/encryption/client.go +++ b/proxy/vless/encryption/client.go @@ -59,13 +59,13 @@ func (i *ClientInstance) Init(nfsEKeyBytes, xorPKeyBytes []byte, xorMode, minute if i.nfsEKey, err = mlkem.NewEncapsulationKey768(nfsEKeyBytes); err != nil { return } - hash32 := sha3.Sum256(nfsEKeyBytes) - copy(i.hash11[:], hash32[:]) if xorMode > 0 { i.xorMode = xorMode if i.xorPKey, err = ecdh.X25519().NewPublicKey(xorPKeyBytes); err != nil { return } + hash32 := sha3.Sum256(nfsEKeyBytes) + copy(i.hash11[:], hash32[:]) } i.minutes = time.Duration(minutes) * time.Minute return @@ -115,7 +115,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*ClientConn, error) { } // client can send more NFS AEAD paddings / messages if needed - _, t, l, err := ReadAndDiscardPaddings(c.Conn) // allow paddings before server hello + _, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before server hello if err != nil { return nil, err } @@ -198,9 +198,9 @@ func (c *ClientConn) Read(b []byte) (int, error) { return 0, nil } if c.peerAEAD == nil { - _, t, l, err := ReadAndDiscardPaddings(c.Conn) // allow paddings before random hello + _, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before random hello if err != nil { - if c.instance != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // 0-RTT's 0-RTT + if c.instance != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // 0-RTT c.instance.Lock() if bytes.Equal(c.ticket, c.instance.ticket) { c.instance.expire = time.Now() // expired diff --git a/proxy/vless/encryption/common.go b/proxy/vless/encryption/common.go index 6de517e0..4e2d4756 100644 --- a/proxy/vless/encryption/common.go +++ b/proxy/vless/encryption/common.go @@ -61,14 +61,21 @@ func ReadAndDecodeHeader(conn net.Conn) (h []byte, t byte, l int, err error) { return } -func ReadAndDiscardPaddings(conn net.Conn) (h []byte, t byte, l int, err error) { +func ReadAndDiscardPaddings(conn net.Conn, aead cipher.AEAD, nonce []byte) (h []byte, t byte, l int, err error) { for { if h, t, l, err = ReadAndDecodeHeader(conn); err != nil || t != 23 { return } - if _, err = io.ReadFull(conn, make([]byte, l)); err != nil { + padding := make([]byte, l) + if _, err = io.ReadFull(conn, padding); err != nil { return } + if aead != nil { + if _, err := aead.Open(nil, nonce, padding, h); err != nil { + return h, t, l, err + } + IncreaseNonce(nonce) + } } } diff --git a/proxy/vless/encryption/server.go b/proxy/vless/encryption/server.go index 8d91f415..336571b4 100644 --- a/proxy/vless/encryption/server.go +++ b/proxy/vless/encryption/server.go @@ -56,13 +56,13 @@ func (i *ServerInstance) Init(nfsDKeySeed, xorSKeyBytes []byte, xorMode, minutes if i.nfsDKey, err = mlkem.NewDecapsulationKey768(nfsDKeySeed); err != nil { return } - hash32 := sha3.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) - copy(i.hash11[:], hash32[:]) if xorMode > 0 { i.xorMode = xorMode if i.xorSKey, err = ecdh.X25519().NewPrivateKey(xorSKeyBytes); err != nil { return } + hash32 := sha3.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) + copy(i.hash11[:], hash32[:]) } if minutes > 0 { i.minutes = time.Duration(minutes) * time.Minute @@ -107,7 +107,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) { } c := &ServerConn{Conn: conn} - _, t, l, err := ReadAndDiscardPaddings(c.Conn) // allow paddings before client/ticket hello + _, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before client/ticket hello if err != nil { return nil, err } @@ -171,11 +171,14 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) { if err != nil { return nil, err } + nfsAEAD := NewAEAD(c.cipher, nfsKey, pfsEKeyBytes, encapsulatedNfsKey) + nfsNonce := append([]byte{}, peerClientHello[:11+1]...) pfsKey, encapsulatedPfsKey := pfsEKey.Encapsulate() c.baseKey = append(pfsKey, nfsKey...) pfsAEAD := NewAEAD(c.cipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey) - c.ticket = append(i.hash11[:], pfsAEAD.Seal(nil, peerClientHello[:11+1], []byte("VLESS"), pfsEKeyBytes)...) - IncreaseNonce(peerClientHello[:11+1]) + pfsNonce := append([]byte{}, peerClientHello[:11+1]...) + c.ticket = append(i.hash11[:], pfsAEAD.Seal(nil, pfsNonce, []byte("VLESS"), pfsEKeyBytes)...) + IncreaseNonce(pfsNonce) serverHello := make([]byte, 5+1088+21+crypto.RandBetween(100, 1000)) EncodeHeader(serverHello, 1, 1088+21) @@ -184,20 +187,41 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) { padding := serverHello[5+1088+21:] rand.Read(padding) // important EncodeHeader(padding, 23, len(padding)-5) - pfsAEAD.Seal(padding[:5], peerClientHello[:11+1], padding[5:len(padding)-16], padding[:5]) + pfsAEAD.Seal(padding[:5], pfsNonce, padding[5:len(padding)-16], padding[:5]) if _, err := c.Conn.Write(serverHello); err != nil { return nil, err } // server can send more PFS AEAD paddings / messages if needed + _, t, l, err = ReadAndDiscardPaddings(c.Conn, nfsAEAD, nfsNonce) // allow paddings before ticket hello + if err != nil { + return nil, err + } + if t != 0 { + return nil, errors.New("unexpected type ", t, ", expect ticket hello") + } + peerTicketHello := make([]byte, 32+32) + if l != len(peerTicketHello) { + return nil, errors.New("unexpected length ", l, " for ticket hello") + } + if _, err := io.ReadFull(c.Conn, peerTicketHello); err != nil { + return nil, err + } + if !bytes.Equal(peerTicketHello[:32], c.ticket) { + return nil, errors.New("naughty boy") + } + c.peerRandom = peerTicketHello[32:] + if i.minutes > 0 { i.Lock() - i.sessions[[32]byte(c.ticket)] = &ServerSession{ + s := &ServerSession{ expire: time.Now().Add(i.minutes), cipher: c.cipher, baseKey: c.baseKey, } + s.randoms.Store([32]byte(c.peerRandom), true) + i.sessions[[32]byte(c.ticket)] = s i.Unlock() } @@ -209,26 +233,6 @@ func (c *ServerConn) Read(b []byte) (int, error) { return 0, nil } if c.peerAEAD == nil { - if c.peerRandom == nil { // 1-RTT's 0-RTT - _, t, l, err := ReadAndDiscardPaddings(c.Conn) // allow paddings before ticket hello - if err != nil { - return 0, err - } - if t != 0 { - return 0, errors.New("unexpected type ", t, ", expect ticket hello") - } - peerTicketHello := make([]byte, 32+32) - if l != len(peerTicketHello) { - return 0, errors.New("unexpected length ", l, " for ticket hello") - } - if _, err := io.ReadFull(c.Conn, peerTicketHello); err != nil { - return 0, err - } - if !bytes.Equal(peerTicketHello[:32], c.ticket) { - return 0, errors.New("naughty boy") - } - c.peerRandom = peerTicketHello[32:] - } c.peerAEAD = NewAEAD(c.cipher, c.baseKey, c.peerRandom, c.ticket) c.peerNonce = make([]byte, 12) } @@ -283,9 +287,6 @@ func (c *ServerConn) Write(b []byte) (int, error) { } n += len(b) if c.aead == nil { - if c.peerRandom == nil { - return 0, errors.New("empty c.peerRandom") - } data = make([]byte, 5+32+5+len(b)+16) EncodeHeader(data, 0, 32) rand.Read(data[5 : 5+32])