diff --git a/tls.go b/tls.go index d17750f..de06d22 100644 --- a/tls.go +++ b/tls.go @@ -40,6 +40,11 @@ import ( "golang.org/x/crypto/hkdf" ) +type CloseWriteConn interface { + net.Conn + CloseWrite() error +} + type MirrorConn struct { *sync.Mutex net.Conn @@ -125,10 +130,11 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { } } - underlying := conn - if pc, ok := underlying.(*proxyproto.Conn); ok { - underlying = pc.Raw() // for TCP splicing in io.Copy() + raw := conn + if pc, ok := conn.(*proxyproto.Conn); ok { + raw = pc.Raw() // for TCP splicing in io.Copy() } + underlying := raw.(CloseWriteConn) // *net.TCPConn or *net.UnixConn mutex := new(sync.Mutex) @@ -334,6 +340,10 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { } conn.Write(s2cSaved) io.Copy(underlying, target) + // Here is bidirectional direct forwarding: + // client ---underlying--- server ---target--- dest + // Call `underlying.CloseWrite()` once `io.Copy()` returned + underlying.CloseWrite() } waitGroup.Done() }()