Skip to content

Commit 9f61990

Browse files
committed
modify schema for environments in tool call proxy
move ci env switch environment parameter for tool proxy
1 parent 405adfc commit 9f61990

File tree

5 files changed

+70
-31
lines changed

5 files changed

+70
-31
lines changed

server/internal/environments/shared.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,31 @@ func (e *EnvironmentEntries) Load(ctx context.Context, projectID uuid.UUID, envI
6666
return envMap, nil
6767
}
6868

69+
func (e *EnvironmentEntries) LoadSourceEnv(ctx context.Context, projectID uuid.UUID, sourceKind string, sourceSlug string) (map[string]string, error) {
70+
sourceEnv, err := e.repo.GetEnvironmentForSource(ctx, repo.GetEnvironmentForSourceParams{
71+
SourceKind: sourceKind,
72+
SourceSlug: sourceSlug,
73+
ProjectID: projectID,
74+
})
75+
if err != nil {
76+
if errors.Is(err, sql.ErrNoRows) {
77+
return map[string]string{}, nil
78+
}
79+
return nil, fmt.Errorf("get environment for source: %w", err)
80+
}
81+
82+
entries, err := e.ListEnvironmentEntries(ctx, projectID, sourceEnv.ID, false)
83+
if err != nil {
84+
return nil, fmt.Errorf("list environment entries: %w", err)
85+
}
86+
87+
envMap := make(map[string]string, len(entries))
88+
for _, entry := range entries {
89+
envMap[entry.Name] = entry.Value
90+
}
91+
return envMap, nil
92+
}
93+
6994
func (e *EnvironmentEntries) ListEnvironmentEntries(ctx context.Context, projectID uuid.UUID, environmentID uuid.UUID, redacted bool) ([]repo.EnvironmentEntry, error) {
7095
entries, err := e.repo.ListEnvironmentEntries(ctx, repo.ListEnvironmentEntriesParams{
7196
ProjectID: projectID,

server/internal/gateway/env.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package gateway
33
import (
44
"context"
55
"fmt"
6+
"strings"
67

78
"github.com/google/uuid"
89
)
@@ -14,6 +15,13 @@ type EnvironmentLoader interface {
1415
// * [ErrNotFound]: when the environment does not exist.
1516
// * `error`: when an unrecognized error occurs.
1617
Load(ctx context.Context, projectID uuid.UUID, environmentID SlugOrID) (map[string]string, error)
18+
19+
// LoadSourceEnv retrieves the environment variables associated with a tool's source.
20+
// Returns an empty map if no source environment exists.
21+
//
22+
// # Errors
23+
// * `error`: when an unrecognized error occurs.
24+
LoadSourceEnv(ctx context.Context, projectID uuid.UUID, sourceKind string, sourceSlug string) (map[string]string, error)
1725
}
1826

1927
type SlugOrID struct {
@@ -51,3 +59,32 @@ func (s *SlugOrID) String() string {
5159
func (s *SlugOrID) IsEmpty() bool {
5260
return s.ID == uuid.Nil && s.Slug == ""
5361
}
62+
63+
type CaseInsensitiveEnv struct {
64+
data map[string]string
65+
}
66+
67+
func NewCaseInsensitiveEnv() *CaseInsensitiveEnv {
68+
return &CaseInsensitiveEnv{data: make(map[string]string)}
69+
}
70+
71+
func CIEnvFrom(vars map[string]string) *CaseInsensitiveEnv {
72+
env := NewCaseInsensitiveEnv()
73+
for k, v := range vars {
74+
env.Set(k, v)
75+
}
76+
return env
77+
}
78+
79+
func (c *CaseInsensitiveEnv) Get(key string) string {
80+
return c.data[strings.ToLower(key)]
81+
}
82+
83+
func (c *CaseInsensitiveEnv) Set(key, value string) {
84+
c.data[strings.ToLower(key)] = value
85+
}
86+
87+
type ToolCallEnv struct {
88+
SystemEnv *CaseInsensitiveEnv
89+
UserConfig *CaseInsensitiveEnv
90+
}

server/internal/gateway/proxy.go

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -74,26 +74,6 @@ type ToolCallBody struct {
7474
GramRequestSummary string `json:"gram-request-summary"`
7575
}
7676

77-
// CaseInsensitiveEnv provides case-insensitive environment variable lookup.
78-
type CaseInsensitiveEnv struct {
79-
data map[string]string
80-
}
81-
82-
// NewCaseInsensitiveEnv creates a new empty case-insensitive environment.
83-
func NewCaseInsensitiveEnv() *CaseInsensitiveEnv {
84-
return &CaseInsensitiveEnv{data: make(map[string]string)}
85-
}
86-
87-
// Get retrieves an environment variable value by key (case-insensitive).
88-
func (c *CaseInsensitiveEnv) Get(key string) string {
89-
return c.data[strings.ToLower(key)]
90-
}
91-
92-
// Set sets an environment variable value by key (case-insensitive).
93-
func (c *CaseInsensitiveEnv) Set(key, value string) {
94-
c.data[strings.ToLower(key)] = value
95-
}
96-
9777
type toolcallErrorSchema struct {
9878
Error string `json:"error"`
9979
}
@@ -146,7 +126,7 @@ func (tp *ToolProxy) Do(
146126
ctx context.Context,
147127
w http.ResponseWriter,
148128
requestBody io.Reader,
149-
env *CaseInsensitiveEnv,
129+
env ToolCallEnv,
150130
plan *ToolCallPlan,
151131
toolCallLogger tm.ToolCallLogger,
152132
) (err error) {
@@ -172,10 +152,6 @@ func (tp *ToolProxy) Do(
172152
attr.SlogToolCallSource(string(tp.source)),
173153
)
174154

175-
if env == nil {
176-
env = NewCaseInsensitiveEnv()
177-
}
178-
179155
switch plan.Kind {
180156
case "":
181157
return oops.E(oops.CodeInvariantViolation, nil, "tool kind is not set").Log(ctx, tp.logger)
@@ -195,7 +171,7 @@ func (tp *ToolProxy) doFunction(
195171
logger *slog.Logger,
196172
w http.ResponseWriter,
197173
requestBody io.Reader,
198-
env *CaseInsensitiveEnv,
174+
env ToolCallEnv,
199175
descriptor *ToolDescriptor,
200176
plan *FunctionToolCallPlan,
201177
toolCallLogger tm.ToolCallLogger,
@@ -307,7 +283,7 @@ func (tp *ToolProxy) doHTTP(
307283
logger *slog.Logger,
308284
w http.ResponseWriter,
309285
requestBody io.Reader,
310-
env *CaseInsensitiveEnv,
286+
env ToolCallEnv,
311287
descriptor *ToolDescriptor,
312288
plan *HTTPToolCallPlan,
313289
toolCallLogger tm.ToolCallLogger,
@@ -570,7 +546,7 @@ type promptGetParams struct {
570546
Arguments map[string]any `json:"arguments"`
571547
}
572548

573-
func (tp *ToolProxy) doPrompt(ctx context.Context, logger *slog.Logger, w http.ResponseWriter, requestBody io.Reader, env *CaseInsensitiveEnv, descriptor *ToolDescriptor, plan *PromptToolCallPlan) error {
549+
func (tp *ToolProxy) doPrompt(ctx context.Context, logger *slog.Logger, w http.ResponseWriter, requestBody io.Reader, env ToolCallEnv, descriptor *ToolDescriptor, plan *PromptToolCallPlan) error {
574550
var params promptGetParams
575551
if err := json.NewDecoder(requestBody).Decode(&params); err != nil {
576552
return oops.E(oops.CodeBadRequest, err, "failed to parse get prompt request").Log(ctx, logger)
@@ -622,7 +598,7 @@ func retryWithBackoff(
622598
if err != nil {
623599
continue
624600
}
625-
601+
// check if we should retry based on method and status code
626602
if !slices.Contains(retryBackoff.methods, resp.Request.Method) || !slices.Contains(retryBackoff.statusCodes, resp.StatusCode) {
627603
return resp, err
628604
}

server/internal/gateway/security.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func processSecurity(
2929
tool *ToolDescriptor,
3030
plan *HTTPToolCallPlan,
3131
cacheImpl cache.Cache,
32-
envVars *CaseInsensitiveEnv,
32+
env ToolCallEnv,
3333
serverURL string,
3434
toolCallLogger tm.ToolCallLogger,
3535
) bool {

server/internal/mcp/rpc_tools_call.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ func resolveUserConfiguration(
361361
) (map[string]string, error) {
362362
userConfig := make(map[string]string)
363363

364-
// IMPORTANT: MCP servers accessed in a public manner or not gram authenticated, there is no concept of using stored environments for them
364+
// IMPORTANT: we must only attach gram environments to authenticated payloads. Gram environments contain
365+
// secrets owned by Gram projects and should not be usable by public clients
365366
if payload.environment != "" && payload.authenticated {
366367
storedEnvVars, err := env.Load(ctx, payload.projectID, gateway.Slug(payload.environment))
367368
switch {

0 commit comments

Comments
 (0)