From 2ade55b26c6e9246e4557b576f0acb43888391ca Mon Sep 17 00:00:00 2001 From: Alex Cohen Date: Tue, 7 Oct 2025 12:44:15 -0400 Subject: [PATCH 1/2] Add LiteLLM Wrapper support (AST-113982) --- example/main.go | 58 ++++++++++------ internal/genaiProxyInternal.go | 119 --------------------------------- internal/gpt.go | 16 ++--- internal/litellm_wrapper.go | 110 ++++++++++++++++++++++++++++++ pkg/wrapper/litellm_wrapper.go | 54 +++++++++++++++ 5 files changed, 210 insertions(+), 147 deletions(-) delete mode 100644 internal/genaiProxyInternal.go create mode 100644 internal/litellm_wrapper.go create mode 100644 pkg/wrapper/litellm_wrapper.go diff --git a/example/main.go b/example/main.go index c634cdf..1d6d094 100644 --- a/example/main.go +++ b/example/main.go @@ -3,14 +3,15 @@ package main import ( "flag" "fmt" - "github.com/Checkmarx/gen-ai-wrapper/pkg/connector" + "os" + "strings" + + "github.com/Checkmarx/gen-ai-wrapper/internal" "github.com/Checkmarx/gen-ai-wrapper/pkg/message" "github.com/Checkmarx/gen-ai-wrapper/pkg/models" "github.com/Checkmarx/gen-ai-wrapper/pkg/role" "github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper" "github.com/google/uuid" - "os" - "strings" ) const usage = ` @@ -23,9 +24,8 @@ Options -s, --system system (or developer) prompt string -u, --user user prompt string -id chat conversation ID - -ai AI server to use. Options: {OpenAI (default), CxOne} - -m, --model model to use. Options: {gpt-4o (default), gpt-4, o1, o1-mini, ...} - -f, --full-response return full response from AI + -ai AI server to use. Options: {OpenAI (default), CxOne, LiteLLM} + -m, --model model to use. Options: {gpt-4o (default), gpt-4, o1, o1-mini, claude-3-5-sonnet-20241022, ...} -h, --help show help ` @@ -103,26 +103,46 @@ func CallAIandPrintResponse(aiServer, model, systemPrompt, userPrompt string, ch return err } - statefulWrapper, err := wrapper.NewStatefulWrapperNew( - connector.NewFileSystemConnector(""), aiEndpoint, aiKey, model, 4, 0) + var litellmWrapper wrapper.LitellmWrapper + + // Use litellm wrapper for litellm server + if strings.EqualFold(aiServer, "LiteLLM") { + litellmWrapper, err = wrapper.NewLitellmWrapper(aiEndpoint, aiKey, model) + } else { + // For other servers, we'll need to implement or use existing wrappers + return fmt.Errorf("unsupported AI server: %s", aiServer) + } + if err != nil { return fmt.Errorf("error creating '%s' AI client: %v", aiServer, err) } newMessages := GetMessages(model, systemPrompt, userPrompt) + // Create proper metadata for the request + metaData := &message.MetaData{ + RequestID: "example-request-" + chatId.String(), + TenantID: "default-tenant", + UserAgent: "gen-ai-wrapper-example", + Feature: "chat-completion", + } + + // Create the request + request := &internal.ChatCompletionRequest{ + Model: model, + Messages: newMessages, + } + + // Make the call + response, err := litellmWrapper.Call(aiKey, metaData, request) + if err != nil { + return fmt.Errorf("error calling litellm: %v", err) + } + if fullResponse { - response, err := statefulWrapper.SecureCallReturningFullResponse("", nil, chatId, newMessages) - if err != nil { - return fmt.Errorf("error calling GPT: %v", err) - } fmt.Printf("%+v\n", response) } else { - response, err := statefulWrapper.Call(chatId, newMessages) - if err != nil { - return fmt.Errorf("error calling GPT: %v", err) - } - fmt.Println(getMessageContents(response)) + fmt.Println(response.Choices[0].Message.Content) } return nil } @@ -156,7 +176,7 @@ func getAIAccessKey(aiServer, model string) (string, error) { } return accessKey, nil } - if strings.EqualFold(aiServer, "CxOne") { + if strings.EqualFold(aiServer, "CxOne") || strings.EqualFold(aiServer, "LiteLLM") { accessKey, err := GetCxOneAIAccessKey() if err != nil { return "", fmt.Errorf("error getting CxOne AI API key: %v", err) @@ -174,7 +194,7 @@ func getAIEndpoint(aiServer string) (string, error) { } return aiEndpoint, nil } - if strings.EqualFold(aiServer, "CxOne") { + if strings.EqualFold(aiServer, "CxOne") || strings.EqualFold(aiServer, "LiteLLM") { aiEndpoint, err := GetCxOneAIEndpoint() if err != nil { return "", fmt.Errorf("error getting CxOne AI endpoint: %v", err) diff --git a/internal/genaiProxyInternal.go b/internal/genaiProxyInternal.go deleted file mode 100644 index d6308ee..0000000 --- a/internal/genaiProxyInternal.go +++ /dev/null @@ -1,119 +0,0 @@ -package internal - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "io" - "net/http" - - "github.com/Checkmarx/gen-ai-wrapper/internal/api/redirect_prompt" - "github.com/Checkmarx/gen-ai-wrapper/pkg/message" - "github.com/Checkmarx/gen-ai-wrapper/pkg/models" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" -) - -type WrapperInternalImpl struct { - connection *grpc.ClientConn - client redirect_prompt.AiProxyServiceClient - dropLen int - setupMessages []message.Message -} - -func NewWrapperInternalImpl(endPoint string, dropLen int) (Wrapper, error) { - connection, err := grpc.NewClient(endPoint, grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - return nil, err - } - client := redirect_prompt.NewAiProxyServiceClient(connection) - return &WrapperInternalImpl{ - connection: connection, - client: client, - dropLen: dropLen, - }, nil -} - -func (w *WrapperInternalImpl) SetupCall(messages []message.Message) { - w.setupMessages = messages -} - -func (w *WrapperInternalImpl) Call(_ string, metaData *message.MetaData, request *ChatCompletionRequest) (*ChatCompletionResponse, error) { - if w.setupMessages != nil { - //true for GPT4 - if request.Model == models.GPT4 { - request.Messages = append(w.setupMessages, request.Messages...) - } else { - userIndex := findLastUserIndex(request.Messages) - front := request.Messages[:userIndex] - back := request.Messages[userIndex:] - request.Messages = append(front, w.setupMessages...) - request.Messages = append(request.Messages, back...) - } - } - - req, err := w.prepareRequest(metaData, request) - if err != nil { - return nil, err - } - - resp, err := w.client.RedirectPrompt(context.Background(), req) - if err != nil { - return nil, err - } - return w.handleGptResponse(metaData, request, resp) -} - -func (w *WrapperInternalImpl) prepareRequest(metaData *message.MetaData, requestBody *ChatCompletionRequest) (*redirect_prompt.RedirectPromptRequest, error) { - jsonData, err := json.Marshal(requestBody) - if err != nil { - return nil, err - } - if metaData == nil { - return nil, errors.New("metadata is nil") - } - req := &redirect_prompt.RedirectPromptRequest{ - Content: jsonData, - Tenant: metaData.TenantID, - RequestId: metaData.RequestID, - Origin: metaData.UserAgent, - Feature: metaData.Feature, - } - return req, nil -} - -func (w *WrapperInternalImpl) handleGptResponse(metaData *message.MetaData, requestBody *ChatCompletionRequest, resp *redirect_prompt.RedirectPromptResponse) (*ChatCompletionResponse, error) { - var err error - bodyBytes, err := io.ReadAll(bytes.NewBuffer(resp.Content)) - if err != nil { - return nil, err - } - if resp.GenAiErrorCode == http.StatusOK { - var responseBody = new(ChatCompletionResponse) - err = json.Unmarshal(bodyBytes, responseBody) - if err != nil { - return nil, err - } - return responseBody, nil - } - var errorResponse = new(ErrorResponse) - err = json.Unmarshal(bodyBytes, errorResponse) - if err != nil { - return nil, err - } - switch resp.GenAiErrorCode { - case http.StatusBadRequest: - if errorResponse.Error.Code == errorCodeMaxTokens { - return w.Call("", metaData, &ChatCompletionRequest{ - Model: requestBody.Model, - Messages: requestBody.Messages[w.dropLen:], - }) - } - } - return nil, fromResponse(int(resp.GenAiErrorCode), errorResponse) -} - -func (w *WrapperInternalImpl) Close() error { - return w.connection.Close() -} diff --git a/internal/gpt.go b/internal/gpt.go index 3e4f952..93b5489 100644 --- a/internal/gpt.go +++ b/internal/gpt.go @@ -3,9 +3,9 @@ package internal import ( "errors" "fmt" + "github.com/Checkmarx/gen-ai-wrapper/pkg/message" "github.com/Checkmarx/gen-ai-wrapper/pkg/role" - "net/url" ) // const gptByAzure = "https://cxgpt4.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2023-05-15" @@ -51,14 +51,12 @@ type Wrapper interface { } func NewWrapperFactory(endPoint, apiKey string, dropLen int) (Wrapper, error) { - endPointURL, err := url.Parse(endPoint) - if err != nil { - return nil, err - } - if endPointURL.Scheme == "http" || endPointURL.Scheme == "https" { - return NewWrapperImpl(endPoint, apiKey, dropLen), nil - } - return NewWrapperInternalImpl(endPoint, dropLen) + return NewWrapperImpl(endPoint, apiKey, dropLen), nil +} + +// NewLitellmWrapperFactory creates a new litellm wrapper factory +func NewLitellmWrapperFactory(endPoint, apiKey string, dropLen int) (Wrapper, error) { + return NewLitellmWrapper(endPoint, apiKey, dropLen), nil } func fromResponse(statusCode int, e *ErrorResponse) error { diff --git a/internal/litellm_wrapper.go b/internal/litellm_wrapper.go new file mode 100644 index 0000000..d53d7e2 --- /dev/null +++ b/internal/litellm_wrapper.go @@ -0,0 +1,110 @@ +package internal + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/Checkmarx/gen-ai-wrapper/pkg/message" +) + +// LitellmWrapper implements the Wrapper interface for litellm AI proxy service +type LitellmWrapper struct { + endPoint string + apiKey string + dropLen int +} + +// NewLitellmWrapper creates a new litellm wrapper instance +func NewLitellmWrapper(endPoint, apiKey string, dropLen int) Wrapper { + return &LitellmWrapper{ + endPoint: endPoint, + apiKey: apiKey, + dropLen: dropLen, + } +} + +// SetupCall sets up the wrapper with initial messages (no-op for litellm) +func (w *LitellmWrapper) SetupCall(messages []message.Message) { + // No setup needed for litellm +} + +// Call makes a request to the litellm AI proxy service +func (w *LitellmWrapper) Call(cxAuth string, metaData *message.MetaData, request *ChatCompletionRequest) (*ChatCompletionResponse, error) { + // Prepare the request + req, err := w.prepareRequest(cxAuth, metaData, request) + if err != nil { + return nil, err + } + + // Make the HTTP request + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // Handle the response + return w.handleResponse(resp) +} + +// prepareRequest creates the HTTP request +func (w *LitellmWrapper) prepareRequest(cxAuth string, metaData *message.MetaData, requestBody *ChatCompletionRequest) (*http.Request, error) { + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(http.MethodPost, w.endPoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cxAuth)) + + // Set required headers for litellm service + req.Header.Set("X-Request-ID", metaData.RequestID) + req.Header.Set("X-Tenant-ID", metaData.TenantID) + req.Header.Set("User-Agent", metaData.UserAgent) + req.Header.Set("X-Feature", metaData.Feature) + + return req, nil +} + +// handleResponse processes the HTTP response +func (w *LitellmWrapper) handleResponse(resp *http.Response) (*ChatCompletionResponse, error) { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Handle successful response + if resp.StatusCode == http.StatusOK { + var responseBody = new(ChatCompletionResponse) + err = json.Unmarshal(bodyBytes, responseBody) + if err != nil { + return nil, err + } + return responseBody, nil + } + + // Handle error responses + var errorResponse = new(ErrorResponse) + err = json.Unmarshal(bodyBytes, errorResponse) + if err != nil { + // If we can't parse the error response, return a generic error + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes)) + } + + // Return the parsed error + return nil, fromResponse(resp.StatusCode, errorResponse) +} + +// Close closes the wrapper (no-op for HTTP client) +func (w *LitellmWrapper) Close() error { + return nil +} diff --git a/pkg/wrapper/litellm_wrapper.go b/pkg/wrapper/litellm_wrapper.go new file mode 100644 index 0000000..325b258 --- /dev/null +++ b/pkg/wrapper/litellm_wrapper.go @@ -0,0 +1,54 @@ +package wrapper + +import ( + "github.com/Checkmarx/gen-ai-wrapper/internal" + "github.com/Checkmarx/gen-ai-wrapper/pkg/message" + "github.com/Checkmarx/gen-ai-wrapper/pkg/models" +) + +// LitellmWrapper provides a simple wrapper for litellm AI proxy service +type LitellmWrapper interface { + Call(cxAuth string, metaData *message.MetaData, request *internal.ChatCompletionRequest) (*internal.ChatCompletionResponse, error) + SetupCall(messages []message.Message) + Close() error +} + +// LitellmWrapperImpl implements the LitellmWrapper interface +type LitellmWrapperImpl struct { + wrapper internal.Wrapper + model string +} + +// NewLitellmWrapper creates a new litellm wrapper +func NewLitellmWrapper(endPoint, apiKey, model string) (LitellmWrapper, error) { + if model == "" { + model = models.DefaultModel + } + + wrapper := internal.NewLitellmWrapper(endPoint, apiKey, 0) + + return &LitellmWrapperImpl{ + wrapper: wrapper, + model: model, + }, nil +} + +// SetupCall sets up the wrapper with initial messages +func (w *LitellmWrapperImpl) SetupCall(messages []message.Message) { + w.wrapper.SetupCall(messages) +} + +// Call makes a request to the litellm service +func (w *LitellmWrapperImpl) Call(cxAuth string, metaData *message.MetaData, request *internal.ChatCompletionRequest) (*internal.ChatCompletionResponse, error) { + // Set the model if not already set + if request.Model == "" { + request.Model = w.model + } + + return w.wrapper.Call(cxAuth, metaData, request) +} + +// Close closes the wrapper +func (w *LitellmWrapperImpl) Close() error { + return w.wrapper.Close() +} From ce8c5462a1c9fab9ff1fe0247d526aa3c9fa371c Mon Sep 17 00:00:00 2001 From: Alex Cohen Date: Thu, 9 Oct 2025 11:45:26 -0400 Subject: [PATCH 2/2] Add LiteLLM Wrapper support (AST-113982) --- example/cxoneai.go | 9 +++++++-- internal/gpt.go | 4 ++-- internal/litellm_wrapper.go | 4 +--- pkg/models/models.go | 3 +++ pkg/wrapper/litellm_wrapper.go | 2 +- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/example/cxoneai.go b/example/cxoneai.go index dd974fe..4bb561f 100644 --- a/example/cxoneai.go +++ b/example/cxoneai.go @@ -50,9 +50,14 @@ func getOAuthAccessToken() (string, error) { } data := url.Values{} - data.Set("grant_type", "client_credentials") + data.Set("grant_type", "refresh_token") data.Set("client_id", clientID) - data.Set("client_secret", clientSecret) + data.Set("refresh_token", clientSecret) + + //Use this if you have client credentials + //data.Set("grant_type", "client_credentials") + //data.Set("client_id", clientID) + //data.Set("client_secret", clientSecret) req, err := http.NewRequest("POST", openIDURL, strings.NewReader(data.Encode())) if err != nil { diff --git a/internal/gpt.go b/internal/gpt.go index 93b5489..1a4a138 100644 --- a/internal/gpt.go +++ b/internal/gpt.go @@ -55,8 +55,8 @@ func NewWrapperFactory(endPoint, apiKey string, dropLen int) (Wrapper, error) { } // NewLitellmWrapperFactory creates a new litellm wrapper factory -func NewLitellmWrapperFactory(endPoint, apiKey string, dropLen int) (Wrapper, error) { - return NewLitellmWrapper(endPoint, apiKey, dropLen), nil +func NewLitellmWrapperFactory(endPoint, apiKey string) (Wrapper, error) { + return NewLitellmWrapper(endPoint, apiKey), nil } func fromResponse(statusCode int, e *ErrorResponse) error { diff --git a/internal/litellm_wrapper.go b/internal/litellm_wrapper.go index d53d7e2..1dd3dff 100644 --- a/internal/litellm_wrapper.go +++ b/internal/litellm_wrapper.go @@ -14,15 +14,13 @@ import ( type LitellmWrapper struct { endPoint string apiKey string - dropLen int } // NewLitellmWrapper creates a new litellm wrapper instance -func NewLitellmWrapper(endPoint, apiKey string, dropLen int) Wrapper { +func NewLitellmWrapper(endPoint, apiKey string) Wrapper { return &LitellmWrapper{ endPoint: endPoint, apiKey: apiKey, - dropLen: dropLen, } } diff --git a/pkg/models/models.go b/pkg/models/models.go index 55d5dbe..4328113 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -15,5 +15,8 @@ const ( GPT3TextDavinci001 = "text-davinci-001" GPT3TextDavinci002 = "text-davinci-002" GPT3TextDavinci003 = "text-davinci-003" + ClaudeSonnet37 = "claude-sonnet-3-7" + ClaudeSonnet4 = "claude-sonnet-4" + ClaudeSonnet45 = "claude-sonnet-4-5" DefaultModel = GPT4o ) diff --git a/pkg/wrapper/litellm_wrapper.go b/pkg/wrapper/litellm_wrapper.go index 325b258..33a2f1e 100644 --- a/pkg/wrapper/litellm_wrapper.go +++ b/pkg/wrapper/litellm_wrapper.go @@ -25,7 +25,7 @@ func NewLitellmWrapper(endPoint, apiKey, model string) (LitellmWrapper, error) { model = models.DefaultModel } - wrapper := internal.NewLitellmWrapper(endPoint, apiKey, 0) + wrapper := internal.NewLitellmWrapper(endPoint, apiKey) return &LitellmWrapperImpl{ wrapper: wrapper,