Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checks/checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Checks struct {
BlockList *BlockList
Carbon *Carbon
Headers *Headers
Hsts *Hsts
IpAddress *Ip
LegacyRank *LegacyRank
LinkedPages *LinkedPages
Expand All @@ -28,6 +29,7 @@ func NewChecks() *Checks {
BlockList: NewBlockList(&ip.NetDNSLookup{}),
Carbon: NewCarbon(client),
Headers: NewHeaders(client),
Hsts: NewHsts(client),
IpAddress: NewIp(NewNetIp()),
LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()),
LinkedPages: NewLinkedPages(client),
Expand Down
87 changes: 87 additions & 0 deletions checks/hsts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package checks

import (
"context"
"net/http"
"strconv"
"strings"
"unicode"
)

type HSTSResponse struct {
Message string `json:"message"`
Compatible bool `json:"compatible"`
HSTSHeader string `json:"hstsHeader"`
}

type Hsts struct {
client *http.Client
}

func NewHsts(client *http.Client) *Hsts {
return &Hsts{client: client}
}

func (h *Hsts) Validate(ctx context.Context, url string) (*HSTSResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
return nil, err
}

resp, err := h.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

hstsHeader := resp.Header.Get("Strict-Transport-Security")
if hstsHeader == "" {
return &HSTSResponse{Message: "Site does not serve any HSTS headers."}, nil
}

if !strings.Contains(hstsHeader, "max-age") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although max-age is not optional, I guess we shouldn't assume it'll have max-age in the headers?

return &HSTSResponse{Message: "HSTS max-age is less than 10886400."}, nil
}

var maxAgeString string
for _, h := range strings.Split(hstsHeader, " ") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OG way to extract max-age was through the use of regex and indexing into the slice to get the max age value. Is this a better approach than indexing?

if strings.Contains(h, "max-age=") {
maxAgeString = extractMaxAgeFromHeader(h)
}
}

maxAge, err := strconv.Atoi(maxAgeString)
if err != nil {
return nil, err
}

if maxAge < 10886400 {
return &HSTSResponse{Message: "HSTS max-age is less than 10886400."}, nil
}

if !strings.Contains(hstsHeader, "includeSubDomains") {
return &HSTSResponse{Message: "HSTS header does not include all subdomains."}, nil
}

if !strings.Contains(hstsHeader, "preload") {
return &HSTSResponse{Message: "HSTS header does not contain the preload directive."}, nil
}

return &HSTSResponse{
Message: "Site is compatible with the HSTS preload list!",
Compatible: true,
HSTSHeader: hstsHeader,
}, nil
}

func extractMaxAgeFromHeader(header string) string {
var maxAge strings.Builder

for _, b := range header {
if unicode.IsDigit(b) {
maxAge.WriteRune(b)
}
}

return maxAge.String()
}
125 changes: 125 additions & 0 deletions checks/hsts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package checks

import (
"context"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/xray-web/web-check-api/testutils"
)

func TestValidate(t *testing.T) {
t.Parallel()

t.Run("given an empty header", func(t *testing.T) {
t.Parallel()

client := testutils.MockClient(&http.Response{
Header: http.Header{"Strict-Transport-Security": []string{""}}})
h := NewHsts(client)

actual, err := h.Validate(context.Background(), "test.com")
assert.NoError(t, err)

assert.Equal(t, "Site does not serve any HSTS headers.", actual.Message)
assert.False(t, actual.Compatible)
assert.Empty(t, actual.HSTSHeader)
})

t.Run("given a header without max age", func(t *testing.T) {
t.Parallel()

client := testutils.MockClient(&http.Response{
Header: http.Header{"Strict-Transport-Security": []string{"includeSubDomains; preload"}}})
h := NewHsts(client)

actual, err := h.Validate(context.Background(), "test.com")
assert.NoError(t, err)

assert.Equal(t, "HSTS max-age is less than 10886400.", actual.Message)
assert.False(t, actual.Compatible)
assert.Empty(t, actual.HSTSHeader)
})

t.Run("given max age less than 10886400", func(t *testing.T) {
t.Parallel()

client := testutils.MockClient(&http.Response{
Header: http.Header{"Strict-Transport-Security": []string{"max-age=47; includeSubDomains; preload"}}})
h := NewHsts(client)

actual, err := h.Validate(context.Background(), "test.com")
assert.NoError(t, err)

assert.Equal(t, "HSTS max-age is less than 10886400.", actual.Message)
assert.False(t, actual.Compatible)
assert.Empty(t, actual.HSTSHeader)
})

t.Run("given a header without includeSubDomains", func(t *testing.T) {
t.Parallel()

client := testutils.MockClient(&http.Response{
Header: http.Header{"Strict-Transport-Security": []string{"max-age=47474747; preload"}}})
h := NewHsts(client)

actual, err := h.Validate(context.Background(), "test.com")
assert.NoError(t, err)

assert.Equal(t, "HSTS header does not include all subdomains.", actual.Message)
assert.False(t, actual.Compatible)
assert.Empty(t, actual.HSTSHeader)
})

t.Run("given a header without preload", func(t *testing.T) {
t.Parallel()

client := testutils.MockClient(&http.Response{
Header: http.Header{"Strict-Transport-Security": []string{"max-age=47474747; includeSubDomains"}}})
h := NewHsts(client)

actual, err := h.Validate(context.Background(), "test.com")
assert.NoError(t, err)

assert.Equal(t, "HSTS header does not contain the preload directive.", actual.Message)
assert.False(t, actual.Compatible)
assert.Empty(t, actual.HSTSHeader)
})

t.Run("given a valid header", func(t *testing.T) {
t.Parallel()

client := testutils.MockClient(&http.Response{
Header: http.Header{"Strict-Transport-Security": []string{"max-age=47474747; includeSubDomains; preload"}}})
h := NewHsts(client)

actual, err := h.Validate(context.Background(), "test.com")
assert.NoError(t, err)

assert.Equal(t, "Site is compatible with the HSTS preload list!", actual.Message)
assert.True(t, actual.Compatible)
assert.NotEmpty(t, actual.HSTSHeader)
})
}

func TestExtractMaxAgeFromHeader(t *testing.T) {
t.Parallel()

for _, tc := range []struct {
name string
header string
expected string
}{
{"give valid header", "max-age=47474747;", "47474747"},
{"given an empty header", "", ""},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

actual := extractMaxAgeFromHeader(tc.header)
assert.Equal(t, tc.expected, actual)
})
}
}
50 changes: 4 additions & 46 deletions handlers/hsts.go
Original file line number Diff line number Diff line change
@@ -1,62 +1,20 @@
package handlers

import (
"fmt"
"net/http"
"regexp"
"strings"
)

type HSTSResponse struct {
Message string `json:"message"`
Compatible bool `json:"compatible"`
HSTSHeader string `json:"hstsHeader"`
}

func checkHSTS(url string) (HSTSResponse, error) {
client := &http.Client{}

req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return HSTSResponse{}, fmt.Errorf("error creating request: %s", err.Error())
}

resp, err := client.Do(req)
if err != nil {
return HSTSResponse{}, fmt.Errorf("error making request: %s", err.Error())
}
defer resp.Body.Close()

hstsHeader := resp.Header.Get("strict-transport-security")
if hstsHeader == "" {
return HSTSResponse{Message: "Site does not serve any HSTS headers."}, nil
}

maxAgeMatch := regexp.MustCompile(`max-age=(\d+)`).FindStringSubmatch(hstsHeader)
if maxAgeMatch == nil || len(maxAgeMatch) < 2 || maxAgeMatch[1] == "" || maxAgeMatch[1] < "10886400" {
return HSTSResponse{Message: "HSTS max-age is less than 10886400."}, nil
}

if !strings.Contains(hstsHeader, "includeSubDomains") {
return HSTSResponse{Message: "HSTS header does not include all subdomains."}, nil
}

if !strings.Contains(hstsHeader, "preload") {
return HSTSResponse{Message: "HSTS header does not contain the preload directive."}, nil
}

return HSTSResponse{Message: "Site is compatible with the HSTS preload list!", Compatible: true, HSTSHeader: hstsHeader}, nil
}
"github.com/xray-web/web-check-api/checks"
)

func HandleHsts() http.Handler {
func HandleHsts(h *checks.Hsts) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawURL, err := extractURL(r)
if err != nil {
JSONError(w, ErrMissingURLParameter, http.StatusBadRequest)
return
}

result, err := checkHSTS(rawURL.String())
result, err := h.Validate(r.Context(), rawURL.String())
if err != nil {
JSONError(w, err, http.StatusInternalServerError)
return
Expand Down
20 changes: 8 additions & 12 deletions handlers/hsts_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handlers

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -11,18 +10,15 @@ import (

func TestHandleHsts(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/check-hsts?url=example.com", nil)
rec := httptest.NewRecorder()
HandleHsts().ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
t.Run("missing URL parameter", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/check-hsts", nil)
rec := httptest.NewRecorder()

var response HSTSResponse
err := json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
HandleHsts(nil).ServeHTTP(rec, req)

assert.NotNil(t, response)
assert.Equal(t, "Site does not serve any HSTS headers.", response.Message)
assert.False(t, response.Compatible)
assert.Empty(t, response.HSTSHeader)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.JSONEq(t, `{"error": "missing URL parameter"}`, rec.Body.String())
})
}
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (s *Server) routes() {
s.mux.Handle("GET /api/firewall", handlers.HandleFirewall())
s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP(s.checks.IpAddress))
s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders(s.checks.Headers))
s.mux.Handle("GET /api/hsts", handlers.HandleHsts())
s.mux.Handle("GET /api/hsts", handlers.HandleHsts(s.checks.Hsts))
s.mux.Handle("GET /api/http-security", handlers.HandleHttpSecurity())
s.mux.Handle("GET /api/legacy-rank", handlers.HandleLegacyRank(s.checks.LegacyRank))
s.mux.Handle("GET /api/linked-pages", handlers.HandleGetLinks(s.checks.LinkedPages))
Expand Down