diff --git a/pkg/sliceutil/README.md b/pkg/sliceutil/README.md index a50e1a0537a..04a5a373680 100644 --- a/pkg/sliceutil/README.md +++ b/pkg/sliceutil/README.md @@ -18,6 +18,8 @@ All functions in this package are pure: they never modify their input. They are | `FilterMapKeys` | `func[K comparable, V any](m map[K]V, predicate func(K, V) bool) []K` | Returns map keys for which `predicate(key, value)` is `true`; order is not guaranteed | | `Any` | `func[T any](slice []T, predicate func(T) bool) bool` | Returns `true` if at least one element satisfies `predicate`; returns `false` for nil or empty slices | | `Deduplicate` | `func[T comparable](slice []T) []T` | Returns a new slice with duplicate elements removed, preserving order of first occurrence | +| `MergeUnique` | `func[T comparable](base []T, extra ...T) []T` | Returns a deduplicated slice starting with `base` and appending unseen values from `extra` | +| `Exclude` | `func[T comparable](base []T, exclude ...T) []T` | Returns a new slice with all `exclude` values removed while preserving order | ## Usage Examples @@ -38,12 +40,20 @@ upper := sliceutil.Map(names, strings.ToUpper) items := []string{"a", "b", "a", "c"} unique := sliceutil.Deduplicate(items) // unique = ["a", "b", "c"] + +// Merge unique values +merged := sliceutil.MergeUnique([]string{"a", "b"}, "b", "c") +// merged = ["a", "b", "c"] + +// Exclude values +filtered := sliceutil.Exclude([]string{"a", "b", "c"}, "b") +// filtered = ["a", "c"] ``` ## Design Notes - `Any` is implemented via `slices.ContainsFunc` from the standard library. -- `Deduplicate` uses a `map[T]bool` for O(n) time complexity. +- `Deduplicate`, `MergeUnique`, and `Exclude` use hash sets (`map[T]struct{}`) for O(n) behavior. - None of these functions sort their output; callers that require sorted results should call `slices.Sort` on the returned slice. --- diff --git a/pkg/sliceutil/sliceutil.go b/pkg/sliceutil/sliceutil.go index c06d7ff36cf..f4caacb6d1c 100644 --- a/pkg/sliceutil/sliceutil.go +++ b/pkg/sliceutil/sliceutil.go @@ -72,3 +72,45 @@ func Deduplicate[T comparable](slice []T) []T { } return result } + +// MergeUnique returns a deduplicated slice that starts with base and appends any +// items from extra that are not already present in base. Order is preserved. +func MergeUnique[T comparable](base []T, extra ...T) []T { + seen := make(map[T]struct{}, len(base)+len(extra)) + result := make([]T, 0, len(base)+len(extra)) + for _, item := range base { + if _, exists := seen[item]; !exists { + seen[item] = struct{}{} + result = append(result, item) + } + } + for _, item := range extra { + if _, exists := seen[item]; !exists { + seen[item] = struct{}{} + result = append(result, item) + } + } + return result +} + +// Exclude returns a new slice containing the items from base that do not appear +// in the exclude set. Order of remaining items is preserved. +// Always returns a fresh slice (never aliases base) even when no items are removed. +func Exclude[T comparable](base []T, exclude ...T) []T { + if len(exclude) == 0 { + return append([]T(nil), base...) + } + + excluded := make(map[T]struct{}, len(exclude)) + for _, item := range exclude { + excluded[item] = struct{}{} + } + + result := make([]T, 0, len(base)) + for _, item := range base { + if _, isExcluded := excluded[item]; !isExcluded { + result = append(result, item) + } + } + return result +} diff --git a/pkg/sliceutil/sliceutil_test.go b/pkg/sliceutil/sliceutil_test.go index b023d3ca90b..717ca0147a3 100644 --- a/pkg/sliceutil/sliceutil_test.go +++ b/pkg/sliceutil/sliceutil_test.go @@ -297,3 +297,64 @@ func TestAny_StopsEarly(t *testing.T) { }) assert.Equal(t, 2, callCount, "Any should stop evaluating after first match") } + +func TestMergeUnique(t *testing.T) { + tests := []struct { + name string + base []string + extra []string + expected []string + }{ + { + name: "deduplicates base and extra preserving first seen order", + base: []string{"a", "b", "a"}, + extra: []string{"b", "c", "a", "d"}, + expected: []string{"a", "b", "c", "d"}, + }, + { + name: "nil base with extra values", + base: nil, + extra: []string{"x", "x", "y"}, + expected: []string{"x", "y"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := MergeUnique(tt.base, tt.extra...) + assert.Equal(t, tt.expected, result, "MergeUnique should return deduplicated merged slice") + }) + } +} + +func TestExclude(t *testing.T) { + tests := []struct { + name string + base []string + exclude []string + expected []string + }{ + { + name: "excludes matching values while preserving order", + base: []string{"a", "b", "c", "b"}, + exclude: []string{"b"}, + expected: []string{"a", "c"}, + }, + { + name: "no excludes returns cloned slice", + base: []string{"a", "b"}, + exclude: nil, + expected: []string{"a", "b"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Exclude(tt.base, tt.exclude...) + assert.Equal(t, tt.expected, result, "Exclude should remove excluded elements") + if len(tt.exclude) == 0 && len(tt.base) > 0 { + assert.NotSame(t, &tt.base[0], &result[0], "Exclude should always return a fresh slice copy") + } + }) + } +} diff --git a/pkg/stringutil/sanitize.go b/pkg/stringutil/sanitize.go index 038cb8a9767..77be54f9e1b 100644 --- a/pkg/stringutil/sanitize.go +++ b/pkg/stringutil/sanitize.go @@ -2,6 +2,7 @@ package stringutil import ( "regexp" + "slices" "strings" "github.com/github/gh-aw/pkg/logger" @@ -9,6 +10,8 @@ import ( var sanitizeLog = logger.New("stringutil:sanitize") +var multipleHyphens = regexp.MustCompile(`-+`) + // Regex patterns for detecting potential secret key names var ( // Match uppercase snake_case identifiers that look like secret names (e.g., MY_SECRET_KEY, GITHUB_TOKEN, API_KEY) @@ -46,6 +49,100 @@ var ( } ) +// SanitizeOptions configures the behavior of the SanitizeName function. +type SanitizeOptions struct { + // PreserveSpecialChars is a list of special characters to preserve during sanitization. + // Common characters include '.', '_'. If nil or empty, only alphanumeric and hyphens are preserved. + PreserveSpecialChars []rune + + // TrimHyphens controls whether leading and trailing hyphens are removed from the result. + // When true, hyphens at the start and end of the sanitized name are trimmed. + TrimHyphens bool + + // DefaultValue is returned when the sanitized name is empty after all transformations. + // If empty string, no default is applied. + DefaultValue string +} + +// SanitizeName sanitizes a string for use as an identifier, file name, or similar context. +// It provides configurable behavior through the SanitizeOptions parameter. +func SanitizeName(name string, opts *SanitizeOptions) string { + if sanitizeLog.Enabled() { + preserveCount := 0 + trimHyphens := false + if opts != nil { + preserveCount = len(opts.PreserveSpecialChars) + trimHyphens = opts.TrimHyphens + } + sanitizeLog.Printf("Sanitizing name: input=%q, preserve_chars=%d, trim_hyphens=%t", + name, preserveCount, trimHyphens) + } + + // Handle nil options + if opts == nil { + opts = &SanitizeOptions{} + } + + // Convert to lowercase + result := strings.ToLower(name) + + // Replace common separators with hyphens + result = strings.ReplaceAll(result, ":", "-") + result = strings.ReplaceAll(result, "\\", "-") + result = strings.ReplaceAll(result, "/", "-") + result = strings.ReplaceAll(result, " ", "-") + + // Check if underscores should be preserved + preserveUnderscore := slices.Contains(opts.PreserveSpecialChars, '_') + + // Replace underscores with hyphens if not preserved + if !preserveUnderscore { + result = strings.ReplaceAll(result, "_", "-") + } + + // Build character preservation pattern based on options + var preserveChars strings.Builder + preserveChars.WriteString("a-z0-9-") // Always preserve alphanumeric and hyphens + if len(opts.PreserveSpecialChars) > 0 { + for _, char := range opts.PreserveSpecialChars { + // Escape special regex characters + switch char { + case '.', '_': + preserveChars.WriteRune(char) + } + } + } + + // Create pattern for characters to remove/replace + pattern := regexp.MustCompile(`[^` + preserveChars.String() + `]+`) + + // Replace unwanted characters with hyphens or empty based on context + if len(opts.PreserveSpecialChars) > 0 { + // Replace with hyphens (SanitizeWorkflowName behavior) + result = pattern.ReplaceAllString(result, "-") + } else { + // Remove completely (SanitizeIdentifier behavior) + result = pattern.ReplaceAllString(result, "") + } + + // Consolidate multiple consecutive hyphens into a single hyphen + result = multipleHyphens.ReplaceAllString(result, "-") + + // Optionally trim leading/trailing hyphens + if opts.TrimHyphens { + result = strings.Trim(result, "-") + } + + // Return default value if result is empty + if result == "" && opts.DefaultValue != "" { + sanitizeLog.Printf("Sanitized name is empty, using default: %q", opts.DefaultValue) + return opts.DefaultValue + } + + sanitizeLog.Printf("Sanitized name result: %q", result) + return result +} + // SanitizeErrorMessage removes potential secret key names from error messages to prevent // information disclosure via logs. This prevents exposing details about an organization's // security infrastructure by redacting secret key names that might appear in error messages. diff --git a/pkg/stringutil/sanitize_test.go b/pkg/stringutil/sanitize_test.go index 839f79a15a7..b8e194b4936 100644 --- a/pkg/stringutil/sanitize_test.go +++ b/pkg/stringutil/sanitize_test.go @@ -668,3 +668,45 @@ func BenchmarkSanitizeForFilename(b *testing.B) { SanitizeForFilename(slug) } } + +func TestSanitizeName(t *testing.T) { + tests := []struct { + name string + input string + opts *SanitizeOptions + expected string + }{ + { + name: "nil options remove special chars", + input: "My Workflow@123", + opts: nil, + expected: "my-workflow123", + }, + { + name: "preserve dot and underscore", + input: "My.Workflow_Name", + opts: &SanitizeOptions{ + PreserveSpecialChars: []rune{'.', '_'}, + }, + expected: "my.workflow_name", + }, + { + name: "trim and default when empty", + input: "@@@", + opts: &SanitizeOptions{ + TrimHyphens: true, + DefaultValue: "default-name", + }, + expected: "default-name", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeName(tt.input, tt.opts) + if result != tt.expected { + t.Errorf("SanitizeName(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/pkg/workflow/call_workflow_permissions.go b/pkg/workflow/call_workflow_permissions.go index e8551a82baf..704d68e9c8c 100644 --- a/pkg/workflow/call_workflow_permissions.go +++ b/pkg/workflow/call_workflow_permissions.go @@ -3,11 +3,9 @@ package workflow import ( "fmt" "os" - "path/filepath" "github.com/github/gh-aw/pkg/logger" "github.com/github/gh-aw/pkg/parser" - "github.com/goccy/go-yaml" ) var callWorkflowPermissionsLog = logger.New("workflow:call_workflow_permissions") @@ -85,17 +83,9 @@ func extractCallWorkflowPermissions(workflowName, markdownPath string) (*Permiss // extractPermissionsFromYAMLFile reads a .lock.yml or .yml workflow file, parses it, // and returns the merged permissions from all its jobs. func extractPermissionsFromYAMLFile(filePath string) (*Permissions, error) { - cleanPath := filepath.Clean(filePath) - // filePath originates from findWorkflowFile(), which validates all paths via - // isPathWithinDir() to prevent directory traversal before returning them. - content, err := os.ReadFile(cleanPath) // #nosec G304 -- path pre-validated by findWorkflowFile() via isPathWithinDir() + workflow, err := readWorkflowYAML(filePath) if err != nil { - return nil, fmt.Errorf("failed to read workflow file %s: %w", filePath, err) - } - - var workflow map[string]any - if err := yaml.Unmarshal(content, &workflow); err != nil { - return nil, fmt.Errorf("failed to parse workflow file %s: %w", filePath, err) + return nil, err } perms := extractJobPermissionsFromParsedWorkflow(workflow) diff --git a/pkg/workflow/call_workflow_secrets.go b/pkg/workflow/call_workflow_secrets.go index 04d1b4366e5..82779f561b9 100644 --- a/pkg/workflow/call_workflow_secrets.go +++ b/pkg/workflow/call_workflow_secrets.go @@ -2,12 +2,9 @@ package workflow import ( "fmt" - "os" - "path/filepath" "sort" "github.com/github/gh-aw/pkg/logger" - "github.com/goccy/go-yaml" ) var callWorkflowSecretsLog = logger.New("workflow:call_workflow_secrets") @@ -45,17 +42,9 @@ func extractCallWorkflowSecrets(workflowName, markdownPath string) ([]string, er // extractSecretsFromWorkflowFile parses a .lock.yml or .yml workflow file and returns // the secret names declared in its on.workflow_call.secrets section. func extractSecretsFromWorkflowFile(filePath string) ([]string, error) { - cleanPath := filepath.Clean(filePath) - // filePath originates from findWorkflowFile(), which validates all paths via - // isPathWithinDir() to prevent directory traversal before returning them. - content, err := os.ReadFile(cleanPath) // #nosec G304 -- path pre-validated by findWorkflowFile() via isPathWithinDir() + workflow, err := readWorkflowYAML(filePath) if err != nil { - return nil, fmt.Errorf("failed to read workflow file %s: %w", filePath, err) - } - - var workflow map[string]any - if err := yaml.Unmarshal(content, &workflow); err != nil { - return nil, fmt.Errorf("failed to parse workflow file %s: %w", filePath, err) + return nil, err } secrets := extractWorkflowCallSecretsFromParsed(workflow) diff --git a/pkg/workflow/call_workflow_validation.go b/pkg/workflow/call_workflow_validation.go index fc5b7910825..8d6c2d275e8 100644 --- a/pkg/workflow/call_workflow_validation.go +++ b/pkg/workflow/call_workflow_validation.go @@ -181,19 +181,9 @@ func (c *Compiler) validateCallWorkflow(data *WorkflowData, workflowPath string) // extractWorkflowCallInputs parses a workflow file and extracts the workflow_call inputs schema. // Returns a map of input definitions that can be used to generate MCP tool schemas. func extractWorkflowCallInputs(workflowPath string) (map[string]any, error) { - cleanPath := filepath.Clean(workflowPath) - if !filepath.IsAbs(cleanPath) { - return nil, fmt.Errorf("workflow path must be absolute: %s", workflowPath) - } - - workflowContent, err := os.ReadFile(cleanPath) // #nosec G304 -- Path is sanitized above + workflow, err := readWorkflowYAML(workflowPath) if err != nil { - return nil, fmt.Errorf("failed to read workflow file %s: %w", workflowPath, err) - } - - var workflow map[string]any - if err := yaml.Unmarshal(workflowContent, &workflow); err != nil { - return nil, fmt.Errorf("failed to parse workflow file %s: %w", workflowPath, err) + return nil, err } return extractWorkflowCallInputsFromParsed(workflow), nil diff --git a/pkg/workflow/compiler_safe_outputs_config.go b/pkg/workflow/compiler_safe_outputs_config.go index 6510f62f81b..ba337cafb21 100644 --- a/pkg/workflow/compiler_safe_outputs_config.go +++ b/pkg/workflow/compiler_safe_outputs_config.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/sliceutil" ) // ======================================== @@ -69,8 +70,8 @@ func (c *Compiler) addHandlerManagerConfigEnvVar(steps *[]string, data *Workflow excludeFiles := extractStringSliceFromConfig(handlerConfig, "_protected_files_exclude") delete(handlerConfig, "_protected_files_exclude") - handlerConfig["protected_files"] = excludeFromSlice(fullManifestFiles, excludeFiles...) - filteredPrefixes := excludeFromSlice(fullPathPrefixes, excludeFiles...) + handlerConfig["protected_files"] = sliceutil.Exclude(fullManifestFiles, excludeFiles...) + filteredPrefixes := sliceutil.Exclude(fullPathPrefixes, excludeFiles...) if len(filteredPrefixes) > 0 { handlerConfig["protected_path_prefixes"] = filteredPrefixes } else { diff --git a/pkg/workflow/dispatch_workflow_validation.go b/pkg/workflow/dispatch_workflow_validation.go index 418ed283f78..8fb96fec6f7 100644 --- a/pkg/workflow/dispatch_workflow_validation.go +++ b/pkg/workflow/dispatch_workflow_validation.go @@ -142,19 +142,9 @@ func (c *Compiler) validateDispatchWorkflow(data *WorkflowData, workflowPath str // Returns a map of input definitions that can be used to generate MCP tool schemas func extractWorkflowDispatchInputs(workflowPath string) (map[string]any, error) { dispatchWorkflowValidationLog.Printf("Extracting workflow_dispatch inputs from: %s", workflowPath) - cleanPath := filepath.Clean(workflowPath) - if !filepath.IsAbs(cleanPath) { - return nil, fmt.Errorf("workflow path must be absolute: %s", workflowPath) - } - - workflowContent, err := os.ReadFile(cleanPath) // #nosec G304 -- Path is sanitized above + workflow, err := readWorkflowYAML(workflowPath) if err != nil { - return nil, fmt.Errorf("failed to read workflow file %s: %w", workflowPath, err) - } - - var workflow map[string]any - if err := yaml.Unmarshal(workflowContent, &workflow); err != nil { - return nil, fmt.Errorf("failed to parse workflow file %s: %w", workflowPath, err) + return nil, err } onSection, hasOn := workflow["on"] diff --git a/pkg/workflow/imports.go b/pkg/workflow/imports.go index 5220177f92c..8042aa8da23 100644 --- a/pkg/workflow/imports.go +++ b/pkg/workflow/imports.go @@ -10,6 +10,7 @@ import ( "github.com/github/gh-aw/pkg/logger" "github.com/github/gh-aw/pkg/parser" + "github.com/github/gh-aw/pkg/sliceutil" ) var importsLog = logger.New("workflow:imports") @@ -247,7 +248,7 @@ func (c *Compiler) MergeSafeOutputs(topSafeOutputs *SafeOutputsConfig, importedS if handlerCfg, ok := config[key].(map[string]any); ok { if pf, ok := handlerCfg["protected-files"].(map[string]any); ok { if excludeFiles := parseStringSliceAny(pf["exclude"], importsLog); len(excludeFiles) > 0 { - accumulatedExclude[key] = mergeUnique(accumulatedExclude[key], excludeFiles...) + accumulatedExclude[key] = sliceutil.MergeUnique(accumulatedExclude[key], excludeFiles...) importsLog.Printf("Saved protected-files exclude from overridden import %s: %v", key, excludeFiles) } } @@ -294,7 +295,7 @@ func (c *Compiler) MergeSafeOutputs(topSafeOutputs *SafeOutputsConfig, importedS if len(accumulatedExclude) > 0 { if result.CreatePullRequests != nil { if excludeFiles, ok := accumulatedExclude["create-pull-request"]; ok && len(excludeFiles) > 0 { - result.CreatePullRequests.ProtectedFilesExclude = mergeUnique( + result.CreatePullRequests.ProtectedFilesExclude = sliceutil.MergeUnique( result.CreatePullRequests.ProtectedFilesExclude, excludeFiles..., ) @@ -303,7 +304,7 @@ func (c *Compiler) MergeSafeOutputs(topSafeOutputs *SafeOutputsConfig, importedS } if result.PushToPullRequestBranch != nil { if excludeFiles, ok := accumulatedExclude["push-to-pull-request-branch"]; ok && len(excludeFiles) > 0 { - result.PushToPullRequestBranch.ProtectedFilesExclude = mergeUnique( + result.PushToPullRequestBranch.ProtectedFilesExclude = sliceutil.MergeUnique( result.PushToPullRequestBranch.ProtectedFilesExclude, excludeFiles..., ) @@ -461,7 +462,7 @@ func mergeSafeOutputConfig(result *SafeOutputsConfig, config map[string]any, c * } else if result.CreatePullRequests != nil && importedConfig.CreatePullRequests != nil { // Merge protected-files exclude lists as a set so that imports can extend exclusions // without replacing the top-level configuration entirely. - result.CreatePullRequests.ProtectedFilesExclude = mergeUnique( + result.CreatePullRequests.ProtectedFilesExclude = sliceutil.MergeUnique( result.CreatePullRequests.ProtectedFilesExclude, importedConfig.CreatePullRequests.ProtectedFilesExclude..., ) @@ -516,7 +517,7 @@ func mergeSafeOutputConfig(result *SafeOutputsConfig, config map[string]any, c * } else if result.PushToPullRequestBranch != nil && importedConfig.PushToPullRequestBranch != nil { // Merge protected-files exclude lists as a set so that imports can extend exclusions // without replacing the top-level configuration entirely. - result.PushToPullRequestBranch.ProtectedFilesExclude = mergeUnique( + result.PushToPullRequestBranch.ProtectedFilesExclude = sliceutil.MergeUnique( result.PushToPullRequestBranch.ProtectedFilesExclude, importedConfig.PushToPullRequestBranch.ProtectedFilesExclude..., ) @@ -612,7 +613,7 @@ func mergeSafeOutputConfig(result *SafeOutputsConfig, config map[string]any, c * result.RunsOn = importedConfig.RunsOn } if len(importedConfig.Needs) > 0 { - result.Needs = mergeUnique(result.Needs, importedConfig.Needs...) + result.Needs = sliceutil.MergeUnique(result.Needs, importedConfig.Needs...) } // Merge Messages configuration at field level (main workflow entries override imported entries) diff --git a/pkg/workflow/runtime_definitions.go b/pkg/workflow/runtime_definitions.go index 2e9e40850a1..b7fc67637d3 100644 --- a/pkg/workflow/runtime_definitions.go +++ b/pkg/workflow/runtime_definitions.go @@ -3,6 +3,7 @@ package workflow import ( "github.com/github/gh-aw/pkg/constants" "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/sliceutil" ) var runtimeDefLog = logger.New("workflow:runtime_definitions") @@ -205,7 +206,7 @@ func getAllManifestFiles(extra ...string) []string { files = append(files, runtime.ManifestFiles...) } files = append(files, securityConfigFiles...) - return mergeUnique(files, extra...) + return sliceutil.MergeUnique(files, extra...) } // getProtectedPathPrefixes returns non-dot path prefixes (relative to repo root) @@ -223,7 +224,7 @@ func getProtectedPathPrefixes(extra ...string) []string { nonDot = append(nonDot, p) } } - return mergeUnique(nil, nonDot...) + return sliceutil.MergeUnique([]string(nil), nonDot...) } // getDotFolderExcludes returns the subset of excludeFiles that are top-level @@ -242,46 +243,6 @@ func getDotFolderExcludes(excludeFiles []string) []string { return result } -// excludeFromSlice returns a new slice containing the items from base -// that do not appear in the exclude set. Order of remaining items is preserved. -// Always returns a fresh slice (never aliases base) even when no items are removed. -func excludeFromSlice(base []string, exclude ...string) []string { - if len(exclude) == 0 { - return append([]string(nil), base...) - } - excluded := make(map[string]bool, len(exclude)) - for _, v := range exclude { - excluded[v] = true - } - result := make([]string, 0, len(base)) - for _, v := range base { - if !excluded[v] { - result = append(result, v) - } - } - return result -} - -// mergeUnique returns a deduplicated slice that starts with base and appends any -// items from extra that are not already present in base. Order is preserved. -func mergeUnique(base []string, extra ...string) []string { - seen := make(map[string]bool, len(base)+len(extra)) - result := make([]string, 0, len(base)+len(extra)) - for _, v := range base { - if !seen[v] { - seen[v] = true - result = append(result, v) - } - } - for _, v := range extra { - if !seen[v] { - seen[v] = true - result = append(result, v) - } - } - return result -} - // findRuntimeByID finds a runtime configuration by its ID func findRuntimeByID(id string) *Runtime { runtimeDefLog.Printf("Finding runtime by ID: %s", id) diff --git a/pkg/workflow/safe_outputs_config_generation.go b/pkg/workflow/safe_outputs_config_generation.go index 532fb0af532..96299d1f617 100644 --- a/pkg/workflow/safe_outputs_config_generation.go +++ b/pkg/workflow/safe_outputs_config_generation.go @@ -6,6 +6,7 @@ import ( "sort" "strings" + "github.com/github/gh-aw/pkg/sliceutil" "github.com/github/gh-aw/pkg/stringutil" ) @@ -45,8 +46,8 @@ func generateSafeOutputsConfig(data *WorkflowData) (string, error) { if _, hasProtectedFiles := handlerCfg["protected_files"]; hasProtectedFiles { fullManifestFiles := getAllManifestFiles(engineManifestFiles...) fullPathPrefixes := getProtectedPathPrefixes(engineManifestPathPrefixes...) - handlerCfg["protected_files"] = excludeFromSlice(fullManifestFiles, excludeFiles...) - filteredPrefixes := excludeFromSlice(fullPathPrefixes, excludeFiles...) + handlerCfg["protected_files"] = sliceutil.Exclude(fullManifestFiles, excludeFiles...) + filteredPrefixes := sliceutil.Exclude(fullPathPrefixes, excludeFiles...) if len(filteredPrefixes) > 0 { handlerCfg["protected_path_prefixes"] = filteredPrefixes } else { diff --git a/pkg/workflow/strings.go b/pkg/workflow/strings.go index 231df47694b..f1718d73188 100644 --- a/pkg/workflow/strings.go +++ b/pkg/workflow/strings.go @@ -85,7 +85,6 @@ import ( "encoding/hex" "fmt" "regexp" - "slices" "strings" "github.com/github/gh-aw/pkg/logger" @@ -94,124 +93,12 @@ import ( var stringsLog = logger.New("workflow:strings") -var multipleHyphens = regexp.MustCompile(`-+`) - // SanitizeOptions configures the behavior of the SanitizeName function. -type SanitizeOptions struct { - // PreserveSpecialChars is a list of special characters to preserve during sanitization. - // Common characters include '.', '_'. If nil or empty, only alphanumeric and hyphens are preserved. - PreserveSpecialChars []rune - - // TrimHyphens controls whether leading and trailing hyphens are removed from the result. - // When true, hyphens at the start and end of the sanitized name are trimmed. - TrimHyphens bool - - // DefaultValue is returned when the sanitized name is empty after all transformations. - // If empty string, no default is applied. - DefaultValue string -} +type SanitizeOptions = stringutil.SanitizeOptions // SanitizeName sanitizes a string for use as an identifier, file name, or similar context. -// It provides configurable behavior through the SanitizeOptions parameter. -// -// The function performs the following transformations: -// - Converts to lowercase -// - Replaces common separators (colons, slashes, backslashes, spaces) with hyphens -// - Replaces underscores with hyphens unless preserved in opts.PreserveSpecialChars -// - Removes or replaces characters based on opts.PreserveSpecialChars -// - Consolidates multiple consecutive hyphens into a single hyphen -// - Optionally trims leading/trailing hyphens (controlled by opts.TrimHyphens) -// - Returns opts.DefaultValue if the result is empty (controlled by opts.DefaultValue) -// -// Example: -// -// // Preserve dots and underscores (like SanitizeWorkflowName) -// opts := &SanitizeOptions{ -// PreserveSpecialChars: []rune{'.', '_'}, -// } -// SanitizeName("My.Workflow_Name", opts) // returns "my.workflow_name" -// -// // Trim hyphens and use default (like SanitizeIdentifier) -// opts := &SanitizeOptions{ -// TrimHyphens: true, -// DefaultValue: "default-name", -// } -// SanitizeName("@@@", opts) // returns "default-name" func SanitizeName(name string, opts *SanitizeOptions) string { - if stringsLog.Enabled() { - preserveCount := 0 - trimHyphens := false - if opts != nil { - preserveCount = len(opts.PreserveSpecialChars) - trimHyphens = opts.TrimHyphens - } - stringsLog.Printf("Sanitizing name: input=%q, preserve_chars=%d, trim_hyphens=%t", - name, preserveCount, trimHyphens) - } - - // Handle nil options - if opts == nil { - opts = &SanitizeOptions{} - } - - // Convert to lowercase - result := strings.ToLower(name) - - // Replace common separators with hyphens - result = strings.ReplaceAll(result, ":", "-") - result = strings.ReplaceAll(result, "\\", "-") - result = strings.ReplaceAll(result, "/", "-") - result = strings.ReplaceAll(result, " ", "-") - - // Check if underscores should be preserved - preserveUnderscore := slices.Contains(opts.PreserveSpecialChars, '_') - - // Replace underscores with hyphens if not preserved - if !preserveUnderscore { - result = strings.ReplaceAll(result, "_", "-") - } - - // Build character preservation pattern based on options - var preserveChars strings.Builder - preserveChars.WriteString("a-z0-9-") // Always preserve alphanumeric and hyphens - if len(opts.PreserveSpecialChars) > 0 { - for _, char := range opts.PreserveSpecialChars { - // Escape special regex characters - switch char { - case '.', '_': - preserveChars.WriteRune(char) - } - } - } - - // Create pattern for characters to remove/replace - pattern := regexp.MustCompile(`[^` + preserveChars.String() + `]+`) - - // Replace unwanted characters with hyphens or empty based on context - if len(opts.PreserveSpecialChars) > 0 { - // Replace with hyphens (SanitizeWorkflowName behavior) - result = pattern.ReplaceAllString(result, "-") - } else { - // Remove completely (SanitizeIdentifier behavior) - result = pattern.ReplaceAllString(result, "") - } - - // Consolidate multiple consecutive hyphens into a single hyphen - result = multipleHyphens.ReplaceAllString(result, "-") - - // Optionally trim leading/trailing hyphens - if opts.TrimHyphens { - result = strings.Trim(result, "-") - } - - // Return default value if result is empty - if result == "" && opts.DefaultValue != "" { - stringsLog.Printf("Sanitized name is empty, using default: %q", opts.DefaultValue) - return opts.DefaultValue - } - - stringsLog.Printf("Sanitized name result: %q", result) - return result + return stringutil.SanitizeName(name, opts) } // SanitizeWorkflowName sanitizes a workflow name for use in artifact names and file paths. diff --git a/pkg/workflow/yaml.go b/pkg/workflow/yaml.go index d7644742c6f..456487418e7 100644 --- a/pkg/workflow/yaml.go +++ b/pkg/workflow/yaml.go @@ -85,6 +85,8 @@ package workflow import ( "fmt" + "os" + "path/filepath" "regexp" "slices" "sort" @@ -104,6 +106,29 @@ var yamlNullPattern = regexp.MustCompile(`:\s*null\s*$`) // unquoteYAMLKeyCache caches compiled regexes for UnquoteYAMLKey by key name var unquoteYAMLKeyCache sync.Map +// readWorkflowYAML reads and parses a trusted workflow YAML file path. +// The caller is responsible for repository-boundary validation (for example via +// findWorkflowFile/isPathWithinDir) before passing workflowPath. +func readWorkflowYAML(workflowPath string) (map[string]any, error) { + cleanPath := filepath.Clean(workflowPath) + absPath, err := filepath.Abs(cleanPath) + if err != nil { + return nil, fmt.Errorf("failed to resolve workflow path %s: %w", workflowPath, err) + } + + content, err := os.ReadFile(absPath) // #nosec G304 -- Caller provides trusted path, and path is normalized/absolute-resolved above + if err != nil { + return nil, fmt.Errorf("failed to read workflow file %s: %w", workflowPath, err) + } + + var workflow map[string]any + if err := yaml.Unmarshal(content, &workflow); err != nil { + return nil, fmt.Errorf("failed to parse workflow file %s: %w", workflowPath, err) + } + + return workflow, nil +} + // UnquoteYAMLKey removes quotes from a YAML key at the start of a line. // // The YAML marshaler automatically adds quotes around YAML reserved words and keywords diff --git a/pkg/workflow/yaml_read_test.go b/pkg/workflow/yaml_read_test.go new file mode 100644 index 00000000000..e9332d5f158 --- /dev/null +++ b/pkg/workflow/yaml_read_test.go @@ -0,0 +1,64 @@ +//go:build !integration + +package workflow + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadWorkflowYAML(t *testing.T) { + t.Run("reads valid workflow yaml from absolute path", func(t *testing.T) { + tmpDir := t.TempDir() + workflowPath := filepath.Join(tmpDir, "workflow.yml") + content := "on:\n workflow_call: {}\n" + require.NoError(t, os.WriteFile(workflowPath, []byte(content), 0644), "Should write workflow file") + + workflow, err := readWorkflowYAML(workflowPath) + require.NoError(t, err, "Should read workflow YAML without error") + require.NotNil(t, workflow, "Should return parsed workflow map") + + onSection, ok := workflow["on"].(map[string]any) + require.True(t, ok, "Parsed workflow should contain on map") + _, hasWorkflowCall := onSection["workflow_call"] + assert.True(t, hasWorkflowCall, "Parsed workflow should include workflow_call trigger") + }) + + t.Run("reads relative path by resolving to absolute path", func(t *testing.T) { + tmpDir := t.TempDir() + workflowPath := filepath.Join(tmpDir, "workflow.yml") + content := "on:\n workflow_dispatch: {}\n" + require.NoError(t, os.WriteFile(workflowPath, []byte(content), 0644), "Should write workflow file") + + cwd, err := os.Getwd() + require.NoError(t, err, "Should get current working directory") + require.NoError(t, os.Chdir(tmpDir), "Should change working directory for relative path test") + t.Cleanup(func() { + require.NoError(t, os.Chdir(cwd), "Should restore working directory") + }) + + workflow, err := readWorkflowYAML("workflow.yml") + require.NoError(t, err, "Relative path should resolve and parse successfully") + require.NotNil(t, workflow, "Relative paths should return workflow data") + onSection, ok := workflow["on"].(map[string]any) + require.True(t, ok, "Parsed workflow should contain on map") + _, hasWorkflowDispatch := onSection["workflow_dispatch"] + assert.True(t, hasWorkflowDispatch, "Parsed workflow should include workflow_dispatch trigger") + }) + + t.Run("returns parse error for invalid yaml", func(t *testing.T) { + tmpDir := t.TempDir() + workflowPath := filepath.Join(tmpDir, "invalid.yml") + invalid := "on:\n workflow_call: [\n" + require.NoError(t, os.WriteFile(workflowPath, []byte(invalid), 0644), "Should write invalid workflow file") + + workflow, err := readWorkflowYAML(workflowPath) + assert.Nil(t, workflow, "Invalid YAML should not return workflow data") + require.Error(t, err, "Invalid YAML should return an error") + assert.Contains(t, err.Error(), "failed to parse workflow file", "Should wrap parse error consistently") + }) +}