mirror of
https://github.com/XTLS/REALITY.git
synced 2025-08-27 08:55:32 +00:00
Allow fragmented REALITY Client Hello & Simplify logic
It's mainly for defending against certain attacks.
This commit is contained in:
parent
e07c3b04b9
commit
e426190d57
129
tls.go
129
tls.go
@ -41,31 +41,43 @@ import (
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
type WeakConn struct {
|
||||
type MirrorConn struct {
|
||||
*sync.Mutex
|
||||
net.Conn
|
||||
Target net.Conn
|
||||
}
|
||||
|
||||
func (c *WeakConn) Read(b []byte) (int, error) {
|
||||
return 0, fmt.Errorf("Read(%v)", len(b))
|
||||
func (c *MirrorConn) Read(b []byte) (int, error) {
|
||||
c.Unlock()
|
||||
runtime.Gosched()
|
||||
n, err := c.Conn.Read(b)
|
||||
c.Lock() // calling c.Lock() before c.Target.Write(), to make sure that this goroutine has the priority to make the next move
|
||||
if n != 0 {
|
||||
c.Target.Write(b[:n])
|
||||
}
|
||||
if err != nil {
|
||||
c.Target.Close()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *WeakConn) Write(b []byte) (int, error) {
|
||||
func (c *MirrorConn) Write(b []byte) (int, error) {
|
||||
return 0, fmt.Errorf("Write(%v)", len(b))
|
||||
}
|
||||
|
||||
func (c *WeakConn) Close() error {
|
||||
func (c *MirrorConn) Close() error {
|
||||
return fmt.Errorf("Close()")
|
||||
}
|
||||
|
||||
func (c *WeakConn) SetDeadline(t time.Time) error {
|
||||
func (c *MirrorConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeakConn) SetReadDeadline(t time.Time) error {
|
||||
func (c *MirrorConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeakConn) SetWriteDeadline(t time.Time) error {
|
||||
func (c *MirrorConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -116,68 +128,33 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
|
||||
|
||||
underlying := conn
|
||||
if pc, ok := underlying.(*proxyproto.Conn); ok {
|
||||
underlying = pc.Raw()
|
||||
underlying = pc.Raw() // for TCP splicing in io.Copy()
|
||||
}
|
||||
|
||||
hs := serverHandshakeStateTLS13{ctx: context.Background()}
|
||||
mutex := new(sync.Mutex)
|
||||
|
||||
c2sSaved := make([]byte, 0, size)
|
||||
s2cSaved := make([]byte, 0, size)
|
||||
hs := serverHandshakeStateTLS13{
|
||||
c: &Conn{
|
||||
conn: &MirrorConn{
|
||||
Mutex: mutex,
|
||||
Conn: conn,
|
||||
Target: target,
|
||||
},
|
||||
config: config,
|
||||
},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
copying := false
|
||||
handled := false
|
||||
|
||||
waitGroup := new(sync.WaitGroup)
|
||||
waitGroup.Add(2)
|
||||
|
||||
mutex := new(sync.Mutex)
|
||||
|
||||
go func() {
|
||||
done := false
|
||||
buf := make([]byte, size)
|
||||
clientHelloLen := 0
|
||||
for {
|
||||
runtime.Gosched()
|
||||
n, err := conn.Read(buf)
|
||||
if n == 0 {
|
||||
if err != nil {
|
||||
target.Close()
|
||||
waitGroup.Done()
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
mutex.Lock()
|
||||
c2sSaved = append(c2sSaved, buf[:n]...)
|
||||
if _, err = target.Write(buf[:n]); err != nil {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
if len(c2sSaved) > size || copying { // too long; follow
|
||||
break
|
||||
}
|
||||
if clientHelloLen == 0 && len(c2sSaved) > recordHeaderLen {
|
||||
if recordType(c2sSaved[0]) != recordTypeHandshake || Value(c2sSaved[1:3]...) != VersionTLS10 || c2sSaved[5] != typeClientHello {
|
||||
break
|
||||
}
|
||||
clientHelloLen = recordHeaderLen + Value(c2sSaved[3:5]...)
|
||||
}
|
||||
if clientHelloLen > size { // too long
|
||||
break
|
||||
}
|
||||
if clientHelloLen == 0 || len(c2sSaved) < clientHelloLen {
|
||||
mutex.Unlock()
|
||||
continue
|
||||
}
|
||||
hs.c = &Conn{
|
||||
conn: &WeakConn{conn},
|
||||
config: config,
|
||||
rawInput: *bytes.NewBuffer(c2sSaved),
|
||||
}
|
||||
if hs.clientHello, err = hs.c.readClientHello(context.Background()); err != nil {
|
||||
break
|
||||
}
|
||||
if hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
|
||||
hs.clientHello, err = hs.c.readClientHello(context.Background()) // TODO: Change some rules in this function.
|
||||
if copying || err != nil || hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
|
||||
break
|
||||
}
|
||||
for i, keyShare := range hs.clientHello.keyShares {
|
||||
@ -228,20 +205,17 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
|
||||
if config.Show {
|
||||
fmt.Printf("REALITY remoteAddr: %v\ths.c.conn == conn: %v\n", remoteAddr, hs.c.conn == conn)
|
||||
}
|
||||
if hs.c.conn == conn {
|
||||
done = true
|
||||
}
|
||||
break
|
||||
}
|
||||
mutex.Unlock()
|
||||
if !done {
|
||||
io.CopyBuffer(target, underlying, buf)
|
||||
if hs.c.conn != conn {
|
||||
io.Copy(target, underlying)
|
||||
}
|
||||
waitGroup.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
done := false
|
||||
s2cSaved := make([]byte, 0, size)
|
||||
buf := make([]byte, size)
|
||||
handshakeLen := 0
|
||||
f:
|
||||
@ -258,14 +232,10 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
|
||||
}
|
||||
mutex.Lock()
|
||||
s2cSaved = append(s2cSaved, buf[:n]...)
|
||||
if hs.c == nil || hs.c.conn != conn {
|
||||
copying = true
|
||||
if _, err = conn.Write(buf[:n]); err != nil {
|
||||
done = true
|
||||
}
|
||||
if hs.c.conn != conn {
|
||||
copying = true // if the target already sent some data, just start bidirectional direct forwarding
|
||||
break
|
||||
}
|
||||
done = true // special
|
||||
if len(s2cSaved) > size {
|
||||
break
|
||||
}
|
||||
@ -349,12 +319,19 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
|
||||
break
|
||||
}
|
||||
atomic.StoreUint32(&hs.c.handshakeStatus, 1)
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
mutex.Unlock()
|
||||
if !done {
|
||||
io.CopyBuffer(underlying, target, buf)
|
||||
if hs.c.out.handshakeLen[0] == 0 { // if the target sent an incorrect Server Hello, or before that
|
||||
if hs.c.conn == conn { // if we processed the Client Hello successfully but the target did not
|
||||
waitGroup.Add(1)
|
||||
go func() {
|
||||
io.Copy(target, underlying)
|
||||
waitGroup.Done()
|
||||
}()
|
||||
}
|
||||
conn.Write(s2cSaved)
|
||||
io.Copy(underlying, target)
|
||||
}
|
||||
waitGroup.Done()
|
||||
}()
|
||||
@ -362,13 +339,13 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
|
||||
waitGroup.Wait()
|
||||
target.Close()
|
||||
if config.Show {
|
||||
fmt.Printf("REALITY remoteAddr: %v\thandled: %v\n", remoteAddr, handled)
|
||||
fmt.Printf("REALITY remoteAddr: %v\ths.c.handshakeStatus: %v\n", remoteAddr, atomic.LoadUint32(&hs.c.handshakeStatus))
|
||||
}
|
||||
if handled {
|
||||
if atomic.LoadUint32(&hs.c.handshakeStatus) == 1 {
|
||||
return hs.c, nil
|
||||
}
|
||||
conn.Close()
|
||||
return nil, errors.New("REALITY: processed invalid connection")
|
||||
return nil, errors.New("REALITY: processed invalid connection") // TODO: Add details.
|
||||
|
||||
/*
|
||||
c := &Conn{
|
||||
|
Loading…
Reference in New Issue
Block a user