diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0e31a2a..75fecaf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,7 @@ jobs: uses: DeterminateSystems/magic-nix-cache-action@565684385bcd71bad329742eefe8d12f2e765b39 # v13 - name: Run Checks (lint, vet, test) - run: nix develop --command check + run: nix flake check - name: Generate Coverage Report run: nix develop --command test-coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 120000 index 8a5689c..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1 +0,0 @@ -/nix/store/9ib8mw33w0z5l9wsvjvajnlvyazlrsxd-pre-commit-config.json \ No newline at end of file diff --git a/README.md b/README.md index f5add7b..22d0f05 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ A daemon that synchronizes certificates from Vault to HAProxy using the HAProxy | `HAPROXY_DATAPLANE_API_INSECURE` | `false` | Skip TLS certificate verification for HTTPS connections | | `CERTIFICATEE_UPDATE_INTERVAL` | `24h` | How often to check certificates for updates | | `CERTIFICATEE_RENEW_BEFORE_DAYS` | `30` | Update certificates expiring within this many days | +| `CERTIFICATEE_LOCAL_CERTS_DIR` | | Directory containing `.pem` bundles (cert + key) to use instead of Vault | | `VAULT_APPROLE_ROLE_ID` | (required) | Vault AppRole Role ID | | `NOMAD_TOKEN` | (required) | Used as Vault AppRole Secret ID | | `VAULT_KV_STORAGE_PATH` | `secret/data/certificator/` | Vault KV storage path for certificates | @@ -37,6 +38,8 @@ A daemon that synchronizes certificates from Vault to HAProxy using the HAProxy | `LOG_LEVEL` | `INFO` | Log level: `DEBUG`, `INFO`, `WARN`, `ERROR` | | `ENVIRONMENT` | `prod` | Environment name for metrics labels | +If `CERTIFICATEE_LOCAL_CERTS_DIR` is set, Vault configuration is optional and certificates are read from local PEM bundles instead. + ### Certificator Environment Variables | Variable | Default | Description | @@ -59,7 +62,7 @@ Certificatee uses the HAProxy Data Plane API to update certificates at runtime w - **Basic authentication**: Authenticate using username/password credentials - **Automatic retries**: Connections are retried with exponential backoff (default: 3 retries, 1-30s delays) - **Graceful degradation**: If one HAProxy instance is unreachable, the tool continues updating reachable instances -- **REST API**: Certificates are managed via the `/v2/services/haproxy/runtime/certs` endpoints +- **REST API**: Certificates are managed via the `/v2/services/haproxy/storage/ssl_certificates` endpoints ### HAProxy Data Plane API Configuration diff --git a/cmd/certificatee/certsource.go b/cmd/certificatee/certsource.go new file mode 100644 index 0000000..99f45e2 --- /dev/null +++ b/cmd/certificatee/certsource.go @@ -0,0 +1,117 @@ +package main + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "path/filepath" + + "github.com/vinted/certificator/pkg/certificate" + "github.com/vinted/certificator/pkg/vault" +) + +type CertSource interface { + GetCertificate(domain string) (*x509.Certificate, error) + GetPEMBundle(domain string) (string, error) +} + +type VaultCertSource struct { + client *vault.VaultClient +} + +func (v VaultCertSource) GetCertificate(domain string) (*x509.Certificate, error) { + return certificate.GetCertificate(domain, v.client) +} + +func (v VaultCertSource) GetPEMBundle(domain string) (string, error) { + certificateSecrets, err := v.client.KVRead(certificate.VaultCertLocation(domain)) + if err != nil { + return "", fmt.Errorf("failed to read certificate data from vault for %s: %w", domain, err) + } + + pemData, err := buildPEMBundle(certificateSecrets) + if err != nil { + return "", fmt.Errorf("failed to build PEM bundle for %s: %w", domain, err) + } + + return pemData, nil +} + +type LocalCertSource struct { + dir string +} + +func (l LocalCertSource) GetCertificate(domain string) (*x509.Certificate, error) { + data, err := os.ReadFile(l.certPath(domain)) + if err != nil { + return nil, fmt.Errorf("failed to read local certificate for %s: %w", domain, err) + } + + cert, err := parseCertificateFromPEM(data) + if err != nil { + return nil, fmt.Errorf("failed to parse local certificate for %s: %w", domain, err) + } + + return cert, nil +} + +func (l LocalCertSource) GetPEMBundle(domain string) (string, error) { + data, err := os.ReadFile(l.certPath(domain)) + if err != nil { + return "", fmt.Errorf("failed to read local PEM bundle for %s: %w", domain, err) + } + + return string(data), nil +} + +func (l LocalCertSource) certPath(domain string) string { + return filepath.Join(l.dir, domain+".pem") +} + +func parseCertificateFromPEM(pemData []byte) (*x509.Certificate, error) { + rest := pemData + for len(rest) > 0 { + var block *pem.Block + block, rest = pem.Decode(rest) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { + continue + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + return cert, nil + } + + return nil, fmt.Errorf("no certificate PEM block found") +} + +// buildPEMBundle creates a PEM bundle from Vault certificate secrets +func buildPEMBundle(secrets map[string]any) (string, error) { + var pemData string + + // Add certificate + if cert, ok := secrets["certificate"].(string); ok && cert != "" { + pemData += cert + } else { + return "", fmt.Errorf("certificate not found in vault secrets") + } + + // Add newline between cert and key + if !endsWith(pemData, "\n") { + pemData += "\n" + } + + // Add private key + if key, ok := secrets["private_key"].(string); ok && key != "" { + pemData += key + } else { + return "", fmt.Errorf("private_key not found in vault secrets") + } + + return pemData, nil +} diff --git a/cmd/certificatee/haproxy_config.go b/cmd/certificatee/haproxy_config.go new file mode 100644 index 0000000..5ac94eb --- /dev/null +++ b/cmd/certificatee/haproxy_config.go @@ -0,0 +1,66 @@ +package main + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/vinted/certificator/pkg/haproxy" +) + +func sanitizeWildcardCertName(name string) string { + name = strings.ReplaceAll(name, "*", "_") + name = strings.ReplaceAll(name, "..", ".") + return name +} + +func ensureStorageCertificate(haproxyClient *haproxy.Client, certName, pemData string) error { + refs, err := haproxyClient.ListCertificateRefs() + if err != nil { + return err + } + + normalized := normalizeCertificateName(certName) + for _, ref := range refs { + if certRefMatches(ref, certName) || certRefMatches(ref, normalized) { + if err := haproxyClient.UpdateCertificate(certName, pemData); err == nil { + return nil + } + break + } + } + + if err := haproxyClient.CreateCertificate(certName, pemData); err != nil { + // Data Plane API normalizes storage names (e.g., replaces '*' and other chars with '_'), + // so a create may return 409 even if the exact requested name wasn't found above. + if strings.Contains(err.Error(), "already exists") { + if err := haproxyClient.UpdateCertificate(certName, pemData); err == nil { + return nil + } + } + return fmt.Errorf("failed to create certificate %s: %w", certName, err) + } + return nil +} + +func normalizeCertificateName(name string) string { + builder := strings.Builder{} + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' || r == '.' { + builder.WriteRune(r) + } else { + builder.WriteRune('_') + } + } + return builder.String() +} + +func certRefMatches(ref haproxy.CertificateRef, name string) bool { + if ref.DisplayName == name { + return true + } + if ref.FilePath == name { + return true + } + return filepath.Base(ref.FilePath) == name +} diff --git a/cmd/certificatee/integration_test.go b/cmd/certificatee/integration_test.go new file mode 100644 index 0000000..2607671 --- /dev/null +++ b/cmd/certificatee/integration_test.go @@ -0,0 +1,868 @@ +//go:build integration + +package main + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/vinted/certificator/pkg/haproxy" +) + +func TestCertificateeUpdatesCertViaDataPlane(t *testing.T) { + if testing.Short() { + t.Skip("integration test") + } + + haproxyPath := requireLookPath(t, "haproxy") + dataplanePath := requireLookPath(t, "dataplaneapi") + opensslPath := requireLookPath(t, "openssl") + + tempDir := t.TempDir() + certsDir := filepath.Join(tempDir, "haproxy-certs") + localCertsDir := filepath.Join(tempDir, "local-certs") + mapsDir := filepath.Join(tempDir, "maps") + spoeDir := filepath.Join(tempDir, "spoe") + storageDir := filepath.Join(tempDir, "storage") + runDir := mustMakeShortTempDir(t, "/tmp", "certificator-it-") + mustMkdirAll(t, certsDir, localCertsDir, mapsDir, spoeDir, storageDir) + + domain := "example-test" + haproxyCertPath := filepath.Join(certsDir, domain+".pem") + localCertPath := filepath.Join(localCertsDir, domain+".pem") + + haproxyCertPEM, haproxySerial := mustSelfSignedPEM(t, domain, time.Now().Add(365*24*time.Hour), big.NewInt(1)) + localCertPEM, localSerial := mustSelfSignedPEM(t, domain, time.Now().Add(24*time.Hour), big.NewInt(2)) + if haproxySerial == localSerial { + t.Fatalf("expected different serials for test certificates") + } + + mustWriteFile(t, haproxyCertPath, haproxyCertPEM) + mustWriteFile(t, localCertPath, localCertPEM) + + haproxyPort := freePort(t) + socketPath := filepath.Join(runDir, "haproxy.sock") + pidPath := filepath.Join(runDir, "haproxy.pid") + haproxyCfgPath := filepath.Join(tempDir, "haproxy.cfg") + haproxyCfg := fmt.Sprintf(`global + log stdout format raw local0 + stats socket %s mode 600 level admin + maxconn 256 +userlist controller + user admin insecure-password admin +defaults + mode http + timeout connect 5s + timeout client 5s + timeout server 5s +frontend fe_tls + bind 127.0.0.1:%d ssl crt %s + default_backend be +backend be + server s1 127.0.0.1:8080 +`, socketPath, haproxyPort, haproxyCertPath) + "\n" + mustWriteFile(t, haproxyCfgPath, []byte(haproxyCfg)) + + _, haproxyLogs := startProcess(t, haproxyPath, []string{"-f", haproxyCfgPath, "-db", "-p", pidPath}, nil, "") + waitForSocket(t, socketPath, 5*time.Second, haproxyLogs) + + reloadScript := mustWriteReloadScript(t, tempDir, haproxyPath, haproxyCfgPath, pidPath) + dataplaneCfgPath := filepath.Join(tempDir, "dataplaneapi.yaml") + dataplaneArgs := []string{ + "-f", dataplaneCfgPath, + "--scheme", "http", + "--host", "127.0.0.1", + "--config-file", haproxyCfgPath, + "--haproxy-bin", haproxyPath, + "--userlist", "controller", + "--ssl-certs-dir", certsDir, + "--general-storage-dir", storageDir, + "--maps-dir", mapsDir, + "--spoe-dir", spoeDir, + "--reload-strategy", "custom", + "--reload-cmd", reloadScript, + "--restart-cmd", reloadScript, + "--log-level", "info", + } + _, dataplaneLogs := startProcess(t, dataplanePath, dataplaneArgs, nil, "") + apiURL := waitForDataPlaneURL(t, dataplaneLogs, 20*time.Second) + waitForDataPlaneAPI(t, apiURL, "admin", "admin", domain+".pem", dataplaneLogs) + + pkgDir := mustGetwd(t) + repoRoot := filepath.Clean(filepath.Join(pkgDir, "../..")) + certificateeBin := filepath.Join(tempDir, "certificatee") + buildCmd := exec.Command("go", "build", "-o", certificateeBin, "./cmd/certificatee") + buildCmd.Dir = repoRoot + buildCmd.Stdout = os.Stdout + buildCmd.Stderr = os.Stderr + if err := buildCmd.Run(); err != nil { + t.Fatalf("failed to build certificatee: %v", err) + } + + certificateeEnv := []string{ + fmt.Sprintf("HAPROXY_DATAPLANE_API_URLS=%s", apiURL), + "HAPROXY_DATAPLANE_API_USER=admin", + "HAPROXY_DATAPLANE_API_PASSWORD=admin", + fmt.Sprintf("CERTIFICATEE_LOCAL_CERTS_DIR=%s", localCertsDir), + "CERTIFICATEE_RENEW_BEFORE_DAYS=30", + "CERTIFICATEE_UPDATE_INTERVAL=1h", + "LOG_LEVEL=DEBUG", + } + + _, certificateeLogs := startProcess(t, certificateeBin, nil, certificateeEnv, repoRoot) + waitForSerialMatch(t, haproxyCertPath, localSerial, 10*time.Second, certificateeLogs, haproxyLogs, dataplaneLogs) + waitForServerSerialMatch(t, opensslPath, fmt.Sprintf("127.0.0.1:%d", haproxyPort), domain, localSerial, 20*time.Second) +} + +func TestCertificateeUpdatesCertWhenMetadataDiffers(t *testing.T) { + if testing.Short() { + t.Skip("integration test") + } + + haproxyPath := requireLookPath(t, "haproxy") + dataplanePath := requireLookPath(t, "dataplaneapi") + opensslPath := requireLookPath(t, "openssl") + + tempDir := t.TempDir() + certsDir := filepath.Join(tempDir, "haproxy-certs") + localCertsDir := filepath.Join(tempDir, "local-certs") + mapsDir := filepath.Join(tempDir, "maps") + spoeDir := filepath.Join(tempDir, "spoe") + storageDir := filepath.Join(tempDir, "storage") + runDir := mustMakeShortTempDir(t, "/tmp", "certificator-it-") + mustMkdirAll(t, certsDir, localCertsDir, mapsDir, spoeDir, storageDir) + + domain := "example-metadata" + haproxyCertPath := filepath.Join(certsDir, domain+".pem") + localCertPath := filepath.Join(localCertsDir, domain+".pem") + + haproxyCertPEM, haproxySerial := mustSelfSignedPEM(t, domain, time.Now().Add(365*24*time.Hour), big.NewInt(101)) + localCertPEM, localSerial := mustSelfSignedPEM(t, domain, time.Now().Add(365*24*time.Hour), big.NewInt(102)) + if haproxySerial == localSerial { + t.Fatalf("expected different serials for test certificates") + } + + mustWriteFile(t, haproxyCertPath, haproxyCertPEM) + mustWriteFile(t, localCertPath, localCertPEM) + + haproxyPort := freePort(t) + socketPath := filepath.Join(runDir, "haproxy.sock") + pidPath := filepath.Join(runDir, "haproxy.pid") + haproxyCfgPath := filepath.Join(tempDir, "haproxy.cfg") + haproxyCfg := fmt.Sprintf(`global + log stdout format raw local0 + stats socket %s mode 600 level admin + maxconn 256 +userlist controller + user admin insecure-password admin +defaults + mode http + timeout connect 5s + timeout client 5s + timeout server 5s +frontend fe_tls + bind 127.0.0.1:%d ssl crt %s + default_backend be +backend be + server s1 127.0.0.1:8080 +`, socketPath, haproxyPort, haproxyCertPath) + "\n" + mustWriteFile(t, haproxyCfgPath, []byte(haproxyCfg)) + + _, haproxyLogs := startProcess(t, haproxyPath, []string{"-f", haproxyCfgPath, "-db", "-p", pidPath}, nil, "") + waitForSocket(t, socketPath, 5*time.Second, haproxyLogs) + + reloadScript := mustWriteReloadScript(t, tempDir, haproxyPath, haproxyCfgPath, pidPath) + dataplaneCfgPath := filepath.Join(tempDir, "dataplaneapi.yaml") + dataplaneArgs := []string{ + "-f", dataplaneCfgPath, + "--scheme", "http", + "--host", "127.0.0.1", + "--config-file", haproxyCfgPath, + "--haproxy-bin", haproxyPath, + "--userlist", "controller", + "--ssl-certs-dir", certsDir, + "--general-storage-dir", storageDir, + "--maps-dir", mapsDir, + "--spoe-dir", spoeDir, + "--reload-strategy", "custom", + "--reload-cmd", reloadScript, + "--restart-cmd", reloadScript, + "--log-level", "info", + } + _, dataplaneLogs := startProcess(t, dataplanePath, dataplaneArgs, nil, "") + apiURL := waitForDataPlaneURL(t, dataplaneLogs, 20*time.Second) + waitForDataPlaneAPI(t, apiURL, "admin", "admin", domain+".pem", dataplaneLogs) + + pkgDir := mustGetwd(t) + repoRoot := filepath.Clean(filepath.Join(pkgDir, "../..")) + certificateeBin := filepath.Join(tempDir, "certificatee") + buildCmd := exec.Command("go", "build", "-o", certificateeBin, "./cmd/certificatee") + buildCmd.Dir = repoRoot + buildCmd.Stdout = os.Stdout + buildCmd.Stderr = os.Stderr + if err := buildCmd.Run(); err != nil { + t.Fatalf("failed to build certificatee: %v", err) + } + + certificateeEnv := []string{ + fmt.Sprintf("HAPROXY_DATAPLANE_API_URLS=%s", apiURL), + "HAPROXY_DATAPLANE_API_USER=admin", + "HAPROXY_DATAPLANE_API_PASSWORD=admin", + fmt.Sprintf("CERTIFICATEE_LOCAL_CERTS_DIR=%s", localCertsDir), + "CERTIFICATEE_RENEW_BEFORE_DAYS=30", + "CERTIFICATEE_UPDATE_INTERVAL=1h", + "LOG_LEVEL=DEBUG", + } + + _, certificateeLogs := startProcess(t, certificateeBin, nil, certificateeEnv, repoRoot) + waitForSerialMatch(t, haproxyCertPath, localSerial, 10*time.Second, certificateeLogs, haproxyLogs, dataplaneLogs) + waitForServerSerialMatch(t, opensslPath, fmt.Sprintf("127.0.0.1:%d", haproxyPort), domain, localSerial, 20*time.Second) +} + +func TestHAProxyPrefersExactCertOverWildcard(t *testing.T) { + if testing.Short() { + t.Skip("integration test") + } + + haproxyPath := requireLookPath(t, "haproxy") + opensslPath := requireLookPath(t, "openssl") + + tempDir := t.TempDir() + certsDir := filepath.Join(tempDir, "haproxy-certs") + runDir := mustMakeShortTempDir(t, "/tmp", "certificator-it-") + mustMkdirAll(t, certsDir) + + domain := "example-test" + wildcardDomain := "*." + domain + + exactCertPEM, exactSerial := mustSelfSignedPEM(t, domain, time.Now().Add(365*24*time.Hour), big.NewInt(11)) + wildcardCertPEM, wildcardSerial := mustSelfSignedPEM(t, wildcardDomain, time.Now().Add(365*24*time.Hour), big.NewInt(12)) + if exactSerial == wildcardSerial { + t.Fatalf("expected different serials for test certificates") + } + + exactCertPath := filepath.Join(certsDir, domain+".pem") + wildcardCertPath := filepath.Join(certsDir, wildcardDomain+".pem") + mustWriteFile(t, exactCertPath, exactCertPEM) + mustWriteFile(t, wildcardCertPath, wildcardCertPEM) + + haproxyPort := freePort(t) + socketPath := filepath.Join(runDir, "haproxy.sock") + haproxyCfgPath := filepath.Join(tempDir, "haproxy.cfg") + haproxyCfg := fmt.Sprintf(`global + log stdout format raw local0 + stats socket %s mode 600 level admin + maxconn 256 +defaults + mode http + timeout connect 5s + timeout client 5s + timeout server 5s +frontend fe_tls + bind 127.0.0.1:%d ssl crt %s alpn h2,http/1.1 + default_backend be +backend be + server s1 127.0.0.1:8080 +`, socketPath, haproxyPort, certsDir) + "\n" + mustWriteFile(t, haproxyCfgPath, []byte(haproxyCfg)) + + _, haproxyLogs := startProcess(t, haproxyPath, []string{"-f", haproxyCfgPath, "-db"}, nil, "") + waitForSocket(t, socketPath, 5*time.Second, haproxyLogs) + + addr := fmt.Sprintf("127.0.0.1:%d", haproxyPort) + subject, err := opensslServerSubject(t, opensslPath, addr, domain) + if err != nil { + t.Fatalf("failed to get server cert subject: %v", err) + } + + if !strings.Contains(subject, "CN="+domain) && !strings.Contains(subject, "CN = "+domain) { + t.Fatalf("expected exact certificate CN %q, got subject %q", domain, subject) + } +} + +func TestDataPlaneAcceptsUnderscoreWildcardName(t *testing.T) { + if testing.Short() { + t.Skip("integration test") + } + + haproxyPath := requireLookPath(t, "haproxy") + dataplanePath := requireLookPath(t, "dataplaneapi") + + tempDir := t.TempDir() + certsDir := filepath.Join(tempDir, "haproxy-certs") + runDir := mustMakeShortTempDir(t, "/tmp", "certificator-it-") + mustMkdirAll(t, certsDir) + + domain := "example-test" + certName := "_." + domain + ".pem" + certPEM, _ := mustSelfSignedPEM(t, "*."+domain, time.Now().Add(365*24*time.Hour), big.NewInt(21)) + initialCertPath := filepath.Join(certsDir, "bootstrap.pem") + mustWriteFile(t, initialCertPath, certPEM) + + haproxyPort := freePort(t) + socketPath := filepath.Join(runDir, "haproxy.sock") + pidPath := filepath.Join(runDir, "haproxy.pid") + haproxyCfgPath := filepath.Join(tempDir, "haproxy.cfg") + haproxyCfg := fmt.Sprintf(`global + log stdout format raw local0 + stats socket %s mode 600 level admin + maxconn 256 +userlist controller + user admin insecure-password admin +defaults + mode http + timeout connect 5s + timeout client 5s + timeout server 5s +frontend fe_tls + bind 127.0.0.1:%d ssl crt %s + default_backend be +backend be + server s1 127.0.0.1:8080 +`, socketPath, haproxyPort, certsDir) + "\n" + mustWriteFile(t, haproxyCfgPath, []byte(haproxyCfg)) + + _, haproxyLogs := startProcess(t, haproxyPath, []string{"-f", haproxyCfgPath, "-db", "-p", pidPath}, nil, "") + waitForSocket(t, socketPath, 5*time.Second, haproxyLogs) + + reloadScript := mustWriteReloadScript(t, tempDir, haproxyPath, haproxyCfgPath, pidPath) + dataplaneCfgPath := filepath.Join(tempDir, "dataplaneapi.yaml") + dataplaneArgs := []string{ + "-f", dataplaneCfgPath, + "--scheme", "http", + "--host", "127.0.0.1", + "--config-file", haproxyCfgPath, + "--haproxy-bin", haproxyPath, + "--userlist", "controller", + "--ssl-certs-dir", certsDir, + "--general-storage-dir", filepath.Join(tempDir, "storage"), + "--maps-dir", filepath.Join(tempDir, "maps"), + "--spoe-dir", filepath.Join(tempDir, "spoe"), + "--reload-strategy", "custom", + "--reload-cmd", reloadScript, + "--restart-cmd", reloadScript, + "--log-level", "info", + } + _, dataplaneLogs := startProcess(t, dataplanePath, dataplaneArgs, nil, "") + apiURL := waitForDataPlaneURL(t, dataplaneLogs, 20*time.Second) + + logger := logrus.New() + client, err := haproxy.NewClient(haproxy.ClientConfig{ + BaseURL: apiURL, + Username: "admin", + Password: "admin", + }, logger) + if err != nil { + t.Fatalf("failed to create haproxy client: %v", err) + } + + if err := client.CreateCertificate(certName, string(certPEM)); err != nil { + t.Fatalf("failed to create certificate %s: %v", certName, err) + } + + refs, err := client.ListCertificateRefs() + if err != nil { + t.Fatalf("failed to list certificates: %v", err) + } + + found := false + for _, ref := range refs { + if ref.DisplayName == certName || filepath.Base(ref.FilePath) == certName { + found = true + break + } + } + if !found { + for _, ref := range refs { + if strings.HasPrefix(ref.DisplayName, "_") || strings.HasPrefix(filepath.Base(ref.FilePath), "_") { + found = true + break + } + } + } + if !found { + t.Fatalf("expected to find certificate with _ prefix in storage list") + } +} + +func TestWildcardReplacementSurvivesWildcardDeletion(t *testing.T) { + if testing.Short() { + t.Skip("integration test") + } + + haproxyPath := requireLookPath(t, "haproxy") + dataplanePath := requireLookPath(t, "dataplaneapi") + opensslPath := requireLookPath(t, "openssl") + + tempDir := t.TempDir() + certsDir := filepath.Join(tempDir, "haproxy-certs") + localCertsDir := filepath.Join(tempDir, "local-certs") + runDir := mustMakeShortTempDir(t, "/tmp", "certificator-it-") + mustMkdirAll(t, certsDir, localCertsDir) + + domain := "example-test" + wildcardDomain := "*." + domain + wildcardName := wildcardDomain + ".pem" + + oldPEM, oldSerial := mustSelfSignedPEM(t, wildcardDomain, time.Now().Add(24*time.Hour), big.NewInt(31)) + newPEM1, serial1 := mustSelfSignedPEM(t, wildcardDomain, time.Now().Add(24*time.Hour), big.NewInt(32)) + newPEM2, serial2 := mustSelfSignedPEM(t, wildcardDomain, time.Now().Add(24*time.Hour), big.NewInt(33)) + if serial1 == serial2 { + t.Fatalf("expected different serials for updated certificates") + } + + oldPath := filepath.Join(certsDir, wildcardName) + mustWriteFile(t, oldPath, oldPEM) + mustWriteFile(t, filepath.Join(localCertsDir, wildcardName), newPEM1) + + haproxyPort := freePort(t) + socketPath := filepath.Join(runDir, "haproxy.sock") + pidPath := filepath.Join(runDir, "haproxy.pid") + haproxyCfgPath := filepath.Join(tempDir, "haproxy.cfg") + haproxyCfg := fmt.Sprintf(`global + log stdout format raw local0 + stats socket %s mode 600 level admin + maxconn 256 +userlist controller + user admin insecure-password admin +defaults + mode http + timeout connect 5s + timeout client 5s + timeout server 5s +frontend fe_tls + bind 127.0.0.1:%d ssl crt %s + default_backend be +backend be + server s1 127.0.0.1:8080 +`, socketPath, haproxyPort, certsDir) + "\n" + mustWriteFile(t, haproxyCfgPath, []byte(haproxyCfg)) + + _, haproxyLogs := startProcess(t, haproxyPath, []string{"-f", haproxyCfgPath, "-db", "-p", pidPath}, nil, "") + waitForSocket(t, socketPath, 5*time.Second, haproxyLogs) + + reloadScript := mustWriteReloadScript(t, tempDir, haproxyPath, haproxyCfgPath, pidPath) + dataplaneCfgPath := filepath.Join(tempDir, "dataplaneapi.yaml") + dataplaneArgs := []string{ + "-f", dataplaneCfgPath, + "--scheme", "http", + "--host", "127.0.0.1", + "--config-file", haproxyCfgPath, + "--haproxy-bin", haproxyPath, + "--userlist", "controller", + "--ssl-certs-dir", certsDir, + "--general-storage-dir", filepath.Join(tempDir, "storage"), + "--maps-dir", filepath.Join(tempDir, "maps"), + "--spoe-dir", filepath.Join(tempDir, "spoe"), + "--reload-strategy", "custom", + "--reload-cmd", reloadScript, + "--restart-cmd", reloadScript, + "--log-level", "info", + } + _, dataplaneLogs := startProcess(t, dataplanePath, dataplaneArgs, nil, "") + apiURL := waitForDataPlaneURL(t, dataplaneLogs, 20*time.Second) + + logger := logrus.New() + client, err := haproxy.NewClient(haproxy.ClientConfig{ + BaseURL: apiURL, + Username: "admin", + Password: "admin", + }, logger) + if err != nil { + t.Fatalf("failed to create haproxy client: %v", err) + } + + certSource := LocalCertSource{dir: localCertsDir} + ref := haproxy.CertificateRef{ + DisplayName: wildcardName, + FilePath: oldPath, + } + + if err := updateCertificate(ref, wildcardDomain, certSource, client); err != nil { + t.Fatalf("failed to update wildcard certificate: %v", err) + } + + underscorePath := findUnderscoreCertPath(t, client) + serial, err := certSerialFromPEM(underscorePath) + if err != nil { + t.Fatalf("failed to read underscore cert serial: %v", err) + } + if serial != serial1 { + t.Fatalf("expected underscore cert serial %s, got %s", serial1, serial) + } + waitForServerSerialMatch(t, opensslPath, fmt.Sprintf("127.0.0.1:%d", haproxyPort), "foo."+domain, oldSerial, 20*time.Second) + + if err := os.Remove(oldPath); err != nil { + t.Fatalf("failed to remove old wildcard cert: %v", err) + } + + mustWriteFile(t, filepath.Join(localCertsDir, wildcardName), newPEM2) + if err := updateCertificate(ref, wildcardDomain, certSource, client); err != nil { + t.Fatalf("failed to update underscore certificate: %v", err) + } + + serial, err = certSerialFromPEM(underscorePath) + if err != nil { + t.Fatalf("failed to read underscore cert serial: %v", err) + } + if serial != serial2 { + t.Fatalf("expected underscore cert serial %s after update, got %s", serial2, serial) + } + waitForServerSerialMatch(t, opensslPath, fmt.Sprintf("127.0.0.1:%d", haproxyPort), "foo."+domain, serial2, 20*time.Second) +} + +func requireLookPath(t *testing.T, bin string) string { + t.Helper() + path, err := exec.LookPath(bin) + if err != nil { + t.Fatalf("required binary %s not found in PATH: %v", bin, err) + } + return path +} + +func mustMkdirAll(t *testing.T, dirs ...string) { + t.Helper() + for _, dir := range dirs { + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("failed to create dir %s: %v", dir, err) + } + } +} + +func mustWriteFile(t *testing.T, path string, data []byte) { + t.Helper() + if err := os.WriteFile(path, data, 0o600); err != nil { + t.Fatalf("failed to write %s: %v", path, err) + } +} + +func mustWriteExecutable(t *testing.T, path string, data []byte) { + t.Helper() + if err := os.WriteFile(path, data, 0o700); err != nil { + t.Fatalf("failed to write %s: %v", path, err) + } +} + +func mustSelfSignedPEM(t *testing.T, domain string, notAfter time.Time, serial *big.Int) ([]byte, string) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: domain, + }, + DNSNames: []string{domain}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + return append(certPEM, keyPEM...), template.SerialNumber.String() +} + +func startProcess(t *testing.T, binary string, args []string, env []string, dir string) (*exec.Cmd, *bytes.Buffer) { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + cmd := exec.CommandContext(ctx, binary, args...) + if dir != "" { + cmd.Dir = dir + } + if env != nil { + cmd.Env = append(os.Environ(), env...) + } + var output bytes.Buffer + cmd.Stdout = &output + cmd.Stderr = &output + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start %s: %v", binary, err) + } + t.Cleanup(func() { + cancel() + _ = cmd.Wait() + }) + return cmd, &output +} + +func waitForSocket(t *testing.T, socketPath string, timeout time.Duration, logs *bytes.Buffer) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if _, err := os.Stat(socketPath); err == nil { + return + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("haproxy socket did not appear at %s; logs:\n%s", socketPath, logs.String()) +} + +func waitForDataPlaneAPI(t *testing.T, baseURL, user, pass, wantCert string, logs *bytes.Buffer) { + t.Helper() + httpClient := &http.Client{Timeout: 2 * time.Second} + deadline := time.Now().Add(30 * time.Second) + for time.Now().Before(deadline) { + req, err := http.NewRequest("GET", fmt.Sprintf("%s/v2/services/haproxy/storage/ssl_certificates", baseURL), nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.SetBasicAuth(user, pass) + resp, err := httpClient.Do(req) + if err == nil { + if resp.StatusCode == http.StatusOK { + var certs []haproxy.SSLCertificateEntry + if decodeErr := json.NewDecoder(resp.Body).Decode(&certs); decodeErr == nil { + for _, cert := range certs { + if cert.StorageName == wantCert || cert.File == wantCert || filepath.Base(cert.File) == wantCert { + _ = resp.Body.Close() + return + } + } + } + } + _ = resp.Body.Close() + } + time.Sleep(200 * time.Millisecond) + } + t.Fatalf("dataplane API did not list certificate %s; logs:\n%s", wantCert, logs.String()) +} + +func waitForDataPlaneURL(t *testing.T, logs *bytes.Buffer, timeout time.Duration) string { + t.Helper() + re := regexp.MustCompile(`Serving data plane at (http://[^\\s"]+)`) + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + matches := re.FindStringSubmatch(logs.String()) + if len(matches) == 2 { + return strings.TrimSpace(matches[1]) + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("dataplane API URL not found in logs within %s; logs:\n%s", timeout, logs.String()) + return "" +} + +func waitForSerialMatch(t *testing.T, certPath, expectedSerial string, timeout time.Duration, logs ...*bytes.Buffer) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + serial, err := certSerialFromPEM(certPath) + if err == nil && serial == expectedSerial { + return + } + time.Sleep(200 * time.Millisecond) + } + var combined strings.Builder + for i, logBuf := range logs { + if logBuf == nil { + continue + } + if i > 0 { + combined.WriteString("\n") + } + combined.WriteString(logBuf.String()) + } + t.Fatalf("certificate at %s did not update to serial %s; logs:\n%s", certPath, expectedSerial, combined.String()) +} + +func certSerialFromPEM(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + block, _ := pem.Decode(data) + if block == nil || block.Type != "CERTIFICATE" { + return "", fmt.Errorf("no certificate PEM found in %s", path) + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return "", err + } + return cert.SerialNumber.String(), nil +} + +func freePort(t *testing.T) int { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to get free port: %v", err) + } + defer func() { _ = listener.Close() }() + return listener.Addr().(*net.TCPAddr).Port +} + +func opensslServerSubject(t *testing.T, opensslPath, addr, serverName string) (string, error) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, opensslPath, "s_client", "-connect", addr, "-servername", serverName, "-showcerts") + cmd.Stdin = strings.NewReader("") + output, err := cmd.CombinedOutput() + if err != nil && ctx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("openssl s_client timed out: %s", output) + } + + certPEM, err := extractFirstCertPEM(string(output)) + if err != nil { + return "", err + } + + certCmd := exec.CommandContext(ctx, opensslPath, "x509", "-noout", "-subject") + certCmd.Stdin = strings.NewReader(certPEM) + subjectOut, err := certCmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("openssl x509 failed: %s", subjectOut) + } + + return strings.TrimSpace(string(subjectOut)), nil +} + +func extractFirstCertPEM(output string) (string, error) { + begin := strings.Index(output, "-----BEGIN CERTIFICATE-----") + end := strings.Index(output, "-----END CERTIFICATE-----") + if begin == -1 || end == -1 || end < begin { + return "", fmt.Errorf("certificate PEM not found in output") + } + end += len("-----END CERTIFICATE-----") + return output[begin:end] + "\n", nil +} + +func waitForServerSerialMatch(t *testing.T, opensslPath, addr, serverName, expectedSerial string, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + serial, err := opensslServerSerial(opensslPath, addr, serverName) + if err == nil && serial == expectedSerial { + return + } + time.Sleep(200 * time.Millisecond) + } + t.Fatalf("server at %s did not present serial %s within %s", addr, expectedSerial, timeout) +} + +func opensslServerSerial(opensslPath, addr, serverName string) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, opensslPath, "s_client", "-connect", addr, "-servername", serverName, "-showcerts") + cmd.Stdin = strings.NewReader("") + output, err := cmd.CombinedOutput() + if err != nil && ctx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("openssl s_client timed out: %s", output) + } + + certPEM, err := extractFirstCertPEM(string(output)) + if err != nil { + return "", err + } + + serial, err := serialFromPEM(certPEM) + if err != nil { + return "", err + } + return serial, nil +} + +func serialFromPEM(pemData string) (string, error) { + block, _ := pem.Decode([]byte(pemData)) + if block == nil || block.Type != "CERTIFICATE" { + return "", fmt.Errorf("no certificate PEM found") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return "", err + } + return cert.SerialNumber.String(), nil +} + +func findUnderscoreCertPath(t *testing.T, client *haproxy.Client) string { + t.Helper() + refs, err := client.ListCertificateRefs() + if err != nil { + t.Fatalf("failed to list certificates: %v", err) + } + + for _, ref := range refs { + if strings.HasPrefix(ref.DisplayName, "_") { + if ref.FilePath != "" { + return ref.FilePath + } + return ref.DisplayName + } + if strings.HasPrefix(filepath.Base(ref.FilePath), "_") { + return ref.FilePath + } + } + + t.Fatalf("could not find underscore certificate in storage list") + return "" +} + +func mustGetwd(t *testing.T) string { + t.Helper() + dir, err := os.Getwd() + if err != nil { + t.Fatalf("failed to get working directory: %v", err) + } + return dir +} + +func mustMakeShortTempDir(t *testing.T, base, pattern string) string { + t.Helper() + dir, err := os.MkdirTemp(base, pattern) + if err != nil { + t.Fatalf("failed to create temp dir in %s: %v", base, err) + } + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + return dir +} + +func mustWriteReloadScript(t *testing.T, dir, haproxyPath, cfgPath, pidPath string) string { + t.Helper() + scriptPath := filepath.Join(dir, "haproxy-reload.sh") + script := fmt.Sprintf(`#!/bin/sh +set -eu +if [ -f %q ]; then + oldpid="$(cat %q || true)" +else + oldpid="" +fi +if [ -n "$oldpid" ]; then + exec %q -f %q -p %q -sf "$oldpid" +else + exec %q -f %q -p %q +fi +`, pidPath, pidPath, haproxyPath, cfgPath, pidPath, haproxyPath, cfgPath, pidPath) + mustWriteExecutable(t, scriptPath, []byte(script)) + return scriptPath +} diff --git a/cmd/certificatee/main.go b/cmd/certificatee/main.go index 1083ac5..1126637 100644 --- a/cmd/certificatee/main.go +++ b/cmd/certificatee/main.go @@ -3,11 +3,12 @@ package main import ( "errors" "fmt" + "path/filepath" + "strings" "time" legoLog "github.com/go-acme/lego/v4/log" "github.com/sirupsen/logrus" - "github.com/vinted/certificator/pkg/certificate" "github.com/vinted/certificator/pkg/certmetrics" "github.com/vinted/certificator/pkg/config" "github.com/vinted/certificator/pkg/haproxy" @@ -35,10 +36,17 @@ func main() { certmetrics.StartMetricsServer(logger, cfg.Metrics.ListenAddress) defer certmetrics.PushMetrics(logger, cfg.Metrics.PushUrl) - vaultClient, err := vault.NewVaultClient(cfg.Vault.ApproleRoleID, - cfg.Vault.ApproleSecretID, cfg.Environment, cfg.Vault.KVStoragePath, logger) - if err != nil { - logger.Fatal(err) + var certSource CertSource + if cfg.Certificatee.LocalCertsDir != "" { + certSource = LocalCertSource{dir: cfg.Certificatee.LocalCertsDir} + logger.Warnf("Using local certificate directory: %s", cfg.Certificatee.LocalCertsDir) + } else { + vaultClient, err := vault.NewVaultClient(cfg.Vault.ApproleRoleID, + cfg.Vault.ApproleSecretID, cfg.Environment, cfg.Vault.KVStoragePath, logger) + if err != nil { + logger.Fatal(err) + } + certSource = VaultCertSource{client: vaultClient} } haproxyClients, err := createHAProxyClients(cfg, logger) @@ -58,25 +66,25 @@ func main() { defer certmetrics.Up.WithLabelValues("certificatee", version, cfg.Hostname, cfg.Environment).Set(0) // Initial run - if err := maybeUpdateCertificates(logger, cfg, vaultClient, haproxyClients); err != nil { + if err := maybeUpdateCertificates(logger, cfg, certSource, haproxyClients); err != nil { logger.Error(err) } for range ticker.C { - if err := maybeUpdateCertificates(logger, cfg, vaultClient, haproxyClients); err != nil { + if err := maybeUpdateCertificates(logger, cfg, certSource, haproxyClients); err != nil { logger.Error(err) } } } -func maybeUpdateCertificates(logger *logrus.Logger, cfg config.Config, vaultClient *vault.VaultClient, haproxyClients []*haproxy.Client) error { +func maybeUpdateCertificates(logger *logrus.Logger, cfg config.Config, certSource CertSource, haproxyClients []*haproxy.Client) error { var allErrs []error for _, haproxyClient := range haproxyClients { endpoint := haproxyClient.Endpoint() logger.Infof("Processing HAProxy endpoint: %s", endpoint) - if err := processHAProxyEndpoint(logger, cfg, vaultClient, haproxyClient); err != nil { + if err := processHAProxyEndpoint(logger, cfg, certSource, haproxyClient); err != nil { allErrs = append(allErrs, fmt.Errorf("endpoint %s: %w", endpoint, err)) logger.Errorf("Failed to process endpoint %s: %v", endpoint, err) } @@ -85,7 +93,7 @@ func maybeUpdateCertificates(logger *logrus.Logger, cfg config.Config, vaultClie return errors.Join(allErrs...) } -func processHAProxyEndpoint(logger *logrus.Logger, cfg config.Config, vaultClient *vault.VaultClient, haproxyClient *haproxy.Client) error { +func processHAProxyEndpoint(logger *logrus.Logger, cfg config.Config, certSource CertSource, haproxyClient *haproxy.Client) error { endpoint := haproxyClient.Endpoint() // Get list of certificates from HAProxy with file paths for lookups @@ -107,15 +115,23 @@ func processHAProxyEndpoint(logger *logrus.Logger, cfg config.Config, vaultClien for _, ref := range certRefs { certPath := ref.DisplayName + if certPath == "" { + certPath = filepath.Base(ref.FilePath) + } logger.Infof("[%s] Checking certificate: %s", endpoint, certPath) // Extract domain name from certificate path domain := haproxy.ExtractDomainFromPath(certPath) logger.Debugf("[%s] Extracted domain '%s' from path '%s'", endpoint, domain, certPath) - // Check if certificate needs update (uses Vault as source of truth for cert details, - // since HAProxy Data Plane API doesn't provide certificate metadata) - shouldUpdate, reason, isExpiring, err := shouldUpdateCertificate(domain, vaultClient, cfg.Certificatee.RenewBeforeDays) + certName := certPath + if strings.Contains(certName, "*") || strings.Contains(ref.FilePath, "*") { + certName = sanitizeWildcardCertName(certName) + } + + // Check if certificate needs update by comparing certificate source to + // Data Plane API metadata for the stored certificate. + shouldUpdate, reason, isExpiring, err := shouldUpdateCertificate(domain, certName, certSource, haproxyClient, cfg.Certificatee.RenewBeforeDays) if err != nil { errs = append(errs, err) logger.Errorf("[%s] %v", endpoint, err) @@ -130,7 +146,7 @@ func processHAProxyEndpoint(logger *logrus.Logger, cfg config.Config, vaultClien if shouldUpdate { logger.Infof("[%s] Certificate %s needs update: %s", endpoint, certPath, reason) - if err := updateCertificate(certPath, domain, vaultClient, haproxyClient); err != nil { + if err := updateCertificate(ref, domain, certSource, haproxyClient); err != nil { errs = append(errs, err) logger.Errorf("[%s] %v", endpoint, err) certmetrics.CertificatesUpdateFailures.WithLabelValues(endpoint, domain).Inc() @@ -149,22 +165,37 @@ func processHAProxyEndpoint(logger *logrus.Logger, cfg config.Config, vaultClien return errors.Join(errs...) } -func shouldUpdateCertificate(domain string, vaultClient *vault.VaultClient, renewBeforeDays int) (shouldUpdate bool, reason string, isExpiring bool, err error) { - // Get certificate from Vault - this is the source of truth for certificate details +func shouldUpdateCertificate(domain, certName string, certSource CertSource, haproxyClient *haproxy.Client, renewBeforeDays int) (shouldUpdate bool, reason string, isExpiring bool, err error) { + // Get certificate from source - this is the source of truth for certificate details // (HAProxy Data Plane API doesn't provide certificate metadata like expiry or serial) - vaultCert, err := certificate.GetCertificate(domain, vaultClient) + vaultCert, err := certSource.GetCertificate(domain) if err != nil { - return false, "", false, fmt.Errorf("failed to get certificate %s from vault: %w", domain, err) + return false, "", false, fmt.Errorf("failed to get certificate %s from source: %w", domain, err) } if vaultCert == nil { - return false, "", false, fmt.Errorf("certificate for %s does not exist in vault", domain) + return false, "", false, fmt.Errorf("certificate for %s does not exist in source", domain) } - // Check if Vault certificate is expiring + // Check if source certificate is expiring threshold := time.Now().AddDate(0, 0, renewBeforeDays) isExpiring = vaultCert.NotAfter.Before(threshold) + storedCert, err := haproxyClient.GetCertificateDetails(certName) + if err != nil { + return false, "", false, fmt.Errorf("failed to get certificate %s metadata: %w", certName, err) + } + + sourceSerial := haproxy.NormalizeSerial(vaultCert.SerialNumber.Text(16)) + storedSerial := haproxy.NormalizeSerial(storedCert.Serial) + if sourceSerial != storedSerial { + return true, fmt.Sprintf("serial differs (source %s vs haproxy %s)", sourceSerial, storedSerial), isExpiring, nil + } + + if !vaultCert.NotAfter.Equal(storedCert.NotAfter) { + return true, fmt.Sprintf("expiry differs (source %s vs haproxy %s)", vaultCert.NotAfter.Format(time.RFC3339), storedCert.NotAfter.Format(time.RFC3339)), isExpiring, nil + } + if isExpiring { // Certificate is expiring, sync to HAProxy (likely was recently renewed) return true, fmt.Sprintf("certificate expires on %s (within %d days)", vaultCert.NotAfter.Format(time.RFC3339), renewBeforeDays), true, nil @@ -173,51 +204,31 @@ func shouldUpdateCertificate(domain string, vaultClient *vault.VaultClient, rene return false, "", false, nil } -func updateCertificate(certPath, domain string, vaultClient *vault.VaultClient, haproxyClient *haproxy.Client) error { - // Read certificate data from Vault - certificateSecrets, err := vaultClient.KVRead(certificate.VaultCertLocation(domain)) - if err != nil { - return fmt.Errorf("failed to read certificate data from vault for %s: %w", domain, err) - } - +func updateCertificate(ref haproxy.CertificateRef, domain string, certSource CertSource, haproxyClient *haproxy.Client) error { // Build PEM bundle (certificate + private key) - pemData, err := buildPEMBundle(certificateSecrets) + pemData, err := certSource.GetPEMBundle(domain) if err != nil { - return fmt.Errorf("failed to build PEM bundle for %s: %w", domain, err) + return fmt.Errorf("failed to load PEM bundle for %s: %w", domain, err) } - // Update certificate in HAProxy - if err := haproxyClient.UpdateCertificate(certPath, pemData); err != nil { - return fmt.Errorf("failed to update certificate %s in HAProxy: %w", certPath, err) - } - - return nil -} - -// buildPEMBundle creates a PEM bundle from Vault certificate secrets -func buildPEMBundle(secrets map[string]any) (string, error) { - var pemData string - - // Add certificate - if cert, ok := secrets["certificate"].(string); ok && cert != "" { - pemData += cert - } else { - return "", fmt.Errorf("certificate not found in vault secrets") + certName := ref.DisplayName + if certName == "" { + certName = filepath.Base(ref.FilePath) } - // Add newline between cert and key - if !endsWith(pemData, "\n") { - pemData += "\n" + if strings.Contains(certName, "*") || strings.Contains(ref.FilePath, "*") { + if err := ensureStorageCertificate(haproxyClient, sanitizeWildcardCertName(certName), pemData); err != nil { + return err + } + return nil } - // Add private key - if key, ok := secrets["private_key"].(string); ok && key != "" { - pemData += key - } else { - return "", fmt.Errorf("private_key not found in vault secrets") + // Update certificate in HAProxy (storage API) + if err := haproxyClient.UpdateCertificate(certName, pemData); err != nil { + return fmt.Errorf("failed to update certificate %s in HAProxy: %w", certName, err) } - return pemData, nil + return nil } // endsWith checks if a string ends with a suffix diff --git a/flake.nix b/flake.nix index ca012fc..3bcbf86 100644 --- a/flake.nix +++ b/flake.nix @@ -35,16 +35,16 @@ # HAProxy Data Plane API - built from source dataplaneapi = pkgs.buildGoModule rec { pname = "dataplaneapi"; - version = "3.0.2"; + version = "2.9.21"; src = pkgs.fetchFromGitHub { owner = "haproxytech"; repo = "dataplaneapi"; rev = "v${version}"; - hash = "sha256-SFI7WKPxF31b97Q4EWbsTbp3laXHcUfdg4hlFUiml5A="; + hash = "sha256-HDSHdrObZopQtG7qHEv/NjKLkalF7hTRyuN7Vf6lHvY="; }; - vendorHash = "sha256-vm+NUf8OCW+jCiPY13d/MjQpy3/NxEwx7Zol2bP+eF4="; + vendorHash = "sha256-Mh9/C5V6Q/VJbPY4wqiXzzoZ0cs7hIqbdyTPxHe9GVA="; # Skip tests as they require network access doCheck = false; @@ -73,6 +73,30 @@ in { + checks.integration = pkgs.buildGoModule { + pname = "certificator-integration-tests"; + version = "0.0.0"; + src = ./.; + + vendorHash = "sha256-wqQj0P3cc9NX+gIFicGQUBi4+y5Fg/CuYPeFOwvJ8Jg="; + + subPackages = [ "cmd/certificatee" ]; + + nativeBuildInputs = [ + pkgs.haproxy + dataplaneapi + pkgs.openssl + ]; + + doCheck = true; + checkPhase = '' + export HOME=$TMPDIR + export GOCACHE=$TMPDIR/go-build + export PATH=${pkgs.haproxy}/bin:${dataplaneapi}/bin:$PATH + go test -tags=integration ./cmd/certificatee + ''; + }; + devShells.default = pkgs.devshell.mkShell { name = "certificator"; diff --git a/pkg/config/config.go b/pkg/config/config.go index 2f378a5..4697716 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -58,6 +58,8 @@ type Config struct { type Certificatee struct { UpdateInterval time.Duration `envconfig:"CERTIFICATEE_UPDATE_INTERVAL" default:"10m"` RenewBeforeDays int `envconfig:"CERTIFICATEE_RENEW_BEFORE_DAYS" default:"30"` + // LocalCertsDir points to a directory with .pem bundles (cert + key) to use instead of Vault. + LocalCertsDir string `envconfig:"CERTIFICATEE_LOCAL_CERTS_DIR" default:""` // HAProxyDataPlaneAPIURLs is a comma-separated list of HAProxy Data Plane API URLs // Example: "http://127.0.0.1:5555,https://haproxy2.local:5555" HAProxyDataPlaneAPIURLs []string `envconfig:"HAPROXY_DATAPLANE_API_URLS" default:"127.0.0.1:5555"` diff --git a/pkg/haproxy/client.go b/pkg/haproxy/client.go index ef6c035..ecaff06 100644 --- a/pkg/haproxy/client.go +++ b/pkg/haproxy/client.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "regexp" + "strconv" "strings" "time" @@ -137,6 +138,20 @@ type SSLCertificateEntry struct { Description string `json:"description"` } +// SSLCertificateDetails represents an SSL certificate entry with metadata. +type SSLCertificateDetails struct { + File string + StorageName string + Description string + Serial string + NotBefore time.Time + NotAfter time.Time + Domains []string + Issuers []string + Subject string + Size int64 +} + // CertificateRef holds both display name and file path for a certificate type CertificateRef struct { // DisplayName is the storage_name or filename for display purposes @@ -161,22 +176,10 @@ func (c *Client) ListCertificates() ([]string, error) { // ListCertificateRefs returns a list of certificate references with both display names and file paths func (c *Client) ListCertificateRefs() ([]CertificateRef, error) { - // Use storage API endpoint for listing SSL certificates - resp, err := c.doRequest("GET", "/v2/services/haproxy/storage/ssl_certificates", nil, "") + certs, err := c.listCertificateRefsV2() if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, errors.Errorf("failed to list certificates: status %d, body: %s", resp.StatusCode, string(body)) - } - - var certs []SSLCertificateEntry - if err := json.NewDecoder(resp.Body).Decode(&certs); err != nil { - return nil, errors.Wrap(err, "failed to decode certificate list") - } var refs []CertificateRef for _, cert := range certs { @@ -197,28 +200,100 @@ func (c *Client) ListCertificateRefs() ([]CertificateRef, error) { return refs, nil } -// UpdateCertificate uploads and commits a certificate update via Data Plane API -func (c *Client) UpdateCertificate(certName, pemData string) error { - // Create multipart form data - var buf bytes.Buffer - writer := multipart.NewWriter(&buf) +func (c *Client) listCertificateRefsV2() ([]SSLCertificateEntry, error) { + resp, err := c.doRequest("GET", "/v2/services/haproxy/storage/ssl_certificates", nil, "") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() - // Add file part - part, err := writer.CreateFormFile("file_upload", certName) + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, errors.Errorf("failed to list certificates: status %d, body: %s", resp.StatusCode, string(body)) + } + + var certs []SSLCertificateEntry + if err := json.NewDecoder(resp.Body).Decode(&certs); err != nil { + return nil, errors.Wrap(err, "failed to decode certificate list") + } + + return certs, nil +} + +func (c *Client) GetCertificateDetails(certName string) (*SSLCertificateDetails, error) { + path := fmt.Sprintf("/v2/services/haproxy/storage/ssl_certificates/%s", certName) + resp, err := c.doRequest("GET", path, nil, "") if err != nil { - return errors.Wrap(err, "failed to create form file") + return nil, err } - if _, err := part.Write([]byte(pemData)); err != nil { - return errors.Wrap(err, "failed to write certificate data") + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, errors.Errorf("failed to get certificate %s: status %d, body: %s", certName, resp.StatusCode, string(body)) } - if err := writer.Close(); err != nil { - return errors.Wrap(err, "failed to close multipart writer") + var payload struct { + File string `json:"file"` + StorageName string `json:"storage_name"` + Description string `json:"description"` + Serial string `json:"serial"` + NotBefore string `json:"not_before"` + NotAfter string `json:"not_after"` + Domains json.RawMessage `json:"domains"` + Issuers json.RawMessage `json:"issuers"` + Subject string `json:"subject"` + Size int64 `json:"size"` + } + + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return nil, errors.Wrap(err, "failed to decode certificate details") + } + + notBefore, err := parseDataPlaneTime(payload.NotBefore) + if err != nil { + return nil, errors.Wrap(err, "failed to parse not_before") + } + notAfter, err := parseDataPlaneTime(payload.NotAfter) + if err != nil { + return nil, errors.Wrap(err, "failed to parse not_after") + } + + domains, err := parseStringOrArray(payload.Domains) + if err != nil { + return nil, errors.Wrap(err, "failed to parse domains") + } + issuers, err := parseStringOrArray(payload.Issuers) + if err != nil { + return nil, errors.Wrap(err, "failed to parse issuers") } - // Send PUT request to replace certificate - path := fmt.Sprintf("/v2/services/haproxy/runtime/certs/%s", certName) - resp, err := c.doRequest("PUT", path, &buf, writer.FormDataContentType()) + return &SSLCertificateDetails{ + File: payload.File, + StorageName: payload.StorageName, + Description: payload.Description, + Serial: payload.Serial, + NotBefore: notBefore, + NotAfter: notAfter, + Domains: domains, + Issuers: issuers, + Subject: payload.Subject, + Size: payload.Size, + }, nil +} + +// UpdateCertificate uploads and commits a certificate update via Data Plane API +func (c *Client) UpdateCertificate(certName, pemData string) error { + version, err := c.getConfigurationVersion() + if err != nil { + return err + } + return c.updateCertificateStorageV2(certName, pemData, version) +} + +func (c *Client) updateCertificateStorageV2(certName, pemData string, version int) error { + path := fmt.Sprintf("/v2/services/haproxy/storage/ssl_certificates/%s?version=%d", certName, version) + resp, err := c.doRequest("PUT", path, strings.NewReader(pemData), "text/plain") if err != nil { return err } @@ -235,11 +310,14 @@ func (c *Client) UpdateCertificate(certName, pemData string) error { // CreateCertificate creates a new certificate entry via Data Plane API func (c *Client) CreateCertificate(certName, pemData string) error { + version, err := c.getConfigurationVersion() + if err != nil { + return err + } + // Create multipart form data var buf bytes.Buffer writer := multipart.NewWriter(&buf) - - // Add file part part, err := writer.CreateFormFile("file_upload", certName) if err != nil { return errors.Wrap(err, "failed to create form file") @@ -247,19 +325,18 @@ func (c *Client) CreateCertificate(certName, pemData string) error { if _, err := part.Write([]byte(pemData)); err != nil { return errors.Wrap(err, "failed to write certificate data") } - if err := writer.Close(); err != nil { return errors.Wrap(err, "failed to close multipart writer") } - // Send POST request to create certificate - resp, err := c.doRequest("POST", "/v2/services/haproxy/runtime/certs", &buf, writer.FormDataContentType()) + path := fmt.Sprintf("/v2/services/haproxy/storage/ssl_certificates?version=%d", version) + resp, err := c.doRequest("POST", path, &buf, writer.FormDataContentType()) if err != nil { return err } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { body, _ := io.ReadAll(resp.Body) return errors.Errorf("failed to create certificate %s: status %d, body: %s", certName, resp.StatusCode, string(body)) } @@ -270,7 +347,12 @@ func (c *Client) CreateCertificate(certName, pemData string) error { // DeleteCertificate deletes a certificate entry via Data Plane API func (c *Client) DeleteCertificate(certName string) error { - path := fmt.Sprintf("/v2/services/haproxy/runtime/certs/%s", certName) + version, err := c.getConfigurationVersion() + if err != nil { + return err + } + + path := fmt.Sprintf("/v2/services/haproxy/storage/ssl_certificates/%s?version=%d", certName, version) resp, err := c.doRequest("DELETE", path, nil, "") if err != nil { return err @@ -286,6 +368,83 @@ func (c *Client) DeleteCertificate(certName string) error { return nil } +// ListFrontends returns frontend names from HAProxy configuration. +func (c *Client) ListFrontends() ([]string, error) { + resp, err := c.doRequest("GET", "/v2/services/haproxy/configuration/frontends", nil, "") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, errors.Errorf("failed to list frontends: status %d, body: %s", resp.StatusCode, string(body)) + } + + data, err := decodeDataArray(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to decode frontends list") + } + + var names []string + for _, item := range data { + if name, ok := item["name"].(string); ok && name != "" { + names = append(names, name) + } + } + return names, nil +} + +// ListBinds returns bind objects for the given frontend. +func (c *Client) ListBinds(frontend string) ([]map[string]any, error) { + path := fmt.Sprintf("/v2/services/haproxy/configuration/binds?frontend=%s", frontend) + resp, err := c.doRequest("GET", path, nil, "") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, errors.Errorf("failed to list binds: status %d, body: %s", resp.StatusCode, string(body)) + } + + data, err := decodeDataArray(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to decode binds list") + } + + return data, nil +} + +// UpdateBind updates a bind in the HAProxy configuration. +func (c *Client) UpdateBind(frontend, bindName string, bind map[string]any) error { + version, err := c.getConfigurationVersion() + if err != nil { + return err + } + + bind["name"] = bindName + payload, err := json.Marshal(bind) + if err != nil { + return errors.Wrap(err, "failed to encode bind payload") + } + + path := fmt.Sprintf("/v2/services/haproxy/configuration/frontends/%s/binds/%s?version=%d", frontend, bindName, version) + resp, err := c.doRequest("PUT", path, strings.NewReader(string(payload)), "application/json") + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return errors.Errorf("failed to update bind %s on frontend %s: status %d, body: %s", bindName, frontend, resp.StatusCode, string(body)) + } + + return nil +} + // ExtractDomainFromPath extracts the domain name from a certificate path // Example: /etc/haproxy/certs/example.com.pem -> example.com func ExtractDomainFromPath(certPath string) string { @@ -322,6 +481,46 @@ func NormalizeSerial(serial string) string { return strings.ToUpper(re.ReplaceAllString(serial, "")) } +func parseDataPlaneTime(value string) (time.Time, error) { + if value == "" { + return time.Time{}, errors.New("empty time value") + } + layouts := []string{ + time.RFC3339, + time.RFC3339Nano, + "2006-01-02 15:04:05", + "2006-01-02 15:04:05 -0700", + "2006-01-02 15:04:05 -0700 MST", + } + for _, layout := range layouts { + if parsed, err := time.Parse(layout, value); err == nil { + return parsed, nil + } + } + return time.Time{}, errors.Errorf("unsupported time format %q", value) +} + +func parseStringOrArray(raw json.RawMessage) ([]string, error) { + if len(raw) == 0 || string(raw) == "null" { + return nil, nil + } + + var list []string + if err := json.Unmarshal(raw, &list); err == nil { + return list, nil + } + + var single string + if err := json.Unmarshal(raw, &single); err == nil { + if single == "" { + return nil, nil + } + return []string{single}, nil + } + + return nil, errors.Errorf("unsupported value %s", string(raw)) +} + // logrusLeveledLogger wraps a logrus.Logger to implement retryablehttp.LeveledLogger type logrusLeveledLogger struct { logger *logrus.Logger @@ -355,3 +554,74 @@ func toLogrusFields(keysAndValues []any) logrus.Fields { } return fields } + +func (c *Client) getConfigurationVersion() (int, error) { + resp, err := c.doRequest("GET", "/v2/services/haproxy/configuration/version", nil, "") + if err != nil { + return 0, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return 0, errors.Errorf("failed to get configuration version: status %d, body: %s", resp.StatusCode, string(body)) + } + + var versionValue any + if err := json.NewDecoder(resp.Body).Decode(&versionValue); err != nil { + return 0, errors.Wrap(err, "failed to decode configuration version") + } + + switch v := versionValue.(type) { + case float64: + return int(v), nil + case string: + return strconv.Atoi(v) + case map[string]any: + if raw, ok := v["_version"]; ok { + return parseVersionValue(raw) + } + if raw, ok := v["version"]; ok { + return parseVersionValue(raw) + } + } + + return 0, errors.New("unsupported configuration version response") +} + +func parseVersionValue(value any) (int, error) { + switch v := value.(type) { + case float64: + return int(v), nil + case string: + return strconv.Atoi(v) + default: + return 0, errors.New("unsupported configuration version value") + } +} + +func decodeDataArray(r io.Reader) ([]map[string]any, error) { + var envelope map[string]any + if err := json.NewDecoder(r).Decode(&envelope); err != nil { + return nil, err + } + + rawData, ok := envelope["data"] + if !ok { + return nil, errors.New("response missing data field") + } + + items, ok := rawData.([]any) + if !ok { + return nil, errors.New("data field is not a list") + } + + out := make([]map[string]any, 0, len(items)) + for _, item := range items { + if obj, ok := item.(map[string]any); ok { + out = append(out, obj) + } + } + + return out, nil +} diff --git a/pkg/haproxy/client_test.go b/pkg/haproxy/client_test.go index d952f7f..936422e 100644 --- a/pkg/haproxy/client_test.go +++ b/pkg/haproxy/client_test.go @@ -1,10 +1,6 @@ package haproxy import ( - "encoding/json" - "io" - "net/http" - "net/http/httptest" "strings" "testing" "time" @@ -219,15 +215,7 @@ func TestNewClients(t *testing.T) { wantErr: true, }, { - name: "single endpoint", - configs: []ClientConfig{ - {BaseURL: "http://localhost:5555", Username: "admin", Password: "secret"}, - }, - wantErr: false, - wantCount: 1, - }, - { - name: "multiple endpoints", + name: "valid configs", configs: []ClientConfig{ {BaseURL: "http://haproxy1:5555", Username: "admin", Password: "secret"}, {BaseURL: "http://haproxy2:5555", Username: "admin", Password: "secret"}, @@ -237,17 +225,17 @@ func TestNewClients(t *testing.T) { wantCount: 3, }, { - name: "with empty baseURLs", + name: "some invalid configs", configs: []ClientConfig{ {BaseURL: "http://haproxy1:5555", Username: "admin", Password: "secret"}, {BaseURL: ""}, {BaseURL: "http://haproxy2:5555", Username: "admin", Password: "secret"}, }, wantErr: false, - wantCount: 2, // empty baseURLs are skipped + wantCount: 2, }, { - name: "only empty baseURLs", + name: "all invalid configs", configs: []ClientConfig{ {BaseURL: ""}, {BaseURL: ""}, @@ -264,488 +252,12 @@ func TestNewClients(t *testing.T) { return } if !tt.wantErr && len(clients) != tt.wantCount { - t.Errorf("NewClients() returned %d clients, want %d", len(clients), tt.wantCount) - } - }) - } -} - -func TestClientEndpoint(t *testing.T) { - logger := logrus.New() - logger.SetLevel(logrus.PanicLevel) - - client, err := NewClient(ClientConfig{ - BaseURL: "http://localhost:5555", - Username: "admin", - Password: "secret", - }, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - if got := client.Endpoint(); got != "http://localhost:5555" { - t.Errorf("Endpoint() = %q, want %q", got, "http://localhost:5555") - } -} - -// mockDataPlaneAPI simulates the HAProxy Data Plane API -type mockDataPlaneAPI struct { - server *httptest.Server - handlers map[string]http.HandlerFunc - authRequired bool - username string - password string - t *testing.T -} - -// newMockDataPlaneAPI creates a mock Data Plane API server -func newMockDataPlaneAPI(t *testing.T) *mockDataPlaneAPI { - m := &mockDataPlaneAPI{ - handlers: make(map[string]http.HandlerFunc), - t: t, - } - - m.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check basic auth if required - if m.authRequired { - user, pass, ok := r.BasicAuth() - if !ok || user != m.username || pass != m.password { - w.WriteHeader(http.StatusUnauthorized) - return - } - } - - // Find matching handler by method + path - key := r.Method + " " + r.URL.Path - if handler, ok := m.handlers[key]; ok { - handler(w, r) - return - } - - // Try prefix matching for dynamic paths (e.g., /v2/services/haproxy/runtime/certs/example.com.pem) - for pattern, handler := range m.handlers { - if strings.HasPrefix(key, pattern) { - handler(w, r) - return - } - } - - // Default 404 - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "not found"}`)) - })) - - return m -} - -func (m *mockDataPlaneAPI) URL() string { - return m.server.URL -} - -func (m *mockDataPlaneAPI) Close() { - m.server.Close() -} - -func (m *mockDataPlaneAPI) SetAuth(username, password string) { - m.authRequired = true - m.username = username - m.password = password -} - -func (m *mockDataPlaneAPI) SetHandler(method, path string, handler http.HandlerFunc) { - m.handlers[method+" "+path] = handler -} - -func TestListCertificates(t *testing.T) { - logger := logrus.New() - logger.SetLevel(logrus.PanicLevel) - - tests := []struct { - name string - response []SSLCertificateEntry - statusCode int - want []string - wantErr bool - }{ - { - name: "normal response with multiple certs", - response: []SSLCertificateEntry{ - {File: "/etc/haproxy/certs/site1.pem", StorageName: "site1.pem"}, - {File: "/etc/haproxy/certs/site2.pem", StorageName: "site2.pem"}, - }, - statusCode: http.StatusOK, - want: []string{"site1.pem", "site2.pem"}, - wantErr: false, - }, - { - name: "empty response", - response: []SSLCertificateEntry{}, - statusCode: http.StatusOK, - want: nil, - wantErr: false, - }, - { - name: "certs with storage_name", - response: []SSLCertificateEntry{ - {StorageName: "example.com.pem"}, - {StorageName: "test.com.pem"}, - }, - statusCode: http.StatusOK, - want: []string{"example.com.pem", "test.com.pem"}, - wantErr: false, - }, - { - name: "single certificate", - response: []SSLCertificateEntry{ - {File: "/etc/haproxy/certs/only.pem", StorageName: "only.pem"}, - }, - statusCode: http.StatusOK, - want: []string{"only.pem"}, - wantErr: false, - }, - { - name: "server error", - response: nil, - statusCode: http.StatusInternalServerError, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mock := newMockDataPlaneAPI(t) - defer mock.Close() - - mock.SetHandler("GET", "/v2/services/haproxy/storage/ssl_certificates", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(tt.statusCode) - if tt.response != nil { - _ = json.NewEncoder(w).Encode(tt.response) - } else { - _, _ = w.Write([]byte(`{"message": "error"}`)) - } - }) - - client, err := NewClient(ClientConfig{BaseURL: mock.URL()}, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - got, err := client.ListCertificates() - if (err != nil) != tt.wantErr { - t.Errorf("ListCertificates() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if len(got) != len(tt.want) { - t.Errorf("ListCertificates() = %v, want %v", got, tt.want) - return - } - for i := range got { - if got[i] != tt.want[i] { - t.Errorf("ListCertificates()[%d] = %q, want %q", i, got[i], tt.want[i]) - } + t.Errorf("NewClients() got %d clients, want %d", len(clients), tt.wantCount) } }) } } -func TestUpdateCertificate(t *testing.T) { - logger := logrus.New() - logger.SetLevel(logrus.PanicLevel) - - tests := []struct { - name string - certName string - pemData string - statusCode int - wantErr bool - }{ - { - name: "success - certificate updated", - certName: "example.com.pem", - pemData: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----", - statusCode: http.StatusOK, - wantErr: false, - }, - { - name: "success - accepted", - certName: "example.com.pem", - pemData: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----", - statusCode: http.StatusAccepted, - wantErr: false, - }, - { - name: "error - not found", - certName: "notfound.pem", - pemData: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----", - statusCode: http.StatusNotFound, - wantErr: true, - }, - { - name: "error - bad request", - certName: "bad.pem", - pemData: "invalid pem", - statusCode: http.StatusBadRequest, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mock := newMockDataPlaneAPI(t) - defer mock.Close() - - mock.SetHandler("PUT", "/v2/services/haproxy/runtime/certs/"+tt.certName, func(w http.ResponseWriter, r *http.Request) { - // Verify content type is multipart - contentType := r.Header.Get("Content-Type") - if !strings.Contains(contentType, "multipart/form-data") { - t.Errorf("Expected multipart/form-data content type, got %s", contentType) - } - - // Read the multipart form - err := r.ParseMultipartForm(10 << 20) // 10 MB - if err != nil { - t.Errorf("Failed to parse multipart form: %v", err) - } - - // Verify file was uploaded - file, _, err := r.FormFile("file_upload") - if err != nil { - t.Errorf("Failed to get file from form: %v", err) - } else { - defer func() { _ = file.Close() }() - data, _ := io.ReadAll(file) - if string(data) != tt.pemData { - t.Errorf("File data = %q, want %q", string(data), tt.pemData) - } - } - - w.WriteHeader(tt.statusCode) - }) - - client, err := NewClient(ClientConfig{BaseURL: mock.URL()}, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - err = client.UpdateCertificate(tt.certName, tt.pemData) - if (err != nil) != tt.wantErr { - t.Errorf("UpdateCertificate() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestCreateCertificate(t *testing.T) { - logger := logrus.New() - logger.SetLevel(logrus.PanicLevel) - - tests := []struct { - name string - certName string - pemData string - statusCode int - wantErr bool - }{ - { - name: "success - certificate created", - certName: "new.example.com.pem", - pemData: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----", - statusCode: http.StatusCreated, - wantErr: false, - }, - { - name: "success - OK status", - certName: "new.example.com.pem", - pemData: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----", - statusCode: http.StatusOK, - wantErr: false, - }, - { - name: "error - conflict (already exists)", - certName: "existing.pem", - pemData: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----", - statusCode: http.StatusConflict, - wantErr: true, - }, - { - name: "error - bad request", - certName: "bad.pem", - pemData: "invalid pem", - statusCode: http.StatusBadRequest, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mock := newMockDataPlaneAPI(t) - defer mock.Close() - - mock.SetHandler("POST", "/v2/services/haproxy/runtime/certs", func(w http.ResponseWriter, r *http.Request) { - // Verify content type is multipart - contentType := r.Header.Get("Content-Type") - if !strings.Contains(contentType, "multipart/form-data") { - t.Errorf("Expected multipart/form-data content type, got %s", contentType) - } - - // Read the multipart form - err := r.ParseMultipartForm(10 << 20) // 10 MB - if err != nil { - t.Errorf("Failed to parse multipart form: %v", err) - } - - // Verify file was uploaded - file, header, err := r.FormFile("file_upload") - if err != nil { - t.Errorf("Failed to get file from form: %v", err) - } else { - defer func() { _ = file.Close() }() - if header.Filename != tt.certName { - t.Errorf("Filename = %q, want %q", header.Filename, tt.certName) - } - data, _ := io.ReadAll(file) - if string(data) != tt.pemData { - t.Errorf("File data = %q, want %q", string(data), tt.pemData) - } - } - - w.WriteHeader(tt.statusCode) - }) - - client, err := NewClient(ClientConfig{BaseURL: mock.URL()}, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - err = client.CreateCertificate(tt.certName, tt.pemData) - if (err != nil) != tt.wantErr { - t.Errorf("CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestDeleteCertificate(t *testing.T) { - logger := logrus.New() - logger.SetLevel(logrus.PanicLevel) - - tests := []struct { - name string - certName string - statusCode int - wantErr bool - }{ - { - name: "success - no content", - certName: "example.com.pem", - statusCode: http.StatusNoContent, - wantErr: false, - }, - { - name: "success - OK", - certName: "example.com.pem", - statusCode: http.StatusOK, - wantErr: false, - }, - { - name: "error - not found", - certName: "notfound.pem", - statusCode: http.StatusNotFound, - wantErr: true, - }, - { - name: "error - server error", - certName: "error.pem", - statusCode: http.StatusInternalServerError, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mock := newMockDataPlaneAPI(t) - defer mock.Close() - - mock.SetHandler("DELETE", "/v2/services/haproxy/runtime/certs/"+tt.certName, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(tt.statusCode) - }) - - client, err := NewClient(ClientConfig{BaseURL: mock.URL()}, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - err = client.DeleteCertificate(tt.certName) - if (err != nil) != tt.wantErr { - t.Errorf("DeleteCertificate() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestBasicAuth(t *testing.T) { - logger := logrus.New() - logger.SetLevel(logrus.PanicLevel) - - mock := newMockDataPlaneAPI(t) - defer mock.Close() - mock.SetAuth("admin", "secret") - - mock.SetHandler("GET", "/v2/services/haproxy/storage/ssl_certificates", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode([]SSLCertificateEntry{}) - }) - - t.Run("valid credentials", func(t *testing.T) { - client, err := NewClient(ClientConfig{ - BaseURL: mock.URL(), - Username: "admin", - Password: "secret", - }, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - _, err = client.ListCertificates() - if err != nil { - t.Errorf("ListCertificates() with valid auth error = %v", err) - } - }) - - t.Run("invalid credentials", func(t *testing.T) { - client, err := NewClient(ClientConfig{ - BaseURL: mock.URL(), - Username: "admin", - Password: "wrong", - }, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - _, err = client.ListCertificates() - if err == nil { - t.Error("ListCertificates() with invalid auth expected error, got nil") - } - }) - - t.Run("no credentials", func(t *testing.T) { - client, err := NewClient(ClientConfig{ - BaseURL: mock.URL(), - }, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - _, err = client.ListCertificates() - if err == nil { - t.Error("ListCertificates() without auth expected error, got nil") - } - }) -} - func TestConnectionError(t *testing.T) { logger := logrus.New() logger.SetLevel(logrus.PanicLevel) @@ -761,39 +273,6 @@ func TestConnectionError(t *testing.T) { _, err = client.ListCertificates() if err == nil { - t.Error("ListCertificates() expected connection error, got nil") - } -} - -func TestRetryOnConnectionFailure(t *testing.T) { - logger := logrus.New() - logger.SetLevel(logrus.PanicLevel) - - // Create client pointing to non-existent server - client, err := NewClient(ClientConfig{ - BaseURL: "http://127.0.0.1:59998", - Timeout: 50 * time.Millisecond, - }, logger) - if err != nil { - t.Fatalf("NewClient() error = %v", err) - } - - start := time.Now() - _, err = client.ListCertificates() - elapsed := time.Since(start) - - // Should fail after retries - if err == nil { - t.Error("Expected connection error, got nil") - } - - // Should have taken some time for retries (at least 2 retries with 10ms delays) - if elapsed < 10*time.Millisecond { - t.Errorf("Retries should have taken longer, elapsed: %v", elapsed) - } - - // Error message should mention retry attempts - if !strings.Contains(err.Error(), "after") { - t.Errorf("Error should mention retry attempts: %v", err) + t.Error("Expected error for connection, got nil") } }