0
0
mirror of https://github.com/XTLS/REALITY.git synced 2025-08-22 14:38:35 +00:00

REALITY is REALITY now

Thank @yuhan6665 for testing
This commit is contained in:
RPRX 2023-02-09 11:59:09 +08:00 committed by GitHub
parent fb7fc93023
commit 5e6719eaf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 501 additions and 23 deletions

View File

@ -3,4 +3,6 @@
### THE NEXT FUTURE ### THE NEXT FUTURE
Server side implementation of REALITY protocol, a fork of package tls in Go 1.19.5. Server side implementation of REALITY protocol, a fork of package tls in Go 1.19.5.
For client side, please follow https://github.com/XTLS/Xray-core. For client side, please follow https://github.com/XTLS/Xray-core/blob/main/transport/internet/reality/reality.go.
TODO List: TODO

View File

@ -515,6 +515,18 @@ const (
// modified. A Config may be reused; the tls package will also not // modified. A Config may be reused; the tls package will also not
// modify it. // modify it.
type Config struct { type Config struct {
Show bool
Type string
Dest string
Xver byte
ServerNames map[string]bool
PrivateKey []byte
MinClientVer []byte
MaxClientVer []byte
MaxTimeDiff time.Duration
ShortIds map[[8]byte]bool
// Rand provides the source of entropy for nonces and RSA blinding. // Rand provides the source of entropy for nonces and RSA blinding.
// If Rand is nil, TLS uses the cryptographic random reader in package // If Rand is nil, TLS uses the cryptographic random reader in package
// crypto/rand. // crypto/rand.

47
conn.go
View File

@ -25,6 +25,11 @@ import (
// A Conn represents a secured connection. // A Conn represents a secured connection.
// It implements the net.Conn interface. // It implements the net.Conn interface.
type Conn struct { type Conn struct {
AuthKey []byte
ClientVer [3]byte
ClientTime time.Time
ClientShortId [8]byte
// constant // constant
conn net.Conn conn net.Conn
isClient bool isClient bool
@ -162,6 +167,9 @@ func (c *Conn) NetConn() net.Conn {
// A halfConn represents one direction of the record layer // A halfConn represents one direction of the record layer
// connection, either sending or receiving. // connection, either sending or receiving.
type halfConn struct { type halfConn struct {
handshakeLen [7]int
handshakeBuf []byte
sync.Mutex sync.Mutex
err error // first permanent error err error // first permanent error
@ -514,9 +522,36 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
// Encrypt the actual ContentType and replace the plaintext one. // Encrypt the actual ContentType and replace the plaintext one.
record = append(record, record[0]) record = append(record, record[0])
padding := 0
if recordType(record[0]) == recordTypeHandshake && hc.handshakeLen[1] != 0 {
switch payload[0] {
case typeEncryptedExtensions:
padding = hc.handshakeLen[2]
hc.handshakeLen[2] = 0
case typeCertificate:
padding = hc.handshakeLen[3]
hc.handshakeLen[3] = 0
case typeCertificateVerify:
padding = hc.handshakeLen[4]
hc.handshakeLen[4] = 0
case typeFinished:
padding = hc.handshakeLen[5]
hc.handshakeLen[5] = 0
case typeNewSessionTicket:
padding = hc.handshakeLen[6]
hc.handshakeLen[6] = 0
record[5] = byte(recordTypeApplicationData)
record[6] = 0
}
padding -= len(record) + c.Overhead()
if padding < 0 {
return nil, fmt.Errorf("payload[0]: %v, padding: %v", payload[0], padding)
}
record = append(record, empty[:padding]...)
}
record[0] = byte(recordTypeApplicationData) record[0] = byte(recordTypeApplicationData)
n := len(payload) + 1 + c.Overhead() n := len(record) + c.Overhead() - recordHeaderLen
record[3] = byte(n >> 8) record[3] = byte(n >> 8)
record[4] = byte(n) record[4] = byte(n)
@ -1009,6 +1044,16 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
c.out.Lock() c.out.Lock()
defer c.out.Unlock() defer c.out.Unlock()
if typ == recordTypeHandshake && c.out.handshakeBuf != nil &&
len(data) > 0 && data[0] != typeServerHello {
c.out.handshakeBuf = append(c.out.handshakeBuf, data...)
if data[0] != typeFinished {
return len(data), nil
}
data = c.out.handshakeBuf
c.out.handshakeBuf = nil
}
return c.writeRecordLocked(typ, data) return c.writeRecordLocked(typ, data)
} }

7
go.mod
View File

@ -1,8 +1,9 @@
module reality module github.com/xtls/reality
go 1.19 go 1.19
require ( require (
golang.org/x/crypto v0.5.0 github.com/pires/go-proxyproto v0.6.2
golang.org/x/sys v0.4.0 golang.org/x/crypto v0.6.0
golang.org/x/sys v0.5.0
) )

10
go.sum
View File

@ -1,4 +1,6 @@
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -8,12 +8,16 @@ import (
"bytes" "bytes"
"context" "context"
"crypto" "crypto"
"crypto/ed25519"
"crypto/hmac" "crypto/hmac"
"crypto/rsa" "crypto/rsa"
"crypto/sha512"
"crypto/x509"
"encoding/binary" "encoding/binary"
"errors" "errors"
"hash" "hash"
"io" "io"
"math/big"
"sync/atomic" "sync/atomic"
"time" "time"
) )
@ -50,15 +54,52 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
} }
// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
/*
if err := hs.processClientHello(); err != nil { if err := hs.processClientHello(); err != nil {
return err return err
} }
*/
{
hs.suite = cipherSuiteTLS13ByID(hs.hello.cipherSuite)
c.cipherSuite = hs.suite.id
hs.transcript = hs.suite.hash.New()
/*
// For Go 1.20 TLS
key, _ := generateECDHEKey(c.config.rand(), X25519)
copy(hs.hello.serverShare.data, key.PublicKey().Bytes())
peerKey, _ := key.Curve().NewPublicKey(hs.clientHello.keyShares[hs.clientHello.keyShares[0].group].data)
hs.sharedKey, _ = key.ECDH(peerKey)
*/
// For Go 1.19 TLS
params, _ := generateECDHEParameters(c.config.rand(), X25519)
copy(hs.hello.serverShare.data, params.PublicKey())
hs.sharedKey = params.SharedKey(hs.clientHello.keyShares[hs.clientHello.keyShares[0].group].data)
c.serverName = hs.clientHello.serverName
}
/*
if err := hs.checkForResumption(); err != nil { if err := hs.checkForResumption(); err != nil {
return err return err
} }
if err := hs.pickCertificate(); err != nil { if err := hs.pickCertificate(); err != nil {
return err return err
} }
*/
{
certificate := x509.Certificate{SerialNumber: &big.Int{}}
pub, priv, _ := ed25519.GenerateKey(c.config.rand())
signedCert, _ := x509.CreateCertificate(c.config.rand(), &certificate, &certificate, pub, priv)
h := hmac.New(sha512.New, c.AuthKey)
h.Write(pub)
h.Sum(signedCert[:len(signedCert)-64])
hs.cert = &Certificate{
Certificate: [][]byte{signedCert},
PrivateKey: priv,
}
hs.sigAlg = Ed25519
}
c.buffering = true c.buffering = true
if err := hs.sendServerParameters(); err != nil { if err := hs.sendServerParameters(); err != nil {
return err return err
@ -69,6 +110,11 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
if err := hs.sendServerFinished(); err != nil { if err := hs.sendServerFinished(); err != nil {
return err return err
} }
if hs.c.out.handshakeLen[6] != 0 {
if _, err := c.writeRecord(recordTypeHandshake, []byte{typeNewSessionTicket}); err != nil {
return err
}
}
// Note that at this point we could start sending application data without // Note that at this point we could start sending application data without
// waiting for the client's second flight, but the application might not // waiting for the client's second flight, but the application might not
// expect the lack of replay protection of the ClientHello parameters. // expect the lack of replay protection of the ClientHello parameters.

374
tls.go
View File

@ -15,29 +15,399 @@ import (
"bytes" "bytes"
"context" "context"
"crypto" "crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/binary"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"runtime"
"strings" "strings"
"sync"
"time"
"github.com/pires/go-proxyproto"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
) )
type ReaderConn struct {
Conn net.Conn
Reader *bytes.Reader
Written int
Closed bool
}
func (c *ReaderConn) Read(b []byte) (int, error) {
if c.Closed {
return 0, errors.New("Closed")
}
n, err := c.Reader.Read(b)
if err == io.EOF {
return n, errors.New("io.EOF") // prevent looping
}
return n, err
}
func (c *ReaderConn) Write(b []byte) (int, error) {
if c.Closed {
return 0, errors.New("Closed")
}
c.Written += len(b)
return len(b), nil
}
func (c *ReaderConn) Close() error {
c.Closed = true
return nil
}
func (c *ReaderConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
func (c *ReaderConn) RemoteAddr() net.Addr {
return c.Conn.RemoteAddr()
}
func (c *ReaderConn) SetDeadline(t time.Time) error {
return nil
}
func (c *ReaderConn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *ReaderConn) SetWriteDeadline(t time.Time) error {
return nil
}
var (
size = 8192
empty = make([]byte, size)
names = [7]string{
"Server Hello",
"Change Cipher Spec",
"Encrypted Extensions",
"Certificate",
"Certificate Verify",
"Finished",
"New Session Ticket",
}
)
func Value(vals ...byte) (value int) {
for i, val := range vals {
value |= int(val) << ((len(vals) - i - 1) * 8)
}
return
}
// Server returns a new TLS server side connection // Server returns a new TLS server side connection
// using conn as the underlying transport. // using conn as the underlying transport.
// The configuration config must be non-nil and must include // The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate. // at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn { func Server(conn net.Conn, config *Config) (*Conn, error) {
remoteAddr := conn.RemoteAddr().String()
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\n", remoteAddr)
}
target, err := net.Dial(config.Type, config.Dest)
if err != nil {
conn.Close()
return nil, errors.New("REALITY: failed to dial dest: " + err.Error())
}
if config.Xver == 1 || config.Xver == 2 {
if _, err = proxyproto.HeaderProxyFromAddrs(config.Xver, conn.RemoteAddr(), conn.LocalAddr()).WriteTo(target); err != nil {
target.Close()
conn.Close()
return nil, errors.New("REALITY: failed to send PROXY protocol: " + err.Error())
}
}
underlying := conn
if pc, ok := underlying.(*proxyproto.Conn); ok {
underlying = pc.Raw()
}
hs := serverHandshakeStateTLS13{ctx: context.TODO()}
c2sSaved := make([]byte, 0, size)
s2cSaved := make([]byte, 0, size)
copying := false
handled := false
waitGroup := new(sync.WaitGroup)
waitGroup.Add(2)
mutex := new(sync.Mutex)
go func() {
done := false
buf := make([]byte, size)
clientHelloLen := 0
for {
runtime.Gosched()
n, err := conn.Read(buf)
mutex.Lock()
if err != nil && err != io.EOF {
target.Close()
done = true
break
}
if n == 0 {
mutex.Unlock()
continue
}
c2sSaved = append(c2sSaved, buf[:n]...)
if _, err = target.Write(buf[:n]); err != nil {
done = true
break
}
if copying || len(c2sSaved) > size || len(s2cSaved) > 0 { // follow; too long; unexpected
break
}
if clientHelloLen == 0 && len(c2sSaved) > recordHeaderLen {
if recordType(c2sSaved[0]) != recordTypeHandshake || Value(c2sSaved[1:3]...) != VersionTLS10 || c2sSaved[5] != typeClientHello {
break
}
clientHelloLen = recordHeaderLen + Value(c2sSaved[3:5]...)
}
if clientHelloLen > size { // too long
break
}
if clientHelloLen == 0 || len(c2sSaved) < clientHelloLen {
mutex.Unlock()
continue
}
if len(c2sSaved) > clientHelloLen { // unexpected
break
}
readerConn := &ReaderConn{
Conn: underlying,
Reader: bytes.NewReader(c2sSaved),
}
hs.c = &Conn{
conn: readerConn,
config: config,
}
hs.clientHello, err = hs.c.readClientHello(context.TODO())
if err != nil || readerConn.Reader.Len() > 0 || readerConn.Written > 0 || readerConn.Closed {
break
}
if hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
break
}
for i, keyShare := range hs.clientHello.keyShares {
if keyShare.group != X25519 || len(keyShare.data) != 32 {
continue
}
if hs.c.AuthKey, err = curve25519.X25519(config.PrivateKey, keyShare.data); err != nil {
break
}
if _, err = hkdf.New(sha256.New, hs.c.AuthKey, hs.clientHello.random[:20], []byte("REALITY")).Read(hs.c.AuthKey); err != nil {
break
}
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\ths.clientHello.sessionId: %v\n", remoteAddr, hs.clientHello.sessionId)
fmt.Printf("REALITY remoteAddr: %v\ths.c.AuthKey: %v\n", remoteAddr, hs.c.AuthKey)
}
block, _ := aes.NewCipher(hs.c.AuthKey)
aead, _ := cipher.NewGCM(block)
ciphertext := make([]byte, 32)
plainText := make([]byte, 32)
copy(ciphertext, hs.clientHello.sessionId)
copy(hs.clientHello.sessionId, plainText) // hs.clientHello.sessionId points to hs.clientHello.raw[39:]
if _, err = aead.Open(plainText[:0], hs.clientHello.random[20:], ciphertext, hs.clientHello.raw); err != nil {
break
}
copy(hs.clientHello.sessionId, ciphertext)
copy(hs.c.ClientVer[:], plainText)
copy(hs.c.ClientShortId[:], plainText[8:])
plainText[0] = 0
plainText[1] = 0
plainText[2] = 0
hs.c.ClientTime = time.Unix(int64(binary.BigEndian.Uint64(plainText)), 0)
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\ths.c.ClientVer: %v\n", remoteAddr, hs.c.ClientVer)
fmt.Printf("REALITY remoteAddr: %v\ths.c.ClientTime: %v\n", remoteAddr, hs.c.ClientTime)
fmt.Printf("REALITY remoteAddr: %v\ths.c.ClientShortId: %v\n", remoteAddr, hs.c.ClientShortId)
}
if (config.MinClientVer == nil || Value(hs.c.ClientVer[:]...) >= Value(config.MinClientVer...)) &&
(config.MaxClientVer == nil || Value(hs.c.ClientVer[:]...) <= Value(config.MaxClientVer...)) &&
(config.MaxTimeDiff == 0 || time.Since(hs.c.ClientTime).Abs() <= config.MaxTimeDiff) &&
(config.ShortIds[hs.c.ClientShortId]) {
hs.c.conn = underlying
}
hs.clientHello.keyShares[0].group = CurveID(i)
break
}
if hs.c.conn == underlying {
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\ths.c.conn: underlying\n", remoteAddr)
}
done = true
}
break
}
if done {
mutex.Unlock()
} else {
copying = true
mutex.Unlock()
io.Copy(target, underlying)
}
waitGroup.Done()
}()
go func() {
done := false
buf := make([]byte, size)
handshakeLen := 0
f:
for {
runtime.Gosched()
n, err := target.Read(buf)
mutex.Lock()
if err != nil && err != io.EOF {
conn.Close()
done = true
break
}
if n == 0 {
mutex.Unlock()
continue
}
s2cSaved = append(s2cSaved, buf[:n]...)
if hs.c == nil || hs.c.conn != underlying {
if _, err = conn.Write(buf[:n]); err != nil {
done = true
break
}
if copying || len(s2cSaved) > size { // follow; too long
break
}
mutex.Unlock()
continue
}
done = true // special
if len(s2cSaved) > size {
break
}
check := func(i int) int {
if hs.c.out.handshakeLen[i] != 0 {
return 0
}
if i == 6 && len(s2cSaved) == 0 {
return 0
}
if handshakeLen == 0 && len(s2cSaved) > recordHeaderLen {
if Value(s2cSaved[1:3]...) != VersionTLS12 ||
(i == 0 && (recordType(s2cSaved[0]) != recordTypeHandshake || s2cSaved[5] != typeServerHello)) ||
(i == 1 && (recordType(s2cSaved[0]) != recordTypeChangeCipherSpec || s2cSaved[5] != 1)) ||
(i > 1 && recordType(s2cSaved[0]) != recordTypeApplicationData) {
return -1
}
handshakeLen = recordHeaderLen + Value(s2cSaved[3:5]...)
}
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\tlen(s2cSaved): %v\t%v: %v\n", remoteAddr, len(s2cSaved), names[i], handshakeLen)
}
if handshakeLen > size { // too long
return -1
}
if i == 1 && handshakeLen > 0 && handshakeLen != 6 {
return -1
}
if i == 2 && handshakeLen > 512 {
hs.c.out.handshakeLen[i] = handshakeLen
hs.c.out.handshakeBuf = s2cSaved[:0]
return 2
}
if i == 6 && handshakeLen > 0 {
hs.c.out.handshakeLen[i] = handshakeLen
return 0
}
if handshakeLen == 0 || len(s2cSaved) < handshakeLen {
mutex.Unlock()
return 1
}
if i == 0 {
hs.hello = new(serverHelloMsg)
if !hs.hello.unmarshal(s2cSaved[recordHeaderLen:handshakeLen]) ||
hs.hello.vers != VersionTLS12 || hs.hello.supportedVersion != VersionTLS13 ||
cipherSuiteTLS13ByID(hs.hello.cipherSuite) == nil ||
hs.hello.serverShare.group != X25519 || len(hs.hello.serverShare.data) != 32 {
return -1
}
}
hs.c.out.handshakeLen[i] = handshakeLen
s2cSaved = s2cSaved[handshakeLen:]
handshakeLen = 0
return 0
}
for i := 0; i < 7; i++ {
switch check(i) {
case 2:
goto handshake
case 1:
continue f
case 0:
continue
case -1:
break f
}
}
handshake:
err = hs.handshake()
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\ths.handshake() err: %v\n", remoteAddr, err)
}
if err == nil {
handled = true
}
break
}
if done {
mutex.Unlock()
} else {
copying = true
mutex.Unlock()
io.Copy(underlying, target)
}
waitGroup.Done()
}()
waitGroup.Wait()
target.Close()
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\thandled: %v\n", remoteAddr, handled)
}
if handled {
return hs.c, nil
}
conn.Close()
return nil, errors.New("REALITY: processed invalid connection")
/*
c := &Conn{ c := &Conn{
conn: conn, conn: conn,
config: config, config: config,
} }
c.handshakeFn = c.serverHandshake c.handshakeFn = c.serverHandshake
return c return c
*/
} }
// Client returns a new TLS client side connection // Client returns a new TLS client side connection
@ -67,7 +437,7 @@ func (l *listener) Accept() (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Server(c, l.config), nil return Server(c, l.config)
} }
// NewListener creates a Listener which accepts connections from an inner // NewListener creates a Listener which accepts connections from an inner