Skip to content

Commit 405adfc

Browse files
committed
extract environment stuff to separate functions
move oauth token into environement resolution exclude server url when system environments are used pass through all environment variables not specified by plan
1 parent 7905137 commit 405adfc

File tree

1 file changed

+157
-56
lines changed

1 file changed

+157
-56
lines changed

server/internal/mcp/rpc_tools_call.go

Lines changed: 157 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99
"fmt"
1010
"log/slog"
11+
"maps"
1112
"mime"
1213
"net/http"
1314
"slices"
@@ -33,6 +34,7 @@ import (
3334
"github.com/speakeasy-api/gram/server/internal/rag"
3435
tm "github.com/speakeasy-api/gram/server/internal/thirdparty/toolmetrics"
3536
"github.com/speakeasy-api/gram/server/internal/toolsets"
37+
"github.com/speakeasy-api/gram/server/internal/urn"
3638
)
3739

3840
type toolsCallParams struct {
@@ -104,7 +106,6 @@ func handleToolsCall(
104106
}
105107

106108
toolsetHelpers := toolsets.NewToolsets(db)
107-
envSlug := payload.environment
108109
var tool *types.Tool
109110

110111
for _, t := range toolset.Tools {
@@ -124,53 +125,16 @@ func handleToolsCall(
124125
return nil, oops.E(oops.CodeUnexpected, err, "failed to get tool urn").Log(ctx, logger)
125126
}
126127

127-
ciEnv := gateway.NewCaseInsensitiveEnv()
128-
129-
envRepo := repo.New(db)
130-
sourceEnv, err := envRepo.GetEnvironmentForSource(ctx, repo.GetEnvironmentForSourceParams{
131-
SourceKind: string(toolURN.Kind),
132-
SourceSlug: toolURN.Source,
133-
})
134-
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
135-
return nil, oops.E(oops.CodeUnexpected, err, "failed to get environment from source").Log(ctx, logger)
136-
}
137-
138-
if err == nil {
139-
sourceEnvVars, err := env.Load(ctx, payload.projectID, gateway.ID(sourceEnv.ID))
140-
if err != nil && !errors.Is(err, gateway.ErrNotFound) {
141-
return nil, oops.E(oops.CodeUnexpected, err, "failed to load source environment variables").Log(ctx, logger)
142-
}
143-
144-
for k, v := range sourceEnvVars {
145-
ciEnv.Set(k, v)
146-
}
147-
}
148-
149-
// IMPORTANT: MCP servers accessed in a public manner or not gram authenticated, there is no concept of using stored environments for them
150-
if envSlug != "" && payload.authenticated {
151-
storedEnvVars, err := env.Load(ctx, payload.projectID, gateway.Slug(envSlug))
152-
switch {
153-
case errors.Is(err, gateway.ErrNotFound):
154-
return nil, oops.E(oops.CodeBadRequest, err, "environment not found").Log(ctx, logger)
155-
case err != nil:
156-
return nil, oops.E(oops.CodeUnexpected, err, "failed to load environment").Log(ctx, logger)
157-
}
158-
159-
for k, v := range storedEnvVars {
160-
ciEnv.Set(k, v)
161-
}
162-
}
163-
164-
// user supplied variables comes after stored environment variables to allow overrides in the case of conflicts
165-
for k, v := range payload.mcpEnvVariables {
166-
ciEnv.Set(k, v)
167-
}
168-
169128
plan, err := toolsetHelpers.GetToolCallPlanByURN(ctx, *toolURN, uuid.UUID(projectID))
170129
if err != nil {
171130
return nil, oops.E(oops.CodeUnexpected, err, "failed get tool call plan").Log(ctx, logger)
172131
}
173132

133+
ciEnv, err := resolveEnvironment(ctx, logger, db, env, toolURN, uuid.UUID(projectID), payload, plan)
134+
if err != nil {
135+
return nil, err
136+
}
137+
174138
descriptor := plan.Descriptor
175139
var toolType tm.ToolType
176140
switch plan.Kind {
@@ -200,19 +164,6 @@ func handleToolsCall(
200164
}
201165

202166
ctx, logger = o11y.EnrichToolCallContext(ctx, logger, descriptor.OrganizationSlug, descriptor.ProjectSlug)
203-
if plan.Kind == gateway.ToolKindHTTP {
204-
for _, security := range plan.HTTP.Security {
205-
for _, token := range payload.oauthTokenInputs {
206-
if (slices.Contains(security.OAuthTypes, "authorization_code") || security.Type.Value == "openIdConnect") && (len(token.securityKeys) == 0 || slices.Contains(token.securityKeys, security.Key)) {
207-
for _, envVar := range security.EnvVariables {
208-
if strings.HasSuffix(envVar, "ACCESS_TOKEN") {
209-
ciEnv.Set(envVar, token.Token)
210-
}
211-
}
212-
}
213-
}
214-
}
215-
}
216167

217168
rw := &toolCallResponseWriter{
218169
headers: make(http.Header),
@@ -321,6 +272,156 @@ func handleToolsCall(
321272
return bs, nil
322273
}
323274

275+
func resolveEnvironment(
276+
ctx context.Context,
277+
logger *slog.Logger,
278+
db *pgxpool.Pool,
279+
env gateway.EnvironmentLoader,
280+
toolURN *urn.Tool,
281+
projectID uuid.UUID,
282+
payload *mcpInputs,
283+
plan *gateway.ToolCallPlan,
284+
) (*gateway.CaseInsensitiveEnv, error) {
285+
systemVars, err := resolveSystemVariables(ctx, logger, db, env, toolURN, projectID)
286+
if err != nil {
287+
return nil, err
288+
}
289+
290+
userConfig, err := resolveUserConfiguration(ctx, logger, env, payload)
291+
if err != nil {
292+
return nil, err
293+
}
294+
295+
// IMPORTANT: when we receive any system environment variables, we _always_ disallow passing
296+
// through a user-supplied server URL. System environment variables should be invisible to users
297+
// and allowing them to pass in URL would allow them to exfiltrate those variables to their own servers.
298+
allowServerUrl := len(systemVars) == 0
299+
filteredUserConfig := filterUserConfiguration(userConfig, systemVars, plan, allowServerUrl)
300+
301+
ciEnv := gateway.NewCaseInsensitiveEnv()
302+
for k, v := range systemVars {
303+
ciEnv.Set(k, v)
304+
}
305+
for k, v := range filteredUserConfig {
306+
ciEnv.Set(k, v)
307+
}
308+
309+
if plan.Kind == gateway.ToolKindHTTP {
310+
for _, security := range plan.HTTP.Security {
311+
for _, token := range payload.oauthTokenInputs {
312+
if (slices.Contains(security.OAuthTypes, "authorization_code") || security.Type.Value == "openIdConnect") && (len(token.securityKeys) == 0 || slices.Contains(token.securityKeys, security.Key)) {
313+
for _, envVar := range security.EnvVariables {
314+
if strings.HasSuffix(envVar, "ACCESS_TOKEN") {
315+
ciEnv.Set(envVar, token.Token)
316+
}
317+
}
318+
}
319+
}
320+
}
321+
}
322+
323+
return ciEnv, nil
324+
}
325+
326+
func resolveSystemVariables(
327+
ctx context.Context,
328+
logger *slog.Logger,
329+
db *pgxpool.Pool,
330+
env gateway.EnvironmentLoader,
331+
toolURN *urn.Tool,
332+
projectID uuid.UUID,
333+
) (map[string]string, error) {
334+
envRepo := repo.New(db)
335+
sourceEnv, err := envRepo.GetEnvironmentForSource(ctx, repo.GetEnvironmentForSourceParams{
336+
SourceKind: string(toolURN.Kind),
337+
SourceSlug: toolURN.Source,
338+
ProjectID: projectID,
339+
})
340+
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
341+
return nil, oops.E(oops.CodeUnexpected, err, "failed to get environment from source").Log(ctx, logger)
342+
}
343+
344+
if errors.Is(err, pgx.ErrNoRows) {
345+
return map[string]string{}, nil
346+
}
347+
348+
sourceEnvVars, err := env.Load(ctx, projectID, gateway.ID(sourceEnv.ID))
349+
if err != nil && !errors.Is(err, gateway.ErrNotFound) {
350+
return nil, oops.E(oops.CodeUnexpected, err, "failed to load source environment variables").Log(ctx, logger)
351+
}
352+
353+
return sourceEnvVars, nil
354+
}
355+
356+
func resolveUserConfiguration(
357+
ctx context.Context,
358+
logger *slog.Logger,
359+
env gateway.EnvironmentLoader,
360+
payload *mcpInputs,
361+
) (map[string]string, error) {
362+
userConfig := make(map[string]string)
363+
364+
// IMPORTANT: MCP servers accessed in a public manner or not gram authenticated, there is no concept of using stored environments for them
365+
if payload.environment != "" && payload.authenticated {
366+
storedEnvVars, err := env.Load(ctx, payload.projectID, gateway.Slug(payload.environment))
367+
switch {
368+
case errors.Is(err, gateway.ErrNotFound):
369+
return nil, oops.E(oops.CodeBadRequest, err, "environment not found").Log(ctx, logger)
370+
case err != nil:
371+
return nil, oops.E(oops.CodeUnexpected, err, "failed to load environment").Log(ctx, logger)
372+
}
373+
374+
maps.Copy(userConfig, storedEnvVars)
375+
}
376+
377+
maps.Copy(userConfig, payload.mcpEnvVariables)
378+
379+
return userConfig, nil
380+
}
381+
382+
func filterUserConfiguration(userConfig map[string]string, systemVars map[string]string, plan *gateway.ToolCallPlan, allowServerUrl bool) map[string]string {
383+
filtered := make(map[string]string)
384+
allowedByPlan := make(map[string]bool)
385+
386+
switch plan.Kind {
387+
case gateway.ToolKindFunction:
388+
if plan.Function != nil {
389+
for _, varName := range plan.Function.Variables {
390+
allowedByPlan[varName] = true
391+
}
392+
}
393+
394+
case gateway.ToolKindHTTP:
395+
if plan.HTTP != nil {
396+
for _, security := range plan.HTTP.Security {
397+
for _, envVar := range security.EnvVariables {
398+
allowedByPlan[envVar] = true
399+
}
400+
}
401+
402+
if allowServerUrl && plan.HTTP.ServerEnvVar != "" {
403+
allowedByPlan[plan.HTTP.ServerEnvVar] = true
404+
}
405+
}
406+
407+
case gateway.ToolKindPrompt:
408+
return map[string]string{}
409+
}
410+
411+
for key, value := range userConfig {
412+
_, isSystemVar := systemVars[key]
413+
_, isAllowedByPlan := allowedByPlan[key]
414+
415+
if isSystemVar && !isAllowedByPlan {
416+
continue
417+
}
418+
419+
filtered[key] = value
420+
}
421+
422+
return filtered
423+
}
424+
324425
func checkToolUsageLimits(ctx context.Context, logger *slog.Logger, orgID string, accountType string, billingRepository billing.Repository) error {
325426
if accountType != string(billing.TierFree) {
326427
return nil

0 commit comments

Comments
 (0)