0
0
mirror of https://github.com/XTLS/REALITY.git synced 2025-08-27 17:05:36 +00:00

Allow fragmented REALITY Client Hello & Simplify logic

It's mainly for defending against certain attacks.
This commit is contained in:
RPRX 2023-08-28 17:12:59 +00:00 committed by GitHub
parent e07c3b04b9
commit e426190d57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

129
tls.go
View File

@ -41,31 +41,43 @@ import (
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )
type WeakConn struct { type MirrorConn struct {
*sync.Mutex
net.Conn net.Conn
Target net.Conn
} }
func (c *WeakConn) Read(b []byte) (int, error) { func (c *MirrorConn) Read(b []byte) (int, error) {
return 0, fmt.Errorf("Read(%v)", len(b)) c.Unlock()
runtime.Gosched()
n, err := c.Conn.Read(b)
c.Lock() // calling c.Lock() before c.Target.Write(), to make sure that this goroutine has the priority to make the next move
if n != 0 {
c.Target.Write(b[:n])
}
if err != nil {
c.Target.Close()
}
return n, err
} }
func (c *WeakConn) Write(b []byte) (int, error) { func (c *MirrorConn) Write(b []byte) (int, error) {
return 0, fmt.Errorf("Write(%v)", len(b)) return 0, fmt.Errorf("Write(%v)", len(b))
} }
func (c *WeakConn) Close() error { func (c *MirrorConn) Close() error {
return fmt.Errorf("Close()") return fmt.Errorf("Close()")
} }
func (c *WeakConn) SetDeadline(t time.Time) error { func (c *MirrorConn) SetDeadline(t time.Time) error {
return nil return nil
} }
func (c *WeakConn) SetReadDeadline(t time.Time) error { func (c *MirrorConn) SetReadDeadline(t time.Time) error {
return nil return nil
} }
func (c *WeakConn) SetWriteDeadline(t time.Time) error { func (c *MirrorConn) SetWriteDeadline(t time.Time) error {
return nil return nil
} }
@ -116,68 +128,33 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
underlying := conn underlying := conn
if pc, ok := underlying.(*proxyproto.Conn); ok { if pc, ok := underlying.(*proxyproto.Conn); ok {
underlying = pc.Raw() underlying = pc.Raw() // for TCP splicing in io.Copy()
} }
hs := serverHandshakeStateTLS13{ctx: context.Background()} mutex := new(sync.Mutex)
c2sSaved := make([]byte, 0, size) hs := serverHandshakeStateTLS13{
s2cSaved := make([]byte, 0, size) c: &Conn{
conn: &MirrorConn{
Mutex: mutex,
Conn: conn,
Target: target,
},
config: config,
},
ctx: context.Background(),
}
copying := false copying := false
handled := false
waitGroup := new(sync.WaitGroup) waitGroup := new(sync.WaitGroup)
waitGroup.Add(2) waitGroup.Add(2)
mutex := new(sync.Mutex)
go func() { go func() {
done := false
buf := make([]byte, size)
clientHelloLen := 0
for { for {
runtime.Gosched()
n, err := conn.Read(buf)
if n == 0 {
if err != nil {
target.Close()
waitGroup.Done()
return
}
continue
}
mutex.Lock() mutex.Lock()
c2sSaved = append(c2sSaved, buf[:n]...) hs.clientHello, err = hs.c.readClientHello(context.Background()) // TODO: Change some rules in this function.
if _, err = target.Write(buf[:n]); err != nil { if copying || err != nil || hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
done = true
break
}
if len(c2sSaved) > size || copying { // too long; follow
break
}
if clientHelloLen == 0 && len(c2sSaved) > recordHeaderLen {
if recordType(c2sSaved[0]) != recordTypeHandshake || Value(c2sSaved[1:3]...) != VersionTLS10 || c2sSaved[5] != typeClientHello {
break
}
clientHelloLen = recordHeaderLen + Value(c2sSaved[3:5]...)
}
if clientHelloLen > size { // too long
break
}
if clientHelloLen == 0 || len(c2sSaved) < clientHelloLen {
mutex.Unlock()
continue
}
hs.c = &Conn{
conn: &WeakConn{conn},
config: config,
rawInput: *bytes.NewBuffer(c2sSaved),
}
if hs.clientHello, err = hs.c.readClientHello(context.Background()); err != nil {
break
}
if hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
break break
} }
for i, keyShare := range hs.clientHello.keyShares { for i, keyShare := range hs.clientHello.keyShares {
@ -228,20 +205,17 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
if config.Show { if config.Show {
fmt.Printf("REALITY remoteAddr: %v\ths.c.conn == conn: %v\n", remoteAddr, hs.c.conn == conn) fmt.Printf("REALITY remoteAddr: %v\ths.c.conn == conn: %v\n", remoteAddr, hs.c.conn == conn)
} }
if hs.c.conn == conn {
done = true
}
break break
} }
mutex.Unlock() mutex.Unlock()
if !done { if hs.c.conn != conn {
io.CopyBuffer(target, underlying, buf) io.Copy(target, underlying)
} }
waitGroup.Done() waitGroup.Done()
}() }()
go func() { go func() {
done := false s2cSaved := make([]byte, 0, size)
buf := make([]byte, size) buf := make([]byte, size)
handshakeLen := 0 handshakeLen := 0
f: f:
@ -258,14 +232,10 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
} }
mutex.Lock() mutex.Lock()
s2cSaved = append(s2cSaved, buf[:n]...) s2cSaved = append(s2cSaved, buf[:n]...)
if hs.c == nil || hs.c.conn != conn { if hs.c.conn != conn {
copying = true copying = true // if the target already sent some data, just start bidirectional direct forwarding
if _, err = conn.Write(buf[:n]); err != nil {
done = true
}
break break
} }
done = true // special
if len(s2cSaved) > size { if len(s2cSaved) > size {
break break
} }
@ -349,12 +319,19 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
break break
} }
atomic.StoreUint32(&hs.c.handshakeStatus, 1) atomic.StoreUint32(&hs.c.handshakeStatus, 1)
handled = true
break break
} }
mutex.Unlock() mutex.Unlock()
if !done { if hs.c.out.handshakeLen[0] == 0 { // if the target sent an incorrect Server Hello, or before that
io.CopyBuffer(underlying, target, buf) if hs.c.conn == conn { // if we processed the Client Hello successfully but the target did not
waitGroup.Add(1)
go func() {
io.Copy(target, underlying)
waitGroup.Done()
}()
}
conn.Write(s2cSaved)
io.Copy(underlying, target)
} }
waitGroup.Done() waitGroup.Done()
}() }()
@ -362,13 +339,13 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
waitGroup.Wait() waitGroup.Wait()
target.Close() target.Close()
if config.Show { if config.Show {
fmt.Printf("REALITY remoteAddr: %v\thandled: %v\n", remoteAddr, handled) fmt.Printf("REALITY remoteAddr: %v\ths.c.handshakeStatus: %v\n", remoteAddr, atomic.LoadUint32(&hs.c.handshakeStatus))
} }
if handled { if atomic.LoadUint32(&hs.c.handshakeStatus) == 1 {
return hs.c, nil return hs.c, nil
} }
conn.Close() conn.Close()
return nil, errors.New("REALITY: processed invalid connection") return nil, errors.New("REALITY: processed invalid connection") // TODO: Add details.
/* /*
c := &Conn{ c := &Conn{