From 93233589098fb6d4a216aa5768308b40687d89de Mon Sep 17 00:00:00 2001 From: Robby <45851384+h0rv@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:17:13 -0400 Subject: [PATCH] Update to `watsonx-go` Client (#10) Signed-off-by: Robby Co-authored-by: Robby --- README.md | 65 +++++++++++++--------- go.mod | 4 +- pkg/internal/tests/models/generate_test.go | 62 +++++++++++++++------ pkg/models/{model.go => client.go} | 35 +++++++----- pkg/models/client_option.go | 42 ++++++++++++++ pkg/models/generate.go | 4 +- pkg/models/model_option.go | 42 -------------- pkg/models/types.go | 4 -- 8 files changed, 151 insertions(+), 107 deletions(-) rename pkg/models/{model.go => client.go} (66%) create mode 100644 pkg/models/client_option.go delete mode 100644 pkg/models/model_option.go diff --git a/README.md b/README.md index 17bdd8b..e464914 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,56 @@ -# go-watsonx +# watsonx-go -Zero dependency [watsonx](https://www.ibm.com/watsonx) API Client for Go +`watsonx-go` is a [watsonx](https://www.ibm.com/watsonx) Client for Go ## Install -Install: - ```sh -go get -u github.com/h0rv/go-watsonx +go get -u github.com/IBM/watsonx-go ``` -Import: +## Usage ```go import ( - wx "github.com/h0rv/go-watsonx/pkg/models" + wx "github.com/IBM/watsonx-go/pkg/models" ) ``` -## Example Usage +### Example Usage + +```sh +export WATSONX_API_KEY="YOUR WATSONX API KEY" +export WATSONX_PROJECT_ID="YOUR WATSONX PROJECT ID" +``` + +Create a client: ```go - model, _ := wx.NewModel( - wx.WithIBMCloudAPIKey("YOUR IBM CLOUD API KEY"), - wx.WithWatsonxProjectID("YOUR WATSONX PROJECT ID"), - ) - - result, _ := model.GenerateText( - "meta-llama/llama-3-70b-instruct", - "Hi, who are you?", - wx.WithTemperature(0.9), - wx.WithTopP(.5), - wx.WithTopK(10), - wx.WithMaxNewTokens(512), - ) - - println(result.Text) +client, _ := wx.NewClient() +``` + +Or pass in the required secrets directly: + +```go +client, err := wx.NewClient( + wx.WithWatsonxAPIKey(apiKey), + wx.WithWatsonxProjectID(projectID), +) +``` + +Generation: + +```go +result, _ := client.GenerateText( + "meta-llama/llama-3-70b-instruct", + "Hi, who are you?", + wx.WithTemperature(0.9), + wx.WithTopP(.5), + wx.WithTopK(10), + wx.WithMaxNewTokens(512), +) + +println(result.Text) ``` ## Development Setup @@ -45,7 +60,7 @@ import ( #### Setup ```sh -export IBMCLOUD_API_KEY="YOUR IBM CLOUD API KEY" +export WATSONX_API_KEY="YOUR WATSONX API KEY" export WATSONX_PROJECT_ID="YOUR WATSONX PROJECT ID" ``` @@ -65,5 +80,5 @@ git config --local core.hooksPath .githooks/ ## Resources -- [watsonx Python SDK Docs](https://ibm.github.io/watson-machine-learning-sdk) - [watsonx REST API Docs](https://cloud.ibm.com/apidocs/watsonx-ai) +- [watsonx Python SDK Docs](https://ibm.github.io/watson-machine-learning-sdk) diff --git a/go.mod b/go.mod index 9c844db..17536f6 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/h0rv/go-watsonx +module github.com/IBM/watsonx-go -go 1.21.2 +go 1.21.4 diff --git a/pkg/internal/tests/models/generate_test.go b/pkg/internal/tests/models/generate_test.go index a21c62f..4599267 100644 --- a/pkg/internal/tests/models/generate_test.go +++ b/pkg/internal/tests/models/generate_test.go @@ -4,12 +4,40 @@ import ( "os" "testing" - wx "github.com/h0rv/go-watsonx/pkg/models" + wx "github.com/IBM/watsonx-go/pkg/models" ) -func getModel(t *testing.T) *wx.Model { - apiKey := os.Getenv(wx.WatsonxAPIKeyEnvVarName) - projectID := os.Getenv(wx.WatsonxProjectIDEnvVarName) +func TestClientCreationWithEnvVars(t *testing.T) { + _, err := wx.NewClient() + + if err != nil { + t.Fatalf("Expected no error for creating client with environment variables, but got %v", err) + } +} + +func TestClientCreationWithPassing(t *testing.T) { + apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName) + + if apiKey == "" { + t.Fatal("No watsonx API key provided") + } + if projectID == "" { + t.Fatal("No watsonx project ID provided") + } + + _, err := wx.NewClient( + wx.WithWatsonxAPIKey(apiKey), + wx.WithWatsonxProjectID(projectID), + ) + + if err != nil { + t.Fatalf("Expected no error for creating client with passing secrets, but got %v", err) + } +} + +func getClient(t *testing.T) *wx.Client { + apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName) + if apiKey == "" { t.Fatal("No watsonx API key provided") } @@ -17,21 +45,21 @@ func getModel(t *testing.T) *wx.Model { t.Fatal("No watsonx project ID provided") } - model, err := wx.NewModel( + client, err := wx.NewClient( wx.WithWatsonxAPIKey(apiKey), wx.WithWatsonxProjectID(projectID), ) if err != nil { - t.Fatalf("Failed to create model for testing. Error: %v", err) + t.Fatalf("Failed to create client for testing. Error: %v", err) } - return model + return client } func TestEmptyPromptError(t *testing.T) { - model := getModel(t) + client := getClient(t) - _, err := model.GenerateText( + _, err := client.GenerateText( "dumby model", "", ) @@ -41,9 +69,9 @@ func TestEmptyPromptError(t *testing.T) { } func TestNilOptions(t *testing.T) { - model := getModel(t) + client := getClient(t) - _, err := model.GenerateText( + _, err := client.GenerateText( "meta-llama/llama-3-70b-instruct", "What day is it?", nil, @@ -54,9 +82,9 @@ func TestNilOptions(t *testing.T) { } func TestValidPrompt(t *testing.T) { - model := getModel(t) + client := getClient(t) - _, err := model.GenerateText( + _, err := client.GenerateText( "meta-llama/llama-3-70b-instruct", "Test prompt", ) @@ -66,9 +94,9 @@ func TestValidPrompt(t *testing.T) { } func TestGenerateText(t *testing.T) { - model := getModel(t) + client := getClient(t) - result, err := model.GenerateText( + result, err := client.GenerateText( "meta-llama/llama-3-70b-instruct", "Hi, who are you?", wx.WithTemperature(0.9), @@ -85,9 +113,9 @@ func TestGenerateText(t *testing.T) { } func TestGenerateTextWithNilOptions(t *testing.T) { - model := getModel(t) + client := getClient(t) - result, err := model.GenerateText( + result, err := client.GenerateText( "meta-llama/llama-3-70b-instruct", "Who are you?", nil, diff --git a/pkg/models/model.go b/pkg/models/client.go similarity index 66% rename from pkg/models/model.go rename to pkg/models/client.go index f61aeee..9c45c2c 100644 --- a/pkg/models/model.go +++ b/pkg/models/client.go @@ -1,16 +1,13 @@ package models -/* - * https://ibm.github.io/watson-machine-learning-sdk/_modules/ibm_watson_machine_learning/foundation_models/model.html#Model - */ - import ( + "errors" "fmt" "net/http" "os" ) -type Model struct { +type Client struct { url string region IBMCloudRegion apiVersion string @@ -22,9 +19,9 @@ type Model struct { httpClient Doer } -func NewModel(options ...ModelOption) (*Model, error) { +func NewClient(options ...ClientOption) (*Client, error) { - opts := defaulModelOptions() + opts := defaulClientOptions() for _, opt := range options { if opt != nil { opt(opts) @@ -36,13 +33,21 @@ func NewModel(options ...ModelOption) (*Model, error) { opts.URL = buildBaseURL(opts.Region) } - m := &Model{ + if opts.apiKey == "" { + return nil, errors.New("no watsonx API key provided") + } + + if opts.projectID == "" { + return nil, errors.New("no watsonx project ID provided") + } + + m := &Client{ url: opts.URL, region: opts.Region, apiVersion: opts.APIVersion, // token: set below - apiKey: opts.watsonxAPIKey, + apiKey: opts.apiKey, projectID: opts.projectID, httpClient: &http.Client{}, @@ -57,7 +62,7 @@ func NewModel(options ...ModelOption) (*Model, error) { } // CheckAndRefreshToken checks the IAM token if it expired; if it did, it refreshes it; nothing if not -func (m *Model) CheckAndRefreshToken() error { +func (m *Client) CheckAndRefreshToken() error { if m.token.Expired() { return m.RefreshToken() } @@ -65,7 +70,7 @@ func (m *Model) CheckAndRefreshToken() error { } // RefreshToken generates and sets the model with a new token -func (m *Model) RefreshToken() error { +func (m *Client) RefreshToken() error { token, err := GenerateToken(m.httpClient, m.apiKey) if err != nil { return err @@ -78,13 +83,13 @@ func buildBaseURL(region IBMCloudRegion) string { return fmt.Sprintf(BaseURLFormatStr, region) } -func defaulModelOptions() *ModelOptions { - return &ModelOptions{ +func defaulClientOptions() *ClientOptions { + return &ClientOptions{ URL: "", Region: DefaultRegion, APIVersion: DefaultAPIVersion, - watsonxAPIKey: os.Getenv(WatsonxAPIKeyEnvVarName), - projectID: os.Getenv(WatsonxProjectIDEnvVarName), + apiKey: os.Getenv(WatsonxAPIKeyEnvVarName), + projectID: os.Getenv(WatsonxProjectIDEnvVarName), } } diff --git a/pkg/models/client_option.go b/pkg/models/client_option.go new file mode 100644 index 0000000..67c3e2e --- /dev/null +++ b/pkg/models/client_option.go @@ -0,0 +1,42 @@ +package models + +type ClientOption func(*ClientOptions) + +type ClientOptions struct { + URL string + Region IBMCloudRegion + APIVersion string + + apiKey WatsonxAPIKey + projectID WatsonxProjectID +} + +func WithURL(url string) ClientOption { + return func(o *ClientOptions) { + o.URL = url + } +} + +func WithRegion(region IBMCloudRegion) ClientOption { + return func(o *ClientOptions) { + o.Region = region + } +} + +func WithAPIVersion(apiVersion string) ClientOption { + return func(o *ClientOptions) { + o.APIVersion = apiVersion + } +} + +func WithWatsonxAPIKey(watsonxAPIKey WatsonxAPIKey) ClientOption { + return func(o *ClientOptions) { + o.apiKey = watsonxAPIKey + } +} + +func WithWatsonxProjectID(projectID WatsonxProjectID) ClientOption { + return func(o *ClientOptions) { + o.projectID = projectID + } +} diff --git a/pkg/models/generate.go b/pkg/models/generate.go index 5fc43fd..15eb99e 100644 --- a/pkg/models/generate.go +++ b/pkg/models/generate.go @@ -49,7 +49,7 @@ type generateTextResponse struct { } // GenerateText generates completion text based on a given prompt and parameters -func (m *Model) GenerateText(model, prompt string, options ...GenerateOption) (GenerateTextResult, error) { +func (m *Client) GenerateText(model, prompt string, options ...GenerateOption) (GenerateTextResult, error) { m.CheckAndRefreshToken() if prompt == "" { @@ -86,7 +86,7 @@ func (m *Model) GenerateText(model, prompt string, options ...GenerateOption) (G // generateTextRequest sends the generate request and handles the response using the http package. // Returns error on non-2XX response -func (m *Model) generateTextRequest(payload GenerateTextPayload) (generateTextResponse, error) { +func (m *Client) generateTextRequest(payload GenerateTextPayload) (generateTextResponse, error) { params := url.Values{ "version": {m.apiVersion}, } diff --git a/pkg/models/model_option.go b/pkg/models/model_option.go deleted file mode 100644 index af77e14..0000000 --- a/pkg/models/model_option.go +++ /dev/null @@ -1,42 +0,0 @@ -package models - -type ModelOption func(*ModelOptions) - -type ModelOptions struct { - URL string - Region IBMCloudRegion - APIVersion string - - watsonxAPIKey WatsonxAPIKey - projectID WatsonxProjectID -} - -func WithURL(url string) ModelOption { - return func(o *ModelOptions) { - o.URL = url - } -} - -func WithRegion(region IBMCloudRegion) ModelOption { - return func(o *ModelOptions) { - o.Region = region - } -} - -func WithAPIVersion(apiVersion string) ModelOption { - return func(o *ModelOptions) { - o.APIVersion = apiVersion - } -} - -func WithWatsonxAPIKey(watsonxAPIKey WatsonxAPIKey) ModelOption { - return func(o *ModelOptions) { - o.watsonxAPIKey = watsonxAPIKey - } -} - -func WithWatsonxProjectID(projectID WatsonxProjectID) ModelOption { - return func(o *ModelOptions) { - o.projectID = projectID - } -} diff --git a/pkg/models/types.go b/pkg/models/types.go index 2bc0e0e..c0d4ce7 100644 --- a/pkg/models/types.go +++ b/pkg/models/types.go @@ -4,10 +4,6 @@ import ( "net/http" ) -/* - * https://ibm.github.io/watson-machine-learning-sdk/model.html#ibm_watson_machine_learning.foundation_models.utils.enums.ModelTypes - */ - type ( WatsonxAPIKey = string WatsonxProjectID = string