diff --git a/common.go b/common.go index 8cd7f91..bd14b98 100644 --- a/common.go +++ b/common.go @@ -538,9 +538,9 @@ const ( ) type LimitFallback struct { - AfterBytes int64 - BytesPerSec int64 - BurstBytesPerSec int64 + AfterBytes uint64 + BytesPerSec uint64 + BurstBytesPerSec uint64 } // A Config structure is used to configure a TLS client or server. diff --git a/tls.go b/tls.go index bbe0ec4..0038fed 100644 --- a/tls.go +++ b/tls.go @@ -43,6 +43,7 @@ import ( "errors" "fmt" "io" + "math" "net" "os" "runtime" @@ -119,11 +120,18 @@ func (c *RatelimitedConn) Write(b []byte) (int, error) { return n, err } -func NewBucketWithRate(bytesPerSec int64, burstBytesPerSec int64) *ratelimit.Bucket { +func NewBucketWithRate(bytesPerSec uint64, burstBytesPerSec uint64) *ratelimit.Bucket { if burstBytesPerSec < bytesPerSec { burstBytesPerSec = bytesPerSec } - return ratelimit.NewBucketWithRate(float64(bytesPerSec), burstBytesPerSec) + return ratelimit.NewBucketWithRate(float64(bytesPerSec), ToInt64(burstBytesPerSec)) +} + +func ToInt64(u uint64) int64 { + if u > math.MaxInt64 { + return math.MaxInt64 + } + return int64(u) } var ( @@ -254,14 +262,14 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if config.Show && hs.clientHello != nil { fmt.Printf("REALITY remoteAddr: %v\tforwarded SNI: %v\n", remoteAddr, hs.clientHello.serverName) } - if config.LimitFallbackUpload.BytesPerSec == 0 || config.LimitFallbackUpload.BurstBytesPerSec == 0 { + if config.LimitFallbackUpload.BytesPerSec == 0 { io.Copy(target, underlying) } else { // Limit upload speed for fallback connection io.Copy(&RatelimitedConn{ Conn: target, Bucket: NewBucketWithRate(config.LimitFallbackUpload.BytesPerSec, config.LimitFallbackUpload.BurstBytesPerSec), - LimitAfter: config.LimitFallbackUpload.AfterBytes - config.LimitFallbackUpload.BurstBytesPerSec, + LimitAfter: ToInt64(config.LimitFallbackUpload.AfterBytes) - ToInt64(config.LimitFallbackUpload.BurstBytesPerSec), }, underlying) } } @@ -394,28 +402,28 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if hs.c.conn == conn { // if we processed the Client Hello successfully but the target did not waitGroup.Add(1) go func() { - if config.LimitFallbackUpload.BytesPerSec == 0 || config.LimitFallbackUpload.BurstBytesPerSec == 0 { + if config.LimitFallbackUpload.BytesPerSec == 0 { io.Copy(target, underlying) } else { // Limit upload speed for fallback connection (handshake ok but hello failed) io.Copy(&RatelimitedConn{ Conn: target, Bucket: NewBucketWithRate(config.LimitFallbackUpload.BytesPerSec, config.LimitFallbackUpload.BurstBytesPerSec), - LimitAfter: config.LimitFallbackUpload.AfterBytes - config.LimitFallbackUpload.BurstBytesPerSec, + LimitAfter: ToInt64(config.LimitFallbackUpload.AfterBytes) - ToInt64(config.LimitFallbackUpload.BurstBytesPerSec), }, underlying) } waitGroup.Done() }() } conn.Write(s2cSaved) - if config.LimitFallbackDownload.BytesPerSec == 0 || config.LimitFallbackDownload.BurstBytesPerSec == 0 { + if config.LimitFallbackDownload.BytesPerSec == 0 { io.Copy(underlying, target) } else { // Limit download speed for fallback connection io.Copy(&RatelimitedConn{ Conn: underlying, Bucket: NewBucketWithRate(config.LimitFallbackDownload.BytesPerSec, config.LimitFallbackDownload.BurstBytesPerSec), - LimitAfter: config.LimitFallbackDownload.AfterBytes - config.LimitFallbackDownload.BurstBytesPerSec, + LimitAfter: ToInt64(config.LimitFallbackDownload.AfterBytes) - ToInt64(config.LimitFallbackDownload.BurstBytesPerSec), }, target) } // Here is bidirectional direct forwarding: