Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pkg/sliceutil/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ All functions in this package are pure: they never modify their input. They are
| `FilterMapKeys` | `func[K comparable, V any](m map[K]V, predicate func(K, V) bool) []K` | Returns map keys for which `predicate(key, value)` is `true`; order is not guaranteed |
| `Any` | `func[T any](slice []T, predicate func(T) bool) bool` | Returns `true` if at least one element satisfies `predicate`; returns `false` for nil or empty slices |
| `Deduplicate` | `func[T comparable](slice []T) []T` | Returns a new slice with duplicate elements removed, preserving order of first occurrence |
| `MergeUnique` | `func[T comparable](base []T, extra ...T) []T` | Returns a deduplicated slice starting with `base` and appending unseen values from `extra` |
| `Exclude` | `func[T comparable](base []T, exclude ...T) []T` | Returns a new slice with all `exclude` values removed while preserving order |

## Usage Examples

Expand All @@ -38,12 +40,20 @@ upper := sliceutil.Map(names, strings.ToUpper)
items := []string{"a", "b", "a", "c"}
unique := sliceutil.Deduplicate(items)
// unique = ["a", "b", "c"]

// Merge unique values
merged := sliceutil.MergeUnique([]string{"a", "b"}, "b", "c")
// merged = ["a", "b", "c"]

// Exclude values
filtered := sliceutil.Exclude([]string{"a", "b", "c"}, "b")
// filtered = ["a", "c"]
```

## Design Notes

- `Any` is implemented via `slices.ContainsFunc` from the standard library.
- `Deduplicate` uses a `map[T]bool` for O(n) time complexity.
- `Deduplicate`, `MergeUnique`, and `Exclude` use hash sets (`map[T]struct{}`) for O(n) behavior.
- None of these functions sort their output; callers that require sorted results should call `slices.Sort` on the returned slice.

---
Expand Down
42 changes: 42 additions & 0 deletions pkg/sliceutil/sliceutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,45 @@ func Deduplicate[T comparable](slice []T) []T {
}
return result
}

// MergeUnique returns a deduplicated slice that starts with base and appends any
// items from extra that are not already present in base. Order is preserved.
func MergeUnique[T comparable](base []T, extra ...T) []T {
seen := make(map[T]struct{}, len(base)+len(extra))
result := make([]T, 0, len(base)+len(extra))
for _, item := range base {
if _, exists := seen[item]; !exists {
seen[item] = struct{}{}
result = append(result, item)
}
}
for _, item := range extra {
if _, exists := seen[item]; !exists {
seen[item] = struct{}{}
result = append(result, item)
}
}
return result
}

// Exclude returns a new slice containing the items from base that do not appear
// in the exclude set. Order of remaining items is preserved.
// Always returns a fresh slice (never aliases base) even when no items are removed.
func Exclude[T comparable](base []T, exclude ...T) []T {
if len(exclude) == 0 {
return append([]T(nil), base...)
}

excluded := make(map[T]struct{}, len(exclude))
for _, item := range exclude {
excluded[item] = struct{}{}
}

result := make([]T, 0, len(base))
for _, item := range base {
if _, isExcluded := excluded[item]; !isExcluded {
result = append(result, item)
}
}
return result
}
61 changes: 61 additions & 0 deletions pkg/sliceutil/sliceutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,64 @@ func TestAny_StopsEarly(t *testing.T) {
})
assert.Equal(t, 2, callCount, "Any should stop evaluating after first match")
}

func TestMergeUnique(t *testing.T) {
tests := []struct {
name string
base []string
extra []string
expected []string
}{
{
name: "deduplicates base and extra preserving first seen order",
base: []string{"a", "b", "a"},
extra: []string{"b", "c", "a", "d"},
expected: []string{"a", "b", "c", "d"},
},
{
name: "nil base with extra values",
base: nil,
extra: []string{"x", "x", "y"},
expected: []string{"x", "y"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := MergeUnique(tt.base, tt.extra...)
assert.Equal(t, tt.expected, result, "MergeUnique should return deduplicated merged slice")
})
}
}

func TestExclude(t *testing.T) {
tests := []struct {
name string
base []string
exclude []string
expected []string
}{
{
name: "excludes matching values while preserving order",
base: []string{"a", "b", "c", "b"},
exclude: []string{"b"},
expected: []string{"a", "c"},
},
{
name: "no excludes returns cloned slice",
base: []string{"a", "b"},
exclude: nil,
expected: []string{"a", "b"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := Exclude(tt.base, tt.exclude...)
assert.Equal(t, tt.expected, result, "Exclude should remove excluded elements")
if len(tt.exclude) == 0 && len(tt.base) > 0 {
assert.NotSame(t, &tt.base[0], &result[0], "Exclude should always return a fresh slice copy")
}
})
}
}
97 changes: 97 additions & 0 deletions pkg/stringutil/sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package stringutil

import (
"regexp"
"slices"
"strings"

"github.com/github/gh-aw/pkg/logger"
)

var sanitizeLog = logger.New("stringutil:sanitize")

var multipleHyphens = regexp.MustCompile(`-+`)

// Regex patterns for detecting potential secret key names
var (
// Match uppercase snake_case identifiers that look like secret names (e.g., MY_SECRET_KEY, GITHUB_TOKEN, API_KEY)
Expand Down Expand Up @@ -46,6 +49,100 @@ var (
}
)

// SanitizeOptions configures the behavior of the SanitizeName function.
type SanitizeOptions struct {
// PreserveSpecialChars is a list of special characters to preserve during sanitization.
// Common characters include '.', '_'. If nil or empty, only alphanumeric and hyphens are preserved.
PreserveSpecialChars []rune

// TrimHyphens controls whether leading and trailing hyphens are removed from the result.
// When true, hyphens at the start and end of the sanitized name are trimmed.
TrimHyphens bool

// DefaultValue is returned when the sanitized name is empty after all transformations.
// If empty string, no default is applied.
DefaultValue string
}

// SanitizeName sanitizes a string for use as an identifier, file name, or similar context.
// It provides configurable behavior through the SanitizeOptions parameter.
func SanitizeName(name string, opts *SanitizeOptions) string {
if sanitizeLog.Enabled() {
preserveCount := 0
trimHyphens := false
if opts != nil {
preserveCount = len(opts.PreserveSpecialChars)
trimHyphens = opts.TrimHyphens
}
sanitizeLog.Printf("Sanitizing name: input=%q, preserve_chars=%d, trim_hyphens=%t",
name, preserveCount, trimHyphens)
}

// Handle nil options
if opts == nil {
opts = &SanitizeOptions{}
}

// Convert to lowercase
result := strings.ToLower(name)

// Replace common separators with hyphens
result = strings.ReplaceAll(result, ":", "-")
result = strings.ReplaceAll(result, "\\", "-")
result = strings.ReplaceAll(result, "/", "-")
result = strings.ReplaceAll(result, " ", "-")

// Check if underscores should be preserved
preserveUnderscore := slices.Contains(opts.PreserveSpecialChars, '_')

// Replace underscores with hyphens if not preserved
if !preserveUnderscore {
result = strings.ReplaceAll(result, "_", "-")
}

// Build character preservation pattern based on options
var preserveChars strings.Builder
preserveChars.WriteString("a-z0-9-") // Always preserve alphanumeric and hyphens
if len(opts.PreserveSpecialChars) > 0 {
for _, char := range opts.PreserveSpecialChars {
// Escape special regex characters
switch char {
case '.', '_':
preserveChars.WriteRune(char)
}
}
}

// Create pattern for characters to remove/replace
pattern := regexp.MustCompile(`[^` + preserveChars.String() + `]+`)

// Replace unwanted characters with hyphens or empty based on context
if len(opts.PreserveSpecialChars) > 0 {
// Replace with hyphens (SanitizeWorkflowName behavior)
result = pattern.ReplaceAllString(result, "-")
} else {
// Remove completely (SanitizeIdentifier behavior)
result = pattern.ReplaceAllString(result, "")
}

// Consolidate multiple consecutive hyphens into a single hyphen
result = multipleHyphens.ReplaceAllString(result, "-")

// Optionally trim leading/trailing hyphens
if opts.TrimHyphens {
result = strings.Trim(result, "-")
}

// Return default value if result is empty
if result == "" && opts.DefaultValue != "" {
sanitizeLog.Printf("Sanitized name is empty, using default: %q", opts.DefaultValue)
return opts.DefaultValue
}

sanitizeLog.Printf("Sanitized name result: %q", result)
return result
}

// SanitizeErrorMessage removes potential secret key names from error messages to prevent
// information disclosure via logs. This prevents exposing details about an organization's
// security infrastructure by redacting secret key names that might appear in error messages.
Expand Down
42 changes: 42 additions & 0 deletions pkg/stringutil/sanitize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,3 +668,45 @@ func BenchmarkSanitizeForFilename(b *testing.B) {
SanitizeForFilename(slug)
}
}

func TestSanitizeName(t *testing.T) {
tests := []struct {
name string
input string
opts *SanitizeOptions
expected string
}{
{
name: "nil options remove special chars",
input: "My Workflow@123",
opts: nil,
expected: "my-workflow123",
},
{
name: "preserve dot and underscore",
input: "My.Workflow_Name",
opts: &SanitizeOptions{
PreserveSpecialChars: []rune{'.', '_'},
},
expected: "my.workflow_name",
},
{
name: "trim and default when empty",
input: "@@@",
opts: &SanitizeOptions{
TrimHyphens: true,
DefaultValue: "default-name",
},
expected: "default-name",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeName(tt.input, tt.opts)
if result != tt.expected {
t.Errorf("SanitizeName(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
14 changes: 2 additions & 12 deletions pkg/workflow/call_workflow_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package workflow
import (
"fmt"
"os"
"path/filepath"

"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/parser"
"github.com/goccy/go-yaml"
)

var callWorkflowPermissionsLog = logger.New("workflow:call_workflow_permissions")
Expand Down Expand Up @@ -85,17 +83,9 @@ func extractCallWorkflowPermissions(workflowName, markdownPath string) (*Permiss
// extractPermissionsFromYAMLFile reads a .lock.yml or .yml workflow file, parses it,
// and returns the merged permissions from all its jobs.
func extractPermissionsFromYAMLFile(filePath string) (*Permissions, error) {
cleanPath := filepath.Clean(filePath)
// filePath originates from findWorkflowFile(), which validates all paths via
// isPathWithinDir() to prevent directory traversal before returning them.
content, err := os.ReadFile(cleanPath) // #nosec G304 -- path pre-validated by findWorkflowFile() via isPathWithinDir()
workflow, err := readWorkflowYAML(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read workflow file %s: %w", filePath, err)
}

var workflow map[string]any
if err := yaml.Unmarshal(content, &workflow); err != nil {
return nil, fmt.Errorf("failed to parse workflow file %s: %w", filePath, err)
return nil, err
}

perms := extractJobPermissionsFromParsedWorkflow(workflow)
Expand Down
15 changes: 2 additions & 13 deletions pkg/workflow/call_workflow_secrets.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@ package workflow

import (
"fmt"
"os"
"path/filepath"
"sort"

"github.com/github/gh-aw/pkg/logger"
"github.com/goccy/go-yaml"
)

var callWorkflowSecretsLog = logger.New("workflow:call_workflow_secrets")
Expand Down Expand Up @@ -45,17 +42,9 @@ func extractCallWorkflowSecrets(workflowName, markdownPath string) ([]string, er
// extractSecretsFromWorkflowFile parses a .lock.yml or .yml workflow file and returns
// the secret names declared in its on.workflow_call.secrets section.
func extractSecretsFromWorkflowFile(filePath string) ([]string, error) {
cleanPath := filepath.Clean(filePath)
// filePath originates from findWorkflowFile(), which validates all paths via
// isPathWithinDir() to prevent directory traversal before returning them.
content, err := os.ReadFile(cleanPath) // #nosec G304 -- path pre-validated by findWorkflowFile() via isPathWithinDir()
workflow, err := readWorkflowYAML(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read workflow file %s: %w", filePath, err)
}

var workflow map[string]any
if err := yaml.Unmarshal(content, &workflow); err != nil {
return nil, fmt.Errorf("failed to parse workflow file %s: %w", filePath, err)
return nil, err
}

secrets := extractWorkflowCallSecretsFromParsed(workflow)
Expand Down
14 changes: 2 additions & 12 deletions pkg/workflow/call_workflow_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,9 @@ func (c *Compiler) validateCallWorkflow(data *WorkflowData, workflowPath string)
// extractWorkflowCallInputs parses a workflow file and extracts the workflow_call inputs schema.
// Returns a map of input definitions that can be used to generate MCP tool schemas.
func extractWorkflowCallInputs(workflowPath string) (map[string]any, error) {
cleanPath := filepath.Clean(workflowPath)
if !filepath.IsAbs(cleanPath) {
return nil, fmt.Errorf("workflow path must be absolute: %s", workflowPath)
}

workflowContent, err := os.ReadFile(cleanPath) // #nosec G304 -- Path is sanitized above
workflow, err := readWorkflowYAML(workflowPath)
if err != nil {
return nil, fmt.Errorf("failed to read workflow file %s: %w", workflowPath, err)
}

var workflow map[string]any
if err := yaml.Unmarshal(workflowContent, &workflow); err != nil {
return nil, fmt.Errorf("failed to parse workflow file %s: %w", workflowPath, err)
return nil, err
}

return extractWorkflowCallInputsFromParsed(workflow), nil
Expand Down
Loading
Loading