diff --git a/auth/saml/metadata.go b/auth/saml/metadata.go index 8be0b1f..fdfd3cd 100644 --- a/auth/saml/metadata.go +++ b/auth/saml/metadata.go @@ -1,38 +1,164 @@ package saml import ( + "crypto/tls" "encoding/xml" + "errors" "fmt" "io" + "net" "net/http" "net/url" + "strings" + "time" "github.com/russellhaering/gosaml2/types" ) +const ( + maxMetadataBytes = 2 << 20 // 2 MiB + maxURLLength = 2048 + maxRedirects = 3 + + requestTimeout = 10 * time.Second + dialTimeout = 5 * time.Second + tlsHandshakeTimeout = 5 * time.Second + responseHeaderTimeout = 5 * time.Second +) + +// Cloud instance metadata service IP (AWS/Azure/GCP commonly use this) +var blockedMetadataIPs = map[string]bool{ + "169.254.169.254": true, +} + +// GCP also commonly exposes metadata via this hostname (which resolves to 169.254.169.254) +var blockedMetadataHostnames = map[string]bool{ + "metadata.google.internal": true, + "metadata.google.internal.": true, +} + func (s *saml) HasValidMetadataURL(metadataURL string) (bool, error) { + if strings.TrimSpace(metadataURL) == "" { + return false, errors.New("metadata URL is empty") + } + if len(metadataURL) > maxURLLength { + return false, fmt.Errorf("metadata URL too long (>%d chars)", maxURLLength) + } + metadataURLParsed, err := url.Parse(metadataURL) if err != nil { - return false, fmt.Errorf("url parse error: %s", err) + return false, fmt.Errorf("url parse error: %w", err) } + + if err := basicURLChecks(metadataURLParsed); err != nil { + return false, err + } + if err := rejectCloudMetadata(metadataURLParsed); err != nil { + return false, err + } + _, err = getMetadata(metadataURLParsed.String()) if err != nil { - return false, fmt.Errorf("fetch metadata error: %s", err) + return false, fmt.Errorf("fetch metadata error: %w", err) } return true, nil } +func basicURLChecks(u *url.URL) error { + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("unsupported URL scheme %q (only http/https allowed)", u.Scheme) + } + if u.Hostname() == "" { + return errors.New("metadata URL missing host") + } + if u.User != nil { + return errors.New("userinfo (username/password) not allowed in metadata URL") + } + if u.Fragment != "" { + return errors.New("fragments are not allowed in metadata URL") + } + return nil +} + +func rejectCloudMetadata(u *url.URL) error { + host := strings.ToLower(u.Hostname()) + + if blockedMetadataHostnames[host] { + return fmt.Errorf("blocked cloud metadata hostname: %s", host) + } + + if ip := net.ParseIP(host); ip != nil { + if blockedMetadataIPs[ip.String()] { + return fmt.Errorf("blocked cloud metadata IP: %s", ip.String()) + } + return nil + } + + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("dns lookup failed for host %q: %w", host, err) + } + for _, ip := range ips { + if blockedMetadataIPs[ip.String()] { + return fmt.Errorf("blocked cloud metadata (host %q resolves to %s)", host, ip.String()) + } + } + return nil +} + func getMetadata(metadataURL string) (types.EntityDescriptor, error) { metadata := types.EntityDescriptor{} - res, err := http.Get(metadataURL) + client := &http.Client{ + Timeout: requestTimeout, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: dialTimeout, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: tlsHandshakeTimeout, + ResponseHeaderTimeout: responseHeaderTimeout, + ExpectContinueTimeout: 1 * time.Second, + IdleConnTimeout: 30 * time.Second, + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + // Re-apply basic checks + metadata blocking to the redirect target + if err := basicURLChecks(req.URL); err != nil { + return err + } + if err := rejectCloudMetadata(req.URL); err != nil { + return err + } + return nil + }, + } + + req, err := http.NewRequest(http.MethodGet, metadataURL, nil) + if err != nil { + return metadata, fmt.Errorf("can't build request: %w", err) + } + req.Header.Set("Accept", "application/samlmetadata+xml, application/xml, text/xml;q=0.9, */*;q=0.1") + res, err := client.Do(req) if err != nil { - return metadata, fmt.Errorf("can't retrieve saml metadata: %s", err) + return metadata, fmt.Errorf("can't retrieve saml metadata: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return metadata, fmt.Errorf("metadata fetch returned %s", res.Status) } - rawMetadata, err := io.ReadAll(res.Body) + rawMetadata, err := readCapped(res.Body, maxMetadataBytes) if err != nil { - return metadata, fmt.Errorf("can't read saml cert data: %s", err) + return metadata, err } err = xml.Unmarshal(rawMetadata, &metadata) @@ -41,3 +167,15 @@ func getMetadata(metadataURL string) (types.EntityDescriptor, error) { } return metadata, nil } + +func readCapped(r io.Reader, max int64) ([]byte, error) { + lr := io.LimitReader(r, max+1) + b, err := io.ReadAll(lr) + if err != nil { + return nil, fmt.Errorf("can't read saml metadata: %w", err) + } + if int64(len(b)) > max { + return nil, fmt.Errorf("metadata too large (>%d bytes)", max) + } + return b, nil +} diff --git a/auth/saml/metadata_test.go b/auth/saml/metadata_test.go new file mode 100644 index 0000000..0e83237 --- /dev/null +++ b/auth/saml/metadata_test.go @@ -0,0 +1,71 @@ +package saml + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +const minimalSAMLMetadata = ` + +` + +func TestHasValidMetadata(t *testing.T) { + s := &saml{} + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Fatalf("expected GET, got %s", r.Method) + } + if got := r.Header.Get("Accept"); got == "" { + t.Fatalf("expected Accept header to be set") + } + + w.Header().Set("Content-Type", "application/samlmetadata+xml") + fmt.Fprint(w, minimalSAMLMetadata) + })) + defer ts.Close() + + hasValidMetadata, err := s.HasValidMetadataURL(ts.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !hasValidMetadata { + t.Fatalf("expected metadata to be valid") + } +} + +func TestHasValidMetadata_InvalidURL(t *testing.T) { + s := &saml{} + + hasValidMetadata, err := s.HasValidMetadataURL("http://169.254.169.254/latest/meta-data/iam/security-credentials/") + if err == nil { + t.Fatalf("expected error for blocked cloud metadata URL, got nil") + } + if hasValidMetadata { + t.Fatalf("expected metadata to be invalid for blocked cloud metadata URL") + } +} +func TestGetMetadata(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Fatalf("expected GET, got %s", r.Method) + } + if got := r.Header.Get("Accept"); got == "" { + t.Fatalf("expected Accept header to be set") + } + + w.Header().Set("Content-Type", "application/samlmetadata+xml") + fmt.Fprint(w, minimalSAMLMetadata) + })) + defer ts.Close() + + metadata, err := getMetadata(ts.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if metadata.EntityID != "urn:test:idp" { + t.Fatalf("expected EntityID to be 'urn:test:idp', got %q", metadata.EntityID) + } +} diff --git a/rest/server.go b/rest/server.go index 80654e5..e7c2192 100644 --- a/rest/server.go +++ b/rest/server.go @@ -7,6 +7,7 @@ import ( "io/fs" "log" "net/http" + "time" "github.com/in4it/go-devops-platform/logging" "github.com/in4it/go-devops-platform/storage" @@ -40,7 +41,23 @@ func StartServer(httpPort, httpsPort int, storage storage.Iface, c *Context, ass // HTTP Configuration go func() { // start http server log.Printf("Start http server on port %d", httpPort) - log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", httpPort), certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))))))) + httpServer := &http.Server{ + Addr: fmt.Sprintf(":%d", httpPort), + + Handler: certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))))), + + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, + + MaxHeaderBytes: 1 << 20, // 1MB + } + + err := httpServer.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + log.Fatalf("http server failed: %v", err) + } }() // TLS Configuration @@ -59,6 +76,13 @@ func StartServer(httpPort, httpsPort int, storage storage.Iface, c *Context, ass GetCertificate: certManager.GetCertificate, }, Handler: c.loggingMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))), + + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, + + MaxHeaderBytes: 1 << 20, // 1MB } c.Protocol = "https" TLSWaiterCompleted = true