Skip to content

Commit 479a210

Browse files
Copilotpelikhan
andauthored
Harden go-gh remote fetch callsites with escaped contents paths/refs and bounded REST clients (#41036)
* Initial plan * Fix go-gh content ref escaping and REST client timeout options Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> * Address review feedback on path and host normalization Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> * Simplify secret API host normalization fallback Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> Co-authored-by: Peli de Halleux <pelikhan@users.noreply.github.com>
1 parent 46ed40e commit 479a210

6 files changed

Lines changed: 153 additions & 51 deletions

File tree

pkg/cli/secret_set_command.go

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"net/url"
1011
"os"
1112
"strings"
1213

1314
"github.com/cli/go-gh/v2/pkg/api"
1415
"github.com/github/gh-aw/pkg/console"
16+
"github.com/github/gh-aw/pkg/constants"
1517
"github.com/github/gh-aw/pkg/logger"
1618
"github.com/github/gh-aw/pkg/repoutil"
19+
"github.com/github/gh-aw/pkg/stringutil"
1720
"github.com/github/gh-aw/pkg/tty"
1821
"github.com/spf13/cobra"
1922
"golang.org/x/crypto/nacl/box"
@@ -88,11 +91,7 @@ The secret value can be provided in three ways:
8891
}
8992

9093
// Create GitHub REST client using go-gh
91-
opts := api.ClientOptions{}
92-
if flagAPIBase != "" {
93-
opts.Host = strings.TrimPrefix(strings.TrimPrefix(flagAPIBase, "https://"), "http://")
94-
}
95-
client, err := api.NewRESTClient(opts)
94+
client, err := api.NewRESTClient(secretSetClientOptions(flagAPIBase))
9695
if err != nil {
9796
return fmt.Errorf("cannot create GitHub client: %w", err)
9897
}
@@ -123,6 +122,44 @@ The secret value can be provided in three ways:
123122
return cmd
124123
}
125124

125+
func secretSetClientOptions(apiBase string) api.ClientOptions {
126+
opts := api.ClientOptions{
127+
Timeout: constants.DefaultHTTPClientTimeout,
128+
}
129+
if apiBase != "" {
130+
opts.Host = normalizeSecretSetAPIHost(apiBase)
131+
}
132+
return opts
133+
}
134+
135+
func normalizeSecretSetAPIHost(apiBase string) string {
136+
raw := strings.TrimSpace(apiBase)
137+
if raw == "" {
138+
return ""
139+
}
140+
141+
candidates := []string{raw}
142+
if !strings.Contains(raw, "://") {
143+
candidates = append(candidates, "https://"+raw)
144+
}
145+
for _, candidate := range candidates {
146+
parsed, err := url.Parse(candidate)
147+
if err != nil || parsed.Hostname() == "" {
148+
continue
149+
}
150+
if parsed.Hostname() == "api.github.com" {
151+
return "github.com"
152+
}
153+
return parsed.Hostname()
154+
}
155+
156+
legacy := stringutil.ExtractDomainFromURL(raw)
157+
if legacy == "api.github.com" {
158+
return "github.com"
159+
}
160+
return legacy
161+
}
162+
126163
func resolveSecretValueForSet(fromEnv, fromFlag string) (string, error) {
127164
if fromEnv != "" {
128165
v := os.Getenv(fromEnv)

pkg/cli/secret_set_command_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"strings"
99
"testing"
1010

11+
"github.com/github/gh-aw/pkg/constants"
1112
"golang.org/x/crypto/nacl/box"
1213
)
1314

@@ -165,3 +166,45 @@ func TestEncryptDecryptRoundTrip(t *testing.T) {
165166
t.Errorf("decrypted = %q, want %q", string(decrypted), plaintext)
166167
}
167168
}
169+
170+
func TestSecretSetClientOptions(t *testing.T) {
171+
t.Run("defaults include timeout", func(t *testing.T) {
172+
opts := secretSetClientOptions("")
173+
if opts.Host != "" {
174+
t.Fatalf("expected empty host, got %q", opts.Host)
175+
}
176+
if opts.Timeout != constants.DefaultHTTPClientTimeout {
177+
t.Fatalf("expected timeout %s, got %s", constants.DefaultHTTPClientTimeout, opts.Timeout)
178+
}
179+
})
180+
181+
t.Run("normalizes host when api-url is provided", func(t *testing.T) {
182+
opts := secretSetClientOptions("https://ghe.example.com")
183+
if opts.Host != "ghe.example.com" {
184+
t.Fatalf("expected host ghe.example.com, got %q", opts.Host)
185+
}
186+
if opts.Timeout != constants.DefaultHTTPClientTimeout {
187+
t.Fatalf("expected timeout %s, got %s", constants.DefaultHTTPClientTimeout, opts.Timeout)
188+
}
189+
})
190+
191+
t.Run("strips API path from api-url", func(t *testing.T) {
192+
opts := secretSetClientOptions("https://ghe.example.com/api/v3")
193+
if opts.Host != "ghe.example.com" {
194+
t.Fatalf("expected host ghe.example.com, got %q", opts.Host)
195+
}
196+
if opts.Timeout != constants.DefaultHTTPClientTimeout {
197+
t.Fatalf("expected timeout %s, got %s", constants.DefaultHTTPClientTimeout, opts.Timeout)
198+
}
199+
})
200+
201+
t.Run("maps api.github.com to github.com", func(t *testing.T) {
202+
opts := secretSetClientOptions("https://api.github.com")
203+
if opts.Host != "github.com" {
204+
t.Fatalf("expected host github.com, got %q", opts.Host)
205+
}
206+
if opts.Timeout != constants.DefaultHTTPClientTimeout {
207+
t.Fatalf("expected timeout %s, got %s", constants.DefaultHTTPClientTimeout, opts.Timeout)
208+
}
209+
})
210+
}

pkg/cli/update_check.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ func getLatestRelease(includePrereleases bool) (string, error) {
239239
// Always target github.com explicitly: gh-aw is only published to github.com,
240240
// and users in mixed-host environments (e.g. a GHE active auth host) must
241241
// still reach the canonical registry to get the correct release metadata.
242-
client, err := api.NewRESTClient(api.ClientOptions{Host: "github.com"})
242+
client, err := api.NewRESTClient(gitHubDotComRESTClientOptions())
243243
if err != nil {
244244
return "", fmt.Errorf("failed to create GitHub client: %w", err)
245245
}
@@ -274,6 +274,13 @@ func getLatestRelease(includePrereleases bool) (string, error) {
274274
return release.TagName, nil
275275
}
276276

277+
func gitHubDotComRESTClientOptions() api.ClientOptions {
278+
return api.ClientOptions{
279+
Host: "github.com",
280+
Timeout: constants.DefaultHTTPClientTimeout,
281+
}
282+
}
283+
277284
// findLatestPublishedReleaseTag returns the first non-draft release tag from the
278285
// releases API response, skipping entries without tag names.
279286
func findLatestPublishedReleaseTag(releases []Release) string {

pkg/cli/update_check_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010
"time"
1111

12+
"github.com/github/gh-aw/pkg/constants"
1213
"github.com/stretchr/testify/assert"
1314
)
1415

@@ -440,3 +441,9 @@ func TestIsCurrentVersionAtLeastLatest(t *testing.T) {
440441
})
441442
}
442443
}
444+
445+
func TestGitHubDotComRESTClientOptions(t *testing.T) {
446+
opts := gitHubDotComRESTClientOptions()
447+
assert.Equal(t, "github.com", opts.Host)
448+
assert.Equal(t, constants.DefaultHTTPClientTimeout, opts.Timeout)
449+
}

pkg/parser/remote_fetch.go

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ func downloadFileViaGitClone(owner, repo, path, ref, host string) ([]byte, error
744744
// Returns the symlink target and true if it is a symlink, or empty string and false otherwise.
745745
// A nil error with false means the path is not a symlink (e.g., it's a directory or file).
746746
func checkRemoteSymlink(client *api.RESTClient, owner, repo, dirPath, ref string) (string, bool, error) {
747-
endpoint := fmt.Sprintf("repos/%s/%s/contents/%s?ref=%s", owner, repo, dirPath, ref)
747+
endpoint := buildContentsAPIPath(owner, repo, dirPath, ref)
748748
remoteLog.Printf("Checking if path component is symlink: %s/%s/%s@%s", owner, repo, dirPath, ref)
749749

750750
// The Contents API returns a JSON object for files/symlinks but a JSON array for directories.
@@ -951,14 +951,29 @@ func downloadFileFromGitHubWithDepth(owner, repo, path, ref string, symlinkDepth
951951
}
952952

953953
func createRESTClientForHost(host string) (*api.RESTClient, error) {
954+
opts := api.ClientOptions{Timeout: constants.DefaultHTTPClientTimeout}
954955
if host != "" {
955-
return api.NewRESTClient(api.ClientOptions{Host: host})
956+
opts.Host = host
956957
}
957-
return api.DefaultRESTClient()
958+
return api.NewRESTClient(opts)
959+
}
960+
961+
func buildContentsAPIPath(owner, repo, path, ref string) string {
962+
pathSegments := strings.Split(path, "/")
963+
for i := range pathSegments {
964+
pathSegments[i] = url.PathEscape(pathSegments[i])
965+
}
966+
return fmt.Sprintf(
967+
"repos/%s/%s/contents/%s?ref=%s",
968+
owner,
969+
repo,
970+
strings.Join(pathSegments, "/"),
971+
url.QueryEscape(ref),
972+
)
958973
}
959974

960975
func fetchRemoteFileContent(client *api.RESTClient, owner, repo, path, ref string, fileContent any) error {
961-
return client.Get(fmt.Sprintf("repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), fileContent)
976+
return client.Get(buildContentsAPIPath(owner, repo, path, ref), fileContent)
962977
}
963978

964979
// downloadFileViaPublicAPI downloads a file from a public GitHub repository
@@ -1020,16 +1035,7 @@ func ListWorkflowFilesForHost(owner, repo, ref, workflowPath, host string) ([]st
10201035
func listWorkflowFilesForHost(owner, repo, ref, workflowPath, host string) ([]string, error) {
10211036
remoteLog.Printf("Listing workflow files for %s/%s@%s (path: %s)", owner, repo, ref, workflowPath)
10221037

1023-
// Create REST client
1024-
var (
1025-
client *api.RESTClient
1026-
err error
1027-
)
1028-
if host != "" {
1029-
client, err = api.NewRESTClient(api.ClientOptions{Host: host})
1030-
} else {
1031-
client, err = api.DefaultRESTClient()
1032-
}
1038+
client, err := createRESTClientForHost(host)
10331039
if err != nil {
10341040
remoteLog.Printf("Failed to create REST client, attempting git fallback: %v", err)
10351041
return listWorkflowFilesViaGitForHost(owner, repo, ref, workflowPath, host)
@@ -1043,7 +1049,7 @@ func listWorkflowFilesForHost(owner, repo, ref, workflowPath, host string) ([]st
10431049
}
10441050

10451051
// Fetch directory contents from GitHub API
1046-
endpoint := fmt.Sprintf("repos/%s/%s/contents/%s?ref=%s", owner, repo, workflowPath, ref)
1052+
endpoint := buildContentsAPIPath(owner, repo, workflowPath, ref)
10471053
err = client.Get(endpoint, &contents)
10481054
if err != nil {
10491055
errStr := err.Error()
@@ -1088,15 +1094,7 @@ func ListDirAllFilesForHost(owner, repo, ref, dirPath, host string) ([]string, e
10881094
func listDirAllFilesForHost(owner, repo, ref, dirPath, host string) ([]string, error) {
10891095
remoteLog.Printf("Listing all files in dir for %s/%s@%s (path: %s)", owner, repo, ref, dirPath)
10901096

1091-
var (
1092-
client *api.RESTClient
1093-
err error
1094-
)
1095-
if host != "" {
1096-
client, err = api.NewRESTClient(api.ClientOptions{Host: host})
1097-
} else {
1098-
client, err = api.DefaultRESTClient()
1099-
}
1097+
client, err := createRESTClientForHost(host)
11001098
if err != nil {
11011099
remoteLog.Printf("Failed to create REST client, attempting git fallback: %v", err)
11021100
return listDirAllFilesViaGitForHost(owner, repo, ref, dirPath, host)
@@ -1108,7 +1106,7 @@ func listDirAllFilesForHost(owner, repo, ref, dirPath, host string) ([]string, e
11081106
Type string `json:"type"`
11091107
}
11101108

1111-
endpoint := fmt.Sprintf("repos/%s/%s/contents/%s?ref=%s", owner, repo, dirPath, ref)
1109+
endpoint := buildContentsAPIPath(owner, repo, dirPath, ref)
11121110
err = client.Get(endpoint, &contents)
11131111
if err != nil {
11141112
errStr := err.Error()
@@ -1209,15 +1207,7 @@ func ListDirAllFilesRecursivelyForHost(owner, repo, ref, dirPath, host string) (
12091207
func listDirAllFilesRecursivelyForHost(owner, repo, ref, dirPath, host string) ([]string, error) {
12101208
remoteLog.Printf("Listing all files recursively in dir for %s/%s@%s (path: %s)", owner, repo, ref, dirPath)
12111209

1212-
var (
1213-
client *api.RESTClient
1214-
err error
1215-
)
1216-
if host != "" {
1217-
client, err = api.NewRESTClient(api.ClientOptions{Host: host})
1218-
} else {
1219-
client, err = api.DefaultRESTClient()
1220-
}
1210+
client, err := createRESTClientForHost(host)
12211211
if err != nil {
12221212
remoteLog.Printf("Failed to create REST client, attempting git fallback: %v", err)
12231213
return listDirAllFilesRecursivelyViaGitForHost(owner, repo, ref, dirPath, host)
@@ -1262,7 +1252,7 @@ func listContentsRecursivelyWithDepth(client *api.RESTClient, owner, repo, ref,
12621252
Type string `json:"type"`
12631253
}
12641254

1265-
endpoint := fmt.Sprintf("repos/%s/%s/contents/%s?ref=%s", owner, repo, dirPath, ref)
1255+
endpoint := buildContentsAPIPath(owner, repo, dirPath, ref)
12661256
if err := client.Get(endpoint, &contents); err != nil {
12671257
return nil, fmt.Errorf("failed to list dir files from %s/%s (path: %s): %w", owner, repo, dirPath, err)
12681258
}
@@ -1363,15 +1353,7 @@ func ListDirSubdirsForHost(owner, repo, ref, dirPath, host string) ([]string, er
13631353
func listDirSubdirsForHost(owner, repo, ref, dirPath, host string) ([]string, error) {
13641354
remoteLog.Printf("Listing subdirs in %s/%s@%s (path: %s)", owner, repo, ref, dirPath)
13651355

1366-
var (
1367-
client *api.RESTClient
1368-
err error
1369-
)
1370-
if host != "" {
1371-
client, err = api.NewRESTClient(api.ClientOptions{Host: host})
1372-
} else {
1373-
client, err = api.DefaultRESTClient()
1374-
}
1356+
client, err := createRESTClientForHost(host)
13751357
if err != nil {
13761358
remoteLog.Printf("Failed to create REST client, attempting git fallback: %v", err)
13771359
return listDirSubdirsViaGitForHost(owner, repo, ref, dirPath, host)
@@ -1383,7 +1365,7 @@ func listDirSubdirsForHost(owner, repo, ref, dirPath, host string) ([]string, er
13831365
Type string `json:"type"`
13841366
}
13851367

1386-
endpoint := fmt.Sprintf("repos/%s/%s/contents/%s?ref=%s", owner, repo, dirPath, ref)
1368+
endpoint := buildContentsAPIPath(owner, repo, dirPath, ref)
13871369
err = client.Get(endpoint, &contents)
13881370
if err != nil {
13891371
errStr := err.Error()

pkg/parser/remote_fetch_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,32 @@ func TestBuildCommitLookupAPIPath(t *testing.T) {
2525
})
2626
}
2727

28+
func TestBuildContentsAPIPath(t *testing.T) {
29+
t.Run("escapes refs with reserved query chars", func(t *testing.T) {
30+
got := buildContentsAPIPath("owner", "repo", ".github/workflows/demo.md", "release+candidate#1")
31+
want := "repos/owner/repo/contents/.github/workflows/demo.md?ref=release%2Bcandidate%231"
32+
if got != want {
33+
t.Fatalf("buildContentsAPIPath() = %q, want %q", got, want)
34+
}
35+
})
36+
37+
t.Run("keeps plain refs readable", func(t *testing.T) {
38+
got := buildContentsAPIPath("owner", "repo", ".github/workflows/demo.md", "main")
39+
want := "repos/owner/repo/contents/.github/workflows/demo.md?ref=main"
40+
if got != want {
41+
t.Fatalf("buildContentsAPIPath() = %q, want %q", got, want)
42+
}
43+
})
44+
45+
t.Run("escapes path segments with reserved chars", func(t *testing.T) {
46+
got := buildContentsAPIPath("owner", "repo", "skills/path with spaces/file#100%.md", "main")
47+
want := "repos/owner/repo/contents/skills/path%20with%20spaces/file%23100%25.md?ref=main"
48+
if got != want {
49+
t.Fatalf("buildContentsAPIPath() = %q, want %q", got, want)
50+
}
51+
})
52+
}
53+
2854
func TestGitFallbackRequiresNonEmptyRef(t *testing.T) {
2955
t.Run("all files fallback validates ref", func(t *testing.T) {
3056
_, err := listDirAllFilesViaGitForHost("owner", "repo", "", "skills/demo", "")

0 commit comments

Comments
 (0)