Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/api/toolsets.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func NewToolCallResult(content string, err error) *ToolCallResult {
type ToolHandlerParams struct {
context.Context
*internalk8s.Kubernetes
internalk8s.ManagerProvider
ToolCallRequest
ListOutput output.Output
}
Expand Down
11 changes: 11 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ import (
"github.com/BurntSushi/toml"
)

const (
ClusterProviderKubeConfig = "kubeconfig"
ClusterProviderInCluster = "in-cluster"
)

// StaticConfig is the configuration for the server.
// It allows to configure server specific settings and tools to be enabled or disabled.
type StaticConfig struct {
Expand Down Expand Up @@ -49,6 +54,12 @@ type StaticConfig struct {
StsScopes []string `toml:"sts_scopes,omitempty"`
CertificateAuthority string `toml:"certificate_authority,omitempty"`
ServerURL string `toml:"server_url,omitempty"`
// ClusterProviderStrategy is how the server finds clusters.
// If set to "kubeconfig", the clusters will be loaded from those in the kubeconfig.
// If set to "in-cluster", the server will use the in cluster config
ClusterProviderStrategy string `toml:"cluster_provider_strategy,omitempty"`
// ClusterContexts is which context should be used for each cluster
ClusterContexts map[string]string `toml:"cluster_contexts"`
}

func Default() *StaticConfig {
Expand Down
53 changes: 49 additions & 4 deletions pkg/http/authorization.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package http

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

Expand All @@ -20,7 +23,44 @@ import (

type KubernetesApiTokenVerifier interface {
// KubernetesApiVerifyToken TODO: clarify proper implementation
KubernetesApiVerifyToken(ctx context.Context, token, audience string) (*authenticationapiv1.UserInfo, []string, error)
KubernetesApiVerifyToken(ctx context.Context, token, audience, cluster string) (*authenticationapiv1.UserInfo, []string, error)
// GetTargetParameterName returns the parameter name used for target identification in MCP requests
GetTargetParameterName() string
}

// extractTargetFromRequest extracts cluster parameter from MCP request body
func extractTargetFromRequest(r *http.Request, targetName string) (string, error) {
if r.Body == nil {
return "", nil
}

// Read the body
body, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}

// Restore the body for downstream handlers
r.Body = io.NopCloser(bytes.NewBuffer(body))

// Parse the MCP request
var mcpRequest struct {
Params struct {
Arguments map[string]interface{} `json:"arguments"`
} `json:"params"`
}

if err := json.Unmarshal(body, &mcpRequest); err != nil {
// If we can't parse the request, just return empty cluster (will use default)
return "", nil
}

// Extract target parameter
if cluster, ok := mcpRequest.Params.Arguments[targetName].(string); ok {
return cluster, nil
}

return "", nil
}

// AuthorizationMiddleware validates the OAuth flow for protected resources.
Expand Down Expand Up @@ -128,7 +168,12 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi
}
// Kubernetes API Server TokenReview validation
if err == nil && staticConfig.ValidateToken {
err = claims.ValidateWithKubernetesApi(r.Context(), staticConfig.OAuthAudience, verifier)
targetParameterName := verifier.GetTargetParameterName()
cluster, clusterErr := extractTargetFromRequest(r, targetParameterName)
if clusterErr != nil {
klog.V(2).Infof("Failed to extract cluster from request, using default: %v", clusterErr)
}
err = claims.ValidateWithKubernetesApi(r.Context(), staticConfig.OAuthAudience, cluster, verifier)
}
if err != nil {
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
Expand Down Expand Up @@ -198,9 +243,9 @@ func (c *JWTClaims) ValidateWithProvider(ctx context.Context, audience string, p
return nil
}

func (c *JWTClaims) ValidateWithKubernetesApi(ctx context.Context, audience string, verifier KubernetesApiTokenVerifier) error {
func (c *JWTClaims) ValidateWithKubernetesApi(ctx context.Context, audience, cluster string, verifier KubernetesApiTokenVerifier) error {
if verifier != nil {
_, _, err := verifier.KubernetesApiVerifyToken(ctx, c.Token, audience)
_, _, err := verifier.KubernetesApiVerifyToken(ctx, c.Token, audience, cluster)
if err != nil {
return fmt.Errorf("kubernetes API token validation error: %v", err)
}
Expand Down
60 changes: 38 additions & 22 deletions pkg/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func TestHealthCheck(t *testing.T) {
})
})
// Health exposed even when require Authorization
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get health check endpoint with OAuth: %v", err)
Expand All @@ -313,7 +313,7 @@ func TestWellKnownReverseProxy(t *testing.T) {
".well-known/openid-configuration",
}
// With No Authorization URL configured
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
for _, path := range cases {
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
t.Cleanup(func() { _ = resp.Body.Close() })
Expand All @@ -333,7 +333,12 @@ func TestWellKnownReverseProxy(t *testing.T) {
_, _ = w.Write([]byte(`NOT A JSON PAYLOAD`))
}))
t.Cleanup(invalidPayloadServer.Close)
invalidPayloadConfig := &config.StaticConfig{AuthorizationURL: invalidPayloadServer.URL, RequireOAuth: true, ValidateToken: true}
invalidPayloadConfig := &config.StaticConfig{
AuthorizationURL: invalidPayloadServer.URL,
RequireOAuth: true,
ValidateToken: true,
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
testCaseWithContext(t, &httpContext{StaticConfig: invalidPayloadConfig}, func(ctx *httpContext) {
for _, path := range cases {
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
Expand All @@ -358,7 +363,12 @@ func TestWellKnownReverseProxy(t *testing.T) {
_, _ = w.Write([]byte(`{"issuer": "https://example.com","scopes_supported":["mcp-server"]}`))
}))
t.Cleanup(testServer.Close)
staticConfig := &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}
staticConfig := &config.StaticConfig{
AuthorizationURL: testServer.URL,
RequireOAuth: true,
ValidateToken: true,
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig}, func(ctx *httpContext) {
for _, path := range cases {
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
Expand Down Expand Up @@ -401,7 +411,12 @@ func TestWellKnownOverrides(t *testing.T) {
}`))
}))
t.Cleanup(testServer.Close)
baseConfig := config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}
baseConfig := config.StaticConfig{
AuthorizationURL: testServer.URL,
RequireOAuth: true,
ValidateToken: true,
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
// With Dynamic Client Registration disabled
disableDynamicRegistrationConfig := baseConfig
disableDynamicRegistrationConfig.DisableDynamicClientRegistration = true
Expand Down Expand Up @@ -488,7 +503,7 @@ func TestMiddlewareLogging(t *testing.T) {

func TestAuthorizationUnauthorized(t *testing.T) {
// Missing Authorization header
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
Expand All @@ -513,7 +528,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Authorization header without Bearer prefix
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand All @@ -538,7 +553,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Invalid Authorization header
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -569,7 +584,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Expired Authorization Bearer token
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -600,7 +615,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Invalid audience claim Bearer token
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience", ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -633,7 +648,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
// Failed OIDC validation
oidcTestServer := NewOidcTestServer(t)
t.Cleanup(oidcTestServer.Close)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -670,7 +685,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
"aud": "mcp-server"
}`
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -703,7 +718,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
}

func TestAuthorizationRequireOAuthFalse(t *testing.T) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
Expand All @@ -728,7 +743,7 @@ func TestAuthorizationRawToken(t *testing.T) {
{"mcp-server", true}, // Audience set, validation enabled
}
for _, c := range cases {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: c.audience, ValidateToken: c.validateToken}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: c.audience, ValidateToken: c.validateToken, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
tokenReviewed := false
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
Expand Down Expand Up @@ -777,7 +792,7 @@ func TestAuthorizationOidcToken(t *testing.T) {
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
cases := []bool{false, true}
for _, validateToken := range cases {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
tokenReviewed := false
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
Expand Down Expand Up @@ -833,13 +848,14 @@ func TestAuthorizationOidcTokenExchange(t *testing.T) {
cases := []bool{false, true}
for _, validateToken := range cases {
staticConfig := &config.StaticConfig{
RequireOAuth: true,
OAuthAudience: "mcp-server",
ValidateToken: validateToken,
StsClientId: "test-sts-client-id",
StsClientSecret: "test-sts-client-secret",
StsAudience: "backend-audience",
StsScopes: []string{"backend-scope"},
RequireOAuth: true,
OAuthAudience: "mcp-server",
ValidateToken: validateToken,
StsClientId: "test-sts-client-id",
StsClientSecret: "test-sts-client-secret",
StsAudience: "backend-audience",
StsScopes: []string{"backend-scope"},
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
tokenReviewed := false
Expand Down
Loading