Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions pkg/github/helper_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package github

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"

"github.com/mark3labs/mcp-go/mcp"
Expand Down Expand Up @@ -108,6 +112,74 @@ func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc {
}
}

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func newLockdownMockGQLHTTPClient(t *testing.T, isPrivate bool, permissions map[string]string) *http.Client {
t.Helper()

lowerPermissions := make(map[string]string, len(permissions))
for user, perm := range permissions {
lowerPermissions[strings.ToLower(user)] = perm
}

return &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
require.Equal(t, "/graphql", req.URL.Path)

bodyBytes, err := io.ReadAll(req.Body)
require.NoError(t, err)
_ = req.Body.Close()

var gqlRequest struct {
Variables map[string]any `json:"variables"`
}
require.NoError(t, json.Unmarshal(bodyBytes, &gqlRequest))

rawUsername, ok := gqlRequest.Variables["username"]
require.True(t, ok, "expected username variable in GraphQL request")

username := fmt.Sprint(rawUsername)
permission := lowerPermissions[strings.ToLower(username)]

edges := []any{}
if permission != "" {
edges = append(edges, map[string]any{
"permission": permission,
"node": map[string]any{
"login": username,
},
})
}

response := map[string]any{
"data": map[string]any{
"repository": map[string]any{
"isPrivate": isPrivate,
"collaborators": map[string]any{
"edges": edges,
},
},
},
}

respBytes, err := json.Marshal(response)
require.NoError(t, err)

res := &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(respBytes)),
}
res.Header.Set("Content-Type", "application/json")
return res, nil
}),
}
}

// createMCPRequest is a helper function to create a MCP request with the given arguments.
func createMCPRequest(args any) mcp.CallToolRequest {
return mcp.CallToolRequest{
Expand Down
44 changes: 40 additions & 4 deletions pkg/github/issues.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ Options are:
case "get":
return GetIssue(ctx, client, gqlClient, owner, repo, issueNumber, flags)
case "get_comments":
return GetIssueComments(ctx, client, owner, repo, issueNumber, pagination, flags)
return GetIssueComments(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags)
case "get_sub_issues":
return GetSubIssues(ctx, client, owner, repo, issueNumber, pagination, flags)
return GetSubIssues(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags)
case "get_labels":
return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber, flags)
default:
Expand Down Expand Up @@ -355,7 +355,7 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl
return mcp.NewToolResultText(string(r)), nil
}

func GetIssueComments(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) {
func GetIssueComments(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) {
opts := &github.IssueListCommentsOptions{
ListOptions: github.ListOptions{
Page: pagination.Page,
Expand All @@ -377,6 +377,24 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string,
return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil
}

if flags.LockdownMode {
filtered := make([]*github.IssueComment, 0, len(comments))
for _, comment := range comments {
if comment == nil || comment.User == nil || comment.User.Login == nil {
continue
}
shouldRemove, err := lockdown.ShouldRemoveContent(ctx, gqlClient, comment.User.GetLogin(), owner, repo)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
}
if shouldRemove {
continue
}
filtered = append(filtered, comment)
}
comments = filtered
}

r, err := json.Marshal(comments)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
Expand All @@ -385,7 +403,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string,
return mcp.NewToolResultText(string(r)), nil
}

func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) {
func GetSubIssues(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) {
opts := &github.IssueListOptions{
ListOptions: github.ListOptions{
Page: pagination.Page,
Expand All @@ -412,6 +430,24 @@ func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo
return mcp.NewToolResultError(fmt.Sprintf("failed to list sub-issues: %s", string(body))), nil
}

if flags.LockdownMode {
filtered := make([]*github.SubIssue, 0, len(subIssues))
for _, subIssue := range subIssues {
if subIssue == nil || subIssue.User == nil || subIssue.User.Login == nil {
continue
}
shouldRemove, err := lockdown.ShouldRemoveContent(ctx, gqlClient, subIssue.User.GetLogin(), owner, repo)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
}
if shouldRemove {
continue
}
filtered = append(filtered, subIssue)
}
subIssues = filtered
}

r, err := json.Marshal(subIssues)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
Expand Down
98 changes: 85 additions & 13 deletions pkg/github/issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ func Test_GetIssue(t *testing.T) {
},
}

mockPrivateIssue := &github.Issue{
Number: github.Ptr(42),
Title: github.Ptr("Test Issue"),
Body: github.Ptr("This is a test issue"),
State: github.Ptr("open"),
HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42"),
User: &github.User{
Login: github.Ptr("privateuser"),
},
Repository: &github.Repository{
Name: github.Ptr("repo"),
Owner: &github.User{
Login: github.Ptr("owner"),
},
},
}

tests := []struct {
name string
mockedClient *http.Client
Expand Down Expand Up @@ -101,7 +118,7 @@ func Test_GetIssue(t *testing.T) {
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatch(
mock.GetReposIssuesByOwnerByRepoByIssueNumber,
mockIssue,
mockPrivateIssue,
),
),
gqlHTTPClient: githubv4mock.NewMockedHTTPClient(
Expand All @@ -122,7 +139,7 @@ func Test_GetIssue(t *testing.T) {
map[string]any{
"owner": githubv4.String("owner"),
"name": githubv4.String("repo"),
"username": githubv4.String("testuser"),
"username": githubv4.String("privateuser"),
},
githubv4mock.DataResponse(map[string]any{
"repository": map[string]any{
Expand All @@ -140,7 +157,7 @@ func Test_GetIssue(t *testing.T) {
"repo": "repo",
"issue_number": float64(42),
},
expectedIssue: mockIssue,
expectedIssue: mockPrivateIssue,
lockdownEnabled: true,
},
{
Expand Down Expand Up @@ -1746,10 +1763,12 @@ func Test_GetIssueComments(t *testing.T) {
tests := []struct {
name string
mockedClient *http.Client
gqlHTTPClient *http.Client
requestArgs map[string]interface{}
expectError bool
expectedComments []*github.IssueComment
expectedErrMsg string
lockdownEnabled bool
}{
{
name: "successful comments retrieval",
Expand All @@ -1765,7 +1784,6 @@ func Test_GetIssueComments(t *testing.T) {
"repo": "repo",
"issue_number": float64(42),
},
expectError: false,
expectedComments: mockComments,
},
{
Expand All @@ -1792,6 +1810,27 @@ func Test_GetIssueComments(t *testing.T) {
expectError: false,
expectedComments: mockComments,
},
{
name: "lockdown enabled removes comments without push access",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatch(
mock.GetReposIssuesCommentsByOwnerByRepoByIssueNumber,
mockComments,
),
),
gqlHTTPClient: newLockdownMockGQLHTTPClient(t, false, map[string]string{
"user1": "WRITE",
"user2": "READ",
}),
requestArgs: map[string]interface{}{
"method": "get_comments",
"owner": "owner",
"repo": "repo",
"issue_number": float64(42),
},
expectedComments: []*github.IssueComment{mockComments[0]},
lockdownEnabled: true,
},
{
name: "issue not found",
mockedClient: mock.NewMockedHTTPClient(
Expand All @@ -1815,8 +1854,14 @@ func Test_GetIssueComments(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockedClient)
gqlClient := githubv4.NewClient(nil)
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
var gqlClient *githubv4.Client
if tc.gqlHTTPClient != nil {
gqlClient = githubv4.NewClient(tc.gqlHTTPClient)
} else {
gqlClient = githubv4.NewClient(nil)
}
flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled})
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags)

// Create call request
request := createMCPRequest(tc.requestArgs)
Expand All @@ -1839,9 +1884,9 @@ func Test_GetIssueComments(t *testing.T) {
err = json.Unmarshal([]byte(textContent.Text), &returnedComments)
require.NoError(t, err)
assert.Equal(t, len(tc.expectedComments), len(returnedComments))
if len(returnedComments) > 0 {
assert.Equal(t, *tc.expectedComments[0].Body, *returnedComments[0].Body)
assert.Equal(t, *tc.expectedComments[0].User.Login, *returnedComments[0].User.Login)
for i := range returnedComments {
assert.Equal(t, *tc.expectedComments[i].Body, *returnedComments[i].Body)
assert.Equal(t, *tc.expectedComments[i].User.Login, *returnedComments[i].User.Login)
}
})
}
Expand Down Expand Up @@ -2669,10 +2714,12 @@ func Test_GetSubIssues(t *testing.T) {
tests := []struct {
name string
mockedClient *http.Client
gqlHTTPClient *http.Client
requestArgs map[string]interface{}
expectError bool
expectedSubIssues []*github.Issue
expectedErrMsg string
lockdownEnabled bool
}{
{
name: "successful sub-issues listing with minimal parameters",
Expand Down Expand Up @@ -2712,7 +2759,6 @@ func Test_GetSubIssues(t *testing.T) {
"page": float64(2),
"perPage": float64(10),
},
expectError: false,
expectedSubIssues: mockSubIssues,
},
{
Expand All @@ -2729,9 +2775,29 @@ func Test_GetSubIssues(t *testing.T) {
"repo": "repo",
"issue_number": float64(42),
},
expectError: false,
expectedSubIssues: []*github.Issue{},
},
{
name: "lockdown enabled filters sub-issues without push access",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatch(
mock.GetReposIssuesSubIssuesByOwnerByRepoByIssueNumber,
mockSubIssues,
),
),
gqlHTTPClient: newLockdownMockGQLHTTPClient(t, false, map[string]string{
"user1": "WRITE",
"user2": "READ",
}),
requestArgs: map[string]interface{}{
"method": "get_sub_issues",
"owner": "owner",
"repo": "repo",
"issue_number": float64(42),
},
expectedSubIssues: []*github.Issue{mockSubIssues[0]},
lockdownEnabled: true,
},
{
name: "parent issue not found",
mockedClient: mock.NewMockedHTTPClient(
Expand Down Expand Up @@ -2815,8 +2881,14 @@ func Test_GetSubIssues(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockedClient)
gqlClient := githubv4.NewClient(nil)
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
var gqlClient *githubv4.Client
if tc.gqlHTTPClient != nil {
gqlClient = githubv4.NewClient(tc.gqlHTTPClient)
} else {
gqlClient = githubv4.NewClient(nil)
}
flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled})
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags)

// Create call request
request := createMCPRequest(tc.requestArgs)
Expand Down
8 changes: 6 additions & 2 deletions pkg/github/pullrequests.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
)

// GetPullRequest creates a tool to get details of a specific pull request.
func PullRequestRead(getClient GetClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) {
func PullRequestRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) {
return mcp.NewTool("pull_request_read",
mcp.WithDescription(t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
Expand Down Expand Up @@ -98,7 +98,11 @@ Possible options:
case "get_reviews":
return GetPullRequestReviews(ctx, client, owner, repo, pullNumber)
case "get_comments":
return GetIssueComments(ctx, client, owner, repo, pullNumber, pagination, flags)
gqlClient, err := getGQLClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub graphql client: %w", err)
}
return GetIssueComments(ctx, client, gqlClient, owner, repo, pullNumber, pagination, flags)
default:
return nil, fmt.Errorf("unknown method: %s", method)
}
Expand Down
Loading
Loading