From 48f0b2d5ed6dd36a84cdf23ed5970dd54c22d8e1 Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Fri, 12 Jul 2024 05:55:06 +0000 Subject: [PATCH] Call `underlying.CloseWrite()` once `io.Copy()` returned (#7) Co-authored-by: Fangliding --- tls.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tls.go b/tls.go index d17750f..de06d22 100644 --- a/tls.go +++ b/tls.go @@ -40,6 +40,11 @@ import ( "golang.org/x/crypto/hkdf" ) +type CloseWriteConn interface { + net.Conn + CloseWrite() error +} + type MirrorConn struct { *sync.Mutex net.Conn @@ -125,10 +130,11 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { } } - underlying := conn - if pc, ok := underlying.(*proxyproto.Conn); ok { - underlying = pc.Raw() // for TCP splicing in io.Copy() + raw := conn + if pc, ok := conn.(*proxyproto.Conn); ok { + raw = pc.Raw() // for TCP splicing in io.Copy() } + underlying := raw.(CloseWriteConn) // *net.TCPConn or *net.UnixConn mutex := new(sync.Mutex) @@ -334,6 +340,10 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { } conn.Write(s2cSaved) io.Copy(underlying, target) + // Here is bidirectional direct forwarding: + // client ---underlying--- server ---target--- dest + // Call `underlying.CloseWrite()` once `io.Copy()` returned + underlying.CloseWrite() } waitGroup.Done() }()