diff --git a/pkg/github/labels.go b/pkg/github/labels.go new file mode 100644 index 00000000..ab49d554 --- /dev/null +++ b/pkg/github/labels.go @@ -0,0 +1,571 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ListLabels creates a tool to list labels in a GitHub repository. +func ListLabels(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_labels", + mcp.WithDescription(t("TOOL_LIST_LABELS_DESCRIPTION", "List labels for a repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + opts := &github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + labels, resp, err := client.Issues.ListLabels(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list labels: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list labels: %s", string(body))), nil + } + + r, err := json.Marshal(labels) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// GetLabel creates a tool to get a specific label in a GitHub repository. +func GetLabel(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_label", + mcp.WithDescription(t("TOOL_GET_LABEL_DESCRIPTION", "Get a specific label from a repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("name", + mcp.Required(), + mcp.Description("Label name"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + name, err := requiredParam[string](request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + label, resp, err := client.Issues.GetLabel(ctx, owner, repo, name) + if err != nil { + return nil, fmt.Errorf("failed to get label: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get label: %s", string(body))), nil + } + + r, err := json.Marshal(label) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// CreateLabel creates a tool to create a new label in a GitHub repository. +func CreateLabel(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("create_label", + mcp.WithDescription(t("TOOL_CREATE_LABEL_DESCRIPTION", "Create a label in a GitHub repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("name", + mcp.Required(), + mcp.Description("Label name"), + ), + mcp.WithString("color", + mcp.Required(), + mcp.Description("The hexadecimal color code for the label, without the leading #"), + ), + mcp.WithString("description", + mcp.Description("A short description of the label"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + name, err := requiredParam[string](request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + color, err := requiredParam[string](request, "color") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + description, err := OptionalParam[string](request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Create the label request + labelRequest := &github.Label{ + Name: github.Ptr(name), + Color: github.Ptr(color), + Description: github.Ptr(description), + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + label, resp, err := client.Issues.CreateLabel(ctx, owner, repo, labelRequest) + if err != nil { + return nil, fmt.Errorf("failed to create label: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create label: %s", string(body))), nil + } + + r, err := json.Marshal(label) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// UpdateLabel creates a tool to update an existing label in a GitHub repository. +func UpdateLabel(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("update_label", + mcp.WithDescription(t("TOOL_UPDATE_LABEL_DESCRIPTION", "Update a label in a GitHub repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("name", + mcp.Required(), + mcp.Description("Current label name"), + ), + mcp.WithString("new_name", + mcp.Description("New label name"), + ), + mcp.WithString("color", + mcp.Description("The hexadecimal color code for the label, without the leading #"), + ), + mcp.WithString("description", + mcp.Description("A short description of the label"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + name, err := requiredParam[string](request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Create the label update + labelRequest := &github.Label{} + updateNeeded := false + + newName, err := OptionalParam[string](request, "new_name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if newName != "" { + labelRequest.Name = github.Ptr(newName) + updateNeeded = true + } + + color, err := OptionalParam[string](request, "color") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if color != "" { + labelRequest.Color = github.Ptr(color) + updateNeeded = true + } + + description, err := OptionalParam[string](request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if description != "" { + labelRequest.Description = github.Ptr(description) + updateNeeded = true + } + + if !updateNeeded { + return mcp.NewToolResultError("No update parameters provided."), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + label, resp, err := client.Issues.EditLabel(ctx, owner, repo, name, labelRequest) + if err != nil { + return nil, fmt.Errorf("failed to update label: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update label: %s", string(body))), nil + } + + r, err := json.Marshal(label) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// DeleteLabel creates a tool to delete a label from a GitHub repository. +func DeleteLabel(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("delete_label", + mcp.WithDescription(t("TOOL_DELETE_LABEL_DESCRIPTION", "Delete a label from a GitHub repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("name", + mcp.Required(), + mcp.Description("Label name"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + name, err := requiredParam[string](request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + resp, err := client.Issues.DeleteLabel(ctx, owner, repo, name) + if err != nil { + return nil, fmt.Errorf("failed to delete label: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to delete label: %s", string(body))), nil + } + + result := fmt.Sprintf("Label '%s' successfully deleted from %s/%s", name, owner, repo) + return mcp.NewToolResultText(result), nil + } +} + +// ListLabelsForIssue creates a tool to list labels on an issue. +func ListLabelsForIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_labels_for_issue", + mcp.WithDescription(t("TOOL_LIST_LABELS_FOR_ISSUE_DESCRIPTION", "List labels for an issue")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("issue_number", + mcp.Required(), + mcp.Description("Issue number"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := RequiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + opts := &github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + labels, resp, err := client.Issues.ListLabelsByIssue(ctx, owner, repo, issueNumber, opts) + if err != nil { + return nil, fmt.Errorf("failed to list labels for issue: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list labels for issue: %s", string(body))), nil + } + + r, err := json.Marshal(labels) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// AddLabelsToIssue creates a tool to add labels to an issue. +func AddLabelsToIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("add_labels_to_issue", + mcp.WithDescription(t("TOOL_ADD_LABELS_TO_ISSUE_DESCRIPTION", "Add labels to an issue")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("issue_number", + mcp.Required(), + mcp.Description("Issue number"), + ), + mcp.WithArray("labels", + mcp.Required(), + mcp.Description("Labels to add to the issue"), + mcp.Items( + map[string]interface{}{ + "type": "string", + }, + ), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := RequiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Get labels + labels, err := RequiredStringArrayParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + addedLabels, resp, err := client.Issues.AddLabelsToIssue(ctx, owner, repo, issueNumber, labels) + if err != nil { + return nil, fmt.Errorf("failed to add labels to issue: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to add labels to issue: %s", string(body))), nil + } + + r, err := json.Marshal(addedLabels) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// RemoveLabelFromIssue creates a tool to remove a label from an issue. +func RemoveLabelFromIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("remove_label_from_issue", + mcp.WithDescription(t("TOOL_REMOVE_LABEL_FROM_ISSUE_DESCRIPTION", "Remove a label from an issue")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("issue_number", + mcp.Required(), + mcp.Description("Issue number"), + ), + mcp.WithString("name", + mcp.Required(), + mcp.Description("Label name"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := RequiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + name, err := requiredParam[string](request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + resp, err := client.Issues.RemoveLabelForIssue(ctx, owner, repo, issueNumber, name) + if err != nil { + return nil, fmt.Errorf("failed to remove label from issue: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to remove label from issue: %s", string(body))), nil + } + + result := fmt.Sprintf("Label '%s' successfully removed from issue #%d in %s/%s", name, issueNumber, owner, repo) + return mcp.NewToolResultText(result), nil + } +} diff --git a/pkg/github/labels_test.go b/pkg/github/labels_test.go new file mode 100644 index 00000000..486f4fc5 --- /dev/null +++ b/pkg/github/labels_test.go @@ -0,0 +1,1054 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListLabels(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := ListLabels(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_labels", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock labels for success case + mockLabels := []*github.Label{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("bug"), + Description: github.Ptr("Something isn't working"), + Color: github.Ptr("f29513"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/bug"), + Default: github.Ptr(true), + }, + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("enhancement"), + Description: github.Ptr("New feature or request"), + Color: github.Ptr("a2eeef"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/enhancement"), + Default: github.Ptr(false), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedLabels []*github.Label + expectedErrMsg string + }{ + { + name: "successful labels listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposLabelsByOwnerByRepo, + mockLabels, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: false, + expectedLabels: mockLabels, + }, + { + name: "labels listing with pagination", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposLabelsByOwnerByRepo, + expectQueryParams(t, map[string]string{ + "page": "2", + "per_page": "10", + }).andThen( + mockResponse(t, http.StatusOK, mockLabels), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "page": float64(2), + "perPage": float64(10), + }, + expectError: false, + expectedLabels: mockLabels, + }, + { + name: "labels listing fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposLabelsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message": "Internal Server Error"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "failed to list labels", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListLabels(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedLabels []*github.Label + err = json.Unmarshal([]byte(textContent.Text), &returnedLabels) + require.NoError(t, err) + assert.Len(t, returnedLabels, len(tc.expectedLabels)) + + for i, label := range returnedLabels { + assert.Equal(t, *tc.expectedLabels[i].Name, *label.Name) + assert.Equal(t, *tc.expectedLabels[i].Color, *label.Color) + assert.Equal(t, *tc.expectedLabels[i].Description, *label.Description) + assert.Equal(t, *tc.expectedLabels[i].Default, *label.Default) + } + }) + } +} + +func Test_GetLabel(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := GetLabel(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "get_label", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "name") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "name"}) + + // Setup mock label for success case + mockLabel := &github.Label{ + ID: github.Ptr(int64(1)), + Name: github.Ptr("bug"), + Description: github.Ptr("Something isn't working"), + Color: github.Ptr("f29513"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/bug"), + Default: github.Ptr(true), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedLabel *github.Label + expectedErrMsg string + }{ + { + name: "successful label retrieval", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposLabelsByOwnerByRepoByName, + mockLabel, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "bug", + }, + expectError: false, + expectedLabel: mockLabel, + }, + { + name: "label retrieval fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposLabelsByOwnerByRepoByName, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Label not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "nonexistent", + }, + expectError: true, + expectedErrMsg: "failed to get label", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := GetLabel(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedLabel *github.Label + err = json.Unmarshal([]byte(textContent.Text), &returnedLabel) + require.NoError(t, err) + assert.Equal(t, *tc.expectedLabel.Name, *returnedLabel.Name) + assert.Equal(t, *tc.expectedLabel.Color, *returnedLabel.Color) + assert.Equal(t, *tc.expectedLabel.Description, *returnedLabel.Description) + assert.Equal(t, *tc.expectedLabel.Default, *returnedLabel.Default) + }) + } +} + +func Test_CreateLabel(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := CreateLabel(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "create_label", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "name") + assert.Contains(t, tool.InputSchema.Properties, "color") + assert.Contains(t, tool.InputSchema.Properties, "description") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "name", "color"}) + + // Setup mock created label for success case + mockLabel := &github.Label{ + ID: github.Ptr(int64(3)), + Name: github.Ptr("documentation"), + Description: github.Ptr("Improvements or additions to documentation"), + Color: github.Ptr("0075ca"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/documentation"), + Default: github.Ptr(false), + } + + labelRequest := map[string]interface{}{ + "name": "documentation", + "description": "Improvements or additions to documentation", + "color": "0075ca", + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedLabel *github.Label + expectedErrMsg string + }{ + { + name: "successful label creation", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposLabelsByOwnerByRepo, + expectRequestBody(t, labelRequest).andThen( + mockResponse(t, http.StatusCreated, mockLabel), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "documentation", + "color": "0075ca", + "description": "Improvements or additions to documentation", + }, + expectError: false, + expectedLabel: mockLabel, + }, + { + name: "label creation fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposLabelsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Validation failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "documentation", + "color": "invalid-color", + "description": "Improvements or additions to documentation", + }, + expectError: true, + expectedErrMsg: "failed to create label", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := CreateLabel(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedLabel *github.Label + err = json.Unmarshal([]byte(textContent.Text), &returnedLabel) + require.NoError(t, err) + assert.Equal(t, *tc.expectedLabel.Name, *returnedLabel.Name) + assert.Equal(t, *tc.expectedLabel.Color, *returnedLabel.Color) + assert.Equal(t, *tc.expectedLabel.Description, *returnedLabel.Description) + }) + } +} + +func Test_UpdateLabel(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := UpdateLabel(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "update_label", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "name") + assert.Contains(t, tool.InputSchema.Properties, "new_name") + assert.Contains(t, tool.InputSchema.Properties, "color") + assert.Contains(t, tool.InputSchema.Properties, "description") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "name"}) + + // Setup mock updated label for success case + mockLabel := &github.Label{ + ID: github.Ptr(int64(1)), + Name: github.Ptr("bug :bug:"), + Description: github.Ptr("Small bug fix required"), + Color: github.Ptr("b01f26"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/bug%20:bug:"), + Default: github.Ptr(true), + } + + labelRequest := map[string]interface{}{ + "name": "bug :bug:", + "description": "Small bug fix required", + "color": "b01f26", + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedLabel *github.Label + expectedErrMsg string + }{ + { + name: "successful label update", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposLabelsByOwnerByRepoByName, + expectRequestBody(t, labelRequest).andThen( + mockResponse(t, http.StatusOK, mockLabel), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "bug", + "new_name": "bug :bug:", + "color": "b01f26", + "description": "Small bug fix required", + }, + expectError: false, + expectedLabel: mockLabel, + }, + { + name: "label update fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposLabelsByOwnerByRepoByName, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Label not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "nonexistent", + "new_name": "bug :bug:", + "color": "b01f26", + "description": "Small bug fix required", + }, + expectError: true, + expectedErrMsg: "failed to update label", + }, + { + name: "no update parameters provided", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "bug", + }, + expectError: false, + expectedErrMsg: "No update parameters provided.", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := UpdateLabel(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Special case for no update parameters - we return a tool result error, not a Go error + if tc.name == "no update parameters provided" { + require.NoError(t, err) + require.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + + // Verify results for other cases + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedLabel *github.Label + err = json.Unmarshal([]byte(textContent.Text), &returnedLabel) + require.NoError(t, err) + assert.Equal(t, *tc.expectedLabel.Name, *returnedLabel.Name) + assert.Equal(t, *tc.expectedLabel.Color, *returnedLabel.Color) + assert.Equal(t, *tc.expectedLabel.Description, *returnedLabel.Description) + }) + } +} + +func Test_DeleteLabel(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := DeleteLabel(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "delete_label", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "name") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "name"}) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult string + expectedErrMsg string + }{ + { + name: "successful label deletion", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.DeleteReposLabelsByOwnerByRepoByName, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "bug", + }, + expectError: false, + expectedResult: "successfully deleted", + }, + { + name: "label deletion fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.DeleteReposLabelsByOwnerByRepoByName, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Label not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "name": "nonexistent", + }, + expectError: true, + expectedErrMsg: "failed to delete label", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := DeleteLabel(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Verify the result + assert.Contains(t, textContent.Text, tc.expectedResult) + }) + } +} + +func Test_ListLabelsForIssue(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := ListLabelsForIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_labels_for_issue", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "issue_number") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number"}) + + // Setup mock labels for success case + mockLabels := []*github.Label{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("bug"), + Description: github.Ptr("Something isn't working"), + Color: github.Ptr("f29513"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/bug"), + Default: github.Ptr(true), + }, + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("enhancement"), + Description: github.Ptr("New feature or request"), + Color: github.Ptr("a2eeef"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/enhancement"), + Default: github.Ptr(false), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedLabels []*github.Label + expectedErrMsg string + }{ + { + name: "successful labels listing for issue", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposIssuesLabelsByOwnerByRepoByIssueNumber, + mockLabels, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + }, + expectError: false, + expectedLabels: mockLabels, + }, + { + name: "labels listing for issue with pagination", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposIssuesLabelsByOwnerByRepoByIssueNumber, + expectQueryParams(t, map[string]string{ + "page": "2", + "per_page": "10", + }).andThen( + mockResponse(t, http.StatusOK, mockLabels), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "page": float64(2), + "perPage": float64(10), + }, + expectError: false, + expectedLabels: mockLabels, + }, + { + name: "labels listing for issue fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposIssuesLabelsByOwnerByRepoByIssueNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Issue not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(999), + }, + expectError: true, + expectedErrMsg: "failed to list labels for issue", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListLabelsForIssue(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedLabels []*github.Label + err = json.Unmarshal([]byte(textContent.Text), &returnedLabels) + require.NoError(t, err) + assert.Len(t, returnedLabels, len(tc.expectedLabels)) + + for i, label := range returnedLabels { + assert.Equal(t, *tc.expectedLabels[i].Name, *label.Name) + assert.Equal(t, *tc.expectedLabels[i].Color, *label.Color) + assert.Equal(t, *tc.expectedLabels[i].Description, *label.Description) + assert.Equal(t, *tc.expectedLabels[i].Default, *label.Default) + } + }) + } +} + +func Test_AddLabelsToIssue(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := AddLabelsToIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "add_labels_to_issue", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "issue_number") + assert.Contains(t, tool.InputSchema.Properties, "labels") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number", "labels"}) + + // Setup mock labels for success case + mockLabels := []*github.Label{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("bug"), + Description: github.Ptr("Something isn't working"), + Color: github.Ptr("f29513"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/bug"), + Default: github.Ptr(true), + }, + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("enhancement"), + Description: github.Ptr("New feature or request"), + Color: github.Ptr("a2eeef"), + URL: github.Ptr("https://api.github.com/repos/octocat/Hello-World/labels/enhancement"), + Default: github.Ptr(false), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedLabels []*github.Label + expectedErrMsg string + }{ + { + name: "successful labels addition to issue", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposIssuesLabelsByOwnerByRepoByIssueNumber, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return success status and expected labels + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(mockLabels) + _, _ = w.Write(data) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "labels": []interface{}{"bug", "enhancement"}, + }, + expectError: false, + expectedLabels: mockLabels, + }, + { + name: "labels addition to issue fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposIssuesLabelsByOwnerByRepoByIssueNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Validation failed"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "labels": []interface{}{"invalid-label"}, + }, + expectError: true, + expectedErrMsg: "failed to add labels to issue", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := AddLabelsToIssue(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedLabels []*github.Label + err = json.Unmarshal([]byte(textContent.Text), &returnedLabels) + require.NoError(t, err) + assert.Len(t, returnedLabels, len(tc.expectedLabels)) + + for i, label := range returnedLabels { + assert.Equal(t, *tc.expectedLabels[i].Name, *label.Name) + assert.Equal(t, *tc.expectedLabels[i].Color, *label.Color) + assert.Equal(t, *tc.expectedLabels[i].Description, *label.Description) + assert.Equal(t, *tc.expectedLabels[i].Default, *label.Default) + } + }) + } +} + +func Test_RemoveLabelFromIssue(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := RemoveLabelFromIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "remove_label_from_issue", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "issue_number") + assert.Contains(t, tool.InputSchema.Properties, "name") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number", "name"}) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResult string + expectedErrMsg string + }{ + { + name: "successful label removal from issue", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.DeleteReposIssuesLabelsByOwnerByRepoByIssueNumberByName, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`[]`)) // GitHub returns an empty array on successful removal + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "name": "bug", + }, + expectError: false, + expectedResult: "successfully removed", + }, + { + name: "label removal from issue fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.DeleteReposIssuesLabelsByOwnerByRepoByIssueNumberByName, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Label or issue not found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(999), + "name": "nonexistent", + }, + expectError: true, + expectedErrMsg: "failed to remove label from issue", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := RemoveLabelFromIssue(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Verify the result + assert.Contains(t, textContent.Text, tc.expectedResult) + }) + } +} + +func Test_RequiredStringArrayParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "parameter not in request", + params: map[string]any{}, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "empty any array parameter", + params: map[string]any{ + "flag": []any{}, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "empty string array parameter", + params: map[string]any{ + "flag": []string{}, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "valid any array parameter", + params: map[string]any{ + "flag": []any{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "valid string array parameter", + params: map[string]any{ + "flag": []string{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "nil parameter", + params: map[string]any{ + "flag": nil, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]any{ + "flag": 1, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "wrong slice type parameter", + params: map[string]any{ + "flag": []any{"foo", 2}, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := RequiredStringArrayParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index e4c24171..fbc51ecc 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -173,6 +173,43 @@ func OptionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) } } +// RequiredStringArrayParam gets a required array of strings from the request. +// Returns an error if the parameter is missing or an empty array. +func RequiredStringArrayParam(request mcp.CallToolRequest, name string) ([]string, error) { + v, ok := request.Params.Arguments[name] + if !ok { + return nil, fmt.Errorf("missing required parameter %s", name) + } + + if v == nil { + return nil, fmt.Errorf("parameter %s is nil", name) + } + + switch value := v.(type) { + case []string: + if len(value) == 0 { + return nil, fmt.Errorf("parameter %s cannot be empty", name) + } + return value, nil + case []interface{}: + if len(value) == 0 { + return nil, fmt.Errorf("parameter %s cannot be empty", name) + } + + result := make([]string, len(value)) + for i, val := range value { + str, ok := val.(string) + if !ok { + return nil, fmt.Errorf("parameter %s[%d] is not a string", name, i) + } + result[i] = str + } + return result, nil + default: + return nil, fmt.Errorf("parameter %s is not an array of strings", name) + } +} + // WithPagination returns a ToolOption that adds "page" and "perPage" parameters to the tool. // The "page" parameter is optional, min 1. The "perPage" parameter is optional, min 1, max 100. func WithPagination() mcp.ToolOption { diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 58bcb9db..03c25e44 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -435,6 +435,101 @@ func TestOptionalStringArrayParam(t *testing.T) { } } +func TestRequiredStringArrayParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "parameter not in request", + params: map[string]any{}, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "empty any array parameter", + params: map[string]any{ + "flag": []any{}, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "empty string array parameter", + params: map[string]any{ + "flag": []string{}, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "valid any array parameter", + params: map[string]any{ + "flag": []any{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "valid string array parameter", + params: map[string]any{ + "flag": []string{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "nil parameter", + params: map[string]any{ + "flag": nil, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]any{ + "flag": 1, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + { + name: "wrong slice type parameter", + params: map[string]any{ + "flag": []any{"foo", 2}, + }, + paramName: "flag", + expected: nil, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := RequiredStringArrayParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + func TestOptionalPaginationParams(t *testing.T) { tests := []struct { name string diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 35dabaef..e421b3e9 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -41,11 +41,19 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(SearchIssues(getClient, t)), toolsets.NewServerTool(ListIssues(getClient, t)), toolsets.NewServerTool(GetIssueComments(getClient, t)), + toolsets.NewServerTool(ListLabels(getClient, t)), + toolsets.NewServerTool(GetLabel(getClient, t)), + toolsets.NewServerTool(ListLabelsForIssue(getClient, t)), ). AddWriteTools( toolsets.NewServerTool(CreateIssue(getClient, t)), toolsets.NewServerTool(AddIssueComment(getClient, t)), toolsets.NewServerTool(UpdateIssue(getClient, t)), + toolsets.NewServerTool(CreateLabel(getClient, t)), + toolsets.NewServerTool(UpdateLabel(getClient, t)), + toolsets.NewServerTool(DeleteLabel(getClient, t)), + toolsets.NewServerTool(AddLabelsToIssue(getClient, t)), + toolsets.NewServerTool(RemoveLabelFromIssue(getClient, t)), ) users := toolsets.NewToolset("users", "GitHub User related tools"). AddReadTools(