Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions example/cxoneai.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ func getOAuthAccessToken() (string, error) {
}

data := url.Values{}
data.Set("grant_type", "client_credentials")
data.Set("grant_type", "refresh_token")
data.Set("client_id", clientID)
data.Set("client_secret", clientSecret)
data.Set("refresh_token", clientSecret)

//Use this if you have client credentials
//data.Set("grant_type", "client_credentials")
//data.Set("client_id", clientID)
//data.Set("client_secret", clientSecret)

req, err := http.NewRequest("POST", openIDURL, strings.NewReader(data.Encode()))
if err != nil {
Expand Down
58 changes: 39 additions & 19 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package main
import (
"flag"
"fmt"
"github.com/Checkmarx/gen-ai-wrapper/pkg/connector"
"os"
"strings"

"github.com/Checkmarx/gen-ai-wrapper/internal"
"github.com/Checkmarx/gen-ai-wrapper/pkg/message"
"github.com/Checkmarx/gen-ai-wrapper/pkg/models"
"github.com/Checkmarx/gen-ai-wrapper/pkg/role"
"github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper"
"github.com/google/uuid"
"os"
"strings"
)

const usage = `
Expand All @@ -23,9 +24,8 @@ Options
-s, --system <system-prompt> system (or developer) prompt string
-u, --user <user-prompt> user prompt string
-id <conversation-id> chat conversation ID
-ai <ai-server> AI server to use. Options: {OpenAI (default), CxOne}
-m, --model <model> model to use. Options: {gpt-4o (default), gpt-4, o1, o1-mini, ...}
-f, --full-response return full response from AI
-ai <ai-server> AI server to use. Options: {OpenAI (default), CxOne, LiteLLM}
-m, --model <model> model to use. Options: {gpt-4o (default), gpt-4, o1, o1-mini, claude-3-5-sonnet-20241022, ...}
-h, --help show help
`

Expand Down Expand Up @@ -103,26 +103,46 @@ func CallAIandPrintResponse(aiServer, model, systemPrompt, userPrompt string, ch
return err
}

statefulWrapper, err := wrapper.NewStatefulWrapperNew(
connector.NewFileSystemConnector(""), aiEndpoint, aiKey, model, 4, 0)
var litellmWrapper wrapper.LitellmWrapper

// Use litellm wrapper for litellm server
if strings.EqualFold(aiServer, "LiteLLM") {
litellmWrapper, err = wrapper.NewLitellmWrapper(aiEndpoint, aiKey, model)
} else {
// For other servers, we'll need to implement or use existing wrappers
return fmt.Errorf("unsupported AI server: %s", aiServer)
}

if err != nil {
return fmt.Errorf("error creating '%s' AI client: %v", aiServer, err)
}

newMessages := GetMessages(model, systemPrompt, userPrompt)

// Create proper metadata for the request
metaData := &message.MetaData{
RequestID: "example-request-" + chatId.String(),
TenantID: "default-tenant",
UserAgent: "gen-ai-wrapper-example",
Feature: "chat-completion",
}

// Create the request
request := &internal.ChatCompletionRequest{
Model: model,
Messages: newMessages,
}

// Make the call
response, err := litellmWrapper.Call(aiKey, metaData, request)
if err != nil {
return fmt.Errorf("error calling litellm: %v", err)
}

if fullResponse {
response, err := statefulWrapper.SecureCallReturningFullResponse("", nil, chatId, newMessages)
if err != nil {
return fmt.Errorf("error calling GPT: %v", err)
}
fmt.Printf("%+v\n", response)
} else {
response, err := statefulWrapper.Call(chatId, newMessages)
if err != nil {
return fmt.Errorf("error calling GPT: %v", err)
}
fmt.Println(getMessageContents(response))
fmt.Println(response.Choices[0].Message.Content)
}
return nil
}
Expand Down Expand Up @@ -156,7 +176,7 @@ func getAIAccessKey(aiServer, model string) (string, error) {
}
return accessKey, nil
}
if strings.EqualFold(aiServer, "CxOne") {
if strings.EqualFold(aiServer, "CxOne") || strings.EqualFold(aiServer, "LiteLLM") {
accessKey, err := GetCxOneAIAccessKey()
if err != nil {
return "", fmt.Errorf("error getting CxOne AI API key: %v", err)
Expand All @@ -174,7 +194,7 @@ func getAIEndpoint(aiServer string) (string, error) {
}
return aiEndpoint, nil
}
if strings.EqualFold(aiServer, "CxOne") {
if strings.EqualFold(aiServer, "CxOne") || strings.EqualFold(aiServer, "LiteLLM") {
aiEndpoint, err := GetCxOneAIEndpoint()
if err != nil {
return "", fmt.Errorf("error getting CxOne AI endpoint: %v", err)
Expand Down
119 changes: 0 additions & 119 deletions internal/genaiProxyInternal.go

This file was deleted.

16 changes: 7 additions & 9 deletions internal/gpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package internal
import (
"errors"
"fmt"

"github.com/Checkmarx/gen-ai-wrapper/pkg/message"
"github.com/Checkmarx/gen-ai-wrapper/pkg/role"
"net/url"
)

// const gptByAzure = "https://cxgpt4.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2023-05-15"
Expand Down Expand Up @@ -51,14 +51,12 @@ type Wrapper interface {
}

func NewWrapperFactory(endPoint, apiKey string, dropLen int) (Wrapper, error) {
endPointURL, err := url.Parse(endPoint)
if err != nil {
return nil, err
}
if endPointURL.Scheme == "http" || endPointURL.Scheme == "https" {
return NewWrapperImpl(endPoint, apiKey, dropLen), nil
}
return NewWrapperInternalImpl(endPoint, dropLen)
return NewWrapperImpl(endPoint, apiKey, dropLen), nil
}

// NewLitellmWrapperFactory creates a new litellm wrapper factory
func NewLitellmWrapperFactory(endPoint, apiKey string) (Wrapper, error) {
return NewLitellmWrapper(endPoint, apiKey), nil
}

func fromResponse(statusCode int, e *ErrorResponse) error {
Expand Down
108 changes: 108 additions & 0 deletions internal/litellm_wrapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package internal

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/Checkmarx/gen-ai-wrapper/pkg/message"
)

// LitellmWrapper implements the Wrapper interface for litellm AI proxy service
type LitellmWrapper struct {
endPoint string
apiKey string
}

// NewLitellmWrapper creates a new litellm wrapper instance
func NewLitellmWrapper(endPoint, apiKey string) Wrapper {
return &LitellmWrapper{
endPoint: endPoint,
apiKey: apiKey,
}
}

// SetupCall sets up the wrapper with initial messages (no-op for litellm)
func (w *LitellmWrapper) SetupCall(messages []message.Message) {
// No setup needed for litellm
}

// Call makes a request to the litellm AI proxy service
func (w *LitellmWrapper) Call(cxAuth string, metaData *message.MetaData, request *ChatCompletionRequest) (*ChatCompletionResponse, error) {
// Prepare the request
req, err := w.prepareRequest(cxAuth, metaData, request)
if err != nil {
return nil, err
}

// Make the HTTP request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

// Handle the response
return w.handleResponse(resp)
}

// prepareRequest creates the HTTP request
func (w *LitellmWrapper) prepareRequest(cxAuth string, metaData *message.MetaData, requestBody *ChatCompletionRequest) (*http.Request, error) {
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, err
}

req, err := http.NewRequest(http.MethodPost, w.endPoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}

// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cxAuth))

// Set required headers for litellm service
req.Header.Set("X-Request-ID", metaData.RequestID)
req.Header.Set("X-Tenant-ID", metaData.TenantID)
req.Header.Set("User-Agent", metaData.UserAgent)
req.Header.Set("X-Feature", metaData.Feature)

return req, nil
}

// handleResponse processes the HTTP response
func (w *LitellmWrapper) handleResponse(resp *http.Response) (*ChatCompletionResponse, error) {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

// Handle successful response
if resp.StatusCode == http.StatusOK {
var responseBody = new(ChatCompletionResponse)
err = json.Unmarshal(bodyBytes, responseBody)
if err != nil {
return nil, err
}
return responseBody, nil
}

// Handle error responses
var errorResponse = new(ErrorResponse)
err = json.Unmarshal(bodyBytes, errorResponse)
if err != nil {
// If we can't parse the error response, return a generic error
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes))
}

// Return the parsed error
return nil, fromResponse(resp.StatusCode, errorResponse)
}

// Close closes the wrapper (no-op for HTTP client)
func (w *LitellmWrapper) Close() error {
return nil
}
3 changes: 3 additions & 0 deletions pkg/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@ const (
GPT3TextDavinci001 = "text-davinci-001"
GPT3TextDavinci002 = "text-davinci-002"
GPT3TextDavinci003 = "text-davinci-003"
ClaudeSonnet37 = "claude-sonnet-3-7"
ClaudeSonnet4 = "claude-sonnet-4"
ClaudeSonnet45 = "claude-sonnet-4-5"
DefaultModel = GPT4o
)
Loading