Skip to content

Commit b5da475

Browse files
committed
pass variables through in http tool calls
1 parent 04a9758 commit b5da475

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

server/internal/gateway/proxy.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,10 @@ func (tp *ToolProxy) doHTTP(
359359
}
360360
}
361361

362-
// environment variable overrides on tool calls typically defined in the SDK
362+
// We prefer tool call specified arguments over user-specified config
363363
if toolCallBody.EnvironmentVariables != nil {
364364
for k, v := range toolCallBody.EnvironmentVariables {
365-
env.Set(k, v)
365+
env.UserConfig.Set(k, v)
366366
}
367367
}
368368

@@ -845,9 +845,15 @@ func reverseProxyRequest(ctx context.Context, opts ReverseProxyOptions) error {
845845
return nil
846846
}
847847

848-
func processServerEnvVars(ctx context.Context, logger *slog.Logger, tool *HTTPToolCallPlan, envVars *CaseInsensitiveEnv) string {
848+
func processServerEnvVars(ctx context.Context, logger *slog.Logger, tool *HTTPToolCallPlan, env ToolCallEnv) string {
849+
// IMPORTANT: when system environment variables exist, we _always_ disallow user-supplied
850+
// server URLs to prevent exfiltration of system environment variables to user-controlled servers.
851+
if len(env.SystemEnv.All()) > 0 {
852+
return ""
853+
}
854+
849855
if tool.ServerEnvVar != "" {
850-
envVar := envVars.Get(tool.ServerEnvVar)
856+
envVar := env.UserConfig.Get(tool.ServerEnvVar)
851857
if envVar != "" {
852858
return envVar
853859
} else {

server/internal/gateway/security.go

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ func processSecurity(
3333
serverURL string,
3434
toolCallLogger tm.ToolCallLogger,
3535
) bool {
36+
// Merge user config into system config: system env is base, user config fills gaps
37+
mergedEnv := NewCaseInsensitiveEnv()
38+
for k, v := range env.SystemEnv.All() {
39+
mergedEnv.Set(k, v)
40+
}
41+
for k, v := range env.UserConfig.All() {
42+
if mergedEnv.Get(k) == "" {
43+
mergedEnv.Set(k, v)
44+
}
45+
}
46+
3647
securityHeadersProcessed := make(map[string]string)
3748
setHeader := func(key, value string) {
3849
req.Header.Set(key, value)
@@ -52,18 +63,18 @@ func processSecurity(
5263
case "apiKey":
5364
if len(security.EnvVariables) == 0 {
5465
logger.ErrorContext(ctx, "no environment variables provided for api key auth", attr.SlogSecurityScheme(security.Scheme.Value))
55-
} else if envVars.Get(security.EnvVariables[0]) == "" {
66+
} else if mergedEnv.Get(security.EnvVariables[0]) == "" {
5667
logger.ErrorContext(ctx, "missing value for environment variable in api key auth", attr.SlogEnvVarName(security.EnvVariables[0]), attr.SlogSecurityScheme(security.Scheme.Value))
5768
} else if !security.Name.Valid || security.Name.Value == "" {
5869
logger.ErrorContext(ctx, "no name provided for api key auth", attr.SlogSecurityScheme(security.Scheme.Value))
5970
} else {
6071
key := security.EnvVariables[0]
6172
switch security.Placement.Value {
6273
case "header":
63-
setHeader(security.Name.Value, envVars.Get(key))
74+
setHeader(security.Name.Value, mergedEnv.Get(key))
6475
case "query":
6576
values := req.URL.Query()
66-
values.Set(security.Name.Value, envVars.Get(key))
77+
values.Set(security.Name.Value, mergedEnv.Get(key))
6778
req.URL.RawQuery = values.Encode()
6879
default:
6980
logger.ErrorContext(ctx, "unsupported api key placement", attr.SlogSecurityPlacement(security.Placement.Value))
@@ -74,10 +85,10 @@ func processSecurity(
7485
case "bearer":
7586
if len(security.EnvVariables) == 0 {
7687
logger.ErrorContext(ctx, "no environment variables provided for bearer auth", attr.SlogSecurityScheme(security.Scheme.Value))
77-
} else if envVars.Get(security.EnvVariables[0]) == "" {
88+
} else if mergedEnv.Get(security.EnvVariables[0]) == "" {
7889
logger.ErrorContext(ctx, "token value is empty for bearer auth", attr.SlogEnvVarName(security.EnvVariables[0]), attr.SlogSecurityScheme(security.Scheme.Value))
7990
} else {
80-
token := envVars.Get(security.EnvVariables[0])
91+
token := mergedEnv.Get(security.EnvVariables[0])
8192
setHeader("Authorization", formatForBearer(token))
8293
}
8394
case "basic":
@@ -87,9 +98,9 @@ func processSecurity(
8798
var username, password string
8899
for _, envVar := range security.EnvVariables {
89100
if strings.Contains(envVar, "USERNAME") {
90-
username = envVars.Get(envVar)
101+
username = mergedEnv.Get(envVar)
91102
} else if strings.Contains(envVar, "PASSWORD") {
92-
password = envVars.Get(envVar)
103+
password = mergedEnv.Get(envVar)
93104
}
94105
}
95106

@@ -110,7 +121,7 @@ func processSecurity(
110121
case "openIdConnect":
111122
for _, envVar := range security.EnvVariables {
112123
if strings.Contains(envVar, "ACCESS_TOKEN") {
113-
if token := envVars.Get(envVar); token == "" {
124+
if token := mergedEnv.Get(envVar); token == "" {
114125
logger.ErrorContext(ctx, "missing authorization code", attr.SlogEnvVarName(envVar))
115126
} else {
116127
setHeader("Authorization", formatForBearer(token))
@@ -127,15 +138,15 @@ func processSecurity(
127138
case "authorization_code", "implicit":
128139
for _, envVar := range security.EnvVariables {
129140
if strings.Contains(envVar, "ACCESS_TOKEN") {
130-
if token := envVars.Get(envVar); token == "" {
141+
if token := mergedEnv.Get(envVar); token == "" {
131142
logger.ErrorContext(ctx, "missing authorization code", attr.SlogEnvVarName(envVar))
132143
} else {
133144
setHeader("Authorization", formatForBearer(token))
134145
}
135146
}
136147
}
137148
case "client_credentials":
138-
token, err := processClientCredentials(ctx, logger, req, cacheImpl, tool, plan.SecurityScopes, security, envVars, serverURL)
149+
token, err := processClientCredentials(ctx, logger, req, cacheImpl, tool, plan.SecurityScopes, security, mergedEnv, serverURL)
139150
if err != nil {
140151
logger.ErrorContext(ctx, "could not process client credentials", attr.SlogError(err))
141152
if strings.Contains(err.Error(), "failed to make client credentials token request") {
@@ -161,6 +172,13 @@ func processSecurity(
161172
}
162173
}
163174

175+
for key, value := range env.SystemEnv.All() {
176+
canonicalKey := http.CanonicalHeaderKey(key)
177+
if _, alreadyProcessed := securityHeadersProcessed[canonicalKey]; !alreadyProcessed {
178+
req.Header.Set(key, value)
179+
}
180+
}
181+
164182
return true
165183
}
166184

@@ -225,21 +243,21 @@ type clientCredentialsTokenResponseCamelCase struct {
225243
ExpiresIn int `json:"expiresIn"`
226244
}
227245

228-
func processClientCredentials(ctx context.Context, logger *slog.Logger, req *http.Request, cacheImpl cache.Cache, tool *ToolDescriptor, planScopes map[string][]string, security *HTTPToolSecurity, envVars *CaseInsensitiveEnv, serverURL string) (string, error) {
246+
func processClientCredentials(ctx context.Context, logger *slog.Logger, req *http.Request, cacheImpl cache.Cache, tool *ToolDescriptor, planScopes map[string][]string, security *HTTPToolSecurity, mergedEnv *CaseInsensitiveEnv, serverURL string) (string, error) {
229247
// To discuss, currently we are taking the approach of exact scope match for reused tokens
230248
// We could look into enabling a prefix match feature for caches where we return multiple entries matching the projectID, clientID, tokenURL and then check scopes against all returned values
231249
// We would want to make sure any underlying cache implementation supports this feature
232250
tokenCache := cache.NewTypedObjectCache[clientCredentialsTokenCache](logger.With(attr.SlogCacheNamespace("client_credentials_token_cache")), cacheImpl, cache.SuffixNone)
233251
var clientSecret, clientID, tokenURLOverride, accessToken string
234252
for _, v := range security.EnvVariables {
235253
if strings.Contains(v, "CLIENT_SECRET") {
236-
clientSecret = envVars.Get(v)
254+
clientSecret = mergedEnv.Get(v)
237255
} else if strings.Contains(v, "CLIENT_ID") {
238-
clientID = envVars.Get(v)
256+
clientID = mergedEnv.Get(v)
239257
} else if strings.Contains(v, "TOKEN_URL") {
240-
tokenURLOverride = envVars.Get(v)
258+
tokenURLOverride = mergedEnv.Get(v)
241259
} else if strings.Contains(v, "ACCESS_TOKEN") {
242-
accessToken = envVars.Get(v)
260+
accessToken = mergedEnv.Get(v)
243261
}
244262
}
245263

0 commit comments

Comments
 (0)