diff --git a/main.go b/main.go index 30319b12..9ad11ce7 100644 --- a/main.go +++ b/main.go @@ -22,10 +22,11 @@ func init() { func main() { var ( - showVersion = flag.Bool("version", false, "Print version information.") - listenAddress = flag.String("web.listen-address", ":9237", "Address to listen on for web interface and telemetry.") - metricsPath = flag.String("web.telemetry-path", "/metrics", "Path under which to expose metrics.") - configFile = flag.String("config.file", os.Getenv("CONFIG"), "SQL Exporter configuration file name.") + showVersion = flag.Bool("version", false, "Print version information.") + listenAddress = flag.String("web.listen-address", ":9237", "Address to listen on for web interface and telemetry.") + metricsPath = flag.String("web.telemetry-path", "/metrics", "Path under which to expose metrics.") + configFile = flag.String("config.file", os.Getenv("CONFIG"), "SQL Exporter configuration file name.") + dbConnectivityAsHealthCheck = flag.Bool("db.connectivity-as-healthz", false, "Use database connectivity check as healthz probe") ) flag.Parse() @@ -66,7 +67,49 @@ func main() { // setup and start webserver http.Handle(*metricsPath, promhttp.Handler()) - http.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { http.Error(w, "OK", http.StatusOK) }) + + if *dbConnectivityAsHealthCheck { + http.HandleFunc("/healthz", + func(w http.ResponseWriter, r *http.Request) { + for _, job := range exporter.jobs { + + if job == nil { + continue + } + + for _, connection := range job.conns { + + if connection == nil { + continue + } + + if connection.conn != nil { + if err := connection.conn.Ping(); err != nil { + // if any of the connections fails to be established/verified, fail the /healthz request + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // otherwise we've successfully verified the connection, continue to the next one + continue + } + + if err := connection.connect(job); err != nil { + // if any of the connections fails to be established, fail the /healthz request + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + } + } + + // otherwise return OK + http.Error(w, "OK", http.StatusOK) + + }) + } else { + http.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { http.Error(w, "OK", http.StatusOK) }) + } + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`