diff --git a/common/dialer/default.go b/common/dialer/default.go index 77536c43db..50244ac55d 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -35,7 +35,6 @@ type DefaultDialer struct { udpListener net.ListenConfig udpAddr4 string udpAddr6 string - isWireGuardListener bool networkManager adapter.NetworkManager networkStrategy *C.NetworkStrategy defaultNetworkStrategy bool @@ -183,11 +182,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial } setMultiPathTCP(&dialer4) } - if options.IsWireGuardListener { - for _, controlFn := range WgControlFns { - listener.Control = control.Append(listener.Control, controlFn) - } - } tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen) if err != nil { return nil, err @@ -204,7 +198,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial udpListener: listener, udpAddr4: udpAddr4, udpAddr6: udpAddr6, - isWireGuardListener: options.IsWireGuardListener, networkManager: networkManager, networkStrategy: networkStrategy, defaultNetworkStrategy: defaultNetworkStrategy, diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index 93803fdb47..812e557557 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -16,59 +16,82 @@ import ( "github.com/sagernet/sing/service" ) +type Options struct { + Context context.Context + Options option.DialerOptions + RemoteIsDomain bool + DirectResolver bool + ResolverOnDetour bool +} + +// TODO: merge with NewWithOptions func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) { - if options.IsWireGuardListener { - return NewDefault(ctx, options) - } + return NewWithOptions(Options{ + Context: ctx, + Options: options, + RemoteIsDomain: remoteIsDomain, + }) +} + +func NewWithOptions(options Options) (N.Dialer, error) { + dialOptions := options.Options var ( dialer N.Dialer err error ) - if options.Detour != "" { - outboundManager := service.FromContext[adapter.OutboundManager](ctx) + if dialOptions.Detour != "" { + outboundManager := service.FromContext[adapter.OutboundManager](options.Context) if outboundManager == nil { return nil, E.New("missing outbound manager") } - dialer = NewDetour(outboundManager, options.Detour) + dialer = NewDetour(outboundManager, dialOptions.Detour) } else { - dialer, err = NewDefault(ctx, options) + dialer, err = NewDefault(options.Context, dialOptions) if err != nil { return nil, err } } - if remoteIsDomain && options.Detour == "" { - networkManager := service.FromContext[adapter.NetworkManager](ctx) - dnsTransport := service.FromContext[adapter.DNSTransportManager](ctx) + if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour) { + networkManager := service.FromContext[adapter.NetworkManager](options.Context) + dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context) var defaultOptions adapter.NetworkOptions if networkManager != nil { defaultOptions = networkManager.DefaultOptions() } var ( + server string dnsQueryOptions adapter.DNSQueryOptions resolveFallbackDelay time.Duration ) - if options.DomainResolver != nil && options.DomainResolver.Server != "" { - transport, loaded := dnsTransport.Transport(options.DomainResolver.Server) - if !loaded { - return nil, E.New("domain resolver not found: " + options.DomainResolver.Server) + if dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "" { + var transport adapter.DNSTransport + if !options.DirectResolver { + var loaded bool + transport, loaded = dnsTransport.Transport(dialOptions.DomainResolver.Server) + if !loaded { + return nil, E.New("domain resolver not found: " + dialOptions.DomainResolver.Server) + } } var strategy C.DomainStrategy - if options.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) { - strategy = C.DomainStrategy(options.DomainResolver.Strategy) + if dialOptions.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) { + strategy = C.DomainStrategy(dialOptions.DomainResolver.Strategy) } else if //nolint:staticcheck - options.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) { + dialOptions.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) { //nolint:staticcheck - strategy = C.DomainStrategy(options.DomainStrategy) + strategy = C.DomainStrategy(dialOptions.DomainStrategy) } + server = dialOptions.DomainResolver.Server dnsQueryOptions = adapter.DNSQueryOptions{ Transport: transport, Strategy: strategy, - DisableCache: options.DomainResolver.DisableCache, - RewriteTTL: options.DomainResolver.RewriteTTL, - ClientSubnet: options.DomainResolver.ClientSubnet.Build(netip.Prefix{}), + DisableCache: dialOptions.DomainResolver.DisableCache, + RewriteTTL: dialOptions.DomainResolver.RewriteTTL, + ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}), } - resolveFallbackDelay = time.Duration(options.FallbackDelay) + resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay) + } else if options.DirectResolver { + return nil, E.New("missing domain resolver for domain server address") } else if defaultOptions.DomainResolver != "" { dnsQueryOptions = defaultOptions.DomainResolveOptions transport, loaded := dnsTransport.Transport(defaultOptions.DomainResolver) @@ -76,68 +99,15 @@ func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) return nil, E.New("default domain resolver not found: " + defaultOptions.DomainResolver) } dnsQueryOptions.Transport = transport - resolveFallbackDelay = time.Duration(options.FallbackDelay) + resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay) } else { - deprecated.Report(ctx, deprecated.OptionMissingDomainResolver) - } - dialer = NewResolveDialer( - ctx, - dialer, - options.Detour == "" && !options.TCPFastOpen, - "", - dnsQueryOptions, - resolveFallbackDelay, - ) - } - return dialer, nil -} - -func NewDNS(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) { - var ( - dialer N.Dialer - err error - ) - if options.Detour != "" { - outboundManager := service.FromContext[adapter.OutboundManager](ctx) - if outboundManager == nil { - return nil, E.New("missing outbound manager") - } - dialer = NewDetour(outboundManager, options.Detour) - } else { - dialer, err = NewDefault(ctx, options) - if err != nil { - return nil, err - } - } - if remoteIsDomain { - var ( - dnsQueryOptions adapter.DNSQueryOptions - resolveFallbackDelay time.Duration - ) - if options.DomainResolver == nil || options.DomainResolver.Server == "" { - return nil, E.New("missing domain resolver for domain server address") - } - var strategy C.DomainStrategy - if options.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) { - strategy = C.DomainStrategy(options.DomainResolver.Strategy) - } else if - //nolint:staticcheck - options.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) { - //nolint:staticcheck - strategy = C.DomainStrategy(options.DomainStrategy) - } - dnsQueryOptions = adapter.DNSQueryOptions{ - Strategy: strategy, - DisableCache: options.DomainResolver.DisableCache, - RewriteTTL: options.DomainResolver.RewriteTTL, - ClientSubnet: options.DomainResolver.ClientSubnet.Build(netip.Prefix{}), + deprecated.Report(options.Context, deprecated.OptionMissingDomainResolver) } - resolveFallbackDelay = time.Duration(options.FallbackDelay) dialer = NewResolveDialer( - ctx, + options.Context, dialer, - options.Detour == "" && !options.TCPFastOpen, - options.DomainResolver.Server, + dialOptions.Detour == "" && !dialOptions.TCPFastOpen, + server, dnsQueryOptions, resolveFallbackDelay, ) diff --git a/dns/transport_dialer.go b/dns/transport_dialer.go index 5fe2949d6f..0b15c7eaab 100644 --- a/dns/transport_dialer.go +++ b/dns/transport_dialer.go @@ -19,7 +19,11 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) ( if options.LegacyDefaultDialer { return dialer.NewDefaultOutbound(ctx), nil } else { - return dialer.NewDNS(ctx, options.DialerOptions, false) + return dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: options.DialerOptions, + DirectResolver: true, + }) } } @@ -38,7 +42,12 @@ func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) } return transportDialer, nil } else { - return dialer.NewDNS(ctx, options.DialerOptions, options.ServerIsDomain()) + return dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: options.DialerOptions, + RemoteIsDomain: options.ServerIsDomain(), + DirectResolver: true, + }) } } diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index 0485d63b66..e167bec166 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -53,7 +53,14 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL if options.Detour == "" { options.IsWireGuardListener = true } - outboundDialer, err := dialer.New(ctx, options.DialerOptions, false) + outboundDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: options.DialerOptions, + RemoteIsDomain: common.Any(options.Peers, func(it option.WireGuardPeer) bool { + return !M.ParseAddr(it.Address).IsValid() + }), + ResolverOnDetour: true, + }) if err != nil { return nil, err } diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index d4eea10bbf..edd8184c03 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -56,7 +56,14 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL } else if options.GSO { return nil, E.New("gso is conflict with detour") } - outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) + outboundDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: options.DialerOptions, + RemoteIsDomain: options.ServerIsDomain() || common.Any(options.Peers, func(it option.LegacyWireGuardPeer) bool { + return it.ServerIsDomain() + }), + ResolverOnDetour: true, + }) if err != nil { return nil, err }