diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8e8da80 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/githubnext/apm + +go 1.24.13 diff --git a/internal/utils/console/console.go b/internal/utils/console/console.go new file mode 100644 index 0000000..76b4195 --- /dev/null +++ b/internal/utils/console/console.go @@ -0,0 +1,156 @@ +// Package console provides console utility functions for formatted CLI output. +// +// All output is within printable ASCII (U+0020-U+007E). Color codes use ANSI +// escape sequences, disabled automatically when NO_COLOR is set or TERM=dumb. +package console + +import ( + "fmt" + "io" + "os" + "strings" +) + +// StatusSymbols maps semantic names to ASCII bracket notation. +var StatusSymbols = map[string]string{ + "success": "[*]", + "sparkles": "[*]", + "running": "[>]", + "gear": "[*]", + "info": "[i]", + "warning": "[!]", + "error": "[x]", + "check": "[+]", + "cross": "[x]", + "list": "[#]", + "preview": "[>]", + "robot": "[>]", + "metrics": "[#]", + "default": "[>]", + "eyes": "[>]", + "folder": "[>]", + "cogs": "[*]", + "plugin": "[>]", + "search": "[>]", + "download": "[>]", + "update": "[~]", + "remove": "[-]", + "equal": "[=]", +} + +// ANSI color codes. +const ( + ansiReset = "\033[0m" + ansiRed = "\033[31m" + ansiGreen = "\033[32m" + ansiYellow = "\033[33m" + ansiBlue = "\033[34m" + ansiCyan = "\033[36m" + ansiBold = "\033[1m" +) + +// colorEnabled returns true when ANSI color output is supported. +func colorEnabled() bool { + if os.Getenv("NO_COLOR") != "" { + return false + } + if os.Getenv("TERM") == "dumb" { + return false + } + return true +} + +// Echo writes a message to w (defaults to os.Stdout) with optional color and +// symbol prefix. color may be "red", "green", "yellow", "blue", "cyan", or +// empty for default terminal color. +func Echo(w io.Writer, message, color, symbol string, bold bool) { + if w == nil { + w = os.Stdout + } + if sym, ok := StatusSymbols[symbol]; ok && symbol != "" { + message = sym + " " + message + } + if colorEnabled() && color != "" { + code := colorCode(color) + if bold { + fmt.Fprintf(w, "%s%s%s%s\n", ansiBold, code, message, ansiReset) + } else { + fmt.Fprintf(w, "%s%s%s\n", code, message, ansiReset) + } + } else { + fmt.Fprintln(w, message) + } +} + +func colorCode(color string) string { + switch strings.ToLower(color) { + case "red": + return ansiRed + case "green": + return ansiGreen + case "yellow": + return ansiYellow + case "blue": + return ansiBlue + case "cyan": + return ansiCyan + default: + return "" + } +} + +// Success prints a success message (green, bold). +func Success(message, symbol string) { + Echo(os.Stdout, message, "green", symbol, true) +} + +// Error prints an error message (red). +func Error(message, symbol string) { + Echo(os.Stderr, message, "red", symbol, false) +} + +// Warning prints a warning message (yellow). +func Warning(message, symbol string) { + Echo(os.Stdout, message, "yellow", symbol, false) +} + +// Info prints an info message (blue). +func Info(message, symbol string) { + Echo(os.Stdout, message, "blue", symbol, false) +} + +// Panel prints content framed by a simple ASCII border with an optional title. +func Panel(content, title, style string) { + if title != "" { + fmt.Printf("\n--- %s ---\n", title) + } + fmt.Println(content) + if title != "" { + fmt.Println(strings.Repeat("-", len(title)+8)) + } +} + +// PrintFilesTable prints a simple two-column table of file name + description. +func PrintFilesTable(files [][]string, tableTitle string) { + if tableTitle != "" { + fmt.Println(tableTitle) + } + for _, row := range files { + name := "" + desc := "" + if len(row) > 0 { + name = row[0] + } + if len(row) > 1 { + desc = row[1] + } + fmt.Printf(" %-40s %s\n", name, desc) + } +} + +// DownloadSpinner prints a simple download-in-progress message and calls fn. +// Unlike Python's context-manager spinner, this is a function-based helper. +func DownloadSpinner(repoName string, fn func()) { + fmt.Printf("[>] Downloading %s...\n", repoName) + fn() +} diff --git a/internal/utils/console/console_test.go b/internal/utils/console/console_test.go new file mode 100644 index 0000000..8b7f3c2 --- /dev/null +++ b/internal/utils/console/console_test.go @@ -0,0 +1,47 @@ +package console_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/githubnext/apm/internal/utils/console" +) + +func TestStatusSymbols(t *testing.T) { + cases := map[string]string{ + "success": "[*]", + "error": "[x]", + "warning": "[!]", + "info": "[i]", + "check": "[+]", + } + for k, want := range cases { + if got := console.StatusSymbols[k]; got != want { + t.Errorf("StatusSymbols[%q] = %q, want %q", k, got, want) + } + } +} + +func TestEcho_noColor(t *testing.T) { + t.Setenv("NO_COLOR", "1") + var buf bytes.Buffer + console.Echo(&buf, "hello", "green", "", false) + if !strings.Contains(buf.String(), "hello") { + t.Errorf("expected 'hello' in output, got %q", buf.String()) + } +} + +func TestEcho_withSymbol(t *testing.T) { + t.Setenv("NO_COLOR", "1") + var buf bytes.Buffer + console.Echo(&buf, "done", "", "check", false) + if !strings.Contains(buf.String(), "[+]") { + t.Errorf("expected symbol [+] in output, got %q", buf.String()) + } +} + +func TestPrintFilesTable_smoke(t *testing.T) { + // Just ensure no panic. + console.PrintFilesTable([][]string{{"file.go", "main source"}}, "Files") +} diff --git a/internal/utils/contenthash/contenthash.go b/internal/utils/contenthash/contenthash.go new file mode 100644 index 0000000..d1e746c --- /dev/null +++ b/internal/utils/contenthash/contenthash.go @@ -0,0 +1,151 @@ +// Package contenthash provides deterministic SHA-256 content hashing for +// package integrity verification. +package contenthash + +import ( + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + "sort" +) + +const ( + // MarkerFilename is the cache-pin marker excluded from package hashes. + MarkerFilename = ".apm-pin" +) + +var excludedDirs = map[string]bool{ + ".git": true, + "__pycache__": true, +} + +// emptyHash is the well-known hash for an empty or missing package. +var emptyHash = "sha256:" + func() string { + h := sha256.Sum256([]byte{}) + return fmt.Sprintf("%x", h) +}() + +// ComputePackageHash computes a deterministic SHA-256 hash of a package's +// file tree. The hash is computed over sorted file paths and their contents, +// making it independent of filesystem ordering and metadata. +// +// Returns a hash string in format "sha256:". +func ComputePackageHash(packagePath string) (string, error) { + info, err := os.Lstat(packagePath) + if err != nil || !info.IsDir() { + return emptyHash, nil + } + + var relFiles []string + err = filepath.WalkDir(packagePath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + // Skip symlinks + if d.Type()&os.ModeSymlink != 0 { + return nil + } + rel, relErr := filepath.Rel(packagePath, path) + if relErr != nil { + return relErr + } + if rel == "." { + return nil + } + // Skip excluded directories + parts := splitPath(rel) + for _, part := range parts { + if excludedDirs[part] { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + } + if d.IsDir() { + return nil + } + // Exclude root-level marker files + if len(parts) == 1 && parts[0] == MarkerFilename { + return nil + } + relFiles = append(relFiles, filepath.ToSlash(rel)) + return nil + }) + if err != nil { + return "", fmt.Errorf("contenthash: walking %s: %w", packagePath, err) + } + + if len(relFiles) == 0 { + return emptyHash, nil + } + + sort.Strings(relFiles) + + h := sha256.New() + for _, rel := range relFiles { + h.Write([]byte(rel)) + f, openErr := os.Open(filepath.Join(packagePath, filepath.FromSlash(rel))) + if openErr != nil { + return "", fmt.Errorf("contenthash: opening %s: %w", rel, openErr) + } + _, copyErr := io.Copy(h, f) + f.Close() + if copyErr != nil { + return "", fmt.Errorf("contenthash: reading %s: %w", rel, copyErr) + } + } + + return fmt.Sprintf("sha256:%x", h.Sum(nil)), nil +} + +// ComputeFileHash computes SHA-256 of a single file's contents. +// Returns "sha256:". Returns the empty-content hash when the +// path does not exist or is not a regular file. +func ComputeFileHash(filePath string) (string, error) { + info, err := os.Lstat(filePath) + if err != nil { + return emptyHash, nil + } + if !info.Mode().IsRegular() { + return emptyHash, nil + } + f, err := os.Open(filePath) + if err != nil { + return emptyHash, nil + } + defer f.Close() + h := sha256.New() + if _, err = io.Copy(h, f); err != nil { + return "", fmt.Errorf("contenthash: reading %s: %w", filePath, err) + } + return fmt.Sprintf("sha256:%x", h.Sum(nil)), nil +} + +// VerifyPackageHash verifies a package's content matches the expected hash. +// Returns true if hash matches. +func VerifyPackageHash(packagePath, expectedHash string) (bool, error) { + actual, err := ComputePackageHash(packagePath) + if err != nil { + return false, err + } + return actual == expectedHash, nil +} + +// splitPath splits a slash-separated relative path into its components. +func splitPath(p string) []string { + s := filepath.ToSlash(p) + var parts []string + start := 0 + for i := 0; i <= len(s); i++ { + if i == len(s) || s[i] == '/' { + if seg := s[start:i]; seg != "" && seg != "." { + parts = append(parts, seg) + } + start = i + 1 + } + } + return parts +} diff --git a/internal/utils/contenthash/contenthash_test.go b/internal/utils/contenthash/contenthash_test.go new file mode 100644 index 0000000..af5fb18 --- /dev/null +++ b/internal/utils/contenthash/contenthash_test.go @@ -0,0 +1,107 @@ +package contenthash_test + +import ( + "crypto/sha256" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/githubnext/apm/internal/utils/contenthash" +) + +func TestComputePackageHash_empty(t *testing.T) { + dir := t.TempDir() + h, err := contenthash.ComputePackageHash(dir) + if err != nil { + t.Fatal(err) + } + want := "sha256:" + fmt.Sprintf("%x", sha256.Sum256([]byte{})) + if h != want { + t.Errorf("empty dir: got %s, want %s", h, want) + } +} + +func TestComputePackageHash_nonexistent(t *testing.T) { + h, err := contenthash.ComputePackageHash("/nonexistent/path/xyz") + if err != nil { + t.Fatal(err) + } + want := "sha256:" + fmt.Sprintf("%x", sha256.Sum256([]byte{})) + if h != want { + t.Errorf("nonexistent: got %s, want %s", h, want) + } +} + +func TestComputePackageHash_deterministic(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "b.txt"), []byte("world"), 0o644); err != nil { + t.Fatal(err) + } + h1, err := contenthash.ComputePackageHash(dir) + if err != nil { + t.Fatal(err) + } + h2, err := contenthash.ComputePackageHash(dir) + if err != nil { + t.Fatal(err) + } + if h1 != h2 { + t.Errorf("not deterministic: %s != %s", h1, h2) + } +} + +func TestComputePackageHash_excludesMarker(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + h1, _ := contenthash.ComputePackageHash(dir) + + if err := os.WriteFile(filepath.Join(dir, ".apm-pin"), []byte("marker"), 0o644); err != nil { + t.Fatal(err) + } + h2, _ := contenthash.ComputePackageHash(dir) + if h1 != h2 { + t.Errorf("marker should not affect hash: %s vs %s", h1, h2) + } +} + +func TestComputeFileHash(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + if err := os.WriteFile(path, []byte("content"), 0o644); err != nil { + t.Fatal(err) + } + h, err := contenthash.ComputeFileHash(path) + if err != nil { + t.Fatal(err) + } + sum := sha256.Sum256([]byte("content")) + want := fmt.Sprintf("sha256:%x", sum) + if h != want { + t.Errorf("got %s, want %s", h, want) + } +} + +func TestVerifyPackageHash(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "f.txt"), []byte("data"), 0o644); err != nil { + t.Fatal(err) + } + h, _ := contenthash.ComputePackageHash(dir) + ok, err := contenthash.VerifyPackageHash(dir, h) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("expected hash to verify") + } + ok2, _ := contenthash.VerifyPackageHash(dir, "sha256:wrong") + if ok2 { + t.Error("expected mismatch to fail") + } +} diff --git a/internal/utils/exclude/exclude.go b/internal/utils/exclude/exclude.go new file mode 100644 index 0000000..8ab5686 --- /dev/null +++ b/internal/utils/exclude/exclude.go @@ -0,0 +1,154 @@ +// Package exclude provides glob-style pattern matching for filtering paths +// against compilation.exclude patterns from apm.yml. +// +// Supports ** (recursive directory) wildcard matching with a bounded-recursion +// guard to prevent exponential blowup. +package exclude + +import ( + "fmt" + "path/filepath" + "strings" +) + +// MaxDoubleStarSegments is the maximum number of ** segments allowed in a +// single pattern to prevent exponential recursion blowup. +const MaxDoubleStarSegments = 5 + +// ValidateExcludePatterns validates and normalizes exclude patterns, rejecting +// dangerous ones. Returns the normalized patterns or an error if any pattern +// exceeds the ** segment safety limit. +func ValidateExcludePatterns(patterns []string) ([]string, error) { + if len(patterns) == 0 { + return nil, nil + } + validated := make([]string, 0, len(patterns)) + for _, pattern := range patterns { + normalized := strings.ReplaceAll(pattern, "\\", "/") + parts := strings.Split(normalized, "/") + // Collapse consecutive ** segments + collapsed := make([]string, 0, len(parts)) + for _, p := range parts { + if p == "**" && len(collapsed) > 0 && collapsed[len(collapsed)-1] == "**" { + continue + } + collapsed = append(collapsed, p) + } + normalized = strings.Join(collapsed, "/") + starCount := 0 + for _, p := range collapsed { + if p == "**" { + starCount++ + } + } + if starCount > MaxDoubleStarSegments { + return nil, fmt.Errorf( + "exclude: pattern %q has %d '**' segments (max %d); simplify the pattern", + pattern, starCount, MaxDoubleStarSegments, + ) + } + validated = append(validated, normalized) + } + return validated, nil +} + +// ShouldExclude checks whether a file path should be excluded based on the +// pre-validated patterns. baseDir is used to compute the relative path. +func ShouldExclude(filePath, baseDir string, excludePatterns []string) bool { + if len(excludePatterns) == 0 { + return false + } + absFile, err := filepath.Abs(filePath) + if err != nil { + absFile = filePath + } + absBase, err := filepath.Abs(baseDir) + if err != nil { + absBase = baseDir + } + rel, err := filepath.Rel(absBase, absFile) + if err != nil { + return false + } + relStr := filepath.ToSlash(rel) + if strings.HasPrefix(relStr, "../") { + return false + } + for _, pattern := range excludePatterns { + if matchesPattern(relStr, pattern) { + return true + } + } + return false +} + +// matchesPattern checks if a relative path string matches a single exclusion pattern. +func matchesPattern(relPathStr, pattern string) bool { + if strings.Contains(pattern, "**") { + pathParts := strings.Split(relPathStr, "/") + patternParts := strings.Split(pattern, "/") + return matchGlobRecursive(pathParts, patternParts) + } + ok, _ := filepath.Match(pattern, relPathStr) + if ok { + return true + } + // Directory prefix matching + if strings.HasSuffix(pattern, "/") { + return strings.HasPrefix(relPathStr, pattern) || relPathStr == strings.TrimSuffix(pattern, "/") + } + return strings.HasPrefix(relPathStr, pattern+"/") || relPathStr == pattern +} + +// matchGlobRecursive matches path components against pattern components with ** support. +func matchGlobRecursive(pathParts, patternParts []string) bool { + // Strip trailing empty parts + for len(patternParts) > 0 && patternParts[len(patternParts)-1] == "" { + patternParts = patternParts[:len(patternParts)-1] + } + + pi, xi := 0, 0 + for pi < len(patternParts) && xi < len(pathParts) { + part := patternParts[pi] + if part == "**" { + break + } + ok, _ := filepath.Match(part, pathParts[xi]) + if !ok { + return false + } + pi++ + xi++ + } + if pi == len(patternParts) { + return xi == len(pathParts) + } + return matchDoubleStar(pathParts[xi:], patternParts[pi:]) +} + +// matchDoubleStar handles ** segments with bounded recursion. +func matchDoubleStar(pathParts, patternParts []string) bool { + if len(patternParts) == 0 { + return len(pathParts) == 0 + } + if len(pathParts) == 0 { + for _, p := range patternParts { + if p != "**" && p != "" { + return false + } + } + return true + } + part := patternParts[0] + if part == "**" { + if matchDoubleStar(pathParts, patternParts[1:]) { + return true + } + return matchDoubleStar(pathParts[1:], patternParts) + } + ok, _ := filepath.Match(part, pathParts[0]) + if ok { + return matchDoubleStar(pathParts[1:], patternParts[1:]) + } + return false +} diff --git a/internal/utils/exclude/exclude_test.go b/internal/utils/exclude/exclude_test.go new file mode 100644 index 0000000..0757f8c --- /dev/null +++ b/internal/utils/exclude/exclude_test.go @@ -0,0 +1,72 @@ +package exclude_test + +import ( + "testing" + + "github.com/githubnext/apm/internal/utils/exclude" +) + +func TestValidateExcludePatterns_nil(t *testing.T) { + out, err := exclude.ValidateExcludePatterns(nil) + if err != nil || len(out) != 0 { + t.Errorf("nil input: got %v %v", out, err) + } +} + +func TestValidateExcludePatterns_normal(t *testing.T) { + patterns := []string{"docs/**", "build/", "*.log"} + out, err := exclude.ValidateExcludePatterns(patterns) + if err != nil { + t.Fatal(err) + } + if len(out) != 3 { + t.Errorf("expected 3, got %d", len(out)) + } +} + +func TestValidateExcludePatterns_tooManyStars(t *testing.T) { + pattern := "a/**/b/**/c/**/d/**/e/**/f/**" + _, err := exclude.ValidateExcludePatterns([]string{pattern}) + if err == nil { + t.Error("expected error for too many ** segments") + } +} + +func TestValidateExcludePatterns_collapsesConsecutiveStars(t *testing.T) { + out, err := exclude.ValidateExcludePatterns([]string{"a/**/**/b"}) + if err != nil { + t.Fatal(err) + } + if out[0] != "a/**/b" { + t.Errorf("expected collapsed pattern, got %s", out[0]) + } +} + +func TestShouldExclude_basic(t *testing.T) { + cases := []struct { + rel string + pattern string + want bool + }{ + {"docs/foo.md", "docs/**", true}, + {"src/main.go", "docs/**", false}, + {"build/out.bin", "build/", true}, + {"log.txt", "*.log", false}, + {"foo.log", "*.log", true}, + {"a/b/c.go", "a/**/c.go", true}, + {"a/x/y/c.go", "a/**/c.go", true}, + {"a/b/d.go", "a/**/c.go", false}, + } + for _, tc := range cases { + got := exclude.ShouldExclude("/base/"+tc.rel, "/base", []string{tc.pattern}) + if got != tc.want { + t.Errorf("ShouldExclude(%q, %q): got %v, want %v", tc.rel, tc.pattern, got, tc.want) + } + } +} + +func TestShouldExclude_noPatterns(t *testing.T) { + if exclude.ShouldExclude("/base/file.go", "/base", nil) { + t.Error("nil patterns should never exclude") + } +} diff --git a/internal/utils/fileops/fileops.go b/internal/utils/fileops/fileops.go new file mode 100644 index 0000000..0fc4508 --- /dev/null +++ b/internal/utils/fileops/fileops.go @@ -0,0 +1,183 @@ +// Package fileops provides retry-aware file operations for cross-platform +// reliability. +// +// On Windows, antivirus and endpoint-protection software briefly lock files +// while scanning them in temp directories. This package provides drop-in +// replacements for os.RemoveAll, filepath.WalkDir-based copy, and os.Copy +// that transparently retry on transient lock errors with exponential backoff. +package fileops + +import ( + "fmt" + "io" + "os" + "path/filepath" + "time" +) + +const ( + defaultMaxRetries = 5 + defaultInitialDelay = 100 * time.Millisecond + defaultMaxDelay = 2 * time.Second + defaultBackoffFactor = 2.0 +) + +// isTransientLockError returns true when err looks like a transient file-lock +// error. Platform-specific detection is in lock_unix.go / lock_windows.go. +// The function defined here handles the Unix EBUSY case; the build-tag files +// add Windows winerror 32/5 detection. + +// retryOnLock executes op, retrying on transient lock errors. +func retryOnLock(op func() error, desc string, maxRetries int, initial, max time.Duration, backoff float64, beforeRetry func()) error { + delay := initial + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + err := op() + if err == nil { + return nil + } + lastErr = err + if !isTransientLockError(err) || attempt == maxRetries { + return err + } + debugFileOp(fmt.Sprintf("%s: transient lock (attempt %d/%d), retrying in %s -- %v", + desc, attempt+1, maxRetries, delay, err)) + if beforeRetry != nil { + beforeRetry() + } + time.Sleep(delay) + next := time.Duration(float64(delay) * backoff) + if next > max { + next = max + } + delay = next + } + return lastErr +} + +// debugFileOp prints debug output when APM_DEBUG is set. +func debugFileOp(msg string) { + if os.Getenv("APM_DEBUG") != "" { + fmt.Fprintf(os.Stderr, "[DEBUG] %s\n", msg) + } +} + +// RobustRemoveAll removes a directory tree, retrying on transient lock errors. +// If ignoreErrors is true, any error after retries is silently discarded. +func RobustRemoveAll(path string, ignoreErrors bool, maxRetries int) error { + if maxRetries <= 0 { + maxRetries = defaultMaxRetries + } + err := retryOnLock(func() error { + return removeAllWritable(path) + }, "rmtree "+path, maxRetries, defaultInitialDelay, defaultMaxDelay, defaultBackoffFactor, nil) + if err != nil && ignoreErrors { + return nil + } + return err +} + +// removeAllWritable removes path, chmod-ing read-only files writable first. +func removeAllWritable(path string) error { + // chmod all files writable so rmtree succeeds on read-only trees (e.g. git pack). + _ = filepath.WalkDir(path, func(p string, d os.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + _ = os.Chmod(p, 0o666) + return nil + }) + return os.RemoveAll(path) +} + +// RobustCopyTree copies a directory tree from src to dst, retrying on +// transient lock errors. Any partial dst is removed before each retry +// unless dirsExistOK is true. +func RobustCopyTree(src, dst string, symlinks, dirsExistOK bool, maxRetries int) error { + if maxRetries <= 0 { + maxRetries = defaultMaxRetries + } + var beforeRetry func() + if !dirsExistOK { + beforeRetry = func() { + _ = os.RemoveAll(dst) + } + } + return retryOnLock(func() error { + return copyTree(src, dst, symlinks, dirsExistOK) + }, fmt.Sprintf("copytree %s -> %s", src, dst), maxRetries, defaultInitialDelay, defaultMaxDelay, defaultBackoffFactor, beforeRetry) +} + +// copyTree is the inner copy-tree implementation (no retry). +func copyTree(src, dst string, symlinks, dirsExistOK bool) error { + return filepath.WalkDir(src, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + rel, relErr := filepath.Rel(src, path) + if relErr != nil { + return relErr + } + target := filepath.Join(dst, rel) + if d.IsDir() { + if mkErr := os.MkdirAll(target, 0o755); mkErr != nil && !dirsExistOK { + return mkErr + } + return nil + } + if d.Type()&os.ModeSymlink != 0 { + if symlinks { + link, readErr := os.Readlink(path) + if readErr != nil { + return readErr + } + return os.Symlink(link, target) + } + // Dereference symlink: stat the real file. + info, statErr := os.Stat(path) + if statErr != nil || !info.Mode().IsRegular() { + return nil + } + } + return copyFile(path, target) + }) +} + +// RobustCopy2 copies a single file with metadata, retrying on transient lock +// errors. +func RobustCopy2(src, dst string, maxRetries int) error { + if maxRetries <= 0 { + maxRetries = defaultMaxRetries + } + return retryOnLock(func() error { + return copyFile(src, dst) + }, fmt.Sprintf("copy2 %s -> %s", src, dst), maxRetries, defaultInitialDelay, defaultMaxDelay, defaultBackoffFactor, nil) +} + +// copyFile copies src to dst, preserving permissions. +func copyFile(src, dst string) error { + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + info, err := in.Stat() + if err != nil { + return err + } + + out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, info.Mode()) + if err != nil { + return err + } + _, copyErr := io.Copy(out, in) + closeErr := out.Close() + if copyErr != nil { + return copyErr + } + return closeErr +} diff --git a/internal/utils/fileops/fileops_test.go b/internal/utils/fileops/fileops_test.go new file mode 100644 index 0000000..4486279 --- /dev/null +++ b/internal/utils/fileops/fileops_test.go @@ -0,0 +1,67 @@ +package fileops_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/githubnext/apm/internal/utils/fileops" +) + +func TestRobustRemoveAll(t *testing.T) { + dir := t.TempDir() + sub := filepath.Join(dir, "sub") + if err := os.MkdirAll(sub, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sub, "f.txt"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + target := filepath.Join(dir, "target") + if err := os.Rename(sub, target); err != nil { + t.Fatal(err) + } + if err := fileops.RobustRemoveAll(target, false, 0); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(target); !os.IsNotExist(err) { + t.Error("directory should have been removed") + } +} + +func TestRobustCopyTree(t *testing.T) { + src := t.TempDir() + dst := filepath.Join(t.TempDir(), "dst") + if err := os.WriteFile(filepath.Join(src, "a.txt"), []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + if err := fileops.RobustCopyTree(src, dst, false, false, 0); err != nil { + t.Fatal(err) + } + data, err := os.ReadFile(filepath.Join(dst, "a.txt")) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello" { + t.Errorf("expected 'hello', got %q", data) + } +} + +func TestRobustCopy2(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "src.txt") + dst := filepath.Join(dir, "dst.txt") + if err := os.WriteFile(src, []byte("content"), 0o644); err != nil { + t.Fatal(err) + } + if err := fileops.RobustCopy2(src, dst, 0); err != nil { + t.Fatal(err) + } + data, err := os.ReadFile(dst) + if err != nil { + t.Fatal(err) + } + if string(data) != "content" { + t.Errorf("expected 'content', got %q", data) + } +} diff --git a/internal/utils/fileops/lock_unix.go b/internal/utils/fileops/lock_unix.go new file mode 100644 index 0000000..a5f8a08 --- /dev/null +++ b/internal/utils/fileops/lock_unix.go @@ -0,0 +1,17 @@ +//go:build !windows + +package fileops + +import ( + "errors" + "syscall" +) + +// isTransientLockError returns true for EBUSY on Unix. +func isTransientLockError(err error) bool { + var errno syscall.Errno + if errors.As(err, &errno) { + return errno == syscall.EBUSY + } + return false +} diff --git a/internal/utils/fileops/lock_windows.go b/internal/utils/fileops/lock_windows.go new file mode 100644 index 0000000..d61cc53 --- /dev/null +++ b/internal/utils/fileops/lock_windows.go @@ -0,0 +1,19 @@ +//go:build windows + +package fileops + +import ( + "errors" + "strings" + "syscall" +) + +// isTransientLockError returns true for Windows winerror 32 or 5. +func isTransientLockError(err error) bool { + var errno syscall.Errno + if errors.As(err, &errno) { + return errno == syscall.ERROR_SHARING_VIOLATION || errno == syscall.ERROR_ACCESS_DENIED + } + s := strings.ToLower(err.Error()) + return strings.Contains(s, "used by another process") || strings.Contains(s, "access is denied") +}