From 775439a7933639a4b3fc90a0da0489907f2d8b31 Mon Sep 17 00:00:00 2001 From: woorui Date: Wed, 11 Dec 2024 22:59:47 +0800 Subject: [PATCH] fix(llm-bridge): retrive a provider if name is empty --- pkg/bridge/ai/provider/provider.go | 15 ++++++++++++++- pkg/bridge/ai/provider/provider_test.go | 23 +++++++++++++++-------- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/pkg/bridge/ai/provider/provider.go b/pkg/bridge/ai/provider/provider.go index b3befdb1a..31a3b365d 100644 --- a/pkg/bridge/ai/provider/provider.go +++ b/pkg/bridge/ai/provider/provider.go @@ -56,8 +56,21 @@ func getProvider(name string) LLMProvider { return nil } -// GetProvider returns the llm provider by name +// GetProvider returns the llm provider by name, +// if name is empty, it will return the first provider that has been registered func GetProvider(name string) (LLMProvider, error) { + if name == "" { + var provider LLMProvider + providers.Range(func(key, _ any) bool { + name = key.(string) + provider = getProvider(name) + return false + }) + if provider != nil { + return provider, nil + } + return nil, ErrNotExistsProvider + } provider := getProvider(name) if provider != nil { return provider, nil diff --git a/pkg/bridge/ai/provider/provider_test.go b/pkg/bridge/ai/provider/provider_test.go index bd55bdc18..349c05d8a 100644 --- a/pkg/bridge/ai/provider/provider_test.go +++ b/pkg/bridge/ai/provider/provider_test.go @@ -21,15 +21,22 @@ func TestProviders(t *testing.T) { assert.ElementsMatch(t, expected, val) }) - t.Run("GetProvider error", func(t *testing.T) { - _, err := GetProvider("name-not-exist") - assert.ErrorIs(t, err, ErrNotExistsProvider) - }) - t.Run("GetProvider", func(t *testing.T) { - p, err := GetProvider("name-1") - assert.NoError(t, err) - assert.Equal(t, p1, p) + t.Run("ok", func(t *testing.T) { + p, err := GetProvider("name-1") + assert.NoError(t, err) + assert.Equal(t, p1, p) + }) + t.Run("name is empty", func(t *testing.T) { + p, err := GetProvider("") + assert.NoError(t, err) + assert.Equal(t, p1, p) + }) + t.Run("not found", func(t *testing.T) { + p, err := GetProvider("name-x") + assert.ErrorIs(t, err, ErrNotExistsProvider) + assert.Nil(t, p) + }) }) }