diff --git a/record_detect.go b/record_detect.go index 2e0fc2b..923d179 100644 --- a/record_detect.go +++ b/record_detect.go @@ -12,50 +12,48 @@ import ( utls "github.com/refraction-networking/utls" ) -var GlobalPostHandshakeRecordsLock sync.Mutex +var GlobalPostHandshakeRecordsLens sync.Map -var GlobalPostHandshakeRecordsLens map[*Config]map[string][]int - -func DetectPostHandshakeRecordsLens(config *Config) map[string][]int { - GlobalPostHandshakeRecordsLock.Lock() - defer GlobalPostHandshakeRecordsLock.Unlock() - if GlobalPostHandshakeRecordsLens == nil { - GlobalPostHandshakeRecordsLens = make(map[*Config]map[string][]int) - } - if GlobalPostHandshakeRecordsLens[config] == nil { - GlobalPostHandshakeRecordsLens[config] = make(map[string][]int) - for sni := range config.ServerNames { - target, err := net.Dial("tcp", config.Dest) - if err != nil { - continue - } - if config.Xver == 1 || config.Xver == 2 { - if _, err = proxyproto.HeaderProxyFromAddrs(config.Xver, target.LocalAddr(), target.RemoteAddr()).WriteTo(target); err != nil { - continue +func DetectPostHandshakeRecordsLens(config *Config) { + for sni := range config.ServerNames { + key := config.Dest + " " + sni + if _, loaded := GlobalPostHandshakeRecordsLens.LoadOrStore(key, false); !loaded { + go func() { + defer func() { + val, _ := GlobalPostHandshakeRecordsLens.Load(key) + if _, ok := val.(bool); ok { + GlobalPostHandshakeRecordsLens.Store(key, []int{}) + } + }() + target, err := net.Dial("tcp", config.Dest) + if err != nil { + return } - } - detectConn := &DetectConn{ - Conn: target, - PostHandshakeRecordsLens: GlobalPostHandshakeRecordsLens[config], - Sni: sni, - } - uConn := utls.UClient(detectConn, &utls.Config{ - ServerName: sni, - }, utls.HelloChrome_Auto) - if err = uConn.Handshake(); err != nil { - continue - } - io.Copy(io.Discard, uConn) + if config.Xver == 1 || config.Xver == 2 { + if _, err = proxyproto.HeaderProxyFromAddrs(config.Xver, target.LocalAddr(), target.RemoteAddr()).WriteTo(target); err != nil { + return + } + } + detectConn := &DetectConn{ + Conn: target, + Key: key, + } + uConn := utls.UClient(detectConn, &utls.Config{ + ServerName: sni, // needs new loopvar behaviour + }, utls.HelloChrome_Auto) + if err = uConn.Handshake(); err != nil { + return + } + io.Copy(io.Discard, uConn) + }() } } - return GlobalPostHandshakeRecordsLens[config] } type DetectConn struct { net.Conn - PostHandshakeRecordsLens map[string][]int - Sni string - CcsSent bool + Key string + CcsSent bool } func (c *DetectConn) Write(b []byte) (n int, err error) { @@ -71,14 +69,16 @@ func (c *DetectConn) Read(b []byte) (n int, err error) { } c.Conn.SetReadDeadline(time.Now().Add(5 * time.Second)) data, _ := io.ReadAll(c.Conn) + var postHandshakeRecordsLens []int for { if len(data) >= 5 && bytes.Equal(data[:3], []byte{23, 3, 3}) { length := int(binary.BigEndian.Uint16(data[3:5])) + 5 - c.PostHandshakeRecordsLens[c.Sni] = append(c.PostHandshakeRecordsLens[c.Sni], length) + postHandshakeRecordsLens = append(postHandshakeRecordsLens, length) data = data[length:] } else { break } } + GlobalPostHandshakeRecordsLens.Store(c.Key, postHandshakeRecordsLens) return 0, io.EOF } diff --git a/tls.go b/tls.go index f3f409e..adcb037 100644 --- a/tls.go +++ b/tls.go @@ -157,13 +157,9 @@ func Value(vals ...byte) (value int) { return } -// Server returns a new TLS server side connection -// using conn as the underlying transport. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. +// You MUST call `DetectPostHandshakeRecordsLens(config)` in advance manually +// if you don't use REALITY's listener, e.g., Xray-core's RAW transport. func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { - postHandshakeRecordsLens := DetectPostHandshakeRecordsLens(config) - remoteAddr := conn.RemoteAddr().String() if config.Show { fmt.Printf("REALITY remoteAddr: %v\n", remoteAddr) @@ -374,20 +370,28 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if err != nil { break } - for _, length := range postHandshakeRecordsLens[hs.clientHello.serverName] { - plainText := make([]byte, length-16) - plainText[0] = 23 - plainText[1] = 3 - plainText[2] = 3 - plainText[3] = byte((length - 5) >> 8) - plainText[4] = byte((length - 5)) - plainText[5] = 23 - postHandshakeRecord := hs.c.out.cipher.(aead).Seal(plainText[:5], hs.c.out.seq[:], plainText[5:], plainText[:5]) - hs.c.out.incSeq() - hs.c.write(postHandshakeRecord) - if config.Show { - fmt.Printf("REALITY remoteAddr: %v\tlen(postHandshakeRecord): %v\n", remoteAddr, len(postHandshakeRecord)) + for { + if val, ok := GlobalPostHandshakeRecordsLens.Load(config.Dest + " " + hs.clientHello.serverName); ok { + if postHandshakeRecordsLens, ok := val.([]int); ok { + for _, length := range postHandshakeRecordsLens { + plainText := make([]byte, length-16) + plainText[0] = 23 + plainText[1] = 3 + plainText[2] = 3 + plainText[3] = byte((length - 5) >> 8) + plainText[4] = byte((length - 5)) + plainText[5] = 23 + postHandshakeRecord := hs.c.out.cipher.(aead).Seal(plainText[:5], hs.c.out.seq[:], plainText[5:], plainText[:5]) + hs.c.out.incSeq() + hs.c.write(postHandshakeRecord) + if config.Show { + fmt.Printf("REALITY remoteAddr: %v\tlen(postHandshakeRecord): %v\n", remoteAddr, len(postHandshakeRecord)) + } + } + break + } } + time.Sleep(5 * time.Second) } hs.c.isHandshakeComplete.Store(true) break