From fbe8ffbb69700d799cb544fbd29d0c39d1edfa51 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sun, 10 Mar 2024 23:12:31 +0800 Subject: [PATCH] move hysteria2 server address fetch outside --- hysteria2/client.go | 37 ++++++++++++++++++++----------------- quic.go | 4 ++-- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/hysteria2/client.go b/hysteria2/client.go index f177712..bd415fc 100644 --- a/hysteria2/client.go +++ b/hysteria2/client.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "io" - "math/rand" "net" "net/http" "net/url" @@ -18,7 +17,6 @@ import ( qtls "github.com/metacubex/sing-quic" hyCC "github.com/metacubex/sing-quic/hysteria2/congestion" "github.com/metacubex/sing-quic/hysteria2/internal/protocol" - "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/baderror" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" @@ -31,8 +29,7 @@ type ClientOptions struct { Dialer N.Dialer Logger logger.Logger BrutalDebug bool - ServerAddress M.Socksaddr - ServerAddresses []M.Socksaddr + ServerAddress func(ctx context.Context) (*net.UDPAddr, error) HopInterval time.Duration SendBPS uint64 ReceiveBPS uint64 @@ -49,8 +46,7 @@ type Client struct { dialer N.Dialer logger logger.Logger brutalDebug bool - serverAddr atomic.TypedValue[M.Socksaddr] - serverAddrs []M.Socksaddr + serverAddress func(ctx context.Context) (*net.UDPAddr, error) hopInterval time.Duration sendBPS uint64 receiveBPS uint64 @@ -85,7 +81,7 @@ func NewClient(options ClientOptions) (*Client, error) { dialer: options.Dialer, logger: options.Logger, brutalDebug: options.BrutalDebug, - serverAddrs: options.ServerAddresses, + serverAddress: options.ServerAddress, hopInterval: options.HopInterval, sendBPS: options.SendBPS, receiveBPS: options.ReceiveBPS, @@ -97,27 +93,31 @@ func NewClient(options ClientOptions) (*Client, error) { cwnd: options.CWND, udpMTU: options.UdpMTU, } - client.serverAddr.Store(options.ServerAddress) return client, nil } func (c *Client) hopLoop(conn *clientQUICConnection) { ticker := time.NewTicker(c.hopInterval) defer ticker.Stop() - c.logger.Info("Entering hop loop ...") + c.logger.Debug("Entering hop loop ...") for { select { case <-ticker.C: - serverAddr := c.serverAddrs[rand.Intn(len(c.serverAddrs))] - c.serverAddr.Store(serverAddr) - conn.quicConn.SetRemoteAddr(serverAddr.UDPAddr()) - c.logger.Info("Hopped to ", serverAddr) + ctx, cancel := context.WithTimeout(context.Background(), c.hopInterval) + serverAddr, err := c.serverAddress(ctx) + cancel() + if err != nil { + c.logger.Warn("Hop loop fetch serverAddress error: '%s', ignored", err) + break + } + conn.quicConn.SetRemoteAddr(serverAddr) + c.logger.Debug("Hopped to ", serverAddr) continue case <-c.ctx.Done(): case <-conn.quicConn.Context().Done(): case <-conn.connDone: } - c.logger.Info("Exiting hop loop ...") + c.logger.Debug("Exiting hop loop ...") return } } @@ -141,8 +141,11 @@ func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { } func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { - serverAddr := c.serverAddr.Load() - packetConn, err := c.dialer.ListenPacket(ctx, serverAddr) + serverAddr, err := c.serverAddress(ctx) + if err == nil { + return nil, err + } + packetConn, err := c.dialer.ListenPacket(ctx, M.SocksaddrFromNet(serverAddr)) if err != nil { return nil, err } @@ -202,7 +205,7 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { go c.loopMessages(conn) } c.conn = conn - if len(c.serverAddrs) > 0 { + if c.hopInterval > 0 { go c.hopLoop(conn) } return conn, nil diff --git a/quic.go b/quic.go index 5fdc5fc..ad06daf 100644 --- a/quic.go +++ b/quic.go @@ -43,13 +43,13 @@ func DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, tlsConfi return quic.DialEarly(ctx, conn, addr, tlsConfig, quicConfig) } -func CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, tlsConfig *tls.Config, quicConfig *quic.Config, enableDatagrams bool) (http.RoundTripper, error) { +func CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr *net.UDPAddr, tlsConfig *tls.Config, quicConfig *quic.Config, enableDatagrams bool) (http.RoundTripper, error) { return &http3.RoundTripper{ TLSClientConfig: tlsConfig, QuicConfig: quicConfig, EnableDatagrams: enableDatagrams, Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - quicConn, err := quic.DialEarly(ctx, conn, serverAddr.UDPAddr(), tlsCfg, cfg) + quicConn, err := quic.DialEarly(ctx, conn, serverAddr, tlsCfg, cfg) if err != nil { return nil, err }