11package saml
22
33import (
4+ "crypto/tls"
45 "encoding/xml"
6+ "errors"
57 "fmt"
68 "io"
9+ "net"
710 "net/http"
811 "net/url"
12+ "strings"
13+ "time"
914
1015 "github.com/russellhaering/gosaml2/types"
1116)
1217
18+ const (
19+ maxMetadataBytes = 2 << 20 // 2 MiB
20+ maxURLLength = 2048
21+ maxRedirects = 3
22+
23+ requestTimeout = 10 * time .Second
24+ dialTimeout = 5 * time .Second
25+ tlsHandshakeTimeout = 5 * time .Second
26+ responseHeaderTimeout = 5 * time .Second
27+ )
28+
29+ // Cloud instance metadata service IP (AWS/Azure/GCP commonly use this)
30+ var blockedMetadataIPs = map [string ]bool {
31+ "169.254.169.254" : true ,
32+ }
33+
34+ // GCP also commonly exposes metadata via this hostname (which resolves to 169.254.169.254)
35+ var blockedMetadataHostnames = map [string ]bool {
36+ "metadata.google.internal" : true ,
37+ "metadata.google.internal." : true ,
38+ }
39+
1340func (s * saml ) HasValidMetadataURL (metadataURL string ) (bool , error ) {
41+ if strings .TrimSpace (metadataURL ) == "" {
42+ return false , errors .New ("metadata URL is empty" )
43+ }
44+ if len (metadataURL ) > maxURLLength {
45+ return false , fmt .Errorf ("metadata URL too long (>%d chars)" , maxURLLength )
46+ }
47+
1448 metadataURLParsed , err := url .Parse (metadataURL )
1549 if err != nil {
16- return false , fmt .Errorf ("url parse error: %s " , err )
50+ return false , fmt .Errorf ("url parse error: %w " , err )
1751 }
52+
53+ if err := basicURLChecks (metadataURLParsed ); err != nil {
54+ return false , err
55+ }
56+ if err := rejectCloudMetadata (metadataURLParsed ); err != nil {
57+ return false , err
58+ }
59+
1860 _ , err = getMetadata (metadataURLParsed .String ())
1961 if err != nil {
20- return false , fmt .Errorf ("fetch metadata error: %s " , err )
62+ return false , fmt .Errorf ("fetch metadata error: %w " , err )
2163 }
2264 return true , nil
2365}
2466
67+ func basicURLChecks (u * url.URL ) error {
68+ if u .Scheme != "http" && u .Scheme != "https" {
69+ return fmt .Errorf ("unsupported URL scheme %q (only http/https allowed)" , u .Scheme )
70+ }
71+ if u .Hostname () == "" {
72+ return errors .New ("metadata URL missing host" )
73+ }
74+ if u .User != nil {
75+ return errors .New ("userinfo (username/password) not allowed in metadata URL" )
76+ }
77+ if u .Fragment != "" {
78+ return errors .New ("fragments are not allowed in metadata URL" )
79+ }
80+ return nil
81+ }
82+
83+ func rejectCloudMetadata (u * url.URL ) error {
84+ host := strings .ToLower (u .Hostname ())
85+
86+ if blockedMetadataHostnames [host ] {
87+ return fmt .Errorf ("blocked cloud metadata hostname: %s" , host )
88+ }
89+
90+ if ip := net .ParseIP (host ); ip != nil {
91+ if blockedMetadataIPs [ip .String ()] {
92+ return fmt .Errorf ("blocked cloud metadata IP: %s" , ip .String ())
93+ }
94+ return nil
95+ }
96+
97+ ips , err := net .LookupIP (host )
98+ if err != nil {
99+ return fmt .Errorf ("dns lookup failed for host %q: %w" , host , err )
100+ }
101+ for _ , ip := range ips {
102+ if blockedMetadataIPs [ip .String ()] {
103+ return fmt .Errorf ("blocked cloud metadata (host %q resolves to %s)" , host , ip .String ())
104+ }
105+ }
106+ return nil
107+ }
108+
25109func getMetadata (metadataURL string ) (types.EntityDescriptor , error ) {
26110 metadata := types.EntityDescriptor {}
27111
28- res , err := http .Get (metadataURL )
112+ client := & http.Client {
113+ Timeout : requestTimeout ,
114+ Transport : & http.Transport {
115+ Proxy : http .ProxyFromEnvironment ,
116+ DialContext : (& net.Dialer {
117+ Timeout : dialTimeout ,
118+ KeepAlive : 30 * time .Second ,
119+ }).DialContext ,
120+ TLSHandshakeTimeout : tlsHandshakeTimeout ,
121+ ResponseHeaderTimeout : responseHeaderTimeout ,
122+ ExpectContinueTimeout : 1 * time .Second ,
123+ IdleConnTimeout : 30 * time .Second ,
124+ DisableKeepAlives : true ,
125+ TLSClientConfig : & tls.Config {
126+ MinVersion : tls .VersionTLS12 ,
127+ },
128+ },
129+ CheckRedirect : func (req * http.Request , via []* http.Request ) error {
130+ if len (via ) >= maxRedirects {
131+ return fmt .Errorf ("stopped after %d redirects" , maxRedirects )
132+ }
133+ // Re-apply basic checks + metadata blocking to the redirect target
134+ if err := basicURLChecks (req .URL ); err != nil {
135+ return err
136+ }
137+ if err := rejectCloudMetadata (req .URL ); err != nil {
138+ return err
139+ }
140+ return nil
141+ },
142+ }
143+
144+ req , err := http .NewRequest (http .MethodGet , metadataURL , nil )
145+ if err != nil {
146+ return metadata , fmt .Errorf ("can't build request: %w" , err )
147+ }
148+ req .Header .Set ("Accept" , "application/samlmetadata+xml, application/xml, text/xml;q=0.9, */*;q=0.1" )
149+ res , err := client .Do (req )
29150 if err != nil {
30- return metadata , fmt .Errorf ("can't retrieve saml metadata: %s" , err )
151+ return metadata , fmt .Errorf ("can't retrieve saml metadata: %w" , err )
152+ }
153+ defer res .Body .Close ()
154+
155+ if res .StatusCode != http .StatusOK {
156+ return metadata , fmt .Errorf ("metadata fetch returned %s" , res .Status )
31157 }
32158
33- rawMetadata , err := io . ReadAll (res .Body )
159+ rawMetadata , err := readCapped (res .Body , maxMetadataBytes )
34160 if err != nil {
35- return metadata , fmt . Errorf ( "can't read saml cert data: %s" , err )
161+ return metadata , err
36162 }
37163
38164 err = xml .Unmarshal (rawMetadata , & metadata )
@@ -41,3 +167,15 @@ func getMetadata(metadataURL string) (types.EntityDescriptor, error) {
41167 }
42168 return metadata , nil
43169}
170+
171+ func readCapped (r io.Reader , max int64 ) ([]byte , error ) {
172+ lr := io .LimitReader (r , max + 1 )
173+ b , err := io .ReadAll (lr )
174+ if err != nil {
175+ return nil , fmt .Errorf ("can't read saml metadata: %w" , err )
176+ }
177+ if int64 (len (b )) > max {
178+ return nil , fmt .Errorf ("metadata too large (>%d bytes)" , max )
179+ }
180+ return b , nil
181+ }
0 commit comments