diff --git a/tls.go b/tls.go index d0eaba4..02cede4 100644 --- a/tls.go +++ b/tls.go @@ -43,7 +43,6 @@ import ( "errors" "fmt" "io" - "math" "net" "os" "runtime" @@ -120,18 +119,20 @@ func (c *RatelimitedConn) Read(b []byte) (int, error) { return n, err } -func NewBucketWithRate(bytesPerSec uint64, burstBytesPerSec uint64) *ratelimit.Bucket { +func NewRatelimitedConn(con net.Conn, config *Config) *RatelimitedConn { + bytesPerSec := config.LimitFallbackUpload.BytesPerSec + burstBytesPerSec := config.LimitFallbackUpload.BurstBytesPerSec + afterBytes := config.LimitFallbackUpload.AfterBytes + if burstBytesPerSec < bytesPerSec { burstBytesPerSec = bytesPerSec } - return ratelimit.NewBucketWithRate(float64(bytesPerSec), ToInt64(burstBytesPerSec)) -} -func ToInt64(u uint64) int64 { - if u > math.MaxInt64 { - return math.MaxInt64 + return &RatelimitedConn{ + Conn: con, + Bucket: ratelimit.NewBucketWithRate(float64(bytesPerSec), int64(burstBytesPerSec)), + LimitAfter: int64(afterBytes) - int64(burstBytesPerSec), } - return int64(u) } var ( @@ -266,11 +267,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { io.Copy(target, underlying) } else { // Limit upload speed for fallback connection - io.Copy(target, &RatelimitedConn{ - Conn: underlying, - Bucket: NewBucketWithRate(config.LimitFallbackUpload.BytesPerSec, config.LimitFallbackUpload.BurstBytesPerSec), - LimitAfter: ToInt64(config.LimitFallbackUpload.AfterBytes) - ToInt64(config.LimitFallbackUpload.BurstBytesPerSec), - }) + io.Copy(target, NewRatelimitedConn(underlying, config)) } } waitGroup.Done() @@ -406,11 +403,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { io.Copy(target, underlying) } else { // Limit upload speed for fallback connection (handshake ok but hello failed) - io.Copy(target, &RatelimitedConn{ - Conn: underlying, - Bucket: NewBucketWithRate(config.LimitFallbackUpload.BytesPerSec, config.LimitFallbackUpload.BurstBytesPerSec), - LimitAfter: ToInt64(config.LimitFallbackUpload.AfterBytes) - ToInt64(config.LimitFallbackUpload.BurstBytesPerSec), - }) + io.Copy(target, NewRatelimitedConn(underlying, config)) } waitGroup.Done() }() @@ -420,11 +413,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { io.Copy(underlying, target) } else { // Limit download speed for fallback connection - io.Copy(underlying, &RatelimitedConn{ - Conn: target, - Bucket: NewBucketWithRate(config.LimitFallbackDownload.BytesPerSec, config.LimitFallbackDownload.BurstBytesPerSec), - LimitAfter: ToInt64(config.LimitFallbackDownload.AfterBytes) - ToInt64(config.LimitFallbackDownload.BurstBytesPerSec), - }) + io.Copy(underlying, NewRatelimitedConn(target, config)) } // Here is bidirectional direct forwarding: // client ---underlying--- server ---target--- dest