diff --git a/checks/checks.go b/checks/checks.go index ebc64d5..eaab3af 100644 --- a/checks/checks.go +++ b/checks/checks.go @@ -12,7 +12,7 @@ type Checks struct { BlockList *BlockList Carbon *Carbon Headers *Headers - IpAddress *Ip + IpAddress *NetIp LegacyRank *LegacyRank LinkedPages *LinkedPages Rank *Rank @@ -28,7 +28,7 @@ func NewChecks() *Checks { BlockList: NewBlockList(&ip.NetDNSLookup{}), Carbon: NewCarbon(client), Headers: NewHeaders(client), - IpAddress: NewIp(NewNetIp()), + IpAddress: NewNetIp(&ip.NetLookup{}), LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()), LinkedPages: NewLinkedPages(client), Rank: NewRank(client), diff --git a/checks/getIP.go b/checks/getIP.go index 2df6c66..0adc766 100644 --- a/checks/getIP.go +++ b/checks/getIP.go @@ -3,6 +3,8 @@ package checks import ( "context" "net" + + "github.com/xray-web/web-check-api/checks/clients/ip" ) type IpAddress struct { @@ -10,32 +12,21 @@ type IpAddress struct { Family int `json:"family"` } -type IpGetter interface { - GetIp(ctx context.Context, host string) ([]IpAddress, error) +type NetIp struct { + lookup ip.Lookup } -type IpGetterFunc func(ctx context.Context, host string) ([]IpAddress, error) - -func (f IpGetterFunc) GetIp(ctx context.Context, host string) ([]IpAddress, error) { - return f(ctx, host) -} - -type NetIp struct{} - -func NewNetIp() *NetIp { - return &NetIp{} +func NewNetIp(lookup ip.Lookup) *NetIp { + return &NetIp{lookup: lookup} } func (l *NetIp) GetIp(ctx context.Context, host string) ([]IpAddress, error) { - resolver := &net.Resolver{ - PreferGo: true, - } - ip4, err := resolver.LookupIP(ctx, "ip4", host) + ip4, err := l.lookup.LookupIP(ctx, "ip4", host) if err != nil { - return nil, err + // do nothing } - ip6, err := resolver.LookupIP(ctx, "ip6", host) - if err != nil { + ip6, err := l.lookup.LookupIP(ctx, "ip6", host) + if err != nil && len(ip4) == 0 && len(ip6) == 0 { return nil, err } @@ -49,15 +40,3 @@ func (l *NetIp) GetIp(ctx context.Context, host string) ([]IpAddress, error) { return ipAddresses, nil } - -type Ip struct { - getter IpGetter -} - -func NewIp(l IpGetter) *Ip { - return &Ip{getter: l} -} - -func (i *Ip) Lookup(ctx context.Context, host string) ([]IpAddress, error) { - return i.getter.GetIp(ctx, host) -} diff --git a/checks/getIP_test.go b/checks/getIP_test.go index de99db0..023e4e8 100644 --- a/checks/getIP_test.go +++ b/checks/getIP_test.go @@ -6,24 +6,17 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/xray-web/web-check-api/checks/clients/ip" ) func TestLookup(t *testing.T) { t.Parallel() - ipAddresses := []IpAddress{ - {net.ParseIP("216.58.201.110"), 4}, - {net.ParseIP("2a00:1450:4009:826::200e"), 6}, - } - i := NewIp(IpGetterFunc(func(ctx context.Context, host string) ([]IpAddress, error) { - return ipAddresses, nil + n := NewNetIp(ip.LookupFunc(func(ctx context.Context, network string, host string) ([]net.IP, error) { + return []net.IP{net.ParseIP("216.58.201.110")}, nil })) - actual, err := i.Lookup(context.Background(), "google.com") + actual, err := n.GetIp(context.Background(), "google.com") assert.NoError(t, err) - assert.Equal(t, ipAddresses[0].Address, actual[0].Address) - assert.Equal(t, 4, actual[0].Family) - - assert.Equal(t, ipAddresses[1].Address, actual[1].Address) - assert.Equal(t, 6, actual[1].Family) + assert.Contains(t, actual, IpAddress{Address: net.ParseIP("216.58.201.110"), Family: 4}) } diff --git a/handlers/getIP.go b/handlers/getIP.go index 52aaa8c..1eea76a 100644 --- a/handlers/getIP.go +++ b/handlers/getIP.go @@ -6,7 +6,7 @@ import ( "github.com/xray-web/web-check-api/checks" ) -func HandleGetIP(i *checks.Ip) http.Handler { +func HandleGetIP(i *checks.NetIp) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rawURL, err := extractURL(r) if err != nil { @@ -14,7 +14,7 @@ func HandleGetIP(i *checks.Ip) http.Handler { return } - result, err := i.Lookup(r.Context(), rawURL.Hostname()) + result, err := i.GetIp(r.Context(), rawURL.Hostname()) if err != nil { JSONError(w, err, http.StatusInternalServerError) return