diff --git a/record_detect.go b/record_detect.go index 3aa117c..56510b7 100644 --- a/record_detect.go +++ b/record_detect.go @@ -6,76 +6,78 @@ import ( "encoding/binary" "io" "net" + "sync" "time" + + "github.com/pires/go-proxyproto" ) -func DetectRecordFingerprint(target string) ([]int, error) { - NetConn, err := net.Dial("tcp", target) - if err != nil { - return nil, err +var lock sync.Mutex + +var PostHandshakeRecordsLen map[*Config]map[string][]int + +func DetectPostHandshakeRecords(config *Config) { + lock.Lock() + if PostHandshakeRecordsLen == nil { + PostHandshakeRecordsLen = make(map[*Config]map[string][]int) } - conn := &detectConn{ - Conn: NetConn, - resultChan: make(chan []int, 1), - } - host, _, err := net.SplitHostPort(target) - if err != nil { - return nil, err - } - tlsConfig := &tls.Config{ - ServerName: host, - } - tlsConn := tls.Client(conn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - return nil, err - } - go func() { - io.Copy(io.Discard, tlsConn) - }() - select { - case result := <-conn.resultChan: - return result, nil - case <-time.After(2 * time.Second): - return nil, nil + if PostHandshakeRecordsLen[config] == nil { + PostHandshakeRecordsLen[config] = make(map[string][]int) + for sni := range config.ServerNames { + target, err := net.Dial("tcp", config.Dest) + if err != nil { + return + } + if config.Xver == 1 || config.Xver == 2 { + if _, err = proxyproto.HeaderProxyFromAddrs(config.Xver, target.LocalAddr(), target.RemoteAddr()).WriteTo(target); err != nil { + return + } + } + detectConn := &DetectConn{ + Conn: target, + config: config, + sni: sni, + } + tlsConn := tls.Client(detectConn, &tls.Config{ + ServerName: sni, + }) + if err = tlsConn.Handshake(); err != nil { + return + } + io.Copy(io.Discard, tlsConn) + } } + lock.Unlock() } -type detectConn struct { +type DetectConn struct { net.Conn - ccsSent bool - done bool - resultChan chan ([]int) + config *Config + sni string + ccsSent bool } -func (c *detectConn) Write(b []byte) (n int, err error) { +func (c *DetectConn) Write(b []byte) (n int, err error) { if len(b) >= 3 && bytes.Equal(b[:3], []byte{20, 3, 3}) { c.ccsSent = true } return c.Conn.Write(b) } -func (c *detectConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - if c.ccsSent && !c.done { - data := make([]byte, len(b)) - copy(data, b) - var result []int - for { - if len(data) > 3 && bytes.Equal(data[:3], []byte{23, 3, 3}) { - length := int(binary.BigEndian.Uint16(data[3:5])) - if len(data) > length+5 { - result = append(result, int(length)) - data = data[length+5:] - } - } else { - break - } - } - if len(result) != 1 { - c.done = true - c.resultChan <- result +func (c *DetectConn) Read(b []byte) (n int, err error) { + if !c.ccsSent { + return c.Conn.Read(b) + } + c.Conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + data, _ := io.ReadAll(c.Conn) + for { + if len(data) >= 5 && bytes.Equal(data[:3], []byte{23, 3, 3}) { + length := int(binary.BigEndian.Uint16(data[3:5])) + 5 + PostHandshakeRecordsLen[c.config][c.sni] = append(PostHandshakeRecordsLen[c.config][c.sni], length) + data = data[length:] + } else { + break } } - return n, err + return 0, io.EOF } diff --git a/tls.go b/tls.go index e0af111..932a868 100644 --- a/tls.go +++ b/tls.go @@ -126,6 +126,8 @@ func Value(vals ...byte) (value int) { // The configuration config must be non-nil and must include // at least one certificate or else set GetCertificate. func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { + DetectPostHandshakeRecords(config) + remoteAddr := conn.RemoteAddr().String() if config.Show { fmt.Printf("REALITY remoteAddr: %v\n", remoteAddr) @@ -336,6 +338,19 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if err != nil { break } + for _, length := range PostHandshakeRecordsLen[config][hs.clientHello.serverName] { + plainText := make([]byte, length-16) + plainText[0] = 23 + plainText[1] = 3 + plainText[2] = 3 + plainText[3] = byte((length - 5) >> 8) + plainText[4] = byte((length - 5)) + plainText[5] = 23 + postHandshakeRecord := hs.c.out.cipher.(aead).Seal(plainText[:5], hs.c.out.seq[:], plainText[5:], plainText[:5]) + hs.c.out.incSeq() + hs.c.write(postHandshakeRecord) + fmt.Printf("REALITY remoteAddr: %v\tlen(postHandshakeRecord): %v\n", remoteAddr, len(postHandshakeRecord)) + } hs.c.isHandshakeComplete.Store(true) break }