Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions cmd/cc-connect/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,9 @@ func main() {
os.Exit(1)
}

// Wire providers if the agent supports it
if ps, ok := agent.(core.ProviderSwitcher); ok && len(proj.Agent.Providers) > 0 {
providers := make([]core.ProviderConfig, len(proj.Agent.Providers))
for i, p := range proj.Agent.Providers {
providers[i] = core.ProviderConfig{
Name: p.Name,
APIKey: p.APIKey,
BaseURL: p.BaseURL,
Model: p.Model,
Models: convertProviderModels(p.Models),
Thinking: p.Thinking,
Env: p.Env,
}
}
ps.SetProviders(providers)
if active, _ := proj.Agent.Options["provider"].(string); active != "" {
ps.SetActiveProvider(active)
}
if _, err := configureAgentProviders(agent, cfg, proj.Name, proj.Agent.Options); err != nil {
slog.Error("failed to configure providers", "project", proj.Name, "error", err)
os.Exit(1)
}

var platforms []core.Platform
Expand Down Expand Up @@ -1235,20 +1220,10 @@ func reloadConfig(configPath, projName string, engine *core.Engine) (*core.Confi
engine.SetAttachmentSendEnabled(cfg.AttachmentSend != "off")

// Reload providers
if ps, ok := engine.GetAgent().(core.ProviderSwitcher); ok {
providers := make([]core.ProviderConfig, len(proj.Agent.Providers))
for i, p := range proj.Agent.Providers {
providers[i] = core.ProviderConfig{
Name: p.Name, APIKey: p.APIKey, BaseURL: p.BaseURL,
Model: p.Model, Models: convertProviderModels(p.Models), Thinking: p.Thinking, Env: p.Env,
}
}
ps.SetProviders(providers)
result.ProvidersUpdated = len(providers)

if active, _ := proj.Agent.Options["provider"].(string); active != "" {
ps.SetActiveProvider(active)
}
if updated, err := configureAgentProviders(engine.GetAgent(), cfg, proj.Name, proj.Agent.Options); err != nil {
return nil, fmt.Errorf("reload providers: %w", err)
} else {
result.ProvidersUpdated = updated
}

// Reload custom commands
Expand Down Expand Up @@ -1328,6 +1303,39 @@ func convertProviderModels(ms []config.ProviderModelConfig) []core.ModelOption {
return opts
}

func configureAgentProviders(agent core.Agent, cfg *config.Config, projectName string, options map[string]any) (int, error) {
ps, ok := agent.(core.ProviderSwitcher)
if !ok {
return 0, nil
}

providerConfigs, err := config.GetEffectiveProjectProviders(cfg, projectName)
if err != nil {
return 0, err
}

providers := make([]core.ProviderConfig, len(providerConfigs))
for i, p := range providerConfigs {
providers[i] = core.ProviderConfig{
Name: p.Name,
APIKey: p.APIKey,
BaseURL: p.BaseURL,
Model: p.Model,
Models: convertProviderModels(p.Models),
Thinking: p.Thinking,
Env: p.Env,
}
}
ps.SetProviders(providers)

if active, _ := options["provider"].(string); active != "" {
ps.SetActiveProvider(active)
} else {
ps.SetActiveProvider("")
}
return len(providers), nil
}

func convertCoreModels(ms []core.ModelOption) []config.ProviderModelConfig {
if len(ms) == 0 {
return nil
Expand Down
170 changes: 170 additions & 0 deletions cmd/cc-connect/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path/filepath"
"testing"

"github.com/chenhg5/cc-connect/config"
"github.com/chenhg5/cc-connect/core"
)

Expand Down Expand Up @@ -33,6 +34,43 @@ func (a *stubMainAgent) GetWorkDir() string {
return a.workDir
}

type stubMainProviderAgent struct {
stubMainAgent
providers []core.ProviderConfig
activeName string
}

func (a *stubMainProviderAgent) SetProviders(providers []core.ProviderConfig) {
a.providers = append([]core.ProviderConfig(nil), providers...)
}

func (a *stubMainProviderAgent) SetActiveProvider(name string) bool {
a.activeName = name
if name == "" {
return true
}
for _, provider := range a.providers {
if provider.Name == name {
return true
}
}
return false
}

func (a *stubMainProviderAgent) GetActiveProvider() *core.ProviderConfig {
for _, provider := range a.providers {
if provider.Name == a.activeName {
p := provider
return &p
}
}
return nil
}

func (a *stubMainProviderAgent) ListProviders() []core.ProviderConfig {
return append([]core.ProviderConfig(nil), a.providers...)
}

type stubMainAgentSession struct{}

func (s *stubMainAgentSession) Send(string, []core.ImageAttachment, []core.FileAttachment) error {
Expand Down Expand Up @@ -73,3 +111,135 @@ func TestApplyProjectStateOverride(t *testing.T) {
t.Fatalf("agent workDir = %q, want %q", agent.workDir, overrideDir)
}
}

func TestConfigureAgentProvidersUsesTopLevelProviders(t *testing.T) {
cfg := &config.Config{
Providers: []config.ProviderConfig{
{Name: "openai", APIKey: "sk-openai"},
{Name: "kimi", APIKey: "sk-kimi"},
},
Projects: []config.ProjectConfig{{
Name: "demo",
Agent: config.AgentConfig{
Type: "stub-main",
Options: map[string]any{
"provider": "openai",
},
Providers: []config.ProviderConfig{
{Name: "legacy", APIKey: "sk-legacy"},
},
},
Platforms: []config.PlatformConfig{{Type: "telegram", Options: map[string]any{"token": "x"}}},
}},
}

agent := &stubMainProviderAgent{}
updated, err := configureAgentProviders(agent, cfg, "demo", cfg.Projects[0].Agent.Options)
if err != nil {
t.Fatalf("configureAgentProviders() error: %v", err)
}
if updated != 2 {
t.Fatalf("updated = %d, want 2", updated)
}
if len(agent.providers) != 2 {
t.Fatalf("provider count = %d, want 2", len(agent.providers))
}
if agent.providers[0].Name != "openai" || agent.providers[1].Name != "kimi" {
t.Fatalf("providers = %#v, want top-level providers", agent.providers)
}
if agent.activeName != "openai" {
t.Fatalf("activeName = %q, want openai", agent.activeName)
}
}

func TestConfigureAgentProvidersFallsBackToLegacyProjectProviders(t *testing.T) {
cfg := &config.Config{
Projects: []config.ProjectConfig{{
Name: "demo",
Agent: config.AgentConfig{
Type: "stub-main",
Options: map[string]any{
"provider": "legacy",
},
Providers: []config.ProviderConfig{
{Name: "legacy", APIKey: "sk-legacy"},
{Name: "backup", APIKey: "sk-backup"},
},
},
Platforms: []config.PlatformConfig{{Type: "telegram", Options: map[string]any{"token": "x"}}},
}},
}

agent := &stubMainProviderAgent{}
updated, err := configureAgentProviders(agent, cfg, "demo", cfg.Projects[0].Agent.Options)
if err != nil {
t.Fatalf("configureAgentProviders() error: %v", err)
}
if updated != 2 {
t.Fatalf("updated = %d, want 2", updated)
}
if len(agent.providers) != 2 {
t.Fatalf("provider count = %d, want 2", len(agent.providers))
}
if agent.providers[0].Name != "legacy" || agent.providers[1].Name != "backup" {
t.Fatalf("providers = %#v, want legacy project providers", agent.providers)
}
if agent.activeName != "legacy" {
t.Fatalf("activeName = %q, want legacy", agent.activeName)
}
}

func TestReloadConfigUsesTopLevelProviders(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.toml")
content := `
[[providers]]
name = "openai"
api_key = "sk-openai"

[[providers]]
name = "kimi"
api_key = "sk-kimi"

[[projects]]
name = "demo"

[projects.agent]
type = "stub-main"

[projects.agent.options]
provider = "kimi"

[[projects.agent.providers]]
name = "legacy"
api_key = "sk-legacy"

[[projects.platforms]]
type = "telegram"

[projects.platforms.options]
token = "test-token"
`
if err := os.WriteFile(configPath, []byte(content), 0o644); err != nil {
t.Fatalf("write config: %v", err)
}

agent := &stubMainProviderAgent{}
engine := core.NewEngine("demo", agent, nil, filepath.Join(dir, "sessions.json"), core.LangEnglish)
result, err := reloadConfig(configPath, "demo", engine)
if err != nil {
t.Fatalf("reloadConfig() error: %v", err)
}
if result.ProvidersUpdated != 2 {
t.Fatalf("ProvidersUpdated = %d, want 2", result.ProvidersUpdated)
}
if len(agent.providers) != 2 {
t.Fatalf("provider count = %d, want 2", len(agent.providers))
}
if agent.providers[0].Name != "openai" || agent.providers[1].Name != "kimi" {
t.Fatalf("providers = %#v, want top-level providers", agent.providers)
}
if agent.activeName != "kimi" {
t.Fatalf("activeName = %q, want kimi", agent.activeName)
}
}
14 changes: 7 additions & 7 deletions cmd/cc-connect/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ func printProviderUsage() {
fmt.Println(`Usage: cc-connect provider <command> [options]

Commands:
add Add a new API provider to a project
list List providers for a project
remove Remove a provider from a project
add Add a shared API provider
list List effective providers for a project
remove Remove a shared API provider
import Import providers from cc-switch

Examples:
Expand All @@ -62,7 +62,7 @@ func initConfigPath(flagValue string) {
func runProviderAdd(args []string) {
fs := flag.NewFlagSet("provider add", flag.ExitOnError)
configFile := fs.String("config", "", "path to config file")
project := fs.String("project", "", "project name (required)")
project := fs.String("project", "", "project name (required for validation)")
name := fs.String("name", "", "provider name (required)")
apiKey := fs.String("api-key", "", "API key")
baseURL := fs.String("base-url", "", "API base URL (optional)")
Expand Down Expand Up @@ -93,7 +93,7 @@ func runProviderAdd(args []string) {
os.Exit(1)
}

fmt.Printf("✅ Provider %q added to project %q\n", *name, *project)
fmt.Printf("✅ Provider %q added to shared config for project %q\n", *name, *project)
if *baseURL != "" {
fmt.Printf(" Base URL: %s\n", *baseURL)
}
Expand Down Expand Up @@ -171,7 +171,7 @@ func listProjectProviders(projectName string) {
func runProviderRemove(args []string) {
fs := flag.NewFlagSet("provider remove", flag.ExitOnError)
configFile := fs.String("config", "", "path to config file")
project := fs.String("project", "", "project name (required)")
project := fs.String("project", "", "project name (required for validation)")
name := fs.String("name", "", "provider name (required)")
_ = fs.Parse(args)

Expand All @@ -188,7 +188,7 @@ func runProviderRemove(args []string) {
os.Exit(1)
}

fmt.Printf("✅ Provider %q removed from project %q\n", *name, *project)
fmt.Printf("✅ Provider %q removed from shared config for project %q\n", *name, *project)
}

// ── Import from cc-switch ──────────────────────────────────────
Expand Down
Loading
Loading