Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 144 additions & 6 deletions auth/saml/metadata.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,164 @@
package saml

import (
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"

"github.com/russellhaering/gosaml2/types"
)

const (
maxMetadataBytes = 2 << 20 // 2 MiB
maxURLLength = 2048
maxRedirects = 3

requestTimeout = 10 * time.Second
dialTimeout = 5 * time.Second
tlsHandshakeTimeout = 5 * time.Second
responseHeaderTimeout = 5 * time.Second
)

// Cloud instance metadata service IP (AWS/Azure/GCP commonly use this)
var blockedMetadataIPs = map[string]bool{
"169.254.169.254": true,
}

// GCP also commonly exposes metadata via this hostname (which resolves to 169.254.169.254)
var blockedMetadataHostnames = map[string]bool{
"metadata.google.internal": true,
"metadata.google.internal.": true,
}

func (s *saml) HasValidMetadataURL(metadataURL string) (bool, error) {
if strings.TrimSpace(metadataURL) == "" {
return false, errors.New("metadata URL is empty")
}
if len(metadataURL) > maxURLLength {
return false, fmt.Errorf("metadata URL too long (>%d chars)", maxURLLength)
}

metadataURLParsed, err := url.Parse(metadataURL)
if err != nil {
return false, fmt.Errorf("url parse error: %s", err)
return false, fmt.Errorf("url parse error: %w", err)
}

if err := basicURLChecks(metadataURLParsed); err != nil {
return false, err
}
if err := rejectCloudMetadata(metadataURLParsed); err != nil {
return false, err
}

_, err = getMetadata(metadataURLParsed.String())
if err != nil {
return false, fmt.Errorf("fetch metadata error: %s", err)
return false, fmt.Errorf("fetch metadata error: %w", err)
}
return true, nil
}

func basicURLChecks(u *url.URL) error {
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("unsupported URL scheme %q (only http/https allowed)", u.Scheme)
}
if u.Hostname() == "" {
return errors.New("metadata URL missing host")
}
if u.User != nil {
return errors.New("userinfo (username/password) not allowed in metadata URL")
}
if u.Fragment != "" {
return errors.New("fragments are not allowed in metadata URL")
}
return nil
}

func rejectCloudMetadata(u *url.URL) error {
host := strings.ToLower(u.Hostname())

if blockedMetadataHostnames[host] {
return fmt.Errorf("blocked cloud metadata hostname: %s", host)
}

if ip := net.ParseIP(host); ip != nil {
if blockedMetadataIPs[ip.String()] {
return fmt.Errorf("blocked cloud metadata IP: %s", ip.String())
}
return nil
}

ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("dns lookup failed for host %q: %w", host, err)
}
for _, ip := range ips {
if blockedMetadataIPs[ip.String()] {
return fmt.Errorf("blocked cloud metadata (host %q resolves to %s)", host, ip.String())
}
}
return nil
}

func getMetadata(metadataURL string) (types.EntityDescriptor, error) {
metadata := types.EntityDescriptor{}

res, err := http.Get(metadataURL)
client := &http.Client{
Timeout: requestTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: dialTimeout,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: tlsHandshakeTimeout,
ResponseHeaderTimeout: responseHeaderTimeout,
ExpectContinueTimeout: 1 * time.Second,
IdleConnTimeout: 30 * time.Second,
DisableKeepAlives: true,
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
// Re-apply basic checks + metadata blocking to the redirect target
if err := basicURLChecks(req.URL); err != nil {
return err
}
if err := rejectCloudMetadata(req.URL); err != nil {
return err
}
return nil
},
}

req, err := http.NewRequest(http.MethodGet, metadataURL, nil)
if err != nil {
return metadata, fmt.Errorf("can't build request: %w", err)
}
req.Header.Set("Accept", "application/samlmetadata+xml, application/xml, text/xml;q=0.9, */*;q=0.1")
res, err := client.Do(req)
if err != nil {
return metadata, fmt.Errorf("can't retrieve saml metadata: %s", err)
return metadata, fmt.Errorf("can't retrieve saml metadata: %w", err)
}
defer res.Body.Close()

if res.StatusCode != http.StatusOK {
return metadata, fmt.Errorf("metadata fetch returned %s", res.Status)
}

rawMetadata, err := io.ReadAll(res.Body)
rawMetadata, err := readCapped(res.Body, maxMetadataBytes)
if err != nil {
return metadata, fmt.Errorf("can't read saml cert data: %s", err)
return metadata, err
}

err = xml.Unmarshal(rawMetadata, &metadata)
Expand All @@ -41,3 +167,15 @@ func getMetadata(metadataURL string) (types.EntityDescriptor, error) {
}
return metadata, nil
}

func readCapped(r io.Reader, max int64) ([]byte, error) {
lr := io.LimitReader(r, max+1)
b, err := io.ReadAll(lr)
if err != nil {
return nil, fmt.Errorf("can't read saml metadata: %w", err)
}
if int64(len(b)) > max {
return nil, fmt.Errorf("metadata too large (>%d bytes)", max)
}
return b, nil
}
71 changes: 71 additions & 0 deletions auth/saml/metadata_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package saml

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)

const minimalSAMLMetadata = `<?xml version="1.0" encoding="UTF-8"?>
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="urn:test:idp">
</EntityDescriptor>`

func TestHasValidMetadata(t *testing.T) {
s := &saml{}

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Fatalf("expected GET, got %s", r.Method)
}
if got := r.Header.Get("Accept"); got == "" {
t.Fatalf("expected Accept header to be set")
}

w.Header().Set("Content-Type", "application/samlmetadata+xml")
fmt.Fprint(w, minimalSAMLMetadata)
}))
defer ts.Close()

hasValidMetadata, err := s.HasValidMetadataURL(ts.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !hasValidMetadata {
t.Fatalf("expected metadata to be valid")
}
}

func TestHasValidMetadata_InvalidURL(t *testing.T) {
s := &saml{}

hasValidMetadata, err := s.HasValidMetadataURL("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
if err == nil {
t.Fatalf("expected error for blocked cloud metadata URL, got nil")
}
if hasValidMetadata {
t.Fatalf("expected metadata to be invalid for blocked cloud metadata URL")
}
}
func TestGetMetadata(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Fatalf("expected GET, got %s", r.Method)
}
if got := r.Header.Get("Accept"); got == "" {
t.Fatalf("expected Accept header to be set")
}

w.Header().Set("Content-Type", "application/samlmetadata+xml")
fmt.Fprint(w, minimalSAMLMetadata)
}))
defer ts.Close()

metadata, err := getMetadata(ts.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if metadata.EntityID != "urn:test:idp" {
t.Fatalf("expected EntityID to be 'urn:test:idp', got %q", metadata.EntityID)
}
}
26 changes: 25 additions & 1 deletion rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/fs"
"log"
"net/http"
"time"

"github.com/in4it/go-devops-platform/logging"
"github.com/in4it/go-devops-platform/storage"
Expand Down Expand Up @@ -40,7 +41,23 @@ func StartServer(httpPort, httpsPort int, storage storage.Iface, c *Context, ass
// HTTP Configuration
go func() { // start http server
log.Printf("Start http server on port %d", httpPort)
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", httpPort), certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody)))))))
httpServer := &http.Server{
Addr: fmt.Sprintf(":%d", httpPort),

Handler: certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))))),

ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,

MaxHeaderBytes: 1 << 20, // 1MB
}

err := httpServer.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
log.Fatalf("http server failed: %v", err)
}
}()

// TLS Configuration
Expand All @@ -59,6 +76,13 @@ func StartServer(httpPort, httpsPort int, storage storage.Iface, c *Context, ass
GetCertificate: certManager.GetCertificate,
},
Handler: c.loggingMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))),

ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,

MaxHeaderBytes: 1 << 20, // 1MB
}
c.Protocol = "https"
TLSWaiterCompleted = true
Expand Down