diff --git a/cmd/gh-aw/main.go b/cmd/gh-aw/main.go index 12b29b4cab6..9a5c6aa5c37 100644 --- a/cmd/gh-aw/main.go +++ b/cmd/gh-aw/main.go @@ -296,8 +296,8 @@ Examples: return err } - // Check for updates (non-blocking, runs once per day) - cli.CheckForUpdatesAsync(cmd.Context(), noCheckUpdate, verbose) + finishCompileUpdateCheck := cli.StartCompileUpdateCheck(cmd.Context(), noCheckUpdate, verbose) + defer finishCompileUpdateCheck() // If --fix is specified, run fix --write first if fix { diff --git a/pkg/cli/compile_update_check.go b/pkg/cli/compile_update_check.go new file mode 100644 index 00000000000..e7ea089e5fe --- /dev/null +++ b/pkg/cli/compile_update_check.go @@ -0,0 +1,384 @@ +package cli + +import ( + "context" + "fmt" + "net/http" + "os" + "path" + "strconv" + "strings" + "time" + + "golang.org/x/mod/semver" + + "github.com/github/gh-aw/pkg/console" + "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/tty" + "github.com/github/gh-aw/pkg/workflow" +) + +var compileUpdateCheckLog = logger.New("cli:update_check") + +const ( + compileUpdateCheckDisableEnv = "GH_AW_DISABLE_UPDATE_CHECK" + compileUpdateCheckFileName = "gh-aw-last-compile-update-check" + compileUpdateCheckInterval = 24 * time.Hour + compileUpdateCheckTimeout = 3 * time.Second + compileUpdateCheckNoWait = 0 +) + +var ( + compileUpdateCheckLatestReleaseURL = "https://github.com/github/gh-aw/releases/latest" + compileUpdateCheckProbeURLFunc = func(tag string) string { + return fmt.Sprintf("https://raw.githubusercontent.com/github/gh-aw/refs/tags/%s/go.mod", tag) + } + compileUpdateCheckHTTPClientFactory = func() *http.Client { + return &http.Client{Timeout: compileUpdateCheckTimeout} + } + compileUpdateCheckIsTerminalFunc = tty.IsStderrTerminal + getCompileUpdateCheckFilePathFunc = getCompileUpdateCheckFilePathImpl +) + +type compileUpdateNotificationKind string + +const ( + compileUpdateNotificationMinorBehind compileUpdateNotificationKind = "minor_behind" + compileUpdateNotificationRemovedTag compileUpdateNotificationKind = "removed_tag" +) + +type compileUpdateNotification struct { + Kind compileUpdateNotificationKind + CurrentVersion string + LatestVersion string +} + +// StartCompileUpdateCheck begins a best-effort update check for the compile command. +// The returned function should be called once before the command exits to print any +// ready notification without blocking compilation for long. +func StartCompileUpdateCheck(ctx context.Context, noCheckUpdate bool, verbose bool) func() { + if !shouldRunCompileUpdateCheck(noCheckUpdate) { + return func() {} + } + updateCompileUpdateCheckTime() + + results := make(chan *compileUpdateNotification, 1) // buffered channel closed by sender goroutine via defer + + go func() { + defer close(results) + defer func() { + if r := recover(); r != nil { + compileUpdateCheckLog.Printf("Panic in compile update check (recovered): %v", r) + } + }() + + if ctx.Err() != nil { + compileUpdateCheckLog.Printf("Compile update check cancelled before starting: %v", ctx.Err()) + return + } + + result, err := runCompileUpdateCheck(ctx, compileUpdateCheckHTTPClientFactory()) + if err != nil { + compileUpdateCheckLog.Printf("Compile update check failed (ignoring): %v", err) + return + } + if result == nil { + if verbose { + compileUpdateCheckLog.Print("No compile update notification needed") + } + return + } + + select { + case results <- result: + case <-ctx.Done(): + } + }() + + return func() { + result := waitForCompileUpdateNotification(ctx, results, compileUpdateCheckNoWait) + if result != nil { + printCompileUpdateNotification(result) + } + } +} + +func shouldRunCompileUpdateCheck(noCheckUpdate bool) bool { + if noCheckUpdate { + compileUpdateCheckLog.Print("Update check disabled via --no-check-update flag") + return false + } + if os.Getenv(compileUpdateCheckDisableEnv) != "" { + compileUpdateCheckLog.Printf("Update check disabled via %s", compileUpdateCheckDisableEnv) + return false + } + if IsRunningInCI() { + compileUpdateCheckLog.Print("Update check disabled in CI environment") + return false + } + if isRunningAsMCPServer() { + compileUpdateCheckLog.Print("Update check disabled in MCP server mode") + return false + } + if !compileUpdateCheckIsTerminalFunc() { + compileUpdateCheckLog.Print("Update check disabled when stderr is not a terminal") + return false + } + + lastCheckFile := getCompileUpdateCheckFilePath() + if lastCheckFile == "" { + compileUpdateCheckLog.Print("Could not determine compile update check file path") + return false + } + + data, err := os.ReadFile(lastCheckFile) + if err != nil { + if !os.IsNotExist(err) { + compileUpdateCheckLog.Printf("Error reading compile update check file: %v", err) + } + return true + } + + lastCheck, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data))) + if err != nil { + compileUpdateCheckLog.Printf("Error parsing compile update check time: %v", err) + return true + } + + elapsed := time.Since(lastCheck) + if elapsed < compileUpdateCheckInterval { + compileUpdateCheckLog.Printf("Last compile update check was %v ago, skipping", elapsed) + return false + } + return true +} + +func waitForCompileUpdateNotification(ctx context.Context, results <-chan *compileUpdateNotification, timeout time.Duration) *compileUpdateNotification { + if results == nil { + return nil + } + + if timeout <= 0 { + select { + case result, ok := <-results: + if !ok { + return nil + } + return result + default: + return nil + } + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case result, ok := <-results: + if !ok { + return nil + } + return result + case <-timer.C: + return nil + case <-ctx.Done(): + return nil + } +} + +func runCompileUpdateCheck(ctx context.Context, client *http.Client) (*compileUpdateNotification, error) { + currentVersion := GetVersion() + if !workflow.IsReleasedVersion(currentVersion) { + compileUpdateCheckLog.Print("Not a released version, skipping update check") + return nil, nil + } + + latestVersion, err := fetchLatestReleaseTag(ctx, client) + if err != nil { + return nil, err + } + if latestVersion == "" { + return nil, nil + } + + latestTagExists, err := downloadReleaseProbeFile(ctx, client, latestVersion) + if err != nil { + return nil, err + } + if !latestTagExists { + compileUpdateCheckLog.Printf("Latest release tag %s did not expose the probe file; skipping", latestVersion) + return nil, nil + } + + currentTagExists, err := downloadReleaseProbeFile(ctx, client, currentVersion) + if err != nil { + return nil, err + } + if !currentTagExists { + return &compileUpdateNotification{ + Kind: compileUpdateNotificationRemovedTag, + CurrentVersion: currentVersion, + LatestVersion: latestVersion, + }, nil + } + + if !isMinorVersionBehind(currentVersion, latestVersion) { + return nil, nil + } + + return &compileUpdateNotification{ + Kind: compileUpdateNotificationMinorBehind, + CurrentVersion: currentVersion, + LatestVersion: latestVersion, + }, nil +} + +func fetchLatestReleaseTag(ctx context.Context, client *http.Client) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, compileUpdateCheckLatestReleaseURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create latest release request: %w", err) + } + req.Header.Set("User-Agent", "gh-aw/"+GetVersion()) + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to query latest release: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("latest release request returned status %d", resp.StatusCode) + } + + finalPath := resp.Request.URL.Path + if !strings.Contains(finalPath, "/releases/tag/") { + return "", fmt.Errorf("unexpected latest release path %q", finalPath) + } + + tag := path.Base(finalPath) + if tag == "" || tag == "." || tag == "latest" { + return "", fmt.Errorf("could not determine latest release tag from %q", finalPath) + } + + return tag, nil +} + +func downloadReleaseProbeFile(ctx context.Context, client *http.Client, tag string) (bool, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, compileUpdateCheckProbeURLFunc(tag), nil) + if err != nil { + return false, fmt.Errorf("failed to create probe request for %s: %w", tag, err) + } + req.Header.Set("User-Agent", "gh-aw/"+GetVersion()) + + resp, err := client.Do(req) + if err != nil { + return false, fmt.Errorf("failed to download probe file for %s: %w", tag, err) + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + return true, nil + case http.StatusNotFound: + return false, nil + default: + return false, fmt.Errorf("probe download for %s returned status %d", tag, resp.StatusCode) + } +} + +func getCompileUpdateCheckFilePath() string { + return getCompileUpdateCheckFilePathFunc() +} + +func getCompileUpdateCheckFilePathImpl() string { + return getLastCheckFilePathFor(compileUpdateCheckFileName) +} + +func updateCompileUpdateCheckTime() { + lastCheckFile := getCompileUpdateCheckFilePath() + if lastCheckFile == "" { + return + } + + timestamp := time.Now().Format(time.RFC3339) + if err := os.WriteFile(lastCheckFile, []byte(timestamp), 0600); err != nil { + compileUpdateCheckLog.Printf("Error writing compile update check time: %v", err) + } +} + +func isMinorVersionBehind(currentVersion string, latestVersion string) bool { + currentSV := ensureSemverPrefix(currentVersion) + latestSV := ensureSemverPrefix(latestVersion) + + if !semver.IsValid(currentSV) || !semver.IsValid(latestSV) { + return false + } + if semver.Compare(currentSV, latestSV) >= 0 { + return false + } + + currentMajor, currentMinor, ok := semverMajorMinorParts(currentSV) + if !ok { + return false + } + latestMajor, latestMinor, ok := semverMajorMinorParts(latestSV) + if !ok { + return false + } + + return currentMajor == latestMajor && latestMinor > currentMinor +} + +func semverMajorMinorParts(version string) (int, int, bool) { + trimmed := strings.TrimPrefix(version, "v") + trimmed = strings.SplitN(trimmed, "-", 2)[0] + trimmed = strings.SplitN(trimmed, "+", 2)[0] + + parts := strings.Split(trimmed, ".") + if len(parts) < 2 { + return 0, 0, false + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, false + } + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, false + } + + return major, minor, true +} + +func ensureSemverPrefix(version string) string { + if strings.HasPrefix(version, "v") { + return version + } + return "v" + version +} + +func printCompileUpdateNotification(notification *compileUpdateNotification) { + if notification == nil { + return + } + + fmt.Fprintln(os.Stderr) + + switch notification.Kind { + case compileUpdateNotificationRemovedTag: + fmt.Fprintln(os.Stderr, console.FormatWarningMessage(fmt.Sprintf( + "The installed gh-aw compiler version %s is no longer available as a repository tag.", notification.CurrentVersion, + ))) + fmt.Fprintln(os.Stderr, console.FormatWarningMessage(fmt.Sprintf( + "Update the compiler before recompiling workflows (latest release: %s).", notification.LatestVersion, + ))) + default: + fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf( + "Compiler upgrade recommended: gh-aw %s is behind the latest release %s.", notification.CurrentVersion, notification.LatestVersion, + ))) + fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Hint: upgrade the compiler with: gh extension upgrade github/gh-aw")) + } + + fmt.Fprintln(os.Stderr) +} diff --git a/pkg/cli/compile_update_check_test.go b/pkg/cli/compile_update_check_test.go new file mode 100644 index 00000000000..cba3aa0f270 --- /dev/null +++ b/pkg/cli/compile_update_check_test.go @@ -0,0 +1,446 @@ +//go:build !integration + +package cli + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "slices" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/github/gh-aw/pkg/workflow" +) + +func TestShouldRunCompileUpdateCheck(t *testing.T) { + origGetFilePath := getCompileUpdateCheckFilePathFunc + origIsTerminal := compileUpdateCheckIsTerminalFunc + t.Cleanup(func() { + getCompileUpdateCheckFilePathFunc = origGetFilePath + compileUpdateCheckIsTerminalFunc = origIsTerminal + }) + + tmpDir := t.TempDir() + lastCheckFile := filepath.Join(tmpDir, compileUpdateCheckFileName) + getCompileUpdateCheckFilePathFunc = func() string { + return lastCheckFile + } + compileUpdateCheckIsTerminalFunc = func() bool { + return true + } + + t.Setenv("CI", "") + t.Setenv("CONTINUOUS_INTEGRATION", "") + t.Setenv("GITHUB_ACTIONS", "") + t.Setenv("GH_AW_MCP_SERVER", "") + t.Setenv(compileUpdateCheckDisableEnv, "") + assert.True(t, shouldRunCompileUpdateCheck(false), "check should run when not disabled") + + require.NoError( + t, + os.WriteFile(lastCheckFile, []byte(time.Now().Format(time.RFC3339)), 0600), + "recent compile update marker should be written", + ) + assert.False(t, shouldRunCompileUpdateCheck(false), "recent marker should suppress the background check") + + t.Setenv(compileUpdateCheckDisableEnv, "1") + assert.False(t, shouldRunCompileUpdateCheck(false), "check should be disabled by environment variable") + + t.Setenv(compileUpdateCheckDisableEnv, "") + assert.False(t, shouldRunCompileUpdateCheck(true), "check should be disabled by flag") + + compileUpdateCheckIsTerminalFunc = func() bool { + return false + } + assert.False(t, shouldRunCompileUpdateCheck(false), "check should be disabled in non-interactive environments") +} + +func TestRunCompileUpdateCheck(t *testing.T) { + originalVersion := GetVersion() + originalRelease := workflow.IsRelease() + originalLatestURL := compileUpdateCheckLatestReleaseURL + originalProbeURLFunc := compileUpdateCheckProbeURLFunc + defer func() { + SetVersionInfo(originalVersion) + workflow.SetIsRelease(originalRelease) + compileUpdateCheckLatestReleaseURL = originalLatestURL + compileUpdateCheckProbeURLFunc = originalProbeURLFunc + }() + + tests := []struct { + name string + currentVersion string + latestVersion string + existingTags map[string]bool + expected *compileUpdateNotification + }{ + { + name: "returns minor version upgrade hint", + currentVersion: "v1.2.3", + latestVersion: "v1.3.0", + existingTags: map[string]bool{ + "v1.2.3": true, + "v1.3.0": true, + }, + expected: &compileUpdateNotification{ + Kind: compileUpdateNotificationMinorBehind, + CurrentVersion: "v1.2.3", + LatestVersion: "v1.3.0", + }, + }, + { + name: "returns prominent notice when current tag is missing", + currentVersion: "v1.2.3", + latestVersion: "v1.3.0", + existingTags: map[string]bool{ + "v1.3.0": true, + }, + expected: &compileUpdateNotification{ + Kind: compileUpdateNotificationRemovedTag, + CurrentVersion: "v1.2.3", + LatestVersion: "v1.3.0", + }, + }, + { + name: "ignores patch-only difference", + currentVersion: "v1.2.3", + latestVersion: "v1.2.4", + existingTags: map[string]bool{ + "v1.2.3": true, + "v1.2.4": true, + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := newCompileUpdateCheckTestServer(t, tt.latestVersion, tt.existingTags) + defer server.Close() + + SetVersionInfo(tt.currentVersion) + workflow.SetIsRelease(true) + compileUpdateCheckLatestReleaseURL = server.URL + "/releases/latest" + compileUpdateCheckProbeURLFunc = func(tag string) string { + return fmt.Sprintf("%s/raw/%s/go.mod", server.URL, tag) + } + + got, err := runCompileUpdateCheck(context.Background(), server.Client()) + require.NoError(t, err, "runCompileUpdateCheck should not fail") + assert.Equal(t, tt.expected, got, "unexpected compile update notification") + }) + } +} + +func TestPrintCompileUpdateNotification(t *testing.T) { + tests := []struct { + name string + notification *compileUpdateNotification + expected []string + }{ + { + name: "minor version behind", + notification: &compileUpdateNotification{ + Kind: compileUpdateNotificationMinorBehind, + CurrentVersion: "v1.2.3", + LatestVersion: "v1.3.0", + }, + expected: []string{ + "Compiler upgrade recommended: gh-aw v1.2.3 is behind the latest release v1.3.0.", + "Hint: upgrade the compiler with: gh extension upgrade github/gh-aw", + }, + }, + { + name: "removed tag warning", + notification: &compileUpdateNotification{ + Kind: compileUpdateNotificationRemovedTag, + CurrentVersion: "v1.2.3", + LatestVersion: "v1.3.0", + }, + expected: []string{ + "The installed gh-aw compiler version v1.2.3 is no longer available as a repository tag.", + "Update the compiler before recompiling workflows (latest release: v1.3.0).", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oldStderr := os.Stderr + r, w, err := os.Pipe() + require.NoError(t, err, "pipe creation should succeed") + defer r.Close() + os.Stderr = w + + printCompileUpdateNotification(tt.notification) + + require.NoError(t, w.Close(), "pipe writer should close cleanly") + os.Stderr = oldStderr + + var buf bytes.Buffer + _, err = buf.ReadFrom(r) + require.NoError(t, err, "pipe reader should capture stderr output") + output := buf.String() + + for _, expected := range tt.expected { + assert.Contains(t, output, expected, "output should contain expected message") + } + }) + } +} + +func TestRunCompileUpdateCheckUsesHEADRequests(t *testing.T) { + originalVersion := GetVersion() + originalRelease := workflow.IsRelease() + originalLatestURL := compileUpdateCheckLatestReleaseURL + originalProbeURLFunc := compileUpdateCheckProbeURLFunc + defer func() { + SetVersionInfo(originalVersion) + workflow.SetIsRelease(originalRelease) + compileUpdateCheckLatestReleaseURL = originalLatestURL + compileUpdateCheckProbeURLFunc = originalProbeURLFunc + }() + + SetVersionInfo("v1.2.3") + workflow.SetIsRelease(true) + + server, methods := newCompileUpdateCheckMethodServer(t, "v1.3.0", map[string]bool{ + "v1.2.3": true, + "v1.3.0": true, + }) + defer server.Close() + + compileUpdateCheckLatestReleaseURL = server.URL + "/releases/latest" + compileUpdateCheckProbeURLFunc = func(tag string) string { + return fmt.Sprintf("%s/raw/%s/go.mod", server.URL, tag) + } + + notification, err := runCompileUpdateCheck(context.Background(), server.Client()) + require.NoError(t, err, "runCompileUpdateCheck should not fail") + require.NotNil(t, notification, "runCompileUpdateCheck should return a notification") + + assert.Equal(t, []string{http.MethodHead}, methodsForPath(methods, "/releases/latest"), "latest release lookup should use HEAD") + + probeMethods := methodsForPrefix(methods, "/raw/") + require.NotEmpty(t, probeMethods, "probe lookups should be recorded") + for _, method := range probeMethods { + assert.Equal(t, http.MethodHead, method, "probe lookups should use HEAD") + } +} + +func TestStartCompileUpdateCheckDoesNotBlockShutdown(t *testing.T) { + originalClientFactory := compileUpdateCheckHTTPClientFactory + originalGetFilePath := getCompileUpdateCheckFilePathFunc + originalIsTerminal := compileUpdateCheckIsTerminalFunc + originalVersion := GetVersion() + originalRelease := workflow.IsRelease() + defer func() { + compileUpdateCheckHTTPClientFactory = originalClientFactory + getCompileUpdateCheckFilePathFunc = originalGetFilePath + compileUpdateCheckIsTerminalFunc = originalIsTerminal + SetVersionInfo(originalVersion) + workflow.SetIsRelease(originalRelease) + }() + + tempDir := t.TempDir() + SetVersionInfo("v1.2.3") + workflow.SetIsRelease(true) + t.Setenv("CI", "") + t.Setenv("CONTINUOUS_INTEGRATION", "") + t.Setenv("GITHUB_ACTIONS", "") + t.Setenv("GH_AW_MCP_SERVER", "") + t.Setenv(compileUpdateCheckDisableEnv, "") + getCompileUpdateCheckFilePathFunc = func() string { + return filepath.Join(tempDir, compileUpdateCheckFileName) + } + compileUpdateCheckIsTerminalFunc = func() bool { + return true + } + + unblockRequest := make(chan struct{}) // cleanup closes this to unblock any pending request + requestStarted := make(chan struct{}, 1) + t.Cleanup(func() { + close(unblockRequest) + }) + + compileUpdateCheckHTTPClientFactory = func() *http.Client { + return &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + select { + case requestStarted <- struct{}{}: + default: + } + <-unblockRequest + return nil, context.DeadlineExceeded + }), + } + } + + finish := StartCompileUpdateCheck(context.Background(), false, false) + select { + case <-requestStarted: + case <-time.After(time.Second): + t.Fatal("background update check did not start an HTTP request") + } + + start := time.Now() + finish() + assert.Less(t, time.Since(start), 100*time.Millisecond, "finish should not wait for a background update check") +} + +func TestStartCompileUpdateCheckSilentlyHandlesLockedDownNetwork(t *testing.T) { + originalClientFactory := compileUpdateCheckHTTPClientFactory + originalGetFilePath := getCompileUpdateCheckFilePathFunc + originalIsTerminal := compileUpdateCheckIsTerminalFunc + originalVersion := GetVersion() + originalRelease := workflow.IsRelease() + defer func() { + compileUpdateCheckHTTPClientFactory = originalClientFactory + getCompileUpdateCheckFilePathFunc = originalGetFilePath + compileUpdateCheckIsTerminalFunc = originalIsTerminal + SetVersionInfo(originalVersion) + workflow.SetIsRelease(originalRelease) + }() + + tempDir := t.TempDir() + SetVersionInfo("v1.2.3") + workflow.SetIsRelease(true) + t.Setenv("CI", "") + t.Setenv("CONTINUOUS_INTEGRATION", "") + t.Setenv("GITHUB_ACTIONS", "") + t.Setenv("GH_AW_MCP_SERVER", "") + t.Setenv(compileUpdateCheckDisableEnv, "") + getCompileUpdateCheckFilePathFunc = func() string { + return filepath.Join(tempDir, compileUpdateCheckFileName) + } + compileUpdateCheckIsTerminalFunc = func() bool { + return true + } + requestStarted := make(chan struct{}, 1) + compileUpdateCheckHTTPClientFactory = func() *http.Client { + return &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + select { + case requestStarted <- struct{}{}: + default: + } + return nil, context.DeadlineExceeded + }), + } + } + + oldStderr := os.Stderr + r, w, err := os.Pipe() + require.NoError(t, err, "pipe creation should succeed") + defer r.Close() + os.Stderr = w + defer func() { + os.Stderr = oldStderr + }() + + finish := StartCompileUpdateCheck(context.Background(), false, false) + select { + case <-requestStarted: + case <-time.After(time.Second): + t.Fatal("background update check did not attempt its network request") + } + finish() + + require.NoError(t, w.Close(), "pipe writer should close cleanly") + + var buf bytes.Buffer + _, err = buf.ReadFrom(r) + require.NoError(t, err, "pipe reader should capture stderr output") + assert.Empty(t, strings.TrimSpace(buf.String()), "locked-down network failures should not print user-facing output") +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func newCompileUpdateCheckTestServer(t *testing.T, latestVersion string, existingTags map[string]bool) *httptest.Server { + t.Helper() + + mux := http.NewServeMux() + mux.HandleFunc("/releases/latest", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/github/gh-aw/releases/tag/"+latestVersion, http.StatusFound) + }) + mux.HandleFunc("/github/gh-aw/releases/tag/"+latestVersion, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("release page")) + }) + mux.HandleFunc("/raw/", func(w http.ResponseWriter, r *http.Request) { + tag := r.URL.Path[len("/raw/"):] + tag = tag[:len(tag)-len("/go.mod")] + if existingTags[tag] { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("module github.com/github/gh-aw\n")) + return + } + http.NotFound(w, r) + }) + + return httptest.NewServer(mux) +} + +func newCompileUpdateCheckMethodServer(t *testing.T, latestVersion string, existingTags map[string]bool) (*httptest.Server, map[string][]string) { + t.Helper() + + var mu sync.Mutex + methods := map[string][]string{} + record := func(path string, method string) { + mu.Lock() + defer mu.Unlock() + methods[path] = append(methods[path], method) + } + + mux := http.NewServeMux() + mux.HandleFunc("/releases/latest", func(w http.ResponseWriter, r *http.Request) { + record(r.URL.Path, r.Method) + http.Redirect(w, r, "/github/gh-aw/releases/tag/"+latestVersion, http.StatusFound) + }) + mux.HandleFunc("/github/gh-aw/releases/tag/"+latestVersion, func(w http.ResponseWriter, r *http.Request) { + record(r.URL.Path, r.Method) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("release page")) + }) + mux.HandleFunc("/raw/", func(w http.ResponseWriter, r *http.Request) { + record(r.URL.Path, r.Method) + tag := r.URL.Path[len("/raw/"):] + tag = tag[:len(tag)-len("/go.mod")] + if existingTags[tag] { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("module github.com/github/gh-aw\n")) + return + } + http.NotFound(w, r) + }) + + return httptest.NewServer(mux), methods +} + +func methodsForPath(methods map[string][]string, path string) []string { + return slices.Clone(methods[path]) +} + +func methodsForPrefix(methods map[string][]string, prefix string) []string { + var collected []string + for path, values := range methods { + if strings.HasPrefix(path, prefix) { + collected = append(collected, values...) + } + } + return collected +} diff --git a/pkg/cli/update_check.go b/pkg/cli/update_check.go index 97bfb370a5f..7b59207f4a3 100644 --- a/pkg/cli/update_check.go +++ b/pkg/cli/update_check.go @@ -112,6 +112,10 @@ func getLastCheckFilePath() string { // getLastCheckFilePathImpl is the actual implementation func getLastCheckFilePathImpl() string { + return getLastCheckFilePathFor(lastCheckFileName) +} + +func getLastCheckFilePathFor(fileName string) string { // Use OS temp directory for cross-platform compatibility tmpDir := os.TempDir() if tmpDir == "" { @@ -126,7 +130,7 @@ func getLastCheckFilePathImpl() string { return "" } - return filepath.Join(ghAwTmpDir, lastCheckFileName) + return filepath.Join(ghAwTmpDir, fileName) } // updateLastCheckTime updates the timestamp of the last update check