diff --git a/pkg/inventory/bitmap.go b/pkg/inventory/bitmap.go new file mode 100644 index 000000000..6f83ed84a --- /dev/null +++ b/pkg/inventory/bitmap.go @@ -0,0 +1,116 @@ +package inventory + +import ( + "math/bits" +) + +// ToolBitmap represents a set of tools as a bitmap for O(1) set operations. +// Uses 4 uint64s (256 bits) to support up to 256 tools. +// Current tool count is ~130, so this provides headroom for growth. +type ToolBitmap [4]uint64 + +// EmptyBitmap returns an empty bitmap with no bits set. +func EmptyBitmap() ToolBitmap { + return ToolBitmap{} +} + +// SetBit returns a new bitmap with the bit at position i set. +func (b ToolBitmap) SetBit(i int) ToolBitmap { + if i < 0 || i >= 256 { + return b + } + word, bit := i/64, uint(i%64) //nolint:gosec // bounds checked above + result := b + result[word] |= 1 << bit + return result +} + +// ClearBit returns a new bitmap with the bit at position i cleared. +func (b ToolBitmap) ClearBit(i int) ToolBitmap { + if i < 0 || i >= 256 { + return b + } + word, bit := i/64, uint(i%64) //nolint:gosec // bounds checked above + result := b + result[word] &^= 1 << bit + return result +} + +// IsSet returns true if the bit at position i is set. +func (b ToolBitmap) IsSet(i int) bool { + if i < 0 || i >= 256 { + return false + } + word, bit := i/64, uint(i%64) //nolint:gosec // bounds checked above + return (b[word] & (1 << bit)) != 0 +} + +// Or returns the union of two bitmaps. +func (b ToolBitmap) Or(other ToolBitmap) ToolBitmap { + return ToolBitmap{ + b[0] | other[0], + b[1] | other[1], + b[2] | other[2], + b[3] | other[3], + } +} + +// And returns the intersection of two bitmaps. +func (b ToolBitmap) And(other ToolBitmap) ToolBitmap { + return ToolBitmap{ + b[0] & other[0], + b[1] & other[1], + b[2] & other[2], + b[3] & other[3], + } +} + +// AndNot returns b AND NOT other (bits in b that are not in other). +func (b ToolBitmap) AndNot(other ToolBitmap) ToolBitmap { + return ToolBitmap{ + b[0] &^ other[0], + b[1] &^ other[1], + b[2] &^ other[2], + b[3] &^ other[3], + } +} + +// PopCount returns the number of set bits (population count). +func (b ToolBitmap) PopCount() int { + return bits.OnesCount64(b[0]) + + bits.OnesCount64(b[1]) + + bits.OnesCount64(b[2]) + + bits.OnesCount64(b[3]) +} + +// IsEmpty returns true if no bits are set. +func (b ToolBitmap) IsEmpty() bool { + return b[0] == 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 +} + +// Iterate calls fn for each set bit position. Stops if fn returns false. +func (b ToolBitmap) Iterate(fn func(position int) bool) { + for word := 0; word < 4; word++ { + v := b[word] + base := word * 64 + for v != 0 { + // Find position of lowest set bit + tz := bits.TrailingZeros64(v) + if !fn(base + tz) { + return + } + // Clear the lowest set bit + v &= v - 1 + } + } +} + +// Positions returns a slice of all set bit positions. +func (b ToolBitmap) Positions() []int { + result := make([]int, 0, b.PopCount()) + b.Iterate(func(pos int) bool { + result = append(result, pos) + return true + }) + return result +} diff --git a/pkg/inventory/bitmap_test.go b/pkg/inventory/bitmap_test.go new file mode 100644 index 000000000..f18acae48 --- /dev/null +++ b/pkg/inventory/bitmap_test.go @@ -0,0 +1,216 @@ +package inventory + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestToolBitmap_SetAndIsSet(t *testing.T) { + t.Parallel() + + var bm ToolBitmap + + // Initially empty + assert.False(t, bm.IsSet(0)) + assert.False(t, bm.IsSet(63)) + assert.False(t, bm.IsSet(64)) + assert.False(t, bm.IsSet(127)) + + // Set some bits + bm = bm.SetBit(0) + bm = bm.SetBit(63) + bm = bm.SetBit(64) + bm = bm.SetBit(127) + bm = bm.SetBit(200) + + assert.True(t, bm.IsSet(0)) + assert.True(t, bm.IsSet(63)) + assert.True(t, bm.IsSet(64)) + assert.True(t, bm.IsSet(127)) + assert.True(t, bm.IsSet(200)) + + // Unset bits should still be false + assert.False(t, bm.IsSet(1)) + assert.False(t, bm.IsSet(62)) + assert.False(t, bm.IsSet(128)) +} + +func TestToolBitmap_ClearBit(t *testing.T) { + t.Parallel() + + bm := ToolBitmap{}.SetBit(5).SetBit(10).SetBit(100) + assert.True(t, bm.IsSet(5)) + assert.True(t, bm.IsSet(10)) + assert.True(t, bm.IsSet(100)) + + bm = bm.ClearBit(10) + assert.True(t, bm.IsSet(5)) + assert.False(t, bm.IsSet(10)) + assert.True(t, bm.IsSet(100)) +} + +func TestToolBitmap_Or(t *testing.T) { + t.Parallel() + + a := ToolBitmap{}.SetBit(1).SetBit(3).SetBit(65) + b := ToolBitmap{}.SetBit(2).SetBit(3).SetBit(130) + + result := a.Or(b) + + assert.True(t, result.IsSet(1)) + assert.True(t, result.IsSet(2)) + assert.True(t, result.IsSet(3)) + assert.True(t, result.IsSet(65)) + assert.True(t, result.IsSet(130)) + assert.False(t, result.IsSet(0)) + assert.False(t, result.IsSet(4)) +} + +func TestToolBitmap_And(t *testing.T) { + t.Parallel() + + a := ToolBitmap{}.SetBit(1).SetBit(3).SetBit(5).SetBit(65) + b := ToolBitmap{}.SetBit(3).SetBit(5).SetBit(7).SetBit(65) + + result := a.And(b) + + assert.False(t, result.IsSet(1)) // only in a + assert.True(t, result.IsSet(3)) // in both + assert.True(t, result.IsSet(5)) // in both + assert.False(t, result.IsSet(7)) // only in b + assert.True(t, result.IsSet(65)) // in both +} + +func TestToolBitmap_AndNot(t *testing.T) { + t.Parallel() + + a := ToolBitmap{}.SetBit(1).SetBit(3).SetBit(5).SetBit(65) + b := ToolBitmap{}.SetBit(3).SetBit(7).SetBit(65) + + result := a.AndNot(b) + + assert.True(t, result.IsSet(1)) // in a, not in b + assert.False(t, result.IsSet(3)) // in both, removed + assert.True(t, result.IsSet(5)) // in a, not in b + assert.False(t, result.IsSet(7)) // not in a + assert.False(t, result.IsSet(65)) // in both, removed +} + +func TestToolBitmap_PopCount(t *testing.T) { + t.Parallel() + + assert.Equal(t, 0, ToolBitmap{}.PopCount()) + assert.Equal(t, 1, ToolBitmap{}.SetBit(0).PopCount()) + assert.Equal(t, 2, ToolBitmap{}.SetBit(0).SetBit(100).PopCount()) + assert.Equal(t, 4, ToolBitmap{}.SetBit(0).SetBit(63).SetBit(64).SetBit(255).PopCount()) +} + +func TestToolBitmap_IsEmpty(t *testing.T) { + t.Parallel() + + assert.True(t, ToolBitmap{}.IsEmpty()) + assert.False(t, ToolBitmap{}.SetBit(0).IsEmpty()) + assert.False(t, ToolBitmap{}.SetBit(200).IsEmpty()) +} + +func TestToolBitmap_Iterate(t *testing.T) { + t.Parallel() + + bm := ToolBitmap{}.SetBit(3).SetBit(10).SetBit(64).SetBit(100).SetBit(200) + + var positions []int + bm.Iterate(func(pos int) bool { + positions = append(positions, pos) + return true + }) + + assert.Equal(t, []int{3, 10, 64, 100, 200}, positions) +} + +func TestToolBitmap_Iterate_EarlyStop(t *testing.T) { + t.Parallel() + + bm := ToolBitmap{}.SetBit(1).SetBit(5).SetBit(10).SetBit(20) + + var positions []int + bm.Iterate(func(pos int) bool { + positions = append(positions, pos) + return pos < 10 // stop after 10 + }) + + assert.Equal(t, []int{1, 5, 10}, positions) +} + +func TestToolBitmap_Positions(t *testing.T) { + t.Parallel() + + bm := ToolBitmap{}.SetBit(5).SetBit(63).SetBit(64).SetBit(128) + positions := bm.Positions() + + assert.Equal(t, []int{5, 63, 64, 128}, positions) +} + +func TestToolBitmap_BoundaryConditions(t *testing.T) { + t.Parallel() + + var bm ToolBitmap + + // Negative index should be no-op + bm = bm.SetBit(-1) + assert.False(t, bm.IsSet(-1)) + + // Index >= 256 should be no-op + bm = bm.SetBit(256) + assert.False(t, bm.IsSet(256)) + + // Maximum valid index + bm = bm.SetBit(255) + assert.True(t, bm.IsSet(255)) +} + +func BenchmarkToolBitmap_Or(b *testing.B) { + a := ToolBitmap{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0} + c := ToolBitmap{0, 0, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = a.Or(c) + } +} + +func BenchmarkToolBitmap_And(b *testing.B) { + a := ToolBitmap{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xAAAAAAAAAAAAAAAA, 0} + c := ToolBitmap{0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = a.And(c) + } +} + +func BenchmarkToolBitmap_PopCount(b *testing.B) { + bm := ToolBitmap{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0xFFFFFFFF00000000, 0x00000000FFFFFFFF} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = bm.PopCount() + } +} + +func BenchmarkToolBitmap_Iterate130Bits(b *testing.B) { + // Simulate ~130 tools + var bm ToolBitmap + for i := 0; i < 130; i++ { + bm = bm.SetBit(i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + bm.Iterate(func(_ int) bool { + count++ + return true + }) + } +} diff --git a/pkg/inventory/tool_index.go b/pkg/inventory/tool_index.go new file mode 100644 index 000000000..a9716be4f --- /dev/null +++ b/pkg/inventory/tool_index.go @@ -0,0 +1,364 @@ +package inventory + +import ( + "context" + "sort" +) + +// ToolIndex provides O(1) bitmap-based filtering for tools. +// +// Instead of iterating through all tools and checking each filter condition, +// this pre-computes bitmaps for each filter dimension. Queries become fast +// bitmap AND/OR operations, and we only need to run dynamic Enabled() checks +// on the tools that survive static filtering. +// +// Memory: ~1-2KB for 130 tools with 15 toolsets and 10 feature flags. +// Query time: O(toolsets + features) bitmap ops + O(surviving tools) for dynamic checks. +type ToolIndex struct { + // tools stores the actual tool data, indexed by bit position + tools []ServerTool + + // toolPosition maps tool name to bitmap position for O(1) lookup + toolPosition map[string]int + + // allTools has all bits set for tools in the index + allTools ToolBitmap + + // Toolset indexes - each toolset maps to bitmap of tools in that toolset + byToolset map[ToolsetID]ToolBitmap + + // Read-only filtering + readOnlyTools ToolBitmap // tools with ReadOnlyHint=true + writeTools ToolBitmap // tools with ReadOnlyHint=false (write tools) + + // Feature flag indexes + // requiresFeature[flag] = tools that require this flag to be ON + requiresFeature map[string]ToolBitmap + // disabledByFeature[flag] = tools that are disabled when this flag is ON + disabledByFeature map[string]ToolBitmap + + // Dynamic check tracking + // hasDynamicCheck contains tools that have a non-nil Enabled function + hasDynamicCheck ToolBitmap +} + +// BuildToolIndex creates a ToolIndex from a slice of ServerTools. +// This should be called once at startup; the index is then reused for all queries. +func BuildToolIndex(tools []ServerTool) *ToolIndex { + idx := &ToolIndex{ + tools: make([]ServerTool, len(tools)), + toolPosition: make(map[string]int, len(tools)), + byToolset: make(map[ToolsetID]ToolBitmap), + requiresFeature: make(map[string]ToolBitmap), + disabledByFeature: make(map[string]ToolBitmap), + } + + // Sort tools for deterministic ordering (by toolset, then name) + sortedTools := make([]ServerTool, len(tools)) + copy(sortedTools, tools) + sort.Slice(sortedTools, func(i, j int) bool { + if sortedTools[i].Toolset.ID != sortedTools[j].Toolset.ID { + return sortedTools[i].Toolset.ID < sortedTools[j].Toolset.ID + } + return sortedTools[i].Tool.Name < sortedTools[j].Tool.Name + }) + + for i, tool := range sortedTools { + idx.tools[i] = tool + idx.toolPosition[tool.Tool.Name] = i + idx.allTools = idx.allTools.SetBit(i) + + // Index by toolset + idx.byToolset[tool.Toolset.ID] = idx.byToolset[tool.Toolset.ID].SetBit(i) + + // Index by read-only status + if tool.IsReadOnly() { + idx.readOnlyTools = idx.readOnlyTools.SetBit(i) + } else { + idx.writeTools = idx.writeTools.SetBit(i) + } + + // Index by feature flags + if tool.FeatureFlagEnable != "" { + idx.requiresFeature[tool.FeatureFlagEnable] = + idx.requiresFeature[tool.FeatureFlagEnable].SetBit(i) + } + if tool.FeatureFlagDisable != "" { + idx.disabledByFeature[tool.FeatureFlagDisable] = + idx.disabledByFeature[tool.FeatureFlagDisable].SetBit(i) + } + + // Track tools with dynamic checks + if tool.Enabled != nil { + idx.hasDynamicCheck = idx.hasDynamicCheck.SetBit(i) + } + } + + return idx +} + +// QueryConfig specifies the filter criteria for a tool query. +type QueryConfig struct { + // Toolset filtering + AllToolsets bool // if true, include all toolsets + EnabledToolsets []ToolsetID // specific toolsets to include + + // Additional tools to include (bypass toolset filter) + AdditionalTools []string + + // Read-only mode - if true, exclude write tools + ReadOnly bool + + // Feature flag states - map of flag name to enabled state + // Only flags that are explicitly set are considered + EnabledFeatures []string // features that are ON + DisabledFeatures []string // features that are OFF (explicit) +} + +// QueryResult contains the result of a tool query. +type QueryResult struct { + // Bitmap of tools that passed all static filters + StaticFiltered ToolBitmap + + // Bitmap of tools in StaticFiltered that need dynamic Enabled() checks + NeedsDynamicCheck ToolBitmap + + // Tools that passed static filters and have no dynamic check (immediately available) + Guaranteed ToolBitmap +} + +// Query executes a filter query and returns which tools match. +// This performs only bitmap operations - no iteration over tools. +// +// The result indicates: +// - StaticFiltered: all tools that passed static criteria +// - NeedsDynamicCheck: subset that requires runtime Enabled() evaluation +// - Guaranteed: subset that is definitely included (no dynamic check needed) +func (idx *ToolIndex) Query(cfg QueryConfig) QueryResult { + var result ToolBitmap + + // Step 1: Toolset filtering - O(|toolsets|) bitmap ORs + if cfg.AllToolsets { + result = idx.allTools + } else { + for _, ts := range cfg.EnabledToolsets { + if bm, ok := idx.byToolset[ts]; ok { + result = result.Or(bm) + } + } + } + + // Step 2: Add additional tools - O(|additional|) bit sets + for _, name := range cfg.AdditionalTools { + if pos, ok := idx.toolPosition[name]; ok { + result = result.SetBit(pos) + } + } + + // Step 3: Read-only filter - O(1) bitmap AND + if cfg.ReadOnly { + result = result.And(idx.readOnlyTools) + } + + // Step 4: Feature flag filtering - O(|features|) bitmap operations + // Remove tools that require a feature that's OFF + for _, flag := range cfg.DisabledFeatures { + if bm, ok := idx.requiresFeature[flag]; ok { + result = result.AndNot(bm) + } + } + // Remove tools that are disabled by a feature that's ON + for _, flag := range cfg.EnabledFeatures { + if bm, ok := idx.disabledByFeature[flag]; ok { + result = result.AndNot(bm) + } + } + + // For tools with FeatureFlagEnable that isn't in EnabledFeatures, filter them out + // (They require the flag, but it's not enabled) + enabledSet := make(map[string]bool, len(cfg.EnabledFeatures)) + for _, f := range cfg.EnabledFeatures { + enabledSet[f] = true + } + for flag, bm := range idx.requiresFeature { + if !enabledSet[flag] { + // This flag is required but not enabled, remove these tools + result = result.AndNot(bm) + } + } + + // Compute which tools need dynamic checks + needsDynamic := result.And(idx.hasDynamicCheck) + guaranteed := result.AndNot(idx.hasDynamicCheck) + + return QueryResult{ + StaticFiltered: result, + NeedsDynamicCheck: needsDynamic, + Guaranteed: guaranteed, + } +} + +// GetTool returns the tool at the given bitmap position. +func (idx *ToolIndex) GetTool(position int) *ServerTool { + if position < 0 || position >= len(idx.tools) { + return nil + } + return &idx.tools[position] +} + +// GetToolByName returns the tool with the given name and its position. +func (idx *ToolIndex) GetToolByName(name string) (*ServerTool, int, bool) { + if pos, ok := idx.toolPosition[name]; ok { + return &idx.tools[pos], pos, true + } + return nil, -1, false +} + +// Materialize converts a QueryResult into actual tools, running dynamic checks as needed. +// Only tools in NeedsDynamicCheck have their Enabled() function called. +// Returns pointers to the cached tools - callers should NOT modify them. +func (idx *ToolIndex) Materialize(ctx context.Context, qr QueryResult) []*ServerTool { + // Pre-allocate with capacity = guaranteed + potential dynamic + capacity := qr.Guaranteed.PopCount() + qr.NeedsDynamicCheck.PopCount() + result := make([]*ServerTool, 0, capacity) + + // Add all guaranteed tools (no dynamic check needed) + qr.Guaranteed.Iterate(func(pos int) bool { + result = append(result, &idx.tools[pos]) + return true + }) + + // Check and add tools that need dynamic evaluation + qr.NeedsDynamicCheck.Iterate(func(pos int) bool { + tool := &idx.tools[pos] + if tool.Enabled != nil { + enabled, err := tool.Enabled(ctx) + if err != nil || !enabled { + return true // skip this tool, continue iteration + } + } + result = append(result, tool) + return true + }) + + // Sort result for deterministic output + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// ToolsetBitmap returns the bitmap for a specific toolset. +func (idx *ToolIndex) ToolsetBitmap(id ToolsetID) ToolBitmap { + return idx.byToolset[id] +} + +// AllToolsBitmap returns a bitmap with all tools. +func (idx *ToolIndex) AllToolsBitmap() ToolBitmap { + return idx.allTools +} + +// ToolCount returns the number of tools in the index. +func (idx *ToolIndex) ToolCount() int { + return len(idx.tools) +} + +// DynamicCheckCount returns how many tools have dynamic Enabled() checks. +func (idx *ToolIndex) DynamicCheckCount() int { + return idx.hasDynamicCheck.PopCount() +} + +// ToolsetIDs returns all toolset IDs in the index. +func (idx *ToolIndex) ToolsetIDs() []ToolsetID { + ids := make([]ToolsetID, 0, len(idx.byToolset)) + for id := range idx.byToolset { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} + +// UniqueFeatureFlags returns all unique feature flags referenced by tools in the index. +// This allows callers to check only the flags that matter, once per request. +func (idx *ToolIndex) UniqueFeatureFlags() []string { + seen := make(map[string]bool) + for flag := range idx.requiresFeature { + seen[flag] = true + } + for flag := range idx.disabledByFeature { + seen[flag] = true + } + + flags := make([]string, 0, len(seen)) + for flag := range seen { + flags = append(flags, flag) + } + sort.Strings(flags) + return flags +} + +// QueryWithFeatureChecker performs a query, automatically checking feature flags. +// Each unique flag is checked exactly once via the checker function. +// This minimizes feature flag service calls from O(tools) to O(unique flags). +func (idx *ToolIndex) QueryWithFeatureChecker(ctx context.Context, cfg QueryConfigWithChecker) QueryResult { + // Check each unique feature flag exactly once + var enabledFeatures, disabledFeatures []string + + for flag := range idx.requiresFeature { + enabled, err := cfg.FeatureChecker(ctx, flag) + if err != nil || !enabled { + disabledFeatures = append(disabledFeatures, flag) + } else { + enabledFeatures = append(enabledFeatures, flag) + } + } + + for flag := range idx.disabledByFeature { + // Only check if we haven't already + alreadyChecked := false + for _, f := range enabledFeatures { + if f == flag { + alreadyChecked = true + break + } + } + for _, f := range disabledFeatures { + if f == flag { + alreadyChecked = true + break + } + } + if !alreadyChecked { + enabled, err := cfg.FeatureChecker(ctx, flag) + if err != nil || !enabled { + disabledFeatures = append(disabledFeatures, flag) + } else { + enabledFeatures = append(enabledFeatures, flag) + } + } + } + + // Delegate to the standard Query with resolved flags + return idx.Query(QueryConfig{ + AllToolsets: cfg.AllToolsets, + EnabledToolsets: cfg.EnabledToolsets, + AdditionalTools: cfg.AdditionalTools, + ReadOnly: cfg.ReadOnly, + EnabledFeatures: enabledFeatures, + DisabledFeatures: disabledFeatures, + }) +} + +// QueryConfigWithChecker is like QueryConfig but uses a checker function +// instead of pre-resolved feature flag lists. +type QueryConfigWithChecker struct { + AllToolsets bool + EnabledToolsets []ToolsetID + AdditionalTools []string + ReadOnly bool + FeatureChecker FeatureFlagChecker // Reuses the existing FeatureFlagChecker type from filters.go +} diff --git a/pkg/inventory/tool_index_test.go b/pkg/inventory/tool_index_test.go new file mode 100644 index 000000000..b02612c0a --- /dev/null +++ b/pkg/inventory/tool_index_test.go @@ -0,0 +1,732 @@ +package inventory + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" +) + +// mockServerToolInToolset creates a mock tool with a specific toolset +func mockServerToolInToolset(name string, toolsetID ToolsetID, readOnly bool) ServerTool { + var annotations *mcp.ToolAnnotations + if readOnly { + annotations = &mcp.ToolAnnotations{ReadOnlyHint: true} + } + return ServerTool{ + Tool: mcp.Tool{ + Name: name, + Description: "Test tool: " + name, + Annotations: annotations, + }, + Toolset: ToolsetMetadata{ID: toolsetID}, + } +} + +// mockServerToolWithFeatureFlag creates a mock tool that requires a feature flag +func mockServerToolWithFeatureFlag(name string, toolsetID ToolsetID, enableFlag string, disableFlag string) ServerTool { + return ServerTool{ + Tool: mcp.Tool{ + Name: name, + Description: "Test tool: " + name, + }, + Toolset: ToolsetMetadata{ID: toolsetID}, + FeatureFlagEnable: enableFlag, + FeatureFlagDisable: disableFlag, + } +} + +// mockServerToolWithDynamicCheck creates a mock tool with a custom Enabled function +func mockServerToolWithDynamicCheck(name string, toolsetID ToolsetID, enabledFn func(context.Context) (bool, error)) ServerTool { + return ServerTool{ + Tool: mcp.Tool{ + Name: name, + Description: "Test tool: " + name, + }, + Toolset: ToolsetMetadata{ID: toolsetID}, + Enabled: enabledFn, + } +} + +func TestBuildToolIndex(t *testing.T) { + t.Parallel() + + // Create test tools in different toolsets + testTools := []ServerTool{ + mockServerToolInToolset("get_me", "users", true), + mockServerToolInToolset("list_issues", "issues", true), + mockServerToolInToolset("create_issue", "issues", false), + mockServerToolInToolset("list_pull_requests", "pull_requests", true), + mockServerToolInToolset("create_pull_request", "pull_requests", false), + } + + index := BuildToolIndex(testTools) + + assert.NotNil(t, index) + assert.Equal(t, 5, index.ToolCount()) + assert.Equal(t, 3, len(index.ToolsetIDs())) // users, issues, pull_requests +} + +func TestToolIndex_Query_AllToolsets(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("tool1", "set_a", true), + mockServerToolInToolset("tool2", "set_a", true), + mockServerToolInToolset("tool3", "set_b", true), + } + + index := BuildToolIndex(testTools) + + // Query for all toolsets + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"set_a", "set_b"}, + ReadOnly: false, + }) + + // All 3 tools should be in the result + assert.Equal(t, 3, result.Guaranteed.PopCount()) + assert.True(t, result.NeedsDynamicCheck.IsEmpty()) +} + +func TestToolIndex_Query_SingleToolset(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("tool1", "set_a", true), + mockServerToolInToolset("tool2", "set_a", true), + mockServerToolInToolset("tool3", "set_b", true), + } + + index := BuildToolIndex(testTools) + + // Query for only set_a + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"set_a"}, + ReadOnly: false, + }) + + // Only 2 tools should be in the result + assert.Equal(t, 2, result.Guaranteed.PopCount()) +} + +func TestToolIndex_Query_ReadOnlyMode(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("get_thing", "things", true), // read-only + mockServerToolInToolset("create_thing", "things", false), // write + mockServerToolInToolset("delete_thing", "things", false), // write + mockServerToolInToolset("list_things", "things", true), // read-only + } + + index := BuildToolIndex(testTools) + + // Query in read-only mode + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"things"}, + ReadOnly: true, + }) + + // Only read-only tools should be in the result + assert.Equal(t, 2, result.Guaranteed.PopCount()) + + // Materialize and verify + ctx := context.Background() + tools := index.Materialize(ctx, result) + + names := make([]string, len(tools)) + for i, tool := range tools { + names[i] = tool.Tool.Name + } + assert.Contains(t, names, "get_thing") + assert.Contains(t, names, "list_things") + assert.NotContains(t, names, "create_thing") + assert.NotContains(t, names, "delete_thing") +} + +func TestToolIndex_Query_FeatureFlags(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("basic_tool", "tools", true), + mockServerToolWithFeatureFlag("advanced_tool", "tools", "advanced_features", ""), + mockServerToolWithFeatureFlag("experimental_tool", "tools", "experimental_features", ""), + } + + index := BuildToolIndex(testTools) + + // Query with no features enabled + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"tools"}, + EnabledFeatures: []string{}, + ReadOnly: false, + }) + + // Only basic_tool should be guaranteed (advanced requires flag) + assert.Equal(t, 1, result.Guaranteed.PopCount()) + + // Materialize to verify + ctx := context.Background() + tools := index.Materialize(ctx, result) + + assert.Len(t, tools, 1) + assert.Equal(t, "basic_tool", tools[0].Tool.Name) + + // Query with advanced_features enabled + result = index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"tools"}, + EnabledFeatures: []string{"advanced_features"}, + ReadOnly: false, + }) + + assert.Equal(t, 2, result.Guaranteed.PopCount()) +} + +func TestToolIndex_Query_FeatureFlagDisables(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("standard_tool", "tools", true), + mockServerToolWithFeatureFlag("legacy_tool", "tools", "", "new_mode"), + } + + index := BuildToolIndex(testTools) + + // Query with new_mode OFF - legacy tool should be available + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"tools"}, + EnabledFeatures: []string{}, + ReadOnly: false, + }) + assert.Equal(t, 2, result.Guaranteed.PopCount()) + + // Query with new_mode ON - legacy tool should be disabled + result = index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"tools"}, + EnabledFeatures: []string{"new_mode"}, + ReadOnly: false, + }) + assert.Equal(t, 1, result.Guaranteed.PopCount()) + + ctx := context.Background() + tools := index.Materialize(ctx, result) + + assert.Len(t, tools, 1) + assert.Equal(t, "standard_tool", tools[0].Tool.Name) +} + +func TestToolIndex_Query_DynamicChecks(t *testing.T) { + t.Parallel() + + // Tool with dynamic Enabled check + dynamicTool := mockServerToolWithDynamicCheck("dynamic_tool", "tools", func(_ context.Context) (bool, error) { + return true, nil + }) + + testTools := []ServerTool{ + mockServerToolInToolset("static_tool", "tools", true), + dynamicTool, + } + + index := BuildToolIndex(testTools) + + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"tools"}, + ReadOnly: false, + }) + + // static_tool is guaranteed, dynamic_tool needs check + assert.Equal(t, 1, result.Guaranteed.PopCount()) + assert.Equal(t, 1, result.NeedsDynamicCheck.PopCount()) + assert.Equal(t, 2, result.StaticFiltered.PopCount()) +} + +func TestToolIndex_Query_DynamicChecksFilteredByToolset(t *testing.T) { + t.Parallel() + + // Key test: dynamic check tool should NOT appear in NeedsDynamicCheck + // if it's already filtered out by toolset + + dynamicTool := mockServerToolWithDynamicCheck("dynamic_tool", "set_b", func(_ context.Context) (bool, error) { + return true, nil + }) + + testTools := []ServerTool{ + mockServerToolInToolset("static_tool", "set_a", true), + dynamicTool, + } + + index := BuildToolIndex(testTools) + + // Query only for set_a - dynamic_tool is in set_b, so it's filtered out + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"set_a"}, + ReadOnly: false, + }) + + // static_tool is guaranteed + assert.Equal(t, 1, result.Guaranteed.PopCount()) + // dynamic_tool should NOT be in NeedsDynamicCheck because it's already filtered + assert.True(t, result.NeedsDynamicCheck.IsEmpty()) +} + +func TestToolIndex_Materialize_NoDynamicChecks(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("tool1", "all", true), + mockServerToolInToolset("tool2", "all", true), + mockServerToolInToolset("tool3", "all", true), + } + + index := BuildToolIndex(testTools) + + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"all"}, + ReadOnly: false, + }) + + ctx := context.Background() + materializedTools := index.Materialize(ctx, result) + + assert.Len(t, materializedTools, 3) + + // Verify tool names + names := make([]string, len(materializedTools)) + for i, tool := range materializedTools { + names[i] = tool.Tool.Name + } + assert.Contains(t, names, "tool1") + assert.Contains(t, names, "tool2") + assert.Contains(t, names, "tool3") +} + +func TestToolIndex_Materialize_WithDynamicChecks(t *testing.T) { + t.Parallel() + + enabledDynamic := mockServerToolWithDynamicCheck("enabled_dynamic", "all", func(_ context.Context) (bool, error) { + return true, nil + }) + + disabledDynamic := mockServerToolWithDynamicCheck("disabled_dynamic", "all", func(_ context.Context) (bool, error) { + return false, nil + }) + + testTools := []ServerTool{ + mockServerToolInToolset("static_tool", "all", true), + enabledDynamic, + disabledDynamic, + } + + index := BuildToolIndex(testTools) + + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"all"}, + ReadOnly: false, + }) + + ctx := context.Background() + materializedTools := index.Materialize(ctx, result) + + // Should have static_tool and enabled_dynamic (disabled_dynamic returns false) + assert.Len(t, materializedTools, 2) + + names := make([]string, len(materializedTools)) + for i, tool := range materializedTools { + names[i] = tool.Tool.Name + } + assert.Contains(t, names, "static_tool") + assert.Contains(t, names, "enabled_dynamic") + assert.NotContains(t, names, "disabled_dynamic") +} + +func TestToolIndex_Materialize_DynamicCheckError(t *testing.T) { + t.Parallel() + + errorTool := mockServerToolWithDynamicCheck("error_tool", "all", func(_ context.Context) (bool, error) { + return false, context.DeadlineExceeded + }) + + testTools := []ServerTool{ + mockServerToolInToolset("static_tool", "all", true), + errorTool, + } + + index := BuildToolIndex(testTools) + + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"all"}, + ReadOnly: false, + }) + + ctx := context.Background() + materializedTools := index.Materialize(ctx, result) + + // Current implementation skips on error - verify the result + // Only static_tool should be present + // The implementation silently skips errors + assert.Len(t, materializedTools, 1) + assert.Equal(t, "static_tool", materializedTools[0].Tool.Name) +} + +func TestToolIndex_Query_AllToolsetsFlag(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("tool1", "set_a", true), + mockServerToolInToolset("tool2", "set_b", true), + mockServerToolInToolset("tool3", "set_c", true), + } + + index := BuildToolIndex(testTools) + + // Query with AllToolsets flag + result := index.Query(QueryConfig{ + AllToolsets: true, + ReadOnly: false, + }) + + assert.Equal(t, 3, result.Guaranteed.PopCount()) +} + +func TestToolIndex_Query_AdditionalTools(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("tool1", "set_a", true), + mockServerToolInToolset("tool2", "set_b", true), + mockServerToolInToolset("special_tool", "set_c", true), + } + + index := BuildToolIndex(testTools) + + // Query for set_a only, but include special_tool via AdditionalTools + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"set_a"}, + AdditionalTools: []string{"special_tool"}, + ReadOnly: false, + }) + + // Should have tool1 (from set_a) and special_tool (from additional) + assert.Equal(t, 2, result.Guaranteed.PopCount()) + + ctx := context.Background() + tools := index.Materialize(ctx, result) + + names := make([]string, len(tools)) + for i, tool := range tools { + names[i] = tool.Tool.Name + } + assert.Contains(t, names, "tool1") + assert.Contains(t, names, "special_tool") + assert.NotContains(t, names, "tool2") +} + +func TestToolIndex_GetTool(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("alpha", "set", true), + mockServerToolInToolset("beta", "set", true), + } + + index := BuildToolIndex(testTools) + + // Get tool by position + tool := index.GetTool(0) + assert.NotNil(t, tool) + + // Out of bounds + tool = index.GetTool(-1) + assert.Nil(t, tool) + tool = index.GetTool(100) + assert.Nil(t, tool) +} + +func TestToolIndex_GetToolByName(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("alpha", "set", true), + mockServerToolInToolset("beta", "set", true), + } + + index := BuildToolIndex(testTools) + + // Get tool by name + tool, pos, ok := index.GetToolByName("alpha") + assert.True(t, ok) + assert.NotNil(t, tool) + assert.Equal(t, "alpha", tool.Tool.Name) + assert.GreaterOrEqual(t, pos, 0) + + // Non-existent + tool, pos, ok = index.GetToolByName("nonexistent") + assert.False(t, ok) + assert.Nil(t, tool) + assert.Equal(t, -1, pos) +} + +func TestToolIndex_ToolsetBitmap(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("tool1", "set_a", true), + mockServerToolInToolset("tool2", "set_a", true), + mockServerToolInToolset("tool3", "set_b", true), + } + + index := BuildToolIndex(testTools) + + bmA := index.ToolsetBitmap("set_a") + assert.Equal(t, 2, bmA.PopCount()) + + bmB := index.ToolsetBitmap("set_b") + assert.Equal(t, 1, bmB.PopCount()) + + // Non-existent toolset returns empty bitmap + bmX := index.ToolsetBitmap("nonexistent") + assert.True(t, bmX.IsEmpty()) +} + +func TestToolIndex_UniqueFeatureFlags(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("basic_tool", "tools", true), + mockServerToolWithFeatureFlag("advanced_tool", "tools", "feature_a", ""), + mockServerToolWithFeatureFlag("experimental_tool", "tools", "feature_b", ""), + mockServerToolWithFeatureFlag("legacy_tool", "tools", "", "feature_c"), + mockServerToolWithFeatureFlag("complex_tool", "tools", "feature_a", "feature_d"), // reuses feature_a + } + + index := BuildToolIndex(testTools) + + flags := index.UniqueFeatureFlags() + + // Should have 4 unique flags: feature_a, feature_b, feature_c, feature_d + assert.Len(t, flags, 4) + assert.Contains(t, flags, "feature_a") + assert.Contains(t, flags, "feature_b") + assert.Contains(t, flags, "feature_c") + assert.Contains(t, flags, "feature_d") +} + +func TestToolIndex_QueryWithFeatureChecker(t *testing.T) { + t.Parallel() + + testTools := []ServerTool{ + mockServerToolInToolset("basic_tool", "tools", true), + mockServerToolWithFeatureFlag("needs_feature_a", "tools", "feature_a", ""), + mockServerToolWithFeatureFlag("needs_feature_b", "tools", "feature_b", ""), + mockServerToolWithFeatureFlag("disabled_by_feature_c", "tools", "", "feature_c"), + } + + index := BuildToolIndex(testTools) + + // Track which flags were checked + checkedFlags := make(map[string]int) + + checker := func(_ context.Context, flag string) (bool, error) { + checkedFlags[flag]++ + // Enable feature_a, disable feature_b, enable feature_c + switch flag { + case "feature_a": + return true, nil + case "feature_b": + return false, nil + case "feature_c": + return true, nil // This will disable "disabled_by_feature_c" + default: + return false, nil + } + } + + ctx := context.Background() + result := index.QueryWithFeatureChecker(ctx, QueryConfigWithChecker{ + EnabledToolsets: []ToolsetID{"tools"}, + ReadOnly: false, + FeatureChecker: checker, + }) + + // Each flag should be checked exactly once + assert.Equal(t, 1, checkedFlags["feature_a"], "feature_a should be checked once") + assert.Equal(t, 1, checkedFlags["feature_b"], "feature_b should be checked once") + assert.Equal(t, 1, checkedFlags["feature_c"], "feature_c should be checked once") + + // Materialize and verify results + tools := index.Materialize(ctx, result) + names := make([]string, len(tools)) + for i, tool := range tools { + names[i] = tool.Tool.Name + } + + // basic_tool: no flags, included + // needs_feature_a: feature_a enabled, included + // needs_feature_b: feature_b disabled, excluded + // disabled_by_feature_c: feature_c enabled, excluded + assert.Len(t, tools, 2) + assert.Contains(t, names, "basic_tool") + assert.Contains(t, names, "needs_feature_a") + assert.NotContains(t, names, "needs_feature_b") + assert.NotContains(t, names, "disabled_by_feature_c") +} + +func TestToolIndex_QueryWithFeatureChecker_MinimizesChecks(t *testing.T) { + t.Parallel() + + // Create 50 tools that all use the same 3 feature flags + testTools := make([]ServerTool, 50) + for i := 0; i < 50; i++ { + flag := []string{"flag_a", "flag_b", "flag_c"}[i%3] + testTools[i] = mockServerToolWithFeatureFlag( + "tool_"+string(rune('a'+i%26)), + "tools", + flag, // All tools require one of 3 flags + "", + ) + } + + index := BuildToolIndex(testTools) + + checkCount := 0 + checker := func(_ context.Context, _ string) (bool, error) { + checkCount++ + return true, nil + } + + ctx := context.Background() + _ = index.QueryWithFeatureChecker(ctx, QueryConfigWithChecker{ + EnabledToolsets: []ToolsetID{"tools"}, + FeatureChecker: checker, + }) + + // Should only check 3 unique flags, not 50 tools + assert.Equal(t, 3, checkCount, "Should check each unique flag exactly once") +} + +func BenchmarkBuildToolIndex_130Tools(b *testing.B) { + // Create realistic toolset distribution + toolsets := []ToolsetID{"repos", "issues", "pull_requests", "users", "actions", "code_security", "projects", "notifications", "discussions", "experiments"} + + testTools := make([]ServerTool, 130) + for i := 0; i < 130; i++ { + toolset := toolsets[i%len(toolsets)] + readOnly := i%3 != 0 // 2/3 are read-only + testTools[i] = mockServerToolInToolset("tool_"+string(rune('a'+i%26))+string(rune('0'+i/26)), toolset, readOnly) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = BuildToolIndex(testTools) + } +} + +func BenchmarkToolIndex_Query_SmallConfig(b *testing.B) { + toolsets := []ToolsetID{"repos", "issues", "pull_requests", "users", "actions", "code_security", "projects", "notifications", "discussions", "experiments"} + + testTools := make([]ServerTool, 130) + for i := 0; i < 130; i++ { + toolset := toolsets[i%len(toolsets)] + readOnly := i%3 != 0 + testTools[i] = mockServerToolInToolset("tool_"+string(rune('a'+i%26))+string(rune('0'+i/26)), toolset, readOnly) + } + + index := BuildToolIndex(testTools) + + config := QueryConfig{ + EnabledToolsets: []ToolsetID{"repos", "issues"}, + ReadOnly: false, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = index.Query(config) + } +} + +func BenchmarkToolIndex_Query_AllToolsets(b *testing.B) { + toolsets := []ToolsetID{"repos", "issues", "pull_requests", "users", "actions", "code_security", "projects", "notifications", "discussions", "experiments"} + + testTools := make([]ServerTool, 130) + for i := 0; i < 130; i++ { + toolset := toolsets[i%len(toolsets)] + readOnly := i%3 != 0 + testTools[i] = mockServerToolInToolset("tool_"+string(rune('a'+i%26))+string(rune('0'+i/26)), toolset, readOnly) + } + + index := BuildToolIndex(testTools) + + config := QueryConfig{ + AllToolsets: true, + ReadOnly: true, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = index.Query(config) + } +} + +func BenchmarkToolIndex_Materialize_NoDynamic(b *testing.B) { + testTools := make([]ServerTool, 50) + for i := 0; i < 50; i++ { + testTools[i] = mockServerToolInToolset("tool_"+string(rune('a'+i%26)), "all", true) + } + + index := BuildToolIndex(testTools) + + result := index.Query(QueryConfig{ + EnabledToolsets: []ToolsetID{"all"}, + ReadOnly: false, + }) + + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = index.Materialize(ctx, result) + } +} + +func BenchmarkToolIndex_QueryAndMaterialize_Realistic(b *testing.B) { + // Simulate realistic scenario: ~130 tools, 10 toolsets, some with dynamic checks + toolsets := []ToolsetID{"repos", "issues", "pull_requests", "users", "actions", "code_security", "projects", "notifications", "discussions", "experiments"} + + testTools := make([]ServerTool, 130) + for i := 0; i < 130; i++ { + toolset := toolsets[i%len(toolsets)] + readOnly := i%3 != 0 + + // 10% have dynamic checks + if i%10 == 0 { + testTools[i] = mockServerToolWithDynamicCheck( + "tool_"+string(rune('a'+i%26))+string(rune('0'+i/26)), + toolset, + func(_ context.Context) (bool, error) { return true, nil }, + ) + if readOnly { + testTools[i].Tool.Annotations = &mcp.ToolAnnotations{ReadOnlyHint: true} + } + } else { + testTools[i] = mockServerToolInToolset("tool_"+string(rune('a'+i%26))+string(rune('0'+i/26)), toolset, readOnly) + } + } + + index := BuildToolIndex(testTools) + + config := QueryConfig{ + EnabledToolsets: []ToolsetID{"repos", "issues", "pull_requests", "actions"}, + ReadOnly: false, + } + + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := index.Query(config) + _ = index.Materialize(ctx, result) + } +} diff --git a/pkg/inventory/tool_variants.go b/pkg/inventory/tool_variants.go new file mode 100644 index 000000000..5b22f5465 --- /dev/null +++ b/pkg/inventory/tool_variants.go @@ -0,0 +1,111 @@ +package inventory + +import "context" + +// ToolOverride allows replacing a tool's definition based on runtime conditions. +// Use this for the small number of tools that have different schemas/handlers +// depending on features, capabilities, or environment. +// +// Example usage: +// +// // Define overrides for tools that have enterprise-specific variants +// overrides := ToolOverrides{ +// "create_issue": { +// ToolName: "create_issue", +// Condition: func(ctx context.Context) (bool, error) { +// // Check if enterprise features are enabled for this request +// return isEnterpriseEnabled(ctx), nil +// }, +// Override: ServerTool{ +// Tool: mcp.Tool{ +// Name: "create_issue", +// Description: "Create an issue (Enterprise)", +// InputSchema: enterpriseCreateIssueSchema, // has extra fields +// }, +// Handler: enterpriseCreateIssueHandler, // different handler +// }, +// }, +// } +// +// // Apply overrides after building the base tool list +// tools := index.Materialize(ctx, queryResult) +// tools = overrides.ApplyToTools(ctx, tools) +type ToolOverride struct { + // ToolName is the canonical tool name to override + ToolName string + + // Condition returns true if this override should apply + Condition func(ctx context.Context) (bool, error) + + // Override is the replacement tool definition + Override ServerTool +} + +// ToolOverrides is a simple map for the few tools that need variant handling. +// Key is the tool name, value is the override to check. +type ToolOverrides map[string]ToolOverride + +// Apply checks if an override should be used for the given tool. +// Returns the override if condition matches, nil otherwise. +func (o ToolOverrides) Apply(ctx context.Context, toolName string) *ServerTool { + override, ok := o[toolName] + if !ok { + return nil + } + + if override.Condition == nil { + return &override.Override + } + + matches, err := override.Condition(ctx) + if err != nil || !matches { + return nil + } + + return &override.Override +} + +// ApplyToTools applies overrides to a list of tools, returning a new list +// with overridden tools replaced. Tools without overrides are unchanged. +// If no overrides match, returns the original slice (no allocation). +func (o ToolOverrides) ApplyToTools(ctx context.Context, tools []*ServerTool) []*ServerTool { + if len(o) == 0 { + return tools + } + + // First pass: check if any overrides apply (avoid allocation if not) + var result []*ServerTool + for i, tool := range tools { + override, hasOverride := o[tool.Tool.Name] + if !hasOverride { + if result != nil { + result[i] = tool + } + continue + } + + // Check condition + var applies bool + if override.Condition == nil { + applies = true + } else if matches, err := override.Condition(ctx); err == nil && matches { + applies = true + } + + if applies { + // Lazy allocation only when we find a match + if result == nil { + result = make([]*ServerTool, len(tools)) + copy(result[:i], tools[:i]) + } + result[i] = &override.Override + } else if result != nil { + result[i] = tool + } + } + + if result == nil { + return tools // No overrides matched, return original + } + return result +} diff --git a/pkg/inventory/tool_variants_test.go b/pkg/inventory/tool_variants_test.go new file mode 100644 index 000000000..343b7497b --- /dev/null +++ b/pkg/inventory/tool_variants_test.go @@ -0,0 +1,185 @@ +package inventory + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" +) + +func makeTool(name string) ServerTool { + return ServerTool{ + Tool: mcp.Tool{ + Name: name, + Description: "Tool " + name, + }, + } +} + +func makeOverride(name, desc string) ServerTool { + return ServerTool{ + Tool: mcp.Tool{ + Name: name, + Description: desc, + }, + } +} + +func TestToolOverrides_Apply(t *testing.T) { + t.Parallel() + + overrides := ToolOverrides{ + "create_issue": { + ToolName: "create_issue", + Condition: func(_ context.Context) (bool, error) { return true, nil }, + Override: makeOverride("create_issue", "Enterprise variant"), + }, + } + + ctx := context.Background() + + // Tool with override + result := overrides.Apply(ctx, "create_issue") + assert.NotNil(t, result) + assert.Equal(t, "Enterprise variant", result.Tool.Description) + + // Tool without override + result = overrides.Apply(ctx, "list_repos") + assert.Nil(t, result) +} + +func TestToolOverrides_Apply_ConditionFalse(t *testing.T) { + t.Parallel() + + overrides := ToolOverrides{ + "create_issue": { + ToolName: "create_issue", + Condition: func(_ context.Context) (bool, error) { return false, nil }, + Override: makeOverride("create_issue", "Enterprise variant"), + }, + } + + ctx := context.Background() + + // Condition doesn't match - no override + result := overrides.Apply(ctx, "create_issue") + assert.Nil(t, result) +} + +func TestToolOverrides_Apply_NilCondition(t *testing.T) { + t.Parallel() + + overrides := ToolOverrides{ + "create_issue": { + ToolName: "create_issue", + // nil Condition - always applies + Override: makeOverride("create_issue", "Always applied"), + }, + } + + ctx := context.Background() + + result := overrides.Apply(ctx, "create_issue") + assert.NotNil(t, result) + assert.Equal(t, "Always applied", result.Tool.Description) +} + +func TestToolOverrides_ApplyToTools(t *testing.T) { + t.Parallel() + + tools := []*ServerTool{ + ptr(makeTool("create_issue")), + ptr(makeTool("list_repos")), + ptr(makeTool("get_me")), + } + + overrides := ToolOverrides{ + "create_issue": { + ToolName: "create_issue", + Condition: func(_ context.Context) (bool, error) { return true, nil }, + Override: makeOverride("create_issue", "Enterprise create_issue"), + }, + } + + ctx := context.Background() + result := overrides.ApplyToTools(ctx, tools) + + assert.Len(t, result, 3) + assert.Equal(t, "Enterprise create_issue", result[0].Tool.Description) + assert.Equal(t, "Tool list_repos", result[1].Tool.Description) + assert.Equal(t, "Tool get_me", result[2].Tool.Description) +} + +func TestToolOverrides_ApplyToTools_Empty(t *testing.T) { + t.Parallel() + + tools := []*ServerTool{ + ptr(makeTool("create_issue")), + } + + overrides := ToolOverrides{} + + ctx := context.Background() + result := overrides.ApplyToTools(ctx, tools) + + // Empty overrides returns original slice + assert.Equal(t, tools, result) +} + +func ptr(t ServerTool) *ServerTool { + return &t +} + +func BenchmarkToolOverrides_ApplyToTools(b *testing.B) { + // 130 tools, 2 overrides (realistic) + tools := make([]*ServerTool, 130) + for i := range tools { + tools[i] = ptr(makeTool("tool_" + string(rune('a'+i%26)))) + } + + overrides := ToolOverrides{ + "tool_a": { + ToolName: "tool_a", + Condition: func(_ context.Context) (bool, error) { return true, nil }, + Override: makeOverride("tool_a", "Override A"), + }, + "tool_b": { + ToolName: "tool_b", + Condition: func(_ context.Context) (bool, error) { return true, nil }, + Override: makeOverride("tool_b", "Override B"), + }, + } + + ctx := context.Background() + + b.ReportAllocs() // Only count allocs in the hot loop + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = overrides.ApplyToTools(ctx, tools) + } +} + +func BenchmarkToolOverrides_ApplyToTools_NoMatch(b *testing.B) { + // 130 tools, overrides don't match any - should be zero alloc + tools := make([]*ServerTool, 130) + for i := range tools { + tools[i] = ptr(makeTool("tool_" + string(rune('a'+i%26)))) + } + + // Override exists but for a tool not in list + overrides := ToolOverrides{ + "nonexistent_tool": { + ToolName: "nonexistent_tool", + Override: makeOverride("nonexistent_tool", "Override"), + }, + } + + ctx := context.Background() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = overrides.ApplyToTools(ctx, tools) + } +}