Skip to content

Commit 4cb2352

Browse files
committed
http configuration improvements for metadata and server
1 parent daf65ea commit 4cb2352

3 files changed

Lines changed: 240 additions & 7 deletions

File tree

auth/saml/metadata.go

Lines changed: 144 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,164 @@
11
package saml
22

33
import (
4+
"crypto/tls"
45
"encoding/xml"
6+
"errors"
57
"fmt"
68
"io"
9+
"net"
710
"net/http"
811
"net/url"
12+
"strings"
13+
"time"
914

1015
"github.com/russellhaering/gosaml2/types"
1116
)
1217

18+
const (
19+
maxMetadataBytes = 2 << 20 // 2 MiB
20+
maxURLLength = 2048
21+
maxRedirects = 3
22+
23+
requestTimeout = 10 * time.Second
24+
dialTimeout = 5 * time.Second
25+
tlsHandshakeTimeout = 5 * time.Second
26+
responseHeaderTimeout = 5 * time.Second
27+
)
28+
29+
// Cloud instance metadata service IP (AWS/Azure/GCP commonly use this)
30+
var blockedMetadataIPs = map[string]bool{
31+
"169.254.169.254": true,
32+
}
33+
34+
// GCP also commonly exposes metadata via this hostname (which resolves to 169.254.169.254)
35+
var blockedMetadataHostnames = map[string]bool{
36+
"metadata.google.internal": true,
37+
"metadata.google.internal.": true,
38+
}
39+
1340
func (s *saml) HasValidMetadataURL(metadataURL string) (bool, error) {
41+
if strings.TrimSpace(metadataURL) == "" {
42+
return false, errors.New("metadata URL is empty")
43+
}
44+
if len(metadataURL) > maxURLLength {
45+
return false, fmt.Errorf("metadata URL too long (>%d chars)", maxURLLength)
46+
}
47+
1448
metadataURLParsed, err := url.Parse(metadataURL)
1549
if err != nil {
16-
return false, fmt.Errorf("url parse error: %s", err)
50+
return false, fmt.Errorf("url parse error: %w", err)
1751
}
52+
53+
if err := basicURLChecks(metadataURLParsed); err != nil {
54+
return false, err
55+
}
56+
if err := rejectCloudMetadata(metadataURLParsed); err != nil {
57+
return false, err
58+
}
59+
1860
_, err = getMetadata(metadataURLParsed.String())
1961
if err != nil {
20-
return false, fmt.Errorf("fetch metadata error: %s", err)
62+
return false, fmt.Errorf("fetch metadata error: %w", err)
2163
}
2264
return true, nil
2365
}
2466

67+
func basicURLChecks(u *url.URL) error {
68+
if u.Scheme != "http" && u.Scheme != "https" {
69+
return fmt.Errorf("unsupported URL scheme %q (only http/https allowed)", u.Scheme)
70+
}
71+
if u.Hostname() == "" {
72+
return errors.New("metadata URL missing host")
73+
}
74+
if u.User != nil {
75+
return errors.New("userinfo (username/password) not allowed in metadata URL")
76+
}
77+
if u.Fragment != "" {
78+
return errors.New("fragments are not allowed in metadata URL")
79+
}
80+
return nil
81+
}
82+
83+
func rejectCloudMetadata(u *url.URL) error {
84+
host := strings.ToLower(u.Hostname())
85+
86+
if blockedMetadataHostnames[host] {
87+
return fmt.Errorf("blocked cloud metadata hostname: %s", host)
88+
}
89+
90+
if ip := net.ParseIP(host); ip != nil {
91+
if blockedMetadataIPs[ip.String()] {
92+
return fmt.Errorf("blocked cloud metadata IP: %s", ip.String())
93+
}
94+
return nil
95+
}
96+
97+
ips, err := net.LookupIP(host)
98+
if err != nil {
99+
return fmt.Errorf("dns lookup failed for host %q: %w", host, err)
100+
}
101+
for _, ip := range ips {
102+
if blockedMetadataIPs[ip.String()] {
103+
return fmt.Errorf("blocked cloud metadata (host %q resolves to %s)", host, ip.String())
104+
}
105+
}
106+
return nil
107+
}
108+
25109
func getMetadata(metadataURL string) (types.EntityDescriptor, error) {
26110
metadata := types.EntityDescriptor{}
27111

28-
res, err := http.Get(metadataURL)
112+
client := &http.Client{
113+
Timeout: requestTimeout,
114+
Transport: &http.Transport{
115+
Proxy: http.ProxyFromEnvironment,
116+
DialContext: (&net.Dialer{
117+
Timeout: dialTimeout,
118+
KeepAlive: 30 * time.Second,
119+
}).DialContext,
120+
TLSHandshakeTimeout: tlsHandshakeTimeout,
121+
ResponseHeaderTimeout: responseHeaderTimeout,
122+
ExpectContinueTimeout: 1 * time.Second,
123+
IdleConnTimeout: 30 * time.Second,
124+
DisableKeepAlives: true,
125+
TLSClientConfig: &tls.Config{
126+
MinVersion: tls.VersionTLS12,
127+
},
128+
},
129+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
130+
if len(via) >= maxRedirects {
131+
return fmt.Errorf("stopped after %d redirects", maxRedirects)
132+
}
133+
// Re-apply basic checks + metadata blocking to the redirect target
134+
if err := basicURLChecks(req.URL); err != nil {
135+
return err
136+
}
137+
if err := rejectCloudMetadata(req.URL); err != nil {
138+
return err
139+
}
140+
return nil
141+
},
142+
}
143+
144+
req, err := http.NewRequest(http.MethodGet, metadataURL, nil)
145+
if err != nil {
146+
return metadata, fmt.Errorf("can't build request: %w", err)
147+
}
148+
req.Header.Set("Accept", "application/samlmetadata+xml, application/xml, text/xml;q=0.9, */*;q=0.1")
149+
res, err := client.Do(req)
29150
if err != nil {
30-
return metadata, fmt.Errorf("can't retrieve saml metadata: %s", err)
151+
return metadata, fmt.Errorf("can't retrieve saml metadata: %w", err)
152+
}
153+
defer res.Body.Close()
154+
155+
if res.StatusCode != http.StatusOK {
156+
return metadata, fmt.Errorf("metadata fetch returned %s", res.Status)
31157
}
32158

33-
rawMetadata, err := io.ReadAll(res.Body)
159+
rawMetadata, err := readCapped(res.Body, maxMetadataBytes)
34160
if err != nil {
35-
return metadata, fmt.Errorf("can't read saml cert data: %s", err)
161+
return metadata, err
36162
}
37163

38164
err = xml.Unmarshal(rawMetadata, &metadata)
@@ -41,3 +167,15 @@ func getMetadata(metadataURL string) (types.EntityDescriptor, error) {
41167
}
42168
return metadata, nil
43169
}
170+
171+
func readCapped(r io.Reader, max int64) ([]byte, error) {
172+
lr := io.LimitReader(r, max+1)
173+
b, err := io.ReadAll(lr)
174+
if err != nil {
175+
return nil, fmt.Errorf("can't read saml metadata: %w", err)
176+
}
177+
if int64(len(b)) > max {
178+
return nil, fmt.Errorf("metadata too large (>%d bytes)", max)
179+
}
180+
return b, nil
181+
}

auth/saml/metadata_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package saml
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
)
9+
10+
const minimalSAMLMetadata = `<?xml version="1.0" encoding="UTF-8"?>
11+
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="urn:test:idp">
12+
</EntityDescriptor>`
13+
14+
func TestHasValidMetadata(t *testing.T) {
15+
s := &saml{}
16+
17+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18+
if r.Method != http.MethodGet {
19+
t.Fatalf("expected GET, got %s", r.Method)
20+
}
21+
if got := r.Header.Get("Accept"); got == "" {
22+
t.Fatalf("expected Accept header to be set")
23+
}
24+
25+
w.Header().Set("Content-Type", "application/samlmetadata+xml")
26+
fmt.Fprint(w, minimalSAMLMetadata)
27+
}))
28+
defer ts.Close()
29+
30+
hasValidMetadata, err := s.HasValidMetadataURL(ts.URL)
31+
if err != nil {
32+
t.Fatalf("unexpected error: %v", err)
33+
}
34+
if !hasValidMetadata {
35+
t.Fatalf("expected metadata to be valid")
36+
}
37+
}
38+
39+
func TestHasValidMetadata_InvalidURL(t *testing.T) {
40+
s := &saml{}
41+
42+
hasValidMetadata, err := s.HasValidMetadataURL("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
43+
if err == nil {
44+
t.Fatalf("expected error for blocked cloud metadata URL, got nil")
45+
}
46+
if hasValidMetadata {
47+
t.Fatalf("expected metadata to be invalid for blocked cloud metadata URL")
48+
}
49+
}
50+
func TestGetMetadata(t *testing.T) {
51+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
52+
if r.Method != http.MethodGet {
53+
t.Fatalf("expected GET, got %s", r.Method)
54+
}
55+
if got := r.Header.Get("Accept"); got == "" {
56+
t.Fatalf("expected Accept header to be set")
57+
}
58+
59+
w.Header().Set("Content-Type", "application/samlmetadata+xml")
60+
fmt.Fprint(w, minimalSAMLMetadata)
61+
}))
62+
defer ts.Close()
63+
64+
metadata, err := getMetadata(ts.URL)
65+
if err != nil {
66+
t.Fatalf("unexpected error: %v", err)
67+
}
68+
if metadata.EntityID != "urn:test:idp" {
69+
t.Fatalf("expected EntityID to be 'urn:test:idp', got %q", metadata.EntityID)
70+
}
71+
}

rest/server.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io/fs"
88
"log"
99
"net/http"
10+
"time"
1011

1112
"github.com/in4it/go-devops-platform/logging"
1213
"github.com/in4it/go-devops-platform/storage"
@@ -40,7 +41,23 @@ func StartServer(httpPort, httpsPort int, storage storage.Iface, c *Context, ass
4041
// HTTP Configuration
4142
go func() { // start http server
4243
log.Printf("Start http server on port %d", httpPort)
43-
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", httpPort), certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody)))))))
44+
httpServer := &http.Server{
45+
Addr: fmt.Sprintf(":%d", httpPort),
46+
47+
Handler: certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))))),
48+
49+
ReadHeaderTimeout: 5 * time.Second,
50+
ReadTimeout: 10 * time.Second,
51+
WriteTimeout: 15 * time.Second,
52+
IdleTimeout: 60 * time.Second,
53+
54+
MaxHeaderBytes: 1 << 20, // 1MB
55+
}
56+
57+
err := httpServer.ListenAndServe()
58+
if err != nil && err != http.ErrServerClosed {
59+
log.Fatalf("http server failed: %v", err)
60+
}
4461
}()
4562

4663
// TLS Configuration
@@ -59,6 +76,13 @@ func StartServer(httpPort, httpsPort int, storage storage.Iface, c *Context, ass
5976
GetCertificate: certManager.GetCertificate,
6077
},
6178
Handler: c.loggingMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))),
79+
80+
ReadHeaderTimeout: 5 * time.Second,
81+
ReadTimeout: 10 * time.Second,
82+
WriteTimeout: 15 * time.Second,
83+
IdleTimeout: 60 * time.Second,
84+
85+
MaxHeaderBytes: 1 << 20, // 1MB
6286
}
6387
c.Protocol = "https"
6488
TLSWaiterCompleted = true

0 commit comments

Comments
 (0)