diff --git a/templates/templates.go b/templates/templates.go index 4af4496d19a1..ea381a5f68ce 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -6,6 +6,9 @@ package templates import ( "bytes" "encoding/json" + "fmt" + "reflect" + "sort" "strings" "text/template" ) @@ -26,7 +29,7 @@ var basicFunctions = template.FuncMap{ return strings.TrimSpace(buf.String()) }, "split": strings.Split, - "join": strings.Join, + "join": joinElements, "title": strings.Title, //nolint:nolintlint,staticcheck // strings.Title is deprecated, but we only use it for ASCII, so replacing with golang.org/x/text is out of scope "lower": strings.ToLower, "upper": strings.ToUpper, @@ -103,3 +106,40 @@ func truncateWithLength(source string, length int) string { } return source[:length] } + +// joinElements joins a slice of items with the given separator. It uses +// [strings.Join] if it's a slice of strings, otherwise uses [fmt.Sprint] +// to join each item to the output. +func joinElements(elems any, sep string) (string, error) { + if elems == nil { + return "", nil + } + + if ss, ok := elems.([]string); ok { + return strings.Join(ss, sep), nil + } + + switch rv := reflect.ValueOf(elems); rv.Kind() { //nolint:exhaustive // ignore: too many options to make exhaustive + case reflect.Array, reflect.Slice: + var b strings.Builder + for i := range rv.Len() { + if i > 0 { + b.WriteString(sep) + } + _, _ = fmt.Fprint(&b, rv.Index(i).Interface()) + } + return b.String(), nil + + case reflect.Map: + var out []string + for _, k := range rv.MapKeys() { + out = append(out, fmt.Sprint(rv.MapIndex(k).Interface())) + } + // Not ideal, but trying to keep a consistent order + sort.Strings(out) + return strings.Join(out, sep), nil + + default: + return "", fmt.Errorf("expected slice, got %T", elems) + } +} diff --git a/templates/templates_test.go b/templates/templates_test.go index e9dbaefd0e5e..ed1ee5b95d13 100644 --- a/templates/templates_test.go +++ b/templates/templates_test.go @@ -3,6 +3,7 @@ package templates import ( "bytes" "testing" + "text/template" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" @@ -139,3 +140,92 @@ func TestHeaderFunctions(t *testing.T) { }) } } + +type stringerString string + +func (s stringerString) String() string { + return "stringer" + string(s) +} + +type stringerAndError string + +func (s stringerAndError) String() string { + return "stringer" + string(s) +} + +func (s stringerAndError) Error() string { + return "error" + string(s) +} + +func TestJoinElements(t *testing.T) { + tests := []struct { + doc string + data any + expOut string + expErr string + }{ + { + doc: "nil", + data: nil, + expOut: `output: ""`, + }, + { + doc: "non-slice", + data: "hello", + expOut: `output: "`, + expErr: `error calling join: expected slice, got string`, + }, + { + doc: "structs", + data: []struct{ A, B string }{{"1", "2"}, {"3", "4"}}, + expOut: `output: "{1 2}, {3 4}"`, + }, + { + doc: "map with strings", + data: map[string]string{"A": "1", "B": "2", "C": "3"}, + expOut: `output: "1, 2, 3"`, + }, + { + doc: "map with stringers", + data: map[string]stringerString{"A": "1", "B": "2", "C": "3"}, + expOut: `output: "stringer1, stringer2, stringer3"`, + }, + { + doc: "map with errors", + data: []stringerAndError{"1", "2", "3"}, + expOut: `output: "error1, error2, error3"`, + }, + { + doc: "stringers", + data: []stringerString{"1", "2", "3"}, + expOut: `output: "stringer1, stringer2, stringer3"`, + }, + { + doc: "stringer with errors", + data: []stringerAndError{"1", "2", "3"}, + expOut: `output: "error1, error2, error3"`, + }, + { + doc: "slice of bools", + data: []bool{true, false, true}, + expOut: `output: "true, false, true"`, + }, + } + + const formatStr = `output: "{{- join . ", " -}}"` + tmpl, err := New("my-template").Funcs(template.FuncMap{"join": joinElements}).Parse(formatStr) + assert.NilError(t, err) + + for _, tc := range tests { + t.Run(tc.doc, func(t *testing.T) { + var b bytes.Buffer + err := tmpl.Execute(&b, tc.data) + if tc.expErr != "" { + assert.ErrorContains(t, err, tc.expErr) + } else { + assert.NilError(t, err) + } + assert.Equal(t, b.String(), tc.expOut) + }) + } +}