88 "errors"
99 "fmt"
1010 "log/slog"
11+ "maps"
1112 "mime"
1213 "net/http"
1314 "slices"
@@ -33,6 +34,7 @@ import (
3334 "github.com/speakeasy-api/gram/server/internal/rag"
3435 tm "github.com/speakeasy-api/gram/server/internal/thirdparty/toolmetrics"
3536 "github.com/speakeasy-api/gram/server/internal/toolsets"
37+ "github.com/speakeasy-api/gram/server/internal/urn"
3638)
3739
3840type toolsCallParams struct {
@@ -104,7 +106,6 @@ func handleToolsCall(
104106 }
105107
106108 toolsetHelpers := toolsets .NewToolsets (db )
107- envSlug := payload .environment
108109 var tool * types.Tool
109110
110111 for _ , t := range toolset .Tools {
@@ -124,53 +125,16 @@ func handleToolsCall(
124125 return nil , oops .E (oops .CodeUnexpected , err , "failed to get tool urn" ).Log (ctx , logger )
125126 }
126127
127- ciEnv := gateway .NewCaseInsensitiveEnv ()
128-
129- envRepo := repo .New (db )
130- sourceEnv , err := envRepo .GetEnvironmentForSource (ctx , repo.GetEnvironmentForSourceParams {
131- SourceKind : string (toolURN .Kind ),
132- SourceSlug : toolURN .Source ,
133- })
134- if err != nil && ! errors .Is (err , pgx .ErrNoRows ) {
135- return nil , oops .E (oops .CodeUnexpected , err , "failed to get environment from source" ).Log (ctx , logger )
136- }
137-
138- if err == nil {
139- sourceEnvVars , err := env .Load (ctx , payload .projectID , gateway .ID (sourceEnv .ID ))
140- if err != nil && ! errors .Is (err , gateway .ErrNotFound ) {
141- return nil , oops .E (oops .CodeUnexpected , err , "failed to load source environment variables" ).Log (ctx , logger )
142- }
143-
144- for k , v := range sourceEnvVars {
145- ciEnv .Set (k , v )
146- }
147- }
148-
149- // IMPORTANT: MCP servers accessed in a public manner or not gram authenticated, there is no concept of using stored environments for them
150- if envSlug != "" && payload .authenticated {
151- storedEnvVars , err := env .Load (ctx , payload .projectID , gateway .Slug (envSlug ))
152- switch {
153- case errors .Is (err , gateway .ErrNotFound ):
154- return nil , oops .E (oops .CodeBadRequest , err , "environment not found" ).Log (ctx , logger )
155- case err != nil :
156- return nil , oops .E (oops .CodeUnexpected , err , "failed to load environment" ).Log (ctx , logger )
157- }
158-
159- for k , v := range storedEnvVars {
160- ciEnv .Set (k , v )
161- }
162- }
163-
164- // user supplied variables comes after stored environment variables to allow overrides in the case of conflicts
165- for k , v := range payload .mcpEnvVariables {
166- ciEnv .Set (k , v )
167- }
168-
169128 plan , err := toolsetHelpers .GetToolCallPlanByURN (ctx , * toolURN , uuid .UUID (projectID ))
170129 if err != nil {
171130 return nil , oops .E (oops .CodeUnexpected , err , "failed get tool call plan" ).Log (ctx , logger )
172131 }
173132
133+ ciEnv , err := resolveEnvironment (ctx , logger , db , env , toolURN , uuid .UUID (projectID ), payload , plan )
134+ if err != nil {
135+ return nil , err
136+ }
137+
174138 descriptor := plan .Descriptor
175139 var toolType tm.ToolType
176140 switch plan .Kind {
@@ -200,19 +164,6 @@ func handleToolsCall(
200164 }
201165
202166 ctx , logger = o11y .EnrichToolCallContext (ctx , logger , descriptor .OrganizationSlug , descriptor .ProjectSlug )
203- if plan .Kind == gateway .ToolKindHTTP {
204- for _ , security := range plan .HTTP .Security {
205- for _ , token := range payload .oauthTokenInputs {
206- if (slices .Contains (security .OAuthTypes , "authorization_code" ) || security .Type .Value == "openIdConnect" ) && (len (token .securityKeys ) == 0 || slices .Contains (token .securityKeys , security .Key )) {
207- for _ , envVar := range security .EnvVariables {
208- if strings .HasSuffix (envVar , "ACCESS_TOKEN" ) {
209- ciEnv .Set (envVar , token .Token )
210- }
211- }
212- }
213- }
214- }
215- }
216167
217168 rw := & toolCallResponseWriter {
218169 headers : make (http.Header ),
@@ -321,6 +272,156 @@ func handleToolsCall(
321272 return bs , nil
322273}
323274
275+ func resolveEnvironment (
276+ ctx context.Context ,
277+ logger * slog.Logger ,
278+ db * pgxpool.Pool ,
279+ env gateway.EnvironmentLoader ,
280+ toolURN * urn.Tool ,
281+ projectID uuid.UUID ,
282+ payload * mcpInputs ,
283+ plan * gateway.ToolCallPlan ,
284+ ) (* gateway.CaseInsensitiveEnv , error ) {
285+ systemVars , err := resolveSystemVariables (ctx , logger , db , env , toolURN , projectID )
286+ if err != nil {
287+ return nil , err
288+ }
289+
290+ userConfig , err := resolveUserConfiguration (ctx , logger , env , payload )
291+ if err != nil {
292+ return nil , err
293+ }
294+
295+ // IMPORTANT: when we receive any system environment variables, we _always_ disallow passing
296+ // through a user-supplied server URL. System environment variables should be invisible to users
297+ // and allowing them to pass in URL would allow them to exfiltrate those variables to their own servers.
298+ allowServerUrl := len (systemVars ) == 0
299+ filteredUserConfig := filterUserConfiguration (userConfig , systemVars , plan , allowServerUrl )
300+
301+ ciEnv := gateway .NewCaseInsensitiveEnv ()
302+ for k , v := range systemVars {
303+ ciEnv .Set (k , v )
304+ }
305+ for k , v := range filteredUserConfig {
306+ ciEnv .Set (k , v )
307+ }
308+
309+ if plan .Kind == gateway .ToolKindHTTP {
310+ for _ , security := range plan .HTTP .Security {
311+ for _ , token := range payload .oauthTokenInputs {
312+ if (slices .Contains (security .OAuthTypes , "authorization_code" ) || security .Type .Value == "openIdConnect" ) && (len (token .securityKeys ) == 0 || slices .Contains (token .securityKeys , security .Key )) {
313+ for _ , envVar := range security .EnvVariables {
314+ if strings .HasSuffix (envVar , "ACCESS_TOKEN" ) {
315+ ciEnv .Set (envVar , token .Token )
316+ }
317+ }
318+ }
319+ }
320+ }
321+ }
322+
323+ return ciEnv , nil
324+ }
325+
326+ func resolveSystemVariables (
327+ ctx context.Context ,
328+ logger * slog.Logger ,
329+ db * pgxpool.Pool ,
330+ env gateway.EnvironmentLoader ,
331+ toolURN * urn.Tool ,
332+ projectID uuid.UUID ,
333+ ) (map [string ]string , error ) {
334+ envRepo := repo .New (db )
335+ sourceEnv , err := envRepo .GetEnvironmentForSource (ctx , repo.GetEnvironmentForSourceParams {
336+ SourceKind : string (toolURN .Kind ),
337+ SourceSlug : toolURN .Source ,
338+ ProjectID : projectID ,
339+ })
340+ if err != nil && ! errors .Is (err , pgx .ErrNoRows ) {
341+ return nil , oops .E (oops .CodeUnexpected , err , "failed to get environment from source" ).Log (ctx , logger )
342+ }
343+
344+ if errors .Is (err , pgx .ErrNoRows ) {
345+ return map [string ]string {}, nil
346+ }
347+
348+ sourceEnvVars , err := env .Load (ctx , projectID , gateway .ID (sourceEnv .ID ))
349+ if err != nil && ! errors .Is (err , gateway .ErrNotFound ) {
350+ return nil , oops .E (oops .CodeUnexpected , err , "failed to load source environment variables" ).Log (ctx , logger )
351+ }
352+
353+ return sourceEnvVars , nil
354+ }
355+
356+ func resolveUserConfiguration (
357+ ctx context.Context ,
358+ logger * slog.Logger ,
359+ env gateway.EnvironmentLoader ,
360+ payload * mcpInputs ,
361+ ) (map [string ]string , error ) {
362+ userConfig := make (map [string ]string )
363+
364+ // IMPORTANT: MCP servers accessed in a public manner or not gram authenticated, there is no concept of using stored environments for them
365+ if payload .environment != "" && payload .authenticated {
366+ storedEnvVars , err := env .Load (ctx , payload .projectID , gateway .Slug (payload .environment ))
367+ switch {
368+ case errors .Is (err , gateway .ErrNotFound ):
369+ return nil , oops .E (oops .CodeBadRequest , err , "environment not found" ).Log (ctx , logger )
370+ case err != nil :
371+ return nil , oops .E (oops .CodeUnexpected , err , "failed to load environment" ).Log (ctx , logger )
372+ }
373+
374+ maps .Copy (userConfig , storedEnvVars )
375+ }
376+
377+ maps .Copy (userConfig , payload .mcpEnvVariables )
378+
379+ return userConfig , nil
380+ }
381+
382+ func filterUserConfiguration (userConfig map [string ]string , systemVars map [string ]string , plan * gateway.ToolCallPlan , allowServerUrl bool ) map [string ]string {
383+ filtered := make (map [string ]string )
384+ allowedByPlan := make (map [string ]bool )
385+
386+ switch plan .Kind {
387+ case gateway .ToolKindFunction :
388+ if plan .Function != nil {
389+ for _ , varName := range plan .Function .Variables {
390+ allowedByPlan [varName ] = true
391+ }
392+ }
393+
394+ case gateway .ToolKindHTTP :
395+ if plan .HTTP != nil {
396+ for _ , security := range plan .HTTP .Security {
397+ for _ , envVar := range security .EnvVariables {
398+ allowedByPlan [envVar ] = true
399+ }
400+ }
401+
402+ if allowServerUrl && plan .HTTP .ServerEnvVar != "" {
403+ allowedByPlan [plan .HTTP .ServerEnvVar ] = true
404+ }
405+ }
406+
407+ case gateway .ToolKindPrompt :
408+ return map [string ]string {}
409+ }
410+
411+ for key , value := range userConfig {
412+ _ , isSystemVar := systemVars [key ]
413+ _ , isAllowedByPlan := allowedByPlan [key ]
414+
415+ if isSystemVar && ! isAllowedByPlan {
416+ continue
417+ }
418+
419+ filtered [key ] = value
420+ }
421+
422+ return filtered
423+ }
424+
324425func checkToolUsageLimits (ctx context.Context , logger * slog.Logger , orgID string , accountType string , billingRepository billing.Repository ) error {
325426 if accountType != string (billing .TierFree ) {
326427 return nil
0 commit comments