diff --git a/tls.go b/tls.go index 455419d..fef2e1c 100644 --- a/tls.go +++ b/tls.go @@ -120,22 +120,19 @@ func (c *RatelimitedConn) Read(b []byte) (int, error) { } func NewRatelimitedConn(conn net.Conn, limit *LimitFallback) net.Conn { - bytesPerSec := limit.BytesPerSec - burstBytesPerSec := limit.BurstBytesPerSec - afterBytes := limit.AfterBytes - - if bytesPerSec == 0 { + if limit.BytesPerSec == 0 { return conn } - if burstBytesPerSec < bytesPerSec { - burstBytesPerSec = bytesPerSec + burstBytesPerSec := limit.BurstBytesPerSec + if burstBytesPerSec < limit.BytesPerSec { + burstBytesPerSec = limit.BytesPerSec } return &RatelimitedConn{ Conn: conn, - Bucket: ratelimit.NewBucketWithRate(float64(bytesPerSec), int64(burstBytesPerSec)), - LimitAfter: int64(afterBytes), + Bucket: ratelimit.NewBucketWithRate(float64(limit.BytesPerSec), int64(burstBytesPerSec)), + LimitAfter: int64(limit.AfterBytes), } } @@ -267,7 +264,6 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if config.Show && hs.clientHello != nil { fmt.Printf("REALITY remoteAddr: %v\tforwarded SNI: %v\n", remoteAddr, hs.clientHello.serverName) } - // Limit upload speed for fallback connection io.Copy(target, NewRatelimitedConn(underlying, &config.LimitFallbackUpload)) } waitGroup.Done() @@ -399,13 +395,11 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if hs.c.conn == conn { // if we processed the Client Hello successfully but the target did not waitGroup.Add(1) go func() { - // Limit upload speed for fallback connection (handshake ok but hello failed) io.Copy(target, NewRatelimitedConn(underlying, &config.LimitFallbackUpload)) waitGroup.Done() }() } conn.Write(s2cSaved) - // Limit download speed for fallback connection io.Copy(underlying, NewRatelimitedConn(target, &config.LimitFallbackDownload)) // Here is bidirectional direct forwarding: // client ---underlying--- server ---target--- dest