diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 69318de8..59890f21 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -295,37 +295,11 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti } if h.senderSettings.Via != nil { - outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] - var domain string - addr := h.senderSettings.Via.AsAddress() - domain = h.senderSettings.Via.GetDomain() - switch { - case h.senderSettings.ViaCidr != "": - ob.Gateway = ParseRandomIP(addr, h.senderSettings.ViaCidr) - - case domain == "origin": - if inbound := session.InboundFromContext(ctx); inbound != nil { - if inbound.Local.IsValid() && inbound.Local.Address.Family().IsIP() { - ob.Gateway = inbound.Local.Address - errors.LogDebug(ctx, "use inbound local ip as sendthrough: ", inbound.Local.Address.String()) - } - } - case domain == "srcip": - if inbound := session.InboundFromContext(ctx); inbound != nil { - if inbound.Source.IsValid() && inbound.Source.Address.Family().IsIP() { - ob.Gateway = inbound.Source.Address - errors.LogDebug(ctx, "use inbound source ip as sendthrough: ", inbound.Source.Address.String()) - } - } - //case addr.Family().IsDomain(): - default: - ob.Gateway = addr - - } - + h.SetOutboundGateway(ctx, ob) } + } if conn, err := h.getUoTConnection(ctx, dest); err != os.ErrInvalid { @@ -340,6 +314,38 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti return conn, err } +func (h *Handler) SetOutboundGateway(ctx context.Context, ob *session.Outbound) { + if ob.Gateway == nil && h.senderSettings != nil && h.senderSettings.Via != nil { + var domain string + addr := h.senderSettings.Via.AsAddress() + domain = h.senderSettings.Via.GetDomain() + switch { + case h.senderSettings.ViaCidr != "": + ob.Gateway = ParseRandomIP(addr, h.senderSettings.ViaCidr) + + case domain == "origin": + if inbound := session.InboundFromContext(ctx); inbound != nil { + if inbound.Local.IsValid() && inbound.Local.Address.Family().IsIP() { + ob.Gateway = inbound.Local.Address + errors.LogDebug(ctx, "use inbound local ip as sendthrough: ", inbound.Local.Address.String()) + } + } + case domain == "srcip": + if inbound := session.InboundFromContext(ctx); inbound != nil { + if inbound.Source.IsValid() && inbound.Source.Address.Family().IsIP() { + ob.Gateway = inbound.Source.Address + errors.LogDebug(ctx, "use inbound source ip as sendthrough: ", inbound.Source.Address.String()) + } + } + //case addr.Family().IsDomain(): + default: + ob.Gateway = addr + + } + + } +} + func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection { if h.uplinkCounter != nil || h.downlinkCounter != nil { return &stat.CounterConnection{ diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 75b53114..bb8877b9 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -89,6 +89,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte destination := ob.Target origTargetAddr := ob.OriginalTarget.Address + dialer.SetOutboundGateway(ctx, ob) outGateway := ob.Gateway UDPOverride := net.UDPDestination(nil, 0) if h.config.DestinationOverride != nil { @@ -481,7 +482,10 @@ func (w *NoisePacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { if err != nil { return err } - w.Writer.WriteMultiBuffer(buf.MultiBuffer{buf.FromBytes(noise)}) + err = w.Writer.WriteMultiBuffer(buf.MultiBuffer{buf.FromBytes(noise)}) + if err != nil { + return err + } if n.DelayMin != 0 || n.DelayMax != 0 { time.Sleep(time.Duration(crypto.RandBetween(int64(n.DelayMin), int64(n.DelayMax))) * time.Millisecond) diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 9a1e44c0..43c18b3e 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -26,6 +26,9 @@ type Dialer interface { // DestIpAddress returns the ip of proxy server. It is useful in case of Android client, which prepare an IP before proxy connection is established DestIpAddress() net.IP + + // SetOutboundGateway set outbound gateway + SetOutboundGateway(ctx context.Context, ob *session.Outbound) } // dialFunc is an interface to dial network connection to a specific destination.