diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index bc1ae412f..eba99c506 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -1,8 +1,12 @@ package github import ( + "bytes" "encoding/json" + "fmt" + "io" "net/http" + "strings" "testing" "github.com/mark3labs/mcp-go/mcp" @@ -108,6 +112,74 @@ func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc { } } +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func newLockdownMockGQLHTTPClient(t *testing.T, isPrivate bool, permissions map[string]string) *http.Client { + t.Helper() + + lowerPermissions := make(map[string]string, len(permissions)) + for user, perm := range permissions { + lowerPermissions[strings.ToLower(user)] = perm + } + + return &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + require.Equal(t, "/graphql", req.URL.Path) + + bodyBytes, err := io.ReadAll(req.Body) + require.NoError(t, err) + _ = req.Body.Close() + + var gqlRequest struct { + Variables map[string]any `json:"variables"` + } + require.NoError(t, json.Unmarshal(bodyBytes, &gqlRequest)) + + rawUsername, ok := gqlRequest.Variables["username"] + require.True(t, ok, "expected username variable in GraphQL request") + + username := fmt.Sprint(rawUsername) + permission := lowerPermissions[strings.ToLower(username)] + + edges := []any{} + if permission != "" { + edges = append(edges, map[string]any{ + "permission": permission, + "node": map[string]any{ + "login": username, + }, + }) + } + + response := map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "isPrivate": isPrivate, + "collaborators": map[string]any{ + "edges": edges, + }, + }, + }, + } + + respBytes, err := json.Marshal(response) + require.NoError(t, err) + + res := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(respBytes)), + } + res.Header.Set("Content-Type", "application/json") + return res, nil + }), + } +} + // createMCPRequest is a helper function to create a MCP request with the given arguments. func createMCPRequest(args any) mcp.CallToolRequest { return mcp.CallToolRequest{ diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 1032d4d04..da8ebf6a9 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -299,9 +299,9 @@ Options are: case "get": return GetIssue(ctx, client, gqlClient, owner, repo, issueNumber, flags) case "get_comments": - return GetIssueComments(ctx, client, owner, repo, issueNumber, pagination, flags) + return GetIssueComments(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags) case "get_sub_issues": - return GetSubIssues(ctx, client, owner, repo, issueNumber, pagination, flags) + return GetSubIssues(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags) case "get_labels": return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber, flags) default: @@ -355,7 +355,7 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl return mcp.NewToolResultText(string(r)), nil } -func GetIssueComments(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssueComments(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.IssueListCommentsOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -377,6 +377,24 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string, return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil } + if flags.LockdownMode { + filtered := make([]*github.IssueComment, 0, len(comments)) + for _, comment := range comments { + if comment == nil || comment.User == nil || comment.User.Login == nil { + continue + } + shouldRemove, err := lockdown.ShouldRemoveContent(ctx, gqlClient, comment.User.GetLogin(), owner, repo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil + } + if shouldRemove { + continue + } + filtered = append(filtered, comment) + } + comments = filtered + } + r, err := json.Marshal(comments) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) @@ -385,7 +403,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string, return mcp.NewToolResultText(string(r)), nil } -func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) { +func GetSubIssues(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.IssueListOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -412,6 +430,24 @@ func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo return mcp.NewToolResultError(fmt.Sprintf("failed to list sub-issues: %s", string(body))), nil } + if flags.LockdownMode { + filtered := make([]*github.SubIssue, 0, len(subIssues)) + for _, subIssue := range subIssues { + if subIssue == nil || subIssue.User == nil || subIssue.User.Login == nil { + continue + } + shouldRemove, err := lockdown.ShouldRemoveContent(ctx, gqlClient, subIssue.User.GetLogin(), owner, repo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil + } + if shouldRemove { + continue + } + filtered = append(filtered, subIssue) + } + subIssues = filtered + } + r, err := json.Marshal(subIssues) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index d13b93e4b..2574933b0 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -52,6 +52,23 @@ func Test_GetIssue(t *testing.T) { }, } + mockPrivateIssue := &github.Issue{ + Number: github.Ptr(42), + Title: github.Ptr("Test Issue"), + Body: github.Ptr("This is a test issue"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42"), + User: &github.User{ + Login: github.Ptr("privateuser"), + }, + Repository: &github.Repository{ + Name: github.Ptr("repo"), + Owner: &github.User{ + Login: github.Ptr("owner"), + }, + }, + } + tests := []struct { name string mockedClient *http.Client @@ -101,7 +118,7 @@ func Test_GetIssue(t *testing.T) { mockedClient: mock.NewMockedHTTPClient( mock.WithRequestMatch( mock.GetReposIssuesByOwnerByRepoByIssueNumber, - mockIssue, + mockPrivateIssue, ), ), gqlHTTPClient: githubv4mock.NewMockedHTTPClient( @@ -122,7 +139,7 @@ func Test_GetIssue(t *testing.T) { map[string]any{ "owner": githubv4.String("owner"), "name": githubv4.String("repo"), - "username": githubv4.String("testuser"), + "username": githubv4.String("privateuser"), }, githubv4mock.DataResponse(map[string]any{ "repository": map[string]any{ @@ -140,7 +157,7 @@ func Test_GetIssue(t *testing.T) { "repo": "repo", "issue_number": float64(42), }, - expectedIssue: mockIssue, + expectedIssue: mockPrivateIssue, lockdownEnabled: true, }, { @@ -1746,10 +1763,12 @@ func Test_GetIssueComments(t *testing.T) { tests := []struct { name string mockedClient *http.Client + gqlHTTPClient *http.Client requestArgs map[string]interface{} expectError bool expectedComments []*github.IssueComment expectedErrMsg string + lockdownEnabled bool }{ { name: "successful comments retrieval", @@ -1765,7 +1784,6 @@ func Test_GetIssueComments(t *testing.T) { "repo": "repo", "issue_number": float64(42), }, - expectError: false, expectedComments: mockComments, }, { @@ -1792,6 +1810,27 @@ func Test_GetIssueComments(t *testing.T) { expectError: false, expectedComments: mockComments, }, + { + name: "lockdown enabled removes comments without push access", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposIssuesCommentsByOwnerByRepoByIssueNumber, + mockComments, + ), + ), + gqlHTTPClient: newLockdownMockGQLHTTPClient(t, false, map[string]string{ + "user1": "WRITE", + "user2": "READ", + }), + requestArgs: map[string]interface{}{ + "method": "get_comments", + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + }, + expectedComments: []*github.IssueComment{mockComments[0]}, + lockdownEnabled: true, + }, { name: "issue not found", mockedClient: mock.NewMockedHTTPClient( @@ -1815,8 +1854,14 @@ func Test_GetIssueComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + var gqlClient *githubv4.Client + if tc.gqlHTTPClient != nil { + gqlClient = githubv4.NewClient(tc.gqlHTTPClient) + } else { + gqlClient = githubv4.NewClient(nil) + } + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1839,9 +1884,9 @@ func Test_GetIssueComments(t *testing.T) { err = json.Unmarshal([]byte(textContent.Text), &returnedComments) require.NoError(t, err) assert.Equal(t, len(tc.expectedComments), len(returnedComments)) - if len(returnedComments) > 0 { - assert.Equal(t, *tc.expectedComments[0].Body, *returnedComments[0].Body) - assert.Equal(t, *tc.expectedComments[0].User.Login, *returnedComments[0].User.Login) + for i := range returnedComments { + assert.Equal(t, *tc.expectedComments[i].Body, *returnedComments[i].Body) + assert.Equal(t, *tc.expectedComments[i].User.Login, *returnedComments[i].User.Login) } }) } @@ -2669,10 +2714,12 @@ func Test_GetSubIssues(t *testing.T) { tests := []struct { name string mockedClient *http.Client + gqlHTTPClient *http.Client requestArgs map[string]interface{} expectError bool expectedSubIssues []*github.Issue expectedErrMsg string + lockdownEnabled bool }{ { name: "successful sub-issues listing with minimal parameters", @@ -2712,7 +2759,6 @@ func Test_GetSubIssues(t *testing.T) { "page": float64(2), "perPage": float64(10), }, - expectError: false, expectedSubIssues: mockSubIssues, }, { @@ -2729,9 +2775,29 @@ func Test_GetSubIssues(t *testing.T) { "repo": "repo", "issue_number": float64(42), }, - expectError: false, expectedSubIssues: []*github.Issue{}, }, + { + name: "lockdown enabled filters sub-issues without push access", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposIssuesSubIssuesByOwnerByRepoByIssueNumber, + mockSubIssues, + ), + ), + gqlHTTPClient: newLockdownMockGQLHTTPClient(t, false, map[string]string{ + "user1": "WRITE", + "user2": "READ", + }), + requestArgs: map[string]interface{}{ + "method": "get_sub_issues", + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + }, + expectedSubIssues: []*github.Issue{mockSubIssues[0]}, + lockdownEnabled: true, + }, { name: "parent issue not found", mockedClient: mock.NewMockedHTTPClient( @@ -2815,8 +2881,14 @@ func Test_GetSubIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + var gqlClient *githubv4.Client + if tc.gqlHTTPClient != nil { + gqlClient = githubv4.NewClient(tc.gqlHTTPClient) + } else { + gqlClient = githubv4.NewClient(nil) + } + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 117f92ecf..8de0bb32d 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -19,7 +19,7 @@ import ( ) // GetPullRequest creates a tool to get details of a specific pull request. -func PullRequestRead(getClient GetClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) { +func PullRequestRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("pull_request_read", mcp.WithDescription(t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -98,7 +98,11 @@ Possible options: case "get_reviews": return GetPullRequestReviews(ctx, client, owner, repo, pullNumber) case "get_comments": - return GetIssueComments(ctx, client, owner, repo, pullNumber, pagination, flags) + gqlClient, err := getGQLClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub graphql client: %w", err) + } + return GetIssueComments(ctx, client, gqlClient, owner, repo, pullNumber, pagination, flags) default: return nil, fmt.Errorf("unknown method: %s", method) } diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 4cc4480e9..13de91a88 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -21,7 +21,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -102,7 +102,8 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + gqlClient := githubv4.NewClient(nil) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1133,7 +1134,7 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1236,7 +1237,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1277,7 +1278,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1404,7 +1405,7 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1566,7 +1567,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1658,7 +1659,7 @@ func Test_GetPullRequestComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1700,7 +1701,7 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1788,7 +1789,7 @@ func Test_GetPullRequestReviews(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -2789,7 +2790,7 @@ func TestGetPullRequestDiff(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -2847,7 +2848,7 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 36c22e7a8..87bbf4834 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -224,7 +224,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). AddReadTools( - toolsets.NewServerTool(PullRequestRead(getClient, t, flags)), + toolsets.NewServerTool(PullRequestRead(getClient, getGQLClient, t, flags)), toolsets.NewServerTool(ListPullRequests(getClient, t)), toolsets.NewServerTool(SearchPullRequests(getClient, t)), ). diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 5a474f73c..1cce0a3ca 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -4,19 +4,70 @@ import ( "context" "fmt" "strings" + "sync" + "time" "github.com/shurcooL/githubv4" ) +type repoAccessKey struct { + owner string + repo string + username string +} + +type repoAccessEntry struct { + isPrivate bool + hasPush bool + loadedAt time.Time +} + +var ( + repoAccessCache sync.Map + repoAccessInfoFunc = repoAccessInfo + timeNow = time.Now +) + +// repoAccessRefreshInterval defines how long to cache repository access +// information before refreshing it. +const repoAccessRefreshInterval = 10 * time.Minute + +func newRepoAccessKey(username, owner, repo string) repoAccessKey { + return repoAccessKey{ + owner: strings.ToLower(owner), + repo: strings.ToLower(repo), + username: strings.ToLower(username), + } +} + // ShouldRemoveContent determines if content should be removed based on // lockdown mode rules. It checks if the repository is private and if the user // has push access to the repository. func ShouldRemoveContent(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, error) { - isPrivate, hasPushAccess, err := repoAccessInfo(ctx, client, username, owner, repo) + key := newRepoAccessKey(username, owner, repo) + + now := timeNow() + if cached, ok := repoAccessCache.Load(key); ok { + entry := cached.(repoAccessEntry) + if now.Sub(entry.loadedAt) < repoAccessRefreshInterval { + if entry.isPrivate { + return false, nil + } + return !entry.hasPush, nil + } + } + + isPrivate, hasPushAccess, err := repoAccessInfoFunc(ctx, client, username, owner, repo) if err != nil { return false, err } + repoAccessCache.Store(key, repoAccessEntry{ + isPrivate: isPrivate, + hasPush: hasPushAccess, + loadedAt: timeNow(), + }) + // Do not filter content for private repositories if isPrivate { return false, nil @@ -25,6 +76,14 @@ func ShouldRemoveContent(ctx context.Context, client *githubv4.Client, username, return !hasPushAccess, nil } +// clearRepoAccessCache removes all cached repository access information; used by tests. +func clearRepoAccessCache() { + repoAccessCache.Range(func(key, _ any) bool { + repoAccessCache.Delete(key) + return true + }) +} + func repoAccessInfo(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, bool, error) { if client == nil { return false, false, fmt.Errorf("nil GraphQL client") diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go new file mode 100644 index 000000000..ade7c3af9 --- /dev/null +++ b/pkg/lockdown/lockdown_test.go @@ -0,0 +1,155 @@ +package lockdown + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/shurcooL/githubv4" +) + +func TestShouldRemoveContentCachesResultsWithinInterval(t *testing.T) { + clearRepoAccessCache() + defer clearRepoAccessCache() + + originalInfoFunc := repoAccessInfoFunc + defer func() { repoAccessInfoFunc = originalInfoFunc }() + + originalTimeNow := timeNow + defer func() { timeNow = originalTimeNow }() + + fixed := time.Now() + timeNow = func() time.Time { return fixed } + + callCount := 0 + repoAccessInfoFunc = func(_ context.Context, _ *githubv4.Client, _, _, _ string) (bool, bool, error) { + callCount++ + return false, true, nil + } + + ctx := context.Background() + + remove, err := ShouldRemoveContent(ctx, nil, "User", "Owner", "Repo") + if err != nil { + t.Fatalf("unexpected error on first call: %v", err) + } + if remove { + t.Fatalf("expected remove=false when user has push access") + } + + remove, err = ShouldRemoveContent(ctx, nil, "user", "owner", "repo") + if err != nil { + t.Fatalf("unexpected error on cached call: %v", err) + } + if remove { + t.Fatalf("expected remove=false when cached entry reused") + } + if callCount != 1 { + t.Fatalf("expected cached result to prevent additional repo access queries, got %d", callCount) + } +} + +func TestShouldRemoveContentRefreshesAfterInterval(t *testing.T) { + clearRepoAccessCache() + defer clearRepoAccessCache() + + originalInfoFunc := repoAccessInfoFunc + defer func() { repoAccessInfoFunc = originalInfoFunc }() + + originalTimeNow := timeNow + defer func() { timeNow = originalTimeNow }() + + base := time.Now() + current := base + timeNow = func() time.Time { return current } + + callCount := 0 + repoAccessInfoFunc = func(_ context.Context, _ *githubv4.Client, _, _, _ string) (bool, bool, error) { + callCount++ + if callCount == 1 { + return false, false, nil + } + return false, true, nil + } + + ctx := context.Background() + + remove, err := ShouldRemoveContent(ctx, nil, "user", "owner", "repo") + if err != nil { + t.Fatalf("unexpected error on first call: %v", err) + } + if !remove { + t.Fatalf("expected remove=true when user lacks push access") + } + if callCount != 1 { + t.Fatalf("expected first call to query once, got %d", callCount) + } + + current = base.Add(9 * time.Minute) + remove, err = ShouldRemoveContent(ctx, nil, "user", "owner", "repo") + if err != nil { + t.Fatalf("unexpected error before refresh interval: %v", err) + } + if !remove { + t.Fatalf("expected remove=true before refresh interval expires") + } + if callCount != 1 { + t.Fatalf("expected cached value before refresh interval, got %d calls", callCount) + } + + current = base.Add(11 * time.Minute) + remove, err = ShouldRemoveContent(ctx, nil, "user", "owner", "repo") + if err != nil { + t.Fatalf("unexpected error after refresh interval: %v", err) + } + if remove { + t.Fatalf("expected remove=false after permissions refreshed") + } + if callCount != 2 { + t.Fatalf("expected refreshed access info after interval, got %d calls", callCount) + } +} + +func TestShouldRemoveContentDoesNotCacheErrors(t *testing.T) { + clearRepoAccessCache() + defer clearRepoAccessCache() + + originalInfoFunc := repoAccessInfoFunc + defer func() { repoAccessInfoFunc = originalInfoFunc }() + + originalTimeNow := timeNow + defer func() { timeNow = originalTimeNow }() + + now := time.Now() + timeNow = func() time.Time { return now } + + callCount := 0 + repoAccessInfoFunc = func(_ context.Context, _ *githubv4.Client, _, _, _ string) (bool, bool, error) { + callCount++ + if callCount == 1 { + return false, false, errors.New("boom") + } + return false, false, nil + } + + ctx := context.Background() + + if _, err := ShouldRemoveContent(ctx, nil, "user", "owner", "repo"); err == nil { + t.Fatal("expected error on first call") + } + if callCount != 1 { + t.Fatalf("expected single call after error, got %d", callCount) + } + + remove, err := ShouldRemoveContent(ctx, nil, "user", "owner", "repo") + if err != nil { + t.Fatalf("unexpected error on retry: %v", err) + } + if !remove { + t.Fatalf("expected remove=true when user lacks push access") + } + if callCount != 2 { + t.Fatalf("expected repo access to be queried again after error, got %d calls", callCount) + } +}