Skip to content

Commit

Permalink
RF: getIP (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
syywu committed Jun 14, 2024
1 parent d9e2e28 commit dd5e713
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 37 deletions.
2 changes: 2 additions & 0 deletions checks/checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

type Checks struct {
Carbon *Carbon
IpAddress *Ip
LegacyRank *LegacyRank
Rank *Rank
SocialTags *SocialTags
Expand All @@ -21,6 +22,7 @@ func NewChecks() *Checks {
}
return &Checks{
Carbon: NewCarbon(client),
IpAddress: NewIp(NewNetIp()),
LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()),
Rank: NewRank(client),
SocialTags: NewSocialTags(client),
Expand Down
63 changes: 63 additions & 0 deletions checks/getIP.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package checks

import (
"context"
"net"
)

type IpAddress struct {
Address net.IP `json:"ip"`
Family int `json:"family"`
}

type IpGetter interface {
GetIp(ctx context.Context, host string) ([]IpAddress, error)
}

type IpGetterFunc func(ctx context.Context, host string) ([]IpAddress, error)

func (f IpGetterFunc) GetIp(ctx context.Context, host string) ([]IpAddress, error) {
return f(ctx, host)
}

type NetIp struct{}

func NewNetIp() *NetIp {
return &NetIp{}
}

func (l *NetIp) GetIp(ctx context.Context, host string) ([]IpAddress, error) {
resolver := &net.Resolver{
PreferGo: true,
}
ip4, err := resolver.LookupIP(ctx, "ip4", host)
if err != nil {
return nil, err
}
ip6, err := resolver.LookupIP(ctx, "ip6", host)
if err != nil {
return nil, err
}

var ipAddresses []IpAddress
for _, ip := range ip4 {
ipAddresses = append(ipAddresses, IpAddress{Address: ip, Family: 4})
}
for _, ip := range ip6 {
ipAddresses = append(ipAddresses, IpAddress{Address: ip, Family: 6})
}

return ipAddresses, nil
}

type Ip struct {
getter IpGetter
}

func NewIp(l IpGetter) *Ip {
return &Ip{getter: l}
}

func (i *Ip) Lookup(ctx context.Context, host string) ([]IpAddress, error) {
return i.getter.GetIp(ctx, host)
}
29 changes: 29 additions & 0 deletions checks/getIP_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package checks

import (
"context"
"net"
"testing"

"github.com/stretchr/testify/assert"
)

func TestLookup(t *testing.T) {
t.Parallel()

ipAddresses := []IpAddress{
{net.ParseIP("216.58.201.110"), 4},
{net.ParseIP("2a00:1450:4009:826::200e"), 6},
}
i := NewIp(IpGetterFunc(func(ctx context.Context, host string) ([]IpAddress, error) {
return ipAddresses, nil
}))
actual, err := i.Lookup(context.Background(), "google.com")
assert.NoError(t, err)

assert.Equal(t, ipAddresses[0].Address, actual[0].Address)
assert.Equal(t, 4, actual[0].Family)

assert.Equal(t, ipAddresses[1].Address, actual[1].Address)
assert.Equal(t, 6, actual[1].Family)
}
25 changes: 4 additions & 21 deletions handlers/getIP.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,20 @@
package handlers

import (
"net"
"net/http"
)

func lookupAsync(address string) (map[string]interface{}, error) {
ip, err := net.LookupIP(address)
if err != nil {
return nil, err
}

result := make(map[string]interface{})
if len(ip) > 0 {
result["ip"] = ip[0].String()
result["family"] = 4
} else {
result["ip"] = ""
result["family"] = nil
}

return result, nil
}
"github.com/xray-web/web-check-api/checks"
)

func HandleGetIP() http.Handler {
func HandleGetIP(i *checks.Ip) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawURL, err := extractURL(r)
if err != nil {
JSONError(w, ErrMissingURLParameter, http.StatusBadRequest)
return
}

result, err := lookupAsync(rawURL.Hostname())
result, err := i.Lookup(r.Context(), rawURL.Hostname())
if err != nil {
JSONError(w, err, http.StatusInternalServerError)
return
Expand Down
24 changes: 9 additions & 15 deletions handlers/getIP_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handlers

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -10,21 +9,16 @@ import (
)

func TestHandleGetIP(t *testing.T) {
req := httptest.NewRequest("GET", "/get-ip?url=example.com", nil)
rec := httptest.NewRecorder()
HandleGetIP().ServeHTTP(rec, req)
t.Parallel()

assert.Equal(t, http.StatusOK, rec.Code)
t.Run("missing URL parameter", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/get-ip", nil)
rec := httptest.NewRecorder()

var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
HandleGetIP(nil).ServeHTTP(rec, req)

ip, ok := response["ip"].(string)
assert.True(t, ok, "IP address not found in response")
assert.NotEmpty(t, ip, "IP address is empty")

family, ok := response["family"].(float64)
assert.True(t, ok, "Family field not found in response")
assert.Equal(t, float64(4), family, "Family field should be 4")
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.JSONEq(t, `{"error": "missing URL parameter"}`, rec.Body.String())
})
}
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (s *Server) routes() {
s.mux.Handle("GET /api/dns", handlers.HandleDNS())
s.mux.Handle("GET /api/dnssec", handlers.HandleDnsSec())
s.mux.Handle("GET /api/firewall", handlers.HandleFirewall())
s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP())
s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP(s.checks.IpAddress))
s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders())
s.mux.Handle("GET /api/hsts", handlers.HandleHsts())
s.mux.Handle("GET /api/http-security", handlers.HandleHttpSecurity())
Expand Down

0 comments on commit dd5e713

Please sign in to comment.