0
0
mirror of https://github.com/XTLS/Xray-core.git synced 2025-08-22 22:48:35 +00:00

Wireguard inbound: Fix context sharing problem (#4988)

* Try fix Wireguard inbound context sharing problem

* Shallow copy inbound and content

* Fix context passing

* Add notes for source address
This commit is contained in:
yuhan6665 2025-08-17 10:56:48 -04:00 committed by GitHub
parent 105b306d07
commit 337b4b814e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 61 deletions

View File

@ -118,9 +118,7 @@ func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.Bu
} }
func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error { func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
// deep-clone outbounds because it is going to be mutated concurrently ctx = session.SubContextFromMuxInbound(ctx)
// (Target and OriginalTarget)
ctx = session.ContextCloneOutboundsAndContent(ctx)
errors.LogInfo(ctx, "received request for ", meta.Target) errors.LogInfo(ctx, "received request for ", meta.Target)
{ {
msg := &log.AccessMessage{ msg := &log.AccessMessage{

View File

@ -16,15 +16,15 @@ const (
inboundSessionKey ctx.SessionKey = 1 inboundSessionKey ctx.SessionKey = 1
outboundSessionKey ctx.SessionKey = 2 outboundSessionKey ctx.SessionKey = 2
contentSessionKey ctx.SessionKey = 3 contentSessionKey ctx.SessionKey = 3
muxPreferredSessionKey ctx.SessionKey = 4 muxPreferredSessionKey ctx.SessionKey = 4 // unused
sockoptSessionKey ctx.SessionKey = 5 sockoptSessionKey ctx.SessionKey = 5 // used by dokodemo to only receive sockopt.Mark
trackedConnectionErrorKey ctx.SessionKey = 6 trackedConnectionErrorKey ctx.SessionKey = 6 // used by observer to get outbound error
dispatcherKey ctx.SessionKey = 7 dispatcherKey ctx.SessionKey = 7 // used by ss2022 inbounds to get dispatcher
timeoutOnlyKey ctx.SessionKey = 8 timeoutOnlyKey ctx.SessionKey = 8 // mux context's child contexts to only cancel when its own traffic times out
allowedNetworkKey ctx.SessionKey = 9 allowedNetworkKey ctx.SessionKey = 9 // muxcool server control incoming request tcp/udp
handlerSessionKey ctx.SessionKey = 10 handlerSessionKey ctx.SessionKey = 10 // unused
mitmAlpn11Key ctx.SessionKey = 11 mitmAlpn11Key ctx.SessionKey = 11 // used by TLS dialer
mitmServerNameKey ctx.SessionKey = 12 mitmServerNameKey ctx.SessionKey = 12 // used by TLS dialer
) )
func ContextWithInbound(ctx context.Context, inbound *Inbound) context.Context { func ContextWithInbound(ctx context.Context, inbound *Inbound) context.Context {
@ -42,18 +42,8 @@ func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Co
return context.WithValue(ctx, outboundSessionKey, outbounds) return context.WithValue(ctx, outboundSessionKey, outbounds)
} }
func ContextCloneOutboundsAndContent(ctx context.Context) context.Context { func SubContextFromMuxInbound(ctx context.Context) context.Context {
outbounds := OutboundsFromContext(ctx) newOutbounds := []*Outbound{{}}
newOutbounds := make([]*Outbound, len(outbounds))
for i, ob := range outbounds {
if ob == nil {
continue
}
// copy outbound by value
v := *ob
newOutbounds[i] = &v
}
content := ContentFromContext(ctx) content := ContentFromContext(ctx)
newContent := Content{} newContent := Content{}

View File

@ -48,9 +48,9 @@ type Inbound struct {
User *protocol.MemoryUser User *protocol.MemoryUser
// VlessRoute is the user-sent VLESS UUID's last byte. // VlessRoute is the user-sent VLESS UUID's last byte.
VlessRoute net.Port VlessRoute net.Port
// Conn is actually internet.Connection. May be nil. // Used by splice copy. Conn is actually internet.Connection. May be nil.
Conn net.Conn Conn net.Conn
// Timer of the inbound buf copier. May be nil. // Used by splice copy. Timer of the inbound buf copier. May be nil.
Timer *signal.ActivityTimer Timer *signal.ActivityTimer
// CanSpliceCopy is a property for this connection // CanSpliceCopy is a property for this connection
// 1 = can, 2 = after processing protocol info should be able to, 3 = cannot // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot
@ -69,31 +69,33 @@ type Outbound struct {
Tag string Tag string
// Name of the outbound proxy that handles the connection. // Name of the outbound proxy that handles the connection.
Name string Name string
// Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings // Unused. Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings
Conn net.Conn Conn net.Conn
// CanSpliceCopy is a property for this connection // CanSpliceCopy is a property for this connection
// 1 = can, 2 = after processing protocol info should be able to, 3 = cannot // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot
CanSpliceCopy int CanSpliceCopy int
} }
// SniffingRequest controls the behavior of content sniffing. // SniffingRequest controls the behavior of content sniffing. They are from inbound config. Read-only
type SniffingRequest struct { type SniffingRequest struct {
ExcludeForDomain []string // read-only once set ExcludeForDomain []string
OverrideDestinationForProtocol []string // read-only once set OverrideDestinationForProtocol []string
Enabled bool Enabled bool
MetadataOnly bool MetadataOnly bool
RouteOnly bool RouteOnly bool
} }
// Content is the metadata of the connection content. // Content is the metadata of the connection content. Mainly used for routing.
type Content struct { type Content struct {
// Protocol of current content. // Protocol of current content.
Protocol string Protocol string
SniffingRequest SniffingRequest SniffingRequest SniffingRequest
// HTTP traffic sniffed headers
Attributes map[string]string Attributes map[string]string
// SkipDNSResolve is set from DNS module. the DOH remote server maybe a domain name, this prevents cycle resolving dead loop
SkipDNSResolve bool SkipDNSResolve bool
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
@ -33,7 +34,6 @@ type routingInfo struct {
ctx context.Context ctx context.Context
dispatcher routing.Dispatcher dispatcher routing.Dispatcher
inboundTag *session.Inbound inboundTag *session.Inbound
outboundTag *session.Outbound
contentTag *session.Content contentTag *session.Content
} }
@ -78,18 +78,11 @@ func (*Server) Network() []net.Network {
// Process implements proxy.Inbound. // Process implements proxy.Inbound.
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "wireguard"
inbound.CanSpliceCopy = 3
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds)-1]
s.info = routingInfo{ s.info = routingInfo{
ctx: core.ToBackgroundDetachedContext(ctx), ctx: ctx,
dispatcher: dispatcher, dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx), inboundTag: session.InboundFromContext(ctx),
outboundTag: ob, contentTag: session.ContentFromContext(ctx),
contentTag: session.ContentFromContext(ctx),
} }
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
@ -134,6 +127,25 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
defer conn.Close() defer conn.Close()
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
sid := session.NewID()
ctx = c.ContextWithID(ctx, sid)
inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
if s.info.inboundTag != nil {
inbound = *s.info.inboundTag
}
inbound.Name = "wireguard"
inbound.CanSpliceCopy = 3
// overwrite the source to use the tun address for each sub context.
// Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
// Currently we have no way to link to the original source address
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
ctx = session.ContextWithInbound(ctx, &inbound)
if s.info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag)
}
ctx = session.SubContextFromMuxInbound(ctx)
plcy := s.policyManager.ForLevel(0) plcy := s.policyManager.ForLevel(0)
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
@ -144,25 +156,9 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
Reason: "", Reason: "",
}) })
if s.info.inboundTag != nil {
ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
}
// what's this?
// Session information should not be shared between different connections
// why reuse them in server level? This will cause incorrect destoverride and unexpected routing behavior.
// Disable it temporarily. Maybe s.info should be removed.
// if s.info.outboundTag != nil {
// ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag})
// }
// if s.info.contentTag != nil {
// ctx = session.ContextWithContent(ctx, s.info.contentTag)
// }
link, err := s.info.dispatcher.Dispatch(ctx, dest) link, err := s.info.dispatcher.Dispatch(ctx, dest)
if err != nil { if err != nil {
errors.LogErrorInner(s.info.ctx, err, "dispatch connection") errors.LogErrorInner(ctx, err, "dispatch connection")
} }
defer cancel() defer cancel()
@ -188,7 +184,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
if err := task.Run(ctx, requestDonePost, responseDone); err != nil { if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader) common.Interrupt(link.Reader)
common.Interrupt(link.Writer) common.Interrupt(link.Writer)
errors.LogDebugInner(s.info.ctx, err, "connection ends") errors.LogDebugInner(ctx, err, "connection ends")
return return
} }
} }