Skip to content

Commit

Permalink
Update to watsonx-go Client (#10)
Browse files Browse the repository at this point in the history
Signed-off-by: Robby <[email protected]>
Co-authored-by: Robby <[email protected]>
  • Loading branch information
h0rv and h0rv authored Jun 17, 2024
1 parent cc529cf commit 9323358
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 107 deletions.
65 changes: 40 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,56 @@
# go-watsonx
# watsonx-go

Zero dependency [watsonx](https://www.ibm.com/watsonx) API Client for Go
`watsonx-go` is a [watsonx](https://www.ibm.com/watsonx) Client for Go

## Install

Install:

```sh
go get -u github.com/h0rv/go-watsonx
go get -u github.com/IBM/watsonx-go
```

Import:
## Usage

```go
import (
wx "github.com/h0rv/go-watsonx/pkg/models"
wx "github.com/IBM/watsonx-go/pkg/models"
)
```

## Example Usage
### Example Usage

```sh
export WATSONX_API_KEY="YOUR WATSONX API KEY"
export WATSONX_PROJECT_ID="YOUR WATSONX PROJECT ID"
```

Create a client:

```go
model, _ := wx.NewModel(
wx.WithIBMCloudAPIKey("YOUR IBM CLOUD API KEY"),
wx.WithWatsonxProjectID("YOUR WATSONX PROJECT ID"),
)

result, _ := model.GenerateText(
"meta-llama/llama-3-70b-instruct",
"Hi, who are you?",
wx.WithTemperature(0.9),
wx.WithTopP(.5),
wx.WithTopK(10),
wx.WithMaxNewTokens(512),
)

println(result.Text)
client, _ := wx.NewClient()
```

Or pass in the required secrets directly:

```go
client, err := wx.NewClient(
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)
```

Generation:

```go
result, _ := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
"Hi, who are you?",
wx.WithTemperature(0.9),
wx.WithTopP(.5),
wx.WithTopK(10),
wx.WithMaxNewTokens(512),
)

println(result.Text)
```

## Development Setup
Expand All @@ -45,7 +60,7 @@ import (
#### Setup

```sh
export IBMCLOUD_API_KEY="YOUR IBM CLOUD API KEY"
export WATSONX_API_KEY="YOUR WATSONX API KEY"
export WATSONX_PROJECT_ID="YOUR WATSONX PROJECT ID"
```

Expand All @@ -65,5 +80,5 @@ git config --local core.hooksPath .githooks/

## Resources

- [watsonx Python SDK Docs](https://ibm.github.io/watson-machine-learning-sdk)
- [watsonx REST API Docs](https://cloud.ibm.com/apidocs/watsonx-ai)
- [watsonx Python SDK Docs](https://ibm.github.io/watson-machine-learning-sdk)
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/h0rv/go-watsonx
module github.com/IBM/watsonx-go

go 1.21.2
go 1.21.4
62 changes: 45 additions & 17 deletions pkg/internal/tests/models/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,62 @@ import (
"os"
"testing"

wx "github.com/h0rv/go-watsonx/pkg/models"
wx "github.com/IBM/watsonx-go/pkg/models"
)

func getModel(t *testing.T) *wx.Model {
apiKey := os.Getenv(wx.WatsonxAPIKeyEnvVarName)
projectID := os.Getenv(wx.WatsonxProjectIDEnvVarName)
func TestClientCreationWithEnvVars(t *testing.T) {
_, err := wx.NewClient()

if err != nil {
t.Fatalf("Expected no error for creating client with environment variables, but got %v", err)
}
}

func TestClientCreationWithPassing(t *testing.T) {
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)

if apiKey == "" {
t.Fatal("No watsonx API key provided")
}
if projectID == "" {
t.Fatal("No watsonx project ID provided")
}

_, err := wx.NewClient(
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)

if err != nil {
t.Fatalf("Expected no error for creating client with passing secrets, but got %v", err)
}
}

func getClient(t *testing.T) *wx.Client {
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)

if apiKey == "" {
t.Fatal("No watsonx API key provided")
}
if projectID == "" {
t.Fatal("No watsonx project ID provided")
}

model, err := wx.NewModel(
client, err := wx.NewClient(
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)
if err != nil {
t.Fatalf("Failed to create model for testing. Error: %v", err)
t.Fatalf("Failed to create client for testing. Error: %v", err)
}

return model
return client
}

func TestEmptyPromptError(t *testing.T) {
model := getModel(t)
client := getClient(t)

_, err := model.GenerateText(
_, err := client.GenerateText(
"dumby model",
"",
)
Expand All @@ -41,9 +69,9 @@ func TestEmptyPromptError(t *testing.T) {
}

func TestNilOptions(t *testing.T) {
model := getModel(t)
client := getClient(t)

_, err := model.GenerateText(
_, err := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
"What day is it?",
nil,
Expand All @@ -54,9 +82,9 @@ func TestNilOptions(t *testing.T) {
}

func TestValidPrompt(t *testing.T) {
model := getModel(t)
client := getClient(t)

_, err := model.GenerateText(
_, err := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
"Test prompt",
)
Expand All @@ -66,9 +94,9 @@ func TestValidPrompt(t *testing.T) {
}

func TestGenerateText(t *testing.T) {
model := getModel(t)
client := getClient(t)

result, err := model.GenerateText(
result, err := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
"Hi, who are you?",
wx.WithTemperature(0.9),
Expand All @@ -85,9 +113,9 @@ func TestGenerateText(t *testing.T) {
}

func TestGenerateTextWithNilOptions(t *testing.T) {
model := getModel(t)
client := getClient(t)

result, err := model.GenerateText(
result, err := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
"Who are you?",
nil,
Expand Down
35 changes: 20 additions & 15 deletions pkg/models/model.go → pkg/models/client.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package models

/*
* https://ibm.github.io/watson-machine-learning-sdk/_modules/ibm_watson_machine_learning/foundation_models/model.html#Model
*/

import (
"errors"
"fmt"
"net/http"
"os"
)

type Model struct {
type Client struct {
url string
region IBMCloudRegion
apiVersion string
Expand All @@ -22,9 +19,9 @@ type Model struct {
httpClient Doer
}

func NewModel(options ...ModelOption) (*Model, error) {
func NewClient(options ...ClientOption) (*Client, error) {

opts := defaulModelOptions()
opts := defaulClientOptions()
for _, opt := range options {
if opt != nil {
opt(opts)
Expand All @@ -36,13 +33,21 @@ func NewModel(options ...ModelOption) (*Model, error) {
opts.URL = buildBaseURL(opts.Region)
}

m := &Model{
if opts.apiKey == "" {
return nil, errors.New("no watsonx API key provided")
}

if opts.projectID == "" {
return nil, errors.New("no watsonx project ID provided")
}

m := &Client{
url: opts.URL,
region: opts.Region,
apiVersion: opts.APIVersion,

// token: set below
apiKey: opts.watsonxAPIKey,
apiKey: opts.apiKey,
projectID: opts.projectID,

httpClient: &http.Client{},
Expand All @@ -57,15 +62,15 @@ func NewModel(options ...ModelOption) (*Model, error) {
}

// CheckAndRefreshToken checks the IAM token if it expired; if it did, it refreshes it; nothing if not
func (m *Model) CheckAndRefreshToken() error {
func (m *Client) CheckAndRefreshToken() error {
if m.token.Expired() {
return m.RefreshToken()
}
return nil
}

// RefreshToken generates and sets the model with a new token
func (m *Model) RefreshToken() error {
func (m *Client) RefreshToken() error {
token, err := GenerateToken(m.httpClient, m.apiKey)
if err != nil {
return err
Expand All @@ -78,13 +83,13 @@ func buildBaseURL(region IBMCloudRegion) string {
return fmt.Sprintf(BaseURLFormatStr, region)
}

func defaulModelOptions() *ModelOptions {
return &ModelOptions{
func defaulClientOptions() *ClientOptions {
return &ClientOptions{
URL: "",
Region: DefaultRegion,
APIVersion: DefaultAPIVersion,

watsonxAPIKey: os.Getenv(WatsonxAPIKeyEnvVarName),
projectID: os.Getenv(WatsonxProjectIDEnvVarName),
apiKey: os.Getenv(WatsonxAPIKeyEnvVarName),
projectID: os.Getenv(WatsonxProjectIDEnvVarName),
}
}
42 changes: 42 additions & 0 deletions pkg/models/client_option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package models

type ClientOption func(*ClientOptions)

type ClientOptions struct {
URL string
Region IBMCloudRegion
APIVersion string

apiKey WatsonxAPIKey
projectID WatsonxProjectID
}

func WithURL(url string) ClientOption {
return func(o *ClientOptions) {
o.URL = url
}
}

func WithRegion(region IBMCloudRegion) ClientOption {
return func(o *ClientOptions) {
o.Region = region
}
}

func WithAPIVersion(apiVersion string) ClientOption {
return func(o *ClientOptions) {
o.APIVersion = apiVersion
}
}

func WithWatsonxAPIKey(watsonxAPIKey WatsonxAPIKey) ClientOption {
return func(o *ClientOptions) {
o.apiKey = watsonxAPIKey
}
}

func WithWatsonxProjectID(projectID WatsonxProjectID) ClientOption {
return func(o *ClientOptions) {
o.projectID = projectID
}
}
4 changes: 2 additions & 2 deletions pkg/models/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type generateTextResponse struct {
}

// GenerateText generates completion text based on a given prompt and parameters
func (m *Model) GenerateText(model, prompt string, options ...GenerateOption) (GenerateTextResult, error) {
func (m *Client) GenerateText(model, prompt string, options ...GenerateOption) (GenerateTextResult, error) {
m.CheckAndRefreshToken()

if prompt == "" {
Expand Down Expand Up @@ -86,7 +86,7 @@ func (m *Model) GenerateText(model, prompt string, options ...GenerateOption) (G

// generateTextRequest sends the generate request and handles the response using the http package.
// Returns error on non-2XX response
func (m *Model) generateTextRequest(payload GenerateTextPayload) (generateTextResponse, error) {
func (m *Client) generateTextRequest(payload GenerateTextPayload) (generateTextResponse, error) {
params := url.Values{
"version": {m.apiVersion},
}
Expand Down
Loading

0 comments on commit 9323358

Please sign in to comment.