diff --git a/checks/checks.go b/checks/checks.go index 1f1cd93..6c92238 100644 --- a/checks/checks.go +++ b/checks/checks.go @@ -9,6 +9,7 @@ import ( type Checks struct { Carbon *Carbon + IpAddress *Ip LegacyRank *LegacyRank Rank *Rank SocialTags *SocialTags @@ -21,6 +22,7 @@ func NewChecks() *Checks { } return &Checks{ Carbon: NewCarbon(client), + IpAddress: NewIp(NewNetIp()), LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()), Rank: NewRank(client), SocialTags: NewSocialTags(client), diff --git a/checks/getIP.go b/checks/getIP.go new file mode 100644 index 0000000..2df6c66 --- /dev/null +++ b/checks/getIP.go @@ -0,0 +1,63 @@ +package checks + +import ( + "context" + "net" +) + +type IpAddress struct { + Address net.IP `json:"ip"` + Family int `json:"family"` +} + +type IpGetter interface { + GetIp(ctx context.Context, host string) ([]IpAddress, error) +} + +type IpGetterFunc func(ctx context.Context, host string) ([]IpAddress, error) + +func (f IpGetterFunc) GetIp(ctx context.Context, host string) ([]IpAddress, error) { + return f(ctx, host) +} + +type NetIp struct{} + +func NewNetIp() *NetIp { + return &NetIp{} +} + +func (l *NetIp) GetIp(ctx context.Context, host string) ([]IpAddress, error) { + resolver := &net.Resolver{ + PreferGo: true, + } + ip4, err := resolver.LookupIP(ctx, "ip4", host) + if err != nil { + return nil, err + } + ip6, err := resolver.LookupIP(ctx, "ip6", host) + if err != nil { + return nil, err + } + + var ipAddresses []IpAddress + for _, ip := range ip4 { + ipAddresses = append(ipAddresses, IpAddress{Address: ip, Family: 4}) + } + for _, ip := range ip6 { + ipAddresses = append(ipAddresses, IpAddress{Address: ip, Family: 6}) + } + + return ipAddresses, nil +} + +type Ip struct { + getter IpGetter +} + +func NewIp(l IpGetter) *Ip { + return &Ip{getter: l} +} + +func (i *Ip) Lookup(ctx context.Context, host string) ([]IpAddress, error) { + return i.getter.GetIp(ctx, host) +} diff --git a/checks/getIP_test.go b/checks/getIP_test.go new file mode 100644 index 0000000..de99db0 --- /dev/null +++ b/checks/getIP_test.go @@ -0,0 +1,29 @@ +package checks + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLookup(t *testing.T) { + t.Parallel() + + ipAddresses := []IpAddress{ + {net.ParseIP("216.58.201.110"), 4}, + {net.ParseIP("2a00:1450:4009:826::200e"), 6}, + } + i := NewIp(IpGetterFunc(func(ctx context.Context, host string) ([]IpAddress, error) { + return ipAddresses, nil + })) + actual, err := i.Lookup(context.Background(), "google.com") + assert.NoError(t, err) + + assert.Equal(t, ipAddresses[0].Address, actual[0].Address) + assert.Equal(t, 4, actual[0].Family) + + assert.Equal(t, ipAddresses[1].Address, actual[1].Address) + assert.Equal(t, 6, actual[1].Family) +} diff --git a/handlers/getIP.go b/handlers/getIP.go index 0b302be..52aaa8c 100644 --- a/handlers/getIP.go +++ b/handlers/getIP.go @@ -1,29 +1,12 @@ package handlers import ( - "net" "net/http" -) - -func lookupAsync(address string) (map[string]interface{}, error) { - ip, err := net.LookupIP(address) - if err != nil { - return nil, err - } - - result := make(map[string]interface{}) - if len(ip) > 0 { - result["ip"] = ip[0].String() - result["family"] = 4 - } else { - result["ip"] = "" - result["family"] = nil - } - return result, nil -} + "github.com/xray-web/web-check-api/checks" +) -func HandleGetIP() http.Handler { +func HandleGetIP(i *checks.Ip) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rawURL, err := extractURL(r) if err != nil { @@ -31,7 +14,7 @@ func HandleGetIP() http.Handler { return } - result, err := lookupAsync(rawURL.Hostname()) + result, err := i.Lookup(r.Context(), rawURL.Hostname()) if err != nil { JSONError(w, err, http.StatusInternalServerError) return diff --git a/handlers/getIP_test.go b/handlers/getIP_test.go index b6edb76..1967bfc 100644 --- a/handlers/getIP_test.go +++ b/handlers/getIP_test.go @@ -1,7 +1,6 @@ package handlers import ( - "encoding/json" "net/http" "net/http/httptest" "testing" @@ -10,21 +9,16 @@ import ( ) func TestHandleGetIP(t *testing.T) { - req := httptest.NewRequest("GET", "/get-ip?url=example.com", nil) - rec := httptest.NewRecorder() - HandleGetIP().ServeHTTP(rec, req) + t.Parallel() - assert.Equal(t, http.StatusOK, rec.Code) + t.Run("missing URL parameter", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/get-ip", nil) + rec := httptest.NewRecorder() - var response map[string]interface{} - err := json.Unmarshal(rec.Body.Bytes(), &response) - assert.NoError(t, err) + HandleGetIP(nil).ServeHTTP(rec, req) - ip, ok := response["ip"].(string) - assert.True(t, ok, "IP address not found in response") - assert.NotEmpty(t, ip, "IP address is empty") - - family, ok := response["family"].(float64) - assert.True(t, ok, "Family field not found in response") - assert.Equal(t, float64(4), family, "Family field should be 4") + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.JSONEq(t, `{"error": "missing URL parameter"}`, rec.Body.String()) + }) } diff --git a/server/server.go b/server/server.go index 96ebf33..1b55c4b 100644 --- a/server/server.go +++ b/server/server.go @@ -36,7 +36,7 @@ func (s *Server) routes() { s.mux.Handle("GET /api/dns", handlers.HandleDNS()) s.mux.Handle("GET /api/dnssec", handlers.HandleDnsSec()) s.mux.Handle("GET /api/firewall", handlers.HandleFirewall()) - s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP()) + s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP(s.checks.IpAddress)) s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders()) s.mux.Handle("GET /api/hsts", handlers.HandleHsts()) s.mux.Handle("GET /api/http-security", handlers.HandleHttpSecurity())