From 5e6719eaf3bc70e1977bb724570c2567c92636b9 Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Thu, 9 Feb 2023 11:59:09 +0800 Subject: [PATCH] REALITY is REALITY now Thank @yuhan6665 for testing --- README.md | 4 +- common.go | 12 ++ conn.go | 47 ++++- go.mod | 7 +- go.sum | 10 +- handshake_server_tls13.go | 60 +++++- tls.go | 384 +++++++++++++++++++++++++++++++++++++- 7 files changed, 501 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 5ad29a2..e6a6e66 100644 --- a/README.md +++ b/README.md @@ -3,4 +3,6 @@ ### THE NEXT FUTURE Server side implementation of REALITY protocol, a fork of package tls in Go 1.19.5. -For client side, please follow https://github.com/XTLS/Xray-core. +For client side, please follow https://github.com/XTLS/Xray-core/blob/main/transport/internet/reality/reality.go. + +TODO List: TODO diff --git a/common.go b/common.go index 163d9be..a8b54c8 100644 --- a/common.go +++ b/common.go @@ -515,6 +515,18 @@ const ( // modified. A Config may be reused; the tls package will also not // modify it. type Config struct { + Show bool + Type string + Dest string + Xver byte + + ServerNames map[string]bool + PrivateKey []byte + MinClientVer []byte + MaxClientVer []byte + MaxTimeDiff time.Duration + ShortIds map[[8]byte]bool + // Rand provides the source of entropy for nonces and RSA blinding. // If Rand is nil, TLS uses the cryptographic random reader in package // crypto/rand. diff --git a/conn.go b/conn.go index 33ecb0e..47a60c1 100644 --- a/conn.go +++ b/conn.go @@ -25,6 +25,11 @@ import ( // A Conn represents a secured connection. // It implements the net.Conn interface. type Conn struct { + AuthKey []byte + ClientVer [3]byte + ClientTime time.Time + ClientShortId [8]byte + // constant conn net.Conn isClient bool @@ -162,6 +167,9 @@ func (c *Conn) NetConn() net.Conn { // A halfConn represents one direction of the record layer // connection, either sending or receiving. type halfConn struct { + handshakeLen [7]int + handshakeBuf []byte + sync.Mutex err error // first permanent error @@ -514,9 +522,36 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err // Encrypt the actual ContentType and replace the plaintext one. record = append(record, record[0]) + padding := 0 + if recordType(record[0]) == recordTypeHandshake && hc.handshakeLen[1] != 0 { + switch payload[0] { + case typeEncryptedExtensions: + padding = hc.handshakeLen[2] + hc.handshakeLen[2] = 0 + case typeCertificate: + padding = hc.handshakeLen[3] + hc.handshakeLen[3] = 0 + case typeCertificateVerify: + padding = hc.handshakeLen[4] + hc.handshakeLen[4] = 0 + case typeFinished: + padding = hc.handshakeLen[5] + hc.handshakeLen[5] = 0 + case typeNewSessionTicket: + padding = hc.handshakeLen[6] + hc.handshakeLen[6] = 0 + record[5] = byte(recordTypeApplicationData) + record[6] = 0 + } + padding -= len(record) + c.Overhead() + if padding < 0 { + return nil, fmt.Errorf("payload[0]: %v, padding: %v", payload[0], padding) + } + record = append(record, empty[:padding]...) + } record[0] = byte(recordTypeApplicationData) - n := len(payload) + 1 + c.Overhead() + n := len(record) + c.Overhead() - recordHeaderLen record[3] = byte(n >> 8) record[4] = byte(n) @@ -1009,6 +1044,16 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { c.out.Lock() defer c.out.Unlock() + if typ == recordTypeHandshake && c.out.handshakeBuf != nil && + len(data) > 0 && data[0] != typeServerHello { + c.out.handshakeBuf = append(c.out.handshakeBuf, data...) + if data[0] != typeFinished { + return len(data), nil + } + data = c.out.handshakeBuf + c.out.handshakeBuf = nil + } + return c.writeRecordLocked(typ, data) } diff --git a/go.mod b/go.mod index ff7927b..7389307 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,9 @@ -module reality +module github.com/xtls/reality go 1.19 require ( - golang.org/x/crypto v0.5.0 - golang.org/x/sys v0.4.0 + github.com/pires/go-proxyproto v0.6.2 + golang.org/x/crypto v0.6.0 + golang.org/x/sys v0.5.0 ) diff --git a/go.sum b/go.sum index 5317d3e..213eae6 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ -golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= -golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8= +github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= +golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index a4b38a5..1b5e026 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -8,12 +8,16 @@ import ( "bytes" "context" "crypto" + "crypto/ed25519" "crypto/hmac" "crypto/rsa" + "crypto/sha512" + "crypto/x509" "encoding/binary" "errors" "hash" "io" + "math/big" "sync/atomic" "time" ) @@ -50,14 +54,51 @@ func (hs *serverHandshakeStateTLS13) handshake() error { } // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. - if err := hs.processClientHello(); err != nil { - return err + /* + if err := hs.processClientHello(); err != nil { + return err + } + */ + { + hs.suite = cipherSuiteTLS13ByID(hs.hello.cipherSuite) + c.cipherSuite = hs.suite.id + hs.transcript = hs.suite.hash.New() + /* + // For Go 1.20 TLS + key, _ := generateECDHEKey(c.config.rand(), X25519) + copy(hs.hello.serverShare.data, key.PublicKey().Bytes()) + peerKey, _ := key.Curve().NewPublicKey(hs.clientHello.keyShares[hs.clientHello.keyShares[0].group].data) + hs.sharedKey, _ = key.ECDH(peerKey) + */ + // For Go 1.19 TLS + params, _ := generateECDHEParameters(c.config.rand(), X25519) + copy(hs.hello.serverShare.data, params.PublicKey()) + hs.sharedKey = params.SharedKey(hs.clientHello.keyShares[hs.clientHello.keyShares[0].group].data) + + c.serverName = hs.clientHello.serverName } - if err := hs.checkForResumption(); err != nil { - return err - } - if err := hs.pickCertificate(); err != nil { - return err + /* + if err := hs.checkForResumption(); err != nil { + return err + } + if err := hs.pickCertificate(); err != nil { + return err + } + */ + { + certificate := x509.Certificate{SerialNumber: &big.Int{}} + pub, priv, _ := ed25519.GenerateKey(c.config.rand()) + signedCert, _ := x509.CreateCertificate(c.config.rand(), &certificate, &certificate, pub, priv) + + h := hmac.New(sha512.New, c.AuthKey) + h.Write(pub) + h.Sum(signedCert[:len(signedCert)-64]) + + hs.cert = &Certificate{ + Certificate: [][]byte{signedCert}, + PrivateKey: priv, + } + hs.sigAlg = Ed25519 } c.buffering = true if err := hs.sendServerParameters(); err != nil { @@ -69,6 +110,11 @@ func (hs *serverHandshakeStateTLS13) handshake() error { if err := hs.sendServerFinished(); err != nil { return err } + if hs.c.out.handshakeLen[6] != 0 { + if _, err := c.writeRecord(recordTypeHandshake, []byte{typeNewSessionTicket}); err != nil { + return err + } + } // Note that at this point we could start sending application data without // waiting for the client's second flight, but the application might not // expect the lack of replay protection of the ClientHello parameters. diff --git a/tls.go b/tls.go index 5118ce3..6ac677e 100644 --- a/tls.go +++ b/tls.go @@ -15,29 +15,399 @@ import ( "bytes" "context" "crypto" + "crypto/aes" + "crypto/cipher" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "crypto/sha256" "crypto/x509" + "encoding/binary" "encoding/pem" "errors" "fmt" + "io" "net" "os" + "runtime" "strings" + "sync" + "time" + + "github.com/pires/go-proxyproto" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" ) +type ReaderConn struct { + Conn net.Conn + Reader *bytes.Reader + Written int + Closed bool +} + +func (c *ReaderConn) Read(b []byte) (int, error) { + if c.Closed { + return 0, errors.New("Closed") + } + n, err := c.Reader.Read(b) + if err == io.EOF { + return n, errors.New("io.EOF") // prevent looping + } + return n, err +} + +func (c *ReaderConn) Write(b []byte) (int, error) { + if c.Closed { + return 0, errors.New("Closed") + } + c.Written += len(b) + return len(b), nil +} + +func (c *ReaderConn) Close() error { + c.Closed = true + return nil +} + +func (c *ReaderConn) LocalAddr() net.Addr { + return c.Conn.LocalAddr() +} + +func (c *ReaderConn) RemoteAddr() net.Addr { + return c.Conn.RemoteAddr() +} + +func (c *ReaderConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *ReaderConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *ReaderConn) SetWriteDeadline(t time.Time) error { + return nil +} + +var ( + size = 8192 + empty = make([]byte, size) + names = [7]string{ + "Server Hello", + "Change Cipher Spec", + "Encrypted Extensions", + "Certificate", + "Certificate Verify", + "Finished", + "New Session Ticket", + } +) + +func Value(vals ...byte) (value int) { + for i, val := range vals { + value |= int(val) << ((len(vals) - i - 1) * 8) + } + return +} + // Server returns a new TLS server side connection // using conn as the underlying transport. // The configuration config must be non-nil and must include // at least one certificate or else set GetCertificate. -func Server(conn net.Conn, config *Config) *Conn { - c := &Conn{ - conn: conn, - config: config, +func Server(conn net.Conn, config *Config) (*Conn, error) { + remoteAddr := conn.RemoteAddr().String() + if config.Show { + fmt.Printf("REALITY remoteAddr: %v\n", remoteAddr) } - c.handshakeFn = c.serverHandshake - return c + + target, err := net.Dial(config.Type, config.Dest) + if err != nil { + conn.Close() + return nil, errors.New("REALITY: failed to dial dest: " + err.Error()) + } + + if config.Xver == 1 || config.Xver == 2 { + if _, err = proxyproto.HeaderProxyFromAddrs(config.Xver, conn.RemoteAddr(), conn.LocalAddr()).WriteTo(target); err != nil { + target.Close() + conn.Close() + return nil, errors.New("REALITY: failed to send PROXY protocol: " + err.Error()) + } + } + + underlying := conn + if pc, ok := underlying.(*proxyproto.Conn); ok { + underlying = pc.Raw() + } + + hs := serverHandshakeStateTLS13{ctx: context.TODO()} + + c2sSaved := make([]byte, 0, size) + s2cSaved := make([]byte, 0, size) + + copying := false + handled := false + + waitGroup := new(sync.WaitGroup) + waitGroup.Add(2) + + mutex := new(sync.Mutex) + + go func() { + done := false + buf := make([]byte, size) + clientHelloLen := 0 + for { + runtime.Gosched() + n, err := conn.Read(buf) + mutex.Lock() + if err != nil && err != io.EOF { + target.Close() + done = true + break + } + if n == 0 { + mutex.Unlock() + continue + } + c2sSaved = append(c2sSaved, buf[:n]...) + if _, err = target.Write(buf[:n]); err != nil { + done = true + break + } + if copying || len(c2sSaved) > size || len(s2cSaved) > 0 { // follow; too long; unexpected + break + } + if clientHelloLen == 0 && len(c2sSaved) > recordHeaderLen { + if recordType(c2sSaved[0]) != recordTypeHandshake || Value(c2sSaved[1:3]...) != VersionTLS10 || c2sSaved[5] != typeClientHello { + break + } + clientHelloLen = recordHeaderLen + Value(c2sSaved[3:5]...) + } + if clientHelloLen > size { // too long + break + } + if clientHelloLen == 0 || len(c2sSaved) < clientHelloLen { + mutex.Unlock() + continue + } + if len(c2sSaved) > clientHelloLen { // unexpected + break + } + readerConn := &ReaderConn{ + Conn: underlying, + Reader: bytes.NewReader(c2sSaved), + } + hs.c = &Conn{ + conn: readerConn, + config: config, + } + hs.clientHello, err = hs.c.readClientHello(context.TODO()) + if err != nil || readerConn.Reader.Len() > 0 || readerConn.Written > 0 || readerConn.Closed { + break + } + if hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] { + break + } + for i, keyShare := range hs.clientHello.keyShares { + if keyShare.group != X25519 || len(keyShare.data) != 32 { + continue + } + if hs.c.AuthKey, err = curve25519.X25519(config.PrivateKey, keyShare.data); err != nil { + break + } + if _, err = hkdf.New(sha256.New, hs.c.AuthKey, hs.clientHello.random[:20], []byte("REALITY")).Read(hs.c.AuthKey); err != nil { + break + } + if config.Show { + fmt.Printf("REALITY remoteAddr: %v\ths.clientHello.sessionId: %v\n", remoteAddr, hs.clientHello.sessionId) + fmt.Printf("REALITY remoteAddr: %v\ths.c.AuthKey: %v\n", remoteAddr, hs.c.AuthKey) + } + block, _ := aes.NewCipher(hs.c.AuthKey) + aead, _ := cipher.NewGCM(block) + ciphertext := make([]byte, 32) + plainText := make([]byte, 32) + copy(ciphertext, hs.clientHello.sessionId) + copy(hs.clientHello.sessionId, plainText) // hs.clientHello.sessionId points to hs.clientHello.raw[39:] + if _, err = aead.Open(plainText[:0], hs.clientHello.random[20:], ciphertext, hs.clientHello.raw); err != nil { + break + } + copy(hs.clientHello.sessionId, ciphertext) + copy(hs.c.ClientVer[:], plainText) + copy(hs.c.ClientShortId[:], plainText[8:]) + plainText[0] = 0 + plainText[1] = 0 + plainText[2] = 0 + hs.c.ClientTime = time.Unix(int64(binary.BigEndian.Uint64(plainText)), 0) + if config.Show { + fmt.Printf("REALITY remoteAddr: %v\ths.c.ClientVer: %v\n", remoteAddr, hs.c.ClientVer) + fmt.Printf("REALITY remoteAddr: %v\ths.c.ClientTime: %v\n", remoteAddr, hs.c.ClientTime) + fmt.Printf("REALITY remoteAddr: %v\ths.c.ClientShortId: %v\n", remoteAddr, hs.c.ClientShortId) + } + if (config.MinClientVer == nil || Value(hs.c.ClientVer[:]...) >= Value(config.MinClientVer...)) && + (config.MaxClientVer == nil || Value(hs.c.ClientVer[:]...) <= Value(config.MaxClientVer...)) && + (config.MaxTimeDiff == 0 || time.Since(hs.c.ClientTime).Abs() <= config.MaxTimeDiff) && + (config.ShortIds[hs.c.ClientShortId]) { + hs.c.conn = underlying + } + hs.clientHello.keyShares[0].group = CurveID(i) + break + } + if hs.c.conn == underlying { + if config.Show { + fmt.Printf("REALITY remoteAddr: %v\ths.c.conn: underlying\n", remoteAddr) + } + done = true + } + break + } + if done { + mutex.Unlock() + } else { + copying = true + mutex.Unlock() + io.Copy(target, underlying) + } + waitGroup.Done() + }() + + go func() { + done := false + buf := make([]byte, size) + handshakeLen := 0 + f: + for { + runtime.Gosched() + n, err := target.Read(buf) + mutex.Lock() + if err != nil && err != io.EOF { + conn.Close() + done = true + break + } + if n == 0 { + mutex.Unlock() + continue + } + s2cSaved = append(s2cSaved, buf[:n]...) + if hs.c == nil || hs.c.conn != underlying { + if _, err = conn.Write(buf[:n]); err != nil { + done = true + break + } + if copying || len(s2cSaved) > size { // follow; too long + break + } + mutex.Unlock() + continue + } + done = true // special + if len(s2cSaved) > size { + break + } + check := func(i int) int { + if hs.c.out.handshakeLen[i] != 0 { + return 0 + } + if i == 6 && len(s2cSaved) == 0 { + return 0 + } + if handshakeLen == 0 && len(s2cSaved) > recordHeaderLen { + if Value(s2cSaved[1:3]...) != VersionTLS12 || + (i == 0 && (recordType(s2cSaved[0]) != recordTypeHandshake || s2cSaved[5] != typeServerHello)) || + (i == 1 && (recordType(s2cSaved[0]) != recordTypeChangeCipherSpec || s2cSaved[5] != 1)) || + (i > 1 && recordType(s2cSaved[0]) != recordTypeApplicationData) { + return -1 + } + handshakeLen = recordHeaderLen + Value(s2cSaved[3:5]...) + } + if config.Show { + fmt.Printf("REALITY remoteAddr: %v\tlen(s2cSaved): %v\t%v: %v\n", remoteAddr, len(s2cSaved), names[i], handshakeLen) + } + if handshakeLen > size { // too long + return -1 + } + if i == 1 && handshakeLen > 0 && handshakeLen != 6 { + return -1 + } + if i == 2 && handshakeLen > 512 { + hs.c.out.handshakeLen[i] = handshakeLen + hs.c.out.handshakeBuf = s2cSaved[:0] + return 2 + } + if i == 6 && handshakeLen > 0 { + hs.c.out.handshakeLen[i] = handshakeLen + return 0 + } + if handshakeLen == 0 || len(s2cSaved) < handshakeLen { + mutex.Unlock() + return 1 + } + if i == 0 { + hs.hello = new(serverHelloMsg) + if !hs.hello.unmarshal(s2cSaved[recordHeaderLen:handshakeLen]) || + hs.hello.vers != VersionTLS12 || hs.hello.supportedVersion != VersionTLS13 || + cipherSuiteTLS13ByID(hs.hello.cipherSuite) == nil || + hs.hello.serverShare.group != X25519 || len(hs.hello.serverShare.data) != 32 { + return -1 + } + } + hs.c.out.handshakeLen[i] = handshakeLen + s2cSaved = s2cSaved[handshakeLen:] + handshakeLen = 0 + return 0 + } + for i := 0; i < 7; i++ { + switch check(i) { + case 2: + goto handshake + case 1: + continue f + case 0: + continue + case -1: + break f + } + } + handshake: + err = hs.handshake() + if config.Show { + fmt.Printf("REALITY remoteAddr: %v\ths.handshake() err: %v\n", remoteAddr, err) + } + if err == nil { + handled = true + } + break + } + if done { + mutex.Unlock() + } else { + copying = true + mutex.Unlock() + io.Copy(underlying, target) + } + waitGroup.Done() + }() + + waitGroup.Wait() + target.Close() + if config.Show { + fmt.Printf("REALITY remoteAddr: %v\thandled: %v\n", remoteAddr, handled) + } + if handled { + return hs.c, nil + } + conn.Close() + return nil, errors.New("REALITY: processed invalid connection") + + /* + c := &Conn{ + conn: conn, + config: config, + } + c.handshakeFn = c.serverHandshake + return c + */ } // Client returns a new TLS client side connection @@ -67,7 +437,7 @@ func (l *listener) Accept() (net.Conn, error) { if err != nil { return nil, err } - return Server(c, l.config), nil + return Server(c, l.config) } // NewListener creates a Listener which accepts connections from an inner