diff --git a/common.go b/common.go index a8b54c8..bca6d04 100644 --- a/common.go +++ b/common.go @@ -515,6 +515,8 @@ const ( // modified. A Config may be reused; the tls package will also not // modify it. type Config struct { + DialContext func(ctx context.Context, network, address string) (net.Conn, error) + Show bool Type string Dest string diff --git a/tls.go b/tls.go index 82bb320..d414a7e 100644 --- a/tls.go +++ b/tls.go @@ -115,13 +115,13 @@ func Value(vals ...byte) (value int) { // using conn as the underlying transport. // The configuration config must be non-nil and must include // at least one certificate or else set GetCertificate. -func Server(conn net.Conn, config *Config) (*Conn, error) { +func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { remoteAddr := conn.RemoteAddr().String() if config.Show { fmt.Printf("REALITY remoteAddr: %v\n", remoteAddr) } - target, err := net.Dial(config.Type, config.Dest) + target, err := config.DialContext(ctx, config.Type, config.Dest) if err != nil { conn.Close() return nil, errors.New("REALITY: failed to dial dest: " + err.Error()) @@ -140,7 +140,7 @@ func Server(conn net.Conn, config *Config) (*Conn, error) { underlying = pc.Raw() } - hs := serverHandshakeStateTLS13{ctx: context.TODO()} + hs := serverHandshakeStateTLS13{ctx: context.Background()} c2sSaved := make([]byte, 0, size) s2cSaved := make([]byte, 0, size) @@ -201,7 +201,7 @@ func Server(conn net.Conn, config *Config) (*Conn, error) { conn: readerConn, config: config, } - hs.clientHello, err = hs.c.readClientHello(context.TODO()) + hs.clientHello, err = hs.c.readClientHello(context.Background()) if err != nil || readerConn.Reader.Len() > 0 || readerConn.Written > 0 || readerConn.Closed { break } @@ -421,7 +421,7 @@ func (l *listener) Accept() (net.Conn, error) { if err != nil { return nil, err } - return Server(c, l.config) + return Server(context.Background(), c, l.config) } // NewListener creates a Listener which accepts connections from an inner